MLflow & Machine Learning on Databricks

Hard 35 min read

What is MLflow on Databricks?

Why MLflow Matters

The Problem: Data science teams struggle with experiment reproducibility, model versioning, and the gap between training a model in a notebook and deploying it to production.

The Solution: MLflow is an open-source platform for the complete ML lifecycle: experiment tracking, model packaging, registry, and serving. Databricks provides a fully managed, integrated MLflow experience.

Real Impact: Teams using MLflow on Databricks deploy models to production 3x faster and reduce the time to reproduce experiments from days to seconds.

Real-World Analogy

Think of MLflow like a professional kitchen management system:

  • Experiment Tracking = Recipe notebook where you log every ingredient, quantity, and outcome
  • Model Registry = The approved recipe book -- only tested recipes make it in
  • Model Serving = The actual restaurant kitchen serving dishes to customers
  • Feature Store = Pre-prepped ingredients ready to use in any recipe
MLflow Architecture on Databricks
Experiment Tracking Parameters & Hyperparams Metrics (accuracy, loss, F1) Artifacts (model files, plots) Register Model Registry Version Management Stage Transitions Approval Workflows Deploy Model Serving REST API Endpoints Auto-scaling A/B Testing AutoML Automated model selection Hyperparameter tuning Feature importance Feature Store Centralized features Point-in-time lookups Online & offline serving Registry Lifecycle None --> Staging Staging --> Production Production --> Archived

Experiment Tracking

MLflow Tracking lets you log parameters, metrics, and artifacts for every training run. On Databricks, tracking is automatic -- every notebook run is logged as an experiment.

PySpark/Python - MLflow Experiment Tracking
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score

# Set the experiment name
mlflow.set_experiment("/Users/[email protected]/churn-prediction")

# Load data from Silver layer
df = spark.table("silver.customer_features").toPandas()
X = df.drop("churned", axis=1)
y = df["churned"]
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42)

# Train with MLflow auto-logging
mlflow.sklearn.autolog()

with mlflow.start_run(run_name="rf-baseline") as run:
    # Log custom parameters
    mlflow.log_param("dataset_version", "2024-03-15")
    mlflow.log_param("feature_count", X.shape[1])

    # Train model
    model = RandomForestClassifier(
        n_estimators=100, max_depth=10, random_state=42)
    model.fit(X_train, y_train)

    # Evaluate and log metrics
    predictions = model.predict(X_test)
    mlflow.log_metric("accuracy", accuracy_score(y_test, predictions))
    mlflow.log_metric("f1_score", f1_score(y_test, predictions))
    mlflow.log_metric("precision", precision_score(y_test, predictions))

    # Log the model
    mlflow.sklearn.log_model(model, "random_forest_model")

    print(f"Run ID: {run.info.run_id}")
    print(f"Accuracy: {accuracy_score(y_test, predictions):.4f}")

Model Registry

The Model Registry provides a centralized store for managing model versions, stage transitions, and approval workflows. In Unity Catalog, models are governed just like tables.

Python - Model Registry Operations
import mlflow
from mlflow import MlflowClient

client = MlflowClient()

# Register a model from an experiment run
model_uri = f"runs:/{run.info.run_id}/random_forest_model"
model_details = mlflow.register_model(
    model_uri,
    "prod_catalog.ml_models.churn_predictor"  # UC model path
)

# Add model description
client.update_registered_model(
    name="prod_catalog.ml_models.churn_predictor",
    description="Predicts customer churn based on usage features"
)

# Transition model to Staging
client.set_registered_model_alias(
    name="prod_catalog.ml_models.churn_predictor",
    alias="staging",
    version=model_details.version
)

# After validation, promote to Production
client.set_registered_model_alias(
    name="prod_catalog.ml_models.churn_predictor",
    alias="production",
    version=model_details.version
)

# Load a model by alias for inference
model = mlflow.pyfunc.load_model(
    "models:/prod_catalog.ml_models.churn_predictor@production"
)
predictions = model.predict(new_data)

