/*
 * NOT USED ANY MORE
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <inttypes.h>

#include <xen/callback.h>

#include "xenner.h"
#include "mm.h"
#include "cpufeature.h"

/* ------------------------------------------------------------------ */

static void emulate_rdtsc(struct xenvm *xen)
{
    uint64_t systime = get_systime() - xen->boot;

    xen->regs.rax = systime & 0xffffffff;
    xen->regs.rdx = systime >> 32;
}

static uint64_t *decode_reg(struct kvm_regs *regs, uint8_t modrm, int rm)
{
    int shift = rm ? 0 : 3;
    uint64_t *reg = NULL;

    switch ((modrm >> shift) & 0x07) {
    case 0: reg = (uint64_t*)&regs->rax;
    case 1: reg = (uint64_t*)&regs->rcx;
    case 2: reg = (uint64_t*)&regs->rdx;
    case 3: reg = (uint64_t*)&regs->rbx;
    case 4: reg = (uint64_t*)&regs->rsp;
    case 5: reg = (uint64_t*)&regs->rbp;
    case 6: reg = (uint64_t*)&regs->rsi;
    case 7: reg = (uint64_t*)&regs->rdi;
    }
    return reg;
}

static int emulate_x86(struct xenvm *xen, uint64_t rip)
{
    uint8_t *instr;
    int skip = 0;
    int in = 0;
    int shift = 0;
    
    instr = guest_vaddr_to_ptr(xen, rip);
    d2printf("%s: eip %" PRIx64 "  instr %02x %02x %02x %02x  %02x %02x %02x %02x\n",
	     __FUNCTION__, rip,
	     instr[0], instr[1], instr[2], instr[3],
	     instr[4], instr[5], instr[6], instr[7]);

    /* prefixes */
    if (instr[skip] == 0x66) {
	shift = 16;
	skip++;
    }

    /* instructions */
    switch (instr[skip]) {
    case 0x0f:
	switch (instr[skip+1]) {
	case 0x09:
	    /* wbinvd -- ignore for now, vmexit should come close enougth ;) */
	    skip += 2;
	    break;
	case 0x0b:
	    /* ud2a */
	    banner_print(xen, "ud2a -> kernel BUG()");
	    return 0; /* bounce to kernel so we get a fancy trace */
	case 0x20:
	{
	    /* read control registers */
	    uint64_t *reg = decode_reg(&xen->regs, instr[skip+2], 1);
	    switch (((instr[skip+2]) >> 3) & 0x07) {
	    case 0:
		d1printf("%s: read cr0\n", __FUNCTION__);
		*reg = xen->sregs.cr0;
		skip = 3;
		break;
	    case 3:
		d1printf("%s: read cr3\n", __FUNCTION__);
		*reg = xen->sregs.cr3;
		skip = 3;
		break;
	    case 4:
		d1printf("%s: read cr4\n", __FUNCTION__);
		*reg = xen->sregs.cr4;
		skip = 3;
		break;
	    }
	    break;
	}
	case 0x22:
	{
	    /* write control registers */
	    uint64_t *reg = decode_reg(&xen->regs, instr[skip+2], 1);
	    switch (((instr[skip+2]) >> 3) & 0x07) {
	    case 0:
		if (xen->sregs.cr0 != *reg)
		    d0printf("%s: ignore cr0 write (0x%llx -> 0x%" PRIx64 ")\n",
			     __FUNCTION__, xen->sregs.cr0, *reg);
		skip = 3;
		break;
	    case 4:
		if (xen->sregs.cr4 != *reg)
		    d0printf("%s: ignore cr4 write (0x%llx -> 0x%" PRIx64 ")\n",
			     __FUNCTION__, xen->sregs.cr4, *reg);
		skip = 3;
		break;
	    }
	    break;
	}
	case 0x30:
	    /* wrmsr */
	    emulate_wrmsr(xen);
	    skip += 2;
	    break;
	case 0x31:
	    /* rdtsc */
	    emulate_rdtsc(xen);
	    skip += 2;
	    break;
	case 0x32:
	    /* rdmsr */
	    emulate_rdmsr(xen);
	    skip += 2;
	    break;
	case 0xa2:
	    /* cpuid */
	    emulate_cpuid(xen);
	    skip += 2;
	    break;
	}
	break;
    case 0xe4: /* in     <next byte>,%al */
    case 0xe5:
	skip += 2;
	in = (instr[0] & 1) ? 2 : 1;
	break;
    case 0xec: /* in     (%dx),%al */
    case 0xed:
	skip += 1;
	in = (instr[0] & 1) ? 2 : 1;
	break;
    case 0xe6: /* out    %al,<next byte> */
    case 0xe7:
	skip += 2;
    case 0xee: /* out    %al,(%dx) */
    case 0xef:
	skip += 1;
    }

    /* unknown instruction */
    if (!skip) {
	d0printf("%s: rip %" PRIx64
		 "   instr %02x %02x %02x %02x  %02x %02x %02x %02x"
		 "   [ FAILED ]\n", __FUNCTION__, rip,
		 instr[0], instr[1], instr[2], instr[3],
		 instr[4], instr[5], instr[6], instr[7]);
	return -1;
    }

    /* I/O instruction */
    if (2 == in)
	xen->regs.rax |= 0xffffffff;
    if (1 == in)
	xen->regs.rax |= (0xffff << shift);

    d2printf("%s: %d bytes handled\n", __FUNCTION__, skip);
    return skip;
}

