/*
* Copyright (c) 2015 The TCPDUMP project
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
*    notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
*    notice, this list of conditions and the following disclaimer in the
*    documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
* FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
* COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
* ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*
* Initial contribution by Andrew Darqui ([email protected]).
*/

#include <sys/cdefs.h>
#ifndef lint
__RCSID("$NetBSD: print-resp.c,v 1.6 2024/09/02 16:15:32 christos Exp $");
#endif

/* \summary: REdis Serialization Protocol (RESP) printer */

#include <config.h>

#include "netdissect-stdinc.h"
#include "netdissect.h"
#include <limits.h>

#include "extract.h"


/*
* For information regarding RESP, see: https://redis.io/topics/protocol
*/

#define RESP_SIMPLE_STRING    '+'
#define RESP_ERROR            '-'
#define RESP_INTEGER          ':'
#define RESP_BULK_STRING      '$'
#define RESP_ARRAY            '*'

#define resp_print_empty(ndo)            ND_PRINT(" empty")
#define resp_print_null(ndo)             ND_PRINT(" null")
#define resp_print_length_too_large(ndo) ND_PRINT(" length too large")
#define resp_print_length_negative(ndo)  ND_PRINT(" length negative and not -1")
#define resp_print_invalid(ndo)          ND_PRINT(" invalid")

static int resp_parse(netdissect_options *, const u_char *, int);
static int resp_print_string_error_integer(netdissect_options *, const u_char *, int);
static int resp_print_simple_string(netdissect_options *, const u_char *, int);
static int resp_print_integer(netdissect_options *, const u_char *, int);
static int resp_print_error(netdissect_options *, const u_char *, int);
static int resp_print_bulk_string(netdissect_options *, const u_char *, int);
static int resp_print_bulk_array(netdissect_options *, const u_char *, int);
static int resp_print_inline(netdissect_options *, const u_char *, int);
static int resp_get_length(netdissect_options *, const u_char *, int, const u_char **);

#define LCHECK2(_tot_len, _len) \
   {                           \
       if (_tot_len < _len)    \
           goto trunc;         \
   }

#define LCHECK(_tot_len) LCHECK2(_tot_len, 1)

/*
* FIND_CRLF:
* Attempts to move our 'ptr' forward until a \r\n is found,
* while also making sure we don't exceed the buffer '_len'
* or go past the end of the captured data.
* If we exceed or go past the end of the captured data,
* jump to trunc.
*/
#define FIND_CRLF(_ptr, _len)                   \
   for (;;) {                                  \
       LCHECK2(_len, 2);                       \
       ND_TCHECK_2(_ptr);                      \
       if (GET_U_1(_ptr) == '\r' &&            \
           GET_U_1(_ptr+1) == '\n')            \
           break;                              \
       _ptr++;                                 \
       _len--;                                 \
   }

/*
* CONSUME_CRLF
* Consume a CRLF that we've just found.
*/
#define CONSUME_CRLF(_ptr, _len) \
   _ptr += 2;                   \
   _len -= 2;

/*
* FIND_CR_OR_LF
* Attempts to move our '_ptr' forward until a \r or \n is found,
* while also making sure we don't exceed the buffer '_len'
* or go past the end of the captured data.
* If we exceed or go past the end of the captured data,
* jump to trunc.
*/
#define FIND_CR_OR_LF(_ptr, _len)           \
   for (;;) {                              \
       LCHECK(_len);                       \
       if (GET_U_1(_ptr) == '\r' ||        \
           GET_U_1(_ptr) == '\n')          \
           break;                          \
       _ptr++;                             \
       _len--;                             \
   }

