/* minimal stateless DHCPv6 server for network boot */
#include <u.h>
#include <libc.h>
#include <ip.h>
#include <bio.h>
#include <ndb.h>
enum {
Eaddrlen = 6,
SOLICIT = 1,
ADVERTISE,
REQUEST,
CONFIRM,
RENEW,
REBIND,
REPLY,
RELEASE,
DECLINE,
RECONFIGURE,
INFOREQ,
RELAYFORW,
RELAYREPL,
};
typedef struct Req Req;
struct Req
{
int tra;
Udphdr *udp;
Ipifc *ifc;
int ncid;
uchar cid[256];
uchar ips[IPaddrlen*8];
int nips;
Ndb *db;
Ndbtuple *t;
struct {
int t;
uchar *p;
uchar *e;
} req;
struct {
int t;
uchar *p;
uchar *e;
} resp;
};
typedef struct Otab Otab;
struct Otab
{
int t;
int (*f)(uchar *, int, Otab*, Req*);
char *q[3];
int done;
};
static Otab otab[];
static Ipifc *ipifcs;
static ulong starttime;
static char *ndbfile;
static char *netmtpt = "/net";
static int debug;
static uchar v6loopback[IPaddrlen] = {
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 1
};
/*
* open ndbfile as db if not already open. also check for stale data
* and reload as needed.
*/
static Ndb *
opendb(void)
{
static ulong lastcheck;
static Ndb *db;
ulong now = time(nil);
/* check no more often than once every minute */
if(db == nil) {
db = ndbopen(ndbfile);
if(db != nil)
lastcheck = now;
} else if(now >= lastcheck + 60) {
if (ndbchanged(db))
ndbreopen(db);
lastcheck = now;
}
return db;
}
static Ipifc*
findifc(char *net, uchar ip[IPaddrlen])
{
Ipifc *ifc;
Iplifc *lifc;
ipifcs = readipifc(net, ipifcs, -1);
for(ifc = ipifcs; ifc != nil; ifc = ifc->next)
for(lifc = ifc->lifc; lifc != nil; lifc = lifc->next)
if(ipcmp(lifc->ip, ip) == 0)
return ifc;
return nil;
}
static Iplifc*
localonifc(Ipifc *ifc, uchar ip[IPaddrlen])
{
Iplifc *lifc;
uchar net[IPaddrlen];
for(lifc = ifc->lifc; lifc != nil; lifc = lifc->next){
maskip(ip, lifc->mask, net);
if(ipcmp(net, lifc->net) == 0)
return lifc;
}
return nil;
}
static int
openlisten(char *net)
{
int fd, cfd;
char data[128], devdir[40];
Ipifc *ifc;
Iplifc *lifc;
sprint(data, "%s/udp!*!dhcp6s", net);
cfd = announce(data, devdir);
if(cfd < 0)
sysfatal("can't announce: %r");
if(fprint(cfd, "headers") < 0)
sysfatal("can't set header mode: %r");
ipifcs = readipifc(net, ipifcs, -1);
for(ifc = ipifcs; ifc != nil; ifc = ifc->next){
if(strcmp(ifc->dev, "/dev/null") == 0)
continue;
for(lifc = ifc->lifc; lifc != nil; lifc = lifc->next){
if(!ISIPV6LINKLOCAL(lifc->ip))
continue;
if(fprint(cfd, "addmulti %I ff02::1:2", lifc->ip) < 0)
fprint(2, "addmulti: %I: %r\n", lifc->ip);
}
}
sprint(data, "%s/data", devdir);
fd = open(data, ORDWR);
if(fd < 0)
sysfatal("open udp data: %r");
return fd;
}
static uchar*
gettlv(int x, int *plen, uchar *p, uchar *e)
{
int t;
int l;
if(plen != nil)
*plen = 0;
while(p+4 <= e){
t = p[0]<<8 | p[1];
l = p[2]<<8 | p[3];
if(p+4+l > e)
break;
if(t == x){
if(plen != nil)
*plen = l;
return p+4;
}
p += l+4;
}
return nil;
}
static int
getv6ips(uchar *ip, int n, Ndbtuple *t, char *attr)
{
int r = 0;
if(n < IPaddrlen)
return 0;
if(*attr == '@')
attr++;
for(; t != nil; t = t->entry){
if(strcmp(t->attr, attr) != 0)
continue;
if(parseip(ip, t->val) == -1)
continue;
if(isv4(ip))
continue;
ip += IPaddrlen;
r += IPaddrlen;
if(r >= n)
break;
}
return r;
}
static int
lookupips(uchar *ip, int n, Ndb *db, uchar mac[Eaddrlen])
{
Ndbtuple *t;
Ndbs s;
char val[256], *attr;
int r;
/*
* use hardware address to find an ip address
*/
attr = "ether";
snprint(val, sizeof val, "%E", mac);
t = ndbsearch(db, &s, attr, val);
r = 0;
while(t != nil){
r += getv6ips(ip + r, n - r, t, "ip");
ndbfree(t);
if(r >= n)
break;
t = ndbsnext(&s, attr, val);
}
return r;
}
static uchar*
clientea(Req *r)
{
static uchar ea[Eaddrlen];
u32int type;
uchar *ip;
if(r->ncid >= 4+Eaddrlen){
type = r->cid[0]<<24 | r->cid[1]<<16 | r->cid[2]<<8 | r->cid[3];
switch(type){
case 0x00010001:
case 0x00030001:
return r->cid + r->ncid - Eaddrlen;
}
}
ip = r->udp->raddr;
ea[0] = ip[8] ^ 2;
ea[1] = ip[9];
ea[2] = ip[10];
ea[3] = ip[13];
ea[4] = ip[14];
ea[5] = ip[15];
return ea;
}
static void
clearotab(void)
{
Otab *o;
for(o = otab; o->t != 0; o++)
o->done = 0;
}
static Otab*
findotab(int t)
{
Otab *o;
for(o = otab; o->t != 0; o++)
if(o->t == t)
return o;
return nil;
}
static int
addoption(Req *r, int t)
{
Otab *o;
int n;
if(r->resp.p+4 > r->resp.e)
return -1;
o = findotab(t);
if(o == nil || o->f == nil || o->done)
return -1;
o->done = 1;
n = (*o->f)(r->resp.p+4, r->resp.e - (r->resp.p+4), o, r);
if(n < 0 || r->resp.p+4+n > r->resp.e)
return -1;
r->resp.p[0] = t>>8, r->resp.p[1] = t;
r->resp.p[2] = n>>8, r->resp.p[3] = n;
if(debug) fprint(2, "%d(%.*H)\n", t, n, r->resp.p+4);
r->resp.p += 4+n;
return n;
}
static void
usage(void)
{
fprint(2, "%s [-d] [-f ndbfile] [-x netmtpt]\n", argv0);
exits("usage");
}
void
main(int argc, char *argv[])
{
uchar ibuf[4096], obuf[4096];
Req r[1];
int fd, n, i;
fmtinstall('H', encodefmt);
fmtinstall('I', eipfmt);
fmtinstall('E', eipfmt);
ARGBEGIN {
case 'd':
debug++;
break;
case 'f':
ndbfile = EARGF(usage());
break;
case 'x':
netmtpt = EARGF(usage());
break;
default:
usage();
} ARGEND;
starttime = time(nil) - 946681200UL;
if(opendb() == nil)
sysfatal("opendb: %r");
fd = openlisten(netmtpt);
/* put process in background */
if(!debug)
switch(rfork(RFNOTEG|RFPROC|RFFDG)) {
default:
exits(nil);
case -1:
sysfatal("fork: %r");
case 0:
break;
}
while((n = read(fd, ibuf, sizeof(ibuf))) > 0){
if(n < Udphdrsize+4)
continue;
r->udp = (Udphdr*)ibuf;
if(isv4(r->udp->raddr))
continue;
if((r->ifc = findifc(netmtpt, r->udp->ifcaddr)) == nil)
continue;
if(localonifc(r->ifc, r->udp->raddr) == nil)
continue;
memmove(obuf, ibuf, Udphdrsize);
r->req.p = ibuf+Udphdrsize;
r->req.e = ibuf+n;
r->resp.p = obuf+Udphdrsize;
r->resp.e = &obuf[sizeof(obuf)];
r->tra = r->req.p[1]<<16 | r->req.p[2]<<8 | r->req.p[3];
r->req.t = r->req.p[0];
if(debug)
fprint(2, "%I->%I(%s) typ=%d tra=%x\n",
r->udp->raddr, r->udp->laddr, r->ifc->dev,
r->req.t, r->tra);
switch(r->req.t){
default:
continue;
case SOLICIT:
r->resp.t = ADVERTISE;
break;
case REQUEST:
case INFOREQ:
r->resp.t = REPLY;
break;
}
r->resp.p[0] = r->resp.t;
r->resp.p[1] = r->tra>>16;
r->resp.p[2] = r->tra>>8;
r->resp.p[3] = r->tra;
r->req.p += 4;
r->resp.p += 4;
r->t = nil;
clearotab();
/* Server Identifier */
if(addoption(r, 2) < 0)
continue;
/* Client Identifier */
if(addoption(r, 1) < 0)
continue;
if((r->db = opendb()) == nil)
continue;
r->nips = lookupips(r->ips, sizeof(r->ips), r->db, clientea(r))/IPaddrlen;
if(debug){
for(i=0; i<r->nips; i++)
fprint(2, "ip=%I\n", r->ips+i*IPaddrlen);
}
addoption(r, 3);
addoption(r, 6);
write(fd, obuf, r->resp.p-obuf);
if(debug) fprint(2, "\n");
}
exits(nil);
}
static int
oclientid(uchar *w, int n, Otab*, Req *r)
{
int len;
uchar *p;
if((p = gettlv(1, &len, r->req.p, r->req.e)) == nil)
return -1;
if(len > sizeof(r->cid) || n < len)
return -1;
memmove(w, p, len);
memmove(r->cid, p, len);
r->ncid = len;
return len;
}
static int
oserverid(uchar *w, int n, Otab*, Req *r)
{
int len;
uchar *p;
if(n < 4+4+Eaddrlen)
return -1;
w[0] = 0, w[1] = 1; /* duid type: link layer address + time*/
w[2] = 0, w[3] = 1; /* hw type: ethernet */
w[4] = starttime>>24;
w[5] = starttime>>16;
w[6] = starttime>>8;
w[7] = starttime;
myetheraddr(w+8, r->ifc->dev);
/* check if server id matches from the request */
p = gettlv(2, &len, r->req.p, r->req.e);
if(p != nil && (len != 4+4+Eaddrlen || memcmp(w, p, 4+4+Eaddrlen) != 0))
return -1;
return 4+4+Eaddrlen;
}
static int
oiana(uchar *w, int n, Otab*, Req *r)
{
int i, len;
uchar *p;
p = gettlv(3, &len, r->req.p, r->req.e);
if(p == nil || len < 3*4)
return -1;
len = 3*4 + (4+IPaddrlen+2*4)*r->nips;
if(n < len)
return -1;
memmove(w, p, 3*4);
w += 3*4;
for(i = 0; i < r->nips; i++){
w[0] = 0, w[1] = 5;
w[2] = 0, w[3] = IPaddrlen+2*4;
w += 4;
memmove(w, r->ips + i*IPaddrlen, IPaddrlen);
w += IPaddrlen;
memset(w, 255, 2*4);
w += 2*4;
}
return len;
}
static Ndbtuple*
lookup(Req *r, char *av[], int ac)
{
Ndbtuple *t;
char *s;
if(ac <= 0)
return nil;
t = nil;
if(r->nips > 0){
int i;
/* use the target ip's to lookup info if any */
for(i=0; i<r->nips; i++){
s = smprint("%I", &r->ips[i*IPaddrlen]);
t = ndbconcatenate(t, ndbipinfo(r->db, "ip", s, av, ac));
free(s);
}
} else {
Iplifc *lifc;
/* use the ipv6 networks on the interface */
for(lifc=r->ifc->lifc; lifc!=nil; lifc=lifc->next){
if(isv4(lifc->ip)
|| ipcmp(lifc->ip, v6loopback) == 0
|| ISIPV6LINKLOCAL(lifc->ip))
continue;
s = smprint("%I", lifc->net);
t = ndbconcatenate(t, ndbipinfo(r->db, "ip", s, av, ac));
free(s);
}
}
return t;
}
static int
oro(uchar*, int, Otab *o, Req *r)
{
uchar *p;
char *av[100];
int i, j, l, ac;
Ndbtuple *t;
p = gettlv(6, &l, r->req.p, r->req.e);
if(p == nil || l < 2)
return -1;
ac = 0;
for(i=0; i<l; i+=2){
if((o = findotab(p[i]>>8 | p[i+1])) == nil || o->done)
continue;
for(j=0; j<3 && o->q[j]!=nil && ac<nelem(av); j++)
av[ac++] = o->q[j];
}
r->t = lookup(r, av, ac);
if(debug){
fprint(2, "ndb(");
for(t = r->t; t != nil; t = t->entry){
fprint(2, "%s=%s ", t->attr, t->val);
if(t->entry != nil && t->entry != t->line)
fprint(2, "\n");
}
fprint(2, ")\n");
}
/* process the options */
for(i=0; i<l; i+=2)
addoption(r, p[i]>>8 | p[i+1]);
ndbfree(r->t);
r->t = nil;
return -1;
}
static int
oservers(uchar *w, int n, Otab *o, Req *r)
{
return getv6ips(w, n, r->t, o->q[0]);
}
static int
odomainlist(uchar *w, int n, Otab *o, Req *q)
{
char val[256];
Ndbtuple *t;
int l, r;
char *s;
r = 0;
for(t = q->t; t != nil; t = t->entry){
if(strcmp(t->attr, o->q[0]) != 0)
continue;
if(utf2idn(t->val, val, sizeof(val)) <= 0)
continue;
for(s = val; *s != 0; s++){
for(l = 0; *s != 0 && *s != '.'; l++)
s++;
if(r+1+l > n)
return -1;
w[r++] = l;
memmove(w+r, s-l, l);
r += l;
if(*s != '.')
break;
}
if(r >= n)
return -1;
w[r++] = 0;
}
return r;
}
static int
obootfileurl(uchar *w, int n, Otab *, Req *q)
{
uchar ip[IPaddrlen];
Ndbtuple *bootf;
if((bootf = ndbfindattr(q->t, q->t, "bootf")) == nil)
return -1;
if(strstr(bootf->val, "://") != nil)
return snprint((char*)w, n, "%s", bootf->val);
else if(getv6ips(ip, sizeof(ip), q->t, "tftp"))
return snprint((char*)w, n, "tftp://[%I]/%s", ip, bootf->val);
return -1;
}
static Otab otab[] = {
{ 1, oclientid, },
{ 2, oserverid, },
{ 3, oiana, },
{ 6, oro, },
{ 23, oservers, "@dns" },
{ 24, odomainlist, "dnsdomain" },
{ 59, obootfileurl, "bootf", "@tftp", },
{ 0 },
};