DataFrames & Transformations

Medium 30 min read

Creating DataFrames

Why This Matters

The Problem: Raw data comes in many formats -- CSV, JSON, Parquet, databases -- and needs to be loaded, cleaned, and transformed before it can drive business decisions.

The Solution: PySpark DataFrames provide a distributed, strongly-typed abstraction for working with structured data at any scale. They look like pandas DataFrames but can process terabytes of data across hundreds of machines.

Real Impact: DataFrames are the foundation of every Databricks pipeline. Understanding transformations is essential for building performant ETL jobs that process billions of rows in minutes.

Real-World Analogy

Think of a DataFrame as a massive spreadsheet that is split across many machines:

  • Rows = Individual records, distributed across cluster nodes (partitions)
  • Columns = Named fields with specific data types (schema)
  • Transformations = Formulas applied to every row, executed in parallel
  • Actions = The "Enter" key that triggers actual computation

Creating DataFrames from Different Sources

Python - Creating DataFrames
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, TimestampType

# 1. From Python lists
data = [
    (1, "Alice", "Engineering", 95000),
    (2, "Bob", "Data Science", 105000),
    (3, "Carol", "Engineering", 98000),
    (4, "Dave", "Analytics", 87000),
    (5, "Eve", "Data Science", 112000),
]
columns = ["id", "name", "department", "salary"]
df = spark.createDataFrame(data, columns)

# 2. With explicit schema (recommended for production)
schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("name", StringType(), False),
    StructField("department", StringType(), True),
    StructField("salary", DoubleType(), True),
])
df_typed = spark.createDataFrame(data, schema)

# 3. From CSV
df_csv = spark.read.format("csv") \
    .option("header", True) \
    .option("inferSchema", True) \
    .load("/mnt/data/employees.csv")

# 4. From JSON
df_json = spark.read.json("/mnt/data/events/*.json")

# 5. From Parquet (schema embedded)
df_parquet = spark.read.parquet("/mnt/data/transactions/")

# 6. From Delta table
df_delta = spark.read.format("delta").load("/mnt/data/bronze/events")

# 7. From SQL table
df_table = spark.table("main.default.customers")

# Inspect the DataFrame
df.printSchema()
df.show(5)
print(f"Rows: {df.count()}, Columns: {len(df.columns)}")

Narrow vs Wide Transformations

Understanding the difference between narrow and wide transformations is critical for writing performant Spark code. This determines when data must be shuffled across the cluster network.

