use super::energy_model_ops::get_grade;
use super::energy_model_service::EnergyModelService;
use super::vehicle::vehicle_type::VehicleType;
use routee_compass_core::model::property::edge::Edge;
use routee_compass_core::model::property::vertex::Vertex;
use routee_compass_core::model::state::state_feature::StateFeature;
use routee_compass_core::model::state::state_model::StateModel;
use routee_compass_core::model::traversal::state::state_variable::StateVar;
use routee_compass_core::model::traversal::traversal_model::TraversalModel;
use routee_compass_core::model::traversal::traversal_model_error::TraversalModelError;
use routee_compass_core::model::unit::*;
use routee_compass_core::util::geo::haversine;
use std::sync::Arc;

pub struct EnergyTraversalModel {
    pub energy_model_service: Arc<EnergyModelService>,
    pub time_model: Arc<dyn TraversalModel>,
    pub vehicle: Arc<dyn VehicleType>,
}

impl TraversalModel for EnergyTraversalModel {
    /// inject the state features required by the VehicleType
    fn state_features(&self) -> Vec<(String, StateFeature)> {
        let mut features = self.vehicle.state_features();
        features.extend(self.time_model.state_features());
        features
    }

    fn traverse_edge(
        &self,
        trajectory: (&Vertex, &Edge, &Vertex),
        state: &mut Vec<StateVar>,
        state_model: &StateModel,
    ) -> Result<(), TraversalModelError> {
        let (_, edge, _) = trajectory;
        let distance =
            BASE_DISTANCE_UNIT.convert(&edge.distance, &self.energy_model_service.distance_unit);
        let prev = state.to_vec();

        // perform time traversal
        self.time_model
            .traverse_edge(trajectory, state, state_model)?;
        let prev_time = state_model.get_time(
            &prev,
            "time",
            &self
                .energy_model_service
                .time_model_speed_unit
                .associated_time_unit(),
        )?;
        let current_time = state_model.get_time(
            state,
            "time",
            &self
                .energy_model_service
                .time_model_speed_unit
                .associated_time_unit(),
        )?;
        let time_delta = current_time - prev_time;

        // perform vehicle energy traversal
        let grade = get_grade(&self.energy_model_service.grade_table, edge.edge_id)?;

        let distance_in_time_model_unit = BASE_DISTANCE_UNIT.convert(
            &edge.distance,
            &self
                .energy_model_service
                .time_model_speed_unit
                .associated_distance_unit(),
        );
        let speed = Speed::from((distance_in_time_model_unit, time_delta));
        self.vehicle.consume_energy(
            (speed, self.energy_model_service.time_model_speed_unit),
            (grade, self.energy_model_service.grade_table_grade_unit),
            (distance, self.energy_model_service.distance_unit),
            state,
            state_model,
        )?;

        Ok(())
    }

    fn estimate_traversal(
        &self,
        od: (&Vertex, &Vertex),
        state: &mut Vec<StateVar>,
        state_model: &StateModel,
    ) -> Result<(), TraversalModelError> {
        let (src, dst) = od;
        let distance = haversine::coord_distance(
            &src.coordinate,
            &dst.coordinate,
            self.energy_model_service.distance_unit,
        )
        .map_err(TraversalModelError::NumericError)?;

        if distance == Distance::ZERO {
            return Ok(());
        }

        self.time_model.estimate_traversal(od, state, state_model)?;
        self.vehicle.best_case_energy_state(
            (distance, self.energy_model_service.distance_unit),
            state,
            state_model,
        )?;

        Ok(())
    }
}

