BertForStorySkillClassification
Model Overview
BertForStorySkillClassification
is a BERT-based text classification model designed to categorize story-related questions into one of the following 7 classes:
- Character
- Setting
- Feeling
- Action
- Causal Relationship
- Outcome Resolution
- Prediction
This model is suitable for applications in education, literary analysis, and story comprehension.
Model Architecture
- Base Model:
bert-base-uncased
- Classification Layer: A fully connected layer on top of BERT for 7-class classification.
- Input: Question text (e.g., "Who is the main character in the story?")、QA text (e.g. "why could n't alice get a doll as a child ? <SEP> because her family was very poor ")、 QA pair + Context(e.g. "why could n't alice get a doll as a child ? <SEP> because her family was very poor <context> alice is ... ")
- Output: Predicted label and confidence score.
Quick Start
Install Dependencies
Ensure you have the transformers
library installed:
pip install transformers
Load Model and Tokenizer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("curious008/BertForStorySkillClassification")
tokenizer = AutoTokenizer.from_pretrained("curious008/BertForStorySkillClassification")
Use the predict Method for Inference
# Single text prediction
result = model.predict(
texts="Where does this story take place?",
tokenizer=tokenizer,
return_probabilities=True
)
print(result)
# Output: [{'text': 'Where does this story take place?', 'label': 'setting', 'score': 0.93178}]
# Batch prediction
results = model.predict(
texts=["Why is the character sad?", "How does the story end?","why could n't alice get a doll as a child ? <SEP> because her family was very poor "],
tokenizer=tokenizer,
batch_size=16,
device="cuda"
)
print(results)
"""
output:
[{'text': 'Why is the character sad?', 'label': 'causal relationship'},
{'text': 'How does the story end?', 'label': 'action'},
{'text': "why could n't alice get a doll as a child ? <SEP> because her family was very poor ",
'label': 'causal relationship'}]
"""
Training Details
Dataset
Source: FairytaleQAData
Training Parameters
Learning Rate: 2e-5 Batch Size: 32 Epochs: 3 Optimizer: AdamW
Performance Metrics
Accuracy: 97.3%
Recall: 96.59%
F1 Score: 96.96%
Notes
- Input Length: The model supports a maximum input length of 512 tokens. Longer texts will be truncated.
- Device Support: The model supports both CPU and GPU inference. GPU is recommended for faster performance.
- Tokenizer: Always use the matching tokenizer (AutoTokenizer) for the model.
Citation
If you use this model, please cite the following:
@misc{BertForStorySkillClassification,
author = {curious},
title = {BertForStorySkillClassification: A BERT-based Model for Story Question Classification},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/curious008/BertForStorySkillClassification}}
}
License
This model is open-sourced under the Apache 2.0 License. For more details, see the LICENSE file.
- Downloads last month
- 8
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Model tree for curious008/BertForStorySkillClassification
Base model
google-bert/bert-base-uncased