#include "os.h"
#include <mp.h>
#include <libsec.h>

mpint*
dh_new(DHstate *dh, mpint *p, mpint *q, mpint *g)
{
       mpint *pm1;
       int n;

       memset(dh, 0, sizeof(*dh));
       if(mpcmp(g, mpone) <= 0)
               return nil;

       n = mpsignif(p);
       pm1 = mpnew(n);
       mpsub(p, mpone, pm1);
       dh->p = mpcopy(p);
       dh->g = mpcopy(g);
       dh->q = mpcopy(q != nil ? q : pm1);
       dh->x = mpnew(mpsignif(dh->q));
       dh->y = mpnew(n);
       for(;;){
               mpnrand(dh->q, genrandom, dh->x);
               mpexp(dh->g, dh->x, dh->p, dh->y);
               if(mpcmp(dh->y, mpone) > 0 && mpcmp(dh->y, pm1) < 0)
                       break;
       }
       mpfree(pm1);

       return dh->y;
}

mpint*
dh_finish(DHstate *dh, mpint *y)
{
       mpint *k = nil;

       if(y == nil || dh->x == nil || dh->p == nil || dh->q == nil)
               goto Out;

       /* y > 1 */
       if(mpcmp(y, mpone) <= 0)
               goto Out;

       k = mpnew(mpsignif(dh->p));

       /* y < p-1 */
       mpsub(dh->p, mpone, k);
       if(mpcmp(y, k) >= 0){
Bad:
               mpfree(k);
               k = nil;
               goto Out;
       }

       /* y**q % p == 1 if q < p-1 */
       if(mpcmp(dh->q, k) < 0){
               mpexp(y, dh->q, dh->p, k);
               if(mpcmp(k, mpone) != 0)
                       goto Bad;
       }

       mpexp(y, dh->x, dh->p, k);

Out:
       mpfree(dh->p);
       mpfree(dh->q);
       mpfree(dh->g);
       mpfree(dh->x);
       mpfree(dh->y);
       memset(dh, 0, sizeof(*dh));
       return k;
}