#!/usr/bin/env python
"""
Batch Messaging System
Queues and batches WebSocket messages to reduce network overhead.
"""

import asyncio
import json
import time
from typing import Dict, List, Any
from collections import defaultdict
import logging

logger = logging.getLogger(__name__)


class MessageBatcher:
    """
    Batches WebSocket messages for efficient transmission.
    Collects messages and sends them in batches to reduce overhead.
    """

    def __init__(self, batch_size=10, flush_interval=0.1):
        """
        Initialize the message batcher.

        Args:
            batch_size: Maximum messages per batch
            flush_interval: Maximum seconds to wait before flushing (default: 100ms)
        """
        self.batch_size = batch_size
        self.flush_interval = flush_interval
        self._queues = defaultdict(list)
        self._last_flush = defaultdict(float)
        self._lock = asyncio.Lock()
        self._flush_task = None

    async def queue_message(self, player_id: str, message: Dict[str, Any]):
        """
        Queue a message for batched delivery.

        Args:
            player_id: Target player's ID
            message: Message dictionary to send
        """
        async with self._lock:
            self._queues[player_id].append({
                'timestamp': time.time(),
                'data': message
            })

            # Auto-flush if batch size reached
            if len(self._queues[player_id]) >= self.batch_size:
                await self._flush_player_queue(player_id)

    async def _flush_player_queue(self, player_id: str) -> List[Dict[str, Any]]:
        """
        Flush queued messages for a specific player.

        Args:
            player_id: Player's ID

        Returns:
            List of batched messages
        """
        if player_id not in self._queues or not self._queues[player_id]:
            return []

        messages = self._queues[player_id]
        self._queues[player_id] = []
        self._last_flush[player_id] = time.time()

        logger.debug(f"Flushed {len(messages)} messages for player {player_id}")
        return messages

    async def flush_all(self) -> Dict[str, List[Dict[str, Any]]]:
        """
        Flush all queued messages for all players.

        Returns:
            Dictionary mapping player_id to their batched messages
        """
        async with self._lock:
            batched = {}
            for player_id in list(self._queues.keys()):
                messages = await self._flush_player_queue(player_id)
                if messages:
                    batched[player_id] = messages
            return batched

    async def flush_if_needed(self, player_id: str) -> List[Dict[str, Any]]:
        """
        Flush messages if flush interval has elapsed.

        Args:
            player_id: Player's ID

        Returns:
            List of flushed messages (empty if not flushed)
        """
        async with self._lock:
            last_flush = self._last_flush.get(player_id, 0)
            if time.time() - last_flush >= self.flush_interval:
                return await self._flush_player_queue(player_id)
            return []

    async def start_auto_flush(self):
        """Start background task to auto-flush messages at intervals."""
        async def auto_flush_loop():
            while True:
                await asyncio.sleep(self.flush_interval)
                batched = await self.flush_all()
                if batched:
                    # This would integrate with the WebSocket send function
                    logger.debug(f"Auto-flushed messages for {len(batched)} players")

        if not self._flush_task or self._flush_task.done():
            self._flush_task = asyncio.create_task(auto_flush_loop())
            logger.info("Started auto-flush background task")

    def stop_auto_flush(self):
        """Stop the auto-flush background task."""
        if self._flush_task and not self._flush_task.done():
            self._flush_task.cancel()
            logger.info("Stopped auto-flush background task")

    def get_stats(self) -> Dict[str, Any]:
        """
        Get batcher statistics.

        Returns:
            Dictionary with queue stats
        """
        return {
            'active_queues': len(self._queues),
            'total_queued': sum(len(q) for q in self._queues.values()),
            'batch_size': self.batch_size,
            'flush_interval': self.flush_interval
        }


# Global batcher instance
message_batcher = MessageBatcher(batch_size=10, flush_interval=0.1)


async def send_batched_message(player_id: str, message_type: str, data: Dict[str, Any]):
    """
    Queue a message for batched delivery to a player.

    Args:
        player_id: Target player's ID
        message_type: Type of message (e.g., 'stat_update', 'event')
        data: Message payload
    """
    message = {
        'type': message_type,
        'data': data,
        'timestamp': time.time()
    }
    await message_batcher.queue_message(player_id, message)


async def flush_player_messages(player_id: str, send_function):
    """
    Flush and send all queued messages for a player.

    Args:
        player_id: Player's ID
        send_function: Async function to send messages (websocket.send)
    """
    messages = await message_batcher._flush_player_queue(player_id)
    if messages:
        # Send as a batch
        batch = {
            'type': 'batch',
            'messages': messages,
            'count': len(messages)
        }
        await send_function(json.dumps(batch))
