--
-- Copyright (c) 2021-2025 Zeping Lee
-- Released under the MIT license.
-- Repository: https://github.com/zepinglee/citeproc-lua
--

local date_module = {}

local element
local ir_node
local output
local util

local using_luatex, kpse = pcall(require, "kpse")
if using_luatex then
 element = require("citeproc-element")
 ir_node = require("citeproc-ir-node")
 output = require("citeproc-output")
 util = require("citeproc-util")
else
 element = require("citeproc.element")
 ir_node = require("citeproc.ir-node")
 output = require("citeproc.output")
 util = require("citeproc.util")
end

local Element = element.Element
local Rendered = ir_node.Rendered
local SeqIr = ir_node.SeqIr
local GroupVar = ir_node.GroupVar
local PlainText = output.PlainText


-- [Date](https://docs.citationstyles.org/en/stable/specification.html#date)
---@class Date: Element
---@field children DatePart[]?
---@field variable string
---@field form string?
---@field date_parts string?
---@field delimiter string?
---@field prefix string?
---@field suffix string?
---@field text_case string?
local Date = Element:derive("date")

function Date:from_node(node)
 local o = Date:new()

 o:set_attribute(node, "variable")
 o:set_attribute(node, "form")
 o:set_attribute(node, "date-parts")

 if o.form and not o.date_parts then
   o.date_parts = "year-month-day"
 end

 o:get_delimiter_attribute(node)
 o:set_formatting_attributes(node)
 o:set_affixes_attributes(node)
 o:set_display_attribute(node)
 o:set_text_case_attribute(node)

 o.children = {}
 o:process_children_nodes(node)

 for _, date_part in ipairs(o.children) do
   o[date_part.name] = date_part
 end

 return o
end

function Date:build_ir(engine, state, context)
 local variable
 if not state.suppressed[self.variable] then
   ---@type DateVariable?
   variable = context:get_variable(self.variable)
 end

 if not variable then
   local ir = Rendered:new({}, self)
   ir.group_var = GroupVar.Missing
   return ir
 end

 local ir

 if variable["date-parts"] and #variable["date-parts"] > 0 then

   -- TODO: move input normlization in one place
   for i = 1, 2 do
     if variable["date-parts"][i] then
       for j = 1, 3 do
         local variabel_part = variable["date-parts"][i][j]
         if variabel_part == 0 or variabel_part == "" then
           variable["date-parts"][i][j] = nil
         else
           variable["date-parts"][i][j] = tonumber(variabel_part)
         end
       end
     end
   end
   if variable["season"] and not variable["date-parts"][1][2] then
     variable["date-parts"][1][2] = 20 + tonumber(variable["season"])
   end

   variable = variable["date-parts"]
   if self.form then
     ir = self:build_localized_date_ir(variable, engine, state, context)
   else
     ir = self:build_independent_date_ir(variable, engine, state, context)
   end
   ir.affixes = self.affixes

 elseif variable["literal"] then
   local inlines = self:render_text_inlines(variable["literal"], context)
   ir = Rendered:new(inlines, self)
   ir.group_var = GroupVar.Important

 elseif variable["raw"] then
   local inlines = self:render_text_inlines(variable["raw"], context)
   ir = Rendered:new(inlines, self)
   ir.group_var = GroupVar.Important

 end

 if not ir then
   -- date_LiteralFailGracefullyIfNoValue.txt
   ir = Rendered:new({}, self)
   if context.sort_key then
     ir.sort_key = false
   end
   ir.group_var = GroupVar.Missing
   return ir
 end

 if ir.group_var == GroupVar.Important then
   -- Suppress substituted name variable
   if state.name_override and not context.sort_key then
     state.suppressed[self.variable] = true
   end
 end

 if context.sort_key then
   ir.sort_key = self:render_sort_key(engine, state, context)
 end

 return ir
end

function Date:build_independent_date_ir(variable, engine, state, context)
 -- else
 --   local literal = variable["literal"]
 --   if literal then
 --     res = literal
 --   else
 --     local raw = variable["raw"]
 --     if raw then
 --       res = raw
 --     end
 --   end

 return self:build_date_parts(self.children, variable, self.delimiter, engine, state, context)
end

function Date:build_localized_date_ir(variable, engine, state, context)
 local date_part_mask = {}
 for _, part in ipairs(util.split(self.date_parts or "year-month-day", "%-")) do
   date_part_mask[part] = true
 end
 -- local date_parts = {}
 -- for _, date_part in ipairs(self.children) do
 --   date_parts[date_part.name] = date_part
 -- end
 local localized_date = context:get_localized_date(self.form)
 local date_parts = {}
 for _, date_part in ipairs(localized_date.children) do
   if date_part_mask[date_part.name] then
     date_part = date_part:copy()
     local local_date_part = self[date_part.name]
     if local_date_part then
       local_date_part:override(date_part)
     end
     table.insert(date_parts, date_part)
   end
 end
 return self:build_date_parts(date_parts, variable, localized_date.delimiter, engine, state, context)
end

function Date:build_date_parts(date_parts, variable, delimiter, engine, state, context)
 if #variable >= 2 then
   return self:build_date_range(date_parts, variable, delimiter, engine, state, context)
 elseif #variable == 1 then
   return self:build_single_date(date_parts, variable[1], delimiter, engine, state, context)
 end
end

function Date:build_single_date(date_parts, single_date, delimiter, engine, state, context)
 local irs = {}
 for _, date_part in ipairs(date_parts) do
   local part_ir = date_part:build_ir(single_date, engine, state, context)
   table.insert(irs, part_ir)
 end

 local ir = SeqIr:new(irs, self)
 ir.delimiter = self.delimiter

 -- return Rendered:new(inlines, self)
 return ir
end

local date_part_index = {
 year = 1,
 month = 2,
 day = 3,
}

function Date:build_date_range(date_parts, variable, delimiter, engine, state, context)
 local first, second = variable[1], variable[2]
 local diff_level = 4
 for _, date_part in ipairs(date_parts) do
   local part_index = date_part_index[date_part.name]
   if first[part_index] and first[part_index] ~= second[part_index] then
     if part_index < diff_level then
       diff_level = part_index
     end
   end
 end

 local irs = {}

 local range_part_queue = {}
 local range_delimiter
 for i, date_part in ipairs(date_parts) do
   local part_index = date_part_index[date_part.name]
   if part_index == diff_level then
     range_delimiter = date_part.range_delimiter or util.unicode["en dash"]
   end
   if first[part_index] then
     if part_index >= diff_level then
       table.insert(range_part_queue, date_part)
     else
       if #range_part_queue > 0 then
         table.insert(irs, self:build_date_range_parts(range_part_queue, variable,
           delimiter, engine, state, context, range_delimiter))
         range_part_queue = {}
       end
       table.insert(irs, date_part:build_ir(first, engine, state, context))
     end
   end
 end
 if #range_part_queue > 0 then
   table.insert(irs, self:build_date_range_parts(range_part_queue, variable,
     delimiter, engine, state, context, range_delimiter))
 end

 local ir = SeqIr:new(irs, self)
 ir.delimiter = delimiter

 return ir
end

function Date:build_date_range_parts(range_part_queue, variable, delimiter, engine, state, context, range_delimiter)
 local irs = {}
 local first, second = variable[1], variable[2]

 local date_part_irs = {}
 for i, diff_part in ipairs(range_part_queue) do
   -- if delimiter and i > 1 then
   --   table.insert(date_part_irs, PlainText:new(delimiter))
   -- end
   if i == #range_part_queue then
     table.insert(date_part_irs, diff_part:build_ir(first, engine, state, context, "suffix"))
   else
     table.insert(date_part_irs, diff_part:build_ir(first, engine, state, context))
   end
 end
 local range_part_ir = SeqIr:new(date_part_irs, self)
 range_part_ir.delimiter = delimiter
 table.insert(irs, range_part_ir)

 table.insert(irs, Rendered:new({PlainText:new(range_delimiter)}, self))

 date_part_irs = {}
 for i, diff_part in ipairs(range_part_queue) do
   if i == 1 then
     table.insert(date_part_irs, diff_part:build_ir(second, engine, state, context, "prefix"))
   else
     table.insert(date_part_irs, diff_part:build_ir(second, engine, state, context))
   end
 end
 range_part_ir = SeqIr:new(date_part_irs, self)
 range_part_ir.delimiter = delimiter
 table.insert(irs, range_part_ir)

 local ir = SeqIr:new(irs, self)

 return ir
end

function Date:render_sort_key(engine, state, context)
 local date = context:get_variable(self.variable)
 if not date then
   return false
 end
 if not date["date-parts"] then
   if date.literal then
     return "1" .. date.literal
   else
     return false
   end
 end

 local show_parts = {
   year = false,
   month = false,
   day = false,
 }
 if self.form then
   for _, dp_name in ipairs(util.split(self.date_parts, "%-")) do
     show_parts[dp_name] = true
   end
 else
   for _, date_part in ipairs(self.children) do
     show_parts[date_part.name] = true
   end
 end
 local res = ""
 for _, range_part in ipairs(date["date-parts"]) do
   if res ~= "" then
     res = res .. "/"
   end
   for i, dp_name in ipairs({"year", "month", "day"}) do
     local value = 0
     if show_parts[dp_name] and range_part[i] then
       value = range_part[i]
     end
     if i == 1 then
       res = res .. string.format("%05d", value + 10000)
     else
       res = res .. "-" .. string.format("%02d", value)
     end
   end
 end
 return res
end


-- [Date-part](https://docs.citationstyles.org/en/stable/specification.html#date-part)
---@class DatePart: Element
---@field name string
---@field form string?
---@field text_case string?
---@field range_delimiter string?
local DatePart = Element:derive("date-part")

function DatePart:from_node(node)
 local o = DatePart:new()
 o:set_attribute(node, "name")
 o:set_attribute(node, "form")
 if o.name == "month" then
   o:set_strip_periods_attribute(node)
 end
 o:set_formatting_attributes(node)
 o:set_text_case_attribute(node)
 o:set_attribute(node, "range-delimiter")
 o:set_affixes_attributes(node)
 return o
end

function DatePart:build_ir(single_date, engine, state, context, suppressed_affix)
 local text
 if self.name == "year" then
   text = self:render_year(single_date[1], engine, state, context)
 elseif self.name == "month" then
   text = self:render_month(single_date[2], engine, state, context)
 elseif self.name == "day" then
   text = self:render_day(single_date[3], single_date[2], engine, state, context)
 end

 if not text then
   local ir = Rendered:new({}, self)
   ir.group_var = GroupVar.Missing
   return ir
 end

 local inlines = {PlainText:new(text)}
 local output_format = context.format
 -- if not context.is_english then
 --   print(debug.traceback())
 -- end
 local is_english = context:is_english()
 output_format:apply_text_case(inlines, self.text_case, is_english)
 inlines = output_format:with_format(inlines, self.formatting)

 local ir = Rendered:new(inlines, self)
 ir.group_var = GroupVar.Important

 if self.name == "year" then
   ir = SeqIr:new({ir}, self)
   ir.is_year = true
 end

 ir.affixes = util.clone(self.affixes)
 if ir.affixes and suppressed_affix then
   ir.affixes[suppressed_affix] = nil
 end

 return ir
end

function DatePart:render_day(day, month, engine, state, context)
 if not day or day == "" then
   return nil
 end
 day = tonumber(day)
 if day < 1 or day > 31 then
   return nil
 end
 local form = self.form or "numeric"
 if form == "ordinal" then
   local limit_day_1 = context.locale.style_options.limit_day_ordinals_to_day_1
   if limit_day_1 and day > 1 then
     form = "numeric"
   end
 end
 if form == "numeric-leading-zeros" then
   return string.format("%02d", day)
 elseif form == "ordinal" then
   -- When the “day” date-part is rendered in the “ordinal” form, the ordinal
   -- gender is matched against that of the month term.
   local gender = context.locale:get_number_gender(string.format("month-%02d", month))
   local suffix = context.locale:get_ordinal_term(day, gender)
   return tostring(day) .. suffix
 else  -- numeric
   return tostring(day)
 end
end

function DatePart:render_month(month, engine, state, context)
 if not month or month == "" then
   return nil
 end
 month = tonumber(month)
 if not month or month < 1 or month > 24 then
   return nil
 end
 local form = self.form or "long"
 local res
 if form == "long" or form == "short" then
   if month >= 1 and month <= 12 then
     res = context:get_simple_term(string.format("month-%02d", month), form)
   else
     local season = month % 4
     if season == 0 then
       season = 4
     end
     res = context:get_simple_term(string.format("season-%02d", season))
   end
 elseif form == "numeric-leading-zeros" then
   res = string.format("%02d", month)
 else
   res = tostring(month)
 end
 return self:apply_strip_periods(res)
end

function DatePart:render_year(year, engine, state, context)
 if not year or year == "" then
   return nil
 end
 year = tonumber(year)
 if year == 0 then
   return nil
 end
 local form = self.form or "long"
 if form == "long" then
   if year < 0 then
     return tostring(-year) .. context:get_simple_term("bc")
   elseif year < 1000 then
     return tostring(year) .. context:get_simple_term("ad")
   else
     return tostring(year)
   end
 elseif form == "short" then
   return string.sub(tostring(year), -2)
 end
end

function DatePart:copy()
 local o = {}
 for key, value in pairs(self) do
   if type(value) == "table" then
     o[key] = {}
     for k, v in pairs(value) do
       o[key][k] = v
     end
   else
     o[key] = value
   end
 end
 setmetatable(o, DatePart)
 return o
end

function DatePart:override(localized_date_part)
 for key, value in pairs(self) do
   if type(value) == "table" and localized_date_part[key] then
     for k, v in pairs(value) do
       localized_date_part[key][k] = v
     end
   else
     localized_date_part[key] = value
   end
 end
end


date_module.Date = Date
date_module.DatePart = DatePart

return date_module