"""
Integration tests for async database operations (database_async.py).

Tests pool management, query execution, fetch operations, transactions,
and error handling with a real test database.

Run with: pytest tests/integration/test_database_async.py -v
"""

import pytest
import asyncio
import aiomysql
from unittest.mock import patch, AsyncMock, MagicMock
import time

# Import the module under test
from ws.database_async import (
    initialize_pool,
    close_pool,
    get_connection,
    execute_query,
    execute_many,
    fetch_one,
    fetch_all,
    fetch_dict_one,
    fetch_dict_all,
    Transaction,
    _pool,
    _pool_lock
)
from ws.config import config


# ============================================================================
# POOL MANAGEMENT TESTS (6 tests)
# ============================================================================

@pytest.mark.asyncio
async def test_initialize_pool_creates_pool(cleanup_pool):
    """Test that initialize_pool creates a connection pool"""
    pool = await initialize_pool(pool_size=5)

    assert pool is not None
    assert isinstance(pool, aiomysql.Pool)
    assert pool.minsize == 1
    assert pool.maxsize == 5

    await close_pool()


@pytest.mark.asyncio
async def test_initialize_pool_connects_to_database(cleanup_pool):
    """Test that pool connects to the correct database"""
    pool = await initialize_pool(pool_size=3)

    # Get a connection and verify database
    async with pool.acquire() as conn:
        async with conn.cursor() as cursor:
            await cursor.execute("SELECT DATABASE()")
            result = await cursor.fetchone()
            db_name = result[0]

            # Should connect to test database
            assert db_name == config.DB_NAME

    await close_pool()


@pytest.mark.asyncio
async def test_close_pool_closes_all_connections(cleanup_pool):
    """Test that close_pool closes all connections gracefully"""
    pool = await initialize_pool(pool_size=3)

    # Get multiple connections to populate pool
    connections = []
    for _ in range(3):
        conn = await pool.acquire()
        connections.append(conn)

    # Release connections back to pool
    for conn in connections:
        await pool.release(conn)

    # Close pool
    await close_pool()

    # Verify pool is closed
    import ws.database_async as db_async
    assert db_async._pool is None


@pytest.mark.asyncio
async def test_pool_size_respected(db_pool):
    """Test that maximum pool size is respected"""
    # Get pool size
    max_size = db_pool.maxsize

    # Acquire max_size connections
    connections = []
    for _ in range(max_size):
        conn = await db_pool.acquire()
        connections.append(conn)

    # Pool should be exhausted
    assert db_pool.size() == max_size
    assert db_pool.freesize() == 0

    # Release connections
    for conn in connections:
        await db_pool.release(conn)


@pytest.mark.asyncio
async def test_pool_connection_reuse(db_pool):
    """Test that connections are reused after release"""
    # Acquire and release a connection
    conn1 = await db_pool.acquire()
    conn1_id = id(conn1)
    await db_pool.release(conn1)

    # Acquire another connection - should be the same one
    conn2 = await db_pool.acquire()
    conn2_id = id(conn2)
    await db_pool.release(conn2)

    # Same connection should be reused
    assert conn1_id == conn2_id


@pytest.mark.asyncio
async def test_pool_exhaustion_waits(db_pool):
    """Test that pool waits for available connection if pool is full"""
    max_size = db_pool.maxsize

    # Acquire all connections
    connections = []
    for _ in range(max_size):
        conn = await db_pool.acquire()
        connections.append(conn)

    # Try to acquire one more - should wait
    async def delayed_release():
        await asyncio.sleep(0.1)
        await db_pool.release(connections[0])

    # Start release task
    release_task = asyncio.create_task(delayed_release())

    # This should wait until a connection is released
    start_time = time.time()
    conn = await db_pool.acquire()
    wait_time = time.time() - start_time

    # Should have waited
    assert wait_time >= 0.1
    assert conn is not None

    # Cleanup
    await db_pool.release(conn)
    for c in connections[1:]:
        await db_pool.release(c)
    await release_task


# ============================================================================
# QUERY EXECUTION TESTS (5 tests)
# ============================================================================

