#include        "u.h"
#include        "../port/lib.h"
#include        "mem.h"
#include        "dat.h"
#include        "fns.h"
#include        "../port/error.h"

#include        "ip.h"

#include        "libsec.h"

typedef struct Esphdr Esphdr;
typedef struct Esptail Esptail;
typedef struct Userhdr Userhdr;
typedef struct Esppriv Esppriv;
typedef struct Espcb Espcb;
typedef struct Algorithm Algorithm;
typedef struct Esprc4 Esprc4;

#define DPRINT if(0)print

enum
{
       IP_ESPPROTO     = 50,
       EsphdrSize      = 28,   // includes IP header
       IphdrSize       = 20,   // options have been stripped
       EsptailSize     = 2,    // does not include pad or auth data
       UserhdrSize     = 4,    // user visable header size - if enabled
};

struct Esphdr
{
       /* ip header */
       uchar   vihl;           /* Version and header length */
       uchar   tos;            /* Type of service */
       uchar   length[2];      /* packet length */
       uchar   id[2];          /* Identification */
       uchar   frag[2];        /* Fragment information */
       uchar   Unused;
       uchar   espproto;       /* Protocol */
       uchar   espplen[2];     /* Header plus data length */
       uchar   espsrc[4];      /* Ip source */
       uchar   espdst[4];      /* Ip destination */

       /* esp header */
       uchar   espspi[4];      /* Security parameter index */
       uchar   espseq[4];      /* Sequence number */
};

struct Esptail
{
       uchar   pad;
       uchar   nexthdr;
};

/* header as seen by the user */
struct Userhdr
{
       uchar   nexthdr;        // next protocol
       uchar   unused[3];
};

struct Esppriv
{
       ulong   in;
       ulong   inerrors;
};

/*
*  protocol specific part of Conv
*/
struct Espcb
{
       int     incoming;
       int     header;         // user user level header
       ulong   spi;
       ulong   seq;            // last seq sent
       ulong   window;         // for replay attacks
       char    *espalg;
       void    *espstate;      // other state for esp
       int     espivlen;       // in bytes
       int     espblklen;
       int     (*cipher)(Espcb*, uchar *buf, int len);
       char    *ahalg;
       void    *ahstate;       // other state for esp
       int     ahlen;          // auth data length in bytes
       int     ahblklen;
       int     (*auth)(Espcb*, uchar *buf, int len, uchar *hash);
};

struct Algorithm
{
       char    *name;
       int     keylen;         // in bits
       void    (*init)(Espcb*, char* name, uchar *key, int keylen);
};


enum {
       RC4forward      = 10*1024*1024, // maximum skip forward
       RC4back = 100*1024,             // maximum look back
};

struct Esprc4
{
       ulong cseq;     // current byte sequence number
       RC4state current;

       int ovalid;     // old is valid
       ulong lgseq; // last good sequence
       ulong oseq;     // old byte sequence number
       RC4state old;
};

static  Conv* convlookup(Proto *esp, ulong spi);
static  char *setalg(Espcb *ecb, char **f, int n, Algorithm *alg);
static  void nullespinit(Espcb*, char*, uchar *key, int keylen);
static  void nullahinit(Espcb*, char*, uchar *key, int keylen);
static  void shaahinit(Espcb*, char*, uchar *key, int keylen);
static  void md5ahinit(Espcb*, char*, uchar *key, int keylen);
static  void desespinit(Espcb *ecb, char *name, uchar *k, int n);
static  void rc4espinit(Espcb *ecb, char *name, uchar *k, int n);

static Algorithm espalg[] =
{
       "null",                 0,      nullespinit,
       "des_56_cbc",           64,     desespinit,
       "rc4_128",              128,    rc4espinit,
       nil,                    0,      nil,
};

static Algorithm ahalg[] =
{
       "null",                 0,      nullahinit,
       "hmac_sha1_96",         128,    shaahinit,
       "hmac_md5_96",          128,    md5ahinit,
       nil,                    0,      nil,
};

