#!/usr/bin/env python
"""
Health Check Endpoint
Provides system health status for monitoring and load balancers.
"""

import time
import asyncio
from typing import Dict, Any
import logging

logger = logging.getLogger(__name__)


class HealthChecker:
    """Performs health checks on various system components."""

    def __init__(self, db_connection=None):
        """
        Initialize the health checker.

        Args:
            db_connection: Database connection to check
        """
        self.db = db_connection
        self.start_time = time.time()
        self._last_check = {}

    def get_uptime(self) -> float:
        """
        Get server uptime in seconds.

        Returns:
            Uptime in seconds
        """
        return time.time() - self.start_time

    def check_database(self) -> Dict[str, Any]:
        """
        Check database connectivity and performance.

        Returns:
            Dictionary with status and timing
        """
        start = time.time()
        try:
            if not self.db or not self.db.is_connected():
                return {
                    'status': 'unhealthy',
                    'error': 'Database not connected',
                    'latency_ms': None
                }

            # Execute simple query to test connectivity
            cursor = self.db.cursor()
            cursor.execute("SELECT 1")
            cursor.fetchone()
            cursor.close()

            latency = (time.time() - start) * 1000

            return {
                'status': 'healthy' if latency < 100 else 'degraded',
                'latency_ms': round(latency, 2),
                'warning': 'High latency' if latency > 100 else None
            }
        except Exception as e:
            logger.error(f"Database health check failed: {e}")
            return {
                'status': 'unhealthy',
                'error': str(e),
                'latency_ms': None
            }

    def check_memory(self) -> Dict[str, Any]:
        """
        Check memory usage.

        Returns:
            Dictionary with memory stats
        """
        try:
            import psutil
            memory = psutil.virtual_memory()

            status = 'healthy'
            if memory.percent > 90:
                status = 'unhealthy'
            elif memory.percent > 80:
                status = 'degraded'

            return {
                'status': status,
                'used_percent': memory.percent,
                'used_mb': round(memory.used / 1024 / 1024, 2),
                'available_mb': round(memory.available / 1024 / 1024, 2),
                'total_mb': round(memory.total / 1024 / 1024, 2)
            }
        except ImportError:
            return {
                'status': 'unknown',
                'error': 'psutil not installed'
            }
        except Exception as e:
            logger.error(f"Memory check failed: {e}")
            return {
                'status': 'unknown',
                'error': str(e)
            }

    def check_disk(self) -> Dict[str, Any]:
        """
        Check disk usage.

        Returns:
            Dictionary with disk stats
        """
        try:
            import psutil
            disk = psutil.disk_usage('/')

            status = 'healthy'
            if disk.percent > 90:
                status = 'unhealthy'
            elif disk.percent > 80:
                status = 'degraded'

            return {
                'status': status,
                'used_percent': disk.percent,
                'used_gb': round(disk.used / 1024 / 1024 / 1024, 2),
                'free_gb': round(disk.free / 1024 / 1024 / 1024, 2),
                'total_gb': round(disk.total / 1024 / 1024 / 1024, 2)
            }
        except ImportError:
            return {
                'status': 'unknown',
                'error': 'psutil not installed'
            }
        except Exception as e:
            logger.error(f"Disk check failed: {e}")
            return {
                'status': 'unknown',
                'error': str(e)
            }

    def check_websocket_connections(self, active_connections: int, max_connections: int = 1000) -> Dict[str, Any]:
        """
        Check WebSocket connection health.

        Args:
            active_connections: Number of active connections
            max_connections: Maximum allowed connections

        Returns:
            Dictionary with connection stats
        """
        usage_percent = (active_connections / max_connections) * 100

        status = 'healthy'
        if usage_percent > 90:
            status = 'degraded'
        elif usage_percent > 95:
            status = 'unhealthy'

        return {
            'status': status,
            'active': active_connections,
            'max': max_connections,
            'usage_percent': round(usage_percent, 2)
        }

    def perform_full_check(self, active_connections: int = 0) -> Dict[str, Any]:
        """
        Perform a comprehensive health check.

        Args:
            active_connections: Number of active WebSocket connections

        Returns:
            Dictionary with overall health status
        """
        checks = {
            'database': self.check_database(),
            'memory': self.check_memory(),
            'disk': self.check_disk(),
            'websockets': self.check_websocket_connections(active_connections)
        }

        # Determine overall status
        statuses = [check['status'] for check in checks.values()]
        if 'unhealthy' in statuses:
            overall_status = 'unhealthy'
        elif 'degraded' in statuses:
            overall_status = 'degraded'
        elif 'unknown' in statuses:
            overall_status = 'unknown'
        else:
            overall_status = 'healthy'

        return {
            'status': overall_status,
            'timestamp': time.time(),
            'uptime_seconds': round(self.get_uptime(), 2),
            'checks': checks
        }