@pytest.mark.asyncio
async def test_execute_query_insert(db_pool, test_table):
    """Test execute_query with INSERT statement"""
    query = f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)"
    params = ("test_user", 42)

    affected_rows = await execute_query(query, params)

    assert affected_rows == 1

    # Verify insertion
    result = await fetch_one(f"SELECT name, value FROM {test_table} WHERE name = %s", ("test_user",))
    assert result == ("test_user", 42)


@pytest.mark.asyncio
async def test_execute_query_update(db_pool, test_table):
    """Test execute_query with UPDATE statement"""
    # Insert test data
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 10))

    # Update
    query = f"UPDATE {test_table} SET value = %s WHERE name = %s"
    params = (20, "user1")

    affected_rows = await execute_query(query, params)

    assert affected_rows == 1

    # Verify update
    result = await fetch_one(f"SELECT value FROM {test_table} WHERE name = %s", ("user1",))
    assert result[0] == 20


@pytest.mark.asyncio
async def test_execute_query_delete(db_pool, test_table):
    """Test execute_query with DELETE statement"""
    # Insert test data
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 10))
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user2", 20))

    # Delete
    query = f"DELETE FROM {test_table} WHERE name = %s"
    params = ("user1",)

    affected_rows = await execute_query(query, params)

    assert affected_rows == 1

    # Verify deletion
    result = await fetch_all(f"SELECT name FROM {test_table}")
    assert len(result) == 1
    assert result[0][0] == "user2"


@pytest.mark.asyncio
async def test_execute_query_with_parameters(db_pool, test_table):
    """Test that parameterized queries work correctly"""
    # Insert multiple records with different parameters
    names = ["alice", "bob", "charlie"]
    values = [100, 200, 300]

    for name, value in zip(names, values):
        await execute_query(
            f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)",
            (name, value)
        )

    # Query with parameters
    result = await fetch_one(
        f"SELECT value FROM {test_table} WHERE name = %s",
        ("bob",)
    )

    assert result[0] == 200


@pytest.mark.asyncio
async def test_execute_query_sql_injection_prevented(db_pool, test_table):
    """Test that parameterized queries prevent SQL injection"""
    # Insert test data
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 10))

    # Attempt SQL injection via parameter
    malicious_input = "user1' OR '1'='1"

    # This should safely escape the input
    result = await fetch_one(
        f"SELECT value FROM {test_table} WHERE name = %s",
        (malicious_input,)
    )

    # Should return None (no match) instead of all rows
    assert result is None


# ============================================================================
# FETCH OPERATIONS TESTS (6 tests)
# ============================================================================

@pytest.mark.asyncio
async def test_fetch_one_returns_tuple(db_pool, test_table):
    """Test fetch_one returns single row as tuple"""
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 42))

    result = await fetch_one(f"SELECT name, value FROM {test_table} WHERE name = %s", ("user1",))

    assert isinstance(result, tuple)
    assert result == ("user1", 42)


@pytest.mark.asyncio
async def test_fetch_all_returns_list(db_pool, test_table):
    """Test fetch_all returns all rows as list of tuples"""
    # Insert multiple rows
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 10))
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user2", 20))
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user3", 30))

    result = await fetch_all(f"SELECT name, value FROM {test_table} ORDER BY value")

    assert isinstance(result, list)
    assert len(result) == 3
    assert result[0] == ("user1", 10)
    assert result[1] == ("user2", 20)
    assert result[2] == ("user3", 30)


@pytest.mark.asyncio
async def test_fetch_dict_one_returns_dict(db_pool, test_table):
    """Test fetch_dict_one returns single row as dictionary"""
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 42))

    result = await fetch_dict_one(f"SELECT name, value FROM {test_table} WHERE name = %s", ("user1",))

    assert isinstance(result, dict)
    assert result["name"] == "user1"
    assert result["value"] == 42


