import atexit
import logging
import os
import platform
import sys
import time

import glfw
import imgui
import numpy as np

from . import event
from .input import KeyCode, MouseButton, KeyModifier
from ... import misc, gl
from ...camera import Camera
from ...graphics import VertexShader, PixelShader, ShaderProgram
from ...math import Vec2

_default_vertex_shader_str = '''
#version 330

layout (location = 0) in vec3 a_position;
layout (location = 1) in vec4 a_color;
layout (location = 2) in vec3 a_normal;

out vec4 v_color;
out vec3 v_normal;
out vec3 light_pos;
out vec3 frag_pos;
out vec3 world_pos;

uniform mat4 model_mat;
uniform mat4 view_mat;
uniform mat4 proj_mat;

void main() {
    gl_Position = proj_mat * view_mat * model_mat * vec4(a_position, 1.0);
    v_color = a_color;
    
    frag_pos = vec3(view_mat * model_mat * vec4(a_position, 1.0));
    world_pos = (model_mat * vec4(a_position, 1.0)).xyz;
    light_pos = vec3(view_mat * vec4(800.0, 800.0, 800.0, 1.0));
    v_normal = mat3(transpose(inverse(view_mat * model_mat))) * a_normal;
}
'''

_default_pixel_shader_str = '''
# version 330

in vec4 v_color;
in vec3 v_normal;
in vec3 light_pos;
in vec3 frag_pos;
in vec3 world_pos;

out vec4 frag_color;

uniform bool do_shading;

vec4 simple_shading() {
    vec3 light_color = vec3(1.0, 1.0, 1.0);
    
    // face normal approximation
    vec3 x = dFdx(world_pos);
    vec3 y = dFdy(world_pos);
    vec3 face_normal = cross(x, y);
    
    // diffuse
    vec3 light_dir = normalize(light_pos - frag_pos);
    float diff = max(dot(normalize(face_normal), light_dir), 0.0);
    vec3 diffuse = 0.4 * (diff * light_color * v_color.xyz) + v_color.xyz * 0.5;
     
    return vec4(diffuse, 1.0); 
}

void main() {
    if (do_shading) {
        frag_color = simple_shading();
    } else {
        frag_color = v_color;
    }
}
'''


# TODO: detailed description of event listener parameters

