#  sfzen/__init__.py
#
#  Copyright 2024 liyang <liyang@veronica>
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software
#  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
#  MA 02110-1301, USA.
#
"""
Simple object-oriented SFZ parsing and manipulation.
"""
import logging, re
from os import mkdir
from os.path import abspath, basename, dirname, join, splitext, exists
from math import ceil
from operator import or_
from functools import reduce
from collections import defaultdict
from appdirs import user_cache_dir
from lark import Lark, Transformer, v_args
from lark.tree import Meta
from log_soso import log_error
from sfzen.sfz_elems import (
	_SFZElement,
	_Modifier,
	Header,
	Global,
	Master,
	Group,
	Region,
	Control,
	Effect,
	Midi,
	Curve,
	Opcode,
	Sample,
	Define,
	Include
)
from sfzen.sort import midi_note_sort_key

__version__ = "1.2.2"

COMMENT_DIVIDER = '// ' + '-' * 76 + "\n"

SAMPLES_ABSPATH				= 0
SAMPLES_RESOLVE				= 1
SAMPLES_COPY				= 2
SAMPLES_MOVE				= 3
SAMPLES_SYMLINK				= 4
SAMPLES_HARDLINK			= 5

GLOBALIZE_NONE				= 0
GLOBALIZE_NUMEROUS			= 1
GLOBALIZE_UNIVERSAL			= 2

KEY_OPCODES = [
	'lokey',
	'hikey',
	'pitch_keycenter'
]
LOOP_DEFINITION_OPCODES = [
	'egN_loop',
	'egN_loop_count',
	'loop_count',
	'loop_crossfade',
	'loop_end',
	'loop_length_onccN',
	'loop_lengthccN',
	'loop_mode',
	'loop_start',
	'loop_start_onccN',
	'loop_startccN',
	'loop_tune',
	'loop_type',
	'loopcount',
	'loopend',
	'loopmode',
	'loopstart',
	'looptune',
	'looptype'
]
SAMPLE_UNIT_OPCODES = [
	'delay_samples',
	'delay_samples_onccN',
	'end',
	'loop_end',
	'loop_length_onccN',
	'loop_lengthccN',
	'loop_start',
	'loop_start_onccN',
	'loop_startccN',
	'loopend',
	'loopstart',
	'offset',
	'offset_ccN',
	'offset_onccN',
	'offset_random'
]

