/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under both the BSD-style license (found in the
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
* in the COPYING file in the root directory of this source tree).
* You may select, at your option, one of the above-listed licenses.
*/

#include <assert.h>
#include <getopt.h>
#include <stdio.h>
#include <string.h>

#include "config.h"
#include "data.h"
#include "method.h"

static int g_max_name_len = 0;

/** Check if a name contains a comma or is too long. */
static int is_name_bad(char const* name) {
   if (name == NULL)
       return 1;
   int const len = strlen(name);
   if (len > g_max_name_len)
       g_max_name_len = len;
   for (; *name != '\0'; ++name)
       if (*name == ',')
           return 1;
   return 0;
}

/** Check if any of the names contain a comma. */
static int are_names_bad() {
   for (size_t method = 0; methods[method] != NULL; ++method)
       if (is_name_bad(methods[method]->name)) {
           fprintf(stderr, "method name %s is bad\n", methods[method]->name);
           return 1;
       }
   for (size_t datum = 0; data[datum] != NULL; ++datum)
       if (is_name_bad(data[datum]->name)) {
           fprintf(stderr, "data name %s is bad\n", data[datum]->name);
           return 1;
       }
   for (size_t config = 0; configs[config] != NULL; ++config)
       if (is_name_bad(configs[config]->name)) {
           fprintf(stderr, "config name %s is bad\n", configs[config]->name);
           return 1;
       }
   return 0;
}

/**
* Option parsing using getopt.
* When you add a new option update: long_options, long_extras, and
* short_options.
*/

/** Option variables filled by parse_args. */
static char const* g_output = NULL;
static char const* g_diff = NULL;
static char const* g_cache = NULL;
static char const* g_zstdcli = NULL;
static char const* g_config = NULL;
static char const* g_data = NULL;
static char const* g_method = NULL;

typedef enum {
   required_option,
   optional_option,
   help_option,
} option_type;

/**
* Extra state that we need to keep per-option that we can't store in getopt.
*/
struct option_extra {
   int id; /**< The short option name, used as an id. */
   char const* help; /**< The help message. */
   option_type opt_type; /**< The option type: required, optional, or help. */
   char const** value; /**< The value to set or NULL if no_argument. */
};

/** The options. */
static struct option long_options[] = {
   {"cache", required_argument, NULL, 'c'},
   {"output", required_argument, NULL, 'o'},
   {"zstd", required_argument, NULL, 'z'},
   {"config", required_argument, NULL, 128},
   {"data", required_argument, NULL, 129},
   {"method", required_argument, NULL, 130},
   {"diff", required_argument, NULL, 'd'},
   {"help", no_argument, NULL, 'h'},
};

static size_t const nargs = sizeof(long_options) / sizeof(long_options[0]);

/** The extra info for the options. Must be in the same order as the options. */
static struct option_extra long_extras[] = {
   {'c', "the cache directory", required_option, &g_cache},
   {'o', "write the results here", required_option, &g_output},
   {'z', "zstd cli tool", required_option, &g_zstdcli},
   {128, "use this config", optional_option, &g_config},
   {129, "use this data", optional_option, &g_data},
   {130, "use this method", optional_option, &g_method},
   {'d', "compare the results to this file", optional_option, &g_diff},
   {'h', "display this message", help_option, NULL},
};

/** The short options. Must correspond to the options. */
static char const short_options[] = "c:d:ho:z:";

/** Return the help string for the option type. */
static char const* required_message(option_type opt_type) {
   switch (opt_type) {
       case required_option:
           return "[required]";
       case optional_option:
           return "[optional]";
       case help_option:
           return "";
       default:
           assert(0);
           return NULL;
   }
}

/** Print the help for the program. */
static void print_help(void) {
   fprintf(stderr, "regression test runner\n");
   size_t const nargs = sizeof(long_options) / sizeof(long_options[0]);
   for (size_t i = 0; i < nargs; ++i) {
       if (long_options[i].val < 128) {
           /* Long / short  - help [option type] */
           fprintf(
               stderr,
               "--%s / -%c \t- %s %s\n",
               long_options[i].name,
               long_options[i].val,
               long_extras[i].help,
               required_message(long_extras[i].opt_type));
       } else {
           /* Short / long  - help [option type] */
           fprintf(
               stderr,
               "--%s      \t- %s %s\n",
               long_options[i].name,
               long_extras[i].help,
               required_message(long_extras[i].opt_type));
       }
   }
}

