HP BERT Intent Classification Model
This model is fine-tuned BERT for classifying different types of queries in the HP documentation context.
Model Details
- Base model: bert-base-uncased
- Task: 3-class classification
- Classes:
- 0: Queries requiring PDF context
- 1: Summary-related queries
- 2: Metadata-related queries
Usage
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
class BertInference:
def __init__(self, model_path):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
self.label_map = {
0: "query_with_pdf",
1: "summarize_pdf",
2: "query_metadata"
}
def predict(self, text):
# Format the input text
# Tokenize
inputs = self.tokenizer(
text,
truncation=True,
max_length=512,
padding='max_length',
return_tensors="pt"
).to(self.device)
# Get prediction
with torch.no_grad():
outputs = self.model(**inputs)
predictions = torch.softmax(outputs.logits, dim=1)
predicted_class = torch.argmax(predictions, dim=1).item()
confidence = predictions[0][predicted_class].item()
return {
"predicted_class": self.label_map[predicted_class],
"confidence": confidence,
"all_probabilities": {
self.label_map[i]: prob.item()
for i, prob in enumerate(predictions[0])
}
}
def main():
# Initialize the model
model_path = "nexaai2b/Octopus-xlm-roberta-BERT-intent-classification" # Path to your saved model
inferencer = BertInference(model_path)
# Example usage
test_questions = [
"What are the new features in corolla cross?",
"What is the summary of the provided pdf?",
"The filesize of the pdf is?",
]
for question in test_questions:
result = inferencer.predict(question)
print(f"\nQuestion: {question}")
print(f"Predicted Class: {result['predicted_class']}")
print(f"Confidence: {result['confidence']:.4f}")
print("All Probabilities:")
for class_name, prob in result['all_probabilities'].items():
print(f" {class_name}: {prob:.4f}")
if __name__ == "__main__":
main()