|
|
--- |
|
|
language: en |
|
|
license: mit |
|
|
datasets: |
|
|
- toxic_comment_classification |
|
|
tags: |
|
|
- text-classification |
|
|
- toxicity-detection |
|
|
- sentiment-analysis |
|
|
- multi-task-learning |
|
|
pipeline_tag: text-classification |
|
|
--- |
|
|
|
|
|
# Comment MTL BERT Model |
|
|
|
|
|
This is a BERT-based multi-task learning model capable of performing sentiment analysis and toxicity detection simultaneously. |
|
|
|
|
|
## Model Architecture |
|
|
|
|
|
The model is based on the `bert-base-uncased` pre-trained model with two separate classification heads: |
|
|
- **Sentiment Analysis Head**: 3-class classification (Negative, Neutral, Positive) |
|
|
- **Toxicity Detection Head**: 6-class multi-label classification (toxic, severe_toxic, obscene, threat, insult, identity_hate) |
|
|
|
|
|
### Technical Parameters |
|
|
|
|
|
- Hidden Size: 768 |
|
|
- Number of Attention Heads: 12 |
|
|
- Number of Hidden Layers: 12 |
|
|
- Vocabulary Size: 30522 |
|
|
- Maximum Position Embeddings: 512 |
|
|
- Hidden Activation Function: gelu |
|
|
- Dropout Probability: 0.1 |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Loading the Model |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer |
|
|
from src.model import CommentMTLModel |
|
|
import torch |
|
|
|
|
|
# Load tokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
|
# Load model |
|
|
model = CommentMTLModel( |
|
|
model_name="bert-base-uncased", |
|
|
num_sentiment_labels=3, |
|
|
num_toxicity_labels=6 |
|
|
) |
|
|
|
|
|
# Load pre-trained weights |
|
|
state_dict = torch.load("model.bin", map_location=torch.device('cpu')) |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
``` |
|
|
|
|
|
### Model Inference |
|
|
|
|
|
```python |
|
|
# Prepare input |
|
|
text = "This is a test comment." |
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
|
|
|
|
# Model inference |
|
|
with torch.no_grad(): |
|
|
outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]) |
|
|
|
|
|
# Get results |
|
|
sentiment_logits = outputs["sentiment_logits"] |
|
|
toxicity_logits = outputs["toxicity_logits"] |
|
|
|
|
|
# Process sentiment analysis results |
|
|
sentiment_probs = torch.softmax(sentiment_logits, dim=1) |
|
|
sentiment_labels = {0: "Negative", 1: "Neutral", 2: "Positive"} |
|
|
sentiment_prediction = sentiment_labels[sentiment_probs.argmax().item()] |
|
|
|
|
|
# Process toxicity detection results |
|
|
toxicity_probs = torch.sigmoid(toxicity_logits) |
|
|
toxicity_cols = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] |
|
|
toxicity_results = {label: prob.item() for label, prob in zip(toxicity_cols, toxicity_probs[0])} |
|
|
|
|
|
print(f"Sentiment: {sentiment_prediction}") |
|
|
print(f"Toxicity probabilities: {toxicity_results}") |
|
|
``` |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- This model has been trained only on English data and is not suitable for other languages. |
|
|
- The toxicity detection function may produce false positives for some edge cases. |
|
|
- The model may lose some information when processing long texts due to the maximum input length limit of 128 tokens. |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite our repository: |
|
|
|
|
|
``` |
|
|
@misc{comment-mtl-bert, |
|
|
author = {Aseem}, |
|
|
title = {Comment MTL BERT: Multi-Task Learning for Comment Analysis}, |
|
|
year = {2023}, |
|
|
publisher = {GitHub}, |
|
|
url = {https://huggingface.co/Aseemks07/comment_mtl_bert_best} |
|
|
} |
|
|
``` |