/*      $NetBSD: netbsd32_socket.c,v 1.56 2021/01/19 03:41:22 simonb Exp $      */

/*
* Copyright (c) 1998, 2001 Matthew R. Green
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
*    notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
*    notice, this list of conditions and the following disclaimer in the
*    documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
* OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
* IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED
* AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
*/

#include <sys/cdefs.h>
__KERNEL_RCSID(0, "$NetBSD: netbsd32_socket.c,v 1.56 2021/01/19 03:41:22 simonb Exp $");

#include <sys/param.h>
#include <sys/systm.h>
#define msg __msg /* Don't ask me! */
#include <sys/mount.h>
#include <sys/socket.h>
#include <sys/sockio.h>
#include <sys/socketvar.h>
#include <sys/mbuf.h>
#include <sys/ktrace.h>
#include <sys/file.h>
#include <sys/filedesc.h>
#include <sys/syscallargs.h>
#include <sys/proc.h>
#include <sys/dirent.h>

#include <compat/netbsd32/netbsd32.h>
#include <compat/netbsd32/netbsd32_syscallargs.h>
#include <compat/netbsd32/netbsd32_conv.h>

/*
* XXX Assumes that struct sockaddr is compatible.
*/

#define CMSG32_ALIGN(n) (((n) + ALIGNBYTES32) & ~ALIGNBYTES32)
#define CMSG32_ASIZE    CMSG32_ALIGN(sizeof(struct cmsghdr))
#define CMSG32_DATA(cmsg) (__CASTV(u_char *, cmsg) + CMSG32_ASIZE)
#define CMSG32_MSGNEXT(ucmsg, kcmsg) \
   (__CASTV(char *, kcmsg) + CMSG32_ALIGN((ucmsg)->cmsg_len))
#define CMSG32_MSGEND(mhdr) \
   (__CASTV(char *, (mhdr)->msg_control) + (mhdr)->msg_controllen)

#define CMSG32_NXTHDR(mhdr, ucmsg, kcmsg)       \
   __CASTV(struct cmsghdr *,  \
       CMSG32_MSGNEXT(ucmsg, kcmsg) + \
       CMSG32_ASIZE > CMSG32_MSGEND(mhdr) ? 0 : \
       CMSG32_MSGNEXT(ucmsg, kcmsg))
#define CMSG32_FIRSTHDR(mhdr) \
   __CASTV(struct cmsghdr *, \
       (mhdr)->msg_controllen < sizeof(struct cmsghdr) ? 0 : \
       (mhdr)->msg_control)

#define CMSG32_SPACE(l) (CMSG32_ALIGN(sizeof(struct cmsghdr)) + CMSG32_ALIGN(l))
#define CMSG32_LEN(l)   (CMSG32_ALIGN(sizeof(struct cmsghdr)) + (l))

static int
copyout32_msg_control_mbuf(struct lwp *l, struct msghdr *mp, u_int *len,
   struct mbuf *m, char **q, bool *truncated)
{
       struct cmsghdr *cmsg, cmsg32;
       size_t i, j;
       int error;

       *truncated = false;
       cmsg = mtod(m, struct cmsghdr *);
       do {
               if ((char *)cmsg == mtod(m, char *) + m->m_len)
                       break;
               if ((char *)cmsg > mtod(m, char *) + m->m_len - sizeof(*cmsg))
                       return EINVAL;
               cmsg32 = *cmsg;
               j = cmsg->cmsg_len - CMSG_LEN(0);
               i = cmsg32.cmsg_len = CMSG32_LEN(j);
               if (i > *len) {
                       mp->msg_flags |= MSG_CTRUNC;
                       if (cmsg->cmsg_level == SOL_SOCKET
                           && cmsg->cmsg_type == SCM_RIGHTS) {
                               *truncated = true;
                               return 0;
                       }
                       j -= i - *len;
                       i = *len;
               }

               ktrkuser(mbuftypes[MT_CONTROL], cmsg, cmsg->cmsg_len);
               error = copyout(&cmsg32, *q, MIN(i, sizeof(cmsg32)));
               if (error)
                       return error;
               if (i > CMSG32_LEN(0)) {
                       error = copyout(CMSG_DATA(cmsg), *q + CMSG32_LEN(0),
                           i - CMSG32_LEN(0));
                       if (error)
                               return error;
               }
               j = CMSG32_SPACE(cmsg->cmsg_len - CMSG_LEN(0));
               if (*len >= j) {
                       *len -= j;
                       *q += j;
               } else {
                       *q += i;
                       *len = 0;
               }
               cmsg = (void *)((char *)cmsg + CMSG_ALIGN(cmsg->cmsg_len));
       } while (*len > 0);

