"""
Rate limiting for expensive operations.

Uses token bucket algorithm to prevent API abuse.
"""

from collections import defaultdict
from datetime import datetime, timedelta
from typing import Dict
from config import config


class RateLimiter:
    """
    Token bucket rate limiter.

    Usage:
        limiter = RateLimiter(max_requests=60, window_seconds=3600)
        if limiter.is_allowed('user_123'):
            # Allow request
        else:
            # Reject (rate limited)
    """

    def __init__(self, max_requests: int, window_seconds: int):
        """
        Initialize rate limiter.

        Args:
            max_requests: Maximum requests allowed in window
            window_seconds: Time window in seconds
        """
        self.max_requests = max_requests
        self.window = timedelta(seconds=window_seconds)
        self.requests: Dict[str, list] = defaultdict(list)

    def is_allowed(self, identifier: str) -> bool:
        """
        Check if request is allowed.

        Args:
            identifier: User ID or other identifier

        Returns:
            True if allowed, False if rate limited
        """
        now = datetime.now()
        cutoff = now - self.window

        # Clean old requests
        self.requests[identifier] = [
            req_time for req_time in self.requests[identifier]
            if req_time > cutoff
        ]

        # Check limit
        if len(self.requests[identifier]) >= self.max_requests:
            print(f"Rate limit exceeded for {identifier}")
            return False

        # Record request
        self.requests[identifier].append(now)
        return True

    def get_remaining(self, identifier: str) -> int:
        """
        Get remaining requests in current window.

        Args:
            identifier: User ID or other identifier

        Returns:
            Number of requests remaining
        """
        now = datetime.now()
        cutoff = now - self.window

        # Clean old requests
        self.requests[identifier] = [
            req_time for req_time in self.requests[identifier]
            if req_time > cutoff
        ]

        return max(0, self.max_requests - len(self.requests[identifier]))

    def reset(self, identifier: str):
        """Reset rate limit for identifier"""
        if identifier in self.requests:
            del self.requests[identifier]


# Global rate limiters
openai_limiter = RateLimiter(
    max_requests=config.OPENAI_MAX_REQUESTS_PER_HOUR,
    window_seconds=3600
)

websocket_limiter = RateLimiter(
    max_requests=config.WEBSOCKET_MAX_MESSAGES_PER_MINUTE,
    window_seconds=60
)


def check_openai_rate_limit(user_id: str) -> bool:
    """
    Check if user can make OpenAI request.

    Args:
        user_id: User ID

    Returns:
        True if allowed, False if rate limited
    """
    if not openai_limiter.is_allowed(user_id):
        remaining = openai_limiter.get_remaining(user_id)
        print(f"OpenAI rate limit exceeded for {user_id}. Remaining: {remaining}")
        return False
    return True


def check_websocket_rate_limit(user_id: str) -> bool:
    """
    Check if user can send WebSocket message.

    Args:
        user_id: User ID

    Returns:
        True if allowed, False if rate limited
    """
    if not websocket_limiter.is_allowed(user_id):
        print(f"WebSocket rate limit exceeded for {user_id}")
        return False
    return True
