--
-- Copyright (c) 2021-2025 Zeping Lee
-- Released under the MIT license.
-- Repository: https://github.com/zepinglee/citeproc-lua
--

local element = {}

local l = require("lpeg")
local ir_node
local output
local util

local using_luatex, kpse = pcall(require, "kpse")
if using_luatex then
 ir_node = require("citeproc-ir-node")
 output = require("citeproc-output")
 util = require("citeproc-util")
else
 ir_node = require("citeproc.ir-node")
 output = require("citeproc.output")
 util = require("citeproc.util")
end

local GroupVar = ir_node.GroupVar
local SeqIr = ir_node.SeqIr

local InlineElement = output.InlineElement
local Micro = output.Micro


---@class Element
---@field element_name string?
---@field children Element[]?
local Element = {
 element_name = nil,
 children = nil,
 element_type_map = {},
}

function Element:new(element_name)
 local o = {
   element_name = element_name or self.element_name,
 }
 setmetatable(o, self)
 self.__index = self
 return o
end

---@param element_name string
---@param default_options table?
---@return Element
function Element:derive(element_name, default_options)
 local o = {
   element_name = element_name or self.element_name,
   children = nil,
 }

 if default_options then
   for key, value in pairs(default_options) do
     o[key] = value
   end
 end

 Element.element_type_map[element_name] = o
 setmetatable(o, self)
 self.__index = self
 return o
end

---@class Node

---@param node Node
---@param parent Element?
---@return Element
function Element:from_node(node, parent)
 local o = self:new()
 o.element_name = self.element_name or node:get_element_name()
 return o
end

function Element:set_attribute(node, attribute)
 local value = node:get_attribute(attribute)
 if value then
   local key = string.gsub(attribute, "%-", "_")
   self[key] = value
 end
end

function Element:set_bool_attribute(node, attribute)
 local value = node:get_attribute(attribute)
 if value == "true" then
   local key = string.gsub(attribute, "%-", "_")
   self[key] = true
 elseif value == "false" then
   local key = string.gsub(attribute, "%-", "_")
   self[key] = false
 end
end

function Element:set_number_attribute(node, attribute)
 local value = node:get_attribute(attribute)
 if value then
   local key = string.gsub(attribute, "%-", "_")
   self[key] = tonumber(value)
 end
end

function Element:process_children_nodes(node)
 if not self.children then
   self.children = {}
 end
 for _, child in ipairs(node:get_children()) do
   if child:is_element() then
     local element_name = child:get_element_name()
     local element_type = self.element_type_map[element_name] or Element
     local child_element = element_type:from_node(child, self)
     table.insert(self.children, child_element)
   end
 end

end

function Element.make_name_inheritance(name, node)
 name:set_attribute(node, "and")
 name:set_attribute(node, "delimiter-precedes-et-al")
 name:set_attribute(node, "delimiter-precedes-last")
 name:set_number_attribute(node, "et-al-min")
 name:set_number_attribute(node, "et-al-use-first")
 name:set_number_attribute(node, "et-al-subsequent-min")
 name:set_number_attribute(node, "et-al-subsequent-use-first")
 name:set_bool_attribute(node, "et-al-use-last")
 name:set_bool_attribute(node, "initialize")
 name:set_attribute(node, "initialize-with")
 name:set_attribute(node, "name-as-sort-order")
 name:set_attribute(node, "sort-separator")
 local delimiter = node:get_attribute("name-delimiter")
 if delimiter then
   name.delimiter = delimiter
 end
 local form = node:get_attribute("name-form")
 if form then
   name.form = form
 end
 local names_delimiter = node:get_attribute("names-delimiter")
 if names_delimiter then
   name.names_delimiter = names_delimiter
 end
end


function Element:build_ir(engine, state, context)
 return self:build_children_ir(engine, state, context)
end

function Element:build_children_ir(engine, state, context)
 local child_irs = {}
 local ir_sort_key
 local group_var = GroupVar.Plain
 if self.children then
   for _, child_element in ipairs(self.children) do
     local child_ir = child_element:build_ir(engine, state, context)
     if child_ir then
       if child_ir.sort_key ~= nil then
         ir_sort_key = child_ir.sort_key
       end
       if child_ir.group_var == GroupVar.Important then
         group_var = GroupVar.Important
       end
       table.insert(child_irs, child_ir)
     end
   end
 end
 local ir = SeqIr:new(child_irs, self)
 ir.sort_key = ir_sort_key
 ir.group_var = group_var
 if #child_irs == 0 then
   ir.group_var = GroupVar.Missing
 else
   ir.group_var = group_var
 end
 return ir