AutoML

Databricks AutoML automatically trains and tunes models, generating a leaderboard of results. It creates reproducible notebooks for each trial so you can inspect and customize the best approach.

Python - Databricks AutoML
from databricks import automl

# Run AutoML classification
summary = automl.classify(
    dataset=spark.table("silver.customer_features"),
    target_col="churned",
    primary_metric="f1",
    timeout_minutes=30,
    max_trials=20
)

# View results
print(f"Best trial: {summary.best_trial}")
print(f"Best F1: {summary.best_trial.metrics['test_f1_score']:.4f}")

# The best model is automatically logged to MLflow
best_model = mlflow.pyfunc.load_model(
    f"runs:/{summary.best_trial.mlflow_run_id}/model"
)

# AutoML also supports regression and forecasting
reg_summary = automl.regress(
    dataset=spark.table("silver.house_features"),
    target_col="price",
    primary_metric="rmse",
    timeout_minutes=20
)

Feature Store

The Feature Store provides a centralized repository for ML features, ensuring consistency between training and serving. Features are stored as Delta tables governed by Unity Catalog.

Python - Feature Store Operations
from databricks.feature_engineering import FeatureEngineeringClient, FeatureLookup

fe = FeatureEngineeringClient()

# Create a feature table
customer_features = spark.sql("""
    SELECT
        user_id,
        COUNT(*) as total_orders,
        AVG(amount) as avg_order_value,
        MAX(order_date) as last_order_date,
        DATEDIFF(current_date(), MAX(order_date)) as days_since_last_order
    FROM silver.orders
    GROUP BY user_id
""")

fe.create_table(
    name="prod_catalog.ml_features.customer_features",
    primary_keys=["user_id"],
    df=customer_features,
    description="Customer behavior features for churn prediction"
)

# Train a model with feature lookups
training_set = fe.create_training_set(
    df=spark.table("silver.churn_labels"),  # user_id + label
    feature_lookups=[
        FeatureLookup(
            table_name="prod_catalog.ml_features.customer_features",
            lookup_key="user_id"
        )
    ],
    label="churned"
)

# Convert to pandas for sklearn training
training_df = training_set.load_df().toPandas()

Model Serving

Databricks Model Serving deploys MLflow models as auto-scaling REST API endpoints. It handles infrastructure, scaling, and monitoring automatically.

Python - Deploy and Query a Model Endpoint
import requests, json

# Model Serving is configured via UI or REST API
# After deployment, query the endpoint:

endpoint_url = (
    "https://my-workspace.databricks.com"
    "/serving-endpoints/churn-predictor/invocations"
)

headers = {
    "Authorization": f"Bearer {token}",
    "Content-Type": "application/json"
}

# Single prediction
payload = {
    "dataframe_records": [{
        "total_orders": 15,
        "avg_order_value": 82.50,
        "days_since_last_order": 45
    }]
}

response = requests.post(endpoint_url, headers=headers,
                         json=payload)
print(response.json())
# {"predictions": [0]}  -- 0 = not churning, 1 = churning

# Batch prediction using the model directly
import mlflow

model = mlflow.pyfunc.load_model(
    "models:/prod_catalog.ml_models.churn_predictor@production"
)
batch_predictions = model.predict(
    spark.table("silver.customer_features").toPandas()
)

Practice Problems

Problem 1: Design an ML Pipeline

Medium

Design an end-to-end ML pipeline for predicting delivery delays. Data comes from silver.orders and silver.logistics tables. The model needs to be retrained weekly and served via REST API.

Problem 2: A/B Testing Models

Hard

You have two candidate models for fraud detection: a gradient boosted tree (fast, 92% recall) and a deep learning model (slower, 97% recall). Design an A/B testing strategy using Model Serving to determine which performs better in production.

Problem 3: Feature Store Design

Medium

Design a Feature Store schema for a recommendation engine. Users browse products, add to cart, and make purchases. The model needs real-time features (current session) and historical features (past 30 days of behavior).