1 / app.py
mayf's picture
Update app.py
e616e4e verified
raw
history blame
4.19 kB
import os
import time
import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi, login
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from gtts import gTTS
import tempfile
# —––––––– Page Config —–––––––
st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered")
st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)")
# —––––––– Load Clients & Pipelines (cached) —–––––––
@st.cache_resource(show_spinner=False)
def load_clients():
hf_token = st.secrets["HF_TOKEN"]
# Authenticate for Hugging Face Hub
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token
login(hf_token)
# 1) BLIP captioning via HTTP API
caption_client = InferenceApi(
repo_id="Salesforce/blip-image-captioning-base",
token=hf_token
)
# 2) Load Qwen2.5-Omni model & tokenizer
t0 = time.time()
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen2.5-Omni-7B",
trust_remote_code=True
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"Qwen/Qwen2.5-Omni-7B",
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2"
)
# 3) Build text2text pipeline
storyteller = pipeline(
task="text2text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
temperature=0.7,
top_p=0.9,
repetition_penalty=1.2,
no_repeat_ngram_size=3,
max_new_tokens=120
)
load_time = time.time() - t0
st.text(f"✅ Story model loaded in {load_time:.1f}s (cached)")
return caption_client, storyteller
caption_client, storyteller = load_clients()
# —––––––– Helpers —–––––––
def generate_caption(img: Image.Image) -> str:
buf = BytesIO()
img.save(buf, format="JPEG")
resp = caption_client(data=buf.getvalue())
if isinstance(resp, list) and resp:
return resp[0].get("generated_text", "").strip()
return ""
def generate_story(caption: str) -> str:
prompt = (
"You are a creative children's-story author.\n"
f"Image description: “{caption}”\n\n"
"Write a coherent 50–100 word story that:\n"
"1. Introduces the main character.\n"
"2. Shows a simple problem or discovery.\n"
"3. Has a happy resolution.\n"
"4. Uses clear language for ages 3–8.\n"
"5. Keeps each sentence under 20 words.\n"
)
t0 = time.time()
result = storyteller(prompt)
gen_time = time.time() - t0
st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU")
story = result[0]["generated_text"].strip()
# Enforce ≤100 words
words = story.split()
if len(words) > 100:
story = " ".join(words[:100])
if not story.endswith('.'):
story += '.'
return story
# —––––––– Main App —–––––––
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"])
if uploaded:
img = Image.open(uploaded).convert("RGB")
if max(img.size) > 2048:
img.thumbnail((2048,2048))
st.image(img, use_container_width=True)
with st.spinner("🔍 Generating caption..."):
caption = generate_caption(img)
if not caption:
st.error("😢 Couldn't understand this image. Try another one!")
st.stop()
st.success(f"**Caption:** {caption}")
with st.spinner("📝 Writing story..."):
story = generate_story(caption)
st.subheader("📚 Your Magical Story")
st.write(story)
with st.spinner("🔊 Converting to audio..."):
try:
tts = gTTS(text=story, lang="en", slow=False)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp:
tts.save(fp.name)
st.audio(fp.name, format="audio/mp3")
except Exception as e:
st.warning(f"⚠️ TTS failed: {e}")
# Footer
st.markdown("---\n*Made with ❤️ by your friendly story wizard*")