# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_arrs.ipynb.

# %% auto 0
__all__ = ['chain', 'nparray', 'detach', 'clone', 'cpu', 'numpy', 'todense', 'tolist', 'toarray', 'arrchain', 'try2arr',
           'trc2arr', 'mtx2arr', 'spr2arr', 'dns2arr', 'ann2arr', 'mth2arr', 'asarr', 'as01']

# %% ../nbs/01_arrs.ipynb 6
#| export


# %% ../nbs/01_arrs.ipynb 8
from typing import Any, Union, TypeAlias, ParamSpec

# %% ../nbs/01_arrs.ipynb 10
try:
    import numpy as np
except ModuleNotFoundError:
    np = None

# %% ../nbs/01_arrs.ipynb 12
from atyp import StrQ, IntQ, NPArray, NPMatrix, Tensor, AnnData
from atyp import SPArray, SPMatrix, Index, Series, DataFrame

from chck import isad, isarr, ismtx, istens, issparse, istrc, isnone, notnone, isnpmatrix
from nlit import CPU, TOARRAY, TODENSE, TOLIST, DETACH, CLONE, NUMPY
from calr import attrfunc, applyfns
from achn import achn

# %% ../nbs/01_arrs.ipynb 14
#| export


# %% ../nbs/01_arrs.ipynb 17
class chain(achn): ...

@chain.able
def nparray(a): return np.array(a)

@chain.able
@attrfunc(DETACH)
def detach(x): '''try: x.detach() except: x'''

@chain.able
@attrfunc(CLONE)
def clone(x): '''try: x.clone() except: x'''

@chain.able
@attrfunc(CPU)
def cpu(x): '''try: x.cpu() except: x'''

@chain.able
@attrfunc(NUMPY)
def numpy(x): '''try: x.numpy() except: x'''

@chain.able
@attrfunc(TODENSE)
def todense(x): '''try: x.todense() except: x'''

@chain.able
@attrfunc(TOLIST)
def tolist(x): '''try: x.tolist() except: x'''

@chain.able
@attrfunc(TOARRAY)
def toarray(x): '''try: x.toarray() except: x'''

# %% ../nbs/01_arrs.ipynb 19
def arrchain(x, funcs, *args, **kwargs): 
    return applyfns(x, funcs, *args, check=isarr, **kwargs)

# %% ../nbs/01_arrs.ipynb 20
def try2arr(a) -> NPArray:
    try: return a if isarr(a) else np.array(a)
    except: return a

# %% ../nbs/01_arrs.ipynb 22
def trc2arr(a: Tensor) -> NPArray:
    '''from torch.tensor --> np.array e.g.
    `torch.tensor([...]).detach().clone().cpu().numpy()`'''
    if isarr(a) or not istrc(a): return a
    return arrchain(a, (detach, clone, cpu, numpy, ))

# %% ../nbs/01_arrs.ipynb 24
def mtx2arr(a: NPMatrix) -> NPArray:
    '''from np.matrix --> np.array e.g. 
    `np.array(np.matrix([...]).tolist())`'''
    if isarr(a) or not isnpmatrix(a): return a
    return arrchain(a, (tolist, nparray))

# %% ../nbs/01_arrs.ipynb 26
def spr2arr(a: Union[SPMatrix, SPArray]) -> NPArray:
    '''from scipy.sparse --> np.array e.g. 
    `scipy.sparse.csr_matrix([...]).toarray()`'''
    if isarr(a) or not issparse(a): return a
    return arrchain(a, (toarray, ))


# %% ../nbs/01_arrs.ipynb 28
def dns2arr(a: Union[SPMatrix, SPArray]) -> NPArray:
    '''from scipy.sparse --> np.array e.g. 
    `np.array(scipy.sparse.csr_matrix([...]).todense().tolist())`'''
    if isarr(a) or not issparse(a): return a
    return arrchain(a, (todense, tolist, nparray, ))

# %% ../nbs/01_arrs.ipynb 30
def ann2arr(a: AnnData, layer: StrQ = None) -> NPArray:
    if isarr(a) or not isad(a): return a
    arr = a.layers.get(layer, a.X)
    return arrchain(arr, (toarray, nparray))

# %% ../nbs/01_arrs.ipynb 32
def mth2arr(a, *args, **kwargs) -> NPArray:    
    if isarr(a): return a
    return arrchain(a, (
            spr2arr, # `scipy.sparse.csr_matrix([...]).toarray()`
            trc2arr, # `torch.tensor([...]).detach().clone().cpu().numpy()`
            dns2arr, # `np.array(scipy.sparse.csr_matrix([...]).todense().tolist())`
            mtx2arr, # `np.array(np.matrix([...]).tolist())`
            ann2arr, # `AnnData([...]).layers.get(layer, a.X).toarray()`
            try2arr, # `np.array`
        ), *args, **kwargs
    )

# %% ../nbs/01_arrs.ipynb 34
def asarr(a, *args, **kwargs) -> NPArray: return mth2arr(a, *args, **kwargs)

# %% ../nbs/01_arrs.ipynb 39
def as01(a: list, cutoff: float = 0.5) -> NPArray:
    '''binarize'''
    return (asarr(a) > cutoff).astype(int)