static char*
espconnect(Conv *c, char **argv, int argc)
{
       char *p, *pp;
       char *e = nil;
       ulong spi;
       Espcb *ecb = (Espcb*)c->ptcl;

       switch(argc) {
       default:
               e = "bad args to connect";
               break;
       case 2:
               p = strchr(argv[1], '!');
               if(p == nil){
                       e = "malformed address";
                       break;
               }
               *p++ = 0;
               parseip(c->raddr, argv[1]);
               findlocalip(c->p->f, c->laddr, c->raddr);
               ecb->incoming = 0;
               ecb->seq = 0;
               if(strcmp(p, "*") == 0) {
                       qlock(c->p);
                       for(;;) {
                               spi = nrand(1<<16) + 256;
                               if(convlookup(c->p, spi) == nil)
                                       break;
                       }
                       qunlock(c->p);
                       ecb->spi = spi;
                       ecb->incoming = 1;
                       qhangup(c->wq, nil);
               } else {
                       spi = strtoul(p, &pp, 10);
                       if(pp == p) {
                               e = "malformed address";
                               break;
                       }
                       ecb->spi = spi;
                       qhangup(c->rq, nil);
               }
               nullespinit(ecb, "null", nil, 0);
               nullahinit(ecb, "null", nil, 0);
       }
       Fsconnected(c, e);

       return e;
}


static int
espstate(Conv *c, char *state, int n)
{
       return snprint(state, n, "%s", c->inuse?"Open\n":"Closed\n");
}

static void
espcreate(Conv *c)
{
       c->rq = qopen(64*1024, 1, 0, 0);
       c->wq = qopen(64*1024, 0, 0, 0);
}

static void
espclose(Conv *c)
{
       Espcb *ecb;

       qclose(c->rq);
       qclose(c->wq);
       qclose(c->eq);
       ipmove(c->laddr, IPnoaddr);
       ipmove(c->raddr, IPnoaddr);

       ecb = (Espcb*)c->ptcl;
       free(ecb->espstate);
       free(ecb->ahstate);
       memset(ecb, 0, sizeof(Espcb));
}

void
espkick(Conv *c)
{
       Esphdr *eh;
       Esptail *et;
       Userhdr *uh;
       Espcb *ecb;
       Block *bp;
       int nexthdr;
       int payload;
       int pad;
       int align;
       uchar *auth;

       bp = qget(c->wq);
       if(bp == nil)
               return;

       qlock(c);
       ecb = c->ptcl;

       if(ecb->header) {
               /* make sure the message has a User header */
               bp = pullupblock(bp, UserhdrSize);
               if(bp == nil) {
                       qunlock(c);
                       return;
               }
               uh = (Userhdr*)bp->rp;
               nexthdr = uh->nexthdr;
               bp->rp += UserhdrSize;
       } else {
               nexthdr = 0;  // what should this be?
       }

       payload = BLEN(bp) + ecb->espivlen;

       /* Make space to fit ip header */
       bp = padblock(bp, EsphdrSize + ecb->espivlen);

       align = 4;
       if(ecb->espblklen > align)
               align = ecb->espblklen;
       if(align % ecb->ahblklen != 0)
               panic("espkick: ahblklen is important after all");
       pad = (align-1) - (payload + EsptailSize-1)%align;

       /*
        * Make space for tail
        * this is done by calling padblock with a negative size
        * Padblock does not change bp->wp!
        */
       bp = padblock(bp, -(pad+EsptailSize+ecb->ahlen));
       bp->wp += pad+EsptailSize+ecb->ahlen;

       eh = (Esphdr *)(bp->rp);
       et = (Esptail*)(bp->rp + EsphdrSize + payload + pad);

       // fill in tail
       et->pad = pad;
       et->nexthdr = nexthdr;

       ecb->cipher(ecb, bp->rp+EsphdrSize, payload+pad+EsptailSize);
       auth = bp->rp + EsphdrSize + payload + pad + EsptailSize;

       // fill in head
       hnputl(eh->espspi, ecb->spi);
       hnputl(eh->espseq, ++ecb->seq);
       v6tov4(eh->espsrc, c->laddr);
       v6tov4(eh->espdst, c->raddr);
       eh->espproto = IP_ESPPROTO;
       eh->frag[0] = 0;
       eh->frag[1] = 0;

       ecb->auth(ecb, bp->rp+IphdrSize, (EsphdrSize-IphdrSize)+payload+pad+EsptailSize, auth);

       qunlock(c);
//print("esp: pass down: %uld\n", BLEN(bp));
       ipoput(c->p->f, bp, 0, c->ttl, c->tos);
}

