#include "common.h"
#include "smtpd.h"
#include <ip.h>

enum {
       NORELAY = 0,
       DNSVERIFY,
       SAVEBLOCK,
       DOMNAME,
       OURNETS,
       OURDOMS,

       IP = 0,
       STRING,
};


typedef struct Keyword Keyword;

struct Keyword {
       char    *name;
       int     code;
};

static Keyword options[] = {
       "norelay",              NORELAY,
       "verifysenderdom",      DNSVERIFY,
       "saveblockedmsg",       SAVEBLOCK,
       "defaultdomain",        DOMNAME,
       "ournets",              OURNETS,
       "ourdomains",           OURDOMS,
       0,                      NONE,
};

static Keyword actions[] = {
       "allow",                ACCEPT,
       "block",                BLOCKED,
       "deny",                 DENIED,
       "dial",                 DIALUP,
       "delay",                DELAY,
       0,                      NONE,
};

static  int     hisaction;
static  List    ourdoms;
static  List    badguys;

static  char*   getline(Biobuf*);

static int
findkey(char *val, Keyword *p)
{

       for(; p->name; p++)
               if(strcmp(val, p->name) == 0)
                               break;
       return p->code;
}

char*
actstr(int a)
{
       static char buf[32];
       Keyword *p;

       for(p = actions; p->name; p++)
               if(p->code == a)
                       return p->name;
       if(a == NONE)
               return "none";
       sprint(buf, "%d", a);
       return buf;
}

int
getaction(char *s, char *type)
{
       char buf[1024];
       Keyword *k;

       if(s == nil || *s == 0)
               return ACCEPT;

       for(k = actions; k->name != 0; k++){
               snprint(buf, sizeof buf, "/mail/ratify/%s/%s/%s", k->name, type, s);
               if(access(buf,0) >= 0)
                       return k->code;
       }
       return ACCEPT;
}

int
istrusted(char *s)
{
       char buf[Pathlen];

       if(s == nil || *s == 0)
               return 0;

       snprint(buf, sizeof buf, "/mail/ratify/trusted/%s", s);
       return access(buf, 0) >= 0;
}

void
getconf(void)
{
       Biobuf *bp;
       char *cp, *p;
       String *s;
       char buf[512];

       trusted = istrusted(nci->rsys);
       hisaction = getaction(nci->rsys, "ip");
       if(debug){
               fprint(2, "istrusted(%s)=%d\n", nci->rsys, trusted);
               fprint(2, "getaction(%s, ip)=%s\n", nci->rsys, actstr(hisaction));
       }
       snprint(buf, sizeof(buf), "%s/smtpd.conf", UPASLIB);
       bp = sysopen(buf, "r", 0);
       if(bp == 0)
               return;

       for(;;){
               cp = getline(bp);
               if(cp == 0)
                       break;
               p = cp + strlen(cp) + 1;
               switch(findkey(cp, options)){
               case NORELAY:
                       if(fflag == 0 && strcmp(p, "on") == 0)
                               fflag++;
                       break;
               case DNSVERIFY:
                       if(rflag == 0 && strcmp(p, "on") == 0)
                               rflag++;
                       break;
               case SAVEBLOCK:
                       if(sflag == 0 && strcmp(p, "on") == 0)
                               sflag++;
                       break;
               case DOMNAME:
                       if(dom == 0)
                               dom = strdup(p);
                       break;
               case OURNETS:
                       while(trusted == 0 && *p){
                               trusted = ipcheck(p);
                               p += strlen(p) + 1;
                       }
                       break;
               case OURDOMS:
                       while(*p){
                               s = s_new();
                               s_append(s, p);
                               listadd(&ourdoms, s);
                               p += strlen(p) + 1;
                       }
                       break;
               default:
                       break;
               }
       }
       sysclose(bp);
}

/*
*      match a user name.  the only meta-char is '*' which matches all
*      characters.  we only allow it as "*", which matches anything or
*      an * at the end of the name (e.g., "username*") which matches
*      trailing characters.
*/
static int
usermatch(char *pathuser, char *specuser)
{
       int n;

       n = strlen(specuser) - 1;
       if(specuser[n] == '*'){
               if(n == 0)              /* match everything */
                       return 0;
               return strncmp(pathuser, specuser, n);
       }
       return strcmp(pathuser, specuser);
}

