"""
Tests the server compute capabilities.
"""

import pytest
import qcfractal.interface as portal
from qcfractal import testing
from qcfractal.testing import fractal_compute_server, reset_server_database, using_psi4, using_rdkit


@pytest.mark.parametrize("data", [
    pytest.param(("psi4", "HF", "sto-3g"), id="psi4", marks=using_psi4),
    pytest.param(("rdkit", "UFF", None), id="rdkit", marks=using_rdkit)
])
def test_task_molecule_no_orientation(data, fractal_compute_server):
    """
    Molecule orientation should not change on compute
    """

    # Reset database each run
    reset_server_database(fractal_compute_server)

    client = portal.FractalClient(fractal_compute_server)

    mol = portal.Molecule(symbols=["H", "H"], geometry=[0, 0, 0, 0, 5, 0], connectivity=[(0, 1, 1)])

    mol_id = client.add_molecules([mol])[0]

    program, method, basis = data
    ret = client.add_compute(program, method, basis, "energy", None, [mol_id])

    # Manually handle the compute
    fractal_compute_server.await_results()

    # Check for the single result
    ret = client.get_results(id=ret.submitted)
    assert len(ret) == 1
    assert ret[0].status == "COMPLETE"
    assert ret[0].molecule == mol_id

    # Make sure no other molecule was added
    ret = client.get_molecules(molecular_formula=["H2"])
    assert len(ret) == 1
    assert ret[0].id == mol_id


@testing.using_rdkit
def test_task_error(fractal_compute_server):
    client = portal.FractalClient(fractal_compute_server)

    mol = portal.models.Molecule(**{"geometry": [0, 0, 0], "symbols": ["He"]})
    # Cookiemonster is an invalid method
    ret = client.add_compute("rdkit", "cookiemonster", "", "energy", None, [mol])

    # Manually handle the compute
    fractal_compute_server.await_results()

    # Check for error
    results = client.get_results(id=ret.submitted)
    assert len(results) == 1
    assert results[0].status == "ERROR"

    assert "connectivity" in results[0].get_error().error_message


@testing.using_rdkit
def test_queue_error(fractal_compute_server):
    reset_server_database(fractal_compute_server)

    client = portal.FractalClient(fractal_compute_server)

    hooh = portal.data.get_molecule("hooh.json").json_dict()
    del hooh["connectivity"]

    compute_ret = client.add_compute("rdkit", "UFF", "", "energy", None, hooh)

    # Pull out a special iteration on the queue manager
    fractal_compute_server.update_tasks()
    assert len(fractal_compute_server.list_current_tasks()) == 1

    fractal_compute_server.await_results()
    assert len(fractal_compute_server.list_current_tasks()) == 0

    # Pull from database, raw JSON
    db = fractal_compute_server.objects["storage_socket"]
    queue_ret = db.get_queue(status="ERROR")["data"]
    result = db.get_results(id=compute_ret.ids)['data'][0]

    assert len(queue_ret) == 1
    assert "connectivity graph" in queue_ret[0]["error"]["error_message"]
    assert result['status'] == 'ERROR'

    # Force a complete mark and test
    fractal_compute_server.objects["storage_socket"].queue_mark_complete([queue_ret[0]["id"]])
    result = db.get_results(id=compute_ret.ids)['data'][0]
    assert result['status'] == 'COMPLETE'


@testing.using_rdkit
def test_queue_duplicate_compute(fractal_compute_server):
    reset_server_database(fractal_compute_server)

    client = portal.FractalClient(fractal_compute_server)

    hooh = portal.data.get_molecule("hooh.json").json_dict()
    mol_ret = client.add_molecules([hooh])

    ret = client.add_compute("rdkit", "UFF", "", "energy", None, mol_ret)
    assert len(ret.ids) == 1
    assert len(ret.existing) == 0

    # Pull out fireworks launchpad and queue nanny
    fractal_compute_server.await_results()

    db = fractal_compute_server.objects["storage_socket"]

    ret = client.add_compute("RDKIT", "uff", None, "energy", None, mol_ret)
    assert len(ret.ids) == 1
    assert len(ret.existing) == 1

    ret = client.add_compute("rdkit", "uFf", "", "energy", None, mol_ret)
    assert len(ret.ids) == 1
    assert len(ret.existing) == 1

    assert len(client.get_results(program="RDKIT")) == 1
    assert len(client.get_results(program="RDKit")) == 1

    assert len(client.get_results(method="UFF")) == 1
    assert len(client.get_results(method="uff")) == 1

    assert len(client.get_results(basis=None)) == 1
    assert len(client.get_results(basis="")) == 1

    assert len(client.get_results(keywords=None)) == 1


@testing.using_rdkit
def test_queue_compute_mixed_molecule(fractal_compute_server):

    client = portal.FractalClient(fractal_compute_server)

    mol1 = portal.Molecule.from_data("He 0 0 0\nHe 0 0 2.1")
    mol_ret = client.add_molecules([mol1])

    mol2 = portal.Molecule.from_data("He 0 0 0\nHe 0 0 2.2")

    ret = client.add_compute("rdkit", "UFF", "", "energy", None, [mol1, mol2, "bad_id"], full_return=True)
    assert len(ret.data.ids) == 3
    assert ret.data.ids[2] is None
    assert len(ret.data.submitted) == 2
    assert len(ret.data.existing) == 0

    # Pull out fireworks launchpad and queue nanny
    fractal_compute_server.await_results()

    db = fractal_compute_server.objects["storage_socket"]

    ret = client.add_compute("rdkit", "UFF", "", "energy", None, [mol_ret[0], "bad_id2"])
    assert len(ret.ids) == 2
    assert ret.ids[1] is None
    assert len(ret.submitted) == 0
    assert len(ret.existing) == 1


@testing.using_rdkit
@testing.using_geometric
def test_queue_duplicate_procedure(fractal_compute_server):

    client = portal.FractalClient(fractal_compute_server)

    hooh = portal.data.get_molecule("hooh.json").json_dict()
    mol_ret = client.add_molecules([hooh])

    geometric_options = {
        "keywords": None,
        "qc_spec": {
            "driver": "gradient",
            "method": "UFF",
            "basis": "",
            "keywords": None,
            "program": "rdkit"
        },
    }

    ret = client.add_procedure("optimization", "geometric", geometric_options, [mol_ret[0], "bad_id"])
    assert len(ret.ids) == 2
    assert ret.ids[1] is None
    assert len(ret.submitted) == 1
    assert len(ret.existing) == 0

    # Pull out fireworks launchpad and queue nanny
    fractal_compute_server.await_results()

    db = fractal_compute_server.objects["storage_socket"]

    ret2 = client.add_procedure("optimization", "geometric", geometric_options, ["bad_id", hooh])
    assert len(ret2.ids) == 2
    assert ret2.ids[0] is None
    assert len(ret2.submitted) == 0
    assert len(ret2.existing) == 1

    assert ret.ids[0] == ret2.ids[1]
