/*
   PKCIPE - public key based configuration tool for CIPE

   proto.c - the PKCIPE protocol handler

   Copyright 2000 Olaf Titz <olaf@bigred.inka.de>

   Modified by Lukasz Engel <Lukasz.Engel@softax.pl> (C) 2002

   This program is free software; you can redistribute it and/or
   modify it under the terms of the GNU General Public License
   as published by the Free Software Foundation; either version
   2 of the License, or (at your option) any later version.
*/
/* $Id: proto.c,v 1.15 2003/05/10 20:07:44 olaf81825 Exp $ */

#include <alloca.h>
#include <fcntl.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <openssl/bn.h>
#include <openssl/dh.h>
#include <openssl/pem.h>
#include <openssl/rand.h>
#include "pkcipe.h"
#include "dhkey.h"

#ifndef KEYDIR
#define KEYDIR "/etc/cipe/pk/%s"
#endif

static DH *dhp=NULL;
static const unsigned char dh_prime_bin[] = DH_PRIME_BIN;
static const unsigned char dh_gen_bin[] = DH_GEN_BIN;

#define NONCESLEN 16
static char nonces[NONCESLEN]={0};

int protoVersion=MAX_SUPPORTED_PROTOCOL;

char *myIdentity=NULL;
char *peerIdentity=NULL;
EVP_PKEY *myKey=NULL;
static EVP_PKEY *peerKey=NULL;

static struct in_addr connectAddr;

int initDH(void)
{
    if (dhp)
	return 1;
    dhp=DH_new();
    dhp->p=BN_bin2bn(dh_prime_bin, DH_PRIME_BYTES, NULL);
    dhp->g=BN_bin2bn(dh_gen_bin, DH_GEN_BYTES, NULL);

#if 0 /* testing */
    {
	int c;
	if (!DH_check(dhp, &c) || c) {
	    Log(LOG_ERR, "DH_check: %d", c);
	    SSLprinterror(LOG_ERR);
	    return -1;
	}
    }
    Log(LOG_INFO, "DH_size: %d*8", DH_size(dhp));
#endif
    return 0;
}


/* Reserve a socket, get its correct address. This socket is later used
   by ciped. The address is determined by connecting to the peer. Since
   we don't know the right port yet, use the canonical dummy (discard
   service) and let ciped re-connect.
   This requires that "connectAddr" contains the right address. When
   the pkcipe TCP connection goes through a tunnel or proxy, the -r
   argument is needed, or this routine could select the wrong local
   address (in particular 127.0.0.1).
*/

static int getMySocket(struct sockaddr_in *sa, const char* me)
{
    int n;
    struct sockaddr_in sb;
    int f=socket(AF_INET, SOCK_DGRAM, 0);

    if (f<0) {
	perror("getMySocket: socket");
	return -1;
    }
    memset(&sb, 0, sizeof(sb));
    sb.sin_family=AF_INET;
    if (me) {
        getaddr(me, &sb, "udp");
    }
    if (bind(f, (struct sockaddr*)&sb, sizeof(sb))<0) {
	perror("getMySocket: bind");
    }
    sb.sin_port=htons(9);
    sb.sin_addr=connectAddr;
    if (connect(f, (struct sockaddr*)&sb, sizeof(sb))<0) {
	perror("getMySocket: connect");
    }
    n=sizeof(*sa);
    if (getsockname(f, (struct sockaddr*)sa, &n)<0) {
	perror("getMySocket: getsockname");
    }
    return f;
}

/* This is really not a fingerprint, but a simple checksum.
   Used mainly for catching configuration errors (wrong key).
*/
/* XX assumes sizeof(unsigned int)==4 */
unsigned int BNfingerprint(const BIGNUM *a)
{
    unsigned int r=0;
    int l=BN_num_bytes(a);
    unsigned char *b=alloca(l);
    unsigned int *p=(unsigned int *)b;
    if (!b)
	return 0;
    BN_bn2bin(a, b);
    while (l>=4) {
        r^=ntohl(*p);
        ++p;
	l-=4;
    }
    return r;
}
#define PKfingerprint(k) BNfingerprint((k)->pkey.rsa->n)

