emotion / src /predict.py
robot4's picture
Upload folder using huggingface_hub
af9853e verified
import torch
import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from .config import Config
class SentimentPredictor:
def __init__(self, model_path=None):
# 1. 如果未指定路径,尝试自动寻找最新的模型
if model_path is None:
# 优先检查 Config.CHECKPOINT_DIR (如果训练完成,final_model 会在这里)
if os.path.exists(os.path.join(Config.CHECKPOINT_DIR, "config.json")):
model_path = Config.CHECKPOINT_DIR
else:
# 如果没有 final_model,尝试寻找最新的 checkpoint (在 results 目录)
import glob
ckpt_list = glob.glob(os.path.join(Config.RESULTS_DIR, "checkpoint-*"))
if ckpt_list:
# 按修改时间排序,取最新的
ckpt_list.sort(key=os.path.getmtime)
model_path = ckpt_list[-1]
print(f"Using latest checkpoint found: {model_path}")
else:
# 只有在真的找不到时才回退
model_path = Config.CHECKPOINT_DIR
print(f"Loading model from {model_path}...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
except OSError:
print(f"Warning: Model not found at {model_path}. Loading base model for demo purpose.")
self.tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
self.model = AutoModelForSequenceClassification.from_pretrained(Config.BASE_MODEL, num_labels=Config.NUM_LABELS)
# Device selection
if torch.backends.mps.is_available():
self.device = torch.device("mps")
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.model.to(self.device)
self.model.eval()
def predict(self, text):
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=Config.MAX_LENGTH,
padding=True
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
prediction = torch.argmax(probabilities, dim=-1).item()
score = probabilities[0][prediction].item()
label = Config.ID2LABEL.get(prediction, "unknown")
return {
"text": text,
"sentiment": label,
"confidence": f"{score:.4f}"
}
if __name__ == "__main__":
# Demo
predictor = SentimentPredictor()
test_texts = [
"这家店的快递太慢了,而且东西味道很奇怪。",
"非常不错,包装很精美,下次还会来买。",
"感觉一般般吧,没有想象中那么好,但也还可以。"
]
print("\nPredicting...")
for text in test_texts:
result = predictor.predict(text)
print(f"Text: {result['text']}")
print(f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']})")
print("-" * 30)