OAuth2, OIDC & JWT Standards

Part of Module 5: Security & Compliance

Modern AI applications require robust authentication and authorization mechanisms. This module covers OAuth2 for delegated authorization, OpenID Connect (OIDC) for identity verification, and JSON Web Tokens (JWT) for secure information exchange in AI systems.

OAuth2 for AI Applications

Understanding OAuth2 Flows

OAuth2 enables AI applications to access resources on behalf of users without exposing credentials. This is crucial for AI agents that need to interact with multiple services.

Key OAuth2 Flows for AI:

  • Authorization Code: Best for web-based AI dashboards
  • Client Credentials: Ideal for service-to-service AI communication
  • Device Code: Perfect for AI IoT devices and CLI tools
  • Refresh Token: Maintains long-lived AI agent sessions
# OAuth2 implementation for AI services
import requests
import jwt
from datetime import datetime, timedelta
from typing import Dict, Optional
import secrets

class OAuth2AIClient:
    """OAuth2 client for AI service authentication"""
    
    def __init__(self, 
                 client_id: str,
                 client_secret: str,
                 auth_server: str,
                 redirect_uri: str):
        self.client_id = client_id
        self.client_secret = client_secret
        self.auth_server = auth_server
        self.redirect_uri = redirect_uri
        self.tokens = {}
    
    def get_authorization_url(self, 
                              scope: str = "ai.model.read ai.model.predict",
                              state: Optional[str] = None) -> str:
        """Generate authorization URL for user consent"""
        if not state:
            state = secrets.token_urlsafe(32)
        
        params = {
            "response_type": "code",
            "client_id": self.client_id,
            "redirect_uri": self.redirect_uri,
            "scope": scope,
            "state": state,
            "access_type": "offline"  # Request refresh token
        }
        
        return f"{self.auth_server}/authorize?" + \
               "&".join([f"{k}={v}" for k, v in params.items()])
    
    def exchange_code_for_token(self, code: str) -> Dict:
        """Exchange authorization code for access token"""
        token_endpoint = f"{self.auth_server}/token"
        
        data = {
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": self.redirect_uri,
            "client_id": self.client_id,
            "client_secret": self.client_secret
        }
        
        response = requests.post(token_endpoint, data=data)
        tokens = response.json()
        
        # Store tokens with expiry
        self.tokens = {
            "access_token": tokens["access_token"],
            "refresh_token": tokens.get("refresh_token"),
            "expires_at": datetime.now() + timedelta(
                seconds=tokens.get("expires_in", 3600)
            )
        }
        
        return self.tokens
    
    def get_client_credentials_token(self, scope: str = "ai.service") -> str:
        """Get token using client credentials flow for service-to-service"""
        token_endpoint = f"{self.auth_server}/token"
        
        data = {
            "grant_type": "client_credentials",
            "client_id": self.client_id,
            "client_secret": self.client_secret,
            "scope": scope
        }
        
        response = requests.post(token_endpoint, data=data)
        return response.json()["access_token"]
    
    def refresh_access_token(self) -> str:
        """Refresh expired access token"""
        if not self.tokens.get("refresh_token"):
            raise ValueError("No refresh token available")
        
        token_endpoint = f"{self.auth_server}/token"
        
        data = {
            "grant_type": "refresh_token",
            "refresh_token": self.tokens["refresh_token"],
            "client_id": self.client_id,
            "client_secret": self.client_secret
        }
        
        response = requests.post(token_endpoint, data=data)
        new_tokens = response.json()
        
        self.tokens["access_token"] = new_tokens["access_token"]
        self.tokens["expires_at"] = datetime.now() + timedelta(
            seconds=new_tokens.get("expires_in", 3600)
        )
        
        return self.tokens["access_token"]

OAuth2 Scopes for AI Systems

Define granular permissions for AI resources:

  • ai.model.read: View model metadata and configurations
  • ai.model.predict: Make predictions using models
  • ai.model.train: Train and update models
  • ai.data.read: Access training and validation data
  • ai.agent.execute: Run autonomous AI agents

OpenID Connect (OIDC) for Identity

OIDC Identity Layer for AI

OIDC builds on OAuth2 to provide identity verification, crucial for ensuring AI systems know who is making requests and maintaining audit trails.

# OIDC implementation for AI user identity
import json
from jose import jwt, JWTError
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend

class OIDCProvider:
    """OpenID Connect provider for AI platform"""
    
    def __init__(self, issuer: str, jwks_uri: str):
        self.issuer = issuer
        self.jwks_uri = jwks_uri
        self.discovery_document = self._fetch_discovery()
        self.jwks = self._fetch_jwks()
    
    def _fetch_discovery(self) -> Dict:
        """Fetch OIDC discovery document"""
        response = requests.get(f"{self.issuer}/.well-known/openid-configuration")
        return response.json()
    
    def _fetch_jwks(self) -> Dict:
        """Fetch JSON Web Key Set for token validation"""
        response = requests.get(self.jwks_uri)
        return response.json()
    
    def validate_id_token(self, 
                          id_token: str, 
                          client_id: str,
                          nonce: Optional[str] = None) -> Dict:
        """Validate and decode OIDC ID token"""
        
        # Decode header to get key ID
        unverified_header = jwt.get_unverified_header(id_token)
        kid = unverified_header["kid"]
        
        # Find the correct key
        key = None
        for jwk in self.jwks["keys"]:
            if jwk["kid"] == kid:
                key = jwk
                break
        
        if not key:
            raise ValueError("Key not found in JWKS")
        
        # Validate the token
        try:
            claims = jwt.decode(
                id_token,
                key,
                algorithms=["RS256"],
                audience=client_id,
                issuer=self.issuer,
                options={
                    "verify_signature": True,
                    "verify_aud": True,
                    "verify_iss": True,
                    "verify_exp": True,
                    "verify_iat": True
                }
            )
            
            # Verify nonce if provided
            if nonce and claims.get("nonce") != nonce:
                raise ValueError("Nonce mismatch")
            
            return claims
            
        except JWTError as e:
            raise ValueError(f"Token validation failed: {e}")
    
    def get_user_info(self, access_token: str) -> Dict:
        """Fetch user information from UserInfo endpoint"""
        userinfo_endpoint = self.discovery_document["userinfo_endpoint"]
        
        headers = {"Authorization": f"Bearer {access_token}"}
        response = requests.get(userinfo_endpoint, headers=headers)
        
        return response.json()

# Example: AI platform user context
class AIUserContext:
    """User context for AI operations"""
    
    def __init__(self, id_token_claims: Dict):
        self.user_id = id_token_claims["sub"]
        self.email = id_token_claims.get("email")
        self.name = id_token_claims.get("name")
        self.roles = id_token_claims.get("roles", [])
        self.department = id_token_claims.get("department")
        self.clearance_level = id_token_claims.get("clearance_level", 0)
    
    def can_access_model(self, model_classification: str) -> bool:
        """Check if user can access model based on classification"""
        classifications = {
            "public": 0,
            "internal": 1,
            "confidential": 2,
            "restricted": 3
        }
        return self.clearance_level >= classifications.get(model_classification, 0)

JWT for Secure Information Exchange

JWT Structure and Implementation

JSON Web Tokens provide a compact, self-contained way to securely transmit information between AI services. JWTs are particularly useful for stateless authentication and authorization.

JWT Components:

  • Header: Token type and signing algorithm
  • Payload: Claims about the entity (user, service, agent)
  • Signature: Ensures token integrity
# JWT implementation for AI service communication
import jwt
import uuid
from datetime import datetime, timedelta
from typing import Dict, List, Optional

