joko333's picture
Implement sentence analysis functionality in Analysis page; add BiLSTM model and prediction utilities
ca5c473
raw
history blame
4.27 kB
import torch
from transformers import AutoTokenizer
from sklearn.preprocessing import LabelEncoder
from BiLSTM import BiLSTMAttentionBERT
import numpy as np
def load_model_for_prediction():
# Force CPU
device = torch.device('cpu')
torch.backends.mps.enabled = False
try:
# Load model from Hugging Face Hub
model = BiLSTMAttentionBERT.from_pretrained(
"joko333/BiLSTM_v01",
hidden_dim=128,
num_classes=22,
num_layers=2,
dropout=0.5
).to(device)
model.eval()
# Initialize label encoder with predefined classes
label_encoder = LabelEncoder()
label_encoder.classes_ = np.array(['Addition', 'Causal', 'Cause and Effect',
'Clarification', 'Comparison', 'Concession',
'Conditional', 'Contrast', 'Contrastive Emphasis',
'Definition', 'Elaboration', 'Emphasis',
'Enumeration', 'Explanation', 'Generalization',
'Illustration', 'Inference', 'Problem Solution',
'Purpose', 'Sequential', 'Summary',
'Temporal Sequence'])
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
'dmis-lab/biobert-base-cased-v1.2'
)
return model, label_encoder, tokenizer
except Exception as e:
print(f"Error loading model components: {str(e)}")
return None, None, None
def predict_sentence(model, sentence, tokenizer, label_encoder, device=None):
"""
Make prediction for a single sentence with label validation.
"""
device = torch.device('cpu')
model = model.to(device)
model.eval()
# Tokenize
encoding = tokenizer(
sentence,
add_special_tokens=True,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
try:
with torch.no_grad():
# Get model outputs
outputs = model(encoding['input_ids'], encoding['attention_mask'])
probabilities = torch.softmax(outputs, dim=1)
# Get prediction and probability
prob, pred_idx = torch.max(probabilities, dim=1)
# Validate prediction index
if pred_idx.item() >= len(label_encoder.classes_):
print(f"Warning: Model predicted invalid label index {pred_idx.item()}")
return "Unknown", 0.0
# Convert to label
try:
predicted_class = label_encoder.classes_[pred_idx.item()]
return predicted_class, prob.item()
except IndexError:
print(f"Warning: Invalid label index {pred_idx.item()}")
return "Unknown", 0.0
except Exception as e:
print(f"Prediction error: {str(e)}")
return "Error", 0.0
def print_labels(label_encoder, show_counts=False):
"""Print all labels and their corresponding indices"""
print("\nAvailable labels:")
print("-" * 40)
for idx, label in enumerate(label_encoder.classes_):
print(f"Index {idx}: {label}")
print("-" * 40)
print(f"Total number of classes: {len(label_encoder.classes_)}\n")
def predict_sentence2(sentence, model, tokenizer, label_encoder):
# Tokenize the input
inputs = tokenizer(sentence,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512)
# Move inputs to the same device as model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Make prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=1)
# Convert prediction to label
predicted_label = label_encoder.inverse_transform(predictions.cpu().numpy())[0]
return predicted_label