#
# Copyright (c) 2023 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause#
import plotly
import plotly.graph_objects as go
import pandas as pd
from plotly.subplots import make_subplots
from dash import dash_table, dcc
from ..settings import *


def data_table(df, n=1000, page_size=10):
    if df is not None:
        df = df.head(n)
        columns = [{"name": "Index", "id": "Index"}] + [{"name": c, "id": c} for c in df.columns]
        data = []
        for i in range(df.shape[0]):
            d = {c: v for c, v in zip(df.columns, df.values[i])}
            d.update({"Index": df.index.values[i]})
            data.append(d)

        table = dash_table.DataTable(
            id="table",
            columns=columns,
            data=data,
            style_cell_conditional=[{"textAlign": "center"}],
            style_table={"overflowX": "scroll"},
            editable=False,
            column_selectable="single",
            page_action="native",
            page_size=page_size,
            page_current=0,
            style_header=dict(backgroundColor=TABLE_HEADER_COLOR),
            style_data=dict(backgroundColor=TABLE_DATA_COLOR),
        )
        return table
    else:
        return dash_table.DataTable()


def plot_timeseries(ts, figure_height=750):
    dfs = [ts] if isinstance(ts, pd.DataFrame) else ts

    index = 0
    traces = []
    color_list = plotly.colors.qualitative.Dark24
    for ts in dfs:
        for i in range(ts.shape[1]):
            v = ts[[ts.columns[i]]]
            color = color_list[index % len(color_list)]
            traces.append(
                go.Scatter(name=ts.columns[i], x=v.index, y=v.values.flatten(), mode="lines", line=dict(color=color))
            )
            index += 1

    layout = dict(
        showlegend=True,
        xaxis=dict(
            title="Time",
            type="date",
            rangeselector=dict(
                buttons=list(
                    [
                        dict(count=7, label="1w", step="day", stepmode="backward"),
                        dict(count=1, label="1m", step="month", stepmode="backward"),
                        dict(count=6, label="6m", step="month", stepmode="backward"),
                        dict(count=1, label="1y", step="year", stepmode="backward"),
                        dict(step="all"),
                    ]
                )
            ),
        ),
    )
    fig = make_subplots(figure=go.Figure(layout=layout))
    fig.update_yaxes(title_text="Timeseries")
    for trace in traces:
        fig.add_trace(trace)
    fig.update_layout(height=figure_height)
    return dcc.Graph(figure=fig)
