(uiop:define-package :st-buchberger/src/arithmetic
(:mix :cl)
(:mix-reexport :st-buchberger/src/polynomial)
(:export #:ring+ #:ring- #:ring* #:ring/
 #:add #:sub #:mul #:divides-p #:div #:divmod
 #:ring-mod #:ring-lcm))

(in-package #:st-buchberger/src/arithmetic)

(defmethod ring+ ((poly polynomial) &rest more-polynomials)
 (reduce #'add more-polynomials :initial-value poly))

(defmethod ring- ((poly polynomial) &rest more-polynomials)
 (reduce #'sub more-polynomials :initial-value poly))

(defmethod ring* ((poly polynomial) &rest more-polynomials)
 (reduce #'mul more-polynomials :initial-value poly))

(defmethod ring/ ((poly polynomial) &rest more-polynomials)
 (if (some #'ring-zero-p more-polynomials)
     (error 'ring-division-by-zero
            :operands (list poly more-polynomials)))
 (divmod poly
         (make-array (length more-polynomials)
                     :initial-contents more-polynomials)))

(defmethod add ((poly polynomial) (tm term))
 "Returns the polynomial with the added term"
 (let ((new-poly (ring-copy poly)))
   (with-slots (terms) new-poly
     (with-slots (coefficient monomial) tm
       (let* ((old-value (gethash monomial terms 0))
              (new-value (+ old-value coefficient)))
         (if (zerop new-value)
             (remhash monomial terms)
             (setf (gethash monomial terms) new-value)))))
   new-poly))

(defmethod add ((p1 polynomial) (p2 polynomial))
 "Returns the sum of two polynomials"
 (let ((new-poly (ring-copy p1)))
   (doterms (tm p2 new-poly)
     (setf new-poly (add new-poly tm)))))

(defmethod sub ((poly polynomial) (tm term))
 "Returns the result of subtracting the term from the polynomial"
 (with-slots (coefficient monomial) tm
   (add poly (make-instance 'term
                            :coefficient (* -1 coefficient)
                            :monomial monomial))))

(defmethod sub ((p1 polynomial) (p2 polynomial))
 "Returns the result of subtracting two polynomials."
 (add p1
      (mul p2
           (make-instance
            'term
            :ring (base-ring p2)
            :coefficient -1
            :monomial (make-array (length (variables (base-ring p2)))
                                  :initial-element 0)))))

(defmethod mul ((poly polynomial) (num number))
 (let ((new-poly (make-instance 'polynomial :ring (base-ring poly))))
   (doterms (tt poly new-poly)
     (setf new-poly (add new-poly (mul tt num))))))

(defmethod mul ((t1 term) (num number))
 (make-instance 'term
                :coefficient (* num (coefficient t1))
                :monomial (monomial t1)))

(defmethod mul ((t1 term) (t2 term))
 "Multiplies two terms storing the result in the first term."
 (with-slots ((c1 coefficient) (m1 monomial)) t1
   (with-slots ((c2 coefficient) (m2 monomial)) t2
     (make-instance 'term
                    :coefficient (* c1 c2)
                    :monomial (map 'vector #'+ m1 m2)))))

(defmethod mul ((poly polynomial) (tm term))
 "Returns the product of a polynomial by a term"
 (let ((new-poly (make-instance 'polynomial :ring (base-ring poly))))
   (doterms (tt poly new-poly)
     (setf new-poly (add new-poly (mul tt tm))))))

(defmethod mul ((p1 polynomial) (p2 polynomial))
 "Returns the product of two polynomials"
 (assert (and p1 p2))
 (let ((new-poly (make-instance 'polynomial :ring (base-ring p1))))
   (doterms (tm p2 new-poly)
     (setf new-poly (add new-poly (mul p1 tm))))))

(defmethod divides-p ((t1 term) (t2 term))
 (assert (and t1 t2))
 (with-slots ((c1 coefficient) (m1 monomial)) t1
   (with-slots ((c2 coefficient) (m2 monomial)) t2
     ;; We don't have to check (divides-p c1 c2) because we're working
     ;; with polynomials over a *field*
     (and (not (zerop c1))
          (every #'>= m2 m1)))))

(defmethod divides-p ((t1 term) (p polynomial))
 (assert (and t1 p))
 (every (lambda (x) (divides-p t1 x)) (terms->list p)))

(defmethod div ((t1 term) (t2 term))
 "Returns the quotient of two terms"
 (assert (divides-p t2 t1))
 (with-slots ((c1 coefficient) (m1 monomial)) t1
   (with-slots ((c2 coefficient) (m2 monomial)) t2
     (make-instance 'term
                    :coefficient (/ c1 c2)
                    :monomial (vector- m1 m2)))))

(defmethod divmod ((f polynomial) fs)
 "Divides F by the polynomials in the sequence FS and returns the
quotients (as an array) and a remainder."
 (flet ((init-vector (poly n)
          (let ((vs (make-array n :fill-pointer 0)))
            (dotimes (i n vs)
              (vector-push (make-instance 'polynomial
                                          :ring (base-ring poly))
                           vs)))))
   (loop :with p := (ring-copy f)
         :with as := (init-vector f (length fs))
         :with r := (make-instance 'polynomial :ring (base-ring f))
         :while (not (ring-zero-p p)) :do
           (loop :with division-occurred-p := nil
                 :with i := 0
                 :while (and (< i (length fs))
                             (not division-occurred-p))
                 :do
                    (let ((fi (elt fs i)))
                      (if (divides-p (lt fi) (lt p))
                          (setf (elt as i) (add (elt as i)
                                                (div (lt p) (lt fi)))
                                p (sub p (mul fi (div (lt p) (lt fi))))
                                division-occurred-p t)
                          (incf i)))
                 :finally (unless division-occurred-p
                            (setf r (add r (lt p))
                                  p (sub p (lt p)))))
         :finally (return (values as r)))))

(defmethod ring-mod ((f polynomial) &rest fs)
 (when (apply #'some #'ring-zero-p fs)
   (error 'ring-division-by-zero :operands (list f fs)))
 (nth-value 1 (apply #'divmod f fs)))

(defmethod ring-lcm ((t1 term) (t2 term))
 "Returns LCM(t1, t2)"
 (with-slots ((c1 coefficient) (m1 monomial)) t1
   (with-slots ((c2 coefficient) (m2 monomial)) t2
     (make-instance 'term
                    :coefficient (ring-lcm c1 c2)
                    :monomial (map 'vector #'max m1 m2)))))

(defmethod ring-lcm ((r1 rational) (r2 rational))
 "LCM over the rationals."
 ;; Translated from SAGE.
 (let ((d (* (denominator r1) (denominator r2)))
       (r1-d (* (numerator r1) (denominator r2)))
       (r2-d (* (numerator r2) (denominator r1))))
   (/ (lcm r1-d r2-d) d)))