Design an ML Training Pipeline

Hard 35 min read

Problem Statement & Requirements

Why ML Training Pipelines Matter

At companies like Google, Meta, and OpenAI, models are retrained hundreds of times per day. A well-designed training pipeline is the backbone of any ML system — it determines how quickly you can iterate on models, reproduce experiments, and ship improvements to production.

Think of an ML training pipeline like a factory assembly line. Raw materials (data) enter one end, pass through quality control (validation), get processed (feature engineering), assembled (training), tested (evaluation), and finally shipped (deployed to production). Each step must be reliable, repeatable, and observable.

Functional Requirements

Non-Functional Requirements

Back-of-Envelope Estimation

ParameterEstimate
Training data size10 TB (tabular) / 500 TB (vision/NLP)
Training runs per day50-200 across teams
GPU cluster size256 A100 GPUs (8 nodes × 32 GPUs)
Single training run (large model)4-72 hours
Checkpoint size1-50 GB per checkpoint
Experiment metadata~10 KB per run × 200 runs/day = 2 MB/day
Artifact storage~500 GB/day (models, logs, plots)
GPU cost$2-3/hour per A100 = $15K-20K/day for cluster

System API Design

Pipeline & Experiment APIs
# Submit a training pipeline run
POST /api/v1/pipelines/runs
{
  "pipeline_id": "fraud_model_v2",
  "config": {
    "data_source": "s3://data/fraud/2024/",
    "model_type": "xgboost",
    "hyperparams": { "lr": 0.01, "depth": 8 },
    "resources": { "gpus": 4, "memory": "64Gi" }
  }
}

# Get experiment run details
GET /api/v1/experiments/{experiment_id}/runs/{run_id}

# Register a trained model
POST /api/v1/models/register
{
  "name": "fraud_detector",
  "version": "3.1",
  "artifact_uri": "s3://models/fraud/run_123/",
  "metrics": { "auc": 0.97, "f1": 0.89 },
  "stage": "staging"
}

# Promote model to production
PUT /api/v1/models/fraud_detector/versions/3.1/stage
{ "stage": "production" }

Data Model

Core Schema
-- Pipeline definitions
CREATE TABLE pipelines (
    pipeline_id   VARCHAR PRIMARY KEY,
    name          TEXT,
    dag_config    JSONB,    -- step definitions and dependencies
    schedule      VARCHAR,  -- cron expression
    created_by    VARCHAR
);

-- Experiment tracking
CREATE TABLE experiments (
    experiment_id VARCHAR PRIMARY KEY,
    name          TEXT,
    description   TEXT,
    owner         VARCHAR
);

CREATE TABLE runs (
    run_id        VARCHAR PRIMARY KEY,
    experiment_id VARCHAR REFERENCES experiments,
    status        VARCHAR,  -- running, completed, failed
    params        JSONB,    -- hyperparameters
    metrics       JSONB,    -- evaluation metrics
    artifact_uri  TEXT,     -- S3 path to artifacts
    started_at    TIMESTAMP,
    ended_at      TIMESTAMP,
    git_commit    VARCHAR   -- code version
);

-- Model registry
CREATE TABLE model_versions (
    model_name    VARCHAR,
    version       VARCHAR,
    run_id        VARCHAR REFERENCES runs,
    stage         VARCHAR,  -- staging, production, archived
    artifact_uri  TEXT,
    created_at    TIMESTAMP,
    PRIMARY KEY (model_name, version)
);

High-Level Architecture

The pipeline follows a DAG-based orchestration pattern where each step is an independent, containerized task:

Step 1: Data Ingestion

Pull data from sources (S3, databases, streaming). Validate schema against expected contract. Compute data quality metrics (completeness, freshness, distribution stats).

Step 2: Data Validation

Detect schema drift (new/removed columns, type changes). Check for data skew compared to training baseline. Alert on anomalies using Great Expectations or TensorFlow Data Validation.

Step 3: Feature Engineering

Apply transformations: encoding, normalization, aggregations. Version transformations so training and serving use identical logic. Materialized to feature store for online serving.

Step 4: Model Training

Distributed training across GPU cluster. Checkpoint every N steps to object storage. Log metrics to experiment tracker in real-time.

