PerryCheng614 commited on
Commit
3ad5c44
1 Parent(s): b5815d0

Upload inference script

Browse files
Files changed (1) hide show
  1. bert_inference.py +69 -0
bert_inference.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+ import torch
3
+
4
+ class BertInference:
5
+ def __init__(self, model_path):
6
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)
8
+ self.tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
9
+ # self.tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
10
+ self.label_map = {
11
+ 0: "query_with_pdf",
12
+ 1: "summarize_pdf",
13
+ 2: "query_metadata"
14
+ }
15
+
16
+ def predict(self, text):
17
+ # Tokenize
18
+ inputs = self.tokenizer(
19
+ text,
20
+ return_tensors="pt"
21
+ ).to(self.device)
22
+
23
+ # Get prediction
24
+ with torch.no_grad():
25
+ outputs = self.model(**inputs)
26
+ predictions = torch.softmax(outputs.logits, dim=1)
27
+ predicted_class = torch.argmax(predictions, dim=1).item()
28
+ confidence = predictions[0][predicted_class].item()
29
+
30
+ return {
31
+ "predicted_class": self.label_map[predicted_class],
32
+ "confidence": confidence,
33
+ "all_probabilities": {
34
+ self.label_map[i]: prob.item()
35
+ for i, prob in enumerate(predictions[0])
36
+ }
37
+ }
38
+
39
+ def main():
40
+ # Initialize the model
41
+ model_path = "output_dir_decision" # Path to your saved model
42
+ # model_path = "output_xlm_roberta_bert"
43
+ inferencer = BertInference(model_path)
44
+
45
+ # Example usage
46
+ test_questions = [
47
+ "Tell me about the new features of chrome 120",
48
+ "What is the battery life",
49
+ "What is the file name?",
50
+ "What is the file size?",
51
+ "What is the upload time?",
52
+ "What is the last modified time?",
53
+ "What is the pdf about?",
54
+ "Could you give me a sketch?",
55
+ "How old is the monkey?",
56
+ "What is the game performance of the new GPU?"
57
+ ]
58
+
59
+ for question in test_questions:
60
+ result = inferencer.predict(question)
61
+ print(f"\nQuestion: {question}")
62
+ print(f"Predicted Class: {result['predicted_class']}")
63
+ print(f"Confidence: {result['confidence']:.4f}")
64
+ print("All Probabilities:")
65
+ for class_name, prob in result['all_probabilities'].items():
66
+ print(f" {class_name}: {prob:.4f}")
67
+
68
+ if __name__ == "__main__":
69
+ main()