TorchedUp
ProblemsPremium
TorchedUp
PyTorch: Simulated Data Parallel Gradient AveragingHard
ProblemsPremium

PyTorch: Simulated Data Parallel Gradient Averaging

Simulate the core gradient-averaging step of Distributed Data Parallel (DDP) using PyTorch.

Signature: def simulate_data_parallel(model_weights, data_shards, lr)

  • model_weights: list of floats (n_weights,) — shared initial weights
  • data_shards: list of lists, each (shard_size, n_weights) — input data per worker
  • lr: learning rate (float)
  • Returns: updated weights after one DDP step as a list

Algorithm:

  1. For each worker/shard: create a local copy of weights with requires_grad=True
  2. Forward pass: out = (x @ w_local).sum() (sum of dot products over the batch)
  3. Call .backward() to compute w_local.grad
  4. Average all workers' gradients: avg_grad = mean([g0, g1, ..., gK])
  5. SGD update: w_new = w - lr * avg_grad
  6. Return updated weights as a list

Why? In DDP, each GPU processes a different shard of data but holds identical model parameters. After each backward pass, gradients are averaged (all-reduce) across GPUs before the optimizer step — ensuring all replicas stay in sync.

Math

Asked at

Python (numpy)0/3 runs today

Test Results

○single worker, 2 samples
○two workers with orthogonal gradients
○three workers, single weight🔒 Premium
○two workers same shard — identical to single-worker doubled🔒 Premium
Advertisement