/*
* CONSUME_CR_OR_LF
* Consume all consecutive \r and \n bytes.
* If we exceed '_len' or go past the end of the captured data,
* jump to trunc.
*/
#define CONSUME_CR_OR_LF(_ptr, _len)             \
   {                                            \
       int _found_cr_or_lf = 0;                 \
       for (;;) {                               \
           /*                                   \
            * Have we hit the end of data?      \
            */                                  \
           if (_len == 0 || !ND_TTEST_1(_ptr)) {\
               /*                               \
                * Yes.  Have we seen a \r       \
                * or \n?                        \
                */                              \
               if (_found_cr_or_lf) {           \
                   /*                           \
                    * Yes.  Just stop.          \
                    */                          \
                   break;                       \
               }                                \
               /*                               \
                * No.  We ran out of packet.    \
                */                              \
               goto trunc;                      \
           }                                    \
           if (GET_U_1(_ptr) != '\r' &&         \
               GET_U_1(_ptr) != '\n')           \
               break;                           \
           _found_cr_or_lf = 1;                 \
           _ptr++;                              \
           _len--;                              \
       }                                        \
   }

/*
* SKIP_OPCODE
* Skip over the opcode character.
* The opcode has already been fetched, so we know it's there, and don't
* need to do any checks.
*/
#define SKIP_OPCODE(_ptr, _tot_len) \
   _ptr++;                         \
   _tot_len--;

/*
* GET_LENGTH
* Get a bulk string or array length.
*/
#define GET_LENGTH(_ndo, _tot_len, _ptr, _len)                \
   {                                                         \
       const u_char *_endp;                                  \
       _len = resp_get_length(_ndo, _ptr, _tot_len, &_endp); \
       _tot_len -= (_endp - _ptr);                           \
       _ptr = _endp;                                         \
   }

/*
* TEST_RET_LEN
* If ret_len is < 0, jump to the trunc tag which returns (-1)
* and 'bubbles up' to printing tstr. Otherwise, return ret_len.
*/
#define TEST_RET_LEN(rl) \
   if (rl < 0) { goto trunc; } else { return rl; }

/*
* TEST_RET_LEN_NORETURN
* If ret_len is < 0, jump to the trunc tag which returns (-1)
* and 'bubbles up' to printing tstr. Otherwise, continue onward.
*/
#define TEST_RET_LEN_NORETURN(rl) \
   if (rl < 0) { goto trunc; }

/*
* RESP_PRINT_SEGMENT
* Prints a segment in the form of: ' "<stuff>"\n"
* Assumes the data has already been verified as present.
*/
#define RESP_PRINT_SEGMENT(_ndo, _bp, _len)            \
   ND_PRINT(" \"");                                   \
   if (nd_printn(_ndo, _bp, _len, _ndo->ndo_snapend)) \
       goto trunc;                                    \
   fn_print_char(_ndo, '"');

void
resp_print(netdissect_options *ndo, const u_char *bp, u_int length)
{
   int ret_len = 0;

   ndo->ndo_protocol = "resp";

   ND_PRINT(": RESP");
   while (length > 0) {
       /*
        * This block supports redis pipelining.
        * For example, multiple operations can be pipelined within the same string:
        * "*2\r\n\$4\r\nINCR\r\n\$1\r\nz\r\n*2\r\n\$4\r\nINCR\r\n\$1\r\nz\r\n*2\r\n\$4\r\nINCR\r\n\$1\r\nz\r\n"
        * or
        * "PING\r\nPING\r\nPING\r\n"
        * In order to handle this case, we must try and parse 'bp' until
        * 'length' bytes have been processed or we reach a trunc condition.
        */
       ret_len = resp_parse(ndo, bp, length);
       TEST_RET_LEN_NORETURN(ret_len);
       bp += ret_len;
       length -= ret_len;
   }

   return;

trunc:
   nd_print_trunc(ndo);
}

