/*      $NetBSD: server.c,v 1.12 2021/08/08 20:54:49 nia Exp $  */

/*-
* Copyright (c) 2006 Itronix Inc.
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
*    notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
*    notice, this list of conditions and the following disclaimer in the
*    documentation and/or other materials provided with the distribution.
* 3. The name of Itronix Inc. may not be used to endorse
*    or promote products derived from this software without specific
*    prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY ITRONIX INC. ``AS IS'' AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL ITRONIX INC. BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
/*-
* Copyright (c) 2009 The NetBSD Foundation, Inc.
* Copyright (c) 2004 Maksim Yevmenkin <[email protected]>
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
*    notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
*    notice, this list of conditions and the following disclaimer in the
*    documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
*
* $FreeBSD: src/usr.sbin/bluetooth/sdpd/server.c,v 1.2 2005/12/06 17:56:36 emax Exp $
*/

#include <sys/cdefs.h>
__RCSID("$NetBSD: server.c,v 1.12 2021/08/08 20:54:49 nia Exp $");

#include <sys/select.h>
#include <sys/stat.h>
#include <sys/ucred.h>
#include <sys/un.h>

#include <assert.h>
#include <bluetooth.h>
#include <errno.h>
#include <grp.h>
#include <pwd.h>
#include <sdp.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include "sdpd.h"

static bool     server_open_control     (server_t *, char const *);
static bool     server_open_l2cap       (server_t *);
static void     server_accept_client    (server_t *, int);
static bool     server_process_request  (server_t *, int);
static void     server_close_fd         (server_t *, int);
static bool     server_auth_check       (server_t *, void *);

/* number of groups we allocate space for in cmsg */
#define MAX_GROUPS      20

/*
* Initialize server
*/
bool
server_init(server_t *srv, char const *control, char const *sgroup)
{

       assert(srv != NULL);
       assert(control != NULL);

       memset(srv, 0, sizeof(*srv));
       FD_ZERO(&srv->fdset);
       srv->sgroup = sgroup;

       srv->fdmax = -1;
       srv->fdidx = calloc(FD_SETSIZE, sizeof(fd_idx_t));
       if (srv->fdidx == NULL) {
               log_crit("Failed to allocate fd index");
               goto fail;
       }

       srv->ctllen = CMSG_SPACE(SOCKCREDSIZE(MAX_GROUPS));
       srv->ctlbuf = malloc(srv->ctllen);
       if (srv->ctlbuf == NULL) {
               log_crit("Malloc cmsg buffer (len=%zu) failed.", srv->ctllen);
               goto fail;
       }

       srv->imtu = SDP_LOCAL_MTU - sizeof(sdp_pdu_t);
       srv->ibuf = malloc(srv->imtu);
       if (srv->ibuf == NULL) {
               log_crit("Malloc input buffer (imtu=%d) failed.", srv->imtu);
               goto fail;
       }

       srv->omtu = L2CAP_MTU_DEFAULT - sizeof(sdp_pdu_t);
       srv->obuf = malloc(srv->omtu);
       if (srv->obuf == NULL) {
               log_crit("Malloc output buffer (omtu=%d) failed.", srv->omtu);
               goto fail;
       }

       if (db_init(srv)
           && server_open_control(srv, control)
           && server_open_l2cap(srv))
               return true;

fail:
       server_shutdown(srv);
       return false;
}

/*
* Open local control socket
*/
static bool
server_open_control(server_t *srv, char const *control)
{
       struct sockaddr_un      un;
       int                     opt, fd;

       if (unlink(control) == -1 && errno != ENOENT) {
               log_crit("Could not unlink(%s). %s (%d)",
                   control, strerror(errno), errno);

               return false;
       }

       fd = socket(PF_LOCAL, SOCK_STREAM, 0);
       if (fd == -1) {
               log_crit("Could not create control socket. %s (%d)",
                   strerror(errno), errno);

               return false;
       }

       opt = 1;
       if (setsockopt(fd, SOL_LOCAL, LOCAL_CREDS, &opt, sizeof(opt)) == -1)
               log_crit("Warning: No credential checks on control socket");

       memset(&un, 0, sizeof(un));
       un.sun_len = sizeof(un);
       un.sun_family = AF_LOCAL;
       strlcpy(un.sun_path, control, sizeof(un.sun_path));

       if (bind(fd, (struct sockaddr *) &un, sizeof(un)) == -1) {
               log_crit("Could not bind control socket. %s (%d)",
                   strerror(errno), errno);

               close(fd);
               return false;
       }

       if (chmod(control, S_IRUSR|S_IWUSR|S_IRGRP|S_IWGRP|S_IROTH|S_IWOTH) == -1) {
               log_crit("Could not set permissions on control socket. %s (%d)",
                   strerror(errno), errno);

               close(fd);
               return false;
       }

       if (listen(fd, 5) == -1) {
               log_crit("Could not listen on control socket. %s (%d)",
                   strerror(errno), errno);

               close(fd);
               return false;
       }

       /* Add control descriptor to index */
       if (fd > srv->fdmax)
               srv->fdmax = fd;

       FD_SET(fd, &srv->fdset);
       srv->fdidx[fd].valid = true;
       srv->fdidx[fd].server = true;
       srv->fdidx[fd].control = true;
       srv->fdidx[fd].priv = false;
       return true;
}

