/*
* See Aho, Corasick Comm ACM 18:6, and Hoffman and O'Donell JACM 29:1
* for detail of the matching algorithm
*/

#ifndef va_start
#include <stdarg.h>
#endif

#include <stdio.h>
#include <stdlib.h>

#define _twig_assert(a,str) \
       do{if(!(a)){\
               fprintf(stderr, "twig internal error %s\n", (str));\
               abort();\
       }}while(0)

/* these are user defined, so can be macros */
#ifndef mtValue
int             mtValue(NODEPTR);
#endif
#ifndef mtSetNodes
void            mtSetNodes(NODEPTR, int, NODEPTR);
#endif
#ifndef mtGetNodes
NODEPTR         mtGetNodes(NODEPTR, int);
#endif

/* made by twig proper */
NODEPTR         mtAction(int, __match **, skeleton *);
short           mtEvalCost(__match *, __match **, skeleton *);

/* stuff defined here */
void            _addmatches(int, skeleton *, __match *);
void            _closure(int, int, skeleton *);
void            _do(skeleton *, __match *, int);
void            _evalleaves(__match **);
__match         **_getleaves(__match *, skeleton *);
int             _machstep(short, short);
void            _match(void);
void            _matchinit(void);
void            _merge(skeleton *, skeleton *);
NODEPTR         _mtG(NODEPTR, ...);
skeleton        *_walk(skeleton *, int);
__match         *_allocmatch(void);
skeleton        *_allocskel(void);
__partial       *_allocpartial(void);
void            _freematch(__match *);
void            _freeskel(skeleton *);
void            _freepartial(__partial *);
void            _prskel(skeleton *, int);

int mtDebug = 0;
int treedebug = 0;
extern short mtStart;
extern short mtTable[], mtAccept[], mtMap[], mtPaths[], mtPathStart[];
#ifdef LINE_XREF
       extern short mtLine[];
#endif

#ifndef MPBLOCKSIZ
#define MPBLOCKSIZ 10000
#endif
__match *_mpblock[MPBLOCKSIZ], **_mpbtop;

/*
* See sym.h in the preprocessor for details
* Basically used to support eh $%n$ construct.
*/
__match **
_getleaves(__match *mp, skeleton *skp)
{
       skeleton *stack[MAXDEPTH];
       skeleton **stp = stack;
       short *sip = &mtPaths[mtPathStart[mp->tree]];
       __match **mmp = _mpbtop;

       __match **mmp2 = mmp;
       _mpbtop += *sip++ + 1;
       _twig_assert(_mpbtop <= &_mpblock[MPBLOCKSIZ], "match block overflow");

       for(;;)
               switch(*sip++){
               case ePUSH:
                       *stp++ = skp;
                       skp = skp->leftchild;
                       break;
               case eNEXT:
                       skp = skp->sibling;
                       break;
               case eEVAL:
                       mp = skp->succ[M_DETAG(*sip++)];
                       _twig_assert(mp != 0, "bad eEVAL");
                       *mmp++ = mp;
                       break;
               case ePOP:
                       skp = *--stp;
                       break;
               case eSTOP:
                       *mmp = 0;
                       return mmp2;
               }
}

void
_do(skeleton *sp, __match *winner, int evalall)
{
       skeleton *skel;

       if(winner == 0) {
               _prskel(sp, 0);
               fprintf(stderr, "no winner");
               return;
       }

       skel = winner->skel;
       if(winner->mode == xDEFER || evalall && winner->mode != xTOPDOWN)
               REORDER(winner->lleaves);
       skel->root = mtAction(winner->tree, winner->lleaves, sp);
       mtSetNodes(skel->parent, skel->nson, skel->root);
}

void
_evalleaves(__match **mpp)
{
       __match *mp;
       while(*mpp){
               mp = *mpp++;
               _do(mp->skel, mp, 1);
       }
}

