/*      $NetBSD: in6_l2tp.c,v 1.23 2023/09/01 11:23:39 andvar Exp $     */

/*
* Copyright (c) 2017 Internet Initiative Japan Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
*    notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
*    notice, this list of conditions and the following disclaimer in the
*    documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE NETBSD FOUNDATION, INC. AND CONTRIBUTORS
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR CONTRIBUTORS
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/

#include <sys/cdefs.h>
__KERNEL_RCSID(0, "$NetBSD: in6_l2tp.c,v 1.23 2023/09/01 11:23:39 andvar Exp $");

#ifdef _KERNEL_OPT
#include "opt_l2tp.h"
#endif

#include <sys/param.h>
#include <sys/systm.h>
#include <sys/socket.h>
#include <sys/sockio.h>
#include <sys/mbuf.h>
#include <sys/errno.h>
#include <sys/ioctl.h>
#include <sys/syslog.h>
#include <sys/kernel.h>

#include <net/if.h>
#include <net/route.h>
#include <net/if_ether.h>

#include <netinet/in.h>
#include <netinet/in_systm.h>
#include <netinet/ip.h>
#include <netinet/ip_var.h>
#include <netinet/ip_private.h>
#include <netinet/in_l2tp.h>
#include <netinet/in_var.h>
#include <netinet/ip_encap.h>

#include <netinet/ip6.h>
#include <netinet6/ip6_var.h>
#include <netinet6/ip6_private.h>
#include <netinet6/in6_l2tp.h>

#ifdef ALTQ
#include <altq/altq.h>
#endif

/* TODO: IP_TCPMSS support */
#undef IP_TCPMSS
#ifdef IP_TCPMSS
#include <netinet/ip_tcpmss.h>
#endif

#include <net/if_l2tp.h>

#define L2TP_HLIM6              64
int ip6_l2tp_hlim = L2TP_HLIM6;

static int in6_l2tp_input(struct mbuf **, int *, int, void *);

static const struct encapsw in6_l2tp_encapsw = {
       .encapsw6 = {
               .pr_input       = in6_l2tp_input,
               .pr_ctlinput    = NULL,
       }
};

static int in6_l2tp_match(struct mbuf *, int, int, void *);

