"""
Authentication and session management using JWT tokens.

This module provides JWT-based authentication to prevent unauthorized access.
"""

import jwt
import hashlib
import secrets
from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from config import config

class AuthError(Exception):
    """Authentication error"""
    pass


class AuthManager:
    """Manage user authentication using JWT tokens"""

    def __init__(self):
        self.secret = config.JWT_SECRET
        self.timeout = config.SESSION_TIMEOUT
        self.algorithm = 'HS256'

    def hash_password(self, password: str) -> str:
        """
        Hash password using SHA-256.

        Args:
            password: Plain text password

        Returns:
            Hashed password (hex)
        """
        return hashlib.sha256(password.encode('utf-8')).hexdigest()

    def verify_password(self, password: str, hashed: str) -> bool:
        """
        Verify password against hash.

        Args:
            password: Plain text password
            hashed: Hashed password from database

        Returns:
            True if password matches
        """
        try:
            return self.hash_password(password) == hashed
        except Exception as e:
            print(f"Password verification error: {e}")
            return False

    def create_token(self, user_id: str, additional_claims: Dict = None) -> str:
        """
        Create JWT token for user.

        Args:
            user_id: User ID
            additional_claims: Optional additional JWT claims

        Returns:
            JWT token string
        """
        now = datetime.utcnow()
        payload = {
            'user_id': user_id,
            'iat': now,
            'exp': now + timedelta(seconds=self.timeout),
            'jti': secrets.token_urlsafe(16)  # Unique token ID
        }

        if additional_claims:
            payload.update(additional_claims)

        token = jwt.encode(payload, self.secret, algorithm=self.algorithm)
        return token

    def verify_token(self, token: str) -> Dict[str, Any]:
        """
        Verify JWT token.

        Args:
            token: JWT token string

        Returns:
            Decoded token payload

        Raises:
            AuthError: If token invalid or expired
        """
        try:
            payload = jwt.decode(token, self.secret, algorithms=[self.algorithm])
            return payload
        except jwt.ExpiredSignatureError:
            raise AuthError("Token expired")
        except jwt.InvalidTokenError as e:
            raise AuthError(f"Invalid token: {e}")

    def create_session(self, user_id: str) -> Dict[str, str]:
        """
        Create authenticated session.

        Args:
            user_id: User ID

        Returns:
            Dict with token and expiry
        """
        token = self.create_token(user_id)

        return {
            'token': token,
            'user_id': user_id,
            'expires_in': self.timeout
        }

    def authenticate_user(self, user_id: str, password: str) -> Optional[Dict[str, str]]:
        """
        Authenticate user with credentials.

        Args:
            user_id: User ID
            password: Password

        Returns:
            Session dict if successful, None otherwise
        """
        from database import get_database_connection

        try:
            conn = get_database_connection()
            cursor = conn.cursor()

            # Check if user exists and get password hash
            sql = "SELECT userID, passwordHash FROM users WHERE userID = %s"
            cursor.execute(sql, (user_id,))

            result = cursor.fetchone()
            if not result:
                print(f"Authentication failed: user {user_id} not found")
                return None

            stored_user_id, password_hash = result

            # Verify password
            if not self.verify_password(password, password_hash):
                print(f"Authentication failed: invalid password for {user_id}")
                return None

            # Create session
            session = self.create_session(user_id)
            print(f"User {user_id} authenticated successfully")

            return session

        except Exception as e:
            print(f"Authentication error: {e}")
            return None
        finally:
            if cursor:
                cursor.close()
            if conn:
                conn.close()


# Singleton instance
auth_manager = AuthManager()


def require_auth(token: str) -> Optional[str]:
    """
    Require authentication for operation.

    Args:
        token: JWT token from client

    Returns:
        User ID if authenticated, None otherwise
    """
    try:
        payload = auth_manager.verify_token(token)
        user_id = payload['user_id']
        print(f"Authenticated user: {user_id}")
        return user_id

    except AuthError as e:
        print(f"Authentication failed: {e}")
        return None