end

-- Used in cs:group and cs:macro
function Element:build_group_ir(engine, state, context)
 if not self.children then
   return nil
 end
 local irs = {}
 local name_count
 local ir_sort_key
 local group_var = GroupVar.UnresolvedPlain

 for _, child_element in ipairs(self.children) do
   local child_ir = child_element:build_ir(engine, state, context)

   if child_ir then
     -- cs:group and its child elements are suppressed if
     --   a) at least one rendering element in cs:group calls a variable (either
     --      directly or via a macro), and
     --   b) all variables that are called are empty. This accommodates
     --      descriptive cs:text and `cs:label` elements.
     local child_group_var = child_ir.group_var
     if child_group_var == GroupVar.Important then
       group_var = GroupVar.Important
     elseif child_group_var == GroupVar.Plain and group_var == GroupVar.UnresolvedPlain then
       group_var = GroupVar.Plain
     elseif child_group_var == GroupVar.Missing and child_ir._type ~= "YearSuffix" then
       if group_var == GroupVar.Plain or group_var == GroupVar.UnresolvedPlain then
         group_var = GroupVar.Missing
       end
     end

     if child_ir.name_count then
       if not name_count then
         name_count = 0
       end
       name_count = name_count + child_ir.name_count
     end

     if child_ir.sort_key ~= nil then
       ir_sort_key = child_ir.sort_key
     end

     table.insert(irs, child_ir)
   end
 end

 -- A non-empty nested cs:group is treated as a non-empty variable for the
 -- puropses of determining suppression of the outer cs:group.
 if #irs > 0 and group_var == GroupVar.Plain then
   group_var = GroupVar.Important
 end

 local ir = SeqIr:new(irs, self)
 ir.name_count = name_count
 ir.sort_key = ir_sort_key
 ir.group_var = group_var

 return ir
end

---@param str string
---@param context Context
---@return InlineElement[]
function Element:render_text_inlines(str, context)
 if str == "" then
   return {}
 end

 str = self:apply_strip_periods(str)
 -- TODO: try links

 local output_format = context.format
 local localized_quotes = nil
 if self.quotes then
   localized_quotes = context:get_localized_quotes()
 end

 local inlines = InlineElement:parse(str, context)
 local is_english = context:is_english()
 output_format:apply_text_case(inlines, self.text_case, is_english)
 inlines = {Micro:new(inlines)}
 inlines = output_format:with_format(inlines, self.formatting)
 inlines = output_format:affixed_quoted(inlines, self.affixes, localized_quotes)
 return output_format:with_display(inlines, self.display)
end

function Element:set_formatting_attributes(node)
 for _, attribute in ipairs({
   "font-style",
   "font-variant",
   "font-weight",
   "text-decoration",
   "vertical-align",
 }) do
   local value = node:get_attribute(attribute)
   if value then
     if not self.formatting then
       self.formatting = {}
     end
     self.formatting[attribute] = value
   end
 end
end

function Element:set_affixes_attributes(node)
 for _, attribute in ipairs({"prefix", "suffix"}) do
   local value = node:get_attribute(attribute)
   if value then
     if not self.affixes then
       self.affixes = {}
     end
     self.affixes[attribute] = value
   end
 end
end

function Element:get_delimiter_attribute(node)
 self:set_attribute(node, "delimiter")
end

function Element:set_display_attribute(node)
 self:set_attribute(node, "display")
end

function Element:set_quotes_attribute(node)
 self:set_bool_attribute(node, "quotes")
end

function Element:set_strip_periods_attribute(node)
 self:set_bool_attribute(node, "strip-periods")
end

function Element:set_text_case_attribute(node)
 self:set_attribute(node, "text-case")
end

