What is MLflow on Databricks?
Why MLflow Matters
The Problem: Data science teams struggle with experiment reproducibility, model versioning, and the gap between training a model in a notebook and deploying it to production.
The Solution: MLflow is an open-source platform for the complete ML lifecycle: experiment tracking, model packaging, registry, and serving. Databricks provides a fully managed, integrated MLflow experience.
Real Impact: Teams using MLflow on Databricks deploy models to production 3x faster and reduce the time to reproduce experiments from days to seconds.
Real-World Analogy
Think of MLflow like a professional kitchen management system:
- Experiment Tracking = Recipe notebook where you log every ingredient, quantity, and outcome
- Model Registry = The approved recipe book -- only tested recipes make it in
- Model Serving = The actual restaurant kitchen serving dishes to customers
- Feature Store = Pre-prepped ingredients ready to use in any recipe
Experiment Tracking
MLflow Tracking lets you log parameters, metrics, and artifacts for every training run. On Databricks, tracking is automatic -- every notebook run is logged as an experiment.
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score
# Set the experiment name
mlflow.set_experiment("/Users/[email protected]/churn-prediction")
# Load data from Silver layer
df = spark.table("silver.customer_features").toPandas()
X = df.drop("churned", axis=1)
y = df["churned"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42)
# Train with MLflow auto-logging
mlflow.sklearn.autolog()
with mlflow.start_run(run_name="rf-baseline") as run:
# Log custom parameters
mlflow.log_param("dataset_version", "2024-03-15")
mlflow.log_param("feature_count", X.shape[1])
# Train model
model = RandomForestClassifier(
n_estimators=100, max_depth=10, random_state=42)
model.fit(X_train, y_train)
# Evaluate and log metrics
predictions = model.predict(X_test)
mlflow.log_metric("accuracy", accuracy_score(y_test, predictions))
mlflow.log_metric("f1_score", f1_score(y_test, predictions))
mlflow.log_metric("precision", precision_score(y_test, predictions))
# Log the model
mlflow.sklearn.log_model(model, "random_forest_model")
print(f"Run ID: {run.info.run_id}")
print(f"Accuracy: {accuracy_score(y_test, predictions):.4f}")
Model Registry
The Model Registry provides a centralized store for managing model versions, stage transitions, and approval workflows. In Unity Catalog, models are governed just like tables.
import mlflow
from mlflow import MlflowClient
client = MlflowClient()
# Register a model from an experiment run
model_uri = f"runs:/{run.info.run_id}/random_forest_model"
model_details = mlflow.register_model(
model_uri,
"prod_catalog.ml_models.churn_predictor" # UC model path
)
# Add model description
client.update_registered_model(
name="prod_catalog.ml_models.churn_predictor",
description="Predicts customer churn based on usage features"
)
# Transition model to Staging
client.set_registered_model_alias(
name="prod_catalog.ml_models.churn_predictor",
alias="staging",
version=model_details.version
)
# After validation, promote to Production
client.set_registered_model_alias(
name="prod_catalog.ml_models.churn_predictor",
alias="production",
version=model_details.version
)
# Load a model by alias for inference
model = mlflow.pyfunc.load_model(
"models:/prod_catalog.ml_models.churn_predictor@production"
)
predictions = model.predict(new_data)
AutoML
Databricks AutoML automatically trains and tunes models, generating a leaderboard of results. It creates reproducible notebooks for each trial so you can inspect and customize the best approach.
from databricks import automl
# Run AutoML classification
summary = automl.classify(
dataset=spark.table("silver.customer_features"),
target_col="churned",
primary_metric="f1",
timeout_minutes=30,
max_trials=20
)
# View results
print(f"Best trial: {summary.best_trial}")
print(f"Best F1: {summary.best_trial.metrics['test_f1_score']:.4f}")
# The best model is automatically logged to MLflow
best_model = mlflow.pyfunc.load_model(
f"runs:/{summary.best_trial.mlflow_run_id}/model"
)
# AutoML also supports regression and forecasting
reg_summary = automl.regress(
dataset=spark.table("silver.house_features"),
target_col="price",
primary_metric="rmse",
timeout_minutes=20
)
Feature Store
The Feature Store provides a centralized repository for ML features, ensuring consistency between training and serving. Features are stored as Delta tables governed by Unity Catalog.
from databricks.feature_engineering import FeatureEngineeringClient, FeatureLookup
fe = FeatureEngineeringClient()
# Create a feature table
customer_features = spark.sql("""
SELECT
user_id,
COUNT(*) as total_orders,
AVG(amount) as avg_order_value,
MAX(order_date) as last_order_date,
DATEDIFF(current_date(), MAX(order_date)) as days_since_last_order
FROM silver.orders
GROUP BY user_id
""")
fe.create_table(
name="prod_catalog.ml_features.customer_features",
primary_keys=["user_id"],
df=customer_features,
description="Customer behavior features for churn prediction"
)
# Train a model with feature lookups
training_set = fe.create_training_set(
df=spark.table("silver.churn_labels"), # user_id + label
feature_lookups=[
FeatureLookup(
table_name="prod_catalog.ml_features.customer_features",
lookup_key="user_id"
)
],
label="churned"
)
# Convert to pandas for sklearn training
training_df = training_set.load_df().toPandas()
Model Serving
Databricks Model Serving deploys MLflow models as auto-scaling REST API endpoints. It handles infrastructure, scaling, and monitoring automatically.
import requests, json
# Model Serving is configured via UI or REST API
# After deployment, query the endpoint:
endpoint_url = (
"https://my-workspace.databricks.com"
"/serving-endpoints/churn-predictor/invocations"
)
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json"
}
# Single prediction
payload = {
"dataframe_records": [{
"total_orders": 15,
"avg_order_value": 82.50,
"days_since_last_order": 45
}]
}
response = requests.post(endpoint_url, headers=headers,
json=payload)
print(response.json())
# {"predictions": [0]} -- 0 = not churning, 1 = churning
# Batch prediction using the model directly
import mlflow
model = mlflow.pyfunc.load_model(
"models:/prod_catalog.ml_models.churn_predictor@production"
)
batch_predictions = model.predict(
spark.table("silver.customer_features").toPandas()
)
Practice Problems
Problem 1: Design an ML Pipeline
MediumDesign an end-to-end ML pipeline for predicting delivery delays. Data comes from silver.orders and silver.logistics tables. The model needs to be retrained weekly and served via REST API.
Problem 2: A/B Testing Models
HardYou have two candidate models for fraud detection: a gradient boosted tree (fast, 92% recall) and a deep learning model (slower, 97% recall). Design an A/B testing strategy using Model Serving to determine which performs better in production.
Problem 3: Feature Store Design
MediumDesign a Feature Store schema for a recommendation engine. Users browse products, add to cart, and make purchases. The model needs real-time features (current session) and historical features (past 30 days of behavior).