INLINE int getKeys(char *buf, int len, int disc)
{
    int i;
    P_SHA1_CTX *ctx;
    char b[40];

    debug((DEB_KEY, "getKeys %d %d", len, disc));
    switch (protoVersion) {
    case 1:
	/* One key is 16 bytes.
	   We take the first one iff our DH pubkey is the bigger one. */
        /* BROKEN: improper use of the DH result. The bits are not
           uniformly distributed, much less the high order ones. */
	i=(disc>0)?16:0;
	setSendKey(buf+i);
	setRecvKey(buf+16-i);
	/* The third key is for the CIPE daemon. */
	setOption("key", hexstr(buf+32, 16), OF_REQVAL);
        return 0;

    case 2:
    case 3:
        /* Generate three blocks from the buffer using P_SHA1.
           The first block is the sending key for the bigger DH pubkey.
           The second block is the sending key for the smaller DH pubkey.
           The third block is the CIPE key, only 16 bytes used. */
#ifdef DEBUG
        if (debugging&DEB_KEY)
            hexdump(nonces, NONCESLEN);
#endif
        if (!(ctx=P_SHA1_init(buf, len, nonces, NONCESLEN)))
            return -1;
        P_SHA1_block(ctx, b);
        P_SHA1_block(ctx, b+20);
	i=(disc>0)?20:0;
	setSendKey(b+i);
	setRecvKey(b+20-i);
        P_SHA1_block(ctx, b);
        setOption("key", hexstr(b, 20), OF_REQVAL);
        P_SHA1_free(ctx);
        return 0;

    default:
        abort();
    }
}


/* The protocol requires that both sides send packets alternately.
   The connection is switched to encrypted as soon as both sides have
   received PKT_DHKEY.

   Protocol state diagram:

   rec    state     send
         (start)
            v       NONCE
         Snonce
   NONCE    v       DHKEY
         Sdhkey
   DHKEY    v       IDENT
         Sident
   IDENT    v       SIGN
          Ssign
   SIGN     v       OPT
  ______> Sopt
 |        / | \                            (any)              (any)
 |OPT    /  |  ---------------               |                  |
 -------/   |          READY | READY         | ERROR      ERROR | DONE
  OPT       |                |               v                  v
        OPT | READY          |             Serr               Sdone
            v                |               |
          Sready             |         DONE  | (close)
   READY    v       DONE     |               v
          Sdone <____________|             (out)
   (eof)    v
          (out)

*/

INLINE pState handlePacket(int fd, unsigned char *pkt, int len, pState state)
{
    unsigned char buf[PKTMAXLEN];
    int i=0;

    #define pkttyp (pkt[0])
    #define Send(t,s) do{						\
	int k=snprintf(buf, sizeof(buf), "%c%s", (t), (s));		\
	packetSend(fd, buf, k);						\
			}while(0)

    #define SendErrorRet(s) do{						\
	int k=snprintf(buf, sizeof(buf), "%c%s", PKT_ERROR, (s));	\
	packetSend(fd, buf, k);						\
	Log(LOG_ERR, "handlePacket: error: %s", buf+1);		        \
	return Serr;							\
			}while(0)

#ifdef DEBUG
    memset(buf, 0xbb, sizeof(buf));
