Model Governance

Part of Module 5: Security & Compliance

Model governance ensures AI models are developed, validated, and deployed responsibly. This module covers validation frameworks, bias testing methodologies, audit trails, and compliance requirements for enterprise AI systems.

Model Validation Framework

Comprehensive Model Validation

Model validation goes beyond accuracy metrics to ensure models are robust, fair, and reliable in production environments.

Validation Dimensions:

  • Statistical Validation: Performance metrics, confidence intervals
  • Business Validation: Alignment with business objectives
  • Ethical Validation: Fairness and bias assessment
  • Technical Validation: Robustness and stability testing
  • Regulatory Validation: Compliance with standards
# Model validation framework implementation
import numpy as np
import pandas as pd
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from datetime import datetime
import hashlib

@dataclass
class ValidationResult:
    """Container for validation results"""
    passed: bool
    score: float
    details: Dict[str, Any]
    timestamp: datetime
    validator_version: str

class ModelValidator:
    """Comprehensive model validation system"""
    
    def __init__(self, model_id: str, model_version: str):
        self.model_id = model_id
        self.model_version = model_version
        self.validation_results = {}
        self.validation_history = []
    
    def validate_performance(self, 
                             y_true: np.ndarray,
                             y_pred: np.ndarray,
                             metrics: List[str] = None) -> ValidationResult:
        """Validate model performance metrics"""
        
        from sklearn.metrics import (
            accuracy_score, precision_score, recall_score,
            f1_score, roc_auc_score, mean_squared_error
        )
        
        if metrics is None:
            metrics = ['accuracy', 'precision', 'recall', 'f1']
        
        results = {}
        for metric in metrics:
            if metric == 'accuracy':
                results[metric] = accuracy_score(y_true, y_pred)
            elif metric == 'precision':
                results[metric] = precision_score(y_true, y_pred, average='weighted')
            elif metric == 'recall':
                results[metric] = recall_score(y_true, y_pred, average='weighted')
            elif metric == 'f1':
                results[metric] = f1_score(y_true, y_pred, average='weighted')
        
        # Check against thresholds
        thresholds = {
            'accuracy': 0.85,
            'precision': 0.80,
            'recall': 0.75,
            'f1': 0.78
        }
        
        passed = all(
            results.get(metric, 0) >= thresholds.get(metric, 0)
            for metric in metrics
        )
        
        return ValidationResult(
            passed=passed,
            score=np.mean(list(results.values())),
            details={
                'metrics': results,
                'thresholds': thresholds,
                'sample_size': len(y_true)
            },
            timestamp=datetime.now(),
            validator_version='1.0.0'
        )
    
    def validate_data_drift(self,
                           training_data: pd.DataFrame,
                           production_data: pd.DataFrame,
                           threshold: float = 0.1) -> ValidationResult:
        """Detect data drift between training and production"""
        
        from scipy import stats
        
        drift_scores = {}
        drifted_features = []
        
        for column in training_data.columns:
            if column in production_data.columns:
                # Kolmogorov-Smirnov test for drift
                statistic, p_value = stats.ks_2samp(
                    training_data[column],
                    production_data[column]
                )
                
                drift_scores[column] = {
                    'statistic': statistic,
                    'p_value': p_value,
                    'drifted': p_value < threshold
                }
                
                if p_value < threshold:
                    drifted_features.append(column)
        
        passed = len(drifted_features) == 0
        
        return ValidationResult(
            passed=passed,
            score=1.0 - (len(drifted_features) / len(training_data.columns)),
            details={
                'drift_scores': drift_scores,
                'drifted_features': drifted_features,
                'threshold': threshold
            },
            timestamp=datetime.now(),
            validator_version='1.0.0'
        )
    
    def validate_robustness(self,
                           model,
                           X_test: np.ndarray,
                           noise_levels: List[float] = [0.01, 0.05, 0.1]) -> ValidationResult:
        """Test model robustness to input perturbations"""
        
        original_predictions = model.predict(X_test)
        robustness_scores = {}
        
        for noise_level in noise_levels:
            # Add Gaussian noise
            noise = np.random.normal(0, noise_level, X_test.shape)
            X_noisy = X_test + noise
            
            noisy_predictions = model.predict(X_noisy)
            
            # Calculate prediction stability
            agreement_rate = np.mean(original_predictions == noisy_predictions)
            robustness_scores[f'noise_{noise_level}'] = agreement_rate
        
        avg_robustness = np.mean(list(robustness_scores.values()))
        passed = avg_robustness > 0.9  # 90% stability threshold
        
        return ValidationResult(
            passed=passed,
            score=avg_robustness,
            details={
                'robustness_scores': robustness_scores,
                'noise_levels': noise_levels,
                'test_samples': len(X_test)
            },
            timestamp=datetime.now(),
            validator_version='1.0.0'
        )
    
    def generate_validation_report(self) -> Dict:
        """Generate comprehensive validation report"""
        
        report = {
            'model_id': self.model_id,
            'model_version': self.model_version,
            'validation_timestamp': datetime.now().isoformat(),
            'overall_status': 'PASSED' if all(
                r.passed for r in self.validation_results.values()
            ) else 'FAILED',
            'validations': {}
        }
        
        for name, result in self.validation_results.items():
            report['validations'][name] = {
                'passed': result.passed,
                'score': result.score,
                'details': result.details,
                'timestamp': result.timestamp.isoformat()
            }
        
        # Generate report hash for integrity
        report_str = json.dumps(report, sort_keys=True)
        report['integrity_hash'] = hashlib.sha256(
            report_str.encode()
        ).hexdigest()
        
        return report