@pytest.mark.asyncio
async def test_fetch_dict_all_returns_list_of_dicts(db_pool, test_table):
    """Test fetch_dict_all returns all rows as list of dictionaries"""
    # Insert multiple rows
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 10))
    await execute_query(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user2", 20))

    result = await fetch_dict_all(f"SELECT name, value FROM {test_table} ORDER BY value")

    assert isinstance(result, list)
    assert len(result) == 2
    assert isinstance(result[0], dict)
    assert result[0]["name"] == "user1"
    assert result[0]["value"] == 10
    assert result[1]["name"] == "user2"
    assert result[1]["value"] == 20


@pytest.mark.asyncio
async def test_fetch_empty_result(db_pool, test_table):
    """Test that empty results are handled correctly"""
    # Query non-existent data
    result_one = await fetch_one(f"SELECT * FROM {test_table} WHERE name = %s", ("nonexistent",))
    result_all = await fetch_all(f"SELECT * FROM {test_table}")
    result_dict_one = await fetch_dict_one(f"SELECT * FROM {test_table} WHERE name = %s", ("nonexistent",))
    result_dict_all = await fetch_dict_all(f"SELECT * FROM {test_table}")

    assert result_one is None
    assert result_all == []
    assert result_dict_one is None
    assert result_dict_all == []


@pytest.mark.asyncio
async def test_execute_many_batch_insert(db_pool, test_table):
    """Test execute_many for batch insert operations"""
    query = f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)"
    params_list = [
        ("user1", 10),
        ("user2", 20),
        ("user3", 30),
        ("user4", 40),
        ("user5", 50)
    ]

    affected_rows = await execute_many(query, params_list)

    # Should insert all rows
    assert affected_rows == 5

    # Verify all inserted
    result = await fetch_all(f"SELECT COUNT(*) FROM {test_table}")
    assert result[0][0] == 5


# ============================================================================
# TRANSACTION TESTS (5 tests)
# ============================================================================

@pytest.mark.asyncio
async def test_transaction_commits(db_pool, test_table):
    """Test that transaction commits on success"""
    async with Transaction() as conn:
        async with conn.cursor() as cursor:
            await cursor.execute(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 10))
            await cursor.execute(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user2", 20))

    # Verify data was committed
    result = await fetch_all(f"SELECT COUNT(*) FROM {test_table}")
    assert result[0][0] == 2


@pytest.mark.asyncio
async def test_transaction_rolls_back_on_error(db_pool, test_table):
    """Test that transaction rolls back on exception"""
    try:
        async with Transaction() as conn:
            async with conn.cursor() as cursor:
                await cursor.execute(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 10))
                # Raise an exception to trigger rollback
                raise ValueError("Test error")
    except ValueError:
        pass

    # Verify data was not committed
    result = await fetch_all(f"SELECT COUNT(*) FROM {test_table}")
    assert result[0][0] == 0


@pytest.mark.asyncio
async def test_transaction_context_manager(db_pool, test_table):
    """Test that Transaction context manager works correctly"""
    # Test successful transaction
    async with Transaction() as conn:
        assert conn is not None
        async with conn.cursor() as cursor:
            await cursor.execute(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 10))

    # Verify commit
    result = await fetch_one(f"SELECT value FROM {test_table} WHERE name = %s", ("user1",))
    assert result[0] == 10


@pytest.mark.asyncio
async def test_nested_transactions(db_pool, test_table):
    """Test nested transaction-like behavior with savepoints"""
    # Note: aiomysql doesn't support nested transactions directly,
    # but we can test sequential transactions

    async with Transaction() as conn:
        async with conn.cursor() as cursor:
            await cursor.execute(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("outer", 10))

    # Second transaction
    async with Transaction() as conn:
        async with conn.cursor() as cursor:
            await cursor.execute(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("inner", 20))

    # Both should be committed
    result = await fetch_all(f"SELECT COUNT(*) FROM {test_table}")
    assert result[0][0] == 2


