Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import BertTokenizer, BertModel | |
| from huggingface_hub import PyTorchModelHubMixin | |
| import numpy as np | |
| import gradio as gr | |
| import nltk | |
| nltk.download('stopwords') | |
| from nltk.corpus import stopwords | |
| import re | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| device | |
| class BERTClass(torch.nn.Module, PyTorchModelHubMixin): | |
| def __init__(self): | |
| super(BERTClass, self).__init__() | |
| self.bert_model = BertModel.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2', return_dict=True) | |
| self.dropout = torch.nn.Dropout(0.3) | |
| self.linear = torch.nn.Linear(1024, 11) | |
| def forward(self, input_ids, attn_mask, token_type_ids): | |
| output = self.bert_model( | |
| input_ids, | |
| attention_mask=attn_mask, | |
| token_type_ids=token_type_ids | |
| ) | |
| output_dropout = self.dropout(output.pooler_output) | |
| output = self.linear(output_dropout) | |
| return output | |
| model = BERTClass() | |
| model = model.from_pretrained("Asutosh2003/ct-bert-v2-vaccine-concern") | |
| model.to(device) | |
| tokenizer = BertTokenizer.from_pretrained('digitalepidemiologylab/covid-twitter-bert-v2') | |
| MAX_LEN = 256 | |
| def rmTrash(raw_string, remuser, remstop, remurls): | |
| final_string = "" | |
| raw_string_2 = "" | |
| if remuser == True: | |
| for i in raw_string.split(): | |
| if '@' not in i: | |
| raw_string_2 += ' ' + i | |
| else: | |
| raw_string_2 = raw_string | |
| raw_string_2 = re.sub(r'[^\w\s]', '', raw_string_2.lower()) | |
| if remurls == True: | |
| raw_string_2 = re.sub(r'http\S+', '', raw_string_2.lower()) | |
| if remstop == True: | |
| raw_string_tokens = raw_string_2.split() | |
| for token in raw_string_tokens: | |
| if (not(token in stopwords.words('english'))): | |
| final_string = final_string + ' ' + token | |
| else: | |
| final_string = raw_string_2 | |
| return final_string | |
| def return_vec(text): | |
| text = rmTrash(text,True,True,True) | |
| encodings = tokenizer.encode_plus( | |
| text, | |
| None, | |
| add_special_tokens=True, | |
| max_length=MAX_LEN, | |
| padding='max_length', | |
| return_token_type_ids=True, | |
| truncation=True, | |
| return_attention_mask=True, | |
| return_tensors='pt' | |
| ) | |
| model.eval() | |
| with torch.no_grad(): | |
| input_ids = encodings['input_ids'].to(device, dtype=torch.long) | |
| attention_mask = encodings['attention_mask'].to(device, dtype=torch.long) | |
| token_type_ids = encodings['token_type_ids'].to(device, dtype=torch.long) | |
| output = model(input_ids, attention_mask, token_type_ids) | |
| final_output = torch.sigmoid(output).cpu().detach().numpy().tolist() | |
| return list(final_output[0]) | |
| def filter_threshold_lst(vector, threshold_list): | |
| optimized_vector = [] | |
| optimized_vector = [1 if val >= threshold else 0 for val, threshold in zip(vector, threshold_list)] | |
| optimized_vector.append(optimized_vector) | |
| return optimized_vector | |
| def predict(text, threshold_lst): | |
| pred_lbl_lst = [] | |
| labels = ('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious') | |
| prob_lst = return_vec(text) | |
| vec = filter_threshold_lst(prob_lst, threshold_lst) | |
| if vec[:11] == [0] * 11: | |
| pred_lbl_lst = ['none'] | |
| vec = [0] * 11 | |
| vec.append(1) | |
| return pred_lbl_lst, prob_lst | |
| for i in range(len(vec)): | |
| if vec[i] == 1: | |
| pred_lbl_lst.append(labels[i]) | |
| return pred_lbl_lst, prob_lst | |
| def gr_predict(text): | |
| thres = [0.616, 0.212, 0.051, 0.131, 0.212, 0.111, 0.071, 0.566, 0.061, 0.02, 0.081] | |
| out_lst, _ = predict(text,thres) | |
| out_str = '' | |
| for lbl in out_lst: | |
| out_str += lbl + ',' | |
| out_str = out_str[:-1] | |
| return out_str | |
| descr = """ | |
| This app uses [Covid-twitter-BERT-v2](https://huggingface.co/digitalepidemiologylab/covid-twitter-bert-v2) | |
| fine tuned on a custom subset of [Caves dataset](https://arxiv.org/abs/2204.13746) sent by [FIRE 2023](http://fire.irsi.res.in/fire/2023/home) | |
| conference to do multi-label classification of tweets expressing concerns towards vaccines. The different concerns/classes are | |
| ('side-effect', 'ineffective', 'rushed', 'pharma', 'mandatory', 'unnecessary', 'political', 'ingredients', 'conspiracy', 'country', 'religious'). | |
| Each tweet can be expressing multiple of these concerns. If a tweet is not expressing any concern falling into any of these categories | |
| it will be classified as 'None'.\n | |
| [Source files](https://github.com/Ranjit246/AISoME_FIRE_2023)\n | |
| Try it out with some ridiculous statements about vaccines. You can use the examples below as a start. | |
| """ | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=gr_predict, | |
| inputs=gr.Textbox(), | |
| outputs=gr.Label(), # Use Label widget for output | |
| examples=["This vaccine gave me mumps", "Chinese vaccine will infect our brain", | |
| "Trump is gonna use these vaccines to control us and become the president"], | |
| title="Vaccine Concerns ML", | |
| description=descr | |
| ) | |
| # Launch the Gradio app | |
| iface.launch(debug=True) |