-- function Element:apply_formatting(ir)
--   local attributes = {
--     "font_style",
--     "font_variant",
--     "font_weight",
--     "text_decoration",
--     "vertical_align",
--   }
--   for _, attribute in ipairs(attributes) do
--     local value = self[attribute]
--     if value then
--       if not ir.formatting then
--         ir.formatting = {}
--       end
--       ir.formatting[attribute] = value
--     end
--   end
--   return ir
-- end

function Element:apply_affixes(ir)
 if ir then
   if self.prefix then
     ir.prefix = self.prefix
   end
   if self.suffix then
     ir.suffix = self.suffix
   end
 end
 return ir
end

function Element:apply_delimiter(ir)
 if ir and ir.children then
   ir.delimiter = self.delimiter
 end
 return ir
end

function Element:apply_display(ir)
 ir.display = self.display
 return ir
end

function Element:apply_quotes(ir)
 if ir and self.quotes then
   ir.quotes = true
   ir.children = {ir}
   ir.open_quote = nil
   ir.close_quote = nil
   ir.open_inner_quote = nil
   ir.close_inner_quote = nil
   ir.punctuation_in_quote = false
 end
 return ir
end

function Element:apply_strip_periods(str)
 local res = str
 if str and self.strip_periods then
   res = string.gsub(str, "%.", "")
 end
 return res
end


---@param number string Non-empty string
---@param variable string
---@param form string
---@param context Context
---@return string
function Element:format_number(number, variable, form, context)
 number = util.strip(number)
 if variable == "locator" then
   local locator_variable = context:get_variable("label")
   if not locator_variable or type(locator_variable) ~= "string" then
     util.error("Invalid locator label")
     locator_variable = "page"
   end
   variable = locator_variable
 end
 form = form or "numeric"
 local number_part_list = self:split_number_parts_lpeg(number, context)
 -- {
 --   {"1", "",  " & "}
 --   {"5", "8", ", "}
 -- }

 for _, number_parts in ipairs(number_part_list) do
   if form == "roman" then
     self:format_roman_number_parts(number_parts)
   elseif form == "ordinal" or form == "long-ordinal" then
     local gender = context.locale:get_number_gender(variable)
     self:format_ordinal_number_parts(number_parts, form, gender, context)
   elseif number_parts[2] ~= "" and variable == "page" then
     local page_range_format = context.style.page_range_format
     self:format_page_range(number_parts, page_range_format)
   else
     self:format_numeric_number_parts(number_parts)
   end
 end

 local range_delimiter = util.unicode["en dash"]
 if variable == "page" then
   local page_range_delimiter = context:get_simple_term("page-range-delimiter")
   if page_range_delimiter then
     range_delimiter = page_range_delimiter
   end
 end

 local res = ""
 for _, number_parts in ipairs(number_part_list) do
   res = res .. number_parts[1]
   if number_parts[2] ~= "" then
     res = res .. range_delimiter
     res = res .. number_parts[2]
   end
   res = res .. number_parts[3]
 end
 return res
end

---@alias NumberToken {type: string, value: string, delimiter_type: string}

