"""
Unit tests for database transaction management (no live DB required).

Uses mocks to test transaction logic without requiring MySQL.
"""

import pytest
from unittest.mock import Mock, MagicMock, patch, call
import mysql.connector
from mysql.connector import Error

# Add ws/ to path
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))


class TestTransactionContextManager:
    """Test transaction context manager with mocked connections"""

    @patch('database.transactions.get_database_connection')
    def test_successful_transaction_commits(self, mock_get_conn):
        """Test that successful transactions call commit"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_conn.is_connected.return_value = False  # After close
        mock_get_conn.return_value = mock_conn

        # Execute transaction
        with transaction() as (conn, cursor):
            cursor.execute("SELECT 1")

        # Verify commit was called
        assert mock_conn.commit.called
        assert mock_cursor.close.called
        assert mock_conn.close.called

    @patch('database.transactions.get_database_connection')
    def test_failed_transaction_rolls_back(self, mock_get_conn):
        """Test that failed transactions call rollback"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        # Execute transaction that raises error
        with pytest.raises(ValueError):
            with transaction() as (conn, cursor):
                raise ValueError("Test error")

        # Verify rollback was called, not commit
        assert mock_conn.rollback.called
        assert not mock_conn.commit.called
        assert mock_cursor.close.called
        assert mock_conn.close.called

    @patch('database.transactions.get_database_connection')
    def test_transaction_disables_autocommit(self, mock_get_conn):
        """Test that transaction disables autocommit"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        with transaction() as (conn, cursor):
            pass

        # Verify autocommit was set to False
        assert mock_conn.autocommit == False

    @patch('database.transactions.get_database_connection')
    def test_transaction_starts_explicitly(self, mock_get_conn):
        """Test that transaction calls start_transaction"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        with transaction() as (conn, cursor):
            pass

        # Verify start_transaction was called
        assert mock_conn.start_transaction.called

    @patch('database.transactions.get_database_connection')
    def test_connection_cleanup_on_error(self, mock_get_conn):
        """Test connection cleanup even when rollback fails"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_conn.rollback.side_effect = Exception("Rollback error")
        mock_get_conn.return_value = mock_conn

        # Execute transaction that raises error
        with pytest.raises(ValueError):
            with transaction() as (conn, cursor):
                raise ValueError("Test error")

        # Verify cleanup still happens
        assert mock_cursor.close.called
        assert mock_conn.close.called

    def test_existing_connection_not_closed(self):
        """Test that existing connection is not closed by transaction"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor

        with transaction(conn=mock_conn) as (conn, cursor):
            pass

        # Verify connection was NOT closed
        assert not mock_conn.close.called
        # But cursor should be closed
        assert mock_cursor.close.called

    def test_cursor_cleanup_on_error(self):
        """Test cursor cleanup even when close fails"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_cursor.close.side_effect = Exception("Cursor close error")
        mock_conn.cursor.return_value = mock_cursor

        # Execute transaction
        with transaction(conn=mock_conn) as (conn, cursor):
            pass

        # Should not raise exception despite cursor.close() error
        # (error is logged but suppressed)


class TestIsolationLevels:
    """Test isolation level handling"""

    @patch('database.transactions.get_database_connection')
    def test_read_committed_isolation(self, mock_get_conn):
        """Test READ COMMITTED isolation level"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        with transaction(isolation_level='READ COMMITTED') as (conn, cursor):
            pass

        # Verify isolation level was set
        mock_cursor.execute.assert_any_call('SET TRANSACTION ISOLATION LEVEL READ COMMITTED')

    @patch('database.transactions.get_database_connection')
    def test_repeatable_read_isolation(self, mock_get_conn):
        """Test REPEATABLE READ isolation level"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        with transaction(isolation_level='REPEATABLE READ') as (conn, cursor):
            pass

        # Verify isolation level was set
        mock_cursor.execute.assert_any_call('SET TRANSACTION ISOLATION LEVEL REPEATABLE READ')

    @patch('database.transactions.get_database_connection')
    def test_serializable_isolation(self, mock_get_conn):
        """Test SERIALIZABLE isolation level"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        with transaction(isolation_level='SERIALIZABLE') as (conn, cursor):
            pass

        # Verify isolation level was set
        mock_cursor.execute.assert_any_call('SET TRANSACTION ISOLATION LEVEL SERIALIZABLE')

    @patch('database.transactions.get_database_connection')
    def test_invalid_isolation_level_raises_error(self, mock_get_conn):
        """Test that invalid isolation level raises ValueError"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        with pytest.raises(ValueError, match="Invalid isolation level"):
            with transaction(isolation_level='INVALID_LEVEL') as (conn, cursor):
                pass

    @patch('database.transactions.get_database_connection')
    def test_case_insensitive_isolation_level(self, mock_get_conn):
        """Test that isolation level is case-insensitive"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        # lowercase should work
        with transaction(isolation_level='read committed') as (conn, cursor):
            pass

        # Verify isolation level was set (uppercased)
        mock_cursor.execute.assert_any_call('SET TRANSACTION ISOLATION LEVEL READ COMMITTED')