static int
dommatch(char *pathdom, char *specdom)
{
       int n;

       if (*specdom == '*'){
               if (specdom[1] == '.' && specdom[2]){
                       specdom += 2;
                       n = strlen(pathdom) - strlen(specdom);
                       if(n == 0 || (n > 0 && pathdom[n-1] == '.'))
                               return strcmp(pathdom + n, specdom);
                       return n;
               }
       }
       return strcmp(pathdom, specdom);
}

/*
*  figure out action for this sender
*/
int
blocked(String *path)
{
       String *lpath;
       int action;

       if(debug)
               fprint(2, "blocked(%s)\n", s_to_c(path));

       /* if the sender's IP address is blessed, ignore sender email address */
       if(trusted){
               if(debug)
                       fprint(2, "\ttrusted => trusted\n");
               return TRUSTED;
       }

       /* if sender's IP address is blocked, ignore sender email address */
       if(hisaction != ACCEPT){
               if(debug)
                       fprint(2, "\thisaction=%s => %s\n", actstr(hisaction), actstr(hisaction));
               return hisaction;
       }

       /* convert to lower case */
       lpath = s_copy(s_to_c(path));
       s_tolower(lpath);

       /* classify */
       action = getaction(s_to_c(lpath), "account");
       if(debug)
               fprint(2, "\tgetaction account %s => %s\n", s_to_c(lpath), actstr(action));
       s_free(lpath);
       return action;
}

/*
* get a canonicalized line: a string of null-terminated lower-case
* tokens with a two null bytes at the end.
*/
static char*
getline(Biobuf *bp)
{
       char c, *cp, *p, *q;
       int n;

       static char *buf;
       static int bufsize;

       for(;;){
               cp = Brdline(bp, '\n');
               if(cp == 0)
                       return 0;
               n = Blinelen(bp);
               cp[n-1] = 0;
               if(buf == 0 || bufsize < n + 1){
                       bufsize += 512;
                       if(bufsize < n + 1)
                               bufsize = n + 1;
                       buf = realloc(buf, bufsize);
                       if(buf == 0)
                               break;
               }
               q = buf;
               for (p = cp; *p; p++){
                       c = *p;
                       if(c == '\\' && p[1])   /* we don't allow \<newline> */
                               c = *++p;
                       else
                       if(c == '#')
                               break;
                       else
                       if(c == ' ' || c == '\t' || c == ',')
                               if(q == buf || q[-1] == 0)
                                       continue;
                               else
                                       c = 0;
                       *q++ = tolower(c);
               }
               if(q != buf){
                       if(q[-1])
                               *q++ = 0;
                       *q = 0;
                       break;
               }
       }
       return buf;
}

static int
isourdom(char *s)
{
       Link *l;

       if(strchr(s, '.') == nil)
               return 1;

       for(l = ourdoms.first; l; l = l->next){
               if(dommatch(s, s_to_c(l->p)) == 0)
                       return 1;
       }
       return 0;
}

int
forwarding(String *path)
{
       char *cp, *s;
       String *lpath;

       if(debug)
               fprint(2, "forwarding(%s)\n", s_to_c(path));

       /* first check if they want loopback */
       lpath = s_copy(s_to_c(s_restart(path)));
       if(nci->rsys && *nci->rsys){
               cp = s_to_c(lpath);
               if(strncmp(cp, "[]!", 3) == 0){
found:
                       s_append(path, "[");
                       s_append(path, nci->rsys);
                       s_append(path, "]!");
                       s_append(path, cp + 3);
                       s_terminate(path);
                       s_free(lpath);
                       return 0;
               }
               cp = strchr(cp,'!');                    /* skip our domain and check next */
               if(cp++ && strncmp(cp, "[]!", 3) == 0)
                       goto found;
       }

       /* if mail is from a trusted IP addr, allow it to forward */
       if(trusted) {
               s_free(lpath);
               return 0;
       }

       /* sender is untrusted; ensure receiver is in one of our domains */
       for(cp = s_to_c(lpath); *cp; cp++)              /* convert receiver lc */
               *cp = tolower(*cp);

       for(s = s_to_c(lpath); cp = strchr(s, '!'); s = cp + 1){
               *cp = 0;
               if(!isourdom(s)){
                       s_free(lpath);
                       return 1;
               }
       }
       s_free(lpath);
       return 0;
}