       return 0;
}

static int
copyout32_msg_control(struct lwp *l, struct msghdr *mp, struct mbuf *control)
{
       int len, error = 0;
       struct mbuf *m;
       char *q;
       bool truncated;

       len = mp->msg_controllen;
       if (len <= 0 || control == 0) {
               mp->msg_controllen = 0;
               free_control_mbuf(l, control, control);
               return 0;
       }

       q = (char *)mp->msg_control;

       for (m = control; len > 0 && m != NULL; m = m->m_next) {
               error = copyout32_msg_control_mbuf(l, mp, &len, m, &q,
                   &truncated);
               if (truncated) {
                       m = control;
                       break;
               }
               if (error)
                       break;
       }

       free_control_mbuf(l, control, m);

       mp->msg_controllen = q - (char *)mp->msg_control;
       return error;
}

static int
msg_recv_copyin(struct lwp *l, const struct netbsd32_msghdr *msg32,
   struct msghdr *msg, struct iovec *aiov)
{
       int error;
       size_t iovsz;
       struct iovec *iov = aiov;

       iovsz = msg32->msg_iovlen * sizeof(struct iovec);
       if (msg32->msg_iovlen > UIO_SMALLIOV) {
               if (msg32->msg_iovlen > IOV_MAX)
                       return EMSGSIZE;
               iov = kmem_alloc(iovsz, KM_SLEEP);
       }

       error = netbsd32_to_iovecin(NETBSD32PTR64(msg32->msg_iov), iov,
           msg32->msg_iovlen);
       if (error)
               goto out;

       netbsd32_to_msghdr(msg32, msg);
       msg->msg_iov = iov;
out:
       if (iov != aiov)
               kmem_free(iov, iovsz);
       return error;
}

static int
msg_recv_copyout(struct lwp *l, struct netbsd32_msghdr *msg32,
   struct msghdr *msg, struct netbsd32_msghdr *arg,
   struct mbuf *from, struct mbuf *control)
{
       int error = 0;

       if (msg->msg_control != NULL)
               error = copyout32_msg_control(l, msg, control);

       if (error == 0)
               error = copyout_sockname(msg->msg_name, &msg->msg_namelen, 0,
                       from);

       if (from != NULL)
               m_free(from);
       if (error)
               return error;

       msg32->msg_namelen = msg->msg_namelen;
       msg32->msg_controllen = msg->msg_controllen;
       msg32->msg_flags = msg->msg_flags;
       ktrkuser("msghdr", msg, sizeof(*msg));
       if (arg == NULL)
               return 0;
       return copyout(msg32, arg, sizeof(*arg));
}

int
netbsd32_recvmsg(struct lwp *l, const struct netbsd32_recvmsg_args *uap,
   register_t *retval)
{
       /* {
               syscallarg(int) s;
               syscallarg(netbsd32_msghdrp_t) msg;
               syscallarg(int) flags;
       } */
       struct netbsd32_msghdr  msg32;
       struct iovec aiov[UIO_SMALLIOV];
       struct msghdr   msg;
       int             error;
       struct mbuf     *from, *control;

       error = copyin(SCARG_P32(uap, msg), &msg32, sizeof(msg32));
       if (error)
               return error;

       if ((error = msg_recv_copyin(l, &msg32, &msg, aiov)) != 0)
               return error;

       msg.msg_flags = SCARG(uap, flags) & MSG_USERFLAGS;
       error = do_sys_recvmsg(l, SCARG(uap, s), &msg,
           &from, msg.msg_control != NULL ? &control : NULL, retval);
       if (error != 0)
               goto out;

       error = msg_recv_copyout(l, &msg32, &msg, SCARG_P32(uap, msg),
           from, control);
out:
       if (msg.msg_iov != aiov)
               kmem_free(msg.msg_iov, msg.msg_iovlen * sizeof(struct iovec));
       return error;
}

