TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
TorchedUp
LearnBetaProblemsSystem DesignSoonPremium
←

108. PyTorch: Simulated Data Parallel Gradient Averaging

Hard

Simulate the core gradient-averaging step of Distributed Data Parallel (DDP) using PyTorch.

Signature: def simulate_data_parallel(model_weights, data_shards, lr)

  • model_weights: list of floats (n_weights,) — shared initial weights
  • data_shards: list of lists, each (shard_size, n_weights) — input data per worker
  • lr: learning rate (float)
  • Returns: updated weights after one DDP step as a list

Algorithm:

  1. For each worker/shard: create a local copy of weights with requires_grad=True
  2. Forward pass: out = (x @ w_local).sum() (sum of dot products over the batch)
  3. Call .backward() to compute w_local.grad
  4. Average all workers' gradients: avg_grad = mean([g0, g1, ..., gK])
  5. SGD update: w_new = w - lr * avg_grad
  6. Return updated weights as a list

Why? In DDP, each GPU processes a different shard of data but holds identical model parameters. After each backward pass, gradients are averaged (all-reduce) across GPUs before the optimizer step — ensuring all replicas stay in sync.

Math

w←w−η⋅K1​k=1∑K​∇w​Lk​

Asked at

NumPy

import numpy as np

 

def simulate_data_parallel(...):

    pass

🔒

Premium problem

Free accounts include problems #1–20. Upgrade to unlock the editor, hidden test cases, and reference solutions for every problem.

Upgrade to PremiumBack to problems

Already premium?