impl EnergyTraversalModel {
    pub fn new(
        energy_model_service: Arc<EnergyModelService>,
        conf: &serde_json::Value,
    ) -> Result<EnergyTraversalModel, TraversalModelError> {
        let time_model = energy_model_service.time_model_service.build(conf)?;

        let prediction_model_name = conf
            .get("model_name".to_string())
            .ok_or_else(|| {
                TraversalModelError::BuildError("No 'model_name' key provided in query".to_string())
            })?
            .as_str()
            .ok_or_else(|| {
                TraversalModelError::BuildError(
                    "Expected 'model_name' value to be string".to_string(),
                )
            })?
            .to_string();

        let vehicle = match energy_model_service
            .vehicle_library
            .get(&prediction_model_name)
        {
            None => {
                let model_names: Vec<&String> =
                    energy_model_service.vehicle_library.keys().collect();
                Err(TraversalModelError::BuildError(format!(
                    "No vehicle found with model_name = '{}', try one of: {:?}",
                    prediction_model_name, model_names
                )))
            }
            Some(mr) => Ok(mr.clone()),
        }?
        .update_from_query(conf)?;

        Ok(EnergyTraversalModel {
            energy_model_service,
            time_model,
            vehicle,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::routee::{
        prediction::load_prediction_model, prediction::model_type::ModelType,
        vehicle::default::ice::ICE,
    };
    use geo::coord;
    use routee_compass_core::{
        model::{
            property::{edge::Edge, vertex::Vertex},
            road_network::{edge_id::EdgeId, vertex_id::VertexId},
            traversal::default::{
                speed_traversal_engine::SpeedTraversalEngine,
                speed_traversal_service::SpeedLookupService,
            },
        },
        util::geo::coord::InternalCoord,
    };
    use std::{collections::HashMap, path::PathBuf};

    #[test]
    fn test_edge_cost_lookup_from_file() {
        let speed_file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
            .join("src")
            .join("routee")
            .join("test")
            .join("velocities.txt");
        let grade_file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
            .join("src")
            .join("routee")
            .join("test")
            .join("grades.txt");
        let model_file_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
            .join("src")
            .join("routee")
            .join("test")
            .join("Toyota_Camry.bin");
        let v = Vertex {
            vertex_id: VertexId(0),
            coordinate: InternalCoord(coord! {x: -86.67, y: 36.12}),
        };
        fn mock_edge(edge_id: usize) -> Edge {
            Edge {
                edge_id: EdgeId(edge_id),
                src_vertex_id: VertexId(0),
                dst_vertex_id: VertexId(1),
                distance: Distance::new(100.0),
            }
        }
        let model_record = load_prediction_model(
            "Toyota_Camry".to_string(),
            &model_file_path,
            ModelType::Smartcore,
            SpeedUnit::MilesPerHour,
            GradeUnit::Decimal,
            EnergyRateUnit::GallonsGasolinePerMile,
            None,
            None,
            None,
        )
        .unwrap();

        let state_model = Arc::new(
            StateModel::empty()
                .extend(vec![
                    (
                        String::from("distance"),
                        StateFeature::Distance {
                            distance_unit: DistanceUnit::Kilometers,
                            initial: Distance::ZERO,
                        },
                    ),
                    (
                        String::from("time"),
                        StateFeature::Time {
                            time_unit: TimeUnit::Minutes,
                            initial: Time::ZERO,
                        },
                    ),
                ])
                .unwrap(),
        );
        let camry = ICE::new("Toyota_Camry".to_string(), model_record).unwrap();

        let mut model_library: HashMap<String, Arc<dyn VehicleType>> = HashMap::new();
        model_library.insert("Toyota_Camry".to_string(), Arc::new(camry));

        let time_engine = Arc::new(
            SpeedTraversalEngine::new(&speed_file_path, SpeedUnit::KilometersPerHour, None, None)
                .unwrap(),
        );
        let time_service = SpeedLookupService { e: time_engine };

        let service = EnergyModelService::new(
            Arc::new(time_service),
            SpeedUnit::MilesPerHour,
            // &speed_file_path,
            &Some(grade_file_path),
            // SpeedUnit::KilometersPerHour,
            GradeUnit::Millis,
            None,
            None,
            model_library,
        )
        .unwrap();
        let arc_service = Arc::new(service);
        let conf = serde_json::json!({
            "model_name": "Toyota_Camry",
        });
        let model = EnergyTraversalModel::new(arc_service, &conf).unwrap();
        let updated_state_model = state_model.extend(model.state_features()).unwrap();
        println!("{:?}", updated_state_model.to_vec());
        let mut state = updated_state_model.initial_state().unwrap();
        let e1 = mock_edge(0);
        // 100 meters @ 10kph should take 36 seconds ((0.1/10) * 3600)
        model
            .traverse_edge((&v, &e1, &v), &mut state, &updated_state_model)
            .unwrap();
        println!("{:?}", state);
    }
}
