--
-- Copyright (c) 2021-2025 Zeping Lee
-- Released under the MIT license.
-- Repository: https://github.com/zepinglee/citeproc-lua
--

local element = {}

local l = require("lpeg")
local ir_node
local output
local util

local using_luatex, kpse = pcall(require, "kpse")
if using_luatex then
  ir_node = require("citeproc-ir-node")
  output = require("citeproc-output")
  util = require("citeproc-util")
else
  ir_node = require("citeproc.ir-node")
  output = require("citeproc.output")
  util = require("citeproc.util")
end

local GroupVar = ir_node.GroupVar
local SeqIr = ir_node.SeqIr

local InlineElement = output.InlineElement
local Micro = output.Micro


---@class Element
---@field element_name string?
---@field children Element[]?
local Element = {
  element_name = nil,
  children = nil,
  element_type_map = {},
}

function Element:new(element_name)
  local o = {
    element_name = element_name or self.element_name,
  }
  setmetatable(o, self)
  self.__index = self
  return o
end

---@param element_name string
---@param default_options table?
---@return Element
function Element:derive(element_name, default_options)
  local o = {
    element_name = element_name or self.element_name,
    children = nil,
  }

  if default_options then
    for key, value in pairs(default_options) do
      o[key] = value
    end
  end

  Element.element_type_map[element_name] = o
  setmetatable(o, self)
  self.__index = self
  return o
end

---@class Node

---@param node Node
---@param parent Element?
---@return Element
function Element:from_node(node, parent)
  local o = self:new()
  o.element_name = self.element_name or node:get_element_name()
  return o
end

function Element:set_attribute(node, attribute)
  local value = node:get_attribute(attribute)
  if value then
    local key = string.gsub(attribute, "%-" , "_")
    self[key] = value
  end
end

function Element:set_bool_attribute(node, attribute)
  local value = node:get_attribute(attribute)
  if value == "true" then
    local key = string.gsub(attribute, "%-" , "_")
    self[key] = true
  elseif value == "false" then
    local key = string.gsub(attribute, "%-" , "_")
    self[key] = false
  end
end

function Element:set_number_attribute(node, attribute)
  local value = node:get_attribute(attribute)
  if value then
    local key = string.gsub(attribute, "%-" , "_")
    self[key] = tonumber(value)
  end
end

function Element:process_children_nodes(node)
  if not self.children then
    self.children = {}
  end
  for _, child in ipairs(node:get_children()) do
    if child:is_element() then
      local element_name = child:get_element_name()
      local element_type = self.element_type_map[element_name] or Element
      local child_element = element_type:from_node(child, self)
      table.insert(self.children, child_element)
    end
  end

end

function Element.make_name_inheritance(name, node)
  name:set_attribute(node, "and")
  name:set_attribute(node, "delimiter-precedes-et-al")
  name:set_attribute(node, "delimiter-precedes-last")
  name:set_number_attribute(node, "et-al-min")
  name:set_number_attribute(node, "et-al-use-first")
  name:set_number_attribute(node, "et-al-subsequent-min")
  name:set_number_attribute(node, "et-al-subsequent-use-first")
  name:set_bool_attribute(node, "et-al-use-last")
  name:set_bool_attribute(node, "initialize")
  name:set_attribute(node, "initialize-with")
  name:set_attribute(node, "name-as-sort-order")
  name:set_attribute(node, "sort-separator")
  local delimiter = node:get_attribute("name-delimiter")
  if delimiter then
    name.delimiter = delimiter
  end
  local form = node:get_attribute("name-form")
  if form then
    name.form = form
  end
  local names_delimiter = node:get_attribute("names-delimiter")
  if names_delimiter then
    name.names_delimiter = names_delimiter
  end
end


function Element:build_ir(engine, state, context)
  return self:build_children_ir(engine, state, context)
end

function Element:build_children_ir(engine, state, context)
  local child_irs = {}
  local ir_sort_key
  local group_var = GroupVar.Plain
  if self.children then
    for _, child_element in ipairs(self.children) do
      local child_ir = child_element:build_ir(engine, state, context)
      if child_ir then
        if child_ir.sort_key ~= nil then
          ir_sort_key = child_ir.sort_key
        end
        if child_ir.group_var == GroupVar.Important then
          group_var = GroupVar.Important
        end
        table.insert(child_irs, child_ir)
      end
    end
  end
  local ir = SeqIr:new(child_irs, self)
  ir.sort_key = ir_sort_key
  ir.group_var = group_var
  if #child_irs == 0 then
    ir.group_var = GroupVar.Missing
  else
    ir.group_var = group_var
  end
  return ir
