Problem Statement & Requirements
Why ML Training Pipelines Matter
At companies like Google, Meta, and OpenAI, models are retrained hundreds of times per day. A well-designed training pipeline is the backbone of any ML system — it determines how quickly you can iterate on models, reproduce experiments, and ship improvements to production.
Think of an ML training pipeline like a factory assembly line. Raw materials (data) enter one end, pass through quality control (validation), get processed (feature engineering), assembled (training), tested (evaluation), and finally shipped (deployed to production). Each step must be reliable, repeatable, and observable.
Functional Requirements
- Data ingestion & validation — Ingest from multiple sources with schema validation and drift detection
- Feature engineering — Reproducible feature transformations with versioning
- Distributed training — Scale training across multiple GPUs/nodes
- Experiment tracking — Log hyperparameters, metrics, artifacts for every run
- Hyperparameter tuning — Automated search with early stopping
- Model registry — Version, stage, and promote models through environments
- Pipeline orchestration — DAG-based workflows with retries and scheduling
Non-Functional Requirements
- Reproducibility — Any experiment can be exactly reproduced from logged metadata
- Fault tolerance — Resume from checkpoints after node failures
- Scalability — Handle petabyte-scale datasets and 100+ GPU clusters
- Cost efficiency — Utilize spot/preemptible instances, auto-scale down when idle
Back-of-Envelope Estimation
| Parameter | Estimate |
|---|---|
| Training data size | 10 TB (tabular) / 500 TB (vision/NLP) |
| Training runs per day | 50-200 across teams |
| GPU cluster size | 256 A100 GPUs (8 nodes × 32 GPUs) |
| Single training run (large model) | 4-72 hours |
| Checkpoint size | 1-50 GB per checkpoint |
| Experiment metadata | ~10 KB per run × 200 runs/day = 2 MB/day |
| Artifact storage | ~500 GB/day (models, logs, plots) |
| GPU cost | $2-3/hour per A100 = $15K-20K/day for cluster |
System API Design
# Submit a training pipeline run
POST /api/v1/pipelines/runs
{
"pipeline_id": "fraud_model_v2",
"config": {
"data_source": "s3://data/fraud/2024/",
"model_type": "xgboost",
"hyperparams": { "lr": 0.01, "depth": 8 },
"resources": { "gpus": 4, "memory": "64Gi" }
}
}
# Get experiment run details
GET /api/v1/experiments/{experiment_id}/runs/{run_id}
# Register a trained model
POST /api/v1/models/register
{
"name": "fraud_detector",
"version": "3.1",
"artifact_uri": "s3://models/fraud/run_123/",
"metrics": { "auc": 0.97, "f1": 0.89 },
"stage": "staging"
}
# Promote model to production
PUT /api/v1/models/fraud_detector/versions/3.1/stage
{ "stage": "production" }
Data Model
-- Pipeline definitions
CREATE TABLE pipelines (
pipeline_id VARCHAR PRIMARY KEY,
name TEXT,
dag_config JSONB, -- step definitions and dependencies
schedule VARCHAR, -- cron expression
created_by VARCHAR
);
-- Experiment tracking
CREATE TABLE experiments (
experiment_id VARCHAR PRIMARY KEY,
name TEXT,
description TEXT,
owner VARCHAR
);
CREATE TABLE runs (
run_id VARCHAR PRIMARY KEY,
experiment_id VARCHAR REFERENCES experiments,
status VARCHAR, -- running, completed, failed
params JSONB, -- hyperparameters
metrics JSONB, -- evaluation metrics
artifact_uri TEXT, -- S3 path to artifacts
started_at TIMESTAMP,
ended_at TIMESTAMP,
git_commit VARCHAR -- code version
);
-- Model registry
CREATE TABLE model_versions (
model_name VARCHAR,
version VARCHAR,
run_id VARCHAR REFERENCES runs,
stage VARCHAR, -- staging, production, archived
artifact_uri TEXT,
created_at TIMESTAMP,
PRIMARY KEY (model_name, version)
);
High-Level Architecture
The pipeline follows a DAG-based orchestration pattern where each step is an independent, containerized task:
Step 1: Data Ingestion
Pull data from sources (S3, databases, streaming). Validate schema against expected contract. Compute data quality metrics (completeness, freshness, distribution stats).
Step 2: Data Validation
Detect schema drift (new/removed columns, type changes). Check for data skew compared to training baseline. Alert on anomalies using Great Expectations or TensorFlow Data Validation.
Step 3: Feature Engineering
Apply transformations: encoding, normalization, aggregations. Version transformations so training and serving use identical logic. Materialized to feature store for online serving.
Step 4: Model Training
Distributed training across GPU cluster. Checkpoint every N steps to object storage. Log metrics to experiment tracker in real-time.
Step 5: Evaluation & Validation
Run evaluation on holdout set. Compare against production baseline. Check for bias across slices (demographics, regions). Gate promotion on quality thresholds.
Step 6: Model Registry
Register validated model with metadata. Promote through stages: staging → canary → production. Trigger downstream serving pipeline.
Deep Dive: Core Components
Distributed Training: Data vs. Model Parallelism
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
def setup_distributed(rank, world_size):
dist.init_process_group("nccl", rank=rank,
world_size=world_size)
torch.cuda.set_device(rank)
def train(rank, world_size, dataset):
setup_distributed(rank, world_size)
model = MyModel().to(rank)
# Wrap model for distributed training
model = DistributedDataParallel(model, device_ids=[rank])
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank
)
loader = DataLoader(dataset, sampler=sampler,
batch_size=64)
for epoch in range(100):
sampler.set_epoch(epoch)
for batch in loader:
loss = model(batch)
loss.backward() # Gradients synced via AllReduce
optimizer.step()
Data Parallelism vs. Model Parallelism
Data parallelism: Each GPU gets a copy of the model but different data batches. Gradients are synchronized via AllReduce. Best for models that fit in one GPU.
Model parallelism: Model is split across GPUs (pipeline or tensor parallelism). Required for models too large for single GPU (LLMs with billions of parameters).
FSDP (Fully Sharded Data Parallel): Hybrid approach that shards model parameters, gradients, and optimizer states across GPUs. Best of both worlds.
Experiment Tracking
Every training run must log: hyperparameters, metrics over time, code version (git SHA), data version, environment (GPU type, framework version), and artifacts (model weights, plots).
Hyperparameter Tuning
| Method | Pros | Cons |
|---|---|---|
| Grid Search | Exhaustive, reproducible | Exponential cost with dimensions |
| Random Search | Better coverage, parallelizable | No learning between trials |
| Bayesian (TPE) | Sample-efficient, learns from history | Sequential, harder to parallelize |
| Successive Halving | Early stopping saves compute | Assumes early performance predicts final |
Pipeline Orchestration
Use DAG-based orchestrators (Airflow, Kubeflow, Prefect) to define step dependencies, handle retries, and manage scheduling. Each step runs in an isolated container with pinned dependencies.
Scaling & Optimization
GPU Cluster Management
- Spot/preemptible instances: 60-80% cost savings. Requires robust checkpointing
- Mixed precision training: FP16/BF16 for forward pass, FP32 for gradient accumulation. 2x throughput
- Gradient accumulation: Simulate larger batch sizes without more GPU memory
- Data loading optimization: Prefetch, parallel workers, memory-mapped files
Checkpointing Strategy
Critical: Checkpoint Frequently
With spot instances, nodes can be preempted at any time. Checkpoint every 15-30 minutes to object storage (S3/GCS). Keep last 3 checkpoints and delete older ones. Resume training from latest checkpoint on new nodes.
Practice Problems
Practice 1: Spot Instance Recovery
Your 32-GPU training job loses 8 GPUs due to spot preemption mid-epoch. Design a recovery strategy that minimizes wasted compute and resumes training within 5 minutes.
Practice 2: Data Versioning
Your training data changes daily as new events stream in. How do you ensure that any past experiment can be exactly reproduced? Design a data versioning system.
Practice 3: Multi-Tenant Cluster
Three ML teams share a 128-GPU cluster. Design a fair scheduling system that handles priority jobs, prevents starvation, and maximizes GPU utilization.
Quick Reference
| Component | Technology | Purpose |
|---|---|---|
| Orchestration | Kubeflow / Airflow | DAG-based pipeline management |
| Distributed Training | PyTorch DDP / DeepSpeed | Multi-GPU/node training |
| Experiment Tracking | MLflow / W&B | Log params, metrics, artifacts |
| HPO | Optuna / Ray Tune | Hyperparameter optimization |
| Data Validation | Great Expectations / TFDV | Schema & distribution checks |
| Model Registry | MLflow / Vertex AI | Version and stage models |
| Compute | Kubernetes + GPU operator | Cluster management |
Key Takeaways
- DAG orchestration ensures reproducible, fault-tolerant pipelines
- Use distributed data parallel for most workloads; model parallelism for LLMs
- Checkpoint aggressively when using spot instances
- Track everything: code version, data version, hyperparams, environment
- Gate model promotion on automated evaluation against baselines