# Global health checker instance
_health_checker = None


def get_health_checker(db_connection=None) -> HealthChecker:
    """
    Get or create the global health checker instance.

    Args:
        db_connection: Optional database connection

    Returns:
        HealthChecker instance
    """
    global _health_checker
    if _health_checker is None:
        _health_checker = HealthChecker(db_connection)
    return _health_checker


# Integration with app.py WebSocket server
"""
Add this to your app.py to enable the health check endpoint:

from deployment.health_check import get_health_checker

# Initialize health checker with database connection
health_checker = get_health_checker(mydb)

# Add health check handler
async def health_check_handler(websocket, path):
    '''Handle health check requests.'''
    if path == '/health':
        # Get active connection count
        active_connections = len(USERS)

        # Perform health check
        health_status = health_checker.perform_full_check(active_connections)

        # Send response
        response = json.dumps(health_status)
        await websocket.send(response)
        await websocket.close()

# Modify your WebSocket server start
async def main():
    async with websockets.serve(
        lambda ws, path: health_check_handler(ws, path) if path == '/health' else consumer(ws, path),
        "0.0.0.0",
        8001
    ):
        await asyncio.Future()  # run forever

if __name__ == "__main__":
    asyncio.run(main())
"""


# Simple HTTP health check endpoint (for use with standard web servers)
def create_http_health_endpoint():
    """
    Create a simple HTTP health check endpoint.
    This can be used with frameworks like Flask or FastAPI.

    Example with Flask:
        from flask import Flask, jsonify
        from deployment.health_check import get_health_checker

        app = Flask(__name__)
        health_checker = get_health_checker(db_connection)

        @app.route('/health')
        def health():
            status = health_checker.perform_full_check()
            code = 200 if status['status'] == 'healthy' else 503
            return jsonify(status), code

    Example with FastAPI:
        from fastapi import FastAPI
        from deployment.health_check import get_health_checker

        app = FastAPI()
        health_checker = get_health_checker(db_connection)

        @app.get('/health')
        async def health():
            status = health_checker.perform_full_check()
            return status
    """
    pass


# Liveness and Readiness probes (Kubernetes-style)
def liveness_probe(db_connection=None) -> bool:
    """
    Simple liveness probe - checks if the service is running.
    Returns True if the service is alive, False otherwise.

    Args:
        db_connection: Database connection

    Returns:
        True if alive
    """
    try:
        checker = get_health_checker(db_connection)
        # Simple check - just verify we can run
        checker.get_uptime()
        return True
    except Exception as e:
        logger.error(f"Liveness probe failed: {e}")
        return False


def readiness_probe(db_connection=None, active_connections: int = 0) -> bool:
    """
    Readiness probe - checks if the service is ready to accept traffic.
    Returns True if ready, False otherwise.

    Args:
        db_connection: Database connection
        active_connections: Number of active connections

    Returns:
        True if ready
    """
    try:
        checker = get_health_checker(db_connection)
        health = checker.perform_full_check(active_connections)
        # Ready if status is healthy or degraded (but not unhealthy)
        return health['status'] in ['healthy', 'degraded']
    except Exception as e:
        logger.error(f"Readiness probe failed: {e}")
        return False
