1 / app.py
mayf's picture
Update app.py
db1550f verified
raw
history blame
4.11 kB
import streamlit as st
from PIL import Image
from transformers import pipeline
from gtts import gTTS
import tempfile
import os
# —––––––– Page config —–––––––
st.set_page_config(
page_title="Storyteller for Kids",
page_icon="📚",
layout="centered",
initial_sidebar_state="collapsed"
)
st.title("🖼️➡️📖 Interactive Storyteller")
# —––––––– Cache model loading —–––––––
@st.cache_resource
def load_pipelines():
# Image-to-text pipeline
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
max_new_tokens=50
)
# Story generation pipeline with better parameters
storyteller = pipeline(
"text2text-generation",
model="google/flan-t5-xxl",
device_map="auto",
model_kwargs={"load_in_8bit": True}
)
return captioner, storyteller
# —––––––– Main workflow —–––––––
def main():
captioner, storyteller = load_pipelines()
# —––––––– Image upload —–––––––
uploaded = st.file_uploader(
"Upload an image:",
type=["jpg", "jpeg", "png"],
help="Max size: 5MB"
)
if uploaded:
try:
# —––––––– Display image —–––––––
image = Image.open(uploaded).convert("RGB")
st.image(image, caption="Your Image", use_column_width=True)
# —––––––– Generate caption —–––––––
with st.spinner("🔍 Analyzing image content..."):
cap_outputs = captioner(image)
cap = cap_outputs[0].get("generated_text", "").strip()
st.subheader("Image Understanding")
st.info(f"**Detected:** {cap}")
# —––––––– Generate story —–––––––
st.subheader("Story Creation")
prompt = f"""Create a children's story (3-10 years old) based on this description:
{cap}
Requirements:
- 50-100 words
- Playful and imaginative
- Positive message
- Simple vocabulary
- Include animal characters
Story:"""
with st.spinner("✍️ Crafting a magical story..."):
story_output = storyteller(
prompt,
max_length=300,
do_sample=True,
top_p=0.95,
temperature=0.85,
num_beams=4,
repetition_penalty=1.2
)
story = story_output[0]["generated_text"].strip()
st.success("**Generated Story:**")
st.write(story)
# —––––––– Text-to-Speech —–––––––
st.subheader("Audio Version")
with st.spinner("🔊 Generating audio..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as tmp:
tts.write_to_fp(tmp)
tmp_path = tmp.name
st.audio(tmp_path, format="audio/mp3")
# Add download button
with open(tmp_path, "rb") as f:
st.download_button(
label="Download Audio Story",
data=f,
file_name="kids_story.mp3",
mime="audio/mpeg"
)
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except Exception as e:
st.error(f"Error processing your request: {str(e)}")
if __name__ == "__main__":
main()