class SFZXformer(Transformer):
	"""
	Takes the parse tree generated by lark and generates an SFZ object.
	"""

	def __init__(self, sfz):
		self.sfz = sfz
		self.current_header = self.sfz

	@v_args(meta=True)
	def header(self, arg1, arg2):
		"""
		Transformer function which handles <header> tokens.
		"""
		meta, toks = self.wonky_lark_args(arg1, arg2)
		if toks[0].value == 'region':
			header = Region(meta)
		elif toks[0].value == 'group':
			header = Group(meta)
		elif toks[0].value == 'control':
			header = Control(meta)
		elif toks[0].value == 'global':
			header = Global(meta)
		elif toks[0].value == 'curve':
			header = Curve(meta)
		elif toks[0].value == 'effect':
			header = Effect(meta)
		elif toks[0].value == 'master':
			header = Master(meta)
		elif toks[0].value == 'midi':
			header = Midi(meta)
		while not self.current_header.may_contain(header):
			self.current_header = self.current_header.parent
		self.current_header.append_subheader(header)
		self.current_header = header

	@v_args(meta=True)
	def define_macro(self, arg1, arg2):
		"""
		Transformer function which handles <define> tokens.
		"""
		meta, toks = self.wonky_lark_args(arg1, arg2)
		try:
			self.sfz.defines[toks[0].value] = Define(toks[0].value, toks[1].value, meta)
		except Exception as e:
			self.sfz.append_parse_error(e, meta)

	@v_args(meta=True)
	def include_macro(self, arg1, arg2):
		"""
		Transformer function which handles <include> tokens.
		"""
		meta, toks = self.wonky_lark_args(arg1, arg2)
		try:
			include = Include(self.unquote(self.replace_defs(toks[0].value)), meta)
			self.sfz.includes.append(include)
			path = join(dirname(self.sfz.filename), include.filename)
			if exists(path):
				try:
					subsfz = SFZ(path, defines = self.sfz.defines, basedir = self.sfz.basedir, is_include = True)
					for header in subsfz.subheaders:
						while not self.current_header.may_contain(header):
							self.current_header = self.current_header.parent
						self.current_header.append_subheader(header)
						self.current_header = header
					self.sfz.defines = subsfz.defines
					self.sfz.includes.extend(subsfz.includes)
				except Exception as e:
					log_error(e)
			else:
				raise ValueError('Include not found: %s' % path)
		except Exception as e:
			self.sfz.append_parse_error(e, meta)

	@v_args(meta=True)
	def opcode_exp(self, arg1, arg2):
		"""
		Transformer function which handles <opcode> tokens.
		"""
		meta, toks = self.wonky_lark_args(arg1, arg2)
		try:
			if isinstance(self.current_header, Curve):
				if toks[0] == 'curve_index':
					self.current_header.curve_index = toks[1].value
				else:
					if re.match(r'v(\d+)', toks[0].value):
						self.current_header.points[toks[0].value] = toks[1].value
					else:
						logging.error('Invalid opcode inside velocity curve definition')
			else:
				opname = self.replace_defs(toks[0].value).lower()
				self.current_header.append_opcode(Opcode(
					opname, self.replace_defs(toks[1].value),
					meta, self.sfz.basedir))
		except Exception as e:
			self.sfz.append_parse_error(e, meta)

	@v_args(meta=True)
	def start(self, arg1, arg2):
		"""
		Transformer function which handles the root of the sfz.
		"""

	def wonky_lark_args(self, arg1, arg2):
		"""
		When @v_args decorates a lark.Transformer function, lark passes parsed tokens
		and meta info to the function. Different versions of lark parser seem to pass
		the tokens and meta arguments in different order. This function fixes that.
		"""
		if isinstance(arg1, Meta):
			return arg1, arg2
		return arg2, arg1

	def replace_defs(self, var):
		"""
		Replace defined variable with the value that it has been defined as in the sfz.
		"""
		return re.sub(r'\$(\w+)', lambda v: self.sfz.defines[v.group(1)].value, var)

	def unquote(self, var):
		"""
		Remove quotes around a parsed value.
		"""
		for q in ["'", '"']:
			if var[0] == q and var[-1] == q:
				return var[1:-1]
		return var