end

-- Used in cs:group and cs:macro
function Element:build_group_ir(engine, state, context)
  if not self.children then
    return nil
  end
  local irs = {}
  local name_count
  local ir_sort_key
  local group_var = GroupVar.UnresolvedPlain

  for _, child_element in ipairs(self.children) do
    -- util.debug(child_element.element_name)
    local child_ir = child_element:build_ir(engine, state, context)
    -- util.debug(child_ir)
    -- util.debug(child_ir.group_var)

    if child_ir then
      -- cs:group and its child elements are suppressed if
      --   a) at least one rendering element in cs:group calls a variable (either
      --      directly or via a macro), and
      --   b) all variables that are called are empty. This accommodates
      --      descriptive cs:text and `cs:label` elements.
      local child_group_var = child_ir.group_var
      if child_group_var == GroupVar.Important then
        group_var = GroupVar.Important
      elseif child_group_var == GroupVar.Plain and group_var == GroupVar.UnresolvedPlain then
        group_var = GroupVar.Plain
      elseif child_group_var == GroupVar.Missing and child_ir._type ~= "YearSuffix" then
        if group_var == GroupVar.Plain or group_var == GroupVar.UnresolvedPlain then
          group_var = GroupVar.Missing
        end
      end

      if child_ir.name_count then
        if not name_count then
          name_count = 0
        end
        name_count = name_count + child_ir.name_count
      end

      if child_ir.sort_key ~= nil then
        ir_sort_key = child_ir.sort_key
      end

      table.insert(irs, child_ir)
    end
  end

  -- A non-empty nested cs:group is treated as a non-empty variable for the
  -- puropses of determining suppression of the outer cs:group.
  if #irs > 0 and group_var == GroupVar.Plain then
    group_var = GroupVar.Important
  end

  local ir = SeqIr:new(irs, self)
  ir.name_count = name_count
  ir.sort_key = ir_sort_key
  ir.group_var = group_var

  -- util.debug(ir)

  return ir
end

---@param str string
---@param context Context
---@return InlineElement[]
function Element:render_text_inlines(str, context)
  if str == "" then
    return {}
  end

  str = self:apply_strip_periods(str)
  -- TODO: try links

  local output_format = context.format
  local localized_quotes = nil
  if self.quotes then
    localized_quotes = context:get_localized_quotes()
  end

  local inlines = InlineElement:parse(str, context)
  local is_english = context:is_english()
  output_format:apply_text_case(inlines, self.text_case, is_english)
  inlines = {Micro:new(inlines)}
  inlines = output_format:with_format(inlines, self.formatting)
  inlines = output_format:affixed_quoted(inlines, self.affixes, localized_quotes)
  return output_format:with_display(inlines, self.display)
end

function Element:set_formatting_attributes(node)
  for _, attribute in ipairs({
    "font-style",
    "font-variant",
    "font-weight",
    "text-decoration",
    "vertical-align",
  }) do
    local value = node:get_attribute(attribute)
    if value then
      if not self.formatting then
        self.formatting = {}
      end
      self.formatting[attribute] = value
    end
  end
end

function Element:set_affixes_attributes(node)
  for _, attribute in ipairs({"prefix", "suffix"}) do
    local value = node:get_attribute(attribute)
    if value then
      if not self.affixes then
        self.affixes = {}
      end
      self.affixes[attribute] = value
    end
  end
end

function Element:get_delimiter_attribute(node)
  self:set_attribute(node, "delimiter")
end

function Element:set_display_attribute(node)
  self:set_attribute(node, "display")
end

function Element:set_quotes_attribute(node)
  self:set_bool_attribute(node, "quotes")
end

function Element:set_strip_periods_attribute(node)
  self:set_bool_attribute(node, "strip-periods")
end

function Element:set_text_case_attribute(node)
  self:set_attribute(node, "text-case")
end

