PerryCheng614's picture
Update model card
2cc9e9f verified

HP BERT Intent Classification Model

This model is fine-tuned BERT for classifying different types of queries in the HP documentation context.

Model Details

  • Base model: bert-base-uncased
  • Task: 3-class classification
  • Classes:
    • 0: Queries requiring PDF context
    • 1: Summary-related queries
    • 2: Metadata-related queries

Usage

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

class BertInference:
    def __init__(self, model_path):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
        self.label_map = {
            0: "query_with_pdf",
            1: "summarize_pdf",
            2: "query_metadata"
        }

    def predict(self, text):
        # Format the input text        
        # Tokenize
        inputs = self.tokenizer(
            text,
            truncation=True,
            max_length=512,
            padding='max_length',
            return_tensors="pt"
        ).to(self.device)

        # Get prediction
        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = torch.softmax(outputs.logits, dim=1)
            predicted_class = torch.argmax(predictions, dim=1).item()
            confidence = predictions[0][predicted_class].item()

        return {
            "predicted_class": self.label_map[predicted_class],
            "confidence": confidence,
            "all_probabilities": {
                self.label_map[i]: prob.item()
                for i, prob in enumerate(predictions[0])
            }
        }

def main():
    # Initialize the model
    model_path = "nexaai2b/Octopus-xlm-roberta-BERT-intent-classification"  # Path to your saved model
    inferencer = BertInference(model_path)

    # Example usage
    test_questions = [
        "What are the new features in corolla cross?",
        "What is the summary of the provided pdf?",
        "The filesize of the pdf is?",
    ]

    for question in test_questions:
        result = inferencer.predict(question)
        print(f"\nQuestion: {question}")
        print(f"Predicted Class: {result['predicted_class']}")
        print(f"Confidence: {result['confidence']:.4f}")
        print("All Probabilities:")
        for class_name, prob in result['all_probabilities'].items():
            print(f"  {class_name}: {prob:.4f}")

if __name__ == "__main__":
    main()