Spaces:
Running
Running
| from typing import List | |
| import torch | |
| import logging | |
| class BaseModel: | |
| """ | |
| Base class for all models. | |
| """ | |
| def __init__(self, model_name: str): | |
| self.model_name = model_name | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| self.logger = logging.getLogger(__name__) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.sentiment_labels = ['NEGATIVE', 'NEUTRAL', 'POSITIVE'] | |
| def predict_sentiment(self, text: str) -> str: | |
| """ | |
| Predict sentiment for a given text. | |
| Args: | |
| text: Input text for sentiment analysis | |
| Returns: | |
| Sentiment label ('NEGATIVE', 'NEUTRAL', or 'POSITIVE') | |
| """ | |
| try: | |
| inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {key: val.to(self.device) for key, val in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Get sentiment predictions | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| sentiment_idx = torch.argmax(probabilities, dim=1).item() | |
| sentiment = self.sentiment_labels[sentiment_idx] | |
| return sentiment | |
| except Exception as e: | |
| self.logger.error(f"Error during sentiment prediction: {str(e)}") | |
| return "NEUTRAL" # Return neutral as default on error | |
| def batch_predict_sentiment(self, texts: List[str]) -> List[str]: | |
| """ | |
| Predict sentiment for a batch of texts. | |
| Args: | |
| texts: List of input texts for sentiment analysis | |
| Returns: | |
| List of sentiment labels | |
| """ | |
| try: | |
| inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {key: val.to(self.device) for key, val in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| # Get sentiment predictions | |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=1) | |
| sentiment_indices = torch.argmax(probabilities, dim=1).tolist() | |
| sentiments = [self.sentiment_labels[idx] for idx in sentiment_indices] | |
| return sentiments | |
| except Exception as e: | |
| self.logger.error(f"Error during batch sentiment prediction: {str(e)}") | |
| return ["NEUTRAL"] * len(texts) # Return neutral for all on error |