int
netbsd32_recvmmsg(struct lwp *l, const struct netbsd32_recvmmsg_args *uap,
   register_t *retval)
{
       /* {
               syscallarg(int)                         s;
               syscallarg(netbsd32_mmsghdr_t)          mmsg;
               syscallarg(unsigned int)                vlen;
               syscallarg(unsigned int)                flags;
               syscallarg(netbsd32_timespecp_t)        timeout;
       } */
       struct mmsghdr mmsg;
       struct netbsd32_mmsghdr mmsg32, *mmsg32p = SCARG_P32(uap, mmsg);
       struct netbsd32_msghdr *msg32 = &mmsg32.msg_hdr;
       struct socket *so;
       struct msghdr *msg = &mmsg.msg_hdr;
       int error, s;
       struct mbuf *from, *control;
       struct timespec ts, now;
       struct netbsd32_timespec ts32;
       unsigned int vlen, flags, dg;
       struct iovec aiov[UIO_SMALLIOV];

       ts.tv_sec = 0;  // XXX: gcc
       ts.tv_nsec = 0;
       if (SCARG_P32(uap, timeout)) {
               if ((error = copyin(SCARG_P32(uap, timeout), &ts32,
                   sizeof(ts32))) != 0)
                       return error;
               getnanotime(&now);
               netbsd32_to_timespec(&ts32, &ts);
               timespecadd(&now, &ts, &ts);
       }

       s = SCARG(uap, s);
       if ((error = fd_getsock(s, &so)) != 0)
               return error;

       /*
        * If so->so_rerror holds a deferred error return it now.
        */
       if (so->so_rerror) {
               error = so->so_rerror;
               so->so_rerror = 0;
               fd_putfile(s);
               return error;
       }

       vlen = SCARG(uap, vlen);
       if (vlen > 1024)
               vlen = 1024;

       from = NULL;
       flags = SCARG(uap, flags) & MSG_USERFLAGS;

       for (dg = 0; dg < vlen;) {
               error = copyin(mmsg32p + dg, &mmsg32, sizeof(mmsg32));
               if (error)
                       break;

               if ((error = msg_recv_copyin(l, msg32, msg, aiov)) != 0)
                       return error;

               msg->msg_flags = flags & ~MSG_WAITFORONE;

               if (from != NULL) {
                       m_free(from);
                       from = NULL;
               }

               error = do_sys_recvmsg_so(l, s, so, msg, &from,
                   msg->msg_control != NULL ? &control : NULL, retval);
               if (error) {
                       if (error == EAGAIN && dg > 0)
                               error = 0;
                       break;
               }
               error = msg_recv_copyout(l, msg32, msg, NULL,
                   from, control);
               from = NULL;
               if (error)
                       break;

               mmsg32.msg_len = *retval;

               error = copyout(&mmsg32, mmsg32p + dg, sizeof(mmsg32));
               if (error)
                       break;

               dg++;
               if (msg->msg_flags & MSG_OOB)
                       break;

               if (SCARG_P32(uap, timeout)) {
                       getnanotime(&now);
                       timespecsub(&now, &ts, &now);
                       if (now.tv_sec > 0)
                               break;
               }

               if (flags & MSG_WAITFORONE)
                       flags |= MSG_DONTWAIT;

       }

       if (from != NULL)
               m_free(from);

       *retval = dg;

       /*
        * If we succeeded at least once, return 0, hopefully so->so_rerror
        * will catch it next time.
        */
       if (error && dg > 0) {
               so->so_rerror = error;
               error = 0;
       }

       fd_putfile(s);

       return error;
}

