story-image-api / story_generator.py
abdalraheemdmd's picture
Update story_generator.py
06e92c4 verified
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import random
import re
# Set Hugging Face cache directory
CACHE_DIR = "/tmp/huggingface"
# ------------------------
# Load Story Generation Model
# ------------------------
STORY_MODEL_NAME = "abdalraheemdmd/story-api"
story_tokenizer = GPT2Tokenizer.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR)
story_model = GPT2LMHeadModel.from_pretrained(STORY_MODEL_NAME, cache_dir=CACHE_DIR)
# ------------------------
# Load Question Generation Model
# ------------------------
QUESTION_MODEL_NAME = "abdalraheemdmd/question-gene"
question_tokenizer = GPT2Tokenizer.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR)
question_model = GPT2LMHeadModel.from_pretrained(QUESTION_MODEL_NAME, cache_dir=CACHE_DIR)
# Ensure tokenizers have a pad token
if story_tokenizer.pad_token_id is None:
story_tokenizer.pad_token_id = story_tokenizer.eos_token_id
if question_tokenizer.pad_token_id is None:
question_tokenizer.pad_token_id = question_tokenizer.eos_token_id
def generate_story(theme, reading_level, max_new_tokens=400, temperature=0.7):
"""Generates a story based on the provided theme and reading level."""
prompt = f"A {reading_level} story about {theme}:"
input_ids = story_tokenizer(prompt, return_tensors="pt").input_ids
with torch.no_grad():
output = story_model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=20,
top_p=0.7,
do_sample=True,
early_stopping=True,
pad_token_id=story_tokenizer.pad_token_id,
eos_token_id=story_tokenizer.eos_token_id,
attention_mask=input_ids.ne(story_tokenizer.pad_token_id)
)
return story_tokenizer.decode(output[0], skip_special_tokens=True)
def extract_protagonist(story):
"""
Attempts to extract the protagonist from the first sentence by searching for the pattern "named <Name>".
Returns the first matched name, if available.
"""
sentences = re.split(r'\.|\n', story)
if sentences:
m = re.search(r"named\s+([A-Z][a-z]+)", sentences[0])
if m:
return m.group(1)
return None
def extract_characters(story):
"""
Extracts potential character names from the story using a frequency count on capitalized words.
Filters out common stopwords so that the most frequently mentioned name is likely the main character.
"""
words = re.findall(r'\b[A-Z][a-zA-Z]+\b', story)
stopwords = {"The", "A", "An", "And", "But", "Suddenly", "Quickly", "However", "Well",
"They", "I", "He", "She", "It", "When", "Where", "Dr", "Mr"}
filtered = [w for w in words if w not in stopwords and len(w) > 2]
if not filtered:
return []
freq = {}
for word in filtered:
freq[word] = freq.get(word, 0) + 1
sorted_chars = sorted(freq.items(), key=lambda x: x[1], reverse=True)
return [item[0] for item in sorted_chars]
def extract_themes(story):
"""Extracts themes from the story based on keyword matching."""
themes = []
story_lower = story.lower()
if "space" in story_lower:
themes.append("space")
if "adventure" in story_lower:
themes.append("adventure")
if "friend" in story_lower:
themes.append("friendship")
if "learn" in story_lower or "lesson" in story_lower:
themes.append("learning")
return themes
def extract_lesson(story):
"""
Attempts to extract a lesson or moral from the story by finding sentences
containing keywords like "learn" or "lesson". Returns the last matching sentence.
"""
sentences = re.split(r'\.|\n', story)
lesson_sentences = [
s.strip() for s in sentences
if ("learn" in s.lower() or "lesson" in s.lower()) and len(s.strip()) > 20
]
if lesson_sentences:
return lesson_sentences[-1]
else:
return "No explicit lesson found."
def format_question(question_prompt, correct_answer, distractors):
"""
Combines the correct answer with three distractors, shuffles the options,
and formats the question as a multiple-choice question.
"""
# Ensure exactly 3 distractors are available
if len(distractors) < 3:
default_distractors = ["Option X", "Option Y", "Option Z"]
while len(distractors) < 3:
distractors.append(default_distractors[len(distractors) % len(default_distractors)])
else:
distractors = random.sample(distractors, 3)
options = distractors + [correct_answer]
random.shuffle(options)
letters = ["A", "B", "C", "D"]
correct_letter = letters[options.index(correct_answer)]
options_text = "\n".join(f"{letters[i]}) {option}" for i, option in enumerate(options))
question_text = f"{question_prompt}\n{options_text}\nCorrect Answer: {correct_letter}"
return question_text
def dynamic_fallback_questions(story):
"""
Generates three multiple-choice questions based on dynamic story content.
Each question uses a randomly chosen template and shuffles its options.
"""
protagonist = extract_protagonist(story)
characters = extract_characters(story)
themes = extract_themes(story)
lesson = extract_lesson(story)
# --- Question 1: Theme ---
theme_templates = [
"What is the main theme of the story?",
"Which theme best represents the narrative?",
"What subject is central to the story?"
]
q1_prompt = random.choice(theme_templates)
correct_theme = " and ".join(themes) if themes else "learning"
q1_distractors = ["sports and competition", "cooking and baking", "weather and seasons", "technology and innovation"]
q1 = format_question(q1_prompt, correct_theme, q1_distractors)
# --- Question 2: Primary Character ---
character_templates = [
"Who is the primary character in the story?",
"Which character drives the main action in the narrative?",
"Who is the central figure in the story?"
]
q2_prompt = random.choice(character_templates)
if protagonist:
correct_character = protagonist
elif characters:
correct_character = characters[0]
else:
correct_character = "The main character"
q2_distractors = ["a mysterious stranger", "an unknown visitor", "a supporting character", "a sidekick"]
q2 = format_question(q2_prompt, correct_character, q2_distractors)
# --- Question 3: Lesson/Moral ---
lesson_templates = [
"What lesson did the characters learn by the end of the story?",
"What moral can be inferred from the narrative?",
"What is the key takeaway from the story?"
]
q3_prompt = random.choice(lesson_templates)
if lesson and lesson != "No explicit lesson found.":
correct_lesson = lesson # full sentence without truncation
else:
correct_lesson = "understanding and growth"
q3_distractors = ["always be silent", "never try new things", "do nothing", "ignore opportunities"]
q3 = format_question(q3_prompt, correct_lesson, q3_distractors)
return f"{q1}\n\n{q2}\n\n{q3}"
def generate_story_and_questions(theme, reading_level):
"""
Generates a story using the story generation model and then creates dynamic,
multiple-choice questions based on that story.
"""
story = generate_story(theme, reading_level)
questions = dynamic_fallback_questions(story)
return {"story": story, "questions": questions}
# Alias for backward compatibility
create_fallback_questions = dynamic_fallback_questions