Spaces:
Runtime error
Runtime error
Updates
Browse files- .gitignore +3 -0
- src/data/make_dataset.py +3 -0
- src/models/model.py +1 -1
- src/models/predict_model.py +9 -0
- src/models/train_model.py +15 -0
.gitignore
CHANGED
|
@@ -88,3 +88,6 @@ coverage.xml
|
|
| 88 |
|
| 89 |
# Mypy cache
|
| 90 |
.mypy_cache/
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# Mypy cache
|
| 90 |
.mypy_cache/
|
| 91 |
+
|
| 92 |
+
.idea
|
| 93 |
+
.vscode
|
src/data/make_dataset.py
CHANGED
|
@@ -9,3 +9,6 @@ def make_dataset(dataset='cnn_dailymail', split='train', version="3.0.0"):
|
|
| 9 |
df['input_text'] = dataset['concepts']
|
| 10 |
df['output_text'] = dataset['target']
|
| 11 |
return df
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
df['input_text'] = dataset['concepts']
|
| 10 |
df['output_text'] = dataset['target']
|
| 11 |
return df
|
| 12 |
+
|
| 13 |
+
if __name__ == '__main__':
|
| 14 |
+
make_dataset(dataset='cnn_dailymail', split='train', version="3.0.0")
|
src/models/model.py
CHANGED
|
@@ -340,7 +340,7 @@ class Summarization:
|
|
| 340 |
trainer.fit(self.T5Model, self.data_module)
|
| 341 |
|
| 342 |
def load_model(
|
| 343 |
-
self, model_dir: str = "models", use_gpu: bool = False
|
| 344 |
):
|
| 345 |
"""
|
| 346 |
loads a checkpoint for inferencing/prediction
|
|
|
|
| 340 |
trainer.fit(self.T5Model, self.data_module)
|
| 341 |
|
| 342 |
def load_model(
|
| 343 |
+
self, model_dir: str = "../../models", use_gpu: bool = False
|
| 344 |
):
|
| 345 |
"""
|
| 346 |
loads a checkpoint for inferencing/prediction
|
src/models/predict_model.py
CHANGED
|
@@ -1,2 +1,11 @@
|
|
| 1 |
from .model import Summarization
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from .model import Summarization
|
| 2 |
|
| 3 |
+
def predict_model(text):
|
| 4 |
+
"""
|
| 5 |
+
Predict the summary of the given text.
|
| 6 |
+
"""
|
| 7 |
+
model = Summarization()
|
| 8 |
+
model.load_model()
|
| 9 |
+
pre_summary = model.predict(text)
|
| 10 |
+
return pre_summary
|
| 11 |
+
|
src/models/train_model.py
CHANGED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import Summarization
|
| 2 |
+
from data.make_dataset import make_dataset
|
| 3 |
+
|
| 4 |
+
def train_model():
|
| 5 |
+
"""
|
| 6 |
+
Train the model
|
| 7 |
+
"""
|
| 8 |
+
# Load the data
|
| 9 |
+
train_df = make_dataset(split = 'train')
|
| 10 |
+
eval_df = make_dataset(split = 'test')
|
| 11 |
+
|
| 12 |
+
model = Summarization()
|
| 13 |
+
model.from_pretrained('t5-base')
|
| 14 |
+
model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
|
| 15 |
+
model.save_model()
|