Model Validation Checklist

  • ✓ Performance metrics meet thresholds
  • ✓ No significant data drift detected
  • ✓ Model is robust to input noise
  • ✓ Bias metrics within acceptable range
  • ✓ Explainability requirements satisfied
  • ✓ Security vulnerabilities assessed
  • ✓ Compliance requirements met

Bias Testing Methodologies

Comprehensive Bias Detection

Systematic testing for various types of bias ensures AI models treat all groups fairly and meet ethical standards.

# Bias testing framework
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple

class BiasDetector:
    """Detect and measure bias in AI models"""
    
    def __init__(self, protected_attributes: List[str]):
        self.protected_attributes = protected_attributes
        self.bias_metrics = {}
    
    def demographic_parity(self,
                          y_pred: np.ndarray,
                          sensitive_features: pd.DataFrame) -> Dict:
        """Measure demographic parity (statistical parity)"""
        
        results = {}
        
        for feature in self.protected_attributes:
            if feature in sensitive_features.columns:
                groups = sensitive_features[feature].unique()
                positive_rates = {}
                
                for group in groups:
                    mask = sensitive_features[feature] == group
                    positive_rate = np.mean(y_pred[mask])
                    positive_rates[group] = positive_rate
                
                # Calculate disparity
                max_rate = max(positive_rates.values())
                min_rate = min(positive_rates.values())
                disparity = max_rate - min_rate
                
                results[feature] = {
                    'positive_rates': positive_rates,
                    'disparity': disparity,
                    'ratio': min_rate / max_rate if max_rate > 0 else 0,
                    'fair': disparity < 0.1  # 10% threshold
                }
        
        return results
    
    def equalized_odds(self,
                       y_true: np.ndarray,
                       y_pred: np.ndarray,
                       sensitive_features: pd.DataFrame) -> Dict:
        """Measure equalized odds (TPR and FPR parity)"""
        
        results = {}
        
        for feature in self.protected_attributes:
            if feature in sensitive_features.columns:
                groups = sensitive_features[feature].unique()
                tpr_values = {}
                fpr_values = {}
                
                for group in groups:
                    mask = sensitive_features[feature] == group
                    
                    # True Positive Rate
                    positive_mask = y_true[mask] == 1
                    if positive_mask.sum() > 0:
                        tpr = np.mean(y_pred[mask][positive_mask] == 1)
                        tpr_values[group] = tpr
                    
                    # False Positive Rate
                    negative_mask = y_true[mask] == 0
                    if negative_mask.sum() > 0:
                        fpr = np.mean(y_pred[mask][negative_mask] == 1)
                        fpr_values[group] = fpr
                
                tpr_disparity = max(tpr_values.values()) - min(tpr_values.values())
                fpr_disparity = max(fpr_values.values()) - min(fpr_values.values())
                
                results[feature] = {
                    'tpr_values': tpr_values,
                    'fpr_values': fpr_values,
                    'tpr_disparity': tpr_disparity,
                    'fpr_disparity': fpr_disparity,
                    'fair': tpr_disparity < 0.1 and fpr_disparity < 0.1
                }
        
        return results
    
    def individual_fairness(self,
                           model,
                           X: np.ndarray,
                           epsilon: float = 0.1) -> Dict:
        """Test individual fairness (similar individuals, similar outcomes)"""
        
        from sklearn.metrics.pairwise import cosine_similarity
        
        # Get model predictions
        predictions = model.predict_proba(X)[:, 1] if hasattr(
            model, 'predict_proba'
        ) else model.predict(X)
        
        # Compute pairwise similarities
        similarities = cosine_similarity(X)
        
        # Check fairness for similar pairs
        violations = 0
        total_pairs = 0
        
        for i in range(len(X)):
            for j in range(i + 1, len(X)):
                if similarities[i, j] > 0.9:  # Very similar inputs
                    pred_diff = abs(predictions[i] - predictions[j])
                    if pred_diff > epsilon:
                        violations += 1
                    total_pairs += 1
        
        fairness_score = 1 - (violations / total_pairs) if total_pairs > 0 else 1
        
        return {
            'fairness_score': fairness_score,
            'violations': violations,
            'total_similar_pairs': total_pairs,
            'epsilon': epsilon,
            'fair': fairness_score > 0.95
        }
    
    def generate_bias_report(self) -> Dict:
        """Generate comprehensive bias report"""
        
        report = {
            'timestamp': datetime.now().isoformat(),
            'protected_attributes': self.protected_attributes,
            'metrics': self.bias_metrics,
            'overall_fair': all(
                metric.get('fair', False)
                for metric in self.bias_metrics.values()
            ),
            'recommendations': self._generate_recommendations()
        }
        
        return report
    
    def _generate_recommendations(self) -> List[str]:
        """Generate bias mitigation recommendations"""
        
        recommendations = []
        
        for metric_name, metric_data in self.bias_metrics.items():
            if not metric_data.get('fair', True):
                if 'demographic_parity' in metric_name:
                    recommendations.append(
                        "Consider reweighting training data or adjusting decision thresholds"
                    )
                elif 'equalized_odds' in metric_name:
                    recommendations.append(
                        "Apply post-processing calibration to equalize error rates"
                    )
                elif 'individual_fairness' in metric_name:
                    recommendations.append(
                        "Implement fairness constraints during model training"
                    )
        
        return recommendations