int
masquerade(String *path, char *him)
{
       char *cp, *s;
       String *lpath;
       int rv = 0;

       if(debug)
               fprint(2, "masquerade(%s)\n", s_to_c(path));

       if(trusted)
               return 0;
       if(path == nil)
               return 0;

       lpath = s_copy(s_to_c(path));

       /* sender is untrusted; ensure receiver is in one of our domains */
       for(cp = s_to_c(lpath); *cp; cp++)              /* convert receiver lc */
               *cp = tolower(*cp);
       s = s_to_c(lpath);

       /* scan first element of ! or last element of @ paths */
       if((cp = strchr(s, '!')) != nil){
               *cp = 0;
               if(isourdom(s))
                       rv = 1;
       } else if((cp = strrchr(s, '@')) != nil){
               if(isourdom(cp + 1))
                       rv = 1;
       } else {
               if(isourdom(him))
                       rv = 1;
       }

       s_free(lpath);
       return rv;
}

int
isbadguy(void)
{
       Link *l;

       /* check if this IP address is banned */
       for(l = badguys.first; l; l = l->next)
               if(ipcheck(s_to_c(l->p)))
                       return 1;

       return 0;
}

void
addbadguy(char *p)
{
       listadd(&badguys, s_copy(p));
};

char*
dumpfile(char *sender)
{
       int i, fd;
       ulong h;
       static char buf[512];
       char *cp;

       if (sflag == 1){
               cp = ctime(time(0));
               cp[7] = 0;
               if(cp[8] == ' ')
                       sprint(buf, "%s/queue.dump/%s%c", SPOOL, cp + 4, cp[9]);
               else
                       sprint(buf, "%s/queue.dump/%s%c%c", SPOOL, cp + 4, cp[8], cp[9]);
               cp = buf + strlen(buf);
               if(access(buf, 0) < 0 && sysmkdir(buf, 0777) < 0)
                       return "/dev/null";
               h = 0;
               while(*sender)
                       h = h*257 + *sender++;
               for(i = 0; i < 50; i++){
                       h += lrand();
                       sprint(cp, "/%lud", h);
                       if(access(buf, 0) >= 0)
                               continue;
                       fd = create(buf, ORDWR, 0666);
                       if(fd >= 0){
                               if(debug)
                                       fprint(2, "saving in %s\n", buf);
                               close(fd);
                               return buf;
                       }
               }
       }
       return "/dev/null";
}

char *validator = "/mail/lib/validateaddress";

int
recipok(char *user)
{
       char *cp, *p, c;
       char buf[512];
       int n;
       Biobuf *bp;
       int pid;
       Waitmsg *w;

       if(shellchars(user)){
               syslog(0, "smtpd", "shellchars in user name");
               return 0;
       }

       if(access(validator, AEXEC) == 0)
       switch(pid = fork()) {
       case -1:
               break;
       case 0:
               execl(validator, "validateaddress", user, nil);
               exits(0);
       default:
               while(w = wait()) {
                       if(w->pid != pid)
                               continue;
                       if(w->msg[0] != 0){
                               /*
                               syslog(0, "smtpd", "validateaddress %s: %s", user, w->msg);
                               */
                               return 0;
                       }
                       break;
               }
       }

       snprint(buf, sizeof(buf), "%s/names.blocked", UPASLIB);
       bp = sysopen(buf, "r", 0);
       if(bp == 0)
               return 1;
       for(;;){
               cp = Brdline(bp, '\n');
               if(cp == 0)
                       break;
               n = Blinelen(bp);
               cp[n-1] = 0;

               while(*cp == ' ' || *cp == '\t')
                       cp++;
               for(p = cp; c = *p; p++){
                       if(c == '#')
                               break;
                       if(c == ' ' || c == '\t')
                               break;
               }
               if(p > cp){
                       *p = 0;
                       if(cistrcmp(user, cp) == 0){
                               syslog(0, "smtpd", "names.blocked blocks %s", user);
                               Bterm(bp);
                               return 0;
                       }
               }
       }
       Bterm(bp);
       return 1;
}

/*
*  a user can opt out of spam filtering by creating
*  a file in his mail directory named 'nospamfiltering'.
*/
int
optoutofspamfilter(char *addr)
{
       char *p, *f;
       int rv;

       p = strchr(addr, '!');
       if(p)
               p++;
       else
               p = addr;


       rv = 0;
       f = smprint("/mail/box/%s/nospamfiltering", p);
       if(f != nil){
               rv = access(f, 0) == 0;
               free(f);
       }

       return rv;
}