PerryCheng614's picture
Upload inference script
3ad5c44 verified
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
class BertInference:
def __init__(self, model_path):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
# self.tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
self.label_map = {
0: "query_with_pdf",
1: "summarize_pdf",
2: "query_metadata"
}
def predict(self, text):
# Tokenize
inputs = self.tokenizer(
text,
return_tensors="pt"
).to(self.device)
# Get prediction
with torch.no_grad():
outputs = self.model(**inputs)
predictions = torch.softmax(outputs.logits, dim=1)
predicted_class = torch.argmax(predictions, dim=1).item()
confidence = predictions[0][predicted_class].item()
return {
"predicted_class": self.label_map[predicted_class],
"confidence": confidence,
"all_probabilities": {
self.label_map[i]: prob.item()
for i, prob in enumerate(predictions[0])
}
}
def main():
# Initialize the model
model_path = "output_dir_decision" # Path to your saved model
# model_path = "output_xlm_roberta_bert"
inferencer = BertInference(model_path)
# Example usage
test_questions = [
"Tell me about the new features of chrome 120",
"What is the battery life",
"What is the file name?",
"What is the file size?",
"What is the upload time?",
"What is the last modified time?",
"What is the pdf about?",
"Could you give me a sketch?",
"How old is the monkey?",
"What is the game performance of the new GPU?"
]
for question in test_questions:
result = inferencer.predict(question)
print(f"\nQuestion: {question}")
print(f"Predicted Class: {result['predicted_class']}")
print(f"Confidence: {result['confidence']:.4f}")
print("All Probabilities:")
for class_name, prob in result['all_probabilities'].items():
print(f" {class_name}: {prob:.4f}")
if __name__ == "__main__":
main()