static int
copyin32_msg_control(struct lwp *l, struct msghdr *mp)
{
       /*
        * Handle cmsg if there is any.
        */
       struct cmsghdr *cmsg, cmsg32, *cc;
       struct mbuf *ctl_mbuf;
       ssize_t resid = mp->msg_controllen;
       size_t clen, cidx = 0, cspace;
       uint8_t *control;
       int error;

       ctl_mbuf = m_get(M_WAIT, MT_CONTROL);
       clen = MLEN;
       control = mtod(ctl_mbuf, void *);
       memset(control, 0, clen);

       for (cc = CMSG32_FIRSTHDR(mp); cc; cc = CMSG32_NXTHDR(mp, &cmsg32, cc))
       {
               error = copyin(cc, &cmsg32, sizeof(cmsg32));
               if (error)
                       goto failure;

               /*
                * Sanity check the control message length.
                */
               if (resid < 0 ||
                   cmsg32.cmsg_len > (size_t)resid ||
                   cmsg32.cmsg_len < sizeof(cmsg32)) {
                       error = EINVAL;
                       goto failure;
               }

               cspace = CMSG_SPACE(cmsg32.cmsg_len - CMSG32_LEN(0));

               /* Check the buffer is big enough */
               if (__predict_false(cidx + cspace > clen)) {
                       uint8_t *nc;
                       size_t nclen;

                       nclen = cidx + cspace;
                       if (nclen >= (size_t)PAGE_SIZE) {
                               error = EINVAL;
                               goto failure;
                       }
                       nc = realloc(clen <= MLEN ? NULL : control,
                                    nclen, M_TEMP, M_WAITOK);
                       if (!nc) {
                               error = ENOMEM;
                               goto failure;
                       }
                       if (cidx <= MLEN) {
                               /* Old buffer was in mbuf... */
                               memcpy(nc, control, cidx);
                               memset(nc + cidx, 0, nclen - cidx);
                       } else {
                               memset(nc + nclen, 0, nclen - clen);
                       }
                       control = nc;
                       clen = nclen;
               }

               /* Copy header */
               cmsg = (void *)&control[cidx];
               cmsg->cmsg_len = CMSG_LEN(cmsg32.cmsg_len - CMSG32_LEN(0));
               cmsg->cmsg_level = cmsg32.cmsg_level;
               cmsg->cmsg_type = cmsg32.cmsg_type;

               /* Copyin the data */
               error = copyin(CMSG32_DATA(cc), CMSG_DATA(cmsg),
                   cmsg32.cmsg_len - CMSG32_LEN(0));
               if (error)
                       goto failure;
               ktrkuser(mbuftypes[MT_CONTROL], cmsg, cmsg->cmsg_len);

               resid -= CMSG32_ALIGN(cmsg32.cmsg_len);
               cidx += CMSG_ALIGN(cmsg->cmsg_len);
       }

       /* If we allocated a buffer, attach to mbuf */
       if (cidx > MLEN) {
               MEXTADD(ctl_mbuf, control, clen, M_MBUF, NULL, NULL);
               ctl_mbuf->m_flags |= M_EXT_RW;
       }
       control = NULL;
       mp->msg_controllen = ctl_mbuf->m_len = CMSG_ALIGN(cidx);

       mp->msg_control = ctl_mbuf;
       mp->msg_flags |= MSG_CONTROLMBUF;


       return 0;

failure:
       if (control != mtod(ctl_mbuf, void *))
               free(control, M_MBUF);
       m_free(ctl_mbuf);
       return error;
}

static int
msg_send_copyin(struct lwp *l, const struct netbsd32_msghdr *msg32,
   struct msghdr *msg, struct iovec *aiov)
{
       int error;
       struct iovec *iov = aiov;
       struct netbsd32_iovec *iov32;
       size_t iovsz;

       netbsd32_to_msghdr(msg32, msg);
       msg->msg_flags = 0;

       if (CMSG32_FIRSTHDR(msg)) {
               error = copyin32_msg_control(l, msg);
               if (error)
                       return error;
               /* From here on, msg->msg_control is allocated */
       } else {
               msg->msg_control = NULL;
               msg->msg_controllen = 0;
       }

       iovsz = msg->msg_iovlen * sizeof(struct iovec);
       if ((u_int)msg->msg_iovlen > UIO_SMALLIOV) {
               if ((u_int)msg->msg_iovlen > IOV_MAX) {
                       error = EMSGSIZE;
                       goto out;
               }
               iov = kmem_alloc(iovsz, KM_SLEEP);
       }

       iov32 = NETBSD32PTR64(msg32->msg_iov);
       error = netbsd32_to_iovecin(iov32, iov, msg->msg_iovlen);
       if (error)
               goto out;
       msg->msg_iov = iov;
       return 0;
out:
       if (msg->msg_control)
               m_free(msg->msg_control);
       if (iov != aiov)
               kmem_free(iov, iovsz);
       return error;
}

int
netbsd32_sendmsg(struct lwp *l, const struct netbsd32_sendmsg_args *uap,
   register_t *retval)
{
       /* {
               syscallarg(int) s;
               syscallarg(const netbsd32_msghdrp_t) msg;
               syscallarg(int) flags;
       } */
       struct msghdr msg;
       struct netbsd32_msghdr msg32;
       struct iovec aiov[UIO_SMALLIOV];
       int error;

       error = copyin(SCARG_P32(uap, msg), &msg32, sizeof(msg32));
       if (error)
               return error;

       if ((error = msg_send_copyin(l, &msg32, &msg, aiov)) != 0)
               return error;

       error = do_sys_sendmsg(l, SCARG(uap, s), &msg, SCARG(uap, flags),
           retval);
       /* msg.msg_control freed by do_sys_sendmsg() */

       if (msg.msg_iov != aiov)
               kmem_free(msg.msg_iov, msg.msg_iovlen * sizeof(struct iovec));
       return error;
}

