#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
CVE Vulnerability Data Models

This module defines data structures for representing CVE vulnerabilities,
affected libraries, and version ranges.
"""

from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any
from enum import Enum
from datetime import datetime


class CVESeverity(Enum):
    """CVE severity levels based on CVSS scores"""
    CRITICAL = "CRITICAL"  # 9.0-10.0
    HIGH = "HIGH"         # 7.0-8.9
    MEDIUM = "MEDIUM"     # 4.0-6.9
    LOW = "LOW"          # 0.1-3.9
    UNKNOWN = "UNKNOWN"   # No score available


@dataclass
class VersionRange:
    """Represents a version range affected by a vulnerability"""
    introduced: Optional[str] = None  # Version where vulnerability was introduced
    fixed: Optional[str] = None       # Version where vulnerability was fixed
    last_affected: Optional[str] = None  # Last version known to be affected
    limit: Optional[str] = None       # Upper limit (exclusive)
    
    def __str__(self) -> str:
        if self.introduced and self.fixed:
            return f"{self.introduced} <= version < {self.fixed}"
        elif self.introduced and self.last_affected:
            return f"{self.introduced} <= version <= {self.last_affected}"
        elif self.introduced:
            return f"version >= {self.introduced}"
        elif self.fixed:
            return f"version < {self.fixed}"
        elif self.last_affected:
            return f"version <= {self.last_affected}"
        else:
            return "All versions"


@dataclass
class AffectedLibrary:
    """Represents a library affected by a CVE"""
    name: str
    ecosystem: Optional[str] = None  # e.g., "Maven", "npm", "PyPI"
    purl: Optional[str] = None       # Package URL
    version_ranges: List[VersionRange] = field(default_factory=list)
    
    def is_version_affected(self, version: str) -> bool:
        """Check if a specific version is affected by this vulnerability"""
        if not self.version_ranges:
            return True  # If no ranges specified, assume all versions affected
        
        # This is a simplified version check - in practice you'd want
        # more sophisticated version comparison using packaging.version
        for range_info in self.version_ranges:
            if self._version_in_range(version, range_info):
                return True
        return False
    
    def _version_in_range(self, version: str, range_info: VersionRange) -> bool:
        """Check if version falls within the specified range"""
        # Simplified version comparison - should use proper version parsing
        try:
            from packaging import version as pkg_version
            ver = pkg_version.parse(version)
            
            # Check introduced version
            if range_info.introduced:
                if ver < pkg_version.parse(range_info.introduced):
                    return False
            
            # Check fixed version (exclusive)
            if range_info.fixed:
                if ver >= pkg_version.parse(range_info.fixed):
                    return False
            
            # Check last affected version (inclusive)
            if range_info.last_affected:
                if ver > pkg_version.parse(range_info.last_affected):
                    return False
            
            # Check limit (exclusive)
            if range_info.limit:
                if ver >= pkg_version.parse(range_info.limit):
                    return False
            
            return True
            
        except Exception:
            # Fallback to string comparison if version parsing fails
            return True


@dataclass
class CVEVulnerability:
    """Represents a CVE vulnerability"""
    cve_id: str
    summary: str
    description: Optional[str] = None
    severity: CVESeverity = CVESeverity.UNKNOWN
    cvss_score: Optional[float] = None
    cvss_vector: Optional[str] = None
    published_date: Optional[datetime] = None
    modified_date: Optional[datetime] = None
    affected_libraries: List[AffectedLibrary] = field(default_factory=list)
    references: List[str] = field(default_factory=list)
    source: Optional[str] = None  # Which CVE database this came from
    raw_data: Dict[str, Any] = field(default_factory=dict)  # Original API response
    
    @classmethod
    def from_cvss_score(cls, score: float) -> CVESeverity:
        """Convert CVSS score to severity level"""
        if score >= 9.0:
            return CVESeverity.CRITICAL
        elif score >= 7.0:
            return CVESeverity.HIGH
        elif score >= 4.0:
            return CVESeverity.MEDIUM
        elif score > 0.0:
            return CVESeverity.LOW
        else:
            return CVESeverity.UNKNOWN
    
    def affects_library(self, library_name: str, version: str) -> bool:
        """Check if this CVE affects a specific library and version"""
        for affected_lib in self.affected_libraries:
            if self._normalize_library_name(affected_lib.name) == self._normalize_library_name(library_name):
                return affected_lib.is_version_affected(version)
        return False
    
    def _normalize_library_name(self, name: str) -> str:
        """Normalize library name for comparison"""
        return name.lower().replace("-", "_").replace(".", "_")
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for JSON serialization"""
        return {
            'cve_id': self.cve_id,
            'summary': self.summary,
            'description': self.description,
            'severity': self.severity.value,
            'cvss_score': self.cvss_score,
            'cvss_vector': self.cvss_vector,
            'published_date': self.published_date.isoformat() if self.published_date else None,
            'modified_date': self.modified_date.isoformat() if self.modified_date else None,
            'affected_libraries': [
                {
                    'name': lib.name,
                    'ecosystem': lib.ecosystem,
                    'version_ranges': [
                        {
                            'introduced': vr.introduced,
                            'fixed': vr.fixed,
                            'last_affected': vr.last_affected,
                            'limit': vr.limit
                        }
                        for vr in lib.version_ranges
                    ]
                }
                for lib in self.affected_libraries
            ],
            'references': self.references,
            'source': self.source
        }
    
    def get_affected_version_summary(self) -> str:
        """Get human-readable summary of affected versions"""
        if not self.affected_libraries:
            return "Unknown affected versions"
        
        summaries = []
        for lib in self.affected_libraries:
            if lib.version_ranges:
                range_strs = [str(vr) for vr in lib.version_ranges]
                summaries.append(f"{lib.name}: {', '.join(range_strs)}")
            else:
                summaries.append(f"{lib.name}: All versions")
        
        return "; ".join(summaries)