-- function Element:apply_formatting(ir)
--   local attributes = {
--     "font_style",
--     "font_variant",
--     "font_weight",
--     "text_decoration",
--     "vertical_align",
--   }
--   for _, attribute in ipairs(attributes) do
--     local value = self[attribute]
--     if value then
--       if not ir.formatting then
--         ir.formatting = {}
--       end
--       ir.formatting[attribute] = value
--     end
--   end
--   return ir
-- end

function Element:apply_affixes(ir)
  if ir then
    if self.prefix then
      ir.prefix = self.prefix
    end
    if self.suffix then
      ir.suffix = self.suffix
    end
  end
  return ir
end

function Element:apply_delimiter(ir)
  if ir and ir.children then
    ir.delimiter = self.delimiter
  end
  return ir
end

function Element:apply_display(ir)
  ir.display = self.display
  return ir
end

function Element:apply_quotes(ir)
  if ir and self.quotes then
    ir.quotes = true
    ir.children = {ir}
    ir.open_quote = nil
    ir.close_quote = nil
    ir.open_inner_quote = nil
    ir.close_inner_quote = nil
    ir.punctuation_in_quote = false
  end
  return ir
end

function Element:apply_strip_periods(str)
  local res = str
  if str and self.strip_periods then
    res = string.gsub(str, "%.", "")
  end
  return res
end


---@param number string Non-empty string
---@param variable string
---@param form string
---@param context Context
---@return string
function Element:format_number(number, variable, form, context)
  number = util.strip(number)
  if variable == "locator" then
    local locator_variable = context:get_variable("label")
    if not locator_variable or type(locator_variable) ~= "string" then
      util.error("Invalid locator label")
      locator_variable = "page"
    end
    variable = locator_variable
  end
  form = form or "numeric"
  local number_part_list = self:split_number_parts_lpeg(number, context)
  -- {
  --   {"1", "",  " & "}
  --   {"5", "8", ", "}
  -- }
  -- util.debug(number_part_list)

  for _, number_parts in ipairs(number_part_list) do
    if form == "roman" then
      self:format_roman_number_parts(number_parts)
    elseif form == "ordinal" or form == "long-ordinal" then
      local gender = context.locale:get_number_gender(variable)
      self:format_ordinal_number_parts(number_parts, form, gender, context)
    elseif number_parts[2] ~= "" and variable == "page" then
      local page_range_format = context.style.page_range_format
      self:format_page_range(number_parts, page_range_format)
    else
      self:format_numeric_number_parts(number_parts)
    end
  end

  local range_delimiter = util.unicode["en dash"]
  if variable == "page" then
    local page_range_delimiter = context:get_simple_term("page-range-delimiter")
    if page_range_delimiter then
      range_delimiter = page_range_delimiter
    end
  end

  local res = ""
  for _, number_parts in ipairs(number_part_list) do
    res = res .. number_parts[1]
    if number_parts[2] ~= "" then
      res = res .. range_delimiter
      res = res .. number_parts[2]
    end
    res = res .. number_parts[3]
  end
  return res
end

---@alias NumberToken {type: string, value: string, delimiter_type: string}

