Avoid using in-place torch operation for scatter_add
Browse filesReplace the in-place scatter add with the out of place equivalent
- modeling_grinmoe.py +2 -1
modeling_grinmoe.py
CHANGED
|
@@ -786,7 +786,8 @@ class mp(torch.autograd.Function):
|
|
| 786 |
grad_at_output = grad_at_output * multiplier
|
| 787 |
|
| 788 |
grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
|
| 789 |
-
grad_at_scores_expaned.
|
|
|
|
| 790 |
dim=-1,
|
| 791 |
index=selected_experts,
|
| 792 |
src=grad_at_output,
|
|
|
|
| 786 |
grad_at_output = grad_at_output * multiplier
|
| 787 |
|
| 788 |
grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
|
| 789 |
+
grad_at_scores_expaned = torch.scatter_add(
|
| 790 |
+
grad_at_scores_expaned,
|
| 791 |
dim=-1,
|
| 792 |
index=selected_experts,
|
| 793 |
src=grad_at_output,
|