"""
Management command to populate occurrence records for visit events from a CSV file.

Usage: python manage.py populate_occurrences_from_csv --csv-file=/path/to/records.csv
"""
import csv
import random
import math
from pathlib import Path
from django.core.management.base import BaseCommand
from django.db import transaction
from shapely.geometry import Point
from shapely.ops import transform
from pyproj import Transformer
from eventhub.models import Event, Occurrence, EventAttribute


# EPSG codes
WGS84 = "EPSG:4326"
OSGB = "EPSG:27700"  # British National Grid


def parse_os_grid_reference(grid_ref: str) -> tuple[float, float]:
    """
    Parse an OS grid reference (e.g., 'ST469897') to easting/northing.
    
    Uses the standard OSGB grid reference system. This implementation
    uses a lookup table for common UK grid squares for accuracy.
    
    Args:
        grid_ref: OS grid reference string
        
    Returns:
        Tuple of (easting, northing) in meters
    """
    if not grid_ref or len(grid_ref) < 2:
        raise ValueError(f"Invalid grid reference: {grid_ref}")
    
    # Remove spaces and convert to uppercase
    grid_ref = grid_ref.replace(' ', '').upper()
    
    # Get the two-letter prefix
    if len(grid_ref) < 2:
        raise ValueError(f"Grid reference too short: {grid_ref}")
    
    prefix = grid_ref[:2]
    numeric_part = grid_ref[2:]
    
    if not numeric_part or len(numeric_part) % 2 != 0:
        raise ValueError(f"Grid reference must have even number of digits: {grid_ref}")
    
    # Split into easting and northing parts
    half_len = len(numeric_part) // 2
    easting_str = numeric_part[:half_len]
    northing_str = numeric_part[half_len:]
    
    if not easting_str.isdigit() or not northing_str.isdigit():
        raise ValueError(f"Grid reference numeric part must be digits: {grid_ref}")
    
    easting_num = int(easting_str)
    northing_num = int(northing_str)
    precision = len(easting_str)
    
    # OSGB 500km square lookup table for common UK grid references
    # Format: (easting_offset_500km, northing_offset_500km)
    # These are the offsets in 500km units from the OSGB origin
    osgb_500km_squares = {
        'SV': (0, 0), 'SW': (1, 0), 'SX': (2, 0), 'SY': (3, 0), 'SZ': (4, 0),
        'SQ': (0, 1), 'SR': (1, 1), 'SS': (2, 1), 'ST': (3, 1), 'SU': (4, 1),
        'SL': (0, 2), 'SM': (1, 2), 'SN': (2, 2), 'SO': (3, 2), 'SP': (4, 2),
        'SF': (0, 3), 'SG': (1, 3), 'SH': (2, 3), 'SJ': (3, 3), 'SK': (4, 3),
        'SA': (0, 4), 'SB': (1, 4), 'SC': (2, 4), 'SD': (3, 4), 'SE': (4, 4),
        'TA': (5, 4), 'TF': (5, 3), 'TG': (6, 3), 'TL': (5, 2), 'TM': (6, 2),
        'TQ': (5, 1), 'TR': (6, 1), 'TV': (5, 0),
        'OV': (0, 5), 'OW': (1, 5), 'OX': (2, 5), 'OY': (3, 5), 'OZ': (4, 5),
        'OQ': (0, 6), 'OR': (1, 6), 'OS': (2, 6), 'OT': (3, 6), 'OU': (4, 6),
        'OL': (0, 7), 'OM': (1, 7), 'ON': (2, 7), 'OO': (3, 7), 'OP': (4, 7),
        'OF': (0, 8), 'OG': (1, 8), 'OH': (2, 8), 'OJ': (3, 8), 'OK': (4, 8),
        'OA': (0, 9), 'OB': (1, 9), 'OC': (2, 9), 'OD': (3, 9), 'OE': (4, 9),
        'NV': (0, 5), 'NW': (1, 5), 'NX': (2, 5), 'NY': (3, 5), 'NZ': (4, 5),
        'NQ': (0, 6), 'NR': (1, 6), 'NS': (2, 6), 'NT': (3, 6), 'NU': (4, 6),
        'NL': (0, 7), 'NM': (1, 7), 'NN': (2, 7), 'NO': (3, 7), 'NP': (4, 7),
        'NF': (0, 8), 'NG': (1, 8), 'NH': (2, 8), 'NJ': (3, 8), 'NK': (4, 8),
        'NA': (0, 9), 'NB': (1, 9), 'NC': (2, 9), 'ND': (3, 9), 'NE': (4, 9),
        'HV': (0, 5), 'HW': (1, 5), 'HX': (2, 5), 'HY': (3, 5), 'HZ': (4, 5),
        'HQ': (0, 6), 'HR': (1, 6), 'HS': (2, 6), 'HT': (3, 6), 'HU': (4, 6),
        'HL': (0, 7), 'HM': (1, 7), 'HN': (2, 7), 'HO': (3, 7), 'HP': (4, 7),
    }
    
    # Get base easting/northing for the 500km square
    if prefix in osgb_500km_squares:
        base_easting_500km, base_northing_500km = osgb_500km_squares[prefix]
        base_easting_500km *= 500000
        base_northing_500km *= 500000
    else:
        # Fallback: try to parse using letter indices (less accurate)
        def letter_to_index(letter):
            """Convert letter to index (A=0, B=1, ..., H=7, J=8, K=9, ..., Z=24)"""
            if letter == 'I':
                raise ValueError(f"Letter 'I' is not used in OS grid references")
            if letter < 'I':
                return ord(letter) - ord('A')
            else:
                return ord(letter) - ord('A') - 1
        
        try:
            first_index = letter_to_index(prefix[0])
            second_index = letter_to_index(prefix[1])
            base_easting_500km = first_index * 500000
            base_northing_500km = second_index * 500000
        except ValueError as e:
            raise ValueError(f"Invalid letter in grid reference {grid_ref}: {e}")
    
    # Determine multiplier based on precision
    # Standard OSGB precision:
    # 2 digits = 10km, 3 digits = 1km, 4 digits = 100m, 5 digits = 10m, 6+ digits = 1m
    if precision == 1:
        multiplier = 100000  # 1 digit = 100km (rare)
    elif precision == 2:
        multiplier = 10000   # 2 digits = 10km
    elif precision == 3:
        multiplier = 1000    # 3 digits = 1km
    elif precision == 4:
        multiplier = 100     # 4 digits = 100m
    elif precision == 5:
        multiplier = 10      # 5 digits = 10m
    else:
        multiplier = 1       # 6+ digits = 1m
    
    # Calculate full easting and northing
    # Add half a grid square to get center point
    easting = base_easting_500km + easting_num * multiplier + (multiplier // 2)
    northing = base_northing_500km + northing_num * multiplier + (multiplier // 2)
    
    return (easting, northing)


def grid_ref_to_lat_lon(grid_ref: str) -> tuple[float, float]:
    """
    Convert OS grid reference to WGS84 latitude/longitude.
    
    Args:
        grid_ref: OS grid reference string (e.g., 'ST469897')
        
    Returns:
        Tuple of (latitude, longitude) in WGS84
    """
    easting, northing = parse_os_grid_reference(grid_ref)
    
    # Convert from OSGB to WGS84
    point_osgb = Point(easting, northing)
    to_wgs84 = Transformer.from_crs(OSGB, WGS84, always_xy=True).transform
    point_wgs84 = transform(to_wgs84, point_osgb)
    
    return (point_wgs84.y, point_wgs84.x)  # lat, lon


def add_variance_to_coords(lat: float, lon: float, max_variance_m: float = 100.0) -> tuple[float, float]:
    """
    Add random variance to coordinates within a specified radius.
    
    Args:
        lat: Latitude in degrees
        lon: Longitude in degrees
        max_variance_m: Maximum variance in meters (default 100m)
        
    Returns:
        Tuple of (new_latitude, new_longitude)
    """
    # Earth's radius in meters
    earth_radius = 6378137.0
    
    # Random distance within the range [0, max_variance]
    distance = random.uniform(0, max_variance_m)
    # Random bearing in radians
    bearing = random.uniform(0, 2 * math.pi)
    
    # Convert to radians
    lat_rad = math.radians(lat)
    lon_rad = math.radians(lon)
    
    # Calculate new position using haversine formula
    angular_distance = distance / earth_radius
    
    new_lat_rad = math.asin(
        math.sin(lat_rad) * math.cos(angular_distance) +
        math.cos(lat_rad) * math.sin(angular_distance) * math.cos(bearing)
    )
    
    new_lon_rad = lon_rad + math.atan2(
        math.sin(bearing) * math.sin(angular_distance) * math.cos(lat_rad),
        math.cos(angular_distance) - math.sin(lat_rad) * math.sin(new_lat_rad)
    )
    
    return (math.degrees(new_lat_rad), math.degrees(new_lon_rad))


class Command(BaseCommand):
    help = 'Populate occurrence records for visit events from a CSV file'

    def add_arguments(self, parser):
        parser.add_argument(
            '--csv-file',
            type=str,
            required=True,
            help='Path to the CSV file containing occurrence records',
        )
        parser.add_argument(
            '--delete-existing',
            action='store_true',
            help='Delete existing occurrence records before populating',
        )

    def handle(self, *args, **options):
        csv_file = Path(options['csv_file'])
        
        if not csv_file.exists():
            self.stdout.write(
                self.style.ERROR(f'CSV file not found: {csv_file}')
            )
            return
        
        # Delete existing occurrences if requested
        if options['delete_existing']:
            count = Occurrence.objects.count()
            Occurrence.objects.all().delete()
            self.stdout.write(
                self.style.WARNING(f'Deleted {count} existing occurrence records')
            )
        
        # Read CSV file and extract relevant records
        self.stdout.write(f'Reading CSV file: {csv_file}')
        csv_records = []
        
        with open(csv_file, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                # Extract only the fields we need
                record = {
                    'occurrence_id': row.get('Occurrence ID', '').strip(),
                    'scientific_name': row.get('Scientific name', '').strip(),
                    'common_name': row.get('Common name', '').strip(),
                    'taxon_id': row.get('Species ID (TVK)', '').strip(),
                    'individual_count': row.get('Individual count', '').strip(),
                }
                # Only add if we have at least scientific name
                if record['scientific_name']:
                    csv_records.append(record)
        
        if not csv_records:
            self.stdout.write(
                self.style.ERROR('No valid records found in CSV file')
            )
            return
        
        self.stdout.write(f'Found {len(csv_records)} valid records in CSV')
        
        # Get all visit events with their parent hierarchy (transect -> site)
        visit_events = Event.objects.filter(event_type='visit').select_related('parent', 'parent__parent')
        
        if not visit_events.exists():
            self.stdout.write(
                self.style.WARNING('No visit events found')
            )
            return
        
        self.stdout.write(f'Found {visit_events.count()} visit events')
        
        # Process each visit event
        total_created = 0
        
        with transaction.atomic():
            for visit_event in visit_events:
                # Get the site's grid reference
                # Visit -> Transect -> Site hierarchy
                grid_ref = None
                site_event = None
                
                # Get parent transect
                transect_event = visit_event.parent
                if not transect_event:
                    self.stdout.write(
                        self.style.WARNING(
                            f'Visit {visit_event.name} has no parent transect, skipping...'
                        )
                    )
                    continue
                
                if transect_event.event_type != 'transect':
                    self.stdout.write(
                        self.style.WARNING(
                            f'Visit {visit_event.name} parent is not a transect ({transect_event.event_type}), skipping...'
                        )
                    )
                    continue
                
                # Get parent site
                site_event = transect_event.parent
                if not site_event:
                    self.stdout.write(
                        self.style.WARNING(
                            f'Transect {transect_event.name} has no parent site, skipping visit {visit_event.name}...'
                        )
                    )
                    continue
                
                if site_event.event_type != 'site':
                    self.stdout.write(
                        self.style.WARNING(
                            f'Transect {transect_event.name} parent is not a site ({site_event.event_type}), skipping visit {visit_event.name}...'
                        )
                    )
                    continue
                
                # Look for grid reference in EventAttribute
                grid_ref_attr = EventAttribute.objects.filter(
                    event=site_event,
                    attribute_type='gridReference'
                ).first()
                
                if grid_ref_attr:
                    grid_ref = grid_ref_attr.attribute_value
                
                if not grid_ref:
                    self.stdout.write(
                        self.style.WARNING(
                            f'No grid reference found for site {site_event.name} (visit {visit_event.name}), skipping...'
                        )
                    )
                    continue
                
                # Get visit date
                visit_date = visit_event.start_date
                if not visit_date:
                    self.stdout.write(
                        self.style.WARNING(
                            f'No date found for visit {visit_event.name}, skipping...'
                        )
                    )
                    continue
                
                # Convert grid reference to lat/lon center
                try:
                    center_lat, center_lon = grid_ref_to_lat_lon(grid_ref)
                except Exception as e:
                    self.stdout.write(
                        self.style.WARNING(
                            f'Could not convert grid reference {grid_ref} to coordinates: {e}, skipping...'
                        )
                    )
                    continue
                
                # Randomly select 5-20 records from CSV
                num_occurrences = random.randint(5, 20)
                selected_records = random.sample(csv_records, min(num_occurrences, len(csv_records)))
                
                # Create occurrences
                occurrences_to_create = []
                used_occurrence_ids = set()
                
                for record in selected_records:
                    # Add variance to coordinates
                    lat, lon = add_variance_to_coords(center_lat, center_lon, max_variance_m=100.0)
                    
                    # Ensure occurrence_id is unique by appending visit ID if needed
                    occurrence_id = record['occurrence_id'] or None
                    if occurrence_id and occurrence_id in used_occurrence_ids:
                        # Make it unique by appending visit event ID
                        occurrence_id = f"{occurrence_id}_{visit_event.id}"
                    if occurrence_id:
                        used_occurrence_ids.add(occurrence_id)
                    
                    occurrence = Occurrence(
                        event=visit_event,
                        occurrence_id=occurrence_id,
                        taxon_name=record['scientific_name'],
                        taxon_id=record['taxon_id'] or None,
                        common_name=record['common_name'] or None,
                        grid_reference=grid_ref,
                        latitude=lat,
                        longitude=lon,
                        abundance=record['individual_count'] or None,
                        occurrence_date=visit_date,
                    )
                    occurrences_to_create.append(occurrence)
                
                # Bulk create occurrences
                Occurrence.objects.bulk_create(occurrences_to_create)
                total_created += len(occurrences_to_create)
                
                self.stdout.write(
                    f'Created {len(occurrences_to_create)} occurrences for visit {visit_event.name}'
                )
        
        self.stdout.write(
            self.style.SUCCESS(f'\nSuccessfully created {total_created} occurrence records')
        )

