Spaces:
Runtime error
Runtime error
fixes
Browse files- src/models/model.py +2 -2
src/models/model.py
CHANGED
|
@@ -15,7 +15,6 @@ from pytorch_lightning import LightningModule
|
|
| 15 |
from datasets import load_metric
|
| 16 |
from tqdm.auto import tqdm
|
| 17 |
|
| 18 |
-
|
| 19 |
# from dagshub.pytorch_lightning import DAGsHubLogger
|
| 20 |
|
| 21 |
|
|
@@ -477,9 +476,10 @@ class Summarization:
|
|
| 477 |
metric = load_metric(metrics)
|
| 478 |
input_text = test_df['input_text'][:5]
|
| 479 |
references = test_df['output_text'][:5]
|
|
|
|
| 480 |
|
| 481 |
predictions = [self.predict(x) for x in input_text]
|
| 482 |
-
print(type(predictions),type(references))
|
| 483 |
|
| 484 |
results = metric.compute(predictions=predictions, references=references)
|
| 485 |
'''
|
|
|
|
| 15 |
from datasets import load_metric
|
| 16 |
from tqdm.auto import tqdm
|
| 17 |
|
|
|
|
| 18 |
# from dagshub.pytorch_lightning import DAGsHubLogger
|
| 19 |
|
| 20 |
|
|
|
|
| 476 |
metric = load_metric(metrics)
|
| 477 |
input_text = test_df['input_text'][:5]
|
| 478 |
references = test_df['output_text'][:5]
|
| 479 |
+
references = references.to_list()
|
| 480 |
|
| 481 |
predictions = [self.predict(x) for x in input_text]
|
| 482 |
+
print(type(predictions), type(references))
|
| 483 |
|
| 484 |
results = metric.compute(predictions=predictions, references=references)
|
| 485 |
'''
|