skeleton *
_walk(skeleton *sp, int ostate)
{
       int state, nstate, nson;
       __partial *pp;
       __match *mp;
       skeleton *nsp, *lastchild = 0;
       NODEPTR son, root;

       root = sp->root;
       nson = 1;
       sp->mincost = INFINITY;
       state = _machstep(ostate, mtValue(root));

       while(son = mtGetNodes(root, nson)){
               nstate = _machstep(state, MV_BRANCH(nson));
               nsp = _allocskel();
               nsp->root = son;
               nsp->parent = root;
               nsp->nson = nson;
               _walk(nsp, nstate);
               if(COSTLESS(nsp->mincost, INFINITY)){
                       _twig_assert(nsp->winner->mode == xREWRITE, "bad mode");
                       if(mtDebug || treedebug) {
                               fprintf(stderr, "rewrite\n");
                               _prskel(nsp, 0);
                       }
                       _do(nsp, nsp->winner, 0);
                       _freeskel(nsp);
                       continue;
               }
               _merge(sp, nsp);
               if(lastchild == 0)
                       sp->leftchild = nsp;
               else
                       lastchild->sibling = nsp;
               lastchild = nsp;
               nson++;
       }

       for(pp = sp->partial; pp < &sp->partial[sp->treecnt]; pp++)
               if(pp->bits & 01){
                       mp = _allocmatch();
                       mp->tree = pp->treeno;
                       _addmatches(ostate, sp, mp);
               }
       if(son == 0 && nson == 1)
               _closure(state, ostate, sp);

       sp->rightchild = lastchild;
       if(root == 0){
               COST c;
               __match *win;
               int i;

               nsp = sp->leftchild;
               c = INFINITY;
               win = 0;
               for(i = 0; i < MAXLABELS; i++){
                       mp = nsp->succ[i];
                       if(mp != 0 && COSTLESS(mp->cost, c)){
                               c = mp->cost;
                               win = mp;
                       }
               }
               if(mtDebug || treedebug)
                       _prskel(nsp, 0);
               _do(nsp, win, 0);
       }
       if(mtDebug)
               _prskel(sp, 0);
       return sp;
}

static short _nodetab[MAXNDVAL], _labeltab[MAXLABELS];

/*
* Convert the start state which has a large branching factor into
* a index table.  This must be called before the matcher is used.
*/
void
_matchinit(void)
{
       short *sp;

       for(sp = _nodetab; sp < &_nodetab[MAXNDVAL]; sp++)
               *sp = HANG;
       for(sp = _labeltab; sp < &_labeltab[MAXLABELS]; sp++)
               *sp = HANG;
       sp = &mtTable[mtStart];
       _twig_assert(*sp == TABLE, "mtTable[mtStart]!=TABLE");
       for(++sp; *sp != -1; sp += 2){
               if(MI_NODE(*sp))
                       _nodetab[M_DETAG(*sp)] = sp[1];
               else if(MI_LABEL(*sp))
                       _labeltab[M_DETAG(*sp)] = sp[1];
       }
}

int
_machstep(short state, short input)
{
       short *stp = &mtTable[state];
       int start = 0;

       if(state == HANG)
               return input == MV_BRANCH(1) ? mtStart : HANG;
rescan:
       if(stp == &mtTable[mtStart]){
               if(MI_NODE(input))
                       return _nodetab[M_DETAG(input)];
               if(MI_LABEL(input))
                       return _labeltab[M_DETAG(input)];
       }

       for(;;){
               if(*stp == ACCEPT)
                       stp += 2;
               if(*stp == TABLE){
                       stp++;
                       while(*stp != -1)
                               if(input == *stp)
                                       return stp[1];
                               else
                                       stp += 2;
                       stp++;
               }
               if(*stp != FAIL){
                       if(start)
                               return HANG;
                       else{
                               stp = &mtTable[mtStart];
                               start = 1;
                               goto rescan;
                       }
               }else{
                       stp = &mtTable[stp[1]];
                       goto rescan;
               }
       }
}

