#!/usr/bin/env python
"""
BaoLife AI Image Generation Module

Handles automated image generation using Google Imagen 4 Ultra (primary), FLUX 1.1 Pro, and DALL-E 3 APIs.
Supports caching, queue management, and cozy cartoon style generation.
"""

import os
import asyncio
import json
import logging
from datetime import datetime
from typing import Optional, Dict, List, Tuple
import aiohttp
from config import config
from database.db_operations import get_database_connection

# Optional OpenAI import (only needed for DALL-E 3)
try:
    from openai import AsyncOpenAI
    OPENAI_AVAILABLE = True
except ImportError:
    AsyncOpenAI = None
    OPENAI_AVAILABLE = False

logger = logging.getLogger(__name__)

# ============================================================
# Configuration
# ============================================================

# API Keys from config
OPENAI_API_KEY = config.OPENAI_API_KEY
FAL_AI_KEY = config.FAL_AI_KEY
REPLICATE_API_TOKEN = config.REPLICATE_API_TOKEN

# Default provider from config
DEFAULT_PROVIDER = config.IMAGE_GENERATION_PROVIDER

# Style preset for cozy cartoon style (tested and optimized for Imagen 4 Ultra)
COZY_CARTOON_STYLE = """cozy cartoon style, illustration, warm colors,
friendly atmosphere, high quality digital art, clean lines,
--no words, text, letters, face"""


# ============================================================
# Image Generation Providers
# ============================================================