---@param number string
---@param context Context
---@return NumberToken[]
function Element:parse_number_tokens(number, context)
 local and_text = "and"
 local and_symbol = "&"
 if context then
   and_text = context.locale:get_simple_term("and") or "and"
   and_symbol = context.locale:get_simple_term("and", "symbol") or "&"
 end

 ---@diagnostic disable codestyle-check
 local space = l.S(" \t\r\n")
 local delimiter_patt = space^0 * l.P(",") * space^0 * l.P(and_text) * space^1
     + space^0 * l.P(",") * space^0 * l.P(and_symbol) * space^0
     + space^0 * l.P(",") * space^0 * l.P("&") * space^0
     + space^1 * l.P(and_text) * space^1
     + space^0 * l.P(and_symbol) * space^0
     + space^0 * l.P("&") * space^0
     + space^0 * l.P(",") * space^0
     + space^0 * l.P("-") * space^0
     + space^0 * l.P(util.unicode["en dash"]) * space^0
 local delimiter = l.C(delimiter_patt^1) / function (delimiter)
   return {
     type = "delimiter",
     value = delimiter,
   }
 end
 local token_patt = l.C((l.P("\\-") + 1 - delimiter_patt)^1) / function (token)
   return {
     type = "string",
     value = token,
   }
 end
 local grammer = l.Ct((token_patt * (delimiter * token_patt)^0)^-1)
 ---@diagnostic enable codestyle-check
 local tokens = grammer:match(number)

 if not tokens then
   return {}
 end
 ---@cast tokens NumberToken[]

 for i, token in ipairs(tokens) do
   if token.type == "string" then
     token.value = string.gsub(token.value, "\\%-", "-")
   elseif token.type == "delimiter" then
     token.value = string.gsub(token.value, "%s*,%s*", ", ")
     token.value = string.gsub(token.value, "&", and_symbol)
     token.value = string.gsub(token.value, "%s*&%s*", " & ")
   end
 end

 local stop_index = 0
 for i, token in ipairs(tokens) do
   if token.type == "string" then
     if string.match(token.value, "^%w*%d+%w*$")
         or string.match(token.value, "^[mdclxvi]+$")
         or string.match(token.value, "^[MDCLXVI]+$") then
       token.type = "number"
     else
       stop_index = i
       if i > 1 and tokens[i - 1].type == "delimiter" then
         stop_index = i - 1
       end
       break
     end
   elseif token.type == "delimiter" then
     token.delimiter_type = "and"
     if string.match(token.value, "^%s*-%s*$")
         or string.match(token.value, "^%s*–%s*$") then
       token.delimiter_type = "range"
       if i > 2 and tokens[i - 2].delimiter_type == "range" then
         stop_index = i
         break
       end
     end
   end
 end

 if stop_index > 0 then
   local token = tokens[stop_index]
   token.type = "string"
   for i = stop_index + 1, #tokens do
     token.value = token.value .. tokens[i].value
   end
   for i = #tokens, stop_index + 1, -1 do
     table.remove(tokens, i)
   end
 end

 return tokens
end