/*
* Open L2CAP server socket
*/
static bool
server_open_l2cap(server_t *srv)
{
       struct sockaddr_bt      sa;
       int                     fd;

       fd = socket(PF_BLUETOOTH, SOCK_SEQPACKET, BTPROTO_L2CAP);
       if (fd == -1) {
               log_crit("Could not create L2CAP socket. %s (%d)",
                   strerror(errno), errno);

               return false;
       }

       if (setsockopt(fd, BTPROTO_L2CAP, SO_L2CAP_IMTU,
           &srv->imtu, sizeof(srv->imtu)) == -1) {
               log_crit("Could not set L2CAP Incoming MTU. %s (%d)",
                   strerror(errno), errno);

               close(fd);
               return false;
       }

       memset(&sa, 0, sizeof(sa));
       sa.bt_len = sizeof(sa);
       sa.bt_family = AF_BLUETOOTH;
       sa.bt_psm = L2CAP_PSM_SDP;
       bdaddr_copy(&sa.bt_bdaddr, BDADDR_ANY);

       if (bind(fd, (struct sockaddr *) &sa, sizeof(sa)) == -1) {
               log_crit("Could not bind L2CAP socket. %s (%d)",
                   strerror(errno), errno);

               close(fd);
               return false;
       }

       if (listen(fd, 5) == -1) {
               log_crit("Could not listen on L2CAP socket. %s (%d)",
                   strerror(errno), errno);

               close(fd);
               return false;
       }

       /* Add L2CAP descriptor to index */
       if (fd > srv->fdmax)
               srv->fdmax = fd;

       FD_SET(fd, &srv->fdset);
       srv->fdidx[fd].valid = true;
       srv->fdidx[fd].server = true;
       srv->fdidx[fd].control = false;
       srv->fdidx[fd].priv = false;
       return true;
}

/*
* Shutdown server
*/
void
server_shutdown(server_t *srv)
{
       record_t *r;
       int     fd;

       assert(srv != NULL);

       while ((r = LIST_FIRST(&srv->rlist)) != NULL) {
               LIST_REMOVE(r, next);
               free(r);
       }

       for (fd = 0; fd < srv->fdmax + 1; fd ++) {
               if (srv->fdidx[fd].valid)
                       server_close_fd(srv, fd);
       }

       free(srv->fdidx);
       free(srv->ctlbuf);
       free(srv->ibuf);
       free(srv->obuf);

       memset(srv, 0, sizeof(*srv));
}

/*
* Do one server iteration
*/
bool
server_do(server_t *srv)
{
       fd_set  fdset;
       int     n, fd;

       assert(srv != NULL);

       memcpy(&fdset, &srv->fdset, sizeof(fdset));
       n = select(srv->fdmax + 1, &fdset, NULL, NULL, NULL);
       if (n == -1) {
               if (errno == EINTR)
                       return true;

               log_err("Could not select(%d, %p). %s (%d)",
                   srv->fdmax + 1, &fdset, strerror(errno), errno);

               return false;
       }

       for (fd = 0; fd < srv->fdmax + 1 && n > 0; fd++) {
               if (!FD_ISSET(fd, &fdset))
                       continue;

               assert(srv->fdidx[fd].valid);

               if (srv->fdidx[fd].server)
                       server_accept_client(srv, fd);
               else if (!server_process_request(srv, fd))
                       server_close_fd(srv, fd);

               n--;
       }

       return true;

}

