Spaces:
Runtime error
Runtime error
Adding Metrics and updating requirements
Browse files- requirements.txt +8 -2
- src/models/evaluate_model.py +12 -0
- src/models/model.py +63 -2
- src/models/predict_model.py +1 -1
- src/models/train_model.py +2 -2
requirements.txt
CHANGED
|
@@ -1,5 +1,11 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
# external requirements
|
| 5 |
click
|
|
|
|
| 1 |
+
numpy==1.19.2
|
| 2 |
+
datasets==1.8.0
|
| 3 |
+
pytorch_lightning==1.3.5
|
| 4 |
+
transformers==4.6.0
|
| 5 |
+
torch==1.9.0+cu111
|
| 6 |
+
dagshub==0.1.6
|
| 7 |
+
pandas==1.2.4
|
| 8 |
+
rouge_score
|
| 9 |
|
| 10 |
# external requirements
|
| 11 |
click
|
src/models/evaluate_model.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.models.model import Summarization
|
| 2 |
+
from src.data.make_dataset import make_dataset
|
| 3 |
+
|
| 4 |
+
def evaluate_model():
|
| 5 |
+
"""
|
| 6 |
+
Evalute model using rouge measure
|
| 7 |
+
"""
|
| 8 |
+
test_df = make_dataset(split='test')
|
| 9 |
+
model = Summarization()
|
| 10 |
+
model.load_model()
|
| 11 |
+
results = model.evaluate(test_df=test_df)
|
| 12 |
+
return results
|
src/models/model.py
CHANGED
|
@@ -15,6 +15,7 @@ from pytorch_lightning import Trainer
|
|
| 15 |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
| 16 |
from pytorch_lightning import LightningDataModule
|
| 17 |
from pytorch_lightning import LightningModule
|
|
|
|
| 18 |
|
| 19 |
torch.cuda.empty_cache()
|
| 20 |
pl.seed_everything(42)
|
|
@@ -302,7 +303,7 @@ class Summarization:
|
|
| 302 |
tokenizer=self.tokenizer, model=self.model, output=outputdir
|
| 303 |
)
|
| 304 |
|
| 305 |
-
#logger = MLFlowLogger(experiment_name="Summarization",tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
|
| 306 |
|
| 307 |
logger = DAGsHubLogger()
|
| 308 |
|
|
@@ -425,4 +426,64 @@ class Summarization:
|
|
| 425 |
)
|
| 426 |
for g in generated_ids
|
| 427 |
]
|
| 428 |
-
return preds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
| 16 |
from pytorch_lightning import LightningDataModule
|
| 17 |
from pytorch_lightning import LightningModule
|
| 18 |
+
from datasets import load_metric
|
| 19 |
|
| 20 |
torch.cuda.empty_cache()
|
| 21 |
pl.seed_everything(42)
|
|
|
|
| 303 |
tokenizer=self.tokenizer, model=self.model, output=outputdir
|
| 304 |
)
|
| 305 |
|
| 306 |
+
# logger = MLFlowLogger(experiment_name="Summarization",tracking_uri="https://dagshub.com/gagan3012/summarization.mlflow")
|
| 307 |
|
| 308 |
logger = DAGsHubLogger()
|
| 309 |
|
|
|
|
| 426 |
)
|
| 427 |
for g in generated_ids
|
| 428 |
]
|
| 429 |
+
return preds[0]
|
| 430 |
+
|
| 431 |
+
def evaluate(
|
| 432 |
+
self,
|
| 433 |
+
test_df: pd.DataFrame,
|
| 434 |
+
metrics: str = "rouge"
|
| 435 |
+
):
|
| 436 |
+
metric = load_metric(metrics)
|
| 437 |
+
input_text = test_df['input_text']
|
| 438 |
+
references = test_df['output_text']
|
| 439 |
+
predictions = [self.predict(x) for x in input_text]
|
| 440 |
+
|
| 441 |
+
results = metric.compute(predictions=predictions, references=references)
|
| 442 |
+
|
| 443 |
+
output = {
|
| 444 |
+
'Rouge 1': {
|
| 445 |
+
'Rouge_1 Low Precision': results["rouge1"].low.precision,
|
| 446 |
+
'Rouge_1 Low recall': results["rouge1"].low.recall,
|
| 447 |
+
'Rouge_1 Low F1': results["rouge1"].low.fmeasure,
|
| 448 |
+
'Rouge_1 Mid Precision': results["rouge1"].mid.precision,
|
| 449 |
+
'Rouge_1 Mid recall': results["rouge1"].mid.recall,
|
| 450 |
+
'Rouge_1 Mid F1': results["rouge1"].mid.fmeasure,
|
| 451 |
+
'Rouge_1 High Precision': results["rouge1"].high.precision,
|
| 452 |
+
'Rouge_1 High recall': results["rouge1"].high.recall,
|
| 453 |
+
'Rouge_1 High F1': results["rouge1"].high.fmeasure,
|
| 454 |
+
},
|
| 455 |
+
'Rouge 2': {
|
| 456 |
+
'Rouge_2 Low Precision': results["rouge2"].low.precision,
|
| 457 |
+
'Rouge_2 Low recall': results["rouge2"].low.recall,
|
| 458 |
+
'Rouge_2 Low F1': results["rouge2"].low.fmeasure,
|
| 459 |
+
'Rouge_2 Mid Precision': results["rouge2"].mid.precision,
|
| 460 |
+
'Rouge_2 Mid recall': results["rouge2"].mid.recall,
|
| 461 |
+
'Rouge_2 Mid F1': results["rouge2"].mid.fmeasure,
|
| 462 |
+
'Rouge_2 High Precision': results["rouge2"].high.precision,
|
| 463 |
+
'Rouge_2 High recall': results["rouge2"].high.recall,
|
| 464 |
+
'Rouge_2 High F1': results["rouge2"].high.fmeasure,
|
| 465 |
+
},
|
| 466 |
+
'Rouge L':{
|
| 467 |
+
'Rouge_L Low Precision': results["rougeL"].low.precision,
|
| 468 |
+
'Rouge_L Low recall': results["rougeL"].low.recall,
|
| 469 |
+
'Rouge_L Low F1': results["rougeL"].low.fmeasure,
|
| 470 |
+
'Rouge_L Mid Precision': results["rougeL"].mid.precision,
|
| 471 |
+
'Rouge_L Mid recall': results["rougeL"].mid.recall,
|
| 472 |
+
'Rouge_L Mid F1': results["rougeL"].mid.fmeasure,
|
| 473 |
+
'Rouge_L High Precision': results["rougeL"].high.precision,
|
| 474 |
+
'Rouge_L High recall': results["rougeL"].high.recall,
|
| 475 |
+
'Rouge_L High F1': results["rougeL"].high.fmeasure,
|
| 476 |
+
},
|
| 477 |
+
'rougeLsum': {
|
| 478 |
+
'rougeLsum Low Precision': results["rougeLsum"].low.precision,
|
| 479 |
+
'rougeLsum Low recall': results["rougeLsum"].low.recall,
|
| 480 |
+
'rougeLsum Low F1': results["rougeLsum"].low.fmeasure,
|
| 481 |
+
'rougeLsum Mid Precision': results["rougeLsum"].mid.precision,
|
| 482 |
+
'rougeLsum Mid recall': results["rougeLsum"].mid.recall,
|
| 483 |
+
'rougeLsum Mid F1': results["rougeLsum"].mid.fmeasure,
|
| 484 |
+
'rougeLsum High Precision': results["rougeLsum"].high.precision,
|
| 485 |
+
'rougeLsum High recall': results["rougeLsum"].high.recall,
|
| 486 |
+
'rougeLsum High F1': results["rougeLsum"].high.fmeasure,
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
return output
|
src/models/predict_model.py
CHANGED
|
@@ -12,6 +12,6 @@ def predict_model(text):
|
|
| 12 |
|
| 13 |
|
| 14 |
if __name__ == '__main__':
|
| 15 |
-
text = make_dataset(split="test")['input_text']
|
| 16 |
pre_summary = predict_model(text)
|
| 17 |
print(pre_summary)
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
if __name__ == '__main__':
|
| 15 |
+
text = make_dataset(split="test")['input_text'][0]
|
| 16 |
pre_summary = predict_model(text)
|
| 17 |
print(pre_summary)
|
src/models/train_model.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
from .model import Summarization
|
| 2 |
-
from data.make_dataset import make_dataset
|
| 3 |
|
| 4 |
def train_model():
|
| 5 |
"""
|
|
|
|
| 1 |
+
from src.models.model import Summarization
|
| 2 |
+
from src.data.make_dataset import make_dataset
|
| 3 |
|
| 4 |
def train_model():
|
| 5 |
"""
|