"""
API Usage Tracker for BaoLife

Tracks API usage and costs for OpenAI calls.
Provides insights into spending patterns and enables budget enforcement.
"""

# Avoid circular import with functions.py by importing directly from source
from database.db_operations import get_database_connection
from datetime import datetime, timedelta


class APIUsageTracker:
    """
    Track API usage and costs for monitoring and budget management.
    """

    # OpenAI pricing (as of Jan 2025)
    PRICING = {
        'gpt-4o-mini': {
            'input': 0.15 / 1_000_000,   # $0.15 per 1M input tokens
            'output': 0.60 / 1_000_000   # $0.60 per 1M output tokens
        },
        'gpt-4o': {
            'input': 2.50 / 1_000_000,   # $2.50 per 1M input tokens
            'output': 10.00 / 1_000_000  # $10.00 per 1M output tokens
        }
    }

    def __init__(self):
        """Initialize API usage tracker"""
        self._table_created = False

    def create_table_if_not_exists(self):
        """Create api_usage table if it doesn't exist"""
        mydb = get_database_connection()
        try:
            with mydb.cursor() as cursor:
                cursor.execute("""
                    CREATE TABLE IF NOT EXISTS api_usage (
                        id INT AUTO_INCREMENT PRIMARY KEY,
                        player_id VARCHAR(36) NOT NULL,
                        conversation_id VARCHAR(36),
                        endpoint VARCHAR(50) NOT NULL,
                        model VARCHAR(50) NOT NULL,
                        prompt_tokens INT NOT NULL,
                        completion_tokens INT NOT NULL,
                        total_tokens INT NOT NULL,
                        cost_usd DECIMAL(10, 6) NOT NULL,
                        created_date DATETIME NOT NULL,
                        purpose VARCHAR(100),
                        INDEX idx_player (player_id),
                        INDEX idx_created (created_date),
                        INDEX idx_conversation (conversation_id),
                        INDEX idx_model (model)
                    )
                """)
                mydb.commit()
                print("API usage table created/verified")

        except Exception as e:
            print(f"Failed to create api_usage table: {e}")

        finally:
            mydb.close()

    def track_usage(self, player_id, conversation_id, model, usage, purpose='conversation'):
        """
        Log API usage to database.

        Args:
            player_id: Player ID
            conversation_id: Conversation ID (optional)
            model: Model name (e.g., 'gpt-4o-mini')
            usage: Usage dict with 'prompt_tokens', 'completion_tokens', 'total_tokens'
            purpose: Purpose of the call (e.g., 'conversation', 'summarization', 'fact_extraction')
        """
        # Lazy table creation on first use
        if not self._table_created:
            self.create_table_if_not_exists()
            self._table_created = True

        prompt_tokens = usage.get('prompt_tokens', 0)
        completion_tokens = usage.get('completion_tokens', 0)
        total_tokens = usage.get('total_tokens', 0)

        # Calculate cost
        if model in self.PRICING:
            cost = (
                prompt_tokens * self.PRICING[model]['input'] +
                completion_tokens * self.PRICING[model]['output']
            )
        else:
            # Unknown model, estimate using gpt-4o-mini pricing
            cost = (
                prompt_tokens * self.PRICING['gpt-4o-mini']['input'] +
                completion_tokens * self.PRICING['gpt-4o-mini']['output']
            )
            print(f"Warning: Unknown model '{model}', using gpt-4o-mini pricing")

        # Save to database
        mydb = get_database_connection()
        try:
            with mydb.cursor() as cursor:
                cursor.execute("""
                    INSERT INTO api_usage
                    (player_id, conversation_id, endpoint, model,
                     prompt_tokens, completion_tokens, total_tokens, cost_usd,
                     created_date, purpose)
                    VALUES (%s, %s, 'openai_chat', %s, %s, %s, %s, %s, NOW(), %s)
                """, (
                    player_id,
                    conversation_id,
                    model,
                    prompt_tokens,
                    completion_tokens,
                    total_tokens,
                    cost,
                    purpose
                ))
                mydb.commit()

        except Exception as e:
            print(f"Failed to track API usage: {e}")

        finally:
            mydb.close()

        return cost

    def get_player_usage(self, player_id, days=30):
        """
        Get player's API usage statistics for last N days.

        Args:
            player_id: Player ID
            days: Number of days to look back

        Returns:
            Dict with usage statistics
        """
        mydb = get_database_connection()
        try:
            with mydb.cursor(dictionary=True) as cursor:
                cursor.execute("""
                    SELECT
                        SUM(cost_usd) as total_cost,
                        SUM(total_tokens) as total_tokens,
                        COUNT(*) as total_calls,
                        AVG(cost_usd) as avg_cost_per_call,
                        MIN(created_date) as first_call,
                        MAX(created_date) as last_call
                    FROM api_usage
                    WHERE player_id = %s
                    AND created_date >= DATE_SUB(NOW(), INTERVAL %s DAY)
                """, (player_id, days))

                result = cursor.fetchone()

                if result and result['total_calls']:
                    return {
                        'total_cost': float(result['total_cost'] or 0),
                        'total_tokens': int(result['total_tokens'] or 0),
                        'total_calls': int(result['total_calls'] or 0),
                        'avg_cost_per_call': float(result['avg_cost_per_call'] or 0),
                        'first_call': result['first_call'],
                        'last_call': result['last_call'],
                        'period_days': days
                    }
                else:
                    return {
                        'total_cost': 0,
                        'total_tokens': 0,
                        'total_calls': 0,
                        'avg_cost_per_call': 0,
                        'first_call': None,
                        'last_call': None,
                        'period_days': days
                    }

        except Exception as e:
            print(f"Failed to get player usage: {e}")
            return None

        finally:
            mydb.close()

    def get_usage_by_purpose(self, player_id=None, days=30):
        """
        Get usage breakdown by purpose.

        Args:
            player_id: Optional player ID (None for all players)
            days: Number of days to look back

        Returns:
            List of dicts with usage by purpose
        """
        mydb = get_database_connection()
        try:
            with mydb.cursor(dictionary=True) as cursor:
                if player_id:
                    cursor.execute("""
                        SELECT
                            purpose,
                            COUNT(*) as call_count,
                            SUM(total_tokens) as total_tokens,
                            SUM(cost_usd) as total_cost,
                            AVG(cost_usd) as avg_cost
                        FROM api_usage
                        WHERE player_id = %s
                        AND created_date >= DATE_SUB(NOW(), INTERVAL %s DAY)
                        GROUP BY purpose
                        ORDER BY total_cost DESC
                    """, (player_id, days))
                else:
                    cursor.execute("""
                        SELECT
                            purpose,
                            COUNT(*) as call_count,
                            SUM(total_tokens) as total_tokens,
                            SUM(cost_usd) as total_cost,
                            AVG(cost_usd) as avg_cost
                        FROM api_usage
                        WHERE created_date >= DATE_SUB(NOW(), INTERVAL %s DAY)
                        GROUP BY purpose
                        ORDER BY total_cost DESC
                    """, (days,))

                results = cursor.fetchall()
                return [{
                    'purpose': row['purpose'],
                    'call_count': int(row['call_count']),
                    'total_tokens': int(row['total_tokens']),
                    'total_cost': float(row['total_cost']),
                    'avg_cost': float(row['avg_cost'])
                } for row in results]

        except Exception as e:
            print(f"Failed to get usage by purpose: {e}")
            return []

        finally:
            mydb.close()

    def get_total_usage(self, days=30):
        """
        Get total usage across all players.

        Args:
            days: Number of days to look back

        Returns:
            Dict with aggregate statistics
        """
        mydb = get_database_connection()
        try:
            with mydb.cursor(dictionary=True) as cursor:
                cursor.execute("""
                    SELECT
                        COUNT(DISTINCT player_id) as unique_players,
                        SUM(cost_usd) as total_cost,
                        SUM(total_tokens) as total_tokens,
                        COUNT(*) as total_calls,
                        AVG(cost_usd) as avg_cost_per_call
                    FROM api_usage
                    WHERE created_date >= DATE_SUB(NOW(), INTERVAL %s DAY)
                """, (days,))

                result = cursor.fetchone()

                if result:
                    return {
                        'unique_players': int(result['unique_players'] or 0),
                        'total_cost': float(result['total_cost'] or 0),
                        'total_tokens': int(result['total_tokens'] or 0),
                        'total_calls': int(result['total_calls'] or 0),
                        'avg_cost_per_call': float(result['avg_cost_per_call'] or 0),
                        'period_days': days
                    }

        except Exception as e:
            print(f"Failed to get total usage: {e}")
            return None

        finally:
            mydb.close()

    def check_player_budget(self, player_id, monthly_limit=5.00):
        """
        Check if player is within budget.

        Args:
            player_id: Player ID
            monthly_limit: Monthly budget limit in USD

        Returns:
            Dict with budget status
        """
        usage = self.get_player_usage(player_id, days=30)

        if usage:
            total_cost = usage['total_cost']
            remaining = monthly_limit - total_cost
            percentage_used = (total_cost / monthly_limit) * 100 if monthly_limit > 0 else 0

            return {
                'total_cost': total_cost,
                'monthly_limit': monthly_limit,
                'remaining': remaining,
                'percentage_used': percentage_used,
                'over_budget': total_cost > monthly_limit,
                'warning': percentage_used > 80  # Warning at 80%
            }

        return None

    def print_usage_report(self, player_id=None, days=7):
        """
        Print usage report to console.

        Args:
            player_id: Optional player ID (None for total usage)
            days: Number of days to report
        """
        print(f"\n{'='*60}")
        if player_id:
            print(f"API Usage Report for Player {player_id} (Last {days} days)")
            usage = self.get_player_usage(player_id, days)
        else:
            print(f"Total API Usage Report (Last {days} days)")
            usage = self.get_total_usage(days)

        if usage:
            print(f"{'='*60}")
            print(f"Total Cost:      ${usage['total_cost']:.4f}")
            print(f"Total Tokens:    {usage['total_tokens']:,}")
            print(f"Total Calls:     {usage['total_calls']:,}")
            print(f"Avg Cost/Call:   ${usage['avg_cost_per_call']:.4f}")

            if not player_id and 'unique_players' in usage:
                print(f"Unique Players:  {usage['unique_players']}")

            # Show breakdown by purpose
            print(f"\n{'-'*60}")
            print("Usage by Purpose:")
            by_purpose = self.get_usage_by_purpose(player_id, days)
            for item in by_purpose:
                print(f"  {item['purpose']:20s} | Calls: {item['call_count']:4d} | "
                      f"Cost: ${item['total_cost']:6.4f} | Avg: ${item['avg_cost']:6.4f}")

        print(f"{'='*60}\n")


# Initialize tracker instance (table created lazily on first use to avoid import-time DB connections)
tracker = APIUsageTracker()
