// SPDX-FileCopyrightText: 2022 UnionTech Software Technology Co., Ltd.
//
// SPDX-License-Identifier: GPL-3.0-or-later

package main

import (
	"bytes"
	"errors"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"os/exec"
	"path/filepath"
	"regexp"
	"strings"

	"github.com/Netflix/go-expect"
)

func getFileBlock(dev, file string) (string, error) {
	cmdStr := fmt.Sprintf("debugfs %s -R 'stat %s' | grep '(0):' | cut -d: -f2", dev, file)
	log.Println("$", cmdStr)
	out, err := exec.Command("sh", "-c", cmdStr).Output()
	if err != nil {
		return "", err
	}
	blockId := string(bytes.TrimSpace(out))
	if blockId == "" {
		return "", errors.New("empty block id")
	}
	return blockId, nil
}

func setFsDirty(dev string) error {
	err := exec.Command("debugfs", dev, "-w", "-R", "dirty").Run()
	return err
}

func modifyFileInode(dev, file, thisFBlock, newFBlock string) error {
	if thisFBlock == newFBlock {
		log.Println("already done")
		return nil
	}

	conOpts := expect.WithStdout(os.Stdout)
	console, err := expect.NewConsole(conOpts)
	if err != nil {
		return err
	}
	// debugfs $dev -w -R 'mi $file'
	log.Printf("$ debugfs %s -w -R 'mi %s'\n", dev, file)
	cmd := exec.Command("debugfs", dev, "-w", "-R", "mi "+file)
	cmd.Stdout = console.Tty()
	cmd.Stdin = console.Tty()
	cmd.Stderr = console.Tty()

	err = cmd.Start()
	if err != nil {
		return err
	}

	for {
		buf, err := console.Expect(expect.Regexp(regexp.MustCompile(`\[\w+]`)))
		if err != nil {
			return err
		}
		log.Println("buf is", buf)

		reply := "\n"
		if strings.Contains(buf, "Direct Block") &&
			strings.Contains(buf, fmt.Sprintf("[%s]", thisFBlock)) {
			reply = newFBlock + "\n"
		}

		_, err = console.Send(reply)
		if err != nil {
			return err
		}

		// 最后一行
		if strings.Contains(buf, "Triple Indirect Block") {
			break
		}
	}

	err = cmd.Wait()
	if err != nil {
		return err
	}

	return nil
}

func mount(dev, mountPoint, fsType string) error {
	return exec.Command("mount", "-t", fsType, dev, mountPoint).Run()
}

func umount(dev string) error {
	return exec.Command("umount", dev).Run()
}

func getFsTypeWithFstype(dev string) (string, error) {
	out, err := exec.Command("fstype", dev).Output()
	if err != nil {
		return "", err
	}
	var ret string
	_, err = fmt.Sscanf(string(out), "FSTYPE=%s\n", &ret)
	if err != nil {
		return "", err
	}
	if strings.EqualFold(ret, "unknown") {
		return "", errors.New("fs type is unknown")
	}
	return ret, nil
}

func getFsTypeWithBlkid(dev string) (string, error) {
	// blkid -o value -s TYPE /dev/nvme0n1p2
	out, err := exec.Command("blkid", "-o", "value", "-s", "TYPE", dev).Output()
	if err != nil {
		return "", err
	}
	ret := string(bytes.TrimSpace(out))
	return ret, nil
}

func getFsType(dev string) (string, error) {
	fsType, err := getFsTypeWithFstype(dev)
	if err == nil {
		return fsType, nil
	}

	return getFsTypeWithBlkid(dev)
}

var _opts struct {
	method int
}

func init() {
	flag.IntVar(&_opts.method, "m", 1, "method")
}

// 专用于 ext4 系列文件系统，制造两个文件具有相同的 block id 的异常情况。
func method1(dev string) {
	err := os.MkdirAll("/mnt", 0755)
	if err != nil {
		log.Fatal("mkdir /mnt failed:", err)
	}

	fsType, err := getFsType(dev)
	if err != nil {
		log.Fatal("get device fsType failed:", err)
	}
	log.Println("DBG: fsType:", fsType)

	err = mount(dev, "/mnt", fsType)
	if err != nil {
		log.Fatalf("mount device %s failed: %v", dev, err)
	}

	err = os.MkdirAll("/mnt/media/root", 0755)
	if err != nil {
		log.Fatal("mkdir /mnt/media/root failed:", err)
	}

	const file1 = "/media/root/file1.txt"
	const file2 = "/media/root/file2.txt"

	_, err = os.Stat(filepath.Join("/mnt", file1))
	if os.IsNotExist(err) {
		err = ioutil.WriteFile(filepath.Join("/mnt", file1), []byte("this is file 1"), 0644)
		if err != nil {
			log.Fatal("write file1 failed:", err)
		}
	}

	_, err = os.Stat(filepath.Join("/mnt", file2))
	if os.IsNotExist(err) {
		err = ioutil.WriteFile(filepath.Join("/mnt", file2), []byte("this is file 2"), 0644)
		if err != nil {
			log.Fatal("write file2 failed:", err)
		}
	}

	err = umount(dev)
	if err != nil {
		log.Fatal("umount failed:", err)
	}

	f1Block, err := getFileBlock(dev, file1)
	if err != nil {
		log.Fatal("get file1 block failed:", err)
	}
	fmt.Println("f1block:", f1Block)

	f2Block, err := getFileBlock(dev, file2)
	if err != nil {
		log.Fatal("get file2 block failed:", err)
	}
	fmt.Println("f2block:", f2Block)

	err = setFsDirty(dev)
	if err != nil {
		log.Fatal("set dirty failed:", err)
	}

	err = modifyFileInode(dev, file2, f2Block, f1Block)
	if err != nil {
		log.Fatal("mi file failed:", err)
	}

}

// exit 2 表示需要重启
const exitCodeNeedReboot = 2

func method2(dev string) {
	// 破坏超级块 super block
	// dd if=/dev/zero of=/dev/sda1 bs=512 count=3
	err := exec.Command("dd", "if=/dev/zero", "of="+dev, "bs=512", "count=3").Run()
	if err != nil {
		log.Fatal("run dd cmd failed:", err)
	}

	os.Exit(exitCodeNeedReboot)
}

func printHelp() {
	fmt.Fprint(os.Stderr, `File System Breaker

Usage:
	fs-break -m <method> <rootPartitionDevice>

Options:
	-h --help    Show help message.
	-m <method>  Set the break method, method can be 1~2.
`)
}

func main() {
	log.SetFlags(log.Lshortfile)
	flag.Usage = printHelp
	flag.Parse()

	if len(flag.Args()) == 0 {
		printHelp()
		os.Exit(1)
	}

	dev := flag.Args()[0]

	switch _opts.method {
	case 1:
		method1(dev)
	case 2:
		method2(dev)
	default:
		log.Printf("WARN: invalid method %d\n", _opts.method)
	}
}