class ImageGenerator:
    """Base class for image generation providers"""

    def __init__(self, provider: str = DEFAULT_PROVIDER):
        self.provider = provider
        # Only initialize OpenAI client if library is available and key is set
        if OPENAI_AVAILABLE and OPENAI_API_KEY:
            self.openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
        else:
            self.openai_client = None

    async def generate_image(
        self,
        prompt: str,
        style: str = 'cozy_cartoon',
        width: int = 1024,
        height: int = 1024
    ) -> Tuple[Optional[str], Optional[str]]:
        """
        Generate an image using the specified provider.

        Returns:
            Tuple[image_url, error_message]
        """
        if self.provider == 'dalle3':
            return await self._generate_dalle3(prompt, style, width, height)
        elif self.provider == 'flux':
            return await self._generate_flux(prompt, style, width, height)
        elif self.provider == 'imagen4':
            return await self._generate_imagen4(prompt, style, width, height)
        else:
            return None, f"Unknown provider: {self.provider}"

    async def _generate_dalle3(
        self,
        prompt: str,
        style: str,
        width: int,
        height: int
    ) -> Tuple[Optional[str], Optional[str]]:
        """Generate image using OpenAI DALL-E 3"""
        if not OPENAI_AVAILABLE:
            return None, "OpenAI library not installed. Install with: pip install --upgrade openai"
        if not self.openai_client:
            return None, "OpenAI API key not configured"

        try:
            # Apply cozy cartoon style to prompt
            full_prompt = self._apply_style(prompt, style)

            # DALL-E 3 supports 1024x1024, 1024x1792, 1792x1024
            size = "1024x1024"
            if width == 1024 and height == 1792:
                size = "1024x1792"
            elif width == 1792 and height == 1024:
                size = "1792x1024"

            logger.info(f"Generating DALL-E 3 image: {full_prompt[:100]}...")

            response = await self.openai_client.images.generate(
                model="dall-e-3",
                prompt=full_prompt,
                size=size,
                quality="standard",  # or "hd" for higher quality ($0.08 vs $0.04)
                n=1,
            )

            image_url = response.data[0].url
            logger.info(f"DALL-E 3 image generated: {image_url}")
            return image_url, None

        except Exception as e:
            error_msg = f"DALL-E 3 generation failed: {str(e)}"
            logger.error(error_msg)
            return None, error_msg

    async def _generate_flux(
        self,
        prompt: str,
        style: str,
        width: int,
        height: int
    ) -> Tuple[Optional[str], Optional[str]]:
        """Generate image using FLUX 1.1 Pro via fal.ai"""
        if not FAL_AI_KEY:
            return None, "fal.ai API key not configured"

        try:
            # Apply cozy cartoon style to prompt
            full_prompt = self._apply_style(prompt, style)

            logger.info(f"Generating FLUX 1.1 Pro image: {full_prompt[:100]}...")

            async with aiohttp.ClientSession() as session:
                headers = {
                    "Authorization": f"Key {FAL_AI_KEY}",
                    "Content-Type": "application/json"
                }

                payload = {
                    "prompt": full_prompt,
                    "image_size": {
                        "width": width,
                        "height": height
                    },
                    "num_inference_steps": 28,  # Default for quality
                    "guidance_scale": 3.5,
                    "num_images": 1,
                    "enable_safety_checker": True,
                    "output_format": "jpeg",
                    "safety_tolerance": "2"
                }

                # Submit generation request
                async with session.post(
                    "https://fal.run/fal-ai/flux-pro/v1.1",
                    headers=headers,
                    json=payload
                ) as response:
                    if response.status != 200:
                        error_text = await response.text()
                        return None, f"FLUX API error {response.status}: {error_text}"

                    result = await response.json()

                    # Extract image URL
                    if 'images' in result and len(result['images']) > 0:
                        image_url = result['images'][0]['url']
                        logger.info(f"FLUX image generated: {image_url}")
                        return image_url, None
                    else:
                        return None, "No image returned from FLUX API"

        except Exception as e:
            error_msg = f"FLUX generation failed: {str(e)}"
            logger.error(error_msg)
            return None, error_msg

    async def _generate_imagen4(
        self,
        prompt: str,
        style: str,
        width: int,
        height: int
    ) -> Tuple[Optional[str], Optional[str]]:
        """Generate image using Google Imagen 4 Ultra via fal.ai"""
        if not FAL_AI_KEY:
            return None, "fal.ai API key not configured"

        try:
            # Apply cozy cartoon style to prompt
            full_prompt = self._apply_style(prompt, style)

            logger.info(f"Generating Imagen 4 Ultra image: {full_prompt[:100]}...")

            async with aiohttp.ClientSession() as session:
                headers = {
                    "Authorization": f"Key {FAL_AI_KEY}",
                    "Content-Type": "application/json"
                }

                payload = {
                    "prompt": full_prompt,
                    "image_size": "landscape_16_9",  # or "square", "portrait_4_3", "portrait_3_4"
                    "num_images": 1,
                    "enable_safety_checker": False,
                    "safety_tolerance": "6"
                }

                # Adjust image_size based on dimensions
                if width == height:
                    payload["image_size"] = "square"
                elif width > height:
                    payload["image_size"] = "landscape_16_9"
                else:
                    payload["image_size"] = "portrait_3_4"

                # Submit generation request
                async with session.post(
                    "https://fal.run/fal-ai/imagen4/preview/ultra",
                    headers=headers,
                    json=payload
                ) as response:
                    if response.status != 200:
                        error_text = await response.text()
                        return None, f"Imagen 4 API error {response.status}: {error_text}"

                    result = await response.json()

                    # Extract image URL
                    if 'images' in result and len(result['images']) > 0:
                        image_url = result['images'][0]['url']
                        logger.info(f"Imagen 4 Ultra image generated: {image_url}")
                        return image_url, None
                    else:
                        return None, "No image returned from Imagen 4 API"

        except Exception as e:
            error_msg = f"Imagen 4 generation failed: {str(e)}"
            logger.error(error_msg)
            return None, error_msg

    def _apply_style(self, prompt: str, style: str) -> str:
        """Apply style preset to prompt"""
        if style == 'cozy_cartoon':
            return f"{prompt}, {COZY_CARTOON_STYLE}"
        return prompt


# ============================================================
# Database Operations
# ============================================================