class Window(event.EventDispatcher):
    """A thin wrapper around GLFW window.

    Attributes:
        _width (int): width of the window
        _height (int): height of the window
        _native_window (GLFWwindow): inner GLFWwindow

    Events:
        on_cursor_enter()
        on_cursor_leave()
        on_draw(dt)
        on_resize(width, height)
        on_mouse_motion(x, y, dx, dy)
        on_mouse_drag(x, y, dx, dy, button)
        on_mouse_press(x, y, button)
        on_mouse_release(x, y, button)
        on_mouse_scroll(x, y, x_offset, y_offset)
        on_key_press(key, mods)
        on_key_release(key, mods)
        on_update(dt)
        on_init()
        on_idle()
        on_show()
        on_hide()
        on_close()
        on_gui()
    """

    def __init__(self, title, width=600, height=600, **kwargs):
        """ Initialization of the GLFW window.

        Window creation hints are passed through `**kwargs`.

        Args:
            width (int): width of the window
            height (int): height of the window
            title (str): title of the window
            **kwargs:
                clear_color (array-like):
                    Specifies the color used to clear OpenGL color buffer.

                resizable (bool):
                    Specifies whether the window is resizable by the user.

                decorated (bool):
                    Specifies whether the windowed mode window will have window decorations such as a border,
                    a close widget, etc.

                floating (bool):
                    Specifies whether the windowed mode window will be floating above other regular windows.

                maximised (bool):
                    Specifies whether the windowed mode window will be maximized when created.

                gl_version ((int, int)):
                    OpenGL context version in tuple (major, minor).

                gl_forward_compat:
                    Specifies whether the OpenGL context should be forward-compatible, i.e. one where
                    all functionality deprecated in the requested version of OpenGL is removed. This
                    must only be used if the requested OpenGL version is 3.0 or above.

                gl_profile (str):
                    Specifies which OpenGL profile to create the context for. Possible values are
                    ['core', 'compat', 'any']

                macos_retina (bool):
                    Specifies whether to use full resolution framebuffer on Retina displays.

                macos_gfx_switch (bool):
                    Specifies whether to in Automatic Graphics Switching, i.e. to allow the system
                    to choose the integrated GPU for the OpenGL context and move it between GPUs if
                    necessary or whether to force it to always run on the discrete GPU.
        """
        super(Window, self).__init__()
        self._cursor_pos = Vec2(0., 0.)
        self._mouse_btn = MouseButton.NONE
        self._width = width
        self._height = height
        self._title = title
        self._native_window = None
        self._clear_flags = gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT | gl.GL_STENCIL_BUFFER_BIT
        self._x = 0
        self._y = 0
        self._clear_color = kwargs.get('clear_color', misc.PaletteDefault.Background)
        self._camera = None

        atexit.register(self._delete)

        if not glfw.init():
            raise RuntimeError("Failed to initialize GLFW!")

        glfw.set_error_callback(self._on_glfw_error)

        # by default the window is resizable, decorated and not floating over other windows
        glfw.window_hint(glfw.RESIZABLE, kwargs.get('resizable', True))
        glfw.window_hint(glfw.DECORATED, kwargs.get('decorated', True))
        glfw.window_hint(glfw.FLOATING, kwargs.get('floating', False))

        # by default the OpenGL version is 3.3
        (gl_major, gl_minor) = kwargs.get('gl_version', (3, 3))
        glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, gl_major)
        glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, gl_minor)

        # GLFW_OPENGL_FORWARD_COMPAT is activated on macos
        if gl_major >= 3 and platform.system() == 'Darwin':  # MacOs
            glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, True)

        profile = kwargs.get('profile', 'core')
        if profile not in ['core', 'compat', 'any']:
            raise RuntimeError(f'Unknown OpenGL profile: {profile}')
        # Context profiles are only defined for OpenGL version 3.2 and above
        # by default, OpenGL context is created with core profile
        if gl_major >= 3 and profile == 'core':
            glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE)
        elif gl_major >= 3 and profile == 'compatibility':
            glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_COMPAT_PROFILE)
        else:
            glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_ANY_PROFILE)

        # by default use full resolution on Retina display
        glfw.window_hint(glfw.COCOA_RETINA_FRAMEBUFFER, kwargs.get('macos_retina', True))

        # by default graphics switching on macos is disabled
        glfw.window_hint(glfw.COCOA_GRAPHICS_SWITCHING, kwargs.get('macos_gfx_switch', False))

        self._native_window = glfw.create_window(width, height, title, None, None)

        if not self._native_window:
            print("Failed to create GLFW window", file=sys.stderr)
            glfw.terminate()
            sys.exit()

        glfw.make_context_current(self._native_window)

        # MacOs: Check framebuffer size and window size. On retina display, they may be different.
        w, h = glfw.get_framebuffer_size(self._native_window)
        if platform.system() == 'darwin' and (w != width or h != height):
            width, height = width // 2, height // 2
            glfw.set_window_size(self._native_window, width, height)

        self._mouse_btn = MouseButton.NONE
        self._cursor_pos = [0., 0.]

        # install callbacks
        glfw.set_framebuffer_size_callback(self._native_window, self._on_glfw_framebuffer_resize)
        glfw.set_cursor_enter_callback(self._native_window, self._on_glfw_cursor_enter)
        glfw.set_window_close_callback(self._native_window, self._on_glfw_window_close)
        glfw.set_key_callback(self._native_window, self._on_glfw_key)
        glfw.set_mouse_button_callback(self._native_window, self._on_glfw_mouse_button)
        glfw.set_cursor_pos_callback(self._native_window, self._on_glfw_mouse_motion)
        glfw.set_scroll_callback(self._native_window, self._on_glfw_scroll)

        self._width, self._height = glfw.get_framebuffer_size(self._native_window)
        self._x, self._y = glfw.get_window_pos(self._native_window)

        imgui.create_context()
        # TODO: UI rendering
        # self._gui = GlfwRenderer(self._native_window) if gl_major >= 3 else None
        self._gui = None

        self._default_shader = None

        if self.current_context_version >= (3, 3):
            self._default_shader = ShaderProgram(VertexShader(_default_vertex_shader_str),
                                                 PixelShader(_default_pixel_shader_str))
            logging.info("Default shader created.")

        self._start_time = time.time()
        self._previous_time = self._start_time
        self._current_time = self._start_time

        gl.glEnable(gl.GL_DEPTH_TEST)

    def shut_down(self):
        glfw.destroy_window(self._native_window)

    @staticmethod
    def _delete():
        logging.info("GLFW clean up.")
        glfw.terminate()

    def _on_glfw_framebuffer_resize(self, _window, width, height):
        self._width = width
        self._height = height
        self.dispatch('on_resize', width, height)

    def _on_glfw_cursor_enter(self, _window, entered):
        if entered:
            self.dispatch('on_cursor_enter')
        else:
            self.dispatch('on_cursor_leave')

    def _on_glfw_window_close(self, _window):
        self.close()

    def _on_glfw_key(self, _window, key, scancode, action, mods):
        keycode = KeyCode.from_glfw_keycode(key)
        modifiers = KeyModifier.from_glfw_modifiers(mods)

        if action in [glfw.PRESS, glfw.REPEAT]:
            self.dispatch('on_key_press', keycode, modifiers)
        else:
            self.dispatch('on_key_release', keycode, modifiers)

    def _on_glfw_mouse_button(self, window, button, action, mods):
        x, y = glfw.get_cursor_pos(window)
        button = MouseButton.from_glfw_mouse_btn_code(button)
        if action == glfw.RELEASE:
            self._mouse_btn = MouseButton.NONE
            self._cursor_pos = [x, y]
            self.dispatch('on_mouse_release', x, y, button)
        elif action == glfw.PRESS:
            self._mouse_btn = button
            self._cursor_pos = [x, y]
            self.dispatch('on_mouse_press', x, y, button)

    def _on_glfw_mouse_motion(self, _window, x, y):
        dx = x - self._cursor_pos[0]
        dy = y - self._cursor_pos[1]
        self._cursor_pos = [x, y]
        if self._mouse_btn != MouseButton.NONE:
            self.dispatch('on_mouse_drag', x, y, dx, dy, self._mouse_btn)
        else:
            self.dispatch('on_mouse_motion', x, y, dx, dy)

    def _on_glfw_scroll(self, window, x_offset, y_offset):
        x, y = glfw.get_cursor_pos(window)
        self.dispatch('on_mouse_scroll', x, y, x_offset, y_offset)

    def _on_glfw_error(self, error, desc):
        print(f'GLFW Error: {desc}', file=sys.stderr)

    @property
    def title(self):
        return self._title

    @property
    def width(self):
        self._width, self._height = glfw.get_framebuffer_size(self._native_window)
        return self._width

    @width.setter
    def width(self, value):
        glfw.set_window_size(self._native_window, value, self._height)
        self._width, self._height = glfw.get_framebuffer_size(self._native_window)

    @property
    def height(self):
        self._width, self._height = glfw.get_framebuffer_size(self._native_window)
        return self._height

    @height.setter
    def height(self, value):
        glfw.set_window_size(self._native_window, self._width, self._height)
        self._width, self._height = glfw.get_framebuffer_size(self._native_window)

    @property
    def aspect_ratio(self):
        return self._width / self._height

    @property
    def size(self):
        self._width, self._height = glfw.get_framebuffer_size(self._native_window)
        return self._width, self._height

    @size.setter
    def size(self, sz):
        glfw.set_window_size(self._native_window, sz[0], sz[1])
        self._width, self._height = glfw.get_framebuffer_size(self._native_window)

    @property
    def position(self):
        self._x, self._y = glfw.get_window_pos(self._native_window)
        return Vec2(self._x, self._y)

    @position.setter
    def position(self, pos):
        glfw.set_window_pos(self._native_window, pos[0], pos[1])
        self._x, self._y = glfw.get_window_pos(self._native_window)

    @property
    def current_context_version(self):
        return glfw.get_window_attrib(self._native_window, glfw.CONTEXT_VERSION_MAJOR), \
               glfw.get_window_attrib(self._native_window, glfw.CONTEXT_VERSION_MINOR)

    @property
    def default_shader(self):
        return self._default_shader

    @property
    def delta_time(self):
        self._current_time = time.time()
        elapsed = self._current_time - self._previous_time
        self._previous_time = self._current_time

        return elapsed

    @property
    def elapsed_time(self):
        return time.time() - self._start_time

    def main_loop(self):
        """
        Start the main loop of the window.
        """
        self._previous_time = time.time()

        while not glfw.window_should_close(self._native_window):
            # update time
            dt = self.delta_time

            # process inputs
            glfw.poll_events()

            # update
            self.dispatch('on_update', dt)

            # render
            gl.glClear(self._clear_flags)

            if self._gui:
                self.dispatch('on_gui')
                # self._gui.process_inputs()

            self.dispatch('on_draw', dt)

            self.dispatch('on_idle', dt)

            # if self._gui:
            #     draw_data = imgui.get_draw_data()
            #     if draw_data:
            #         self._gui.render(draw_data)

            glfw.swap_buffers(self._native_window)

        self.dispatch('on_close')

    def clear(self, color):
        color = color if color else self._clear_color
        gl.glClearColor(*color)
        gl.glClear(self._clear_flags)

    def swap_buffers(self):
        glfw.swap_buffers(self._native_window)

    def activate(self):
        glfw.make_context_current(self._native_window)

    def destroy(self):
        glfw.destroy_window(self._native_window)

    def close(self):
        glfw.set_window_should_close(self._native_window, True)

    def show(self):
        glfw.show_window(self._native_window)
        self.dispatch('on_show')

    def hide(self):
        glfw.hide_window(self._native_window)
        self.dispatch('on_hide')

    def on_init(self):
        gl.glClearColor(*self._clear_color)
        gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)

    def on_close(self):
        if self._gui:
            self._gui.shutdown()
        glfw.terminate()

    def on_resize(self, width, height):
        """Default resize handler."""
        self.activate()
        gl.glViewport(0, 0, width, height)
        gl.glClear(self._clear_flags)
        self.dispatch('on_draw', self.delta_time)
        self.swap_buffers()

    def on_key_press(self, key, mods):
        if key == KeyCode.Escape:
            self.close()
        elif key == KeyCode.F10:
            import png
            import datetime
            h, w = self.size
            framebuffer = np.zeros((h, w * 3), dtype=np.uint8)
            gl.glReadPixels(0, 0, w, h, gl.GL_RGB, gl.GL_UNSIGNED_BYTE, framebuffer)
            filename = f'{datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}.png'
            filepath = os.path.join(os.getcwd(), filename)
            png.from_array(framebuffer[::-1], 'RGB').save(filepath)
            print(f'Screenshot saved to {filepath}')

        return True

    def mouse_position(self):
        return glfw.get_cursor_pos(self._native_window)

    def create_camera(self, pos, look_at, up, fov_v=45.0, near=0.1, far=1000., degrees=True):
        self._camera = Camera(pos, look_at, up, self._width / self._height, fov_v, near, far, degrees)
        self.attach_listeners(self._camera)


