Implement gradient clipping by global norm.
Signature: def clip_gradients(grads: list, max_norm: float) -> list
Compute the global L2 norm across all gradient arrays. If it exceeds max_norm, scale all gradients by max_norm / global_norm.
Math
Asked at
Test Results