int
in6_l2tp_output(struct l2tp_variant *var, struct mbuf *m)
{
       struct rtentry *rt;
       struct route *ro_pc;
       kmutex_t *lock_pc;
       struct l2tp_softc *sc;
       struct ifnet *ifp;
       struct sockaddr_in6 *sin6_src = satosin6(var->lv_psrc);
       struct sockaddr_in6 *sin6_dst = satosin6(var->lv_pdst);
       struct ip6_hdr ip6hdr;  /* capsule IP header, host byte ordered */
       int error;
       uint32_t sess_id;

       KASSERT(var != NULL);
       KASSERT(l2tp_heldref_variant(var));
       KASSERT(sin6_src != NULL && sin6_dst != NULL);
       KASSERT(sin6_src->sin6_family == AF_INET6
           && sin6_dst->sin6_family == AF_INET6);

       sc = var->lv_softc;
       ifp = &sc->l2tp_ec.ec_if;
       error = l2tp_check_nesting(ifp, m);
       if (error) {
               m_freem(m);
               goto looped;
       }

       /* bidirectional configured tunnel mode */
       if (IN6_IS_ADDR_UNSPECIFIED(&sin6_dst->sin6_addr)) {
               m_freem(m);
               if ((ifp->if_flags & IFF_DEBUG) != 0)
                       log(LOG_DEBUG, "%s: ENETUNREACH\n", __func__);
               return ENETUNREACH;
       }

#ifdef NOTYET
/* TODO: support ALTQ for inner frame */
#ifdef ALTQ
       ALTQ_SAVE_PAYLOAD(m, AF_ETHER);
#endif
#endif

       memset(&ip6hdr, 0, sizeof(ip6hdr));
       ip6hdr.ip6_src = sin6_src->sin6_addr;
       ip6hdr.ip6_dst = sin6_dst->sin6_addr;
       /* unlike IPv4, IP version must be filled by caller of ip6_output() */
       ip6hdr.ip6_vfc = 0x60;
       ip6hdr.ip6_nxt = IPPROTO_L2TP;
       ip6hdr.ip6_hlim = ip6_l2tp_hlim;
       /* outer IP payload length */
       ip6hdr.ip6_plen = 0;
       /* session-id length */
       ip6hdr.ip6_plen += sizeof(uint32_t);
       if (var->lv_use_cookie == L2TP_COOKIE_ON) {
               /* cookie length */
               ip6hdr.ip6_plen += var->lv_peer_cookie_len;
       }

/* TODO: IP_TCPMSS support */
#ifdef IP_TCPMSS
       m = l2tp_tcpmss_clamp(ifp, m);
       if (m == NULL)
               return EINVAL;
#endif

       /*
        * Payload length.
        *
        * NOTE: payload length may be changed in ip_tcpmss(). Typical case
        * is missing of TCP mss option in original TCP header.
        */
       ip6hdr.ip6_plen += m->m_pkthdr.len;
       HTONS(ip6hdr.ip6_plen);

       if (var->lv_use_cookie == L2TP_COOKIE_ON) {
               /* prepend session cookie */
               uint32_t cookie_32;
               uint64_t cookie_64;
               M_PREPEND(m, var->lv_peer_cookie_len, M_DONTWAIT);
               if (m && m->m_len < var->lv_peer_cookie_len)
                       m = m_pullup(m, var->lv_peer_cookie_len);
               if (m == NULL)
                       return ENOBUFS;
               if (var->lv_peer_cookie_len == 4) {
                       cookie_32 = htonl((uint32_t)var->lv_peer_cookie);
                       memcpy(mtod(m, void *), &cookie_32, sizeof(uint32_t));
               } else {
                       cookie_64 = htobe64(var->lv_peer_cookie);
                       memcpy(mtod(m, void *), &cookie_64, sizeof(uint64_t));
               }
       }

       /* prepend session-ID */
       sess_id = htonl(var->lv_peer_sess_id);
       M_PREPEND(m, sizeof(uint32_t), M_DONTWAIT);
       if (m && m->m_len < sizeof(uint32_t))
               m = m_pullup(m, sizeof(uint32_t));
       if (m == NULL)
               return ENOBUFS;
       memcpy(mtod(m, uint32_t *), &sess_id, sizeof(uint32_t));

       /* prepend new IP header */
       M_PREPEND(m, sizeof(struct ip6_hdr), M_DONTWAIT);
       if (m == NULL)
               return ENOBUFS;
       if (M_GET_ALIGNED_HDR(&m, struct ip6_hdr, false) != 0)
               return ENOBUFS;
       memcpy(mtod(m, struct ip6_hdr *), &ip6hdr, sizeof(struct ip6_hdr));

       if_tunnel_get_ro(sc->l2tp_ro_percpu, &ro_pc, &lock_pc);
       if ((rt = rtcache_lookup(ro_pc, var->lv_pdst)) == NULL) {
               if_tunnel_put_ro(sc->l2tp_ro_percpu, lock_pc);
               m_freem(m);
               return ENETUNREACH;
       }

       /* If the route constitutes infinite encapsulation, punt. */
       if (rt->rt_ifp == ifp) {
               rtcache_unref(rt, ro_pc);
               rtcache_free(ro_pc);
               if_tunnel_put_ro(sc->l2tp_ro_percpu, lock_pc);
               m_freem(m);
               return ENETUNREACH;     /* XXX */
       }
       rtcache_unref(rt, ro_pc);

       /*
        * To avoid inappropriate rewrite of checksum,
        * clear csum flags.
        */
       m->m_pkthdr.csum_flags  = 0;

       error = ip6_output(m, 0, ro_pc, 0, NULL, NULL, NULL);
       if_tunnel_put_ro(sc->l2tp_ro_percpu, lock_pc);
       return(error);

looped:
       if (error)
               if_statinc(ifp, if_oerrors);

       return error;
}

