#!python
# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION
# cython: language_level=3
# cython: cpow=True
# cython: boundscheck=False
# cython: wraparound=False
# cython: initializedcheck=False
# cython: cdivision=True
from typing import Optional
import numpy
cimport numpy
from libc.math cimport exp, fabs, log, sin, cos, tan, tanh, asin, acos, atan, isnan, isinf
from libc.math cimport NAN as nan
from libc.math cimport INFINITY as inf
import cython
from cpython.mem cimport PyMem_Malloc
from cpython.mem cimport PyMem_Realloc
from cpython.mem cimport PyMem_Free
from hydpy.cythons.autogen cimport configutils
from hydpy.cythons.autogen cimport interfaceutils
from hydpy.cythons.autogen cimport interputils
from hydpy.cythons.autogen import pointerutils
from hydpy.cythons.autogen cimport pointerutils
from hydpy.cythons.autogen cimport quadutils
from hydpy.cythons.autogen cimport rootutils
from hydpy.cythons.autogen cimport smoothutils
from hydpy.cythons.autogen cimport masterinterface


cdef void do_nothing(Model model)  noexcept nogil:
    pass

cpdef get_wrapper():
    cdef CallbackWrapper wrapper = CallbackWrapper()
    wrapper.callback = do_nothing
    return wrapper

@cython.final
cdef class Parameters:
    pass
@cython.final
cdef class ControlParameters:
    pass
@cython.final
cdef class SolverParameters:
    pass
@cython.final
cdef class Sequences:
    pass
@cython.final
cdef class FluxSequences:
    cpdef inline void load_data(self, numpy.int64_t idx)  noexcept nogil:
        cdef numpy.int64_t jdx0
        cdef numpy.int64_t k
        if self._qv_diskflag_reading:
            k = 0
            for jdx0 in range(self._qv_length_0):
                self.qv[jdx0] = self._qv_ncarray[k]
                k += 1
        elif self._qv_ramflag:
            for jdx0 in range(self._qv_length_0):
                self.qv[jdx0] = self._qv_array[idx, jdx0]
    cpdef inline void save_data(self, numpy.int64_t idx)  noexcept nogil:
        cdef numpy.int64_t jdx0
        cdef numpy.int64_t k
        if self._qv_diskflag_writing:
            k = 0
            for jdx0 in range(self._qv_length_0):
                self._qv_ncarray[k] = self.qv[jdx0]
                k += 1
        if self._qv_ramflag:
            for jdx0 in range(self._qv_length_0):
                self._qv_array[idx, jdx0] = self.qv[jdx0]
    cpdef inline set_pointeroutput(self, str name, pointerutils.PDouble value):
        pass
    cpdef inline void update_outputs(self) noexcept nogil:
        pass
@cython.final
cdef class StateSequences:
    cpdef inline void load_data(self, numpy.int64_t idx)  noexcept nogil:
        cdef numpy.int64_t jdx0
        cdef numpy.int64_t k
        if self._sv_diskflag_reading:
            k = 0
            for jdx0 in range(self._sv_length_0):
                self.sv[jdx0] = self._sv_ncarray[k]
                k += 1
        elif self._sv_ramflag:
            for jdx0 in range(self._sv_length_0):
                self.sv[jdx0] = self._sv_array[idx, jdx0]
    cpdef inline void save_data(self, numpy.int64_t idx)  noexcept nogil:
        cdef numpy.int64_t jdx0
        cdef numpy.int64_t k
        if self._sv_diskflag_writing:
            k = 0
            for jdx0 in range(self._sv_length_0):
                self._sv_ncarray[k] = self.sv[jdx0]
                k += 1
        if self._sv_ramflag:
            for jdx0 in range(self._sv_length_0):
                self._sv_array[idx, jdx0] = self.sv[jdx0]
    cpdef inline set_pointeroutput(self, str name, pointerutils.PDouble value):
        pass
    cpdef inline void update_outputs(self) noexcept nogil:
        pass
@cython.final
cdef class NumConsts:
    pass
@cython.final
cdef class NumVars:
    pass