class TestExecuteInTransaction:
    """Test execute_in_transaction helper"""

    @patch('database.transactions.get_database_connection')
    def test_execute_in_transaction_success(self, mock_get_conn):
        """Test successful execution with helper"""
        from database.transactions import execute_in_transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        # Define operation
        def my_operation(conn, cursor):
            cursor.execute("SELECT 1")
            return "success"

        # Execute
        result = execute_in_transaction(my_operation)

        # Verify
        assert result == "success"
        assert mock_conn.commit.called

    @patch('database.transactions.get_database_connection')
    def test_execute_in_transaction_error(self, mock_get_conn):
        """Test rollback with helper on error"""
        from database.transactions import execute_in_transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        # Define operation that raises error
        def failing_operation(conn, cursor):
            raise ValueError("Test error")

        # Execute and expect error
        with pytest.raises(ValueError):
            execute_in_transaction(failing_operation)

        # Verify rollback
        assert mock_conn.rollback.called
        assert not mock_conn.commit.called

    @patch('database.transactions.get_database_connection')
    def test_execute_in_transaction_with_existing_conn(self, mock_get_conn):
        """Test helper with existing connection"""
        from database.transactions import execute_in_transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor

        # Define operation
        def my_operation(conn, cursor):
            return "done"

        # Execute with existing connection
        result = execute_in_transaction(my_operation, conn=mock_conn)

        # Verify
        assert result == "done"
        # Should not have created new connection
        assert not mock_get_conn.called


class TestTransactionExamples:
    """Test example functions from module"""

    @patch('database.transactions.get_database_connection')
    def test_example_atomic_transfer_success(self, mock_get_conn):
        """Test example atomic transfer function"""
        from database.transactions import example_atomic_transfer

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        # Mock SELECT results
        mock_cursor.fetchone.side_effect = [
            (1000,),  # from_balance
            (1,),     # to_account exists
        ]

        # Execute transfer
        example_atomic_transfer(1, 2, 100)

        # Verify queries executed
        execute_calls = [str(call) for call in mock_cursor.execute.call_args_list]
        assert any('SELECT balance FROM accounts WHERE id' in str(call) for call in execute_calls)
        assert any('UPDATE accounts SET balance' in str(call) for call in execute_calls)
        assert any('INSERT INTO transaction_log' in str(call) for call in execute_calls)

        # Verify commit
        assert mock_conn.commit.called

    @patch('database.transactions.get_database_connection')
    def test_example_atomic_transfer_insufficient_funds(self, mock_get_conn):
        """Test example atomic transfer with insufficient funds"""
        from database.transactions import example_atomic_transfer

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        # Mock insufficient balance
        mock_cursor.fetchone.return_value = (50,)  # Only 50 available

        # Execute transfer should fail
        with pytest.raises(ValueError, match="Insufficient funds"):
            example_atomic_transfer(1, 2, 100)

        # Verify rollback
        assert mock_conn.rollback.called
        assert not mock_conn.commit.called

    @patch('database.transactions.get_database_connection')
    def test_example_idempotent_operation_first_execution(self, mock_get_conn):
        """Test example idempotent operation - first execution"""
        from database.transactions import example_idempotent_operation

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        # Mock operation not yet processed
        mock_cursor.fetchone.return_value = None

        # Execute
        result = example_idempotent_operation('op_001', 123)

        # Should process and return True
        assert result is True
        assert mock_conn.commit.called

    @patch('database.transactions.get_database_connection')
    def test_example_idempotent_operation_duplicate(self, mock_get_conn):
        """Test example idempotent operation - duplicate execution"""
        from database.transactions import example_idempotent_operation

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        # Mock operation already processed
        mock_cursor.fetchone.return_value = (1,)  # Found existing record

        # Execute
        result = example_idempotent_operation('op_001', 123)

        # Should skip and return False
        assert result is False