Window.register_event_type('on_cursor_enter')
Window.register_event_type('on_cursor_leave')
Window.register_event_type('on_draw')
Window.register_event_type('on_resize')
Window.register_event_type('on_mouse_motion')
Window.register_event_type('on_mouse_drag')
Window.register_event_type('on_mouse_press')
Window.register_event_type('on_mouse_release')
Window.register_event_type('on_mouse_scroll')
Window.register_event_type('on_key_press')
Window.register_event_type('on_key_release')
Window.register_event_type('on_update')
Window.register_event_type('on_init')
Window.register_event_type('on_idle')
Window.register_event_type('on_show')
Window.register_event_type('on_hide')
Window.register_event_type('on_close')
Window.register_event_type('on_gui')


class WindowEventLogger:
    """Logs window events when triggered.
    """

    def __init__(self, log_file=None):
        """Creates a `WindowEventLogger` object.

        Args:
            log_file (file like object):
                Specifies the output of log messages. If not specified, stdout will be used.
        """
        self._output = log_file if log_file is not None else sys.stdout

    def on_key_press(self, key, mods):
        print(f'on_key_press({key}, mods={mods})', file=self._output)

    def on_key_release(self, key, mods):
        print(f'on_key_release({key}, mods={mods})', file=self._output)

    def on_mouse_motion(self, x, y, dx, dy):
        print(f'on_mouse_motion(x={x}, y={y}, dx={dx}, dy={dy})', file=self._output)

    def on_mouse_drag(self, x, y, dx, dy, btns, mods):
        print(f'on_mouse_drag(x={x}, y={y}, dx={dx}, dy={dy}, btns={btns}, mods={mods}', file=self._output)

    def on_mouse_press(self, x, y, btn, mods):
        print(f'on_mouse_press(x={x}, y={y}, btn={btn}, mods={mods}', file=self._output)

    def on_mouse_release(self, x, y, btn, mods):
        print(f'on_mouse_release(x={x}, y={y}, btn={btn}, mods={mods}', file=self._output)

    def on_mouse_scroll(self, x, y, dx, dy):
        print(f'on_mouse_scroll(x={x}, y={y}, dx={dx}, dy={dy})', file=self._output)

    def on_close(self):
        print(f'on_close()', file=self._output)

    def on_resize(self, w, h):
        print(f'on_resize(w={w}, h={h})', file=self._output)

    def on_draw(self):
        print(f'on_draw()', file=self._output)