void
_addmatches(int ostate, skeleton *sp, __match *np)
{
       int label;
       int state;
       __match *mp;

       label = mtMap[np->tree];

       /*
        * this is a very poor substitute for good design of the DFA.
        * What we need is a special case that allows any label to be accepted
        * by the start state but we don't want the start state to recognize
        * them after a failure.
        */
       state = _machstep(ostate, MV_LABEL(label));
       if(ostate != mtStart && state == HANG){
               _freematch(np);
               return;
       }

       np->lleaves = _getleaves(np, sp);
       np->skel = sp;
       if((np->mode = mtEvalCost(np, np->lleaves, sp)) == xABORT){
               _freematch(np);
               return;
       }

       if(mp = sp->succ[label]){
               if(!COSTLESS(np->cost, mp->cost)){
                       _freematch(np);
                       return;
               }
               _freematch(mp);
       }
       if(COSTLESS(np->cost, sp->mincost)){
               if(np->mode == xREWRITE){
                       sp->mincost = np->cost;
                       sp->winner = np;
               }else{
                       sp->mincost = INFINITY;
                       sp->winner = 0;
               }
       }
       sp->succ[label] = np;
       _closure(state, ostate, sp);
}

void
_closure(int state, int ostate, skeleton *skp)
{
       short *sp = &mtTable[state];
       __match *mp;

       if(state == HANG || *sp != ACCEPT)
               return;

       for(sp = &mtAccept[sp[1]]; *sp != -1; sp += 2)
               if(sp[1] == 0){
                       mp = _allocmatch();
                       mp->tree = *sp;
                       _addmatches(ostate, skp, mp);
               }else{
                       __partial *pp;
                       __partial *lim = &skp->partial[skp->treecnt];
                       for(pp = skp->partial; pp < lim; pp++)
                               if(pp->treeno == *sp)
                                       break;
                       if(pp == lim){
                               skp->treecnt++;
                               pp->treeno = *sp;
                               pp->bits = 1 << sp[1];
                       }else
                               pp->bits |= 1 << sp[1];
               }
}

void
_merge(skeleton *old, skeleton *new)
{
       __partial *op = old->partial, *np = new->partial;
       int nson = new->nson;
       __partial *lim = np + new->treecnt;
       if(nson == 1){
               old->treecnt = new->treecnt;
               for(; np < lim; op++, np++) {
                       op->treeno = np->treeno;
                       op->bits = np->bits / 2;
               }
       }else{
               __partial *newer = _allocpartial();
               __partial *newerp = newer;
               int cnt;
               lim = op + old->treecnt;
               for(cnt = new->treecnt; cnt-- ; np++){
                       for(op = old->partial; op < lim; op++)
                               if(op->treeno == np->treeno){
                                       newerp->treeno = op->treeno;
                                       newerp++->bits = op->bits & np->bits / 2;
                                       break;
                               }
               }
               _freepartial(old->partial);
               old->partial = newer;
               old->treecnt = newerp-newer;
       }
}

/* memory management */

#define BLKF    100

typedef union __matchalloc{
       __match                 it;
       union __matchalloc      *next;
}__matchalloc;
static __matchalloc *__matchfree = 0;
#ifdef CHECKMEM
staic int a_matches;
#endif

__match *
_allocmatch(void)
{
       static int count = 0;
       static __matchalloc *block = 0;
       __matchalloc *m;

       if(__matchfree){
               m = __matchfree;
               __matchfree = __matchfree->next;
       }else{
               if(count == 0){
                       block = (void *)malloc(BLKF * sizeof *block);
                       count = BLKF;
               }
               m = block++;
               count--;
       }
#ifdef CHECKMEM
               a_matches++;
#endif
       return (__match *)m;
}

void
_freematch(__match *mp)
{
       __matchalloc *m = (__matchalloc *)mp;
       m->next = __matchfree;
       __matchfree = m;
#ifdef CHECKMEM
       a_matches--;
#endif
}

typedef union __partalloc{
       __partial               it[MAXTREES];
       union __partalloc       *next;
}__partalloc;
static __partalloc *__partfree = 0;
#ifdef CHECKMEM
static int a_partials;
#endif