Transformation Pipeline: Narrow vs Wide
Narrow Transformations No shuffle -- each partition processed independently Partition 1 [A, B, C] Partition 2 [D, E, F] Partition 3 [G, H, I] select() / filter() / map() Partition 1 [A', B'] Partition 2 [D', E'] Partition 3 [G'] Examples (no data movement): select(), filter(), where() withColumn(), drop() map(), flatMap() union(), coalesce() cast(), alias() Performance: Fast (no network I/O) Wide Transformations Requires shuffle -- data moves across network Partition 1 [A:3, B:1] Partition 2 [A:2, C:4] Partition 3 [B:5, C:1] groupBy() / join() / sort() Key A [A:3, A:2] Key B [B:1, B:5] Key C [C:4, C:1] Examples (data moves across network): groupBy(), agg() join(), crossJoin() orderBy(), sort() distinct(), dropDuplicates() repartition(), pivot() Performance: Expensive (network I/O)

Core Transformations

select, filter, withColumn

Python - Basic Transformations
from pyspark.sql import functions as F

# select -- choose specific columns
df_names = df.select("name", "department")
df_computed = df.select(
    F.col("name"),
    F.col("salary"),
    (F.col("salary") * 0.3).alias("tax"),
    F.upper(F.col("name")).alias("name_upper")
)

# filter / where -- keep rows that match condition
df_eng = df.filter(F.col("department") == "Engineering")
df_high_salary = df.where(F.col("salary") > 90000)
df_multi = df.filter(
    (F.col("department") == "Engineering") &
    (F.col("salary") > 90000)
)

# withColumn -- add or replace a column
df_bonus = df.withColumn("bonus", F.col("salary") * 0.1)
df_level = df.withColumn("level",
    F.when(F.col("salary") >= 100000, "Senior")
     .when(F.col("salary") >= 90000, "Mid")
     .otherwise("Junior")
)

# drop -- remove columns
df_slim = df.drop("id")

# distinct -- deduplicate
df_unique_depts = df.select("department").distinct()

groupBy and Aggregations

Python - GroupBy & Aggregations
# Basic groupBy with single aggregation
dept_avg = df.groupBy("department").avg("salary")

# Multiple aggregations
dept_stats = df.groupBy("department").agg(
    F.count("*").alias("employee_count"),
    F.avg("salary").alias("avg_salary"),
    F.min("salary").alias("min_salary"),
    F.max("salary").alias("max_salary"),
    F.sum("salary").alias("total_payroll"),
    F.stddev("salary").alias("salary_stddev")
)

# Group by multiple columns
multi_group = df.groupBy("department", "level").agg(
    F.count("*").alias("count"),
    F.round(F.avg("salary"), 2).alias("avg_salary")
)

# Pivot -- transform rows into columns
pivot_df = df.groupBy("department").pivot("level").avg("salary")

# Sorting
sorted_df = dept_stats.orderBy(F.col("avg_salary").desc())

Joins

Python - Join Operations
# Create department metadata table
dept_data = [
    ("Engineering", "Building A", "VP-Eng"),
    ("Data Science", "Building B", "VP-DS"),
    ("Analytics", "Building B", "VP-Analytics"),
    ("Marketing", "Building C", "VP-Mkt"),
]
dept_df = spark.createDataFrame(dept_data, ["dept_name", "location", "lead"])

# Inner join -- only matching rows from both sides
inner = df.join(dept_df, df.department == dept_df.dept_name, "inner")

# Left join -- all rows from left, matching from right
left = df.join(dept_df, df.department == dept_df.dept_name, "left")

# Right join -- all rows from right, matching from left
right = df.join(dept_df, df.department == dept_df.dept_name, "right")

# Full outer join -- all rows from both sides
full = df.join(dept_df, df.department == dept_df.dept_name, "outer")

# Anti join -- rows from left that have NO match in right
no_dept_info = df.join(dept_df, df.department == dept_df.dept_name, "anti")

# Semi join -- rows from left that HAVE a match in right (no right columns)
has_dept_info = df.join(dept_df, df.department == dept_df.dept_name, "semi")

# Join with multiple conditions
complex_join = orders.join(
    customers,
    (orders.customer_id == customers.id) &
    (orders.region == customers.region),
    "inner"
)

# Broadcast join -- small table sent to all nodes (avoid shuffle)
optimized = df.join(F.broadcast(dept_df), df.department == dept_df.dept_name)

Join Performance Tip

Use F.broadcast(small_df) when one side of the join is small enough to fit in memory on each executor (typically under 10MB). This avoids the expensive shuffle of the large table and can speed up joins by 10-100x. Spark auto-broadcasts tables under spark.sql.autoBroadcastJoinThreshold (default 10MB).

Window Functions

Python - Window Functions
from pyspark.sql.window import Window

# Define window specifications
dept_window = Window.partitionBy("department").orderBy(F.col("salary").desc())
dept_all = Window.partitionBy("department")

# Ranking functions
df_ranked = df.withColumn("rank", F.rank().over(dept_window)) \
              .withColumn("dense_rank", F.dense_rank().over(dept_window)) \
              .withColumn("row_number", F.row_number().over(dept_window))

# Aggregate within window (running totals, averages)
df_window_agg = df.withColumn("dept_avg_salary", F.avg("salary").over(dept_all)) \
                  .withColumn("dept_max_salary", F.max("salary").over(dept_all)) \
                  .withColumn("salary_pct_of_dept",
                      F.round(F.col("salary") / F.avg("salary").over(dept_all) * 100, 1))

# Lead / Lag -- access previous/next rows
time_window = Window.partitionBy("user_id").orderBy("event_time")
df_events = events.withColumn("prev_event", F.lag("event_type").over(time_window)) \
                  .withColumn("next_event", F.lead("event_type").over(time_window)) \
                  .withColumn("time_since_last",
                      F.col("event_time").cast("long") -
                      F.lag("event_time").over(time_window).cast("long"))

# Running sum with frame specification
running_window = Window.partitionBy("department") \
    .orderBy("hire_date") \
    .rowsBetween(Window.unboundedPreceding, Window.currentRow)

df_running = df.withColumn("running_headcount", F.count("*").over(running_window)) \
               .withColumn("running_payroll", F.sum("salary").over(running_window))

# Top N per group using row_number
top_2_per_dept = df.withColumn(
    "rn", F.row_number().over(dept_window)
).filter(F.col("rn") <= 2).drop("rn")

UDFs & Schema Enforcement

User-Defined Functions

Python - UDFs
from pyspark.sql.functions import udf, pandas_udf
from pyspark.sql.types import StringType, DoubleType
import pandas as pd

# Standard UDF (row-at-a-time, Python overhead)
@udf(returnType=StringType())
def classify_salary(salary):
    if salary >= 120000: return "Executive"
    elif salary >= 100000: return "Senior"
    elif salary >= 80000: return "Mid"
    else: return "Junior"

df_classified = df.withColumn("level", classify_salary(F.col("salary")))

# Pandas UDF (vectorized, much faster -- uses Arrow)
@pandas_udf(DoubleType())
def normalize_salary(salary: pd.Series) -> pd.Series:
    return (salary - salary.mean()) / salary.std()

df_normalized = df.withColumn("salary_zscore", normalize_salary(F.col("salary")))

# Register UDF for SQL use
spark.udf.register("classify_salary_sql", classify_salary)

Schema Enforcement

Python - Schema Enforcement Patterns
# Define expected schema
expected_schema = StructType([
    StructField("order_id", IntegerType(), False),
    StructField("customer_id", IntegerType(), False),
    StructField("product", StringType(), True),
    StructField("amount", DoubleType(), True),
    StructField("order_date", TimestampType(), True),
])

# Read with explicit schema (no schema inference overhead)
df_orders = spark.read.schema(expected_schema).json("/mnt/data/orders/")

# Validate schema at runtime
def validate_schema(df, expected):
    actual_fields = {f.name: f.dataType for f in df.schema.fields}
    expected_fields = {f.name: f.dataType for f in expected.fields}
    missing = set(expected_fields.keys()) - set(actual_fields.keys())
    type_mismatches = {
        k: (str(actual_fields[k]), str(expected_fields[k]))
        for k in actual_fields
        if k in expected_fields and actual_fields[k] != expected_fields[k]
    }
    if missing:
        raise ValueError(f"Missing columns: {missing}")
    if type_mismatches:
        raise TypeError(f"Type mismatches: {type_mismatches}")

validate_schema(df_orders, expected_schema)

SQL Equivalent

SQL - DataFrame Operations in SQL
-- Register DataFrame as temp view
-- df.createOrReplaceTempView("employees")

-- Select and filter
SELECT name, department, salary,
       salary * 0.1 AS bonus
FROM employees
WHERE department = 'Engineering'
  AND salary > 90000;

-- Group by with aggregation
SELECT department,
       COUNT(*) AS headcount,
       AVG(salary) AS avg_salary,
       MAX(salary) AS max_salary
FROM employees
GROUP BY department
ORDER BY avg_salary DESC;

-- Window functions
SELECT name, department, salary,
       RANK() OVER (PARTITION BY department ORDER BY salary DESC) AS dept_rank,
       AVG(salary) OVER (PARTITION BY department) AS dept_avg,
       salary - AVG(salary) OVER (PARTITION BY department) AS diff_from_avg
FROM employees;

-- Top 2 per department
WITH ranked AS (
    SELECT *, ROW_NUMBER() OVER (PARTITION BY department ORDER BY salary DESC) AS rn
    FROM employees
)
SELECT * FROM ranked WHERE rn <= 2;

Practice Problems

Problem 1: Sales Analytics Pipeline

Medium

Given an orders DataFrame with columns (order_id, customer_id, product, amount, order_date), write PySpark code to: (1) Filter orders from the last 30 days, (2) Calculate total revenue per product, (3) Rank products by revenue, (4) Show the top 5 products.

Problem 2: Customer Cohort Analysis

Medium

Write a PySpark transformation that: (1) Finds each customer's first order date, (2) Assigns each customer to a monthly cohort based on their first order, (3) Calculates the number of orders each customer made in each subsequent month.

Problem 3: Multi-Table Join Pipeline

Hard

You have three tables: orders, customers, and products. Write a PySpark pipeline that produces a report showing: each customer's name, their total spend, their most-purchased product, and their rank within their region by total spend. Use broadcast join for the products table (small).

Quick Reference

OperationPySparkSQL Equivalent
Select columnsdf.select("col1", "col2")SELECT col1, col2
Filter rowsdf.filter(F.col("x") > 10)WHERE x > 10
Add columndf.withColumn("new", expr)SELECT *, expr AS new
Group + aggregatedf.groupBy("g").agg(F.sum("v"))GROUP BY g
Joindf1.join(df2, "key")JOIN df2 ON key
Sortdf.orderBy(F.desc("col"))ORDER BY col DESC
Window rankF.rank().over(window)RANK() OVER (...)
Deduplicatedf.dropDuplicates(["key"])DISTINCT / ROW_NUMBER

Useful Resources