Spaces:
Sleeping
Sleeping
| import torch | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import torch.nn as nn | |
| import pickle | |
| import re | |
| token_2_id = None | |
| # Load the dictionary later | |
| with open(r"vocab.pkl", "rb") as f: | |
| token_2_id = pickle.load(f) | |
| print(token_2_id) | |
| def normalize(text): | |
| text = text.lower() | |
| text = re.sub(r'[^a-z0-9\s]', '', text) | |
| text = ' '.join(text.split()) | |
| return text | |
| def tokenize(text): | |
| tokens = text.split() | |
| return tokens | |
| def convert_tokens_2_ids(tokens): | |
| input_ids = [ | |
| token_2_id.get(token, token_2_id['<UNK>']) for token in tokens | |
| ] | |
| return input_ids | |
| def process_text(text, aspect): | |
| text_aspect_pair = text + ' ' + aspect | |
| normalized_text = normalize(text_aspect_pair) | |
| tokens = tokenize(normalized_text) | |
| input_ids = convert_tokens_2_ids(tokens) | |
| input_ids = torch.tensor(input_ids).unsqueeze(0) | |
| return input_ids | |
| class ABSA(nn.Module): | |
| def __init__(self, vocab_size, num_labels=3): | |
| super(ABSA, self).__init__() | |
| self.vocab_size = vocab_size | |
| self.num_labels = num_labels | |
| self.embedding_layer = nn.Embedding( | |
| num_embeddings=vocab_size, embedding_dim=256 | |
| ) | |
| self.lstm_layer = nn.LSTM( | |
| input_size=256, | |
| hidden_size=512, | |
| batch_first=True, | |
| ) | |
| self.fc_layer = nn.Linear( | |
| in_features=512, | |
| out_features=self.num_labels | |
| ) | |
| def forward(self, x): | |
| embeddings = self.embedding_layer(x) | |
| lstm_out, _ = self.lstm_layer(embeddings) | |
| logits = self.fc_layer(lstm_out[:, -1, :]) | |
| return logits | |
| model = ABSA(vocab_size=len(token_2_id.keys()), num_labels=3) | |
| model.load_state_dict(torch.load('model_weights.pth')) | |
| model.eval() | |
| print("Model loaded successfully") | |
| app = FastAPI() | |
| # Root endpoint | |
| def greet_json(): | |
| return {"Hello": "World!"} | |
| # Input model for request validation | |
| class TextAspectInput(BaseModel): | |
| text: str | |
| aspect: str | |
| # Sentiment labels | |
| SENTIMENT_LABELS = {0: "Negative", 1: "Neutral", 2: "Positive"} | |
| # Predict endpoint | |
| async def predict_sentiment(input_data: TextAspectInput): | |
| print(input_data) | |
| try: | |
| # Extract text and aspect | |
| text = input_data.text | |
| aspect = input_data.aspect | |
| # Process input | |
| input_ids = process_text(text, aspect) | |
| print("Process text: ", input_ids) | |
| # Make prediction | |
| try: | |
| with torch.no_grad(): | |
| logits = model(input_ids) | |
| probabilities = torch.softmax(logits, dim=-1) | |
| prediction = probabilities.argmax(dim=-1).item() | |
| sentiment = SENTIMENT_LABELS[prediction] | |
| except Exception as e: | |
| print(e) | |
| return {"sentiment": sentiment, "probabilities": probabilities.squeeze().tolist()} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |