|
# HP BERT Intent Classification Model |
|
|
|
This model is fine-tuned BERT for classifying different types of queries in the HP documentation context. |
|
|
|
## Model Details |
|
- Base model: bert-base-uncased |
|
- Task: 3-class classification |
|
- Classes: |
|
- 0: Queries requiring PDF context |
|
- 1: Summary-related queries |
|
- 2: Metadata-related queries |
|
|
|
## Usage |
|
```python |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
import torch |
|
|
|
class BertInference: |
|
def __init__(self, model_path): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device) |
|
self.tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") |
|
self.template = "Question: {} Response: " |
|
self.label_map = { |
|
0: "query_with_pdf", |
|
1: "summarize_pdf", |
|
2: "query_metadata" |
|
} |
|
|
|
def predict(self, text): |
|
# Format the input text |
|
formatted_text = self.template.format(text) |
|
|
|
# Tokenize |
|
inputs = self.tokenizer( |
|
formatted_text, |
|
truncation=True, |
|
max_length=512, |
|
padding='max_length', |
|
return_tensors="pt" |
|
).to(self.device) |
|
|
|
# Get prediction |
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
predictions = torch.softmax(outputs.logits, dim=1) |
|
predicted_class = torch.argmax(predictions, dim=1).item() |
|
confidence = predictions[0][predicted_class].item() |
|
|
|
return { |
|
"predicted_class": self.label_map[predicted_class], |
|
"confidence": confidence, |
|
"all_probabilities": { |
|
self.label_map[i]: prob.item() |
|
for i, prob in enumerate(predictions[0]) |
|
} |
|
} |
|
|
|
def main(): |
|
# Initialize the model |
|
model_path = "output_dir_decision" # Path to your saved model |
|
inferencer = BertInference(model_path) |
|
|
|
# Example usage |
|
test_questions = [ |
|
"What are the new features in corolla cross?", |
|
"What is the summary of the provided pdf?", |
|
"The filesize of the pdf is?", |
|
] |
|
|
|
for question in test_questions: |
|
result = inferencer.predict(question) |
|
print(f"\nQuestion: {question}") |
|
print(f"Predicted Class: {result['predicted_class']}") |
|
print(f"Confidence: {result['confidence']:.4f}") |
|
print("All Probabilities:") |
|
for class_name, prob in result['all_probabilities'].items(): |
|
print(f" {class_name}: {prob:.4f}") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
``` |