"""
Management command to import UKBMS Wales synthetic data into eventhub.

Usage: python manage.py import_ukbms_wales_to_eventhub [--data-dir=path/to/data]
"""
import csv
import hashlib
import os
from datetime import datetime
from pathlib import Path
from django.core.management.base import BaseCommand
from django.db import transaction
from django.db.models import Min, Max
from shapely import wkt as shapely_wkt
from shapely.geometry import Point, LineString
from shapely.ops import transform
from shapely import affinity
from pyproj import Transformer
from eventhub.models import Event, Occurrence, EventAttribute


# EPSG codes
WGS84 = "EPSG:4326"
OSGB = "EPSG:27700"  # British National Grid


def transform_geometry_from_osgb_to_wgs84(wkt_str: str) -> str:
    """
    Transform a WKT geometry from OSGB (EPSG:27700) to WGS84 (EPSG:4326).
    
    Args:
        wkt_str: WKT string in OSGB coordinates
        
    Returns:
        WKT string in WGS84 coordinates
    """
    try:
        geom = shapely_wkt.loads(wkt_str)
    except Exception as e:
        raise ValueError(f"Invalid WKT: {e}") from e
    
    # Create transformer from OSGB to WGS84
    to_wgs84 = Transformer.from_crs(OSGB, WGS84, always_xy=True).transform
    geom_wgs84 = transform(to_wgs84, geom)
    
    return geom_wgs84.wkt


def create_point_from_easting_northing(easting: float, northing: float) -> str:
    """
    Create a Point WKT geometry from OSGB easting/northing coordinates.
    
    Args:
        easting: Easting coordinate in OSGB
        northing: Northing coordinate in OSGB
        
    Returns:
        WKT string for the point in WGS84 coordinates
    """
    point_osgb = Point(easting, northing)
    wkt_osgb = point_osgb.wkt
    return transform_geometry_from_osgb_to_wgs84(wkt_osgb)


def rotate_transect_geometry(wkt_str: str, transect_id: str, rotation_range: float = 60.0) -> str:
    """
    Rotate a transect geometry by a varying angle based on the transect_id.
    
    This ensures transects are not all oriented the same direction (e.g., all West to East).
    The rotation angle is deterministic based on the transect_id, so the same transect
    will always get the same rotation.
    
    Args:
        wkt_str: WKT string of the geometry (should be in WGS84)
        transect_id: Unique identifier for the transect (used to generate rotation angle)
        rotation_range: Total rotation range in degrees (default: 60, so -30 to +30 degrees)
        
    Returns:
        WKT string for the rotated geometry
    """
    try:
        geom = shapely_wkt.loads(wkt_str)
    except Exception as e:
        raise ValueError(f"Invalid WKT: {e}") from e
    
    # Generate a deterministic rotation angle based on transect_id
    # Use MD5 hash to get a pseudo-random but consistent value
    hash_bytes = hashlib.md5(transect_id.encode()).digest()
    # Convert first 4 bytes to an integer
    hash_int = int.from_bytes(hash_bytes[:4], byteorder='big')
    # Convert to angle between -rotation_range/2 and +rotation_range/2
    rotation_angle = (hash_int % 360) * (rotation_range / 360.0) - (rotation_range / 2.0)
    
    # Rotate around the centroid of the geometry
    centroid = geom.centroid
    rotated_geom = affinity.rotate(geom, rotation_angle, origin=centroid)
    
    return rotated_geom.wkt