class TestTransactionBehavior:
    """Test specific transaction behaviors"""

    @patch('database.transactions.get_database_connection')
    def test_multiple_operations_in_transaction(self, mock_get_conn):
        """Test multiple operations in single transaction"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        with transaction() as (conn, cursor):
            cursor.execute("UPDATE table1 SET x = 1")
            cursor.execute("UPDATE table2 SET y = 2")
            cursor.execute("INSERT INTO table3 VALUES (3)")

        # Verify all operations executed before commit
        assert mock_cursor.execute.call_count == 3
        # Commit should be last
        assert mock_conn.commit.called

    @patch('database.transactions.get_database_connection')
    def test_transaction_with_for_update(self, mock_get_conn):
        """Test transaction with FOR UPDATE locking"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        # Note: FOR UPDATE locking is specified in the query, not as a parameter
        with transaction() as (conn, cursor):
            cursor.execute("SELECT * FROM accounts WHERE id = 1 FOR UPDATE")
            cursor.execute("UPDATE accounts SET balance = 100 WHERE id = 1")

        # Verify operations and commit
        assert mock_cursor.execute.call_count == 2
        assert mock_conn.commit.called

    @patch('database.transactions.get_database_connection')
    def test_yields_correct_objects(self, mock_get_conn):
        """Test that transaction yields correct connection and cursor"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        with transaction() as (conn, cursor):
            # Verify we got the right objects
            assert conn is mock_conn
            assert cursor is mock_cursor


class TestErrorHandling:
    """Test error handling scenarios"""

    @patch('database.transactions.get_database_connection')
    def test_connection_error_propagates(self, mock_get_conn):
        """Test that connection errors propagate correctly"""
        from database.transactions import transaction

        # Setup mock to raise connection error
        mock_get_conn.side_effect = mysql.connector.Error("Connection failed")

        # Should propagate the error
        with pytest.raises(mysql.connector.Error):
            with transaction() as (conn, cursor):
                pass

    @patch('database.transactions.get_database_connection')
    def test_query_error_triggers_rollback(self, mock_get_conn):
        """Test that query errors trigger rollback"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_cursor.execute.side_effect = mysql.connector.Error("Query error")
        mock_conn.cursor.return_value = mock_cursor
        mock_get_conn.return_value = mock_conn

        # Execute transaction with query error
        with pytest.raises(mysql.connector.Error):
            with transaction() as (conn, cursor):
                cursor.execute("BAD QUERY")

        # Verify rollback was called
        assert mock_conn.rollback.called
        assert not mock_conn.commit.called

    @patch('database.transactions.get_database_connection')
    def test_rollback_error_does_not_hide_original_error(self, mock_get_conn):
        """Test that rollback errors don't hide the original error"""
        from database.transactions import transaction

        # Setup mocks
        mock_conn = MagicMock()
        mock_cursor = MagicMock()
        mock_conn.cursor.return_value = mock_cursor
        mock_conn.rollback.side_effect = Exception("Rollback failed")
        mock_get_conn.return_value = mock_conn

        # Execute transaction with error
        with pytest.raises(ValueError, match="Original error"):
            with transaction() as (conn, cursor):
                raise ValueError("Original error")

        # Original error should be raised, not rollback error


if __name__ == '__main__':
    pytest.main([__file__, '-v'])