/** Parse the arguments. Return 0 on success. Print help on failure. */
static int parse_args(int argc, char** argv) {
   int option_index = 0;
   int c;

   while (1) {
       c = getopt_long(argc, argv, short_options, long_options, &option_index);
       if (c == -1)
           break;

       int found = 0;
       for (size_t i = 0; i < nargs; ++i) {
           if (c == long_extras[i].id && long_extras[i].value != NULL) {
               *long_extras[i].value = optarg;
               found = 1;
               break;
           }
       }
       if (found)
           continue;

       switch (c) {
           case 'h':
           case '?':
           default:
               print_help();
               return 1;
       }
   }

   int bad = 0;
   for (size_t i = 0; i < nargs; ++i) {
       if (long_extras[i].opt_type != required_option)
           continue;
       if (long_extras[i].value == NULL)
           continue;
       if (*long_extras[i].value != NULL)
           continue;
       fprintf(
           stderr,
           "--%s is a required argument but is not set\n",
           long_options[i].name);
       bad = 1;
   }
   if (bad) {
       fprintf(stderr, "\n");
       print_help();
       return 1;
   }

   return 0;
}

/** Helper macro to print to stderr and a file. */
#define tprintf(file, ...)            \
   do {                              \
       fprintf(file, __VA_ARGS__);   \
       fprintf(stderr, __VA_ARGS__); \
   } while (0)
/** Helper macro to flush stderr and a file. */
#define tflush(file)    \
   do {                \
       fflush(file);   \
       fflush(stderr); \
   } while (0)

void tprint_names(
   FILE* results,
   char const* data_name,
   char const* config_name,
   char const* method_name) {
   int const data_padding = g_max_name_len - strlen(data_name);
   int const config_padding = g_max_name_len - strlen(config_name);
   int const method_padding = g_max_name_len - strlen(method_name);

   tprintf(
       results,
       "%s, %*s%s, %*s%s, %*s",
       data_name,
       data_padding,
       "",
       config_name,
       config_padding,
       "",
       method_name,
       method_padding,
       "");
}

/**
* Run all the regression tests and record the results table to results and
* stderr progressively.
*/
static int run_all(FILE* results) {
   tprint_names(results, "Data", "Config", "Method");
   tprintf(results, "Total compressed size\n");
   for (size_t method = 0; methods[method] != NULL; ++method) {
       if (g_method != NULL && strcmp(methods[method]->name, g_method))
           continue;
       for (size_t datum = 0; data[datum] != NULL; ++datum) {
           if (g_data != NULL && strcmp(data[datum]->name, g_data))
               continue;
           /* Create the state common to all configs */
           method_state_t* state = methods[method]->create(data[datum]);
           for (size_t config = 0; configs[config] != NULL; ++config) {
               if (g_config != NULL && strcmp(configs[config]->name, g_config))
                   continue;
               if (config_skip_data(configs[config], data[datum]))
                   continue;
               /* Print the result for the (method, data, config) tuple. */
               result_t const result =
                   methods[method]->compress(state, configs[config]);
               if (result_is_skip(result))
                   continue;
               tprint_names(
                   results,
                   data[datum]->name,
                   configs[config]->name,
                   methods[method]->name);
               if (result_is_error(result)) {
                   tprintf(results, "%s\n", result_get_error_string(result));
               } else {
                   tprintf(
                       results,
                       "%llu\n",
                       (unsigned long long)result_get_data(result).total_size);
               }
               tflush(results);
           }
           methods[method]->destroy(state);
       }
   }
   return 0;
}

/** memcmp() the old results file and the new results file. */
static int diff_results(char const* actual_file, char const* expected_file) {
   data_buffer_t const actual = data_buffer_read(actual_file);
   data_buffer_t const expected = data_buffer_read(expected_file);
   int ret = 1;

   if (actual.data == NULL) {
       fprintf(stderr, "failed to open results '%s' for diff\n", actual_file);
       goto out;
   }
   if (expected.data == NULL) {
       fprintf(
           stderr,
           "failed to open previous results '%s' for diff\n",
           expected_file);
       goto out;
   }

   ret = data_buffer_compare(actual, expected);
   if (ret != 0) {
       fprintf(
           stderr,
           "actual results '%s' does not match expected results '%s'\n",
           actual_file,
           expected_file);
   } else {
       fprintf(stderr, "actual results match expected results\n");
   }
out:
   data_buffer_free(actual);
   data_buffer_free(expected);
   return ret;
}

int main(int argc, char** argv) {
   /* Parse args and validate modules. */
   int ret = parse_args(argc, argv);
   if (ret != 0)
       return ret;

   if (are_names_bad())
       return 1;

   /* Initialize modules. */
   method_set_zstdcli(g_zstdcli);
   ret = data_init(g_cache);
   if (ret != 0) {
       fprintf(stderr, "data_init() failed with error=%s\n", strerror(ret));
       return 1;
   }

   /* Run the regression tests. */
   ret = 1;
   FILE* results = fopen(g_output, "w");
   if (results == NULL) {
       fprintf(stderr, "Failed to open the output file\n");
       goto out;
   }
   ret = run_all(results);
   fclose(results);

   if (ret != 0)
       goto out;

   if (g_diff)
       /* Diff the new results with the previous results. */
       ret = diff_results(g_output, g_diff);

out:
   data_finish();
   return ret;
}