/*      $NetBSD: tcp.c,v 1.11 2011/08/30 21:14:06 joerg Exp $   */
/*      $KAME: tcp.c,v 1.10 2002/08/20 23:01:01 itojun Exp $    */

/*
* Copyright (C) 1997 and 1998 WIDE Project.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
*    notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
*    notice, this list of conditions and the following disclaimer in the
*    documentation and/or other materials provided with the distribution.
* 3. Neither the name of the project nor the names of its contributors
*    may be used to endorse or promote products derived from this software
*    without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE PROJECT AND CONTRIBUTORS ``AS IS'' AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED.  IN NO EVENT SHALL THE PROJECT OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
*/

#include <sys/param.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/ioctl.h>
#include <sys/time.h>
#include <sys/wait.h>

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <syslog.h>
#include <unistd.h>
#include <errno.h>
#include <fcntl.h>
#include <signal.h>

#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>

#include "faithd.h"

static char tcpbuf[16*1024];
       /* bigger than MSS and may be lesser than window size */
static ssize_t tblen, tboff;
static int oob_exists;
static fd_set readfds, writefds, exceptfds;
static char atmark_buf[2];
static pid_t cpid = (pid_t)0;
static pid_t ppid = (pid_t)0;
volatile time_t child_lastactive = (time_t)0;
static time_t parent_lastactive = (time_t)0;

static void sig_ctimeout(int);
static void sig_child(int) __dead;
static void notify_inactive(void);
static void notify_active(void);
static void send_data(int, int, const char *, int);
static void relay(int, int, const char *, int) __dead;

/*
* Inactivity timer:
* - child side (ppid != 0) will send SIGUSR1 to parent every (FAITH_TIMEOUT/4)
*   second if traffic is active.  if traffic is inactive, don't send SIGUSR1.
* - parent side (ppid == 0) will check the last SIGUSR1 it have seen.
*/
static void
/*ARGSUSED*/
sig_ctimeout(int sig)
{
       /* parent side: record notification from the child */
       if (dflag)
               syslog(LOG_DEBUG, "activity timer from child");
       child_lastactive = time(NULL);
}

/* parent will terminate if child dies. */
static void
/*ARGSUSED*/
sig_child(int sig)
{
       int status;
       pid_t pid;

       pid = wait3(&status, WNOHANG, (struct rusage *)0);
       if (pid > 0 && WEXITSTATUS(status))
               syslog(LOG_WARNING, "child %ld exit status 0x%x",
                   (long)pid, status);
       exit_success("terminate connection due to child termination");
}

static void
notify_inactive()
{
       time_t t;

       /* only on parent side... */
       if (ppid)
               return;

       /* parent side should check for timeout. */
       t = time(NULL);
       if (dflag) {
               syslog(LOG_DEBUG, "parent side %sactive, child side %sactive",
                       (FAITH_TIMEOUT < t - parent_lastactive) ? "in" : "",
                       (FAITH_TIMEOUT < t - child_lastactive) ? "in" : "");
       }

       if (FAITH_TIMEOUT < t - child_lastactive
        && FAITH_TIMEOUT < t - parent_lastactive) {
               /* both side timeouted */
               (void)signal(SIGCHLD, SIG_DFL);
               (void)kill(cpid, SIGTERM);
               (void)wait(NULL);
               exit_failure("connection timeout");
               /* NOTREACHED */
       }
}

static void
notify_active()
{
       if (ppid) {
               /* child side: notify parent of active traffic */
               time_t t;
               t = time(NULL);
               if (FAITH_TIMEOUT / 4 < t - child_lastactive) {
                       if (kill(ppid, SIGUSR1) < 0) {
                               exit_failure("terminate connection due to parent termination");
                               /* NOTREACHED */
                       }
                       child_lastactive = t;
               }
       } else {
               /* parent side */
               parent_lastactive = time(NULL);
       }
}

static void
/*ARGSUSED*/
send_data(int s_rcv, int s_snd, const char *service, int direction)
{
       ssize_t cc;

       if (oob_exists) {
               cc = send(s_snd, atmark_buf, 1, MSG_OOB);
               if (cc == -1)
                       goto retry_or_err;
               oob_exists = 0;
               if (s_rcv >= FD_SETSIZE)
                       exit_failure("descriptor too big");
               FD_SET(s_rcv, &exceptfds);
       }

       for (; tboff < tblen; tboff += cc) {
               cc = write(s_snd, tcpbuf + tboff, (size_t)(tblen - tboff));
               if (cc < 0)
                       goto retry_or_err;
       }
#ifdef DEBUG
       if (tblen) {
               if (tblen >= sizeof(tcpbuf))
                       tblen = sizeof(tcpbuf) - 1;
               tcpbuf[tblen] = '\0';
               syslog(LOG_DEBUG, "from %s (%dbytes): %s",
                      direction == 1 ? "client" : "server", tblen, tcpbuf);
       }
#endif /* DEBUG */
       tblen = 0; tboff = 0;
       if (s_snd >= FD_SETSIZE)
               exit_failure("descriptor too big");
       FD_CLR(s_snd, &writefds);
       if (s_rcv >= FD_SETSIZE)
               exit_failure("descriptor too big");
       FD_SET(s_rcv, &readfds);
       return;
   retry_or_err:
       if (errno != EAGAIN)
               exit_failure("writing relay data failed: %s", strerror(errno));
       if (s_snd >= FD_SETSIZE)
               exit_failure("descriptor too big");
       FD_SET(s_snd, &writefds);
}

