rikf commited on
Commit
27dffa4
·
verified ·
1 Parent(s): a718c05

Avoid using in-place torch operation for scatter_add

Browse files

Replace the in-place scatter add with the out of place equivalent

Files changed (1) hide show
  1. modeling_grinmoe.py +2 -1
modeling_grinmoe.py CHANGED
@@ -786,7 +786,8 @@ class mp(torch.autograd.Function):
786
  grad_at_output = grad_at_output * multiplier
787
 
788
  grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
789
- grad_at_scores_expaned.scatter_add_(
 
790
  dim=-1,
791
  index=selected_experts,
792
  src=grad_at_output,
 
786
  grad_at_output = grad_at_output * multiplier
787
 
788
  grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
789
+ grad_at_scores_expaned = torch.scatter_add(
790
+ grad_at_scores_expaned,
791
  dim=-1,
792
  index=selected_experts,
793
  src=grad_at_output,