/* ------------------------------------------------------------------ */

static int bounce_trap_32(struct xenvm *xen, uint32_t *esp, int trapno, int cbno)
{
    uint32_t *kesp, eip, cs;
    int stack_switch = 0;
    int error_code = 0;
    int k = 0;
    
    if (trapno >= 0) {
	/* trap bounce */
	eip  = xen->xentr.tr32[trapno].address;
	cs   = xen->xentr.tr32[trapno].cs;
    }
    if (cbno >= 0) {
	/* callback */
	eip  = xen->xencb.cb32[cbno].eip;
	cs   = xen->xencb.cb32[cbno].cs;
    }

    if (14 == trapno) {
	/* page fault */
	xen->shinfo.sh32->vcpu_info[0].arch.cr2 = xen->sregs.cr2;
	error_code = 1;
    }

    if (error_code)
	esp++;

    if (!cs)
	vm_kill(xen, "no guest trap handler", esp[3]);

    if ((esp[1] & 0x03) < (cs & 0x03))
	vm_kill(xen, "bounce trap: illegal ring switch\n", 0);
    if ((esp[1] & 0x03) > (cs & 0x03))
	stack_switch = 1;

    /* prepare guest stack: copy from emu, so the handler
     * jumps straigt back without round-trip via emu */
    if (stack_switch) {
	d2printf("%s: switching stack to %" PRIx32 ":%" PRIx32 "\n",
		 __FUNCTION__, xen->tss_32->ss1, xen->tss_32->esp1);
	kesp = guest_vaddr_to_ptr(xen, xen->tss_32->esp1);
	kesp[-(++k)] = esp[4];   // push ss
	kesp[-(++k)] = esp[3];   // push esp
    } else {
	kesp = guest_vaddr_to_ptr(xen, esp[3]);
    }

    kesp[-(++k)] = esp[2];       // push eflags
    kesp[-(++k)] = esp[1];       // push cs
    kesp[-(++k)] = esp[0];       // push eip
    if (error_code)
	kesp[-(++k)] = esp[-1];  // push error code

    /* prepare emu stack */
    esp[0]  = eip;
    esp[1]  = cs;
    esp[2] &= EFLAGS_TRAPMASK;
    if (stack_switch) {
	esp[4] = xen->tss_32->ss1;
	esp[3] = xen->tss_32->esp1;
    }
    esp[3] -= 4*k;
    
    d2printf("%s: trap %d cb %d | code %x:%x stack %x:%x\n", __FUNCTION__,
	     trapno, cbno, esp[1], esp[0], esp[4], esp[3]);
    return 0;
}

static int bounce_trap_64(struct xenvm *xen, uint64_t *emu_stack, int trapno)
{
    uint64_t *gst_stack, rip, cs, stack_cs, rsp, ss;
    int error_code = 0;
    int k = 0;
    
    rip  = xen->xentr.tr64[trapno].address;
    cs   = 0xe033; // FIXME: xen->xentr.tr64[trapno].cs;
    if (!cs)
	vm_kill(xen, "no guest trap handler", emu_stack[3]);

    if (14 == trapno) {
	/* page fault */
	xen->shinfo.sh64->vcpu_info[0].arch.cr2 = xen->sregs.cr2;
	error_code = 1;
    }

    if (error_code)
	emu_stack++;

    stack_cs = emu_stack[1];

    /* prepare guest stack: copy from emu, so the handler
     * jumps straigt back without round-trip via emu */
    if (0 /* user mode */) {
	vm_kill(xen, "bounce trap 64: fixme", 0);
    } else {
	stack_cs &= ~3;         /* signal kernel mode */
	rsp = emu_stack[3] & ~0xf; /* align stack */
	ss  = emu_stack[4];
	gst_stack = guest_vaddr_to_ptr(xen, rsp);
    }
    
    gst_stack[-(++k)] = emu_stack[4];       // push ss
    gst_stack[-(++k)] = emu_stack[3];       // push esp
    gst_stack[-(++k)] = emu_stack[2];       // push eflags
    gst_stack[-(++k)] = stack_cs;           // push cs
    gst_stack[-(++k)] = emu_stack[0];       // push eip
    if (error_code)
	gst_stack[-(++k)] = emu_stack[-1];  // push error code

    gst_stack[-(++k)] = xen->regs.r11;       // push r11
    gst_stack[-(++k)] = xen->regs.rcx;       // push r11

    /* prepare emu stack */
    emu_stack[0]  = rip;
    emu_stack[1]  = cs;
    emu_stack[2] &= EFLAGS_TRAPMASK;
    emu_stack[4]  = ss;
    emu_stack[3]  = rsp;
    emu_stack[3] -= 8*k;

    d0printf("%s: trap %d | code %lx:%lx stack %lx:%lx\n", __FUNCTION__,
	     trapno, emu_stack[1], emu_stack[0], emu_stack[4], emu_stack[3]);
    return 0;
}

