"""
Management command to project occurrence coordinates onto their parent transect lines.

For each occurrence:
1. Gets the visit event (occurrence.event)
2. Gets the transect event (visit.parent)
3. Gets the transect line geometry
4. Projects the occurrence's current lat/lon onto the transect line
5. Updates the occurrence's lat/lon to be on or very close (<5 meters) to the line

Usage: python manage.py project_occurrences_to_transects [--dry-run] [--max-distance=5]
"""
import hashlib
import math
from django.core.management.base import BaseCommand
from django.db import transaction
from shapely import wkt as shapely_wkt
from shapely.geometry import Point
from eventhub.models import Event, Occurrence


class Command(BaseCommand):
    help = 'Generate occurrence coordinates on their parent transect lines'

    def add_arguments(self, parser):
        parser.add_argument(
            '--dry-run',
            action='store_true',
            help='Show what would be updated without actually updating the database',
        )
        parser.add_argument(
            '--batch-size',
            type=int,
            default=100,
            help='Number of occurrences to process in each batch (default: 100)',
        )
        parser.add_argument(
            '--verbose',
            action='store_true',
            help='Show detailed information about each occurrence',
        )

    def handle(self, *args, **options):
        dry_run = options['dry_run']
        batch_size = options['batch_size']
        verbose = options['verbose']
        
        # Get all occurrences (we'll generate coordinates for all, regardless of existing values)
        occurrences = Occurrence.objects.select_related('event', 'event__parent')
        
        total_count = occurrences.count()
        
        self.stdout.write(f'Found {total_count} occurrence(s) to process.')
        
        if total_count == 0:
            self.stdout.write(self.style.WARNING('No occurrences found.'))
            return
        
        if dry_run:
            self.stdout.write(self.style.WARNING('\nDRY RUN MODE - No changes will be saved\n'))
        
        updated_count = 0
        skipped_count = 0
        error_count = 0
        skip_reasons = {
            'no_parent': 0,
            'not_transect': 0,
            'no_geometry': 0,
            'too_far': 0,
            'invalid_coords': 0,
        }
        
        # Process in batches
        for i in range(0, total_count, batch_size):
            batch = occurrences[i:i+batch_size]
            
            if not dry_run:
                with transaction.atomic():
                    for occ in batch:
                        try:
                            result, reason = self._process_occurrence(occ, dry_run, verbose)
                            if result == 'updated':
                                updated_count += 1
                            elif result == 'skipped':
                                skipped_count += 1
                                if reason:
                                    skip_reasons[reason] = skip_reasons.get(reason, 0) + 1
                            elif result == 'error':
                                error_count += 1
                        except Exception as e:
                            error_count += 1
                            self.stdout.write(
                                self.style.ERROR(
                                    f'Error processing occurrence {occ.id} ({occ.taxon_name}): {e}'
                                )
                            )
            else:
                for occ in batch:
                    try:
                        result, reason = self._process_occurrence(occ, dry_run, verbose)
                        if result == 'updated':
                            updated_count += 1
                        elif result == 'skipped':
                            skipped_count += 1
                            if reason:
                                skip_reasons[reason] = skip_reasons.get(reason, 0) + 1
                        elif result == 'error':
                            error_count += 1
                    except Exception as e:
                        error_count += 1
                        self.stdout.write(
                            self.style.ERROR(
                                f'Error processing occurrence {occ.id} ({occ.taxon_name}): {e}'
                            )
                        )
            
            # Progress update
            processed = min(i + batch_size, total_count)
            if processed % 100 == 0 or processed == total_count:
                self.stdout.write(f'Processed {processed}/{total_count} occurrences...')
        
        # Summary
        self.stdout.write('')
        self.stdout.write('=' * 50)
        if dry_run:
            self.stdout.write(
                self.style.SUCCESS(
                    f'DRY RUN: Would update {updated_count} occurrence(s)'
                )
            )
        else:
            self.stdout.write(
                self.style.SUCCESS(
                    f'Successfully updated {updated_count} occurrence(s)'
                )
            )
        
        if skipped_count > 0:
            self.stdout.write(
                self.style.WARNING(
                    f'Skipped {skipped_count} occurrence(s)'
                )
            )
            if verbose or any(skip_reasons.values()):
                for reason, count in skip_reasons.items():
                    if count > 0:
                        reason_display = {
                            'no_parent': 'No parent transect',
                            'not_transect': 'Parent is not a transect',
                            'no_geometry': 'Transect has no geometry',
                            'invalid_coords': 'Invalid coordinates',
                        }.get(reason, reason)
                        self.stdout.write(f'  - {reason_display}: {count}')
        
        if error_count > 0:
            self.stdout.write(
                self.style.ERROR(f'Errors occurred for {error_count} occurrence(s)')
            )
        self.stdout.write('=' * 50)
    
    def _process_occurrence(self, occ: Occurrence, dry_run: bool, verbose: bool = False) -> tuple[str, str | None]:
        """
        Process a single occurrence.
        
        Returns:
            Tuple of (result, reason) where result is 'updated', 'skipped', or 'error'
            and reason is a string describing why it was skipped (or None)
        """
        # Get the visit event
        visit = occ.event
        if not visit:
            if verbose:
                self.stdout.write(
                    self.style.WARNING(f'Occurrence {occ.id}: No event associated')
                )
            return 'error', None
        
        # Get the transect event (parent of visit)
        transect = visit.parent
        if not transect:
            if verbose or not dry_run:
                self.stdout.write(
                    self.style.WARNING(
                        f'Occurrence {occ.id}: Visit {visit.id} ({visit.name}) has no parent (transect)'
                    )
                )
            return 'skipped', 'no_parent'
        
        if transect.event_type != 'transect':
            if verbose or not dry_run:
                self.stdout.write(
                    self.style.WARNING(
                        f'Occurrence {occ.id}: Parent of visit {visit.id} is not a transect (type: {transect.event_type}, id: {transect.id})'
                    )
                )
            return 'skipped', 'not_transect'
        
        # Get the transect line geometry
        if not transect.footprintWKT:
            if verbose or not dry_run:
                self.stdout.write(
                    self.style.WARNING(
                        f'Occurrence {occ.id}: Transect {transect.id} ({transect.name}) has no geometry'
                    )
                )
            return 'skipped', 'no_geometry'
        
        # Generate a new point on the transect line (ignore existing coordinates)
        try:
            line = shapely_wkt.loads(transect.footprintWKT)
        except Exception as e:
            if verbose or not dry_run:
                self.stdout.write(
                    self.style.ERROR(
                        f'Occurrence {occ.id}: Error parsing transect geometry: {e}'
                    )
                )
            return 'error', None
        
        # Get all occurrences for this transect to distribute them along the line
        # We'll use the occurrence's position in the sequence to determine where on the line
        visit_occurrences = list(
            Occurrence.objects.filter(event=visit)
            .order_by('id')
            .values_list('id', flat=True)
        )
        
        # Find this occurrence's index in the sequence
        try:
            occurrence_index = visit_occurrences.index(occ.id)
        except ValueError:
            occurrence_index = 0
        
        # Calculate position along the line (0.0 to 1.0) with varied spacing
        total_occurrences = len(visit_occurrences)
        if total_occurrences > 1:
            # Base position: distribute evenly
            base_position = occurrence_index / (total_occurrences - 1)
            
            # Add variety to spacing using hash of occurrence ID
            hash_val = int(hashlib.md5(str(occ.id).encode()).hexdigest()[:8], 16)
            # Create varied spacing: -15% to +15% of average spacing
            spacing_variation = ((hash_val % 30) - 15) / 100.0  # -0.15 to +0.15
            # Normalize by total occurrences to get position offset
            position_offset = spacing_variation / (total_occurrences - 1)
            position = base_position + position_offset
            position = max(0.05, min(0.95, position))  # Keep away from endpoints
        else:
            position = 0.5  # Middle of line if only one occurrence
        
        # Interpolate point at this position along the line
        new_point = line.interpolate(position, normalized=True)
        
        # Add perpendicular offset (above/below the line) for some occurrences
        # Use a different part of the hash to determine offset direction and magnitude
        hash_val2 = int(hashlib.md5(str(occ.id).encode()).hexdigest()[8:16], 16)
        # 70% of occurrences get an offset, 30% stay exactly on the line
        should_offset = (hash_val2 % 100) < 70
        offset_meters = 0.0  # Initialize for verbose output
        
        if should_offset and total_occurrences > 1:
            # Calculate offset distance: 0 to 5 meters
            # Convert to degrees (approximate: 1 degree ≈ 111,000 meters)
            offset_meters = (hash_val2 % 500) / 100.0  # 0 to 5 meters
            offset_degrees = offset_meters / 111000.0
            
            # Get direction of line at this point to calculate perpendicular
            # Sample nearby points to get line direction
            sample_dist = min(0.01, 1.0 / (total_occurrences * 10))  # Small sample distance
            point_before = line.interpolate(max(0, position - sample_dist), normalized=True)
            point_after = line.interpolate(min(1, position + sample_dist), normalized=True)
            
            # Calculate direction vector
            dx = point_after.x - point_before.x
            dy = point_after.y - point_before.y
            
            # Calculate perpendicular vector (rotate 90 degrees)
            # Perpendicular to (dx, dy) is (-dy, dx) or (dy, -dx)
            # Use hash to determine which side (above or below)
            perp_dx = -dy if (hash_val2 % 2) == 0 else dy
            perp_dy = dx if (hash_val2 % 2) == 0 else -dx
            
            # Normalize perpendicular vector
            perp_length = math.sqrt(perp_dx * perp_dx + perp_dy * perp_dy)
            if perp_length > 0:
                perp_dx /= perp_length
                perp_dy /= perp_length
                
                # Apply offset
                new_point = Point(
                    new_point.x + perp_dx * offset_degrees,
                    new_point.y + perp_dy * offset_degrees
                )
        
        # Update occurrence coordinates
        new_lat = float(new_point.y)
        new_lon = float(new_point.x)
        
        if verbose:
            old_coords = f'({occ.latitude:.6f}, {occ.longitude:.6f})' if occ.latitude and occ.longitude else 'N/A'
            offset_info = ''
            if should_offset and total_occurrences > 1:
                offset_info = f', offset: {offset_meters:.2f}m'
            self.stdout.write(
                f'Occurrence {occ.id}: Generating new coordinates on transect line '
                f'(old: {old_coords}, new: ({new_lat:.6f}, {new_lon:.6f}), position: {position:.3f}{offset_info})'
            )
        
        if not dry_run:
            occ.latitude = new_lat
            occ.longitude = new_lon
            occ.save(update_fields=['latitude', 'longitude'])
        
        return 'updated', None