int
netbsd32_sendmmsg(struct lwp *l, const struct netbsd32_sendmmsg_args *uap,
   register_t *retval)
{
       /* {
               syscallarg(int)                 s;
               syscallarg(const netbsd32_mmsghdr_t)    mmsg;
               syscallarg(unsigned int)        vlen;
               syscallarg(unsigned int)        flags;
       } */
       struct mmsghdr mmsg;
       struct netbsd32_mmsghdr mmsg32, *mmsg32p = SCARG_P32(uap, mmsg);
       struct netbsd32_msghdr *msg32 = &mmsg32.msg_hdr;
       struct socket *so;
       file_t *fp;
       struct msghdr *msg = &mmsg.msg_hdr;
       int error, s;
       unsigned int vlen, flags, dg;
       struct iovec aiov[UIO_SMALLIOV];

       s = SCARG(uap, s);
       if ((error = fd_getsock1(s, &so, &fp)) != 0)
               return error;

       vlen = SCARG(uap, vlen);
       if (vlen > 1024)
               vlen = 1024;

       flags = SCARG(uap, flags) & MSG_USERFLAGS;

       for (dg = 0; dg < vlen;) {
               error = copyin(mmsg32p + dg, &mmsg32, sizeof(mmsg32));
               if (error)
                       break;
               if ((error = msg_send_copyin(l, msg32, msg, aiov)) != 0)
                       break;

               msg->msg_flags = flags;

               error = do_sys_sendmsg_so(l, s, so, fp, msg, flags, retval);
               if (msg->msg_iov != aiov) {
                       kmem_free(msg->msg_iov,
                           msg->msg_iovlen * sizeof(struct iovec));
               }
               if (error)
                       break;

               ktrkuser("msghdr", msg, sizeof(*msg));
               mmsg.msg_len = *retval;
               netbsd32_from_mmsghdr(&mmsg32, &mmsg);
               error = copyout(&mmsg32, mmsg32p + dg, sizeof(mmsg32));
               if (error)
                       break;
               dg++;
       }

       *retval = dg;

       fd_putfile(s);

       /*
        * If we succeeded at least once, return 0.
        */
       if (dg)
               return 0;
       return error;
}

int
netbsd32_recvfrom(struct lwp *l, const struct netbsd32_recvfrom_args *uap,
   register_t *retval)
{
       /* {
               syscallarg(int) s;
               syscallarg(netbsd32_voidp) buf;
               syscallarg(netbsd32_size_t) len;
               syscallarg(int) flags;
               syscallarg(netbsd32_sockaddrp_t) from;
               syscallarg(netbsd32_intp) fromlenaddr;
       } */
       struct msghdr   msg;
       struct iovec    aiov;
       int             error;
       struct mbuf     *from;

       if (SCARG(uap, len) > NETBSD32_SSIZE_MAX)
               return EINVAL;

       msg.msg_name = NULL;
       msg.msg_iov = &aiov;
       msg.msg_iovlen = 1;
       aiov.iov_base = SCARG_P32(uap, buf);
       aiov.iov_len = SCARG(uap, len);
       msg.msg_control = NULL;
       msg.msg_flags = SCARG(uap, flags) & MSG_USERFLAGS;

       error = do_sys_recvmsg(l, SCARG(uap, s), &msg, &from, NULL, retval);
       if (error != 0)
               return error;

       error = copyout_sockname(SCARG_P32(uap, from),
           SCARG_P32(uap, fromlenaddr), MSG_LENUSRSPACE, from);
       if (from != NULL)
               m_free(from);
       return error;
}

int
netbsd32_sendto(struct lwp *l, const struct netbsd32_sendto_args *uap,
   register_t *retval)
{
       /* {
               syscallarg(int) s;
               syscallarg(const netbsd32_voidp) buf;
               syscallarg(netbsd32_size_t) len;
               syscallarg(int) flags;
               syscallarg(const netbsd32_sockaddrp_t) to;
               syscallarg(int) tolen;
       } */
       struct msghdr msg;
       struct iovec aiov;

       if (SCARG(uap, len) > NETBSD32_SSIZE_MAX)
               return EINVAL;

       msg.msg_name = SCARG_P32(uap, to); /* XXX kills const */
       msg.msg_namelen = SCARG(uap, tolen);
       msg.msg_iov = &aiov;
       msg.msg_iovlen = 1;
       msg.msg_control = 0;
       aiov.iov_base = SCARG_P32(uap, buf);    /* XXX kills const */
       aiov.iov_len = SCARG(uap, len);
       msg.msg_flags = 0;
       return do_sys_sendmsg(l, SCARG(uap, s), &msg, SCARG(uap, flags),
           retval);
}