import pytest
import pytest_asyncio
import asyncio
from database_async import (
    initialize_pool,
    close_pool,
    get_connection,
    execute_query,
    execute_many,
    fetch_one,
    fetch_all,
)

pytestmark = pytest.mark.skip(reason="Requires MySQL database - integration test")

@pytest_asyncio.fixture(scope="function")
async def db_pool():
    """Initialize database pool for tests"""
    await initialize_pool(pool_size=5)
    yield
    await close_pool()

@pytest.mark.asyncio
async def test_execute_query(db_pool):
    """Test simple query execution"""
    await execute_query("CREATE TEMPORARY TABLE test_table (id INT, name VARCHAR(50))")
    await execute_query("INSERT INTO test_table VALUES (%s, %s)", (1, "test"))

    result = await fetch_one("SELECT * FROM test_table WHERE id = %s", (1,))
    assert result == (1, "test")

@pytest.mark.asyncio
async def test_execute_many(db_pool):
    """Test batch insert"""
    await execute_query("CREATE TEMPORARY TABLE test_batch (id INT, value VARCHAR(50))")

    data = [(1, "a"), (2, "b"), (3, "c")]
    await execute_many("INSERT INTO test_batch VALUES (%s, %s)", data)

    results = await fetch_all("SELECT * FROM test_batch ORDER BY id")
    assert len(results) == 3
    assert results[0] == (1, "a")

@pytest.mark.asyncio
async def test_concurrent_queries(db_pool):
    """Test handling concurrent queries"""
    await execute_query("CREATE TEMPORARY TABLE test_concurrent (id INT PRIMARY KEY)")

    async def insert_value(val):
        await execute_query("INSERT INTO test_concurrent VALUES (%s)", (val,))
        return val

    # Run 10 concurrent inserts
    results = await asyncio.gather(*[insert_value(i) for i in range(10)])
    assert len(results) == 10

    count = await fetch_one("SELECT COUNT(*) FROM test_concurrent")
    assert count[0] == 10