/*
* Accept new client connection and register it with index
*/
static void
server_accept_client(server_t *srv, int fd)
{
       struct sockaddr_bt      sa;
       socklen_t               len;
       int                     cfd;
       uint16_t                omtu;

       do {
               cfd = accept(fd, NULL, NULL);
       } while (cfd == -1 && errno == EINTR);

       if (cfd == -1) {
               log_err("Could not accept connection on %s socket. %s (%d)",
                   srv->fdidx[fd].control ? "control" : "L2CAP",
                   strerror(errno), errno);

               return;
       }

       if (cfd >= FD_SETSIZE) {
               log_crit("File descriptor too large");
               close(cfd);
               return;
       }

       assert(!FD_ISSET(cfd, &srv->fdset));
       assert(!srv->fdidx[cfd].valid);

       memset(&sa, 0, sizeof(sa));
       omtu = srv->omtu;

       if (!srv->fdidx[fd].control) {
               len = sizeof(sa);
               if (getsockname(cfd, (struct sockaddr *)&sa, &len) == -1)
                       log_warning("getsockname failed, using BDADDR_ANY");

               len = sizeof(omtu);
               if (getsockopt(cfd, BTPROTO_L2CAP, SO_L2CAP_OMTU, &omtu, &len) == -1)
                       log_warning("Could not get L2CAP OMTU, using %d", omtu);
               else
                       omtu -= sizeof(sdp_pdu_t);
       }

       /* Add client descriptor to the index */
       if (cfd > srv->fdmax)
               srv->fdmax = cfd;

       FD_SET(cfd, &srv->fdset);
       srv->fdidx[cfd].valid = true;
       srv->fdidx[cfd].server = false;
       srv->fdidx[cfd].control = srv->fdidx[fd].control;
       srv->fdidx[cfd].priv = false;
       srv->fdidx[cfd].omtu = (omtu > srv->omtu) ? srv->omtu : omtu;
       srv->fdidx[cfd].offset = 0;
       bdaddr_copy(&srv->fdidx[cfd].bdaddr, &sa.bt_bdaddr);

       log_debug("new %s client on fd#%d",
           srv->fdidx[cfd].control ? "control" : "L2CAP", cfd);
}

