/* expr.c
  Copyright (C) 2005,2006,2007 Eugene K. Ressler, Jr.

This file is part of Sketch, a small, simple system for making
3d drawings with LaTeX and the PSTricks or TikZ package.

Sketch is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 3, or (at your option)
any later version.

Sketch is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with Sketch; see the file COPYING.txt.  If not, see
http://www.gnu.org/copyleft */

#include <stdio.h>
#include <math.h>
#include "expr.h"
#include "error.h"

#define F "%.3f"

char *expr_val_type_str[] = {
 "float",
 "point",
 "vector",
 "transform",
};

// set expression value to given type and value
void
set_float (EXPR_VAL * r, FLOAT val)
{
 r->tag = E_FLOAT;
 r->val.flt = val;
}

void
print_float (FILE * f, EXPR_VAL * val)
{
 fprintf (f, F, val->val.flt);
}

void
set_point (EXPR_VAL * r, POINT_3D val)
{
 r->tag = E_POINT;
 copy_pt_3d (r->val.pt, val);
}

void
print_point (FILE * f, EXPR_VAL * val)
{
 FLOAT *p = val->val.pt;
 fprintf (f, "(" F "," F "," F ")", p[X], p[Y], p[Z]);
}

void
set_vector (EXPR_VAL * r, VECTOR_3D val)
{
 r->tag = E_VECTOR;
 copy_vec_3d (r->val.vec, val);
}

void
print_vector (FILE * f, EXPR_VAL * val)
{
 FLOAT *v = val->val.vec;
 fprintf (f, "[" F "," F "," F "]", v[X], v[Y], v[Z]);
}

void
set_transform (EXPR_VAL * r, TRANSFORM val)
{
 r->tag = E_TRANSFORM;
 copy_transform (r->val.xf, val);
}

void
print_transform (FILE * f, EXPR_VAL * val)
{
 FLOAT *xf = val->val.xf;
 int i, j;

 fprintf (f, "[");
 for (i = 0; i < 4; i++)
   {
     fprintf (f, "[");
     for (j = 0; j < 16; j += 4)
       fprintf (f, "%s" F, (j == 0) ? "" : ",", xf[i + j]);
     fprintf (f, "]");
   }
 fprintf (f, "]");
}

// coerce an expression value to given type
// generate error message if it can't be done
void
coerce_to_float (EXPR_VAL * r, FLOAT * val, SRC_LINE line)
{
 if (r->tag == E_FLOAT)
   {
     *val = r->val.flt;
   }
 else
   {
     *val = 0;
     err (line, "expected float, found %s", expr_val_type_str[r->tag]);
   }
}

void
coerce_to_point (EXPR_VAL * r, POINT_3D val, SRC_LINE line)
{
 if (r->tag == E_POINT)
   {
     copy_pt_3d (val, r->val.pt);
   }
 else
   {
     val[X] = val[Y] = val[Z] = 0;
     err (line, "expected point, found %s", expr_val_type_str[r->tag]);
   }
}

void
coerce_to_vector (EXPR_VAL * r, VECTOR_3D val, SRC_LINE line)
{
 if (r->tag == E_VECTOR)
   {
     copy_vec_3d (val, r->val.vec);
   }
 else
   {
     val[X] = val[Y] = val[Z] = 0;
     err (line, "expected vector, found %s", expr_val_type_str[r->tag]);
   }
}

void
coerce_to_transform (EXPR_VAL * r, TRANSFORM val, SRC_LINE line)
{
 if (r->tag == E_TRANSFORM)
   {
     copy_transform (val, r->val.xf);
   }
 else
   {
     set_ident (val);
     err (line, "expected transform, found %s", expr_val_type_str[r->tag]);
   }
}

typedef void (*PRINT_FUNC) (FILE *, EXPR_VAL *);

static PRINT_FUNC print_expr_val_tbl[] = {
 print_float,
 print_point,
 print_vector,
 print_transform,
};

void
print_expr_val (FILE * f, EXPR_VAL * r)
{
 (*print_expr_val_tbl[r->tag]) (f, r);
}

#define HASH(A, B) (((A) << 2) | (B))