void
espiput(Proto *esp, Ipifc*, Block *bp)
{
       Esphdr *eh;
       Esptail *et;
       Userhdr *uh;
       Conv *c;
       Espcb *ecb;
       uchar raddr[IPaddrlen], laddr[IPaddrlen];
       Fs *f;
       uchar *auth;
       ulong spi;
       int payload, nexthdr;

       f = esp->f;

       bp = pullupblock(bp, EsphdrSize+EsptailSize);
       if(bp == nil) {
               netlog(f, Logesp, "esp: short packet\n");
               return;
       }

       eh = (Esphdr*)(bp->rp);
       spi = nhgetl(eh->espspi);
       v4tov6(raddr, eh->espsrc);
       v4tov6(laddr, eh->espdst);

       qlock(esp);
       /* Look for a conversation structure for this port */
       c = convlookup(esp, spi);
       if(c == nil) {
               qunlock(esp);
               netlog(f, Logesp, "esp: no conv %I -> %I!%d\n", raddr,
                       laddr, spi);
               icmpnoconv(f, bp);
               freeblist(bp);
               return;
       }

       qlock(c);
       qunlock(esp);

       ecb = c->ptcl;
       // too hard to do decryption/authentication on block lists
       if(bp->next)
               bp = concatblock(bp);

       if(BLEN(bp) < EsphdrSize + ecb->espivlen + EsptailSize + ecb->ahlen) {
               qunlock(c);
               netlog(f, Logesp, "esp: short block %I -> %I!%d\n", raddr,
                       laddr, spi);
               freeb(bp);
               return;
       }

       eh = (Esphdr*)(bp->rp);
       auth = bp->wp - ecb->ahlen;
       if(!ecb->auth(ecb, eh->espspi, auth-eh->espspi, auth)) {
               qunlock(c);
print("esp: bad auth %I -> %I!%ld\n", raddr, laddr, spi);
               netlog(f, Logesp, "esp: bad auth %I -> %I!%d\n", raddr,
                       laddr, spi);
               freeb(bp);
               return;
       }

       payload = BLEN(bp)-EsphdrSize-ecb->ahlen;
       if(payload<=0 || payload%4 != 0 || payload%ecb->espblklen!=0) {
               qunlock(c);
               netlog(f, Logesp, "esp: bad length %I -> %I!%d payload=%d BLEN=%d\n", raddr,
                       laddr, spi, payload, BLEN(bp));
               freeb(bp);
               return;
       }
       if(!ecb->cipher(ecb, bp->rp+EsphdrSize, payload)) {
               qunlock(c);
print("esp: cipher failed %I -> %I!%ld: %r\n", raddr, laddr, spi);
               netlog(f, Logesp, "esp: cipher failed %I -> %I!%d: %r\n", raddr,
                       laddr, spi);
               freeb(bp);
               return;
       }

       payload -= EsptailSize;
       et = (Esptail*)(bp->rp + EsphdrSize + payload);
       payload -= et->pad + ecb->espivlen;
       nexthdr = et->nexthdr;
       if(payload <= 0) {
               qunlock(c);
               netlog(f, Logesp, "esp: short packet after decrypt %I -> %I!%d\n", raddr,
                       laddr, spi);
               freeb(bp);
               return;
       }

       // trim packet
       bp->rp += EsphdrSize + ecb->espivlen;
       bp->wp = bp->rp + payload;
       if(ecb->header) {
               // assume UserhdrSize < EsphdrSize
               bp->rp -= UserhdrSize;
               uh = (Userhdr*)bp->rp;
               memset(uh, 0, UserhdrSize);
               uh->nexthdr = nexthdr;
       }

       if(qfull(c->rq)){
               netlog(f, Logesp, "esp: qfull %I -> %I.%uld\n", raddr,
                       laddr, spi);
               freeblist(bp);
       }else {
//print("esp: pass up: %uld\n", BLEN(bp));
               qpass(c->rq, bp);
       }

       qunlock(c);
}