__partial *
_allocpartial(void)
{
       static int count = 0;
       static __partalloc *block = 0;
       __partalloc *p;

       if(__partfree != 0){
               p = __partfree;
               __partfree = __partfree->next;
       }else{
               if(count == 0){
                       block = (void *)malloc(BLKF * sizeof *block);
                       count = BLKF;
               }
               p = block++;
               count--;
       }
#ifdef CHECKMEM
               a_partials++;
#endif
       return (__partial *)p;
}

void
_freepartial(__partial *pp)
{
       __partalloc *p = (__partalloc *)pp;
       p->next = __partfree;
       __partfree = p;
#ifdef CHECKMEM
       a_partials--;
#endif
}

typedef union __skelalloc{
       skeleton                it;
       union __skelalloc       *next;
}__skelalloc;
static __skelalloc *__skelfree = 0;

skeleton *
_allocskel(void)
{
       static int count = 0;
       static __skelalloc *block = 0;
       __skelalloc *sf;
       skeleton *s;
       int i;

       if(__skelfree){
               sf = __skelfree;
               __skelfree = sf->next;
       }else{
               if(count == 0){
                       block = (void *)malloc(BLKF * sizeof *block);
                       count = BLKF;
               }
               sf = block++;
               count--;
       }
       s = (skeleton *)sf;
       s->sibling = 0;
       s->leftchild = 0;
       i = 0;
       while(i < MAXLABELS)
               s->succ[i++] = 0;
       s->treecnt = 0;
       s->partial = _allocpartial();
       return s;
}

void
_freeskel(skeleton *s)
{
       int i;
       __match *mp;
       __skelalloc *sf;

       if(s == 0)
               return;
       if(s->leftchild)
               _freeskel(s->leftchild);
       if(s->sibling)
               _freeskel(s->sibling);
       _freepartial(s->partial);
       i = 0;
       while(i < MAXLABELS)
               if(mp = s->succ[i++])
                       _freematch(mp);
       sf = (__skelalloc *)s;
       sf->next = __skelfree;
       __skelfree = sf;
}

void
_match(void)
{
       skeleton *sp;
       sp = _allocskel();
       sp->root = 0;
       _mpbtop = _mpblock;
       _freeskel(_walk(sp, HANG));
#ifdef CHECKMEM
       _twig_assert(a_matches == 0, "__match memory leak");
       _twig_assert(a_partials == 0, "__partial memory leak");
#endif
}

NODEPTR
_mtG(NODEPTR root, ...)
{
       va_list a;
       int i;

       va_start(a, root);
       while((i = va_arg(a, int)) != -1)
               root = mtGetNodes(root, i);
       va_end(a);
       return root;
}

/* diagnostic routines */
void
_prskel(skeleton *skp, int lvl)
{
       int i;
       __match *mp;
       __partial *pp;
       if(skp==0)
               return;
       for(i = lvl; i > 0; i--)
               fprintf(stderr, "  ");
       fprintf(stderr, "###\n");
       for(i = lvl; i > 0; i--)
               fprintf(stderr, "  ");
       for(i = 0; i < MAXLABELS; i++)
               if(mp = skp->succ[i])
#ifdef LINE_XREF
                       fprintf(stderr, "[%d<%d> %d]", mp->tree,
                               mtLine[mp->tree], mp->cost);
#else
                       fprintf(stderr, "[%d %d]", mp->tree, mp->cost);
#endif
       fprintf(stderr, "\n");
       for(i = lvl; i > 0; i--)
               fprintf(stderr, "  ");
       for(i = 0, pp=skp->partial; i < skp->treecnt; i++, pp++)
#ifdef LINE_XREF
                       fprintf(stderr, "(%d<%d> %x)", pp->treeno, mtLine[pp->treeno],
                               pp->bits);
#else
                       fprintf(stderr, "(%d %x)", pp->treeno, pp->bits);
#endif
       fprintf(stderr, "\n");
       _prskel(skp->leftchild, lvl + 2);
       _prskel(skp->sibling, lvl);
}