static int
resp_parse(netdissect_options *ndo, const u_char *bp, int length)
{
   u_char op;
   int ret_len;

   LCHECK2(length, 1);
   op = GET_U_1(bp);

   /* bp now points to the op, so these routines must skip it */
   switch(op) {
       case RESP_SIMPLE_STRING:  ret_len = resp_print_simple_string(ndo, bp, length);   break;
       case RESP_INTEGER:        ret_len = resp_print_integer(ndo, bp, length);         break;
       case RESP_ERROR:          ret_len = resp_print_error(ndo, bp, length);           break;
       case RESP_BULK_STRING:    ret_len = resp_print_bulk_string(ndo, bp, length);     break;
       case RESP_ARRAY:          ret_len = resp_print_bulk_array(ndo, bp, length);      break;
       default:                  ret_len = resp_print_inline(ndo, bp, length);          break;
   }

   /*
    * This gives up with a "truncated" indicator for all errors,
    * including invalid packet errors; that's what we want, as
    * we have to give up on further parsing in that case.
    */
   TEST_RET_LEN(ret_len);

trunc:
   return (-1);
}

static int
resp_print_simple_string(netdissect_options *ndo, const u_char *bp, int length) {
   return resp_print_string_error_integer(ndo, bp, length);
}

static int
resp_print_integer(netdissect_options *ndo, const u_char *bp, int length) {
   return resp_print_string_error_integer(ndo, bp, length);
}

static int
resp_print_error(netdissect_options *ndo, const u_char *bp, int length) {
   return resp_print_string_error_integer(ndo, bp, length);
}

static int
resp_print_string_error_integer(netdissect_options *ndo, const u_char *bp, int length) {
   int length_cur = length, len, ret_len;
   const u_char *bp_ptr;

   /* bp points to the op; skip it */
   SKIP_OPCODE(bp, length_cur);
   bp_ptr = bp;

   /*
    * bp now prints past the (+-;) opcode, so it's pointing to the first
    * character of the string (which could be numeric).
    * +OK\r\n
    * -ERR ...\r\n
    * :02912309\r\n
    *
    * Find the \r\n with FIND_CRLF().
    */
   FIND_CRLF(bp_ptr, length_cur);

   /*
    * bp_ptr points to the \r\n, so bp_ptr - bp is the length of text
    * preceding the \r\n.  That includes the opcode, so don't print
    * that.
    */
   len = ND_BYTES_BETWEEN(bp, bp_ptr);
   RESP_PRINT_SEGMENT(ndo, bp, len);
   ret_len = 1 /*<opcode>*/ + len /*<string>*/ + 2 /*<CRLF>*/;

   TEST_RET_LEN(ret_len);

trunc:
   return (-1);
}

static int
resp_print_bulk_string(netdissect_options *ndo, const u_char *bp, int length) {
   int length_cur = length, string_len;

   /* bp points to the op; skip it */
   SKIP_OPCODE(bp, length_cur);

   /* <length>\r\n */
   GET_LENGTH(ndo, length_cur, bp, string_len);

   if (string_len >= 0) {
       /* Byte string of length string_len, starting at bp */
       if (string_len == 0)
           resp_print_empty(ndo);
       else {
           LCHECK2(length_cur, string_len);
           ND_TCHECK_LEN(bp, string_len);
           RESP_PRINT_SEGMENT(ndo, bp, string_len);
           bp += string_len;
           length_cur -= string_len;
       }

       /*
        * Find the \r\n at the end of the string and skip past it.
        * XXX - report an error if the \r\n isn't immediately after
        * the item?
        */
       FIND_CRLF(bp, length_cur);
       CONSUME_CRLF(bp, length_cur);
   } else {
       /* null, truncated, or invalid for some reason */
       switch(string_len) {
           case (-1):  resp_print_null(ndo);             break;
           case (-2):  goto trunc;
           case (-3):  resp_print_length_too_large(ndo); break;
           case (-4):  resp_print_length_negative(ndo);  break;
           default:    resp_print_invalid(ndo);          break;
       }
   }

   return (length - length_cur);

trunc:
   return (-1);
}