class JWTManager:
    """JWT management for AI services"""
    
    def __init__(self, 
                 private_key: str,
                 public_key: str,
                 issuer: str = "ai-platform"):
        self.private_key = private_key
        self.public_key = public_key
        self.issuer = issuer
    
    def create_access_token(self,
                           subject: str,
                           scopes: List[str],
                           additional_claims: Dict = {},
                           expires_in: int = 3600) -> str:
        """Create JWT access token for AI service"""
        
        now = datetime.utcnow()
        
        payload = {
            "iss": self.issuer,
            "sub": subject,
            "iat": now,
            "exp": now + timedelta(seconds=expires_in),
            "jti": str(uuid.uuid4()),
            "scopes": scopes,
            "type": "access"
        }
        
        payload.update(additional_claims)
        
        token = jwt.encode(
            payload,
            self.private_key,
            algorithm="RS256"
        )
        
        return token
    
    def create_ai_agent_token(self,
                              agent_id: str,
                              capabilities: List[str],
                              resource_limits: Dict,
                              expires_in: int = 1800) -> str:
        """Create specialized JWT for AI agent"""
        
        now = datetime.utcnow()
        
        payload = {
            "iss": self.issuer,
            "sub": f"agent:{agent_id}",
            "iat": now,
            "exp": now + timedelta(seconds=expires_in),
            "jti": str(uuid.uuid4()),
            "type": "agent",
            "agent": {
                "id": agent_id,
                "capabilities": capabilities,
                "resource_limits": resource_limits,
                "max_requests": resource_limits.get("max_requests", 100),
                "allowed_models": resource_limits.get("allowed_models", [])
            }
        }
        
        token = jwt.encode(
            payload,
            self.private_key,
            algorithm="RS256"
        )
        
        return token
    
    def validate_token(self, 
                      token: str,
                      verify_exp: bool = True) -> Dict:
        """Validate and decode JWT"""
        
        try:
            payload = jwt.decode(
                token,
                self.public_key,
                algorithms=["RS256"],
                issuer=self.issuer,
                options={
                    "verify_signature": True,
                    "verify_exp": verify_exp,
                    "verify_iss": True,
                    "verify_iat": True
                }
            )
            return payload
            
        except jwt.ExpiredSignatureError:
            raise ValueError("Token has expired")
        except jwt.InvalidTokenError as e:
            raise ValueError(f"Invalid token: {e}")
    
    def create_model_access_token(self,
                                  user_id: str,
                                  model_id: str,
                                  permissions: List[str],
                                  expires_in: int = 900) -> str:
        """Create fine-grained token for model access"""
        
        now = datetime.utcnow()
        
        payload = {
            "iss": self.issuer,
            "sub": user_id,
            "iat": now,
            "exp": now + timedelta(seconds=expires_in),
            "jti": str(uuid.uuid4()),
            "type": "model_access",
            "model": {
                "id": model_id,
                "permissions": permissions,
                "rate_limit": 100,  # requests per minute
                "max_batch_size": 32
            }
        }
        
        token = jwt.encode(
            payload,
            self.private_key,
            algorithm="RS256"
        )
        
        return token