---@param number string
---@param context Context
---@return NumberToken[]
function Element:parse_number_tokens(number, context)
  local and_text = "and"
  local and_symbol = "&"
  if context then
    and_text = context.locale:get_simple_term("and") or "and"
    and_symbol = context.locale:get_simple_term("and", "symbol") or "&"
  end
  -- util.debug(and_symbol)

  local space = l.S(" \t\r\n")
  local delimiter_patt = space^0 * l.P(",") * space^0 * l.P(and_text) * space^1
    + space^0 * l.P(",") * space^0 * l.P(and_symbol) * space^0
    + space^0 * l.P(",") * space^0 * l.P("&") * space^0
    + space^1 * l.P(and_text) * space^1
    + space^0 * l.P(and_symbol) * space^0
    + space^0 * l.P("&") * space^0
    + space^0 * l.P(",") * space^0
    + space^0 * l.P("-") * space^0
    + space^0 * l.P(util.unicode["en dash"]) * space^0
  local delimiter = l.C(delimiter_patt^1) / function (delimiter)
    return {
      type = "delimiter",
      value = delimiter,
    }
  end
  local token_patt = l.C((l.P("\\-") + 1 - delimiter_patt)^1) / function (token)
    return {
      type = "string",
      value = token,
    }
  end
  local grammer = l.Ct((token_patt * (delimiter * token_patt)^0)^-1)
  local tokens = grammer:match(number)
  -- util.debug(tokens)

  if not tokens then
    return {}
  end
  ---@cast tokens NumberToken[]

  for i, token in ipairs(tokens) do
    if token.type == "string" then
      token.value = string.gsub(token.value, "\\%-", "-")
    elseif token.type == "delimiter" then
      token.value = string.gsub(token.value, "%s*,%s*", ", ")
      token.value = string.gsub(token.value, "&", and_symbol)
      token.value = string.gsub(token.value, "%s*&%s*", " & ")
    end
  end

  local stop_index = 0
  for i, token in ipairs(tokens) do
    if token.type == "string" then
      if string.match(token.value, "^%w*%d+%w*$")
          or string.match(token.value, "^[mdclxvi]+$")
          or string.match(token.value, "^[MDCLXVI]+$") then
        token.type = "number"
      else
        stop_index = i
        if i > 1 and tokens[i-1].type == "delimiter" then
          stop_index = i - 1
        end
        break
      end
    elseif token.type == "delimiter" then
      token.delimiter_type = "and"
      if string.match(token.value, "^%s*-%s*$")
          or string.match(token.value, "^%s*–%s*$") then
        token.delimiter_type = "range"
        if i > 2 and tokens[i-2].delimiter_type == "range" then
          stop_index = i
          break
        end
      end
    end
  end

  if stop_index > 0 then
    local token = tokens[stop_index]
    token.type = "string"
    for i = stop_index + 1, #tokens do
      token.value = token.value .. tokens[i].value
    end
    for i = #tokens, stop_index + 1, -1 do
      table.remove(tokens, i)
    end
  end

  return tokens
end