Common Bias Types to Test

  • Historical Bias: Training data reflects past discrimination
  • Representation Bias: Underrepresentation of certain groups
  • Measurement Bias: Features measured differently across groups
  • Aggregation Bias: Model assumes one-size-fits-all
  • Evaluation Bias: Benchmarks don't represent all users

Audit Trail Implementation

Comprehensive Model Auditing

Maintain detailed audit trails for all model activities to ensure accountability, reproducibility, and compliance.

# Audit trail system for model governance
import json
import sqlite3
from enum import Enum
from typing import Optional, List, Dict

class AuditEventType(Enum):
    MODEL_CREATED = "model_created"
    MODEL_TRAINED = "model_trained"
    MODEL_VALIDATED = "model_validated"
    MODEL_DEPLOYED = "model_deployed"
    MODEL_ACCESSED = "model_accessed"
    MODEL_UPDATED = "model_updated"
    MODEL_RETIRED = "model_retired"
    PREDICTION_MADE = "prediction_made"
    ANOMALY_DETECTED = "anomaly_detected"

class ModelAuditTrail:
    """Audit trail system for model governance"""
    
    def __init__(self, db_path: str = "model_audit.db"):
        self.conn = sqlite3.connect(db_path)
        self._create_tables()
    
    def _create_tables(self):
        """Create audit trail database tables"""
        
        cursor = self.conn.cursor()
        
        # Main audit log table
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS audit_log (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
                event_type TEXT NOT NULL,
                model_id TEXT NOT NULL,
                model_version TEXT,
                user_id TEXT,
                action_details TEXT,
                metadata TEXT,
                risk_level TEXT,
                ip_address TEXT,
                session_id TEXT
            )
        ''')
        
        # Model lineage table
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS model_lineage (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                model_id TEXT NOT NULL,
                version TEXT NOT NULL,
                parent_version TEXT,
                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
                created_by TEXT,
                training_data_hash TEXT,
                code_version TEXT,
                hyperparameters TEXT,
                performance_metrics TEXT
            )
        ''')
        
        # Prediction log table
        cursor.execute('''
            CREATE TABLE IF NOT EXISTS prediction_log (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
                model_id TEXT NOT NULL,
                model_version TEXT NOT NULL,
                input_hash TEXT,
                prediction TEXT,
                confidence REAL,
                latency_ms INTEGER,
                user_id TEXT
            )
        ''')
        
        self.conn.commit()
    
    def log_event(self,
                  event_type: AuditEventType,
                  model_id: str,
                  user_id: Optional[str] = None,
                  action_details: Optional[Dict] = None,
                  metadata: Optional[Dict] = None,
                  risk_level: str = "low") -> int:
        """Log an audit event"""
        
        cursor = self.conn.cursor()
        
        cursor.execute('''
            INSERT INTO audit_log 
            (event_type, model_id, user_id, action_details, metadata, risk_level)
            VALUES (?, ?, ?, ?, ?, ?)
        ''', (
            event_type.value,
            model_id,
            user_id,
            json.dumps(action_details) if action_details else None,
            json.dumps(metadata) if metadata else None,
            risk_level
        ))
        
        self.conn.commit()
        return cursor.lastrowid
    
    def log_model_lineage(self,
                          model_id: str,
                          version: str,
                          parent_version: Optional[str] = None,
                          training_data_hash: Optional[str] = None,
                          hyperparameters: Optional[Dict] = None) -> int:
        """Log model lineage information"""
        
        cursor = self.conn.cursor()
        
        cursor.execute('''
            INSERT INTO model_lineage
            (model_id, version, parent_version, training_data_hash, hyperparameters)
            VALUES (?, ?, ?, ?, ?)
        ''', (
            model_id,
            version,
            parent_version,
            training_data_hash,
            json.dumps(hyperparameters) if hyperparameters else None
        ))
        
        self.conn.commit()
        return cursor.lastrowid
    
    def log_prediction(self,
                      model_id: str,
                      model_version: str,
                      input_data: Any,
                      prediction: Any,
                      confidence: float,
                      latency_ms: int,
                      user_id: Optional[str] = None):
        """Log individual prediction for audit"""
        
        # Hash input for privacy
        input_str = json.dumps(input_data, sort_keys=True)
        input_hash = hashlib.sha256(input_str.encode()).hexdigest()[:16]
        
        cursor = self.conn.cursor()
        
        cursor.execute('''
            INSERT INTO prediction_log
            (model_id, model_version, input_hash, prediction, confidence, latency_ms, user_id)
            VALUES (?, ?, ?, ?, ?, ?, ?)
        ''', (
            model_id,
            model_version,
            input_hash,
            json.dumps(prediction),
            confidence,
            latency_ms,
            user_id
        ))
        
        self.conn.commit()
    
    def get_model_history(self, model_id: str) -> List[Dict]:
        """Retrieve complete history for a model"""
        
        cursor = self.conn.cursor()
        
        cursor.execute('''
            SELECT * FROM audit_log
            WHERE model_id = ?
            ORDER BY timestamp DESC
        ''', (model_id,))
        
        columns = [desc[0] for desc in cursor.description]
        return [dict(zip(columns, row)) for row in cursor.fetchall()]
    
    def generate_compliance_report(self, 
                                   model_id: str,
                                   start_date: str,
                                   end_date: str) -> Dict:
        """Generate compliance report for auditors"""
        
        cursor = self.conn.cursor()
        
        # Get event counts
        cursor.execute('''
            SELECT event_type, COUNT(*) as count
            FROM audit_log
            WHERE model_id = ? AND timestamp BETWEEN ? AND ?
            GROUP BY event_type
        ''', (model_id, start_date, end_date))
        
        event_counts = dict(cursor.fetchall())
        
        # Get risk events
        cursor.execute('''
            SELECT COUNT(*) FROM audit_log
            WHERE model_id = ? AND risk_level IN ('high', 'critical')
            AND timestamp BETWEEN ? AND ?
        ''', (model_id, start_date, end_date))
        
        high_risk_events = cursor.fetchone()[0]
        
        return {
            'model_id': model_id,
            'period': {'start': start_date, 'end': end_date},
            'event_counts': event_counts,
            'high_risk_events': high_risk_events,
            'generated_at': datetime.now().isoformat()
        }