async def save_generated_image(
    image_url: str,
    prompt: str,
    provider: str,
    event_type: Optional[str] = None,
    event_category: Optional[str] = None,
    style_preset: str = 'cozy_cartoon',
    width: int = 1024,
    height: int = 1024,
    cost: float = 0.04,
    tags: Optional[List[str]] = None
) -> Optional[int]:
    """Save generated image to database"""
    db = None
    cursor = None
    try:
        db = get_database_connection()
        cursor = db.cursor()

        tags_json = json.dumps(tags) if tags else None

        cursor.execute("""
            INSERT INTO generated_images
            (image_url, prompt, style_preset, event_type, event_category,
             provider, generation_cost, image_width, image_height, tags)
            VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
        """, (image_url, prompt, style_preset, event_type, event_category,
              provider, cost, width, height, tags_json))

        db.commit()
        image_id = cursor.lastrowid
        logger.info(f"Saved generated image {image_id}: {image_url}")
        return image_id

    except Exception as e:
        logger.error(f"Failed to save generated image: {e}")
        if db:
            db.rollback()
        return None
    finally:
        if cursor:
            cursor.close()
        if db:
            db.close()


async def get_cached_image(
    event_type: Optional[str] = None,
    event_category: Optional[str] = None,
    prompt_keywords: Optional[str] = None
) -> Optional[Dict]:
    """
    Retrieve a cached image from database.

    Args:
        event_type: Exact event type match
        event_category: Event category match
        prompt_keywords: Keywords to search in prompt

    Returns:
        Dict with image data or None
    """
    try:
        db = get_database_connection()
        cursor = db.cursor(dictionary=True)

        query = "SELECT * FROM generated_images WHERE is_active = TRUE"
        params = []

        if event_type:
            query += " AND event_type = %s"
            params.append(event_type)

        if event_category:
            query += " AND event_category = %s"
            params.append(event_category)

        if prompt_keywords:
            query += " AND MATCH(prompt) AGAINST(%s IN NATURAL LANGUAGE MODE)"
            params.append(prompt_keywords)

        query += " ORDER BY created_at DESC LIMIT 1"

        cursor.execute(query, params)
        result = cursor.fetchone()

        if result:
            logger.info(f"Found cached image: {result['id']}")

        return result

    except Exception as e:
        logger.error(f"Failed to retrieve cached image: {e}")
        return None


async def add_to_generation_queue(
    prompt: str,
    event_type: Optional[str] = None,
    event_category: Optional[str] = None,
    style_preset: str = 'cozy_cartoon',
    priority: int = 5,
    provider: str = DEFAULT_PROVIDER
) -> Optional[int]:
    """Add image generation request to queue"""
    try:
        db = get_database_connection()
        cursor = db.cursor()

        cursor.execute("""
            INSERT INTO image_generation_queue
            (prompt, style_preset, event_type, event_category, priority, provider)
            VALUES (%s, %s, %s, %s, %s, %s)
        """, (prompt, style_preset, event_type, event_category, priority, provider))

        db.commit()
        queue_id = cursor.lastrowid
        logger.info(f"Added to generation queue: {queue_id}")
        return queue_id

    except Exception as e:
        logger.error(f"Failed to add to generation queue: {e}")
        return None


async def process_generation_queue(batch_size: int = 10):
    """Process pending image generation requests from queue"""
    try:
        db = get_database_connection()
        cursor = db.cursor(dictionary=True)

        # Get pending items ordered by priority
        cursor.execute("""
            SELECT * FROM image_generation_queue
            WHERE status = 'pending' AND attempts < max_attempts
            ORDER BY priority DESC, created_at ASC
            LIMIT %s
        """, (batch_size,))

        items = cursor.fetchall()

        if not items:
            logger.info("No pending images in queue")
            return

        logger.info(f"Processing {len(items)} images from queue")

        for item in items:
            await process_queue_item(item)

    except Exception as e:
        logger.error(f"Failed to process generation queue: {e}")


