BertForStorySkillClassification

Model Overview

BertForStorySkillClassification is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes:

  1. Character
  2. Setting
  3. Feeling
  4. Action
  5. Causal Relationship
  6. Outcome Resolution
  7. Prediction

This model is suitable for applications in education, literary analysis, and story comprehension.


Model Architecture

  • Base Model: bert-base-uncased
  • Classification Layer: A fully connected layer on top of BERT for 7-class classification.
  • Input: Question text (e.g., "Who is the main character in the story?")、QA text (e.g. "why could n't alice get a doll as a child ? <SEP> because her family was very poor ")、 QA pair + Context(e.g. "why could n't alice get a doll as a child ? <SEP> because her family was very poor <context> alice is ... ")
  • Output: Predicted label and confidence score.

Quick Start

Install Dependencies

Ensure you have the transformers library installed:

pip install transformers

Load Model and Tokenizer

from transformers import AutoModelForSequenceClassification, AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification")
tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification")

Use the predict Method for Inference

# Single text prediction
result = model.predict(
    texts="Where does this story take place?",
    tokenizer=tokenizer,
    return_probabilities=True
)
print(result)
# Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}]

# Batch prediction
results = model.predict(
    texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "],
    tokenizer=tokenizer,
    batch_size=16,
    device="cuda"
)
print(results)
"""
output:
[{'text': 'Why is the character sad?', 'label': 'causal relationship'},
 {'text': 'How does the story end?', 'label': 'action'},
 {'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ",
  'label': 'causal relationship'}]
"""

Training Details

Dataset

Source: FairytaleQAData

Training Parameters

Learning Rate: 2e-5 Batch Size: 32 Epochs: 3 Optimizer: AdamW

Performance Metrics

Accuracy: 97.3%

Recall: 96.59%

F1 Score: 96.96%

Notes

  1. Input Length: The model supports a maximum input length of 512 tokens. Longer texts will be truncated.
  2. Device Support: The model supports both CPU and GPU inference. GPU is recommended for faster performance.
  3. Tokenizer: Always use the matching tokenizer (AutoTokenizer) for the model.

Citation

If you use this model, please cite the following:

@misc{BertForStorySkillClassification,
  author = {curious},
  title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification},
  year = {2025},
  publisher = {Hugging Face},
  howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}}
}

License

This model is open-sourced under the Apache 2.0 License. For more details, see the LICENSE file.

Downloads last month
8
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for curious008/BertForStorySkillClassification

Finetuned
(4666)
this model

Dataset used to train curious008/BertForStorySkillClassification