Model Versioning and Rollback

Version Control for ML Models

Implement robust versioning to track model evolution and enable quick rollbacks when issues are detected.

# Model versioning and rollback system
import pickle
import shutil
from pathlib import Path

class ModelVersionControl:
    """Version control system for ML models"""
    
    def __init__(self, repository_path: str):
        self.repo_path = Path(repository_path)
        self.repo_path.mkdir(parents=True, exist_ok=True)
        self.current_version = None
        self.version_history = []
    
    def save_model_version(self,
                          model,
                          version: str,
                          metadata: Dict,
                          is_production: bool = False) -> str:
        """Save a new model version"""
        
        version_path = self.repo_path / version
        version_path.mkdir(exist_ok=True)
        
        # Save model artifact
        model_file = version_path / "model.pkl"
        with open(model_file, 'wb') as f:
            pickle.dump(model, f)
        
        # Save metadata
        metadata['version'] = version
        metadata['created_at'] = datetime.now().isoformat()
        metadata['is_production'] = is_production
        
        metadata_file = version_path / "metadata.json"
        with open(metadata_file, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        # Update version history
        self.version_history.append(version)
        
        if is_production:
            self.current_version = version
            self._update_production_link(version)
        
        return str(version_path)
    
    def rollback_to_version(self, version: str) -> bool:
        """Rollback to a previous model version"""
        
        if version not in self.version_history:
            raise ValueError(f"Version {version} not found")
        
        # Verify version exists and is valid
        version_path = self.repo_path / version
        if not version_path.exists():
            raise ValueError(f"Version {version} files missing")
        
        # Update production link
        self._update_production_link(version)
        self.current_version = version
        
        # Log rollback event
        self._log_rollback(version)
        
        return True
    
    def _update_production_link(self, version: str):
        """Update production symlink to point to version"""
        
        production_link = self.repo_path / "production"
        if production_link.exists():
            production_link.unlink()
        
        version_path = self.repo_path / version
        production_link.symlink_to(version_path)
    
    def _log_rollback(self, version: str):
        """Log rollback event for audit"""
        
        rollback_log = self.repo_path / "rollback.log"
        with open(rollback_log, 'a') as f:
            f.write(f"{datetime.now().isoformat()}: Rolled back to {version}\n")

Explainability Requirements

Model Explainability Framework

Ensure models can explain their decisions to stakeholders, regulators, and affected individuals.

# Explainability framework using SHAP and LIME
import shap
import lime.lime_tabular

class ModelExplainer:
    """Generate explanations for model predictions"""
    
    def __init__(self, model, X_train: np.ndarray, feature_names: List[str]):
        self.model = model
        self.X_train = X_train
        self.feature_names = feature_names
        
        # Initialize SHAP explainer
        self.shap_explainer = shap.Explainer(model, X_train)
        
        # Initialize LIME explainer
        self.lime_explainer = lime.lime_tabular.LimeTabularExplainer(
            X_train,
            feature_names=feature_names,
            mode='classification'
        )
    
    def explain_prediction(self, 
                          X_instance: np.ndarray,
                          method: str = 'shap') -> Dict:
        """Explain a single prediction"""
        
        if method == 'shap':
            return self._shap_explanation(X_instance)
        elif method == 'lime':
            return self._lime_explanation(X_instance)
        else:
            raise ValueError(f"Unknown method: {method}")
    
    def _shap_explanation(self, X_instance: np.ndarray) -> Dict:
        """Generate SHAP explanation"""
        
        shap_values = self.shap_explainer(X_instance)
        
        # Get feature importance
        feature_importance = {}
        for i, feature in enumerate(self.feature_names):
            feature_importance[feature] = float(shap_values.values[0][i])
        
        return {
            'method': 'shap',
            'feature_importance': feature_importance,
            'base_value': float(shap_values.base_values[0]),
            'prediction': float(self.model.predict(X_instance.reshape(1, -1))[0])
        }
    
    def generate_report(self, X_instance: np.ndarray) -> str:
        """Generate human-readable explanation report"""
        
        explanation = self.explain_prediction(X_instance)
        prediction = explanation['prediction']
        
        report = f"""
        Model Prediction Explanation Report
        ===================================
        
        Prediction: {prediction:.4f}
        Base Value: {explanation['base_value']:.4f}
        
        Top Contributing Features:
        """
        
        # Sort features by importance
        sorted_features = sorted(
            explanation['feature_importance'].items(),
            key=lambda x: abs(x[1]),
            reverse=True
        )[:5]
        
        for feature, importance in sorted_features:
            direction = "increases" if importance > 0 else "decreases"
            report += f"\n- {feature}: {direction} prediction by {abs(importance):.4f}"
        
        return report

Model Governance Best Practices

  • Establish clear governance framework: Define roles, responsibilities, and processes
  • Implement continuous monitoring: Track model performance and fairness metrics
  • Maintain comprehensive documentation: Document all model decisions and changes
  • Regular bias audits: Test for bias across all protected attributes
  • Version everything: Track models, data, code, and configurations
  • Enable quick rollbacks: Maintain ability to revert to previous versions
  • Ensure explainability: Provide clear explanations for all decisions
  • Implement approval workflows: Require review before production deployment
  • Regular compliance checks: Ensure ongoing regulatory compliance
  • Incident response plan: Have procedures for handling model failures

Module 5: Security & Compliance Topics