class SFZ(Header):
	"""
	Provides an object-oriented interface to an .sfz file.
	"""

	_parser = None

	def __init__(self, filename = None, defines = None, basedir = None, is_include = False):
		"""
		filename: (str) Path to an .sfz file.

		Passing "defines" allows us to construct an SFZ which is an import, and use the
		defined variables from the parent SFZ.

		Passing "basedir" allows included SFZ parts to use the directory of their
		parent when parsing sample paths.
		"""
		self.filename = filename
		self.defines = defines or {}
		self.is_include = is_include
		self._parent = None
		self._subheaders = []
		self._opcodes = {}
		self.includes = []
		self.parse_errors = []
		if filename is None:
			self.name = '[unnamed SFZ]'
			self.basedir = None
		else:
			self.name = basename(filename)
			self.basedir = dirname(self.filename) if basedir is None else basedir
			if SFZ._parser is None:
				cache_file = join(user_cache_dir(), 'sfzen')
				grammar = join(dirname(__file__), 'res', 'sfz.lark')
				SFZ._parser = Lark.open(grammar, parser='lalr', propagate_positions=True, cache=cache_file)
			with open(filename) as f:
				tree = SFZ._parser.parse(f.read() + "\n")
			xformer = SFZXformer(self)
			xformer.transform(tree)

	def may_contain(self, _):
		return True

	def append_opcode(self, opcode):
		if not self.is_include:
			raise RuntimeError("Opcode outside of header")
		super().append_opcode(opcode)

	def append_parse_error(self, error, meta):
		self.parse_errors.append((error, meta))

	def inherited_opcodes(self):
		return {}

	def global_header(self):
		"""
		Returns the Global header from this SFZ; it one does not exist, creates it at
		the beginning of the header list.
		"""
		for sub in self._subheaders:
			if isinstance(sub, Global):
				return sub
		global_header = Global(None, None)
		self._subheaders.insert(0, global_header)
		global_header.parent = self
		return global_header

	def samples(self):
		"""
		Returns all <sample> opcodes contained in this SFZ.
		Generator which yields a Sample object on each iteration.
		"""
		for sub in self._subheaders:
			yield from sub.samples()

	def opcodes_used(self):
		"""
		Returns a set of the keys of all the opcodes used in this SFZ.
		"""
		return reduce(or_, [sub.opcodes_used() for sub in self._subheaders], set())

	def regions_for(self, key=None, lokey=None, hikey=None, lovel=None, hivel=None):
		"""
		Generator function which yields each region contained in this SFZ which matches
		the given criteria.
		For example, to get every region which plays Middle C at any velocity:
			sfz.regions_for(lokey = 60, hikey = 60)
		"""
		for region in self.regions():
			if region.is_triggerd_by(key, lokey, hikey, lovel, hivel):
				yield region

	def save(self):
		"""
		Save to original filename without any further modification.
		"""
		with open(self.filename, 'w', encoding = 'utf-8') as fob:
			self.write(fob)
		return self

	def save_as(self, filename, samples_mode = SAMPLES_ABSPATH):
		"""
		Save to the given filename.

		"samples_mode" is a constant which defines how to render "sample"
		opcodes. May be one of:

			SAMPLES_ABSPATH		SAMPLES_RESOLVE		SAMPLES_COPY
			SAMPLES_MOVE		SAMPLES_SYMLINK		SAMPLES_HARDLINK

		"""
		filename = abspath(filename)
		target_sfz_dir = dirname(filename)
		filetitle, ext = splitext(basename(filename))
		if not exists(target_sfz_dir):
			mkdir(target_sfz_dir)
		if samples_mode == SAMPLES_ABSPATH:
			for sample in self.samples():
				sample.use_abspath()
		elif samples_mode == SAMPLES_RESOLVE:
			for sample in self.samples():
				sample.resolve_from(target_sfz_dir)
		else:
			samples_path = filetitle + '-samples'
			target_samples_dir = join(target_sfz_dir, samples_path)
			if not exists(target_samples_dir):
				mkdir(join(target_sfz_dir, samples_path))
			for sample in self.samples():
				if samples_mode == SAMPLES_COPY:
					sample.copy_to(target_sfz_dir, samples_path)
				elif samples_mode == SAMPLES_MOVE:
					sample.move_to(target_sfz_dir, samples_path)
				elif samples_mode == SAMPLES_SYMLINK:
					sample.symlink_to(target_sfz_dir, samples_path)
				elif samples_mode == SAMPLES_HARDLINK:
					sample.hardlink_to(target_sfz_dir, samples_path)
		self.filename = filename + '.sfz' if ext == '' else filename
		self.save()

	def write(self, stream):
		"""
		Exports this SFZ to .sfz format.
		"stream" may be any file-like object, including sys.stdout.
		"""
		stream.write(f'{COMMENT_DIVIDER}// {self.filename}\n{COMMENT_DIVIDER}\n')
		for sub in self._subheaders:
			sub.write(stream)

	def simplified(self, globalize_mode = GLOBALIZE_NUMEROUS):
		"""
		Returns an equivalent SFZ with common opcodes grouped, and opcodes using the
		default value skipped.
		"""

		simplified_sfz = SFZ()
		simplified_sfz.basedir = self.basedir
		global_header = Global(None, None)

		regions = [
			self._clone_sample_region(sample) \
			for sample in self.samples()
		]

		# Condense "lokey", "hikey", "pitch_keycenter"
		for region in regions:
			key_related_values = [
				opcode.value for opcode in region.opcodes.values() \
				if opcode.name in KEY_OPCODES
			]
			if len(key_related_values) == 3 and len(set(key_related_values)) == 1:
				for opcode_name in KEY_OPCODES:
					del region._opcodes[opcode_name]
				region.append_opcode(Opcode('key', key_related_values[0], None))

		# Remove loop-related opcodes for regions that are not looped:
		for region in regions:
			if region.loop_mode or region.loopmode == 'no_loop':
				for opcode_name in LOOP_DEFINITION_OPCODES:
					try:
						del region._opcodes[opcode_name]
					except KeyError:
						pass

		# Place opcodes which have the same value in a majority of regions
		# into the global header:
		if globalize_mode == GLOBALIZE_NUMEROUS and len(regions) > 1:
			opstring_counts = defaultdict(int)
			min_count = ceil(len(regions) / 2)
			for region in regions:
				for opstring in region.opstrings():
					if opstring.split('=', 1)[0] != 'sample':
						opstring_counts[opstring] += 1
			print(opstring_counts)
			for opstring, count in opstring_counts.items():
				if count >= min_count:
					opcode, value = opstring.split('=', 1)
					for region in regions:
						if opcode in region._opcodes and str(region._opcodes[opcode]) == opstring:
							del region._opcodes[opcode]
					global_header.append_opcode(Opcode(opcode, value))

		# Sort in key order:
		regions.sort(key = midi_note_sort_key)

		# Group regions based on common key:
		key_grouped_regions = defaultdict(list)
		for region in regions:
			key_grouped_regions[midi_note_sort_key(region)].append(region)
		for regions in key_grouped_regions.values():
			if len(regions) > 1:
				group = Group(None, None)
				for region in regions:
					group.append_subheader(region)
				common_opstrings = group.common_opstrings()
				if any(len(region._opcodes) > len(common_opstrings) for region in regions):
					for tup in [ opstring.split('=', 1) for opstring in common_opstrings ]:
						group.append_opcode(Opcode(tup[0], tup[1]))
						for region in regions:
							del region._opcodes[tup[0]]
					simplified_sfz.append_subheader(group)
				else:
					logging.warning('All region opstrings are common')
					simplified_sfz.append_subheader(regions[0])
			else:
				simplified_sfz.append_subheader(regions[0])

		if globalize_mode == GLOBALIZE_UNIVERSAL and len(list(simplified_sfz.regions())) > 1:
			# Filter global opstrings:
			common_opstrings = simplified_sfz.common_opstrings()
			if len(common_opstrings):
				opstring_tuples = [ opstring.split('=', 1) \
					for opstring in common_opstrings ]
				opstring_tuples = [ tup for tup in opstring_tuples if tup[0] != 'sample' ]
				for tup in opstring_tuples:
					global_header.append_opcode(Opcode(tup[0], tup[1]))
				simplified_sfz.remove_opcodes([ tup[0] for tup in opstring_tuples ])

		if len(global_header._opcodes):
			simplified_sfz._subheaders.insert(0, global_header)
			global_header.parent = simplified_sfz

		# Give "Sample" opcodes a basedir for abspath function:
		for elem,_ in simplified_sfz.walk():
			if isinstance(elem, Sample):
				elem.basedir = self.basedir

		return simplified_sfz

	def _clone_sample_region(self, sample):
		region = Region(None, None)
		for opcode in sample.parent.inherited_opcodes().values():
			region.append_opcode(opcode)
		return region

	def dump(self):
		"""
		Print (to stdout) a concise outline of this SFZ.
		"""
		for elem, depth in self.walk():
			print('  ' * depth + repr(elem))

	def __repr__(self):
		return f'SFZ {self.filename}'


#  end sfzen/__init__.py