Step 5: Evaluation & Validation

Run evaluation on holdout set. Compare against production baseline. Check for bias across slices (demographics, regions). Gate promotion on quality thresholds.

Step 6: Model Registry

Register validated model with metadata. Promote through stages: staging → canary → production. Trigger downstream serving pipeline.

Deep Dive: Core Components

Distributed Training: Data vs. Model Parallelism

PyTorch Distributed Data Parallel
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

def setup_distributed(rank, world_size):
    dist.init_process_group("nccl", rank=rank,
                            world_size=world_size)
    torch.cuda.set_device(rank)

def train(rank, world_size, dataset):
    setup_distributed(rank, world_size)
    model = MyModel().to(rank)
    # Wrap model for distributed training
    model = DistributedDataParallel(model, device_ids=[rank])

    sampler = torch.utils.data.distributed.DistributedSampler(
        dataset, num_replicas=world_size, rank=rank
    )
    loader = DataLoader(dataset, sampler=sampler,
                       batch_size=64)

    for epoch in range(100):
        sampler.set_epoch(epoch)
        for batch in loader:
            loss = model(batch)
            loss.backward()  # Gradients synced via AllReduce
            optimizer.step()

Data Parallelism vs. Model Parallelism

Data parallelism: Each GPU gets a copy of the model but different data batches. Gradients are synchronized via AllReduce. Best for models that fit in one GPU.
Model parallelism: Model is split across GPUs (pipeline or tensor parallelism). Required for models too large for single GPU (LLMs with billions of parameters).
FSDP (Fully Sharded Data Parallel): Hybrid approach that shards model parameters, gradients, and optimizer states across GPUs. Best of both worlds.

Experiment Tracking

Every training run must log: hyperparameters, metrics over time, code version (git SHA), data version, environment (GPU type, framework version), and artifacts (model weights, plots).

Hyperparameter Tuning

MethodProsCons
Grid SearchExhaustive, reproducibleExponential cost with dimensions
Random SearchBetter coverage, parallelizableNo learning between trials
Bayesian (TPE)Sample-efficient, learns from historySequential, harder to parallelize
Successive HalvingEarly stopping saves computeAssumes early performance predicts final

Pipeline Orchestration

Use DAG-based orchestrators (Airflow, Kubeflow, Prefect) to define step dependencies, handle retries, and manage scheduling. Each step runs in an isolated container with pinned dependencies.

Scaling & Optimization

GPU Cluster Management

Checkpointing Strategy

Critical: Checkpoint Frequently

With spot instances, nodes can be preempted at any time. Checkpoint every 15-30 minutes to object storage (S3/GCS). Keep last 3 checkpoints and delete older ones. Resume training from latest checkpoint on new nodes.

Practice Problems

Practice 1: Spot Instance Recovery

Your 32-GPU training job loses 8 GPUs due to spot preemption mid-epoch. Design a recovery strategy that minimizes wasted compute and resumes training within 5 minutes.

Practice 2: Data Versioning

Your training data changes daily as new events stream in. How do you ensure that any past experiment can be exactly reproduced? Design a data versioning system.

Practice 3: Multi-Tenant Cluster

Three ML teams share a 128-GPU cluster. Design a fair scheduling system that handles priority jobs, prevents starvation, and maximizes GPU utilization.

Quick Reference

ComponentTechnologyPurpose
OrchestrationKubeflow / AirflowDAG-based pipeline management
Distributed TrainingPyTorch DDP / DeepSpeedMulti-GPU/node training
Experiment TrackingMLflow / W&BLog params, metrics, artifacts
HPOOptuna / Ray TuneHyperparameter optimization
Data ValidationGreat Expectations / TFDVSchema & distribution checks
Model RegistryMLflow / Vertex AIVersion and stage models
ComputeKubernetes + GPU operatorCluster management

Key Takeaways

  • DAG orchestration ensures reproducible, fault-tolerant pipelines
  • Use distributed data parallel for most workloads; model parallelism for LLMs
  • Checkpoint aggressively when using spot instances
  • Track everything: code version, data version, hyperparams, environment
  • Gate model promotion on automated evaluation against baselines