@pytest.mark.asyncio
async def test_transaction_isolation(db_pool, test_table):
    """Test that transaction isolation works correctly"""
    # Start a transaction but don't commit
    conn1 = await db_pool.acquire()
    await conn1.begin()

    async with conn1.cursor() as cursor:
        await cursor.execute(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", ("user1", 10))

    # From another connection, data should not be visible yet
    result = await fetch_one(f"SELECT * FROM {test_table} WHERE name = %s", ("user1",))
    assert result is None

    # Commit first transaction
    await conn1.commit()
    await db_pool.release(conn1)

    # Now data should be visible
    result = await fetch_one(f"SELECT * FROM {test_table} WHERE name = %s", ("user1",))
    assert result is not None


# ============================================================================
# ERROR HANDLING TESTS (6 tests)
# ============================================================================

@pytest.mark.asyncio
async def test_pool_not_initialized_raises_error():
    """Test that operations fail gracefully when pool is not initialized"""
    # Close pool if exists
    await close_pool()

    # Try to use database operations
    with pytest.raises(RuntimeError, match="Database pool not initialized"):
        await execute_query("SELECT 1")

    with pytest.raises(RuntimeError, match="Database pool not initialized"):
        await fetch_one("SELECT 1")

    with pytest.raises(RuntimeError, match="Database pool not initialized"):
        async with Transaction():
            pass

    # Re-initialize for other tests
    await initialize_pool()


@pytest.mark.asyncio
async def test_query_timeout_handled(db_pool):
    """Test that query timeout is handled properly"""
    # This test simulates a slow query
    # Note: Actual timeout handling depends on aiomysql configuration

    try:
        # Try to execute a query that would timeout (sleep for longer than timeout)
        # This may or may not timeout depending on server configuration
        result = await fetch_one("SELECT SLEEP(0.1)")
        # If it completes, that's fine too
        assert result is not None or result is None
    except Exception as e:
        # If timeout occurs, ensure it's handled
        assert isinstance(e, (asyncio.TimeoutError, Exception))


@pytest.mark.asyncio
async def test_connection_lost_reconnects(db_pool, test_table):
    """Test that connection pool handles lost connections"""
    # Get a connection
    conn = await db_pool.acquire()

    # Simulate connection issue by closing it
    conn.close()
    await db_pool.release(conn)

    # Next query should work (pool creates new connection)
    result = await fetch_one("SELECT 1")
    assert result == (1,)


@pytest.mark.asyncio
async def test_database_down_raises_error():
    """Test clear error when database is unavailable"""
    await close_pool()

    # Try to initialize with invalid credentials
    with pytest.raises(Exception):
        await initialize_pool(pool_size=1)
        # Try to connect to non-existent database
        async with _pool.acquire() as conn:
            async with conn.cursor() as cursor:
                await cursor.execute("USE nonexistent_db")

    # Re-initialize for other tests
    await close_pool()
    await initialize_pool()


@pytest.mark.asyncio
async def test_malformed_query_raises_error(db_pool):
    """Test that malformed queries raise appropriate errors"""
    with pytest.raises(Exception):  # Should raise aiomysql.Error or similar
        await execute_query("INVALID SQL SYNTAX")


@pytest.mark.asyncio
async def test_get_connection_requires_pool():
    """Test that get_connection fails if pool not initialized"""
    await close_pool()

    with pytest.raises(RuntimeError, match="Database pool not initialized"):
        async with get_connection() as conn:
            pass

    # Re-initialize for other tests
    await initialize_pool()


# ============================================================================
# ADDITIONAL EDGE CASE TESTS (2 tests)
# ============================================================================

@pytest.mark.asyncio
async def test_concurrent_transactions(db_pool, test_table):
    """Test multiple concurrent transactions"""
    async def insert_data(name, value):
        async with Transaction() as conn:
            async with conn.cursor() as cursor:
                await cursor.execute(f"INSERT INTO {test_table} (name, value) VALUES (%s, %s)", (name, value))
                await asyncio.sleep(0.01)  # Simulate work

    # Run multiple transactions concurrently
    tasks = [
        insert_data(f"user{i}", i * 10)
        for i in range(10)
    ]

    await asyncio.gather(*tasks)

    # Verify all inserted
    result = await fetch_all(f"SELECT COUNT(*) FROM {test_table}")
    assert result[0][0] == 10


@pytest.mark.asyncio
async def test_pool_initialization_is_idempotent(cleanup_pool):
    """Test that calling initialize_pool multiple times doesn't create multiple pools"""
    pool1 = await initialize_pool(pool_size=3)
    pool2 = await initialize_pool(pool_size=5)  # Should return same pool

    # Should be the same pool instance
    assert pool1 is pool2

    await close_pool()
