In [1]:
import os
import numpy as np

# import transformers
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
)
from datasets import load_metric

from dataset_loader import IntentDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# transformers.logging.set_verbosity_info()
# transformers.logging.set_verbosity_error() 
# We set the verbosity to error to avoid the annoying huggingface warnings 
# when loading models before training them. If you're having trouble getting things to work
# maybe comment that line (setting the verbosity to info also may lead to interesting outputs!)
# os.environ['TOKENIZERS_PARALLELISM'] = "false" # trainer (?) was complaining about parallel tokenization
# os.environ["WANDB_DISABLED"] = "true" # trainer was complaining about wandb

In [3]:
model_checkpoint_name = 'roberta-base' # try 'bert-base-uncased', 'bert-base-cased', 'bert-large-uncased'
dataset_name = 'twiz-data' # rename to your dataset dir
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint_name) # loads a tokenizer
tokenizer.save_pretrained("tokenizer")



In [4]:
train_dataset = IntentDataset(dataset_name, tokenizer, 'train') # check twiz_dataset.py for dataset loading code
val_dataset = IntentDataset(dataset_name, tokenizer, 'val')

model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint_name, num_labels=len(train_dataset.all_intents)) # Loads the BERT model weights

Loaded Intent detection dataset. 5916 examples. (train). 
Loaded Intent detection dataset. 819 examples. (val). 


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
inspect_index = 0
print('All data keys:', train_dataset[inspect_index].keys())
print(train_dataset[inspect_index]['input_ids'], train_dataset[inspect_index]['input_ids'].shape)
# you can check the correspondence of a label by checking the all_intents attribute, as such:
train_dataset[inspect_index]['label'], train_dataset.all_intents[train_dataset[inspect_index]['label']]

All data keys: dict_keys(['input_ids', 'attention_mask', 'label'])
tensor([    0,  6715,    28,  7316,    77,   634,   143,  3270,    50,  2104,
            4,  9427,     6,  1078,    78,   328,  1398,    16,   103,   335,
           59, 26157,     8, 42446, 11182,   102,     4,    85,    34,    10,
          204,     4,   398,   999,   691,     4,  1437,    85,    16,  2319,
            7,   185,    59,  1718,   728,   479,    85,  4542,   204,     4,
         3139,  9600,   672,    16, 18609,     4,  1437,   318,    42,    16,
           45,  1341,    99,    47,    32,   546,    13,   224,     6,   213,
          124,     4,   598,   535,     5,  3685,     6,    95,   224,     6,
          311,  7075,     4,     2,     2, 12005,  7075,     2,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     

(tensor(29), 'IngredientsConfirmationIntent')

In [6]:
acc = load_metric('accuracy')
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = acc.compute(predictions=predictions, references=labels)
    return accuracy

def get_trainer(model):
    return Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
    )

training_args = TrainingArguments(
    output_dir='roberta-based',
    do_train=True,
    do_eval=True,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_strategy='epoch',
    metric_for_best_model='accuracy',
    learning_rate=2e-5,
    num_train_epochs=5,
    weight_decay=0.01,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    load_best_model_at_end=True,
    disable_tqdm=False,
)

trainer = get_trainer(model)

  acc = load_metric('accuracy')
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
Using the latest cached version of the module from /user/home/dc.tavares/.cache/huggingface/modules/datasets_modules/metrics/accuracy/bbddc2dafac9b46b0aeeb39c145af710c55e03b223eae89dfe86388f40d9d157 (last modified on Wed May 18 17:06:59 2022) since it couldn't be found locally at accuracy, or remotely on the Hugging Face Hub.


In [7]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.7332,1.017632,0.799756
2,0.6767,0.734118,0.82906
3,0.4469,0.668322,0.847375
4,0.3435,0.640882,0.852259
5,0.2829,0.641061,0.857143


TrainOutput(global_step=925, training_loss=0.6966540857263513, metrics={'train_runtime': 515.0261, 'train_samples_per_second': 57.434, 'train_steps_per_second': 1.796, 'total_flos': 2736984690806400.0, 'train_loss': 0.6966540857263513, 'epoch': 5.0})

In [9]:
# run the next cell with the next line uncommented and fill your checkpoint directory to evaluate the model
# model = AutoModelForSequenceClassification.from_pretrained('./your-checkpoint-directory').eval()
test_dataset = IntentDataset(dataset_name, tokenizer, 'test')
trainer = get_trainer(model)
trainer.evaluate(eval_dataset=test_dataset)

Loaded Intent detection dataset. 842 examples. (test). 


ConnectionError: (MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /api/repos/create (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x7fd2023513c0>: Failed to resolve \'huggingface.co\' ([Errno -3] Temporary failure in name resolution)"))'), '(Request ID: 893f7cae-38f8-4513-ba1d-a7c8dd3db7c8)')