void
do_add (EXPR_VAL * r, EXPR_VAL * a, EXPR_VAL * b, SRC_LINE line)
{
 switch (HASH (a->tag, b->tag))
   {
   case HASH (E_FLOAT, E_FLOAT):
     set_float (r, a->val.flt + b->val.flt);
     break;
   case HASH (E_POINT, E_VECTOR):
     r->tag = E_POINT;
     add_vec_to_pt_3d (r->val.pt, a->val.pt, b->val.vec);
     break;
   case HASH (E_VECTOR, E_POINT):
     r->tag = E_POINT;
     add_vec_to_pt_3d (r->val.pt, b->val.pt, a->val.vec);
     break;
   case HASH (E_VECTOR, E_VECTOR):
     r->tag = E_VECTOR;
     add_vecs_3d (r->val.vec, a->val.vec, b->val.vec);
     break;
   default:
     err (line, "operands of + (types %s and %s) cannot be added",
          expr_val_type_str[a->tag], expr_val_type_str[b->tag]);
     set_float (r, 0);
     break;
   }
}

void
do_sub (EXPR_VAL * r, EXPR_VAL * a, EXPR_VAL * b, SRC_LINE line)
{
 switch (HASH (a->tag, b->tag))
   {
   case HASH (E_FLOAT, E_FLOAT):
     set_float (r, a->val.flt - b->val.flt);
     break;
   case HASH (E_POINT, E_POINT):
     r->tag = E_VECTOR;
     sub_pts_3d (r->val.vec, a->val.pt, b->val.pt);
     break;
   case HASH (E_POINT, E_VECTOR):
     r->tag = E_POINT;
     add_scaled_vec_to_pt_3d (r->val.pt, a->val.pt, b->val.vec, -1);
     break;
   case HASH (E_VECTOR, E_VECTOR):
     r->tag = E_VECTOR;
     sub_vecs_3d (r->val.vec, a->val.vec, b->val.vec);
     break;
   default:
     err (line, "operands of - (types %s and %s) cannot be subtracted",
          expr_val_type_str[a->tag], expr_val_type_str[b->tag]);
     set_float (r, 0);
     break;
   }
}

void
do_mul (EXPR_VAL * r, EXPR_VAL * a, EXPR_VAL * b, SRC_LINE line)
{
 switch (HASH (a->tag, b->tag))
   {
   case HASH (E_FLOAT, E_FLOAT):
     set_float (r, a->val.flt * b->val.flt);
     break;
   case HASH (E_VECTOR, E_FLOAT):
     r->tag = E_VECTOR;
     scale_vec_3d (r->val.vec, a->val.vec, b->val.flt);
     break;
   case HASH (E_FLOAT, E_VECTOR):
     r->tag = E_VECTOR;
     scale_vec_3d (r->val.vec, b->val.vec, a->val.flt);
     break;
   case HASH (E_VECTOR, E_VECTOR):
     r->tag = E_VECTOR;
     cross (r->val.vec, a->val.vec, b->val.vec);
     break;
   case HASH (E_TRANSFORM, E_TRANSFORM):
     r->tag = E_TRANSFORM;
     compose (r->val.xf, a->val.xf, b->val.xf);
     break;
   case HASH (E_TRANSFORM, E_POINT):
     r->tag = E_POINT;
     transform_pt_3d (r->val.pt, a->val.xf, b->val.pt);
     break;
   case HASH (E_TRANSFORM, E_VECTOR):
     r->tag = E_VECTOR;
     transform_vec_3d (r->val.vec, a->val.xf, b->val.vec);
     break;
   default:
     err (line, "operands of * (types %s and %s) cannot be multiplied",
          expr_val_type_str[a->tag], expr_val_type_str[b->tag]);
     set_float (r, 0);
     break;
   }
}

void
do_thn (EXPR_VAL * r, EXPR_VAL * a, EXPR_VAL * b, SRC_LINE line)
{
 switch (HASH (a->tag, b->tag))
   {
   case HASH (E_TRANSFORM, E_TRANSFORM):
     r->tag = E_TRANSFORM;
     compose (r->val.xf, b->val.xf, a->val.xf);
     break;
   case HASH (E_POINT, E_TRANSFORM):
     r->tag = E_POINT;
     transform_pt_3d (r->val.pt, b->val.xf, a->val.pt);
     break;
   case HASH (E_VECTOR, E_TRANSFORM):
     r->tag = E_VECTOR;
     transform_vec_3d (r->val.vec, b->val.xf, a->val.vec);
     break;
   default:
     err (line,
          "operands of 'then' (types %s and %s) cannot be multiplied",
          expr_val_type_str[a->tag], expr_val_type_str[b->tag]);
     set_float (r, 0);
     break;
   }
}

