#include "os.h"
#include <mp.h>
#include <libsec.h>

RSApriv*
rsafill(mpint *n, mpint *e, mpint *d, mpint *p, mpint *q)
{
       mpint *c2, *kq, *kp, *x;
       RSApriv *rsa;

       // make sure we're not being hoodwinked
       if(!probably_prime(p, 10) || !probably_prime(q, 10)){
               werrstr("rsafill: p or q not prime");
               return nil;
       }
       x = mpnew(0);
       mpmul(p, q, x);
       if(mpcmp(n, x) != 0){
               werrstr("rsafill: n != p*q");
               mpfree(x);
               return nil;
       }
       c2 = mpnew(0);
       mpsub(p, mpone, c2);
       mpsub(q, mpone, x);
       mpmul(c2, x, x);
       mpmul(e, d, c2);
       mpmod(c2, x, x);
       if(mpcmp(x, mpone) != 0){
               werrstr("rsafill: e*d != 1 mod (p-1)*(q-1)");
               mpfree(x);
               mpfree(c2);
               return nil;
       }

       // compute chinese remainder coefficient
       mpinvert(p, q, c2);

       // for crt a**k mod p == (a**(k mod p-1)) mod p
       kq = mpnew(0);
       kp = mpnew(0);
       mpsub(p, mpone, x);
       mpmod(d, x, kp);
       mpsub(q, mpone, x);
       mpmod(d, x, kq);

       rsa = rsaprivalloc();
       rsa->pub.ek = mpcopy(e);
       rsa->pub.n = mpcopy(n);
       rsa->dk = mpcopy(d);
       rsa->kp = kp;
       rsa->kq = kq;
       rsa->p = mpcopy(p);
       rsa->q = mpcopy(q);
       rsa->c2 = c2;

       mpfree(x);

       return rsa;
}