# Copyright (c) 2025 Thomas Goirand <zigo@debian.org>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import threading

from flask import request, jsonify
from functools import wraps

from keystoneauth1 import session as ks_session
from keystoneauth1.exceptions import http as ks_http_exceptions

from oslo_policy import policy as oslo_policy
from oslo_context import context as oslo_context
from oslo_log import log as logging

# Import your policies explicitly
from vmms.policy import rules as vmms_rules

LOG = logging.getLogger(__name__)

_ENFORCER = None

_token_cache = {}
_token_cache_lock = threading.Lock()

def _validate_keystone_token(CONF, token, ttl=60):
    """Validate token with Keystone using keystoneauth1 session + TTL cache"""
    if not token:
        return False

    now = time.time()
    with _token_cache_lock:
        cached = _token_cache.get(token)
        if cached and cached['expires'] > now:
            return cached.get('roles', [])

    try:
        sess = ks_session.Session()
        resp = sess.get(f"{CONF.identity.auth_url}/v3/auth/tokens",
                        headers={'X-Auth-Token': token},
                        authenticated=False,
                        raise_exc=True,
        )
        if resp.status_code in (200, 201, 204):
            # Extract roles from the token payload
            token_data = resp.json().get('token', {})
            keystone_roles = [role['name'] for role in token_data.get('roles', [])]

            with _token_cache_lock:
                _token_cache[token] = {
                    'expires': now + ttl,
                    'roles': keystone_roles
                }
            return keystone_roles

    except ks_http_exceptions.HttpError as e:
        LOG.warning(f"Token validation failed: {e}")
    except Exception as e:
        LOG.error(f"Unexpected error during token validation: {e}")
    # Return empty list instead of False to indicate validation failure
    return []


def init_enforcer(CONF):
    global _ENFORCER
    if not _ENFORCER:
        LOG.debug("🐛 Initializing policy enforcer")
        # Create enforcer with proper configuration
        _ENFORCER = oslo_policy.Enforcer(CONF, policy_file='', rules={})

        # Explicitly register your policies
        LOG.debug("🐛 Registering VMMS policy rules")
        for rule in vmms_rules.list_rules():
            _ENFORCER.register_default(rule)

        # Load the rules after registering defaults
        _ENFORCER.load_rules(force_reload=True)

        LOG.debug(f"🐛 Registered {len(_ENFORCER.registered_rules)} policy rules")
            
    return _ENFORCER

def get_enforcer(CONF):
    enforcer = _ENFORCER or init_enforcer(CONF)
    return enforcer

def enforce_policy(CONF, action, target=None):
    """Enforce policy for the current Flask request"""
    # Create context from request headers (set by keystonemiddleware)
    roles_header = request.headers.get('X-Roles', '')
    header_roles = [role.strip() for role in roles_header.split(',')] if roles_header else []

    # Validate token and get roles from Keystone
    token = request.headers.get('X-Auth-Token')
    keystone_roles = _validate_keystone_token(CONF, token)

    if not keystone_roles:
        return jsonify({'error': 'Invalid or expired Keystone token'}), 401

    # Make sure all the roles in the headers are returned by keystone
    # when given to it the user token
    if set(header_roles) != set(keystone_roles):
        LOG.warning(f"Role mismatch: header roles {header_roles} != Keystone roles {keystone_roles}")
        return jsonify({'error': 'Forbidden: Role mismatch'}), 403

    ctx = oslo_context.RequestContext(
        user_id=request.headers.get('X-User-Id'),
        project_id=request.headers.get('X-Project-Id'),
        roles=keystone_roles,
        is_admin='admin' in keystone_roles
    )
    
    # Get policy enforcer
    enforcer = get_enforcer(CONF)
    
    # Set default target if not provided
    if target is None:
        target = {
            'project_id': ctx.project_id or '',
            'user_id': ctx.user_id or ''
        }
    
    # Get credentials from context
    creds = ctx.to_policy_values()

    try:
        # Use the standard enforce method
        result = enforcer.enforce(action, target, creds)
        
        if not result:
            raise oslo_policy.PolicyNotAuthorized(action, target, creds)
            
        return result
    except oslo_policy.PolicyNotAuthorized:
        raise
    except Exception as e:
        LOG.error(f"⧱ Policy enforcement failed for {action}: {e}", exc_info=True)
        raise

def require_policy_factory(get_config_func):
    """Factory function to create policy decorator with config getter"""
    def require_policy(action):
        """Decorator to enforce policy on API endpoints"""
        def decorator(f):
            @wraps(f)
            def wrapper(*args, **kwargs):
                # Check authentication first
                identity_status = request.headers.get('X-Identity-Status')
                if not identity_status or identity_status.upper() != 'CONFIRMED':
                    return jsonify({'error': 'Authentication required'}), 401
                
                try:
                    CONF = get_config_func()
                    enforce_policy(CONF, action)
                except KeyError as e:
                    LOG.error(f"⧱ Policy key error: {e}")
                    return jsonify({'error': 'Policy configuration error'}), 500
                except Exception as e:
                    return jsonify({'error': 'Forbidden'}), 403
                    
                return f(*args, **kwargs)
            return wrapper
        return decorator
    return require_policy