static FLOAT
safe_dvd (FLOAT a, FLOAT b, SRC_LINE line)
{
 if (-FLOAT_EPS < b && b < FLOAT_EPS)
   {
     err (line, "attempt to divide " F " by zero", a);
     return 0;
   }
 return a / b;
}

void
do_dvd (EXPR_VAL * r, EXPR_VAL * a, EXPR_VAL * b, SRC_LINE line)
{
 switch (HASH (a->tag, b->tag))
   {
   case HASH (E_FLOAT, E_FLOAT):
     set_float (r, safe_dvd (a->val.flt, b->val.flt, line));
     break;
   case HASH (E_VECTOR, E_FLOAT):
     r->tag = E_VECTOR;
     scale_vec_3d (r->val.vec, a->val.vec, safe_dvd (1, b->val.flt, line));
     break;
   case HASH (E_FLOAT, E_VECTOR):
     r->tag = E_VECTOR;
     scale_vec_3d (r->val.vec, b->val.vec, safe_dvd (1, a->val.flt, line));
     break;
   default:
     err (line, "operands of / (types %s and %s) cannot be divided",
          expr_val_type_str[a->tag], expr_val_type_str[b->tag]);
     set_float (r, 0);
     break;
   }
}

void
do_dot (EXPR_VAL * r, EXPR_VAL * a, EXPR_VAL * b, SRC_LINE line)
{
 switch (HASH (a->tag, b->tag))
   {
   case HASH (E_VECTOR, E_VECTOR):
     r->tag = E_FLOAT;
     r->val.flt = dot_3d (a->val.vec, b->val.vec);
     break;
   case HASH (E_FLOAT, E_FLOAT):
   case HASH (E_VECTOR, E_FLOAT):
   case HASH (E_FLOAT, E_VECTOR):
   case HASH (E_TRANSFORM, E_TRANSFORM):
   case HASH (E_TRANSFORM, E_POINT):
   case HASH (E_TRANSFORM, E_VECTOR):
     do_mul (r, a, b, line);
     break;
   default:
     err (line, "operands of dot (types %s and %s) cannot be multiplied",
          expr_val_type_str[a->tag], expr_val_type_str[b->tag]);
     set_float (r, 0);
     break;
   }
}

void
do_index (EXPR_VAL * r, EXPR_VAL * a, int index, SRC_LINE line)
{
 switch (a->tag)
   {
   case E_VECTOR:
     set_float (r, a->val.vec[index]);
     break;
   case E_POINT:
     set_float (r, a->val.pt[index]);
     break;
   default:
     err (line,
          "operand of 'index is a %s and should be a point or a vector",
          expr_val_type_str[a->tag]);
     set_float (r, 0);
     break;
   }
}

void
do_inverse (TRANSFORM inv, TRANSFORM xf, SRC_LINE line)
{
 FLOAT det;
 invert (inv, &det, xf, 1e-4);
 if (det == 0)
   {
     err (line, "inverse of singular transform");
     set_ident (inv);
   }
}

// put a^n into r;  r and a can't both be the same storage
// exploits a^(2n) = (a^n)^2 to reduce work
void
do_transform_power (TRANSFORM r, TRANSFORM a, int n, SRC_LINE line)
{
 if (n < 0)
   {
     TRANSFORM inv;
     do_inverse (inv, a, line);
     do_transform_power (r, inv, -n, line);
   }
 else if (n == 0)
   {
     set_ident (r);
   }
 else
   {
     int m = (int) bit (30);
     while ((m & n) == 0)
       m >>= 1;
     copy_transform (r, a);
     for (m >>= 1; m; m >>= 1)
       {
         compose (r, r, r);
         if (m & n)
           compose (r, r, a);
       }
   }
}

int
to_integer (FLOAT x, int *n)
{
 double frac_part, int_part;

 frac_part = modf (x, &int_part);
 if (-1e9 <= int_part && int_part <= 1e9)
   {
     *n = (int) int_part;
     return 1;
   }
 return 0;
}

