severity / inference.py
Fredaaaaaa's picture
Update inference.py
338a546 verified
raw
history blame
5.45 kB
# First try to import with fallbacks
try:
import torch
from transformers import AutoTokenizer, AutoModel
import joblib
from huggingface_hub import hf_hub_download
import json
except ImportError as e:
print(f"Import error: {e}")
# Try to install missing packages (this might not work in Spaces)
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "transformers", "joblib", "huggingface-hub"])
import torch
from transformers import AutoTokenizer, AutoModel
import joblib
from huggingface_hub import hf_hub_download
import json
class DrugInteractionClassifier(torch.nn.Module):
def __init__(self, n_classes, bert_model_name="emilyalsentzer/Bio_ClinicalBERT"):
super(DrugInteractionClassifier, self).__init__()
self.bert = AutoModel.from_pretrained(bert_model_name)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(self.bert.config.hidden_size, 256),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(256, n_classes)
)
def forward(self, input_ids, attention_mask):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = bert_output[0][:, 0, :]
return self.classifier(pooled_output)
class DDIPredictor:
def __init__(self, repo_id="Fredaaaaaa/drug_interaction_severity"):
self.repo_id = repo_id
print(f"πŸš€ Loading model from: {repo_id}")
try:
# Download model files from Hugging Face
print("πŸ“₯ Downloading config.json...")
self.config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
print("πŸ“₯ Downloading pytorch_model.bin...")
self.model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin")
print("πŸ“₯ Downloading label_encoder.joblib...")
self.label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.joblib")
# Load config
with open(self.config_path, "r") as f:
self.config = json.load(f)
# Load tokenizer from repo
print("πŸ”€ Loading tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(repo_id)
# Load label encoder
print("🏷️ Loading label encoder...")
self.label_encoder = joblib.load(self.label_encoder_path)
# Initialize model
print("🧠 Initializing model...")
self.model = DrugInteractionClassifier(
n_classes=self.config["num_labels"],
bert_model_name=self.config["bert_model_name"]
)
# Load weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"βš™οΈ Loading weights on {device}...")
self.model.load_state_dict(
torch.load(self.model_path, map_location=device)
)
self.model.to(device)
self.model.eval()
self.device = device
print(f"βœ… Model loaded successfully from {repo_id} on {device}")
except Exception as e:
print(f"❌ Error loading model: {e}")
raise e
def predict(self, text, confidence_threshold=0.0):
"""Predict drug interaction severity"""
if not text or not text.strip():
return {
"prediction": "Invalid Input",
"confidence": 0.0,
"probabilities": {label: 0.0 for label in self.label_encoder.classes_}
}
try:
# Tokenize
inputs = self.tokenizer(
text,
max_length=self.config.get("max_length", 128),
padding=True,
truncation=True,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Predict
with torch.no_grad():
outputs = self.model(inputs["input_ids"], inputs["attention_mask"])
probabilities = torch.softmax(outputs, dim=1)
confidence, predicted_idx = torch.max(probabilities, dim=1)
predicted_label = self.label_encoder.inverse_transform([predicted_idx.item()])[0]
# Get all probabilities
all_probs = {
self.label_encoder.inverse_transform([i])[0]: prob.item()
for i, prob in enumerate(probabilities[0])
}
return {
"prediction": predicted_label,
"confidence": confidence.item(),
"probabilities": all_probs
}
except Exception as e:
return {
"prediction": f"Error: {str(e)}",
"confidence": 0.0,
"probabilities": {label: 0.0 for label in self.label_encoder.classes_}
}
# Global predictor instance
try:
predictor = DDIPredictor("Fredaaaaaa/drug_interaction_severity")
MODEL_LOADED = True
except Exception as e:
print(f"Failed to load model: {e}")
predictor = None
MODEL_LOADED = False