--
-- Copyright (c) 2021-2025 Zeping Lee
-- Released under the MIT license.
-- Repository: https://github.com/zepinglee/citeproc-lua
--

local sort = {}

local uca_languages
local uca_ducet
local uca_collator

local element
local output
local util
local node_date

local using_luatex, kpse = pcall(require, "kpse")
if using_luatex then
 uca_languages = require("lua-uca-languages")
 uca_ducet = require("lua-uca-ducet")
 uca_collator = require("lua-uca-collator")
 element = require("citeproc-element")
 output = require("citeproc-output")
 util = require("citeproc-util")
 node_date = require("citeproc-node-date")
else
 uca_languages = require("citeproc.lua-uca.languages")
 uca_ducet = require("citeproc.lua-uca.ducet")
 uca_collator = require("citeproc.lua-uca.collator")
 element = require("citeproc.element")
 output = require("citeproc.output")
 util = require("citeproc.util")
 node_date = require("citeproc.node-date")
end

local Element = element.Element
local Date = node_date.Date
local InlineElement = output.InlineElement


-- [Sorting](https://docs.citationstyles.org/en/stable/specification.html#sorting)
---@class Sort: Element
---@field children Key[]
---@field sort_directions boolean[]
local Sort = Element:derive("sort")

function Sort:from_node(node)
 local o = Sort:new()
 o.children = {}

 o:process_children_nodes(node)
 o.sort_directions = {}
 for i, key in ipairs(o.children) do
   o.sort_directions[i] = (key.sort ~= "descending")
 end
 table.insert(o.sort_directions, true)

 return o
end

function Sort:sort(items, state, context)
 -- key_map = {
 --   id1 = {key1, key2, ...},
 --   id2 = {key1, key2, ...},
 --   ...
 -- }
 local key_map = {}
 local sort_directions = self.sort_directions
 -- true: ascending
 -- false: descending

 if not Sort.collator then
   local lang = context.engine.lang
   local language = string.sub(lang, 1, 2)
   Sort.collator = uca_collator.new(uca_ducet)
   if language ~= "en" then
     if uca_languages[language] then
       Sort.collator = uca_languages[language](Sort.collator)
     else
       util.warning(string.format("Locale '%s' is not provided by lua-uca. The sorting order may be incorrect.", lang))
     end
   end
 end

 -- TODO: optimize: use cached values for fixed keys
 for i, item in ipairs(items) do
   key_map[item.id] = {}

   context.id = item.id
   context.cite = item
   context.reference = context.engine.registry.registry[item.id]

   for j, key in ipairs(self.children) do
     if context.reference then
       context.sort_key = key
       local key_str = key:eval(context.engine, state, context)
       key_map[item.id][j] = key_str
     else
       -- The entry is missing
       key_map[item.id][j] = false
     end
   end
   -- To preserve the original order of items with same sort keys
   -- sort_NameImplicitSortOrderAndForm.txt
   table.insert(key_map[item.id], i)
 end

 local function compare_entry(item1, item2)
   return self.compare_entry(key_map, sort_directions, item1, item2)
 end
 table.sort(items, compare_entry)

 return items
end

function Sort.compare(value1, value2)
 if type(value1) == "string" then
   return Sort.compare_strings(value1, value2)
 else
   return value1 < value2
 end
end

function Sort.compare_strings(str1, str2)
 if Sort.collator then
   return Sort.collator:compare_strings(str1, str2)
 else
   return str1 < str2
 end
end

function Sort.compare_entry(key_map, sort_directions, item1, item2)
 for i, value1 in ipairs(key_map[item1.id]) do
   local ascending = sort_directions[i]
   local value2 = key_map[item2.id][i]
   if value1 and value2 then
     local res
     if ascending then
       res = Sort.compare(value1, value2)
     else
       res = Sort.compare(value2, value1)
     end
     if res or value1 ~= value2 then
       return res
     end
   elseif value1 then
     return true
   elseif value2 then
     return false
   end
 end
end


---@class Key: Element
---@field sort string?
---@field variable string?
---@field macro string?
---@field names_min number?
---@field names_use_first number?
---@field names_use_last number?
local Key = Element:derive("key")

function Key:new()
 local o = Element.new(self)
 Key.sort = "ascending"
 return o
end

function Key:from_node(node)
 local o = Key:new()
 o:set_attribute(node, "sort")
 o:set_attribute(node, "variable")
 o:set_attribute(node, "macro")
 o:set_number_attribute(node, "names-min")
 o:set_number_attribute(node, "names-use-first")
 o:set_bool_attribute(node, "names-use-last")
 return o
end

function Key:eval(engine, state, context)
 local res
 if self.variable then
   local variable_type = util.variable_types[self.variable]
   if variable_type == "name" then
     res = self:eval_name(engine, state, context)
   elseif variable_type == "date" then
     res = self:eval_date(context)
   elseif variable_type == "number" then
     local value = context:get_variable(self.variable)
     if type(value) == "string" and string.match(value, "%s+") then
       value = tonumber(value)
     end
     res = value
   else
     res = context:get_variable(self.variable)
     if type(res) == "string" then
       local inlines = InlineElement:parse(res, context)
       res = context.format:output(inlines)
     end
   end
 elseif self.macro then
   local macro = context:get_macro(self.macro)
   state:push_macro(self.macro)
   local ir = macro:build_ir(engine, state, context)
   state:pop_macro(self.macro)
   if ir.name_count then
     return ir.name_count
   elseif ir.sort_key ~= nil then
     return ir.sort_key
   end
   local output_format = context.format
   local inlines = ir:flatten(output_format)
   local str = output_format:output(inlines)
   return str
 end
 if res == nil then
   -- make table.insert(_, nil) work
   res = false
 end
 return res
end

function Key:eval_name(engine, state, context)
 if not self.name_inheritance then
   self.name_inheritance = util.clone(context.name_inheritance)
 end
 local name = context:get_variable(self.variable)
 if not name then
   return false
 end
 local ir = self.name_inheritance:build_ir(self.variable, nil, nil, engine, state, context)
 if ir.name_count then
   -- name count
   return ir.name_count
 end
 local output_format = context.format
 local inlines = ir:flatten(output_format)

 local str = output_format:output(inlines)

 return str
end

function Key:eval_date(context)
 if not self.date then
   self.date = Date:new()
   self.date.variable = self.variable
   self.date.form = "numeric"
   self.date.date_parts = "year-month-day"
 end
 return self.date:render_sort_key(context.engine, nil, context)
end


sort.Sort = Sort
sort.Key = Key

return sort