static void
relay(int s_rcv, int s_snd, const char *service, int direction)
{
       int atmark, error, maxfd;
       struct timeval tv;
       fd_set oreadfds, owritefds, oexceptfds;

       FD_ZERO(&readfds);
       FD_ZERO(&writefds);
       FD_ZERO(&exceptfds);
       (void)fcntl(s_snd, F_SETFD, O_NONBLOCK);
       oreadfds = readfds; owritefds = writefds; oexceptfds = exceptfds;
       if (s_rcv >= FD_SETSIZE)
               exit_failure("descriptor too big");
       FD_SET(s_rcv, &readfds);
       FD_SET(s_rcv, &exceptfds);
       oob_exists = 0;
       maxfd = (s_rcv > s_snd) ? s_rcv : s_snd;

       for (;;) {
               tv.tv_sec = FAITH_TIMEOUT / 4;
               tv.tv_usec = 0;
               oreadfds = readfds;
               owritefds = writefds;
               oexceptfds = exceptfds;
               error = select(maxfd + 1, &readfds, &writefds, &exceptfds, &tv);
               if (error == -1) {
                       if (errno == EINTR)
                               continue;
                       exit_failure("select: %s", strerror(errno));
               } else if (error == 0) {
                       readfds = oreadfds;
                       writefds = owritefds;
                       exceptfds = oexceptfds;
                       notify_inactive();
                       continue;
               }

               /* activity notification */
               notify_active();

               if (FD_ISSET(s_rcv, &exceptfds)) {
                       error = ioctl(s_rcv, SIOCATMARK, &atmark);
                       if (error != -1 && atmark == 1) {
                               ssize_t cc;
oob_read_retry:
                               cc = read(s_rcv, atmark_buf, 1);
                               if (cc == 1) {
                                       if (s_rcv >= FD_SETSIZE)
                                               exit_failure("descriptor too big");
                                       FD_CLR(s_rcv, &exceptfds);
                                       if (s_snd >= FD_SETSIZE)
                                               exit_failure("descriptor too big");
                                       FD_SET(s_snd, &writefds);
                                       oob_exists = 1;
                               } else if (cc == -1) {
                                       if (errno == EINTR)
                                               goto oob_read_retry;
                                       exit_failure("reading oob data failed"
                                                    ": %s",
                                                    strerror(errno));
                               }
                       }
               }
               if (FD_ISSET(s_rcv, &readfds)) {
                   relaydata_read_retry:
                       tblen = read(s_rcv, tcpbuf, sizeof(tcpbuf));
                       tboff = 0;

                       switch (tblen) {
                       case -1:
                               if (errno == EINTR)
                                       goto relaydata_read_retry;
                               exit_failure("reading relay data failed: %s",
                                            strerror(errno));
                               /* NOTREACHED */
                       case 0:
                               /* to close opposite-direction relay process */
                               (void)shutdown(s_snd, 0);

                               (void)close(s_rcv);
                               (void)close(s_snd);
                               exit_success("terminating %s relay", service);
                               /* NOTREACHED */
                       default:
                               if (s_rcv >= FD_SETSIZE)
                                       exit_failure("descriptor too big");
                               FD_CLR(s_rcv, &readfds);
                               if (s_snd >= FD_SETSIZE)
                                       exit_failure("descriptor too big");
                               FD_SET(s_snd, &writefds);
                               break;
                       }
               }
               if (FD_ISSET(s_snd, &writefds))
                       send_data(s_rcv, s_snd, service, direction);
       }
}

void
tcp_relay(int s_src, int s_dst, const char *service)
{
       syslog(LOG_INFO, "starting %s relay", service);

       child_lastactive = parent_lastactive = time(NULL);

       cpid = fork();
       switch (cpid) {
       case -1:
               exit_failure("tcp_relay: can't fork grand child: %s",
                   strerror(errno));
               /* NOTREACHED */
       case 0:
               /* child process: relay going traffic */
               ppid = getppid();
               /* this is child so reopen log */
               closelog();
               openlog(logname, LOG_PID | LOG_NOWAIT, LOG_DAEMON);
               relay(s_src, s_dst, service, 1);
               /* NOTREACHED */
       default:
               /* parent process: relay coming traffic */
               ppid = (pid_t)0;
               (void)signal(SIGUSR1, sig_ctimeout);
               (void)signal(SIGCHLD, sig_child);
               relay(s_dst, s_src, service, 0);
               /* NOTREACHED */
       }
}