#!python
#SPDX-License-Identifier: GPL-3.0-only
#
# Copyright (C) 2019 Alberto Pianon <pianon@array.eu>
#
# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, version 3.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <https://www.gnu.org/licenses/>.

import panflute as pf

import re

from io import StringIO


def upperRoman(number):
    roman_digit_map = [
        (1000, 'M'), (900, 'CM'), (500, 'D'), (400, 'CD'),
        (100, 'C'),  (90, 'XC'),  (50, 'L'),  (40, 'XL'),
        (10, 'X'),   (9, 'IX'),   (5, 'V'),   (4, 'IV'),
        (1, 'I'),
    ]
    roman_number = ''
    i = 0
    while  number > 0:
        x = 0
        count = number // roman_digit_map[i][0]
        while x < count:
            roman_number += roman_digit_map[i][1]
            number -= roman_digit_map[i][0]
            x = x + 1
        i += 1
    return roman_number

def lowerRoman(number):
    return upperRoman(number).lower()

def lowerAlpha(number):
    return ' abcdefghijklmnopqrstuvwxyz'[number]

def upperAlpha(number):
    return ' ABCDEFGHIJKLMNOPQRSTUVWXYZ'[number]

number_formats = {
    'Decimal': str,
    'LowerAlpha': lowerAlpha,
    'UpperAlpha': upperAlpha,
    'LowerRoman': lowerRoman,
    'UpperRoman': upperRoman,
}

curlevel = 0
curid = ''

def action(elem, doc):
    global curlevel, curid
    numberSections = doc.get_metadata('numberSections')
    sectionsDepth = doc.get_metadata('sectionsDepth')
    sectionsDepth = int(sectionsDepth) if sectionsDepth else 0
    secHeaderDelim = doc.get_metadata('secHeaderDelim')
    inlineHeaderLevel = doc.get_metadata('inlineHeaderLevel')
    crossrefOrderedList = doc.get_metadata('crossrefOrderedList')
    if not numberSections:
        return None
    if not inlineHeaderLevel and not crossrefOrderedList:
        return None
    if isinstance(elem, pf.Header):
        curlevel = elem.level
        curid = elem.identifier
    elif isinstance(elem, pf.OrderedList):
        ordered_list = elem
        if ordered_list.style not in [style for style in number_formats]:
            return None
        format_number = number_formats[ordered_list.style]
        for idx, list_item in enumerate(ordered_list.content):
            number = format_number(idx+1)
            metadata_start_el_idx = None
            metadata_end_el_idx = None
            if list_item.content and list_item.content[0].content:
                # list item content consists of a 'Plain' object containing in
                # turn the objects constituting the "actual" content

                # search for metadata within curly brackets
                content_length = len(list_item.content[0].content)
                for el_idx, el in enumerate(list_item.content[0].content):
                    if isinstance(el, pf.Str):
                        if el.text.startswith('{'):
                            metadata_start_el_idx = el_idx
                        if el.text.endswith('}'):
                            metadata_end_el_idx = el_idx
            if (
                    metadata_start_el_idx is not None
                    and
                    metadata_end_el_idx is not None
                    and
                    metadata_start_el_idx <= metadata_end_el_idx
            ):
                # 1) create Header object
                tmp_span = pf.Span()
                for m in range(metadata_end_el_idx + 1):
                    tmp_span.content.append(
                        list_item.content[0].content[m])
                header_str = "{} {} {}\n".format(
                    "#" * (curlevel+1),
                    (
                        number + (
                            secHeaderDelim
                            if metadata_start_el_idx > 0
                            else "")
                        if(curlevel >= sectionsDepth and sectionsDepth !=-1)
                        else ""
                    ),
                    pf.stringify(tmp_span)
                )
                metadata = re.search("\{.*\}", header_str)[0]
                identifier_set_by_user = [
                    i
                    for i in metadata[1:-1].split(" ")
                    if i.startswith("#")
                ]
                header = pf.convert_text(
                    header_str.replace("\t", "&#x09;"),
                    extra_args=["--preserve-tabs",]
                    # replace and extra_args are necessary to preserve tabs
                )[0]
                header.attributes['label'] = number
                if not identifier_set_by_user:
                    # avoid possible duplicate auto-generated identifiers
                    header.identifier = curid+":"+number
                # 2) create Para object
                para = pf.Para()
                if content_length > metadata_end_el_idx + 1:
                    for p in range(metadata_end_el_idx+2, content_length):
                        para.content.append(list_item.content[0].content[p])

            else:
                # 1) create Header object
                header = pf.Header(
                    (
                        pf.Str(number)
                        if(curlevel >= sectionsDepth and sectionsDepth !=-1)
                        else pf.Str("")
                    ),
                    level=curlevel+1,
                    identifier=curid+":"+number,
                    attributes={'label': number},
                )
                # 2) create Para object
                para = pf.Para()
                if list_item.content and list_item.content[0].content:
                    for el in list_item.content[0].content:
                        para.content.append(el)

            elem.parent.content.insert(elem.index+1, header)
            elem = elem.next
            if para.content:
                elem.parent.content.insert(elem.index+1, para)
                elem = elem.next

        while len(ordered_list.content) > 0: # delete ordered list
            ordered_list.content.pop(0)

def prepare(doc):
    if doc.get_metadata('crossrefYaml'):
        with open(doc.get_metadata('crossrefYaml'), 'r') as f:
            json = pf.convert_text(
                "---\n{}\n---\n".format( f.read() ),
                output_format="json"
            )
            crossrefDoc = pf.load(StringIO(json))
            if crossrefDoc.metadata.content:
                for key, value in crossrefDoc.metadata.content.items():
                    if not doc.get_metadata(key): # do not override doc metadata
                        doc.metadata[key] = value
                return doc

if __name__ == '__main__':
    pf.toJSONFilter(action, prepare=prepare)
