cnmoro's picture
Update README.md
b527297 verified
metadata
license: apache-2.0
datasets:
  - cnmoro/QuestionClassification
tags:
  - classification
  - questioning
  - directed
  - generic
language:
  - en
  - pt
library_name: transformers
pipeline_tag: text-classification
widget:
  - text: What is the summary of the text?

A finetuned version of prajjwal1/bert-tiny.

The goal is to classify questions into "Directed" or "Generic".

If a question is not directed, we would change the actions we perform on a RAG pipeline (if it is generic, semantic search wouldn't be useful directly; e.g. asking for a summary).

(Class 0 is Generic; Class 1 is Directed)

The accuracy on the training dataset is around 87.5%

from transformers import BertForSequenceClassification, BertTokenizerFast
import torch

# Load the model and tokenizer
model = BertForSequenceClassification.from_pretrained("cnmoro/bert-tiny-question-classifier")
tokenizer = BertTokenizerFast.from_pretrained("cnmoro/bert-tiny-question-classifier")

def is_question_generic(question):
    # Tokenize the sentence and convert to PyTorch tensors
    inputs = tokenizer(
        question.lower(),
        truncation=True,
        padding=True,
        return_tensors="pt",
        max_length=512
    )

    # Get the model's predictions
    with torch.no_grad():
        outputs = model(**inputs)

    # Extract the prediction
    predictions = outputs.logits
    predicted_class = torch.argmax(predictions).item()

    return int(predicted_class) == 0