static int
resp_print_bulk_array(netdissect_options *ndo, const u_char *bp, int length) {
   u_int length_cur = length;
   int array_len, i, ret_len;

   /* bp points to the op; skip it */
   SKIP_OPCODE(bp, length_cur);

   /* <array_length>\r\n */
   GET_LENGTH(ndo, length_cur, bp, array_len);

   if (array_len > 0) {
       /* non empty array */
       for (i = 0; i < array_len; i++) {
           ret_len = resp_parse(ndo, bp, length_cur);

           TEST_RET_LEN_NORETURN(ret_len);

           bp += ret_len;
           length_cur -= ret_len;
       }
   } else {
       /* empty, null, truncated, or invalid */
       switch(array_len) {
           case 0:     resp_print_empty(ndo);            break;
           case (-1):  resp_print_null(ndo);             break;
           case (-2):  goto trunc;
           case (-3):  resp_print_length_too_large(ndo); break;
           case (-4):  resp_print_length_negative(ndo);  break;
           default:    resp_print_invalid(ndo);          break;
       }
   }

   return (length - length_cur);

trunc:
   return (-1);
}

static int
resp_print_inline(netdissect_options *ndo, const u_char *bp, int length) {
   int length_cur = length;
   int len;
   const u_char *bp_ptr;

   /*
    * Inline commands are simply 'strings' followed by \r or \n or both.
    * Redis will do its best to split/parse these strings.
    * This feature of redis is implemented to support the ability of
    * command parsing from telnet/nc sessions etc.
    *
    * <string><\r||\n||\r\n...>
    */

   /*
    * Skip forward past any leading \r, \n, or \r\n.
    */
   CONSUME_CR_OR_LF(bp, length_cur);
   bp_ptr = bp;

   /*
    * Scan forward looking for \r or \n.
    */
   FIND_CR_OR_LF(bp_ptr, length_cur);

   /*
    * Found it; bp_ptr points to the \r or \n, so bp_ptr - bp is the
    * Length of the line text that precedes it.  Print it.
    */
   len = ND_BYTES_BETWEEN(bp, bp_ptr);
   RESP_PRINT_SEGMENT(ndo, bp, len);

   /*
    * Skip forward past the \r, \n, or \r\n.
    */
   CONSUME_CR_OR_LF(bp_ptr, length_cur);

   /*
    * Return the number of bytes we processed.
    */
   return (length - length_cur);

trunc:
   return (-1);
}

static int
resp_get_length(netdissect_options *ndo, const u_char *bp, int len, const u_char **endp)
{
   int result;
   u_char c;
   int saw_digit;
   int neg;
   int too_large;

   if (len == 0)
       goto trunc;
   too_large = 0;
   neg = 0;
   if (GET_U_1(bp) == '-') {
       neg = 1;
       bp++;
       len--;
   }
   result = 0;
   saw_digit = 0;

   for (;;) {
       if (len == 0)
           goto trunc;
       c = GET_U_1(bp);
       if (!(c >= '0' && c <= '9')) {
           if (!saw_digit) {
               bp++;
               goto invalid;
           }
           break;
       }
       c -= '0';
       if (result > (INT_MAX / 10)) {
           /* This will overflow an int when we multiply it by 10. */
           too_large = 1;
       } else {
           result *= 10;
           if (result == ((INT_MAX / 10) * 10) && c > (INT_MAX % 10)) {
               /* This will overflow an int when we add c */
               too_large = 1;
           } else
               result += c;
       }
       bp++;
       len--;
       saw_digit = 1;
   }

   /*
    * OK, we found a non-digit character.  It should be a \r, followed
    * by a \n.
    */
   if (GET_U_1(bp) != '\r') {
       bp++;
       goto invalid;
   }
   bp++;
   len--;
   if (len == 0)
       goto trunc;
   if (GET_U_1(bp) != '\n') {
       bp++;
       goto invalid;
   }
   bp++;
   len--;
   *endp = bp;
   if (neg) {
       /* -1 means "null", anything else is invalid */
       if (too_large || result != 1)
           return (-4);
       result = -1;
   }
   return (too_large ? -3 : result);

trunc:
   *endp = bp;
   return (-2);

invalid:
   *endp = bp;
   return (-5);
}