[#1] logging the training loss
Browse files- idiomify/models.py +3 -0
idiomify/models.py
CHANGED
|
@@ -44,6 +44,9 @@ class Alpha(pl.LightningModule): # noqa
|
|
| 44 |
"loss": loss
|
| 45 |
}
|
| 46 |
|
|
|
|
|
|
|
|
|
|
| 47 |
def predict(self, srcs: torch.Tensor) -> torch.Tensor:
|
| 48 |
pred_ids = self.bart.generate(
|
| 49 |
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|
|
|
|
| 44 |
"loss": loss
|
| 45 |
}
|
| 46 |
|
| 47 |
+
def on_train_batch_end(self, outputs: dict, *args, **kwargs):
|
| 48 |
+
self.log("Train/Loss", outputs['loss'])
|
| 49 |
+
|
| 50 |
def predict(self, srcs: torch.Tensor) -> torch.Tensor:
|
| 51 |
pred_ids = self.bart.generate(
|
| 52 |
inputs=srcs[:, 0], # (N, 2, L) -> (N, L)
|