#endif
    if (pkttyp==PKT_ERROR) {
	Log(LOG_NOTICE, "handlePacket: received ERROR: %s", pkt+1);
	Send(PKT_DONE, "error");
	return Sdone;
    }
    if (pkttyp==PKT_DEBUG) {
	Log(LOG_DEBUG, "handlePacket: peer: %s", pkt+1);
	return state;
    }

    switch (state) {

    case Snonce:
	if (pkttyp!=PKT_NONCE)
	    goto stateerr;
	if (len<8)
	    SendErrorRet("nonce too short");
        if (len>NONCESLEN+1)
            len=NONCESLEN+1;
        for (i=0; i<len-1; ++i)
            nonces[i]^=pkt[i+1];
        DH_generate_key(dhp);
	packetSendBN(fd, PKT_DHKEY, dhp->pub_key);
	return Sdhkey;

    case Sdhkey:
	if (pkttyp!=PKT_DHKEY)
	    goto stateerr;
	if (mlock(buf, sizeof(buf))<0)
	    Log(LOG_ERR, "handlePacket: mlock: %m"); /* not fatal */
	{
	    BIGNUM *a=packetExtrBN(pkt, len);
	    i=BN_cmp(a, dhp->pub_key);
	    if (!i) {
		Log(LOG_ALERT, "handlePacket: Two equal DH pubkeys???");
		abort(); /* something is VERY wrong */
	    }
	    if (DH_compute_key(buf, a, dhp)<0) {
		SSLprinterror(LOG_ERR);
		BN_free(a);
		SendErrorRet("internal bad DH");
	    }
	    BN_free(a);
	}
        /* Get all keys, switch to encrypted */
        i=getKeys(buf, DH_size(dhp), i);
	/* Clear key material. Not munlocking here is probably overzealous
           since we are on the stack, but better safe... */
	memset(buf, 0, sizeof(buf));
        if (i<0)
            SendErrorRet("handlePacket: getKeys failed");

	/* Send out identity next. */
	i=snprintf(buf, sizeof(buf), "%c%08x %s", PKT_IDENT,
		   PKfingerprint(myKey), myIdentity);
	packetSend(fd, buf, i);
	return Sident;

    case Sident:
	if (pkttyp!=PKT_IDENT)
	    goto stateerr;
	/* Packet is Tffffffff iiiii...*/
	if (len>MAXIDLEN+10)
	    SendErrorRet("handlePacket: identity too long");
	peerIdentity=strdup(pkt+10);
	/* FIXME error handling */
	(void)lockMaster();
	i=lockPeer();
	(void)unlockMaster();
	if (i<0)
	    SendErrorRet("handlePacket: Lock failed");
	if (i>0)
	    SendErrorRet("handlePacket: peer is talking");
	snprintf(buf, sizeof(buf), KEYDIR, peerIdentity);
	debug((DEB_PROTO, "Using peer key %s", buf));
	if (secchk(buf, 0077, 0022, 0)==0) {
	    FILE *f=fopen(buf, "r");
	    if (f) {
		unsigned int k,l=0;
		peerKey=PEM_read_PUBKEY(f, NULL, NULL, NULL);
		if (!peerKey)
		    SSLprinterror(LOG_ERR);
		pkt[9]='\0';
		sscanf(pkt+1, "%x", &k);
		l=PKfingerprint(peerKey);
		if (k!=l) {
		    Log(LOG_NOTICE,
                        "handlePacket: Key fingerprint mismatch %08x %08x",
                        k, l);
		    EVP_PKEY_free(peerKey);
		    peerKey=NULL;
		} else {
		    readOptions(f);
		}
		fclose(f);
	    } else {
                /* file has vanished just after secchk() - not good */
                Log(LOG_ERR, "handlePacket: %s: %m", buf);
	    }
	} else {
            /* Don't tell the peer (yet) that we don't know him */
            Log(LOG_NOTICE, "handlePacket: %s: unknown identity",
                buf);
        }
	/* Sign our sent stuff. */
	buf[0]=PKT_SIGN;
	i=EVP_PKEY_size(myKey);
	if (signFinal(buf+1, &i, myKey)<=0) {
	    SSLprinterror(LOG_ERR);
	    SendErrorRet("Signing failed");
	}
	packetSend(fd, buf, i+1);
	return Ssign;

    case Ssign:
	if (pkttyp!=PKT_SIGN)
	    goto stateerr;
	/* Verify the peer's signature. */
	if (!peerKey)
	    goto sigerr;
	if ((i=vrfyFinal(pkt+1, len-1, peerKey))==1) {
	    struct sockaddr_in sa;
	    debug((DEB_PROTO, "Good signature"));
	    /* Determine own UDP address, send this as first option. */
	    if ((cipeSocket=getMySocket(&sa,getOption("me", NULL)))<0)
		SendErrorRet("Could not obtain UDP socket");
	    i=snprintf(buf, sizeof(buf), "%cme=%s:%d", PKT_OPT_REQ,
                       inet_ntoa(sa.sin_addr), ntohs(sa.sin_port));
	    setOption("me", buf+4, OF_DEFAULT); /* later? */
	    packetSend(fd, buf, i);
	    return Sopt;
	}
	debug((DEB_PROTO, "Bad signature"));
	SSLprinterror(LOG_ERR);
    sigerr:
	SendErrorRet("Signature check failed");

    case Sopt:
	switch(pkttyp) {
	case PKT_OPT_REQ:
#if 0
	case PKT_OPT_ACK:
	case PKT_OPT_NAK:
	case PKT_OPT_REJ:
#endif
	    return negotiate(fd, pkt, len);
	case PKT_READY:
	    return ready(fd);
	}
	goto stateerr;

    case Sready:
	if (pkttyp!=PKT_READY)
	    goto stateerr;
	Send(PKT_DONE, "ready");
	return Sdone;

    case Serr:
	Send(PKT_DONE, "error");
	return Sdone;

    default:
	SendErrorRet("internal bad state");
    }

 stateerr:
    i=snprintf(buf, sizeof(buf), "%cunexpected packet %02x/%02x",
	       PKT_ERROR, state, pkttyp);
    packetSend(fd, buf, i);
    Log(LOG_NOTICE, "handlePacket: %s", buf+1);
    return Serr;
}

int doProtocol(int fd, struct in_addr addr)
{
    unsigned char buf[PKTMAXLEN];
    int l;
    pState state=Snonce;
    connectAddr=addr; /* not nice */
    /* 0123456789abcdef012345678
       T........ssssuuuu */
    buf[0]=PKT_NONCE;
    RAND_bytes(buf+1, 8);
    {
        struct timeval tv;
        gettimeofday(&tv, NULL);
        *(long*)(buf+9)=htonl(tv.tv_sec);
        *(long*)(buf+13)=htonl(tv.tv_usec);
    }
    memcpy(nonces, buf+1, 16);
    packetSend(fd, buf, 17);

    while (state!=Sdone) {
	if ((l=packetRecv(fd, buf, sizeof(buf)))<0) {
	    /* read error or unexpected EOF */
	    Log(LOG_NOTICE, "doProtocol: read error: %m");
	    return -1;
	}
	debug((DEB_PROTO,
	       "handlePacket %d %02x/%02x", l, state, buf[0]));
	state=handlePacket(fd, buf, l, state);
    }
    return 0;
}