async def process_queue_item(item: Dict):
    """Process a single queue item"""
    try:
        db = get_database_connection()
        cursor = db.cursor()

        # Update status to processing
        cursor.execute("""
            UPDATE image_generation_queue
            SET status = 'processing', attempts = attempts + 1, updated_at = NOW()
            WHERE id = %s
        """, (item['id'],))
        db.commit()

        # Generate image
        generator = ImageGenerator(provider=item['provider'])
        image_url, error = await generator.generate_image(
            prompt=item['prompt'],
            style=item['style_preset']
        )

        if image_url:
            # Save to generated_images
            image_id = await save_generated_image(
                image_url=image_url,
                prompt=item['prompt'],
                provider=item['provider'],
                event_type=item['event_type'],
                event_category=item['event_category'],
                style_preset=item['style_preset']
            )

            # Update queue item
            cursor.execute("""
                UPDATE image_generation_queue
                SET status = 'completed', generated_image_id = %s,
                    completed_at = NOW(), updated_at = NOW()
                WHERE id = %s
            """, (image_id, item['id']))
            db.commit()

            logger.info(f"Queue item {item['id']} completed successfully")
        else:
            # Mark as failed
            cursor.execute("""
                UPDATE image_generation_queue
                SET status = 'failed', error_message = %s, updated_at = NOW()
                WHERE id = %s
            """, (error, item['id']))
            db.commit()

            logger.error(f"Queue item {item['id']} failed: {error}")

    except Exception as e:
        logger.error(f"Failed to process queue item {item['id']}: {e}")


# ============================================================
# Helper Functions
# ============================================================

async def generate_and_cache_image(
    prompt: str,
    event_type: Optional[str] = None,
    event_category: Optional[str] = None,
    style: str = 'cozy_cartoon',
    provider: str = DEFAULT_PROVIDER,
    use_cache: bool = True
) -> Optional[str]:
    """
    Generate an image and cache it, or return cached version if available.

    Returns:
        Image URL or None
    """
    # Check cache first
    if use_cache:
        cached = await get_cached_image(
            event_type=event_type,
            event_category=event_category
        )
        if cached:
            logger.info(f"Using cached image: {cached['image_url']}")
            return cached['image_url']

    # Generate new image
    generator = ImageGenerator(provider=provider)
    image_url, error = await generator.generate_image(prompt, style)

    if image_url:
        # Save to cache
        await save_generated_image(
            image_url=image_url,
            prompt=prompt,
            provider=provider,
            event_type=event_type,
            event_category=event_category,
            style_preset=style
        )
        return image_url
    else:
        logger.error(f"Image generation failed: {error}")
        return None


# ============================================================
# Event Image Helpers
# ============================================================

def get_event_image_url(event_type: str, event_category: Optional[str] = None) -> Optional[str]:
    """
    Get image URL for an event (synchronous for use in event functions).
    Returns cached image or None.
    """
    try:
        db = get_database_connection()
        cursor = db.cursor(dictionary=True)

        cursor.execute("""
            SELECT gi.image_url
            FROM event_images ei
            JOIN generated_images gi ON ei.generated_image_id = gi.id
            WHERE ei.event_type = %s AND gi.is_active = TRUE
            ORDER BY ei.is_primary DESC, ei.display_order ASC
            LIMIT 1
        """, (event_type,))

        result = cursor.fetchone()
        return result['image_url'] if result else None

    except Exception as e:
        logger.error(f"Failed to get event image: {e}")
        return None


# ============================================================
# Testing
# ============================================================

async def test_generation():
    """Test image generation with sample prompts"""
    test_prompts = [
        {
            "prompt": "large contemporary architecture high school building exterior",
            "event_type": "school_exterior",
            "event_category": "education"
        },
        {
            "prompt": "cozy bedroom with desk and computer, teenage room",
            "event_type": "bedroom",
            "event_category": "home"
        },
        {
            "prompt": "happy family having dinner together at dining table",
            "event_type": "family_dinner",
            "event_category": "family"
        }
    ]

    for test in test_prompts:
        logger.info(f"Testing: {test['prompt']}")
        image_url = await generate_and_cache_image(
            prompt=test['prompt'],
            event_type=test['event_type'],
            event_category=test['event_category']
        )
        if image_url:
            logger.info(f"✓ Generated: {image_url}")
        else:
            logger.error(f"✗ Failed: {test['prompt']}")

        # Wait between requests to avoid rate limiting
        await asyncio.sleep(2)


if __name__ == "__main__":
    # Run test
    logging.basicConfig(level=logging.INFO)
    asyncio.run(test_generation())