@cython.final
cdef class Model:
    cpdef inline void simulate(self, numpy.int64_t idx)  noexcept nogil:
        self.idx_sim = idx
        self.solve()
        self.update_outputs()
    cpdef void reset_reuseflags(self) noexcept nogil:
        pass
    cpdef void save_data(self, numpy.int64_t idx) noexcept nogil:
        self.idx_sim = idx
        self.sequences.fluxes.save_data(idx)
        self.sequences.states.save_data(idx)
    cpdef void new2old(self) noexcept nogil:
        cdef numpy.int64_t jdx0
        for jdx0 in range(self.sequences.states._sv_length_0):
            self.sequences.old_states.sv[jdx0] = self.sequences.new_states.sv[jdx0]
    cpdef inline void update_inlets(self) noexcept nogil:
        pass
    cpdef inline void update_outlets(self) noexcept nogil:
        pass
    cpdef inline void update_receivers(self, numpy.int64_t idx) noexcept nogil:
        self.idx_sim = idx
        pass
    cpdef inline void update_senders(self, numpy.int64_t idx) noexcept nogil:
        self.idx_sim = idx
        pass
    cpdef inline void update_outputs(self) noexcept nogil:
        pass
    cpdef inline void solve(self) noexcept nogil:
        cdef numpy.int64_t decrease_dt
        self.numvars.use_relerror = not isnan(            self.parameters.solver.relerrormax        )
        self.numvars.t0, self.numvars.t1 = 0.0, 1.0
        self.numvars.dt_est = 1.0 * self.parameters.solver.reldtmax
        self.numvars.f0_ready = False
        self.reset_sum_fluxes()
        while self.numvars.t0 < self.numvars.t1 - 1e-14:
            self.numvars.last_abserror = inf
            self.numvars.last_relerror = inf
            self.numvars.dt = min(                self.numvars.t1 - self.numvars.t0,                1.0 * self.parameters.solver.reldtmax,                max(self.numvars.dt_est, self.parameters.solver.reldtmin),            )
            if not self.numvars.f0_ready:
                self.calculate_single_terms()
                self.numvars.idx_method = 0
                self.numvars.idx_stage = 0
                self.set_point_fluxes()
                self.set_point_states()
                self.set_result_states()
            for self.numvars.idx_method in range(1, self.numconsts.nmb_methods + 1):
                for self.numvars.idx_stage in range(1, self.numvars.idx_method):
                    self.get_point_states()
                    self.calculate_single_terms()
                    self.set_point_fluxes()
                for self.numvars.idx_stage in range(1, self.numvars.idx_method + 1):
                    self.integrate_fluxes()
                    self.calculate_full_terms()
                    self.set_point_states()
                self.set_result_fluxes()
                self.set_result_states()
                self.calculate_error()
                self.extrapolate_error()
                if self.numvars.idx_method == 1:
                    continue
                if (self.numvars.abserror <= self.parameters.solver.abserrormax) or (                    self.numvars.relerror <= self.parameters.solver.relerrormax                ):
                    self.numvars.dt_est = self.numconsts.dt_increase * self.numvars.dt
                    self.numvars.f0_ready = False
                    self.addup_fluxes()
                    self.numvars.t0 = self.numvars.t0 + self.numvars.dt
                    self.new2old()
                    break
                decrease_dt = self.numvars.dt > self.parameters.solver.reldtmin
                decrease_dt = decrease_dt and (                    self.numvars.extrapolated_abserror                    > self.parameters.solver.abserrormax                )
                if self.numvars.use_relerror:
                    decrease_dt = decrease_dt and (                        self.numvars.extrapolated_relerror                        > self.parameters.solver.relerrormax                    )
                if decrease_dt:
                    self.numvars.f0_ready = True
                    self.numvars.dt_est = self.numvars.dt / self.numconsts.dt_decrease
                    break
                self.numvars.last_abserror = self.numvars.abserror
                self.numvars.last_relerror = self.numvars.relerror
                self.numvars.f0_ready = True
            else:
                if self.numvars.dt <= self.parameters.solver.reldtmin:
                    self.numvars.f0_ready = False
                    self.addup_fluxes()
                    self.numvars.t0 = self.numvars.t0 + self.numvars.dt
                    self.new2old()
                else:
                    self.numvars.f0_ready = True
                    self.numvars.dt_est = self.numvars.dt / self.numconsts.dt_decrease
        self.get_sum_fluxes()
    cpdef inline void calculate_single_terms(self) noexcept nogil:
        self.numvars.nmb_calls = self.numvars.nmb_calls + 1
        self.calc_qv_v1()
    cpdef inline void calculate_full_terms(self) noexcept nogil:
        self.calc_sv_v1()
    cpdef inline void get_point_states(self) noexcept nogil:
        cdef numpy.int64_t idx0
        for idx0 in range(self.sequences.states._sv_length):
            self.sequences.states.sv[idx0] = self.sequences.states._sv_points[self.numvars.idx_stage][idx0]
    cpdef inline void set_point_states(self) noexcept nogil:
        cdef numpy.int64_t idx0
        for idx0 in range(self.sequences.states._sv_length):
            self.sequences.states._sv_points[self.numvars.idx_stage][idx0] = self.sequences.states.sv[idx0]
    cpdef inline void set_result_states(self) noexcept nogil:
        cdef numpy.int64_t idx0
        for idx0 in range(self.sequences.states._sv_length):
            self.sequences.states._sv_results[self.numvars.idx_method][idx0] = self.sequences.states.sv[idx0]
    cpdef inline void get_sum_fluxes(self) noexcept nogil:
        cdef numpy.int64_t idx0
        for idx0 in range(self.sequences.fluxes._qv_length):
            self.sequences.fluxes.qv[idx0] = self.sequences.fluxes._qv_sum[idx0]
    cpdef inline void set_point_fluxes(self) noexcept nogil:
        cdef numpy.int64_t idx0
        for idx0 in range(self.sequences.fluxes._qv_length):
            self.sequences.fluxes._qv_points[self.numvars.idx_stage][idx0] = self.sequences.fluxes.qv[idx0]
    cpdef inline void set_result_fluxes(self) noexcept nogil:
        cdef numpy.int64_t idx0
        for idx0 in range(self.sequences.fluxes._qv_length):
            self.sequences.fluxes._qv_results[self.numvars.idx_method][idx0] = self.sequences.fluxes.qv[idx0]
    cpdef inline void integrate_fluxes(self) noexcept nogil:
        cdef numpy.int64_t jdx, idx0
        for idx0 in range(self.sequences.fluxes._qv_length):
            self.sequences.fluxes.qv[idx0] = 0.
            for jdx in range(self.numvars.idx_method):
                self.sequences.fluxes.qv[idx0] = self.sequences.fluxes.qv[idx0] + self.numvars.dt * self.numconsts.a_coefs[self.numvars.idx_method-1, self.numvars.idx_stage, jdx]*self.sequences.fluxes._qv_points[jdx, idx0]
    cpdef inline void reset_sum_fluxes(self) noexcept nogil:
        cdef numpy.int64_t idx0
        for idx0 in range(self.sequences.fluxes._qv_length):
            self.sequences.fluxes._qv_sum[idx0] = 0.
    cpdef inline void addup_fluxes(self) noexcept nogil:
        cdef numpy.int64_t idx0
        for idx0 in range(self.sequences.fluxes._qv_length):
            self.sequences.fluxes._qv_sum[idx0] = self.sequences.fluxes._qv_sum[idx0] + self.sequences.fluxes.qv[idx0]
    cpdef inline void calculate_error(self) noexcept nogil:
        cdef numpy.int64_t idx0
        cdef double abserror
        self.numvars.abserror = 0.
        if self.numvars.use_relerror:
            self.numvars.relerror = 0.
        else:
            self.numvars.relerror = inf
        for idx0 in range(self.sequences.fluxes._qv_length):
            abserror = fabs(self.sequences.fluxes._qv_results[self.numvars.idx_method, idx0]-self.sequences.fluxes._qv_results[self.numvars.idx_method-1, idx0])
            self.numvars.abserror = max(self.numvars.abserror, abserror)
            if self.numvars.use_relerror:
                if self.sequences.fluxes._qv_results[self.numvars.idx_method, idx0] == 0.:
                    self.numvars.relerror = inf
                else:
                    self.numvars.relerror = max(self.numvars.relerror, fabs(abserror/self.sequences.fluxes._qv_results[self.numvars.idx_method, idx0]))
    cpdef inline void extrapolate_error(self) noexcept nogil:
        if self.numvars.abserror <= 0.0:
            self.numvars.extrapolated_abserror = 0.0
            self.numvars.extrapolated_relerror = 0.0
        else:
            if self.numvars.idx_method > 2:
                self.numvars.extrapolated_abserror = exp(                    log(self.numvars.abserror)                    + (                        log(self.numvars.abserror)                        - log(self.numvars.last_abserror)                    )                    * (self.numconsts.nmb_methods - self.numvars.idx_method)                )
            else:
                self.numvars.extrapolated_abserror = -999.9
            if self.numvars.use_relerror:
                if self.numvars.idx_method > 2:
                    if isinf(self.numvars.relerror):
                        self.numvars.extrapolated_relerror = inf
                    else:
                        self.numvars.extrapolated_relerror = exp(                            log(self.numvars.relerror)                            + (                                log(self.numvars.relerror)                                - log(self.numvars.last_relerror)                            )                            * (self.numconsts.nmb_methods - self.numvars.idx_method)                        )
                else:
                    self.numvars.extrapolated_relerror = -999.9
            else:
                self.numvars.extrapolated_relerror = inf
    cpdef inline void calc_qv_v1(self) noexcept nogil:
        cdef numpy.int64_t i
        for i in range(self.parameters.control.n):
            self.sequences.fluxes.qv[i] = self.parameters.control.k * self.sequences.states.sv[i]
    cpdef inline void calc_sv_v1(self) noexcept nogil:
        cdef numpy.int64_t i
        for i in range(self.parameters.control.n):
            self.sequences.new_states.sv[i] = self.sequences.old_states.sv[i] - self.sequences.fluxes.qv[i]
    cpdef inline void calc_qv(self) noexcept nogil:
        cdef numpy.int64_t i
        for i in range(self.parameters.control.n):
            self.sequences.fluxes.qv[i] = self.parameters.control.k * self.sequences.states.sv[i]
    cpdef inline void calc_sv(self) noexcept nogil:
        cdef numpy.int64_t i
        for i in range(self.parameters.control.n):
            self.sequences.new_states.sv[i] = self.sequences.old_states.sv[i] - self.sequences.fluxes.qv[i]