/* ------------------------------------------------------------------ */

int do_instr_emu(struct xenvm *xen)
{
    static const uint8_t emu_prefix[5] = {0x0f,0x0b,0x78,0x65,0x6e};
    uint32_t *esp;
    uint64_t *rsp;
    int rc, skip = 0;

    need_regs(xen);
    need_sregs(xen);
    switch (xen->mode) {
    case XENMODE_32:
    case XENMODE_PAE:
	esp = emu_vaddr_to_ptr(xen, xen->regs.rsp);
	if (0 == memcmp(guest_vaddr_to_ptr(xen, esp[0]), emu_prefix, 5)) {
	    d2printf("%s: emu prefix\n", __FUNCTION__);
	    skip = 5;
	}
	rc = emulate_x86(xen, esp[0]+skip);
	if (-1 == rc)
	    vm_kill(xen, "instr emu failure", 0);
	skip += rc;
	if (0 == skip) {
	    bounce_trap_32(xen, esp, 6, -1);
	} else {
	    esp[0] += skip;
	    d2printf("%s: continue at 0x%" PRIx32 "\n", __FUNCTION__, esp[0]);
	}
	break;
    case XENMODE_64:
	rsp = emu_vaddr_to_ptr(xen, xen->regs.rsp);
	if (0 == memcmp(guest_vaddr_to_ptr(xen, rsp[0]), emu_prefix, 5)) {
	    d2printf("%s: emu prefix\n", __FUNCTION__);
	    skip = 5;
	}
	rc = emulate_x86(xen, rsp[0]+skip);
	if (-1 == rc)
	    vm_kill(xen, "instr emu failure", 0);
	skip += rc;
	if (0 == skip) {
	    bounce_trap_64(xen, rsp, 6);
	} else {
	    rsp[0] += skip;
	    d2printf("%s: continue at 0x%" PRIx64 "\n", __FUNCTION__, rsp[0]);
	}
	break;
    default:
	d0printf("%s: unhandled xen mode %d\n", __FUNCTION__, xen->mode);
	break;
    }
    flush_regs(xen);
    return 0;
}

int do_general_protection(struct xenvm *xen)
{
    uint32_t *esp = NULL;
    uint64_t *rsp = NULL;
    int rc;

    need_regs(xen);
    need_sregs(xen);
    switch (xen->mode) {
    case XENMODE_32:
    case XENMODE_PAE:
	esp = emu_vaddr_to_ptr(xen, xen->regs.rsp);
	if (!(esp[2] & 3))
	    vm_kill(xen, "emu (ring0) general protection fault", 0);
	if (0 == esp[0]) {
	    switch (rc = emulate_x86(xen, esp[1])) {
	    case -1:
		vm_kill(xen, "gpf32: emu failure", esp[4]);
		break;
	    case 0:
		vm_kill(xen, "gpf32: FIXME: bounce trap", esp[4]);
		break;
	    default:
		esp[1] += rc;
		break;
	    }
	    break;
	}
	d0printf("%s: index 0x%x%s%s%s\n", __FUNCTION__,
		 esp[0] >> 3,
		 (esp[0] & 0x04) ? ", TI"  : "",
		 (esp[0] & 0x02) ? ", IDT" : "",
		 (esp[0] & 0x01) ? ", EXT" : "");
	vm_kill(xen, "gpf32", esp[4]);
	break;
    case XENMODE_64:
	rsp = emu_vaddr_to_ptr(xen, xen->regs.rsp);
	if (!(rsp[2] & 3))
	    vm_kill(xen, "emu (ring0) general protection fault", 0);
	if (0 == rsp[0]) {
	    switch (rc = emulate_x86(xen, rsp[1])) {
	    case -1:
		vm_kill(xen, "gpf64: emu failure", rsp[4]);
		break;
	    case 0:
		vm_kill(xen, "gpf64: FIXME: bounce trap", rsp[4]);
		break;
	    default:
		rsp[1] += rc;
		break;
	    }
	    break;
	}
	d0printf("%s: index 0x%" PRIx64 "%s%s%s\n", __FUNCTION__,
		 rsp[0] >> 3,
		 (rsp[0] & 0x04) ? ", TI"  : "",
		 (rsp[0] & 0x02) ? ", IDT" : "",
		 (rsp[0] & 0x01) ? ", EXT" : "");
	vm_kill(xen, "gpf64", rsp[4]);
    }
    flush_regs(xen);
    return 0;
}