static int
in6_l2tp_input(struct mbuf **mp, int *offp, int proto, void *eparg __unused)
{
       struct mbuf *m = *mp;
       int off = *offp;

       struct ifnet *l2tpp = NULL;
       struct l2tp_softc *sc;
       struct l2tp_variant *var;
       uint32_t sess_id;
       uint32_t cookie_32;
       uint64_t cookie_64;
       struct psref psref;

       KASSERT((m->m_flags & M_PKTHDR) != 0);

       if (m->m_pkthdr.len < off + sizeof(uint32_t)) {
               m_freem(m);
               return IPPROTO_DONE;
       }

       /* get L2TP session ID */
       m_copydata(m, off, sizeof(uint32_t), (void *)&sess_id);
       NTOHL(sess_id);
#ifdef L2TP_DEBUG
       log(LOG_DEBUG, "%s: sess_id = %" PRIu32 "\n", __func__, sess_id);
#endif
       if (sess_id == 0) {
               int rv;
               /*
                * L2TPv3 control packet received.
                * userland daemon(l2tpd?) should process.
                */
               SOFTNET_LOCK_IF_NET_MPSAFE();
               rv = rip6_input(mp, offp, proto);
               SOFTNET_UNLOCK_IF_NET_MPSAFE();
               return rv;
       }

       var = l2tp_lookup_session_ref(sess_id, &psref);
       if (var == NULL) {
               m_freem(m);
               IP_STATINC(IP_STAT_NOL2TP);
               return IPPROTO_DONE;
       }

       sc = var->lv_softc;
       l2tpp = &(sc->l2tp_ec.ec_if);

       if (l2tpp == NULL || (l2tpp->if_flags & IFF_UP) == 0) {
#ifdef L2TP_DEBUG
               if (l2tpp == NULL)
                       log(LOG_DEBUG, "%s: l2tpp is NULL\n", __func__);
               else
                       log(LOG_DEBUG, "%s: l2tpp is down\n", __func__);
#endif
               m_freem(m);
               IP_STATINC(IP_STAT_NOL2TP);
               goto out;
       }

       /* other CPU did l2tp_delete_tunnel */
       if (var->lv_psrc == NULL || var->lv_pdst == NULL) {
               m_freem(m);
               ip_statinc(IP_STAT_NOL2TP);
               goto out;
       }

       if (var->lv_state != L2TP_STATE_UP) {
               m_freem(m);
               goto out;
       }
       m_adj(m, off + sizeof(uint32_t));

       if (var->lv_use_cookie == L2TP_COOKIE_ON) {
               if (m->m_pkthdr.len < var->lv_my_cookie_len) {
                       m_freem(m);
                       goto out;
               }
               if (var->lv_my_cookie_len == 4) {
                       m_copydata(m, 0, sizeof(uint32_t), (void *)&cookie_32);
                       NTOHL(cookie_32);
                       if (cookie_32 != var->lv_my_cookie) {
                               m_freem(m);
                               goto out;
                       }
                       m_adj(m, sizeof(uint32_t));
               } else {
                       m_copydata(m, 0, sizeof(uint64_t), (void *)&cookie_64);
                       BE64TOH(cookie_64);
                       if (cookie_64 != var->lv_my_cookie) {
                               m_freem(m);
                               goto out;
                       }
                       m_adj(m, sizeof(uint64_t));
               }
       }

/* TODO: IP_TCPMSS support */
#ifdef IP_TCPMSS
       m = l2tp_tcpmss_clamp(l2tpp, m);
       if (m == NULL)
               goto out;
#endif
       l2tp_input(m, l2tpp);

out:
       l2tp_putref_variant(var, &psref);
       return IPPROTO_DONE;
}

/*
* This function is used by encap6_lookup() to decide priority of the encaptab.
* This priority is compared to the match length between mbuf's source/destination
* IPv6 address pair and encaptab's one.
* l2tp(4) does not use address pairs to search matched encaptab, so this
* function must return the length bigger than or equals to IPv6 address pair to
* avoid wrong encaptab.
*/
static int
in6_l2tp_match(struct mbuf *m, int off, int proto, void *arg)
{
       struct l2tp_softc *sc = arg;
       struct l2tp_variant *var;
       struct psref psref;
       uint32_t sess_id;
       int rv = 0;

       KASSERT(proto == IPPROTO_L2TP);

       var = l2tp_getref_variant(sc, &psref);
       if (__predict_false(var == NULL))
               return rv;

       /*
        * If the packet contains no session ID it cannot match
        */
       if (m_length(m) < off + sizeof(uint32_t)) {
               rv = 0 ;
               goto out;
       }

       /* get L2TP session ID */
       m_copydata(m, off, sizeof(uint32_t), (void *)&sess_id);
       NTOHL(sess_id);
       if (sess_id == 0) {
               /*
                * L2TPv3 control packet received.
                * userland daemon(l2tpd?) should process.
                */
               rv = 128 * 2;
       } else if (sess_id == var->lv_my_sess_id)
               rv = 128 * 2;
       else
               rv = 0;

out:
       l2tp_putref_variant(var, &psref);
       return rv;
}

int
in6_l2tp_attach(struct l2tp_variant *var)
{
       struct l2tp_softc *sc = var->lv_softc;

       if (sc == NULL)
               return EINVAL;

       var->lv_encap_cookie = encap_attach_addr(AF_INET6, IPPROTO_L2TP,
           var->lv_psrc, var->lv_pdst, in6_l2tp_match, &in6_l2tp_encapsw, sc);
       if (var->lv_encap_cookie == NULL)
               return EEXIST;

       return 0;
}

int
in6_l2tp_detach(struct l2tp_variant *var)
{
       int error;

       error = encap_detach(var->lv_encap_cookie);
       if (error == 0)
               var->lv_encap_cookie = NULL;

       return error;
}