| from transformers import BartTokenizer | |
| from idiomify.datamodules import IdiomifyDataModule | |
| CONFIG = { | |
| "literal2idiomatic_ver": "d-1-2", | |
| "batch_size": 20, | |
| "num_workers": 4, | |
| "shuffle": True | |
| } | |
| def main(): | |
| tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") | |
| datamodule = IdiomifyDataModule(CONFIG, tokenizer) | |
| datamodule.prepare_data() | |
| datamodule.setup() | |
| for batch in datamodule.train_dataloader(): | |
| srcs, tgts_r, tgts = batch | |
| print(srcs.shape) | |
| print(tgts_r.shape) | |
| print(tgts.shape) | |
| break | |
| for batch in datamodule.test_dataloader(): | |
| srcs, tgts_r, tgts = batch | |
| print(srcs.shape) | |
| print(tgts_r.shape) | |
| print(tgts.shape) | |
| break | |
| if __name__ == '__main__': | |
| main() | |