Simulate the core gradient-averaging step of Distributed Data Parallel (DDP) using PyTorch.
Signature: def simulate_data_parallel(model_weights, data_shards, lr)
model_weights: list of floats (n_weights,) — shared initial weightsdata_shards: list of lists, each (shard_size, n_weights) — input data per workerlr: learning rate (float)Algorithm:
requires_grad=Trueout = (x @ w_local).sum() (sum of dot products over the batch).backward() to compute w_local.gradavg_grad = mean([g0, g1, ..., gK])w_new = w - lr * avg_gradWhy? In DDP, each GPU processes a different shard of data but holds identical model parameters. After each backward pass, gradients are averaged (all-reduce) across GPUs before the optimizer step — ensuring all replicas stay in sync.
Math
Asked at
Test Results