class Command(BaseCommand):
    help = 'Import UKBMS Wales synthetic data from CSV files into eventhub'

    def add_arguments(self, parser):
        parser.add_argument(
            '--data-dir',
            type=str,
            default='context_data',
            help='Directory containing the UKBMS CSV files (default: context_data)',
        )

    def handle(self, *args, **options):
        data_dir = Path(options['data_dir'])
        
        # Check if data directory exists
        if not data_dir.exists():
            self.stdout.write(
                self.style.ERROR(f'Data directory not found: {data_dir}')
            )
            return
        
        # Define CSV file paths
        sites_file = data_dir / 'ukbms_sites_wales_synth.csv'
        transects_file = data_dir / 'ukbms_transects_wales_synth.csv'
        transect_geometries_file = data_dir / 'ukbms_transect_geometries_wales_synth.csv'
        visits_file = data_dir / 'ukbms_visits_wales_synth.csv'
        occurrences_file = data_dir / 'ukbms_occurrences_wales_synth.csv'
        
        # Check if all required files exist
        required_files = [sites_file, transects_file, transect_geometries_file, visits_file, occurrences_file]
        missing_files = [f for f in required_files if not f.exists()]
        if missing_files:
            self.stdout.write(
                self.style.ERROR(f'Missing required files: {", ".join(str(f) for f in missing_files)}')
            )
            return
        
        with transaction.atomic():
            # Step 1: Create Project Event
            self.stdout.write('Creating Project event...')
            project_event, project_event_created = Event.objects.get_or_create(
                name="UKBMS Wales (synthetic)",
                event_type='project',
                defaults={
                    'parent': None,
                }
            )
            
            if project_event_created:
                self.stdout.write(self.style.SUCCESS(f'Created project event: {project_event.name}'))
            else:
                self.stdout.write(f'Using existing project event: {project_event.name}')
            
            # Step 2: Import Sites
            self.stdout.write('Importing Sites...')
            site_events = {}
            sites_created = 0
            sites_reused = 0
            
            with open(sites_file, 'r', encoding='utf-8') as f:
                reader = csv.DictReader(f)
                for row in reader:
                    site_number = row['site_number']
                    site_name = row['site_name']
                    
                    # Create point geometry from easting/northing
                    footprint_wkt = None
                    try:
                        easting = float(row['easting']) if row['easting'] else None
                        northing = float(row['northing']) if row['northing'] else None
                        if easting and northing:
                            footprint_wkt = create_point_from_easting_northing(easting, northing)
                    except (ValueError, TypeError) as e:
                        self.stdout.write(
                            self.style.WARNING(f'Could not create geometry for site {site_number}: {e}')
                        )
                    
                    # Parse dates
                    start_date = None
                    end_date = None
                    try:
                        if row['first_year']:
                            start_date = datetime(int(row['first_year']), 1, 1).date()
                        if row['last_year']:
                            end_date = datetime(int(row['last_year']), 12, 31).date()
                    except (ValueError, TypeError):
                        pass
                    
                    # Build location note
                    location_parts = []
                    if row.get('gridref'):
                        location_parts.append(f"gridref={row['gridref']}")
                    if row.get('country'):
                        location_parts.append(f"country={row['country']}")
                    location_note = ", ".join(location_parts) if location_parts else None
                    
                    # Create or get Site event
                    site_event, site_created = Event.objects.get_or_create(
                        name=site_name,
                        event_type='site',
                        parent=project_event,
                        defaults={
                            'start_date': start_date,
                            'end_date': end_date,
                            'location_note': location_note,
                            'footprintWKT': footprint_wkt,
                        }
                    )
                    
                    if site_created:
                        sites_created += 1
                    else:
                        sites_reused += 1
                        # Update dates and geometry if needed
                        if start_date and site_event.start_date != start_date:
                            site_event.start_date = start_date
                        if end_date and site_event.end_date != end_date:
                            site_event.end_date = end_date
                        if location_note and site_event.location_note != location_note:
                            site_event.location_note = location_note
                        if footprint_wkt and site_event.footprintWKT != footprint_wkt:
                            site_event.footprintWKT = footprint_wkt
                        if site_event.parent != project_event:
                            site_event.parent = project_event
                        site_event.save()
                    
                    # Store site attributes
                    EventAttribute.objects.filter(event=site_event).delete()
                    attributes_to_create = []
                    
                    if row.get('gridref'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=site_event,
                                attribute_type='gridReference',
                                attribute_value=row['gridref']
                            )
                        )
                    if row.get('length_m'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=site_event,
                                attribute_type='length',
                                attribute_value=row['length_m'],
                                unit='m'
                            )
                        )
                    if row.get('n_sections'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=site_event,
                                attribute_type='n_sections',
                                attribute_value=row['n_sections']
                            )
                        )
                    if row.get('survey_type'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=site_event,
                                attribute_type='survey_type',
                                attribute_value=row['survey_type']
                            )
                        )
                    
                    if attributes_to_create:
                        EventAttribute.objects.bulk_create(attributes_to_create)
                    
                    site_events[site_number] = site_event
            
            self.stdout.write(f'Created {sites_created} new Site events, reused {sites_reused} existing')
            
            # Step 3: Load transect geometries into a dict for quick lookup
            transect_geometries = {}
            with open(transect_geometries_file, 'r', encoding='utf-8') as f:
                reader = csv.DictReader(f)
                for row in reader:
                    transect_id = row['transect_id']
                    wkt_geom_osgb = row['wkt_geom']
                    try:
                        # Transform geometry from OSGB to WGS84
                        wkt_geom_wgs84 = transform_geometry_from_osgb_to_wgs84(wkt_geom_osgb)
                        # Rotate the geometry to vary orientation (not all West to East)
                        wkt_geom_rotated = rotate_transect_geometry(wkt_geom_wgs84, transect_id)
                        transect_geometries[transect_id] = wkt_geom_rotated
                    except Exception as e:
                        self.stdout.write(
                            self.style.WARNING(f'Could not transform geometry for transect {transect_id}: {e}')
                        )
            
            # Step 4: Import Transects
            self.stdout.write('Importing Transects...')
            transect_events = {}
            transects_created = 0
            transects_reused = 0
            
            with open(transects_file, 'r', encoding='utf-8') as f:
                reader = csv.DictReader(f)
                for row in reader:
                    transect_id = row['transect_id']
                    site_number = row['site_number']
                    transect_name = row['transect_name']
                    
                    # Get parent site event
                    parent_site = site_events.get(site_number)
                    if not parent_site:
                        self.stdout.write(
                            self.style.WARNING(f'No Site event found for site_number {site_number}, skipping transect {transect_id}...')
                        )
                        continue
                    
                    # Get geometry for this transect
                    footprint_wkt = transect_geometries.get(transect_id)
                    
                    # Create or get Transect event
                    transect_event, transect_created = Event.objects.get_or_create(
                        name=transect_name,
                        event_type='transect',
                        parent=parent_site,
                        defaults={
                            'footprintWKT': footprint_wkt,
                        }
                    )
                    
                    if transect_created:
                        transects_created += 1
                    else:
                        transects_reused += 1
                        # Update geometry if needed
                        if footprint_wkt and transect_event.footprintWKT != footprint_wkt:
                            transect_event.footprintWKT = footprint_wkt
                        if transect_event.parent != parent_site:
                            transect_event.parent = parent_site
                        transect_event.save()
                    
                    # Store transect attributes
                    EventAttribute.objects.filter(event=transect_event).delete()
                    attributes_to_create = []
                    
                    if row.get('length_m'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=transect_event,
                                attribute_type='length',
                                attribute_value=row['length_m'],
                                unit='m'
                            )
                        )
                    if row.get('n_sections'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=transect_event,
                                attribute_type='n_sections',
                                attribute_value=row['n_sections']
                            )
                        )
                    
                    if attributes_to_create:
                        EventAttribute.objects.bulk_create(attributes_to_create)
                    
                    transect_events[transect_id] = transect_event
            
            self.stdout.write(f'Created {transects_created} new Transect events, reused {transects_reused} existing')
            
            # Step 5: Import Visits
            self.stdout.write('Importing Visits...')
            visit_events = {}
            visits_created = 0
            visits_reused = 0
            
            with open(visits_file, 'r', encoding='utf-8') as f:
                reader = csv.DictReader(f)
                for row in reader:
                    visit_id = row['visit_id']
                    transect_id = row['transect_id']
                    visit_date_str = row['visit_date']
                    
                    # Get parent transect event
                    parent_transect = transect_events.get(transect_id)
                    if not parent_transect:
                        self.stdout.write(
                            self.style.WARNING(f'No Transect event found for transect_id {transect_id}, skipping visit {visit_id}...')
                        )
                        continue
                    
                    # Parse visit date
                    visit_date = None
                    try:
                        visit_date = datetime.strptime(visit_date_str, '%Y-%m-%d').date()
                    except (ValueError, TypeError):
                        pass
                    
                    # Build visit name
                    visit_name = f"Visit {visit_date_str}" if visit_date_str else visit_id
                    
                    # Create or get Visit event
                    visit_event, visit_created = Event.objects.get_or_create(
                        name=visit_name,
                        event_type='visit',
                        parent=parent_transect,
                        defaults={
                            'start_date': visit_date,
                            'end_date': visit_date,
                            'fieldNotes': row.get('notes'),
                        }
                    )
                    
                    if visit_created:
                        visits_created += 1
                    else:
                        visits_reused += 1
                        # Update dates if needed
                        if visit_date and visit_event.start_date != visit_date:
                            visit_event.start_date = visit_date
                            visit_event.end_date = visit_date
                        if row.get('notes') and visit_event.fieldNotes != row['notes']:
                            visit_event.fieldNotes = row['notes']
                        if visit_event.parent != parent_transect:
                            visit_event.parent = parent_transect
                        visit_event.save()
                    
                    # Store visit attributes
                    EventAttribute.objects.filter(event=visit_event).delete()
                    attributes_to_create = []
                    
                    if row.get('start_time'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=visit_event,
                                attribute_type='start_time',
                                attribute_value=row['start_time']
                            )
                        )
                    if row.get('duration_min'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=visit_event,
                                attribute_type='duration',
                                attribute_value=row['duration_min'],
                                unit='min'
                            )
                        )
                    if row.get('temperature_c'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=visit_event,
                                attribute_type='temperature',
                                attribute_value=row['temperature_c'],
                                unit='°C'
                            )
                        )
                    if row.get('wind_beaufort'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=visit_event,
                                attribute_type='wind_beaufort',
                                attribute_value=row['wind_beaufort']
                            )
                        )
                    if row.get('cloud_oktas'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=visit_event,
                                attribute_type='cloud_oktas',
                                attribute_value=row['cloud_oktas']
                            )
                        )
                    if row.get('recorder_id'):
                        attributes_to_create.append(
                            EventAttribute(
                                event=visit_event,
                                attribute_type='recorder_id',
                                attribute_value=row['recorder_id']
                            )
                        )
                    
                    if attributes_to_create:
                        EventAttribute.objects.bulk_create(attributes_to_create)
                    
                    visit_events[visit_id] = visit_event
            
            self.stdout.write(f'Created {visits_created} new Visit events, reused {visits_reused} existing')
            
            # Step 6: Import Occurrences
            self.stdout.write('Importing Occurrences...')
            occurrences_created = 0
            occurrences_reused = 0
            
            with open(occurrences_file, 'r', encoding='utf-8') as f:
                reader = csv.DictReader(f)
                for row in reader:
                    visit_id = row['visit_id']
                    scientific_name = row['scientific_name']
                    common_name = row.get('common_name', '')
                    individual_count = row.get('individual_count', '')
                    
                    # Get parent visit event
                    visit_event = visit_events.get(visit_id)
                    if not visit_event:
                        self.stdout.write(
                            self.style.WARNING(f'No Visit event found for visit_id {visit_id}, skipping occurrence...')
                        )
                        continue
                    
                    # Create or get Occurrence
                    occurrence, occ_created = Occurrence.objects.get_or_create(
                        event=visit_event,
                        taxon_name=scientific_name,
                        defaults={
                            'abundance': individual_count if individual_count else None,
                            'occurrence_date': visit_event.start_date,
                        }
                    )
                    
                    if occ_created:
                        occurrences_created += 1
                    else:
                        occurrences_reused += 1
                        # Update abundance and date if needed
                        if individual_count and occurrence.abundance != individual_count:
                            occurrence.abundance = individual_count
                        if visit_event.start_date and occurrence.occurrence_date != visit_event.start_date:
                            occurrence.occurrence_date = visit_event.start_date
                        occurrence.save()
                    
                    # Store occurrence attributes (common name)
                    EventAttribute.objects.filter(occurrence=occurrence).delete()
                    if common_name:
                        EventAttribute.objects.create(
                            occurrence=occurrence,
                            attribute_type='common_name',
                            attribute_value=common_name
                        )
            
            self.stdout.write(f'Created {occurrences_created} new Occurrences, reused {occurrences_reused} existing')
            
            # Step 7: Update Site event dates from child transects
            self.stdout.write('Updating Site event dates from child transects...')
            for site_number, site_event in site_events.items():
                child_dates = Event.objects.filter(
                    parent=site_event,
                    event_type='transect'
                ).aggregate(
                    min_start=Min('start_date'),
                    max_end=Max('end_date')
                )
                
                # Also check transect children (visits) for dates
                transect_children = Event.objects.filter(
                    parent__parent=site_event,
                    event_type='visit'
                ).aggregate(
                    min_start=Min('start_date'),
                    max_end=Max('end_date')
                )
                
                # Combine dates from both
                all_min = None
                all_max = None
                if child_dates['min_start']:
                    all_min = child_dates['min_start']
                if transect_children['min_start']:
                    if all_min is None or transect_children['min_start'] < all_min:
                        all_min = transect_children['min_start']
                if child_dates['max_end']:
                    all_max = child_dates['max_end']
                if transect_children['max_end']:
                    if all_max is None or transect_children['max_end'] > all_max:
                        all_max = transect_children['max_end']
                
                updated = False
                if all_min and site_event.start_date != all_min:
                    site_event.start_date = all_min
                    updated = True
                if all_max and site_event.end_date != all_max:
                    site_event.end_date = all_max
                    updated = True
                
                if updated:
                    site_event.save()
            
            # Step 8: Update project event dates from child Site events
            self.stdout.write('Updating project event dates from child Site events...')
            site_dates = Event.objects.filter(
                parent=project_event,
                event_type='site'
            ).aggregate(
                min_start=Min('start_date'),
                max_end=Max('end_date')
            )
            
            updated = False
            if site_dates['min_start'] and project_event.start_date != site_dates['min_start']:
                project_event.start_date = site_dates['min_start']
                updated = True
            if site_dates['max_end'] and project_event.end_date != site_dates['max_end']:
                project_event.end_date = site_dates['max_end']
                updated = True
            
            if updated:
                project_event.save()
            
            # Print summary
            self.stdout.write(self.style.SUCCESS('\n' + '='*50))
            self.stdout.write(self.style.SUCCESS('Import Summary:'))
            self.stdout.write(self.style.SUCCESS('='*50))
            self.stdout.write(f'Site events created: {sites_created}')
            self.stdout.write(f'Site events reused: {sites_reused}')
            self.stdout.write(f'Transect events created: {transects_created}')
            self.stdout.write(f'Transect events reused: {transects_reused}')
            self.stdout.write(f'Visit events created: {visits_created}')
            self.stdout.write(f'Visit events reused: {visits_reused}')
            self.stdout.write(f'Occurrences created: {occurrences_created}')
            self.stdout.write(f'Occurrences reused: {occurrences_reused}')
            self.stdout.write(self.style.SUCCESS('='*50))