-- Returns something like
-- {
--   {"1", "",  " & "}
--   {"5", "8", ", "}
-- }
function Element:split_number_parts_lpeg(number, context)
  local tokens = self:parse_number_tokens(number, context)
  local number_parts = {}
  for i, token in ipairs(tokens) do
    if token.type == "number" then
      if i == 1 or tokens[i-1].delimiter_type == "and" then
        table.insert(number_parts, {token.value, "", ""})
      else
        number_parts[#number_parts][2] = token.value
      end
    elseif token.type == "delimiter" then
      if token.delimiter_type == "and" then
        number_parts[#number_parts][3] = token.value
      end
    else
      if #number_parts > 0 then
        number_parts[#number_parts][3] = token.value
      else
        table.insert(number_parts, {token.value, "", ""})
      end
    end
  end
  -- util.debug(number_parts)
  return number_parts
end


function Element:split_number_parts(number, context)
  -- number = string.gsub(number, util.unicode["en dash"], "-")
  local and_symbol
  and_symbol = context.locale:get_simple_term("and", "symbol")
  if and_symbol then
    and_symbol = " " .. and_symbol .. " "
  end
  local number_part_list = {}
  for _, tuple in ipairs(util.split_multiple(number, "%s*[,&]%s*", true)) do
    local single_number, delim = table.unpack(tuple)
    delim = util.strip(delim)
    if delim == "," then
      delim = ", "
    elseif delim == "&" then
      delim = and_symbol or " & "
    elseif delim == "and" then
      delim = " and "
    elseif delim == "et" then
      delim = " et "
    end
    local start = single_number
    local stop = ""
    local splits = util.split(start, "%s*%-%s*")
    if #splits == 2 then
      start, stop = table.unpack(splits)
      if util.endswith(start, "\\") then
        start = string.sub(start, 1, -2)
        start = start .. "-" .. stop
        stop = ""
      end
      -- if string.match(start, "^%a*%d+%a*$") and string.match(stop, "^%a*%d+%a*$") then
      --   if s
        table.insert(number_part_list, {start, stop, delim})
      -- else
        -- table.insert(number_part_list, {start .. "-" .. stop, "", delim})
      -- end
    else
      table.insert(number_part_list, {start, stop, delim})
    end
  end
  return number_part_list
end

function Element:format_roman_number_parts(number_parts)
  for i = 1, 2 do
    local part = number_parts[i]
    if string.match(part, "%d+") then
      number_parts[i] = util.convert_roman(tonumber(part))
    end
  end
end

function Element:format_ordinal_number_parts(number_parts, form, gender, context)
  for i = 1, 2 do
    local part = number_parts[i]
    -- Values like "2nd" are kept the in the original form.
    if string.match(part, "^%d+$") then
      local number = tonumber(part)
      if form == "long-ordinal" and number >= 1 and number <= 10 then
        number_parts[i] = context:get_simple_term(string.format("long-ordinal-%02d", number))
      else
        local suffix = context.locale:get_ordinal_term(number, gender)
        if suffix then
          number_parts[i] = number_parts[i] .. suffix
        end
      end
    end
  end
end

function Element:format_numeric_number_parts(number_parts)
  -- if number_parts[2] ~= "" then
  --   local first_prefix = string.match(number_parts[1], "^(.-)%d+")
  --   local second_prefix = string.match(number_parts[2], "^(.-)%d+")
  --   if first_prefix == second_prefix then
  --     number_parts[1] = number_parts[1] .. "-" .. number_parts[2]
  --     number_parts[2] = ""
  --   end
  -- end
end

-- https://docs.citationstyles.org/en/stable/specification.html#appendix-v-page-range-formats
function Element:format_page_range(number_parts, page_range_format)
  local start = number_parts[1]
  local stop = number_parts[2]

  if string.match(start, "^%a+$") and string.match(stop, "^%a+$") then
    -- CMoS exaple: xxv–xxviii
    return stop
  end

  local start_prefix, start_num  = string.match(start, "^(.-)(%d+)$")
  local stop_prefix, stop_num = string.match(stop, "^(.-)(%d+)$")
  if start_prefix ~= stop_prefix then
    -- Not valid range: "n11564-1568" -> "n11564-1568"
    -- 110-N6
    -- N110-P5
    number_parts[1] = start .. "-" .. stop
    number_parts[2] = ""
    return
  end

  if not page_range_format then
    return
  end
  if page_range_format == "chicago-16" then
    stop = self:_format_range_chicago_16(start_num, stop_num)
  elseif page_range_format == "chicago-15" then
    stop = self:_format_range_chicago_15(start_num, stop_num)
  elseif page_range_format == "expanded" then
    stop = stop_prefix .. self:_format_range_expanded(start_num, stop_num)
  elseif page_range_format == "minimal" then
    stop = self:_format_range_minimal(start_num, stop_num)
  elseif page_range_format == "minimal-two" then
    stop = self:_format_range_minimal(start_num, stop_num, 2)
  end
  number_parts[2] = stop
end

---comment
---@param start string
---@param stop string
---@return string
function Element:_format_range_chicago_16(start, stop)
  if not start then
    print(debug.traceback())
  end
  stop = self:_format_range_expanded(start, stop)
  if #start < 3 or string.sub(start, -2) == "00" then
    return self:_format_range_expanded(start, stop)
  elseif string.sub(start, -2, -2) == "0" then
    return self:_format_range_minimal(start, stop)
  else
    return self:_format_range_minimal(start, stop, 2)
  end
  return stop
end

function Element:_format_range_chicago_15(start, stop)
  if #start < 3 or string.sub(start, -2) == "00" then
    return self:_format_range_expanded(start, stop)
  else
    stop = self:_format_range_expanded(start, stop)
    local changed_digits = self:_format_range_minimal(start, stop)
    if string.sub(start, -2, -2) == "0" then
      return changed_digits
    elseif #start == 4 and #changed_digits == 3 then
      return self:_format_range_expanded(start, stop)
    else
      return self:_format_range_minimal(start, stop, 2)
    end
  end
  return stop
end

function Element:_format_range_expanded(start, stop)
  -- Expand  "1234–56" -> "1234–1256"
  if #start <= #stop then
    return stop
  end
  return string.sub(start, 1, #start - #stop) .. stop
end

---comment
---@param start string
---@param stop string
---@param threshold integer? Number of minimal digits
---@return string
function Element:_format_range_minimal(start, stop, threshold)
  -- util.debug(start)
  -- util.debug(stop)
  -- util.debug(threshold)
  threshold = threshold or 1
  if #start < #stop then
    return stop
  end
  local offset = #start - #stop
  for i = 1, #stop - threshold do
    local j = i + offset
    if string.sub(stop, i, i) ~= string.sub(start, j, j) then
      return string.sub(stop, i)
    end
  end
  local res = string.sub(stop, -threshold)
  -- util.debug(res)
  return res
end

element.Element = Element

return element