from os import remove
from nut import Print
from time import sleep
from pathlib import Path
from traceback import format_exc
from zstandard import ZstdCompressor
from ThreadSafeCounter import Counter
from SectionFs import isNcaPacked, sortedFs
from multiprocessing import Process, Manager
from Fs import Pfs0, Hfs0, Nca, Type, Ticket, Xci, factory
from PathTools import *
import enlighten
#import sys

def compressBlockTask(in_queue, out_list, readyForWork, pleaseKillYourself):
	while True:
		readyForWork.increment()
		item = in_queue.get()
		readyForWork.decrement()
		if pleaseKillYourself.value() > 0:
			break
		buffer, compressionLevel, compressedblockSizeList, chunkRelativeBlockID = item # compressedblockSizeList IS UNUSED VARIABLE
		if buffer == 0:
			return
		compressed = ZstdCompressor(level=compressionLevel).compress(buffer)
		#sys.stdout.write(str(len(compressed) - len(buffer)) + "\n")
		out_list[chunkRelativeBlockID] = compressed if len(compressed) < len(buffer) else buffer

def blockCompress(filePath, compressionLevel, blockSizeExponent, outputDir, threads):
	if filePath.suffix == '.nsp':
		return blockCompressNsp(filePath, compressionLevel, blockSizeExponent, outputDir, threads)
	elif filePath.suffix == '.xci':
		return blockCompressXci(filePath, compressionLevel, blockSizeExponent, outputDir, threads)

