"""
Comprehensive tests for database transaction management.

Tests cover:
- Successful transaction commits
- Transaction rollbacks on errors
- Nested transaction handling
- Concurrent update scenarios
- Connection cleanup
- Row-level locking
- Isolation levels
"""

import pytest
import mysql.connector
from mysql.connector import Error
import threading
import time
from contextlib import contextmanager

# Add ws/ to path
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

from database.transactions import transaction, execute_in_transaction
from config import config

pytestmark = pytest.mark.skip(reason="Requires MySQL database - integration test")


# Test database connection helper
@contextmanager
def get_test_connection():
    """Get test database connection"""
    conn = mysql.connector.connect(
        host=config.DB_HOST,
        port=config.DB_PORT,
        user=config.DB_USER,
        password=config.DB_PASSWORD,
        database=config.DB_NAME
    )
    try:
        yield conn
    finally:
        if conn.is_connected():
            conn.close()


@pytest.fixture(scope='function')
def test_table():
    """
    Create test table for transaction testing.
    Drops and recreates before each test.
    """
    with get_test_connection() as conn:
        cursor = conn.cursor()

        # Drop existing test tables
        cursor.execute("DROP TABLE IF EXISTS test_accounts")
        cursor.execute("DROP TABLE IF EXISTS test_transaction_log")
        cursor.execute("DROP TABLE IF EXISTS test_processed_ops")

        # Create test accounts table
        cursor.execute("""
            CREATE TABLE test_accounts (
                id INT PRIMARY KEY AUTO_INCREMENT,
                name VARCHAR(100),
                balance DECIMAL(10, 2) DEFAULT 0,
                version INT DEFAULT 0
            )
        """)

        # Create transaction log table
        cursor.execute("""
            CREATE TABLE test_transaction_log (
                id INT PRIMARY KEY AUTO_INCREMENT,
                from_account INT,
                to_account INT,
                amount DECIMAL(10, 2),
                timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)

        # Create processed operations table
        cursor.execute("""
            CREATE TABLE test_processed_ops (
                id INT PRIMARY KEY AUTO_INCREMENT,
                operation_id VARCHAR(100) UNIQUE,
                processed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
        """)

        # Insert test data
        cursor.execute("INSERT INTO test_accounts (name, balance) VALUES ('Alice', 1000)")
        cursor.execute("INSERT INTO test_accounts (name, balance) VALUES ('Bob', 500)")
        cursor.execute("INSERT INTO test_accounts (name, balance) VALUES ('Charlie', 250)")

        conn.commit()
        cursor.close()

    yield

    # Cleanup
    with get_test_connection() as conn:
        cursor = conn.cursor()
        cursor.execute("DROP TABLE IF EXISTS test_accounts")
        cursor.execute("DROP TABLE IF EXISTS test_transaction_log")
        cursor.execute("DROP TABLE IF EXISTS test_processed_ops")
        conn.commit()
        cursor.close()


class TestBasicTransactions:
    """Test basic transaction operations"""

    def test_successful_commit(self, test_table):
        """Test that successful transactions commit changes"""
        # Perform transaction
        with transaction() as (conn, cursor):
            cursor.execute("UPDATE test_accounts SET balance = 1500 WHERE name = 'Alice'")

        # Verify change persisted
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            balance = cursor.fetchone()[0]
            assert float(balance) == 1500.0
            cursor.close()

    def test_rollback_on_error(self, test_table):
        """Test that transactions rollback on exceptions"""
        # Get initial balance
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            initial_balance = cursor.fetchone()[0]
            cursor.close()

        # Attempt transaction that raises error
        with pytest.raises(ValueError):
            with transaction() as (conn, cursor):
                cursor.execute("UPDATE test_accounts SET balance = 2000 WHERE name = 'Alice'")
                raise ValueError("Simulated error")

        # Verify rollback - balance should be unchanged
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            final_balance = cursor.fetchone()[0]
            assert float(final_balance) == float(initial_balance)
            cursor.close()

    def test_multiple_operations(self, test_table):
        """Test multiple operations in single transaction"""
        with transaction() as (conn, cursor):
            cursor.execute("UPDATE test_accounts SET balance = balance - 100 WHERE name = 'Alice'")
            cursor.execute("UPDATE test_accounts SET balance = balance + 100 WHERE name = 'Bob'")

        # Verify both changes
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            alice_balance = cursor.fetchone()[0]
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Bob'")
            bob_balance = cursor.fetchone()[0]
            assert float(alice_balance) == 900.0
            assert float(bob_balance) == 600.0
            cursor.close()

    def test_insert_and_select(self, test_table):
        """Test insert and select in same transaction"""
        with transaction() as (conn, cursor):
            cursor.execute("INSERT INTO test_accounts (name, balance) VALUES ('David', 750)")
            cursor.execute("SELECT LAST_INSERT_ID()")
            new_id = cursor.fetchone()[0]
            assert new_id > 0

        # Verify insert persisted
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'David'")
            result = cursor.fetchone()
            assert result is not None
            assert float(result[0]) == 750.0
            cursor.close()


class TestConnectionManagement:
    """Test connection management and cleanup"""

    def test_connection_cleanup_on_success(self, test_table):
        """Test connection is properly closed after successful transaction"""
        conn_ref = None

        with transaction() as (conn, cursor):
            conn_ref = conn
            cursor.execute("SELECT 1")

        # Connection should be closed (returned to pool)
        assert not conn_ref.is_connected()

    def test_connection_cleanup_on_error(self, test_table):
        """Test connection is properly closed after failed transaction"""
        conn_ref = None

        with pytest.raises(ValueError):
            with transaction() as (conn, cursor):
                conn_ref = conn
                cursor.execute("SELECT 1")
                raise ValueError("Test error")

        # Connection should be closed even after error
        assert not conn_ref.is_connected()

    def test_existing_connection_not_closed(self, test_table):
        """Test that existing connection is not closed by transaction"""
        with get_test_connection() as conn:
            # Use existing connection
            with transaction(conn=conn) as (trans_conn, cursor):
                cursor.execute("SELECT 1")

            # Connection should still be open (not closed by transaction)
            assert conn.is_connected()


class TestRowLevelLocking:
    """Test row-level locking with FOR UPDATE"""

    def test_for_update_locks_row(self, test_table):
        """Test that FOR UPDATE prevents concurrent modifications"""
        results = []
        errors = []

        def thread1():
            try:
                with transaction(lock=True) as (conn, cursor):
                    # Lock Alice's account
                    cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice' FOR UPDATE")
                    balance = cursor.fetchone()[0]
                    results.append(('thread1_start', float(balance)))

                    # Simulate processing time
                    time.sleep(0.5)

                    # Update balance
                    cursor.execute("UPDATE test_accounts SET balance = balance - 100 WHERE name = 'Alice'")
                    results.append(('thread1_update', 'done'))
            except Exception as e:
                errors.append(('thread1', str(e)))

        def thread2():
            try:
                # Wait a bit to ensure thread1 gets lock first
                time.sleep(0.1)

                with transaction(lock=True) as (conn, cursor):
                    results.append(('thread2_waiting', 'started'))
                    # This should wait for thread1's transaction to complete
                    cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice' FOR UPDATE")
                    balance = cursor.fetchone()[0]
                    results.append(('thread2_got_lock', float(balance)))

                    # Update balance
                    cursor.execute("UPDATE test_accounts SET balance = balance - 50 WHERE name = 'Alice'")
            except Exception as e:
                errors.append(('thread2', str(e)))

        # Run threads
        t1 = threading.Thread(target=thread1)
        t2 = threading.Thread(target=thread2)

        t1.start()
        t2.start()

        t1.join()
        t2.join()

        # Check no errors occurred
        assert len(errors) == 0, f"Errors: {errors}"

        # Verify final balance (1000 - 100 - 50 = 850)
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            final_balance = cursor.fetchone()[0]
            assert float(final_balance) == 850.0
            cursor.close()

        # Verify thread2 waited for thread1
        thread1_indices = [i for i, r in enumerate(results) if 'thread1' in r[0]]
        thread2_got_lock_index = [i for i, r in enumerate(results) if r[0] == 'thread2_got_lock']

        if thread2_got_lock_index:
            # Thread2 should get lock after thread1 completes
            assert max(thread1_indices) < thread2_got_lock_index[0]


class TestAtomicOperations:
    """Test atomic operations like money transfers"""

    def test_atomic_transfer_success(self, test_table):
        """Test successful atomic money transfer"""
        with transaction(lock=True) as (conn, cursor):
            # Lock and fetch balances
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice' FOR UPDATE")
            alice_balance = cursor.fetchone()[0]

            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Bob' FOR UPDATE")
            bob_balance = cursor.fetchone()[0]

            # Perform transfer
            transfer_amount = 200
            cursor.execute("UPDATE test_accounts SET balance = balance - %s WHERE name = 'Alice'",
                         (transfer_amount,))
            cursor.execute("UPDATE test_accounts SET balance = balance + %s WHERE name = 'Bob'",
                         (transfer_amount,))

            # Log transaction
            cursor.execute("""
                INSERT INTO test_transaction_log (from_account, to_account, amount)
                SELECT a1.id, a2.id, %s
                FROM test_accounts a1, test_accounts a2
                WHERE a1.name = 'Alice' AND a2.name = 'Bob'
            """, (transfer_amount,))

        # Verify balances
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            assert float(cursor.fetchone()[0]) == 800.0

            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Bob'")
            assert float(cursor.fetchone()[0]) == 700.0

            # Verify log entry
            cursor.execute("SELECT COUNT(*) FROM test_transaction_log")
            assert cursor.fetchone()[0] == 1
            cursor.close()

    def test_atomic_transfer_insufficient_funds(self, test_table):
        """Test transfer rollback on insufficient funds"""
        # Get initial balances
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Bob'")
            initial_bob_balance = cursor.fetchone()[0]
            cursor.close()

        # Attempt transfer with insufficient funds
        with pytest.raises(ValueError):
            with transaction(lock=True) as (conn, cursor):
                cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Bob' FOR UPDATE")
                balance = cursor.fetchone()[0]

                transfer_amount = 600  # Bob only has 500
                if balance < transfer_amount:
                    raise ValueError("Insufficient funds")

                cursor.execute("UPDATE test_accounts SET balance = balance - %s WHERE name = 'Bob'",
                             (transfer_amount,))

        # Verify balance unchanged
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Bob'")
            final_balance = cursor.fetchone()[0]
            assert float(final_balance) == float(initial_bob_balance)
            cursor.close()


class TestIsolationLevels:
    """Test transaction isolation levels"""

    def test_read_committed_isolation(self, test_table):
        """Test READ COMMITTED isolation level"""
        with transaction(isolation_level='READ COMMITTED') as (conn, cursor):
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            balance = cursor.fetchone()[0]
            assert balance is not None

    def test_repeatable_read_isolation(self, test_table):
        """Test REPEATABLE READ isolation level"""
        with transaction(isolation_level='REPEATABLE READ') as (conn, cursor):
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            first_read = cursor.fetchone()[0]

            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            second_read = cursor.fetchone()[0]

            assert first_read == second_read

    def test_serializable_isolation(self, test_table):
        """Test SERIALIZABLE isolation level"""
        with transaction(isolation_level='SERIALIZABLE') as (conn, cursor):
            cursor.execute("SELECT COUNT(*) FROM test_accounts")
            count = cursor.fetchone()[0]
            assert count == 3  # Alice, Bob, Charlie

    def test_invalid_isolation_level(self, test_table):
        """Test that invalid isolation level raises error"""
        with pytest.raises(ValueError, match="Invalid isolation level"):
            with transaction(isolation_level='INVALID_LEVEL') as (conn, cursor):
                pass


class TestExecuteInTransaction:
    """Test execute_in_transaction helper function"""

    def test_execute_in_transaction_success(self, test_table):
        """Test successful execution with helper function"""
        def update_balance(conn, cursor):
            cursor.execute("UPDATE test_accounts SET balance = 1200 WHERE name = 'Alice'")
            return True

        result = execute_in_transaction(update_balance)
        assert result is True

        # Verify change
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            balance = cursor.fetchone()[0]
            assert float(balance) == 1200.0
            cursor.close()

    def test_execute_in_transaction_error(self, test_table):
        """Test rollback with helper function"""
        def failing_operation(conn, cursor):
            cursor.execute("UPDATE test_accounts SET balance = 2000 WHERE name = 'Alice'")
            raise ValueError("Simulated error")

        with pytest.raises(ValueError):
            execute_in_transaction(failing_operation)

        # Verify rollback
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            balance = cursor.fetchone()[0]
            assert float(balance) == 1000.0  # Original value
            cursor.close()


class TestIdempotency:
    """Test idempotent operations using database locks"""

    def test_idempotent_operation(self, test_table):
        """Test operation only executes once even if called multiple times"""
        operation_id = 'test_op_001'

        def process_operation():
            with transaction(lock=True) as (conn, cursor):
                # Check if already processed
                cursor.execute(
                    "SELECT id FROM test_processed_ops WHERE operation_id = %s FOR UPDATE",
                    (operation_id,)
                )

                if cursor.fetchone():
                    return False  # Already processed

                # Process operation
                cursor.execute("UPDATE test_accounts SET balance = balance + 100 WHERE name = 'Alice'")

                # Mark as processed
                cursor.execute(
                    "INSERT INTO test_processed_ops (operation_id) VALUES (%s)",
                    (operation_id,)
                )

                return True

        # First execution should succeed
        result1 = process_operation()
        assert result1 is True

        # Second execution should be skipped
        result2 = process_operation()
        assert result2 is False

        # Verify balance only increased once
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            balance = cursor.fetchone()[0]
            assert float(balance) == 1100.0  # 1000 + 100 (only once)
            cursor.close()


class TestEdgeCases:
    """Test edge cases and error conditions"""

    def test_empty_transaction(self, test_table):
        """Test transaction with no operations"""
        with transaction() as (conn, cursor):
            pass  # No operations

        # Should complete without error

    def test_transaction_with_select_only(self, test_table):
        """Test transaction with only SELECT queries"""
        with transaction() as (conn, cursor):
            cursor.execute("SELECT * FROM test_accounts")
            results = cursor.fetchall()
            assert len(results) == 3

    def test_rollback_after_partial_updates(self, test_table):
        """Test rollback after some updates succeeded"""
        with pytest.raises(Exception):
            with transaction() as (conn, cursor):
                cursor.execute("UPDATE test_accounts SET balance = 1500 WHERE name = 'Alice'")
                cursor.execute("UPDATE test_accounts SET balance = 800 WHERE name = 'Bob'")
                # Force error
                cursor.execute("UPDATE nonexistent_table SET x = 1")

        # Verify all changes rolled back
        with get_test_connection() as conn:
            cursor = conn.cursor()
            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Alice'")
            assert float(cursor.fetchone()[0]) == 1000.0

            cursor.execute("SELECT balance FROM test_accounts WHERE name = 'Bob'")
            assert float(cursor.fetchone()[0]) == 500.0
            cursor.close()


if __name__ == '__main__':
    pytest.main([__file__, '-v'])