/*
* Process request from the client
*/
static bool
server_process_request(server_t *srv, int fd)
{
       struct msghdr   msg;
       struct iovec    iov[2];
       struct cmsghdr  *cmsg;
       ssize_t         len;
       uint16_t        error;

       assert(FD_ISSET(fd, &srv->fdset));
       assert(srv->fdidx[fd].valid);
       assert(!srv->fdidx[fd].server);

       iov[0].iov_base = &srv->pdu;
       iov[0].iov_len = sizeof(srv->pdu);
       iov[1].iov_base = srv->ibuf;
       iov[1].iov_len = srv->imtu;

       msg.msg_name = NULL;
       msg.msg_namelen = 0;
       msg.msg_iov = iov;
       msg.msg_iovlen = __arraycount(iov);
       msg.msg_control = srv->ctlbuf;
       msg.msg_controllen = srv->ctllen;
       msg.msg_flags = 0;

       do {
               len = recvmsg(fd, &msg, 0);
       } while (len == -1 && errno == EINTR);

       if (len == -1) {
               log_err("Could not receive SDP request on %s socket. %s (%d)",
                   srv->fdidx[fd].control ? "control" : "L2CAP",
                   strerror(errno), errno);

               return false;
       }

       if (len == 0) {
               log_info("Client on %s socket has disconnected",
                   srv->fdidx[fd].control ? "control" : "L2CAP");

               return false;
       }

       if (msg.msg_flags & MSG_TRUNC)
               log_info("Truncated message on %s socket",
                   srv->fdidx[fd].control ? "control" : "L2CAP");

       if ((cmsg = CMSG_FIRSTHDR(&msg)) != NULL
           && cmsg->cmsg_level == SOL_SOCKET
           && cmsg->cmsg_type == SCM_CREDS
           && cmsg->cmsg_len >= CMSG_LEN(SOCKCREDSIZE(0)))
               srv->fdidx[fd].priv = server_auth_check(srv, CMSG_DATA(cmsg));

       srv->pdu.len = be16toh(srv->pdu.len);

       if ((uint32_t)len < sizeof(srv->pdu)
           || (uint32_t)len != sizeof(srv->pdu) + srv->pdu.len) {
               error = SDP_ERROR_CODE_INVALID_PDU_SIZE;
       } else {
               switch (srv->pdu.pid) {
               case SDP_PDU_SERVICE_SEARCH_REQUEST:
                       error = service_search_request(srv, fd);
                       break;

               case SDP_PDU_SERVICE_ATTRIBUTE_REQUEST:
                       error = service_attribute_request(srv, fd);
                       break;

               case SDP_PDU_SERVICE_SEARCH_ATTRIBUTE_REQUEST:
                       error = service_search_attribute_request(srv, fd);
                       break;

#ifdef SDP_COMPAT
               case SDP_PDU_SERVICE_REGISTER_REQUEST:
                       error = compat_register_request(srv, fd);
                       break;

               case SDP_PDU_SERVICE_CHANGE_REQUEST:
                       error = compat_change_request(srv, fd);
                       break;
#endif

               case SDP_PDU_RECORD_INSERT_REQUEST:
                       error = record_insert_request(srv, fd);
                       break;

               case SDP_PDU_RECORD_UPDATE_REQUEST:
                       error = record_update_request(srv, fd);
                       break;

               case SDP_PDU_RECORD_REMOVE_REQUEST:
                       error = record_remove_request(srv, fd);
                       break;

               default:
                       error = SDP_ERROR_CODE_INVALID_REQUEST_SYNTAX;
                       break;
               }
       }

       if (error != 0) {
               srv->fdidx[fd].offset = 0;
               db_unselect(srv, fd);
               srv->pdu.pid = SDP_PDU_ERROR_RESPONSE;
               srv->pdu.len = sizeof(error);
               be16enc(srv->obuf, error);
               log_debug("sending ErrorResponse (error=0x%04x)", error);
       }

       iov[0].iov_base = &srv->pdu;
       iov[0].iov_len = sizeof(srv->pdu);
       iov[1].iov_base = srv->obuf;
       iov[1].iov_len = srv->pdu.len;

       srv->pdu.len = htobe16(srv->pdu.len);

       msg.msg_name = NULL;
       msg.msg_namelen = 0;
       msg.msg_iov = iov;
       msg.msg_iovlen = __arraycount(iov);
       msg.msg_control = NULL;
       msg.msg_controllen = 0;
       msg.msg_flags = 0;

       do {
               len = sendmsg(fd, &msg, 0);
       } while (len == -1 && errno == EINTR);

       if (len == -1) {
               log_err("Could not send SDP response on %s socket. %s (%d)",
                   srv->fdidx[fd].control ? "control" : "L2CAP",
                   strerror(errno), errno);

               return false;
       }

       return true;
}

/*
* Close descriptor and remove it from index
*/
static void
server_close_fd(server_t *srv, int fd)
{

       assert(FD_ISSET(fd, &srv->fdset));
       assert(srv->fdidx[fd].valid);

       db_unselect(srv, fd);   /* release selected records */
       db_release(srv, fd);    /* expire owned records */

       close(fd);
       FD_CLR(fd, &srv->fdset);
       srv->fdidx[fd].valid = false;

       log_debug("client on fd#%d closed", fd);

       if (fd == srv->fdmax) {
               while (fd > 0 && !srv->fdidx[fd].valid)
                       fd--;

               srv->fdmax = fd;
       }
}

/*
* check credentials, return true when permitted to modify service records
*/
static bool
server_auth_check(server_t *srv, void *data)
{
       struct sockcred *cred = data;
       struct group *grp;
       int n;

       if (cred == NULL)
               return false;

       if (cred->sc_uid == 0 || cred->sc_euid == 0)
               return true;

       if (srv->sgroup == NULL)
               return false;

       grp = getgrnam(srv->sgroup);
       if (grp == NULL) {
               log_err("No gid for group '%s'", srv->sgroup);
               srv->sgroup = NULL;
               return false;
       }

       if (cred->sc_gid == grp->gr_gid || cred->sc_egid == grp->gr_gid)
               return true;

       if (cred->sc_ngroups > MAX_GROUPS) {
               log_info("Credentials truncated, lost %d groups",
                   MAX_GROUPS - cred->sc_ngroups);

               cred->sc_ngroups = MAX_GROUPS;
       }

       for (n = 0 ; n < cred->sc_ngroups ; n++) {
               if (cred->sc_groups[n] == grp->gr_gid)
                       return true;
       }

       return false;
}