import os

import deepwave
import matplotlib.pyplot as plt
import torch
from deepwave import scalar
from scipy.ndimage import gaussian_filter
from scipy.signal import butter
from torchaudio.functional import biquad


def main():
    def get_file(name, path='out/serial'):
        return os.path.join(os.path.dirname(__file__), path, name)

    def load(name, path='out/serial'):
        return torch.load(get_file(name, path))

    def save(tensor, name, path='out/serial'):
        torch.save(tensor, get_file(name, path))

    def savefig(name, path='out/serial'):
        plt.savefig(get_file(name, path))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ny = 2301
    nx = 751
    dx = 4.0
    v_true = load('vp.pt', path='out/base')

    # Select portion of model for inversion
    ny = 600
    nx = 250
    v_true = v_true[:ny, :nx]

    # Smooth to use as starting model
    v_init = torch.tensor(1 / gaussian_filter(1 / v_true.numpy(), 40)).to(
        device
    )
    v = v_init.clone()
    v.requires_grad_()

    n_shots = 115

    n_sources_per_shot = 1
    d_source = 20  # 20 * 4m = 80m
    first_source = 10  # 10 * 4m = 40m
    source_depth = 2  # 2 * 4m = 8m

    n_receivers_per_shot = 384
    d_receiver = 6  # 6 * 4m = 24m
    first_receiver = 0  # 0 * 4m = 0m
    receiver_depth = 2  # 2 * 4m = 8m

    freq = 25
    nt = 750
    dt = 0.004
    peak_time = 1.5 / freq

    observed_data = load('obs_data.pt', path='out/base')

    # Select portion of data for inversion
    n_shots = 20
    n_receivers_per_shot = 100
    nt = 300
    observed_data = observed_data[:n_shots, :n_receivers_per_shot, :nt].to(
        device
    )

    # source_locations
    source_locations = torch.zeros(
        n_shots, n_sources_per_shot, 2, dtype=torch.long, device=device
    )
    source_locations[..., 1] = source_depth
    source_locations[:, 0, 0] = torch.arange(n_shots) * d_source + first_source

    # receiver_locations
    receiver_locations = torch.zeros(
        n_shots, n_receivers_per_shot, 2, dtype=torch.long, device=device
    )
    receiver_locations[..., 1] = receiver_depth
    receiver_locations[:, :, 0] = (
        torch.arange(n_receivers_per_shot) * d_receiver + first_receiver
    ).repeat(n_shots, 1)

    # source_amplitudes
    source_amplitudes = (
        (deepwave.wavelets.ricker(freq, nt, dt, peak_time))
        .repeat(n_shots, n_sources_per_shot, 1)
        .to(device)
    )

    ## First attempt: simple inversion

    # Setup optimiser to perform inversion
    optimiser = torch.optim.SGD([v], lr=1e9, momentum=0.9)
    loss_fn = torch.nn.MSELoss()

    # Run optimisation/inversion
    n_epochs = 250

    for epoch in range(n_epochs):
        optimiser.zero_grad()
        out = scalar(
            v,
            dx,
            dt,
            source_amplitudes=source_amplitudes,
            source_locations=source_locations,
            receiver_locations=receiver_locations,
            pml_freq=freq,
        )
        loss = loss_fn(out[-1], observed_data)
        loss.backward()
        torch.nn.utils.clip_grad_value_(
            v, torch.quantile(v.grad.detach().abs(), 0.98)
        )
        optimiser.step()

    # Plot
    vmin = v_true.min()
    vmax = v_true.max()
    _, ax = plt.subplots(3, figsize=(10.5, 10.5), sharex=True, sharey=True)
    ax[0].imshow(
        v_init.cpu().T, aspect='auto', cmap='gray', vmin=vmin, vmax=vmax
    )
    ax[0].set_title("Initial")
    ax[1].imshow(
        v.detach().cpu().T, aspect='auto', cmap='gray', vmin=vmin, vmax=vmax
    )
    ax[1].set_title("Out")
    ax[2].imshow(
        v_true.cpu().T, aspect='auto', cmap='gray', vmin=vmin, vmax=vmax
    )
    ax[2].set_title("True")
    plt.tight_layout()
    savefig('example_simple_fwi.jpg')

    ## Second attempt: constrained velocity and frequency filtering

    # Define a function to taper the ends of traces
    def taper(x):
        return deepwave.common.cosine_taper_end(x, 100)

    # Generate a velocity model constrained to be within a desired range
    class Model(torch.nn.Module):
        def __init__(self, initial, min_vel, max_vel):
            super().__init__()
            self.min_vel = min_vel
            self.max_vel = max_vel
            self.model = torch.nn.Parameter(
                torch.logit((initial - min_vel) / (max_vel - min_vel))
            )

        def forward(self):
            return (
                torch.sigmoid(self.model) * (self.max_vel - self.min_vel)
                + self.min_vel
            )

    observed_data = taper(observed_data)
    model = Model(v_init, 1000, 2500).to(device)

    # Run optimisation/inversion
    n_epochs = 2
    loss_record = []
    v_record = []
    out_record = []
    out_filt_record = []

    for cutoff_freq in [10, 15, 20, 25, 30]:
        sos = butter(6, cutoff_freq, fs=1 / dt, output='sos')
        sos = [
            torch.tensor(sosi).to(observed_data.dtype).to(device)
            for sosi in sos
        ]

        def filt(x):
            return biquad(biquad(biquad(x, *sos[0]), *sos[1]), *sos[2])

        observed_data_filt = filt(observed_data)
        optimiser = torch.optim.LBFGS(
            model.parameters(), line_search_fn='strong_wolfe'
        )
        for epoch in range(n_epochs):
            num_calls = 0

            def closure():
                nonlocal num_calls
                num_calls += 1
                optimiser.zero_grad()
                v = model()
                out = scalar(
                    v,
                    dx,
                    dt,
                    source_amplitudes=source_amplitudes,
                    source_locations=source_locations,
                    receiver_locations=receiver_locations,
                    max_vel=2500,
                    pml_freq=freq,
                    time_pad_frac=0.2,
                )
                out_filt = filt(taper(out[-1]))
                loss = 1e6 * loss_fn(out_filt, observed_data_filt)
                loss.backward()
                if num_calls == 1:
                    loss_record.append(loss.item())
                    v_record.append(v.detach().cpu())
                    out_record.append(out[-1].detach().cpu())
                    out_filt_record.append(out_filt.detach().cpu())
                return loss

            optimiser.step(closure)

    v = model()
    vmin = v_true.min()
    vmax = v_true.max()
    _, ax = plt.subplots(3, figsize=(10.5, 10.5), sharex=True, sharey=True)
    ax[0].imshow(
        v_init.cpu().T, aspect='auto', cmap='gray', vmin=vmin, vmax=vmax
    )
    ax[0].set_title("Initial")
    ax[1].imshow(
        v.detach().cpu().T, aspect='auto', cmap='gray', vmin=vmin, vmax=vmax
    )
    ax[1].set_title("Out")
    ax[2].imshow(
        v_true.cpu().T, aspect='auto', cmap='gray', vmin=vmin, vmax=vmax
    )
    ax[2].set_title("True")
    plt.tight_layout()
    savefig('example_increasing_freq_fwi.jpg')

    save(torch.tensor(loss_record), 'loss_record.pt')
    save(torch.stack(v_record), 'v_record.pt')
    save(torch.stack(out_record), 'out_record.pt')
    save(torch.stack(out_filt_record), 'out_filt_record.pt')


if __name__ == '__main__':
    main()