JWT Security Best Practices

  • Use RS256 or ES256 for signing (avoid HS256 in production)
  • Keep token lifetimes short (15-60 minutes for access tokens)
  • Never store sensitive data in JWT payload (it's Base64, not encrypted)
  • Implement token rotation and revocation mechanisms
  • Use JTI (JWT ID) to prevent replay attacks

Service-to-Service Authentication

Securing AI Microservices Communication

AI systems often consist of multiple microservices that need to authenticate with each other. Implementing mutual TLS (mTLS) with JWT provides defense in depth.

# Service mesh authentication for AI services
import ssl
import certifi
from functools import wraps
from flask import Flask, request, jsonify

class ServiceAuthenticator:
    """Service-to-service authentication for AI platform"""
    
    def __init__(self, jwt_manager: JWTManager):
        self.jwt_manager = jwt_manager
        self.service_registry = {}
    
    def register_service(self, 
                        service_id: str,
                        service_name: str,
                        public_key: str):
        """Register a trusted service"""
        self.service_registry[service_id] = {
            "name": service_name,
            "public_key": public_key,
            "registered_at": datetime.utcnow()
        }
    
    def create_service_token(self,
                            source_service: str,
                            target_service: str,
                            purpose: str) -> str:
        """Create token for service-to-service communication"""
        
        return self.jwt_manager.create_access_token(
            subject=f"service:{source_service}",
            scopes=[f"{target_service}:access"],
            additional_claims={
                "source": source_service,
                "target": target_service,
                "purpose": purpose,
                "service_auth": True
            },
            expires_in=300  # 5 minutes for service calls
        )
    
    def validate_service_token(self, token: str) -> Dict:
        """Validate service-to-service token"""
        payload = self.jwt_manager.validate_token(token)
        
        if not payload.get("service_auth"):
            raise ValueError("Not a service authentication token")
        
        source = payload.get("source")
        if source not in self.service_registry:
            raise ValueError(f"Unknown service: {source}")
        
        return payload

# Flask decorator for service authentication
def require_service_auth(authenticator: ServiceAuthenticator):
    """Decorator to require service authentication"""
    def decorator(f):
        @wraps(f)
        def decorated_function(*args, **kwargs):
            auth_header = request.headers.get("Authorization")
            
            if not auth_header or not auth_header.startswith("Bearer "):
                return jsonify({"error": "Missing or invalid authorization"}), 401
            
            token = auth_header.split(" ")[1]
            
            try:
                payload = authenticator.validate_service_token(token)
                request.service_context = payload
                return f(*args, **kwargs)
            except ValueError as e:
                return jsonify({"error": str(e)}), 403
        
        return decorated_function
    return decorator

# Example: AI Model Service
app = Flask(__name__)
jwt_manager = JWTManager(private_key, public_key)
service_auth = ServiceAuthenticator(jwt_manager)

@app.route("/api/v1/models//predict", methods=["POST"])
@require_service_auth(service_auth)
def predict(model_id: str):
    """Protected model prediction endpoint"""
    
    # Access service context from token
    source_service = request.service_context["source"]
    purpose = request.service_context["purpose"]
    
    # Log service access for audit
    log_service_access(source_service, model_id, purpose)
    
    # Perform prediction
    prediction = perform_prediction(model_id, request.json)
    
    return jsonify({
        "prediction": prediction,
        "model_id": model_id,
        "requested_by": source_service
    })

Token Management Best Practices

Token Lifecycle Management

Proper token management is crucial for maintaining security in AI systems while ensuring smooth operations.

# Token lifecycle management system
import redis
from typing import Optional

class TokenLifecycleManager:
    """Manages token lifecycle for AI platform"""
    
    def __init__(self, 
                 redis_client: redis.Redis,
                 jwt_manager: JWTManager):
        self.redis = redis_client
        self.jwt_manager = jwt_manager
    
    def issue_token_pair(self, 
                        user_id: str,
                        scopes: List[str]) -> Dict:
        """Issue access and refresh token pair"""
        
        # Create access token (short-lived)
        access_token = self.jwt_manager.create_access_token(
            subject=user_id,
            scopes=scopes,
            expires_in=900  # 15 minutes
        )
        
        # Create refresh token (long-lived)
        refresh_token = str(uuid.uuid4())
        refresh_data = {
            "user_id": user_id,
            "scopes": scopes,
            "created_at": datetime.utcnow().isoformat(),
            "usage_count": 0
        }
        
        # Store refresh token in Redis with expiry
        self.redis.setex(
            f"refresh:{refresh_token}",
            timedelta(days=30),
            json.dumps(refresh_data)
        )
        
        return {
            "access_token": access_token,
            "refresh_token": refresh_token,
            "token_type": "Bearer",
            "expires_in": 900
        }
    
    def refresh_tokens(self, refresh_token: str) -> Dict:
        """Use refresh token to get new access token"""
        
        # Retrieve refresh token data
        refresh_data = self.redis.get(f"refresh:{refresh_token}")
        if not refresh_data:
            raise ValueError("Invalid refresh token")
        
        data = json.loads(refresh_data)
        
        # Check usage limit (prevent abuse)
        if data["usage_count"] >= 100:
            self.revoke_refresh_token(refresh_token)
            raise ValueError("Refresh token usage limit exceeded")
        
        # Update usage count
        data["usage_count"] += 1
        self.redis.setex(
            f"refresh:{refresh_token}",
            timedelta(days=30),
            json.dumps(data)
        )
        
        # Issue new access token
        access_token = self.jwt_manager.create_access_token(
            subject=data["user_id"],
            scopes=data["scopes"],
            expires_in=900
        )
        
        return {
            "access_token": access_token,
            "token_type": "Bearer",
            "expires_in": 900
        }
    
    def revoke_token(self, jti: str):
        """Revoke an access token by JTI"""
        # Add to revocation list with expiry matching token expiry
        self.redis.setex(
            f"revoked:{jti}",
            timedelta(hours=1),
            "1"
        )
    
    def is_token_revoked(self, jti: str) -> bool:
        """Check if token has been revoked"""
        return self.redis.exists(f"revoked:{jti}") > 0
    
    def revoke_all_user_tokens(self, user_id: str):
        """Revoke all tokens for a user (emergency)"""
        # Set a flag that will be checked during validation
        self.redis.setex(
            f"user_revoked:{user_id}",
            timedelta(days=1),
            datetime.utcnow().isoformat()
        )

Standards Comparison

Standard Purpose Use Case in AI Key Features
OAuth2 Delegated authorization AI apps accessing user resources Multiple flows, scopes, refresh tokens
OIDC Identity verification User authentication in AI platforms ID tokens, UserInfo, discovery
JWT Information exchange Service-to-service auth, API tokens Self-contained, stateless, claims
SAML Enterprise SSO Enterprise AI platform integration XML-based, enterprise features
mTLS Mutual authentication Microservice communication Certificate-based, bidirectional

Best Practices for AI Authentication

  • Use OAuth2 + OIDC: Combine for complete auth solution
  • Short token lifetimes: 15-60 minutes for access tokens
  • Implement token rotation: Regularly refresh tokens
  • Audit all access: Log authentication and authorization events
  • Use scopes effectively: Define granular permissions for AI resources
  • Implement rate limiting: Prevent token abuse and DoS attacks
  • Store tokens securely: Use secure storage, never in localStorage for web apps
  • Validate everything: Check signatures, expiry, issuer, audience
  • Plan for revocation: Implement token blacklisting mechanism
  • Monitor anomalies: Detect unusual authentication patterns

Module 5: Security & Compliance Topics