void
do_pwr (EXPR_VAL * r, EXPR_VAL * a, EXPR_VAL * b, SRC_LINE line)
{
 TRANSFORM xf_pwr;
 int n;

 switch (HASH (a->tag, b->tag))
   {
   case HASH (E_FLOAT, E_FLOAT):
     set_float (r, pow (a->val.flt, b->val.flt));
     break;
   case HASH (E_TRANSFORM, E_FLOAT):
     if (to_integer (b->val.flt, &n))
       {
         do_transform_power (xf_pwr, a->val.xf, n, line);
       }
     else
       {
         err (line, "transform power out of domain (integer -1e9..1e9)");
         set_ident (xf_pwr);
       }
     set_transform (r, xf_pwr);
     break;
   default:
     err (line, "operands of ^ (types %s and %s) must be type float",
          expr_val_type_str[a->tag], expr_val_type_str[b->tag]);
     set_float (r, 0);
     break;
   }
}

void
do_mag (EXPR_VAL * r, EXPR_VAL * a, SRC_LINE line)
{
 switch (a->tag)
   {
   case E_FLOAT:
     set_float (r, a->val.flt >= 0 ? a->val.flt : -a->val.flt);
     break;
   case E_VECTOR:
     set_float (r, length_vec_3d (a->val.vec));
     break;
   default:
     err (line, "operand of magnitude operator (type %s) must be vector",
          expr_val_type_str[a->tag]);
     *r = *a;
     break;
   }
}

void
do_neg (EXPR_VAL * r, EXPR_VAL * a, SRC_LINE line)
{
 switch (a->tag)
   {
   case E_FLOAT:
     set_float (r, -a->val.flt);
     break;
   case E_VECTOR:
     r->tag = E_VECTOR;
     negate_vec_3d (r->val.vec, a->val.vec);
     break;
   default:
     err (line, "operand of unary minus (type %s) cannot be negated",
          expr_val_type_str[a->tag]);
     *r = *a;
     break;
   }
}

void
do_unit (EXPR_VAL * r, EXPR_VAL * a, SRC_LINE line)
{
 if (a->tag == E_VECTOR)
   {
     r->tag = E_VECTOR;
     find_unit_vec_3d (r->val.vec, a->val.vec);
   }
 else
   {
     static VECTOR_3D k = { 0, 0, 1 };
     err (line, "operand of unit operator (type %s) must be vector",
          expr_val_type_str[a->tag]);
     set_vector (r, k);
   }
}

void
do_sqrt (EXPR_VAL * r, EXPR_VAL * a, SRC_LINE line)
{
 switch (a->tag)
   {
   case E_FLOAT:
     if (a->val.flt < 0)
       err (line, "square root of negative number");
     set_float (r, sqrt (a->val.flt));
     break;
   default:
     err (line, "operand of sqrt (type %s) must be float",
          expr_val_type_str[a->tag]);
     break;
   }
}

void
do_sin (EXPR_VAL * r, EXPR_VAL * a, SRC_LINE line)
{
 switch (a->tag)
   {
   case E_FLOAT:
     set_float (r, sin ((PI / 180) * a->val.flt));
     break;
   default:
     err (line, "operand of sin (type %s) must be float",
          expr_val_type_str[a->tag]);
     break;
   }
}

void
do_cos (EXPR_VAL * r, EXPR_VAL * a, SRC_LINE line)
{
 switch (a->tag)
   {
   case E_FLOAT:
     set_float (r, cos ((PI / 180) * a->val.flt));
     break;
   default:
     err (line, "operand of cos (type %s) must be float",
          expr_val_type_str[a->tag]);
     break;
   }
}

void
do_atan2 (EXPR_VAL * r, EXPR_VAL * a, EXPR_VAL * b, SRC_LINE line)
{
 switch (HASH (a->tag, b->tag))
   {
   case HASH (E_FLOAT, E_FLOAT):
     set_float (r, (180 / PI) * atan2 (a->val.flt, b->val.flt));
     break;
   default:
     err (line, "operands of atan2 (types %s, %s) must be float",
          expr_val_type_str[a->tag], expr_val_type_str[b->tag]);
     break;
   }
}