-- Returns something like
-- {
--   {"1", "",  " & "}
--   {"5", "8", ", "}
-- }
function Element:split_number_parts_lpeg(number, context)
 local tokens = self:parse_number_tokens(number, context)
 local number_parts = {}
 for i, token in ipairs(tokens) do
   if token.type == "number" then
     if i == 1 or tokens[i - 1].delimiter_type == "and" then
       table.insert(number_parts, {token.value, "", ""})
     else
       number_parts[#number_parts][2] = token.value
     end
   elseif token.type == "delimiter" then
     if token.delimiter_type == "and" then
       number_parts[#number_parts][3] = token.value
     end
   else
     if #number_parts > 0 then
       number_parts[#number_parts][3] = token.value
     else
       table.insert(number_parts, {token.value, "", ""})
     end
   end
 end
 return number_parts
end


function Element:split_number_parts(number, context)
 -- number = string.gsub(number, util.unicode["en dash"], "-")
 local and_symbol
 and_symbol = context.locale:get_simple_term("and", "symbol")
 if and_symbol then
   and_symbol = " " .. and_symbol .. " "
 end
 local number_part_list = {}
 for _, tuple in ipairs(util.split_multiple(number, "%s*[,&]%s*", true)) do
   local single_number, delim = table.unpack(tuple)
   delim = util.strip(delim)
   if delim == "," then
     delim = ", "
   elseif delim == "&" then
     delim = and_symbol or " & "
   elseif delim == "and" then
     delim = " and "
   elseif delim == "et" then
     delim = " et "
   end
   local start = single_number
   local stop = ""
   local splits = util.split(start, "%s*%-%s*")
   if #splits == 2 then
     start, stop = table.unpack(splits)
     if util.endswith(start, "\\") then
       start = string.sub(start, 1, -2)
       start = start .. "-" .. stop
       stop = ""
     end
     -- if string.match(start, "^%a*%d+%a*$") and string.match(stop, "^%a*%d+%a*$") then
     --   if s
     table.insert(number_part_list, {start, stop, delim})
     -- else
     -- table.insert(number_part_list, {start .. "-" .. stop, "", delim})
     -- end
   else
     table.insert(number_part_list, {start, stop, delim})
   end
 end
 return number_part_list
end

function Element:format_roman_number_parts(number_parts)
 for i = 1, 2 do
   local part = number_parts[i]
   if string.match(part, "%d+") then
     number_parts[i] = util.convert_roman(tonumber(part))
   end
 end
end

function Element:format_ordinal_number_parts(number_parts, form, gender, context)
 for i = 1, 2 do
   local part = number_parts[i]
   -- Values like "2nd" are kept the in the original form.
   if string.match(part, "^%d+$") then
     local number = tonumber(part)
     if form == "long-ordinal" and number >= 1 and number <= 10 then
       number_parts[i] = context:get_simple_term(string.format("long-ordinal-%02d", number))
     else
       local suffix = context.locale:get_ordinal_term(number, gender)
       if suffix then
         number_parts[i] = number_parts[i] .. suffix
       end
     end
   end
 end
end

function Element:format_numeric_number_parts(number_parts)
 -- if number_parts[2] ~= "" then
 --   local first_prefix = string.match(number_parts[1], "^(.-)%d+")
 --   local second_prefix = string.match(number_parts[2], "^(.-)%d+")
 --   if first_prefix == second_prefix then
 --     number_parts[1] = number_parts[1] .. "-" .. number_parts[2]
 --     number_parts[2] = ""
 --   end
 -- end
end

-- https://docs.citationstyles.org/en/stable/specification.html#appendix-v-page-range-formats
function Element:format_page_range(number_parts, page_range_format)
 local start = number_parts[1]
 local stop = number_parts[2]

 if string.match(start, "^%a+$") and string.match(stop, "^%a+$") then
   -- CMoS exaple: xxv–xxviii
   return stop
 end

 local start_prefix, start_num = string.match(start, "^(.-)(%d+)$")
 local stop_prefix, stop_num = string.match(stop, "^(.-)(%d+)$")
 if start_prefix ~= stop_prefix then
   -- Not valid range: "n11564-1568" -> "n11564-1568"
   -- 110-N6
   -- N110-P5
   number_parts[1] = start .. "-" .. stop
   number_parts[2] = ""
   return
 end

 if not page_range_format then
   return
 end
 if page_range_format == "chicago-16" then
   stop = self:_format_range_chicago_16(start_num, stop_num)
 elseif page_range_format == "chicago-15" then
   stop = self:_format_range_chicago_15(start_num, stop_num)
 elseif page_range_format == "expanded" then
   stop = stop_prefix .. self:_format_range_expanded(start_num, stop_num)
 elseif page_range_format == "minimal" then
   stop = self:_format_range_minimal(start_num, stop_num)
 elseif page_range_format == "minimal-two" then
   stop = self:_format_range_minimal(start_num, stop_num, 2)
 end
 number_parts[2] = stop
end

---@param start string
---@param stop string
---@return string
function Element:_format_range_chicago_16(start, stop)
 if not start then
   print(debug.traceback())
 end
 stop = self:_format_range_expanded(start, stop)
 if #start < 3 or string.sub(start, -2) == "00" then
   return self:_format_range_expanded(start, stop)
 elseif string.sub(start, -2, -2) == "0" then
   return self:_format_range_minimal(start, stop)
 else
   return self:_format_range_minimal(start, stop, 2)
 end
 return stop
end

function Element:_format_range_chicago_15(start, stop)
 if #start < 3 or string.sub(start, -2) == "00" then
   return self:_format_range_expanded(start, stop)
 else
   stop = self:_format_range_expanded(start, stop)
   local changed_digits = self:_format_range_minimal(start, stop)
   if string.sub(start, -2, -2) == "0" then
     return changed_digits
   elseif #start == 4 and #changed_digits == 3 then
     return self:_format_range_expanded(start, stop)
   else
     return self:_format_range_minimal(start, stop, 2)
   end
 end
 return stop
end

function Element:_format_range_expanded(start, stop)
 -- Expand  "1234–56" -> "1234–1256"
 if #start <= #stop then
   return stop
 end
 return string.sub(start, 1, #start - #stop) .. stop
end

---@param start string
---@param stop string
---@param threshold integer? Number of minimal digits
---@return string
function Element:_format_range_minimal(start, stop, threshold)
 threshold = threshold or 1
 if #start < #stop then
   return stop
 end
 local offset = #start - #stop
 for i = 1, #stop - threshold do
   local j = i + offset
   if string.sub(stop, i, i) ~= string.sub(start, j, j) then
     return string.sub(stop, i)
   end
 end
 local res = string.sub(stop, -threshold)
 return res
end

element.Element = Element

return element