def blockCompressContainer(readContainer, writeContainer, compressionLevel, blockSizeExponent, threads):
	CHUNK_SZ = 0x100000
	ncaHeaderSize = 0x4000
	if blockSizeExponent < 14 or blockSizeExponent > 32:
		raise ValueError("Block size must be between 14 and 32")
	blockSize = 2**blockSizeExponent
	manager = Manager()
	results = manager.list()
	readyForWork = Counter(0)
	pleaseKillYourself = Counter(0)
	TasksPerChunk = 209715200//blockSize
	for i in range(TasksPerChunk):
		results.append(b"")
	pool = []
	work = manager.Queue(threads)
	
	for i in range(threads):
		p = Process(target=compressBlockTask, args=(work, results, readyForWork, pleaseKillYourself))
		p.start()
		pool.append(p)

	for nspf in readContainer:
		if isinstance(nspf, Nca.Nca) and nspf.header.contentType == Type.Content.DATA:
			Print.info('Skipping delta fragment {0}'.format(nspf._path))
			continue
		if isinstance(nspf, Nca.Nca) and (nspf.header.contentType == Type.Content.PROGRAM or nspf.header.contentType == Type.Content.PUBLICDATA):
			if isNcaPacked(nspf, ncaHeaderSize):
				newFileName = nspf._path[0:-1] + 'z'
				f = writeContainer.add(newFileName, nspf.size)
				startPos = f.tell()
				nspf.seek(0)
				f.write(nspf.read(ncaHeaderSize))
				sections = []

				for fs in sortedFs(nspf):
					sections += fs.getEncryptionSections()

				if len(sections) == 0:
					for p in pool:
						#Process.terminate() might corrupt the datastructure but we do't care
						p.terminate()
					raise Exception("NCA can't be decrypted. Outdated keys.txt?")
				header = b'NCZSECTN'
				header += len(sections).to_bytes(8, 'little')
				i = 0

				for fs in sections:
					i += 1
					header += fs.offset.to_bytes(8, 'little')
					header += fs.size.to_bytes(8, 'little')
					header += fs.cryptoType.to_bytes(8, 'little')
					header += b'\x00' * 8
					header += fs.cryptoKey
					header += fs.cryptoCounter

				f.write(header)
				blockID = 0
				chunkRelativeBlockID = 0
				startChunkBlockID = 0
				blocksHeaderFilePos = f.tell()
				bytesToCompress = nspf.size - ncaHeaderSize
				blocksToCompress = bytesToCompress//blockSize + (bytesToCompress%blockSize > 0)
				compressedblockSizeList = [0]*blocksToCompress
				header = b'NCZBLOCK' #Magic
				header += b'\x02' #Version
				header += b'\x01' #Type
				header += b'\x00' #Unused
				header += blockSizeExponent.to_bytes(1, 'little') #blockSizeExponent in bits: 2^x
				header += blocksToCompress.to_bytes(4, 'little') #Amount of Blocks
				header += bytesToCompress.to_bytes(8, 'little') #Decompressed Size
				header += b'\x00' * (blocksToCompress*4)
				f.write(header)
				decompressedBytes = ncaHeaderSize
				compressedBytes = f.tell()
				BAR_FMT = u'{desc}{desc_pad}{percentage:3.0f}%|{bar}| {count:{len_total}d}/{total:d} {unit} [{elapsed}<{eta}, {rate:.2f}{unit_pad}{unit}/s]'
				bar = enlighten.Counter(total=nspf.size//1048576, desc='Compressing', unit='MiB', color='cyan', bar_format=BAR_FMT)
				subBars = bar.add_subcounter('green', all_fields=True)
				partitions = [nspf.partition(offset = section.offset, size = section.size, n = None, cryptoType = section.cryptoType, cryptoKey = section.cryptoKey, cryptoCounter = bytearray(section.cryptoCounter), autoOpen = True) for section in sections]
				partNr = 0
				bar.count = nspf.tell()//1048576
				subBars.count = f.tell()//1048576
				bar.refresh()
				while True:
					buffer = partitions[partNr].read(blockSize)
					while (len(buffer) < blockSize and partNr < len(partitions)-1):
						partitions[partNr].close()
						partitions[partNr] = None
						partNr += 1
						buffer += partitions[partNr].read(blockSize - len(buffer))
					if chunkRelativeBlockID >= TasksPerChunk or len(buffer) == 0:
						while readyForWork.value() < threads:
							sleep(0.02)

						for i in range(min(TasksPerChunk, blocksToCompress-startChunkBlockID)):
							lenResult = len(results[i])
							compressedBytes += lenResult
							compressedblockSizeList[startChunkBlockID+i] = lenResult
							f.write(results[i])
							results[i] = b""

						if len(buffer) == 0:
							break
						chunkRelativeBlockID = 0
						startChunkBlockID = blockID
					work.put([buffer, compressionLevel, compressedblockSizeList, chunkRelativeBlockID])
					blockID += 1
					chunkRelativeBlockID += 1
					decompressedBytes += len(buffer)
					bar.count = decompressedBytes//1048576
					subBars.count = compressedBytes//1048576
					bar.refresh()
				partitions[partNr].close()
				partitions[partNr] = None
				endPos = f.tell()
				bar.count = decompressedBytes//1048576
				subBars.count = compressedBytes//1048576
				bar.close()
				written = endPos - startPos
				f.seek(blocksHeaderFilePos+24)
				header = b""

				for compressedblockSize in compressedblockSizeList:
					header += compressedblockSize.to_bytes(4, 'little')

				f.write(header)
				f.seek(endPos) #Seek to end of file.
				Print.info('compressed %d%% %d -> %d  - %s' % (int(written * 100 / nspf.size), decompressedBytes, written, nspf._path))
				writeContainer.resize(newFileName, written)
				continue
			else:
				Print.info('Skipping not packed {0}'.format(nspf._path))
		f = writeContainer.add(nspf._path, nspf.size)
		nspf.seek(0)
		while not nspf.eof():
			buffer = nspf.read(CHUNK_SZ)
			f.write(buffer)

	#Ensures that all threads are started and compleaded before being requested to quit
	while readyForWork.value() < threads:
		sleep(0.02)
	pleaseKillYourself.increment()

	for i in range(readyForWork.value()):
		work.put(None)

	while readyForWork.value() > 0:
		sleep(0.02)


def blockCompressNsp(filePath, compressionLevel , blockSizeExponent, outputDir, threads):
	filePath = filePath.resolve()
	container = factory(filePath)
	container.open(str(filePath), 'rb')
	nszPath = outputDir.joinpath(filePath.stem + '.nsz')

	Print.info('Block compressing (level {0}) {1} -> {2}'.format(compressionLevel, filePath, nszPath))
	
	try:
		with Pfs0.Pfs0Stream(str(nszPath)) as nsp:
			blockCompressContainer(container, nsp, compressionLevel, blockSizeExponent, threads)
	except BaseException as ex:
		if not ex is KeyboardInterrupt:
			Print.error(format_exc())
		if nszPath.is_file():
			nszPath.unlink()

	container.close()
	return nszPath
	
def blockCompressXci(filePath, compressionLevel, blockSizeExponent, outputDir, threads):
	filePath = filePath.resolve()
	container = factory(filePath)
	container.open(str(filePath), 'rb')
	secureIn = container.hfs0['secure']
	xczPath = outputDir.joinpath(filePath.stem + '.xcz')

	Print.info('Block compressing (level {0}) {1} -> {2}'.format(compressionLevel, filePath, xczPath))
	
	try:
		with Xci.XciStream(str(xczPath), originalXciPath = filePath) as xci: # need filepath to copy XCI container settings
			with Hfs0.Hfs0Stream(xci.hfs0.add('secure', 0), xci.f.tell()) as secureOut:
				blockCompressContainer(secureIn, secureOut, compressionLevel, blockSizeExponent, threads)
			
			xci.hfs0.resize('secure', secureOut.actualSize)
	except BaseException as ex:
		if not ex is KeyboardInterrupt:
			Print.error(format_exc())
		if xczPath.is_file():
			xczPath.unlink()

	container.close()
	return xczPath
