Spaces:
Sleeping
Sleeping
# First try to import with fallbacks | |
try: | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import joblib | |
from huggingface_hub import hf_hub_download | |
import json | |
except ImportError as e: | |
print(f"Import error: {e}") | |
# Try to install missing packages (this might not work in Spaces) | |
import subprocess | |
import sys | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "transformers", "joblib", "huggingface-hub"]) | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import joblib | |
from huggingface_hub import hf_hub_download | |
import json | |
class DrugInteractionClassifier(torch.nn.Module): | |
def __init__(self, n_classes, bert_model_name="emilyalsentzer/Bio_ClinicalBERT"): | |
super(DrugInteractionClassifier, self).__init__() | |
self.bert = AutoModel.from_pretrained(bert_model_name) | |
self.classifier = torch.nn.Sequential( | |
torch.nn.Linear(self.bert.config.hidden_size, 256), | |
torch.nn.ReLU(), | |
torch.nn.Dropout(0.3), | |
torch.nn.Linear(256, n_classes) | |
) | |
def forward(self, input_ids, attention_mask): | |
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
pooled_output = bert_output[0][:, 0, :] | |
return self.classifier(pooled_output) | |
class DDIPredictor: | |
def __init__(self, repo_id="Fredaaaaaa/drug_interaction_severity"): | |
self.repo_id = repo_id | |
print(f"π Loading model from: {repo_id}") | |
try: | |
# Download model files from Hugging Face | |
print("π₯ Downloading config.json...") | |
self.config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
print("π₯ Downloading pytorch_model.bin...") | |
self.model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") | |
print("π₯ Downloading label_encoder.joblib...") | |
self.label_encoder_path = hf_hub_download(repo_id=repo_id, filename="label_encoder.joblib") | |
# Load config | |
with open(self.config_path, "r") as f: | |
self.config = json.load(f) | |
# Load tokenizer from repo | |
print("π€ Loading tokenizer...") | |
self.tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
# Load label encoder | |
print("π·οΈ Loading label encoder...") | |
self.label_encoder = joblib.load(self.label_encoder_path) | |
# Initialize model | |
print("π§ Initializing model...") | |
self.model = DrugInteractionClassifier( | |
n_classes=self.config["num_labels"], | |
bert_model_name=self.config["bert_model_name"] | |
) | |
# Load weights | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"βοΈ Loading weights on {device}...") | |
self.model.load_state_dict( | |
torch.load(self.model_path, map_location=device) | |
) | |
self.model.to(device) | |
self.model.eval() | |
self.device = device | |
print(f"β Model loaded successfully from {repo_id} on {device}") | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
raise e | |
def predict(self, text, confidence_threshold=0.0): | |
"""Predict drug interaction severity""" | |
if not text or not text.strip(): | |
return { | |
"prediction": "Invalid Input", | |
"confidence": 0.0, | |
"probabilities": {label: 0.0 for label in self.label_encoder.classes_} | |
} | |
try: | |
# Tokenize | |
inputs = self.tokenizer( | |
text, | |
max_length=self.config.get("max_length", 128), | |
padding=True, | |
truncation=True, | |
return_tensors="pt" | |
) | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
# Predict | |
with torch.no_grad(): | |
outputs = self.model(inputs["input_ids"], inputs["attention_mask"]) | |
probabilities = torch.softmax(outputs, dim=1) | |
confidence, predicted_idx = torch.max(probabilities, dim=1) | |
predicted_label = self.label_encoder.inverse_transform([predicted_idx.item()])[0] | |
# Get all probabilities | |
all_probs = { | |
self.label_encoder.inverse_transform([i])[0]: prob.item() | |
for i, prob in enumerate(probabilities[0]) | |
} | |
return { | |
"prediction": predicted_label, | |
"confidence": confidence.item(), | |
"probabilities": all_probs | |
} | |
except Exception as e: | |
return { | |
"prediction": f"Error: {str(e)}", | |
"confidence": 0.0, | |
"probabilities": {label: 0.0 for label in self.label_encoder.classes_} | |
} | |
# Global predictor instance | |
try: | |
predictor = DDIPredictor("Fredaaaaaa/drug_interaction_severity") | |
MODEL_LOADED = True | |
except Exception as e: | |
print(f"Failed to load model: {e}") | |
predictor = None | |
MODEL_LOADED = False |