EMOTIA / models /text.py
Manav2op's picture
Upload folder using huggingface_hub
25d0747 verified
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import re
class TextIntentModel(nn.Module):
"""
Transformer-based model for text intent and sentiment analysis.
Fine-tuned BERT for conversational intent detection.
"""
def __init__(self, num_intents=5, pretrained=True):
super().__init__()
self.num_intents = num_intents
# Load pre-trained BERT
if pretrained:
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
else:
from transformers import BertConfig
config = BertConfig()
self.bert = BertModel(config)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Freeze base layers
for param in self.bert.parameters():
param.requires_grad = False
hidden_size = self.bert.config.hidden_size
# Intent classification head
self.intent_classifier = nn.Sequential(
nn.Linear(hidden_size, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_intents)
)
# Sentiment/emotion head
self.sentiment_head = nn.Sequential(
nn.Linear(hidden_size, 128),
nn.ReLU(),
nn.Linear(128, 7) # 7 emotions
)
# Confidence/hesitation detection
self.confidence_head = nn.Sequential(
nn.Linear(hidden_size, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(self, input_ids, attention_mask):
"""
input_ids: tokenized text (B, seq_len)
attention_mask: attention mask (B, seq_len)
Returns: intent_logits, sentiment_logits, confidence
"""
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output # [CLS] token
intent_logits = self.intent_classifier(pooled_output)
sentiment_logits = self.sentiment_head(pooled_output)
confidence = self.confidence_head(pooled_output)
return intent_logits, sentiment_logits, confidence.squeeze()
def preprocess_text(self, text):
"""
Preprocess and tokenize text input.
"""
# Clean text
text = self.clean_text(text)
# Tokenize
encoding = self.tokenizer(
text,
max_length=128,
padding='max_length',
truncation=True,
return_tensors='pt'
)
return encoding['input_ids'].squeeze(), encoding['attention_mask'].squeeze()
def clean_text(self, text):
"""
Clean and normalize text.
"""
# Remove special characters but keep punctuation
text = re.sub(r'[^\w\s.,!?]', '', text)
# Normalize whitespace
text = ' '.join(text.split())
return text.lower()
def detect_hesitation_phrases(self, text):
"""
Detect phrases indicating hesitation or confusion.
"""
hesitation_keywords = [
'um', 'uh', 'like', 'you know', 'sort of', 'kind of',
'i think', 'maybe', 'perhaps', 'i\'m not sure'
]
text_lower = text.lower()
hesitation_score = sum(1 for keyword in hesitation_keywords if keyword in text_lower)
return min(hesitation_score / 5.0, 1.0) # Normalize to 0-1
def extract_intent_features(self, text):
"""
Extract intent-related features from text.
"""
with torch.no_grad():
input_ids, attention_mask = self.preprocess_text(text)
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
intent_logits, sentiment_logits, confidence = self.forward(input_ids, attention_mask)
return {
'intent_logits': intent_logits,
'sentiment_logits': sentiment_logits,
'confidence': confidence,
'hesitation_score': self.detect_hesitation_phrases(text)
}