char*
espctl(Conv *c, char **f, int n)
{
       Espcb *ecb = c->ptcl;
       char *e = nil;

       if(strcmp(f[0], "esp") == 0)
               e = setalg(ecb, f, n, espalg);
       else if(strcmp(f[0], "ah") == 0)
               e = setalg(ecb, f, n, ahalg);
       else if(strcmp(f[0], "header") == 0)
               ecb->header = 1;
       else if(strcmp(f[0], "noheader") == 0)
               ecb->header = 0;
       else
               e = "unknown control request";
       return e;
}

void
espadvise(Proto *esp, Block *bp, char *msg)
{
       Esphdr *h;
       Conv *c;
       ulong spi;

       h = (Esphdr*)(bp->rp);

       spi = nhgets(h->espspi);
       qlock(esp);
       c = convlookup(esp, spi);
       if(c != nil) {
               qhangup(c->rq, msg);
               qhangup(c->wq, msg);
       }
       qunlock(esp);
       freeblist(bp);
}

int
espstats(Proto *esp, char *buf, int len)
{
       Esppriv *upriv;

       upriv = esp->priv;
       return snprint(buf, len, "%lud %lud\n",
               upriv->in,
               upriv->inerrors);
}

static int
esplocal(Conv *c, char *buf, int len)
{
       Espcb *ecb = c->ptcl;
       int n;

       qlock(c);
       if(ecb->incoming)
               n = snprint(buf, len, "%I!%uld\n", c->laddr, ecb->spi);
       else
               n = snprint(buf, len, "%I\n", c->laddr);
       qunlock(c);
       return n;
}

static int
espremote(Conv *c, char *buf, int len)
{
       Espcb *ecb = c->ptcl;
       int n;

       qlock(c);
       if(ecb->incoming)
               n = snprint(buf, len, "%I\n", c->raddr);
       else
               n = snprint(buf, len, "%I!%uld\n", c->raddr, ecb->spi);
       qunlock(c);
       return n;
}

static  Conv*
convlookup(Proto *esp, ulong spi)
{
       Conv *c, **p;
       Espcb *ecb;

       for(p=esp->conv; *p; p++){
               c = *p;
               ecb = c->ptcl;
               if(ecb->incoming && ecb->spi == spi)
                       return c;
       }
       return nil;
}

static char *
setalg(Espcb *ecb, char **f, int n, Algorithm *alg)
{
       uchar *key;
       int i, nbyte, nchar;
       int c;

       if(n < 2)
               return "bad format";
       for(; alg->name; alg++)
               if(strcmp(f[1], alg->name) == 0)
                       break;
       if(alg->name == nil)
               return "unknown algorithm";

       if(n != 3)
               return "bad format";
       nbyte = (alg->keylen + 7) >> 3;
       nchar = strlen(f[2]);
       for(i=0; i<nchar; i++) {
               c = f[2][i];
               if(c >= '0' && c <= '9')
                       f[2][i] -= '0';
               else if(c >= 'a' && c <= 'f')
                       f[2][i] -= 'a'-10;
               else if(c >= 'A' && c <= 'F')
                       f[2][i] -= 'A'-10;
               else
                       return "bad character in key";
       }
       key = smalloc(nbyte);
       for(i=0; i<nchar && i*2<nbyte; i++) {
               c = f[2][nchar-i-1];
               if(i&1)
                       c <<= 4;
               key[i>>1] |= c;
       }

       alg->init(ecb, alg->name, key, alg->keylen);
       free(key);
       return nil;
}

static int
nullcipher(Espcb*, uchar*, int)
{
       return 1;
}

static void
nullespinit(Espcb *ecb, char *name, uchar*, int)
{
       ecb->espalg = name;
       ecb->espblklen = 1;
       ecb->espivlen = 0;
       ecb->cipher = nullcipher;
}

static int
nullauth(Espcb*, uchar*, int, uchar*)
{
       return 1;
}

static void
nullahinit(Espcb *ecb, char *name, uchar*, int)
{
       ecb->ahalg = name;
       ecb->ahblklen = 1;
       ecb->ahlen = 0;
       ecb->auth = nullauth;
}

void
seanq_hmac_sha1(uchar hash[SHA1dlen], uchar *t, long tlen, uchar *key, long klen)
{
       uchar ipad[65], opad[65];
       int i;
       DigestState *digest;
       uchar innerhash[SHA1dlen];

       for(i=0; i<64; i++){
               ipad[i] = 0x36;
               opad[i] = 0x5c;
       }
       ipad[64] = opad[64] = 0;
       for(i=0; i<klen; i++){
               ipad[i] ^= key[i];
               opad[i] ^= key[i];
       }
       digest = sha1(ipad, 64, nil, nil);
       sha1(t, tlen, innerhash, digest);
       digest = sha1(opad, 64, nil, nil);
       sha1(innerhash, SHA1dlen, hash, digest);
}

static int
shaauth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
{
       uchar hash[SHA1dlen];
       int r;

       memset(hash, 0, SHA1dlen);
       seanq_hmac_sha1(hash, t, tlen, (uchar*)ecb->ahstate, 16);
       r = memcmp(auth, hash, ecb->ahlen) == 0;
       memmove(auth, hash, ecb->ahlen);
       return r;
}

static void
shaahinit(Espcb *ecb, char *name, uchar *key, int klen)
{
       if(klen != 128)
               panic("shaahinit: bad keylen");
       klen >>= 8;     // convert to bytes

       ecb->ahalg = name;
       ecb->ahblklen = 1;
       ecb->ahlen = 12;
       ecb->auth = shaauth;
       ecb->ahstate = smalloc(klen);
       memmove(ecb->ahstate, key, klen);
}

void
seanq_hmac_md5(uchar hash[MD5dlen], uchar *t, long tlen, uchar *key, long klen)
{
       uchar ipad[65], opad[65];
       int i;
       DigestState *digest;
       uchar innerhash[MD5dlen];

       for(i=0; i<64; i++){
               ipad[i] = 0x36;
               opad[i] = 0x5c;
       }
       ipad[64] = opad[64] = 0;
       for(i=0; i<klen; i++){
               ipad[i] ^= key[i];
               opad[i] ^= key[i];
       }
       digest = md5(ipad, 64, nil, nil);
       md5(t, tlen, innerhash, digest);
       digest = md5(opad, 64, nil, nil);
       md5(innerhash, MD5dlen, hash, digest);
}

static int
md5auth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
{
       uchar hash[MD5dlen];
       int r;

       memset(hash, 0, MD5dlen);
       seanq_hmac_md5(hash, t, tlen, (uchar*)ecb->ahstate, 16);
       r = memcmp(auth, hash, ecb->ahlen) == 0;
       memmove(auth, hash, ecb->ahlen);
       return r;
}

static void
md5ahinit(Espcb *ecb, char *name, uchar *key, int klen)
{
       if(klen != 128)
               panic("md5ahinit: bad keylen");
       klen >>= 3;     // convert to bytes


       ecb->ahalg = name;
       ecb->ahblklen = 1;
       ecb->ahlen = 12;
       ecb->auth = md5auth;
       ecb->ahstate = smalloc(klen);
       memmove(ecb->ahstate, key, klen);
}

static int
descipher(Espcb *ecb, uchar *p, int n)
{
       uchar tmp[8];
       uchar *pp, *tp, *ip, *eip, *ep;
       DESstate *ds = ecb->espstate;

       ep = p + n;
       if(ecb->incoming) {
               memmove(ds->ivec, p, 8);
               p += 8;
               while(p < ep){
                       memmove(tmp, p, 8);
                       block_cipher(ds->expanded, p, 1);
                       tp = tmp;
                       ip = ds->ivec;
                       for(eip = ip+8; ip < eip; ){
                               *p++ ^= *ip;
                               *ip++ = *tp++;
                       }
               }
       } else {
               memmove(p, ds->ivec, 8);
               for(p += 8; p < ep; p += 8){
                       pp = p;
                       ip = ds->ivec;
                       for(eip = ip+8; ip < eip; )
                               *pp++ ^= *ip++;
                       block_cipher(ds->expanded, p, 0);
                       memmove(ds->ivec, p, 8);
               }
       }
       return 1;
}

static void
desespinit(Espcb *ecb, char *name, uchar *k, int n)
{
       uchar key[8];
       uchar ivec[8];
       int i;

       // bits to bytes
       n = (n+7)>>3;
       if(n > 8)
               n = 8;
       memset(key, 0, sizeof(key));
       memmove(key, k, n);
       for(i=0; i<8; i++)
               ivec[i] = nrand(256);
       ecb->espalg = name;
       ecb->espblklen = 8;
       ecb->espivlen = 8;
       ecb->cipher = descipher;
       ecb->espstate = smalloc(sizeof(DESstate));
       setupDESstate(ecb->espstate, key, ivec);
}

static int
rc4cipher(Espcb *ecb, uchar *p, int n)
{
       Esprc4 *esprc4;
       RC4state tmpstate;
       ulong seq;
       long d, dd;

       if(n < 4)
               return 0;

       esprc4 = ecb->espstate;
       if(ecb->incoming) {
               seq = nhgetl(p);
               p += 4;
               n -= 4;
               d = seq-esprc4->cseq;
               if(d == 0) {
                       rc4(&esprc4->current, p, n);
                       esprc4->cseq += n;
                       if(esprc4->ovalid) {
                               dd = esprc4->cseq - esprc4->lgseq;
                               if(dd > RC4back)
                                       esprc4->ovalid = 0;
                       }
               } else if(d > 0) {
print("missing packet: %uld %ld\n", seq, d);
                       // this link is hosed
                       if(d > RC4forward) {
                               strcpy(up->errstr, "rc4cipher: skipped too much");
                               return 0;
                       }
                       esprc4->lgseq = seq;
                       if(!esprc4->ovalid) {
                               esprc4->ovalid = 1;
                               esprc4->oseq = esprc4->cseq;
                               memmove(&esprc4->old, &esprc4->current, sizeof(RC4state));
                       }
                       rc4skip(&esprc4->current, d);
                       rc4(&esprc4->current, p, n);
                       esprc4->cseq = seq+n;
               } else {
print("reordered packet: %uld %ld\n", seq, d);
                       dd = seq - esprc4->oseq;
                       if(!esprc4->ovalid || -d > RC4back || dd < 0) {
                               strcpy(up->errstr, "rc4cipher: too far back");
                               return 0;
                       }
                       memmove(&tmpstate, &esprc4->old, sizeof(RC4state));
                       rc4skip(&tmpstate, dd);
                       rc4(&tmpstate, p, n);
                       return 1;
               }

               // move old state up
               if(esprc4->ovalid) {
                       dd = esprc4->cseq - RC4back - esprc4->oseq;
                       if(dd > 0) {
                               rc4skip(&esprc4->old, dd);
                               esprc4->oseq += dd;
                       }
               }
       } else {
               hnputl(p, esprc4->cseq);
               p += 4;
               n -= 4;
               rc4(&esprc4->current, p, n);
               esprc4->cseq += n;
       }
       return 1;
}

static void
rc4espinit(Espcb *ecb, char *name, uchar *k, int n)
{
       Esprc4 *esprc4;

       // bits to bytes
       n = (n+7)>>3;
       esprc4 = smalloc(sizeof(Esprc4));
       memset(esprc4, 0, sizeof(Esprc4));
       setupRC4state(&esprc4->current, k, n);
       ecb->espalg = name;
       ecb->espblklen = 4;
       ecb->espivlen = 4;
       ecb->cipher = rc4cipher;
       ecb->espstate = esprc4;
}

void
espinit(Fs *fs)
{
       Proto *esp;

       esp = smalloc(sizeof(Proto));
       esp->priv = smalloc(sizeof(Esppriv));
       esp->name = "esp";
       esp->kick = espkick;
       esp->connect = espconnect;
       esp->announce = nil;
       esp->ctl = espctl;
       esp->state = espstate;
       esp->create = espcreate;
       esp->close = espclose;
       esp->rcv = espiput;
       esp->advise = espadvise;
       esp->stats = espstats;
       esp->local = esplocal;
       esp->remote = espremote;
       esp->ipproto = IP_ESPPROTO;
       esp->nc = Nchans;
       esp->ptclsize = sizeof(Espcb);

       Fsproto(fs, esp);
}