|
import os |
|
import time |
|
import streamlit as st |
|
from PIL import Image |
|
from io import BytesIO |
|
from huggingface_hub import InferenceApi, login |
|
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer |
|
import torch |
|
from gtts import gTTS |
|
import tempfile |
|
|
|
|
|
st.set_page_config(page_title="Magic Story Generator (Qwen2.5)", layout="centered") |
|
st.title("📖✨ Turn Images into Children's Stories (Qwen2.5-Omni-7B)") |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def load_clients(): |
|
hf_token = st.secrets["HF_TOKEN"] |
|
|
|
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_token |
|
login(hf_token) |
|
|
|
|
|
caption_client = InferenceApi( |
|
repo_id="Salesforce/blip-image-captioning-base", |
|
token=hf_token |
|
) |
|
|
|
|
|
t0 = time.time() |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"Qwen/Qwen2.5-Omni-7B", |
|
trust_remote_code=True |
|
) |
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
"Qwen/Qwen2.5-Omni-7B", |
|
trust_remote_code=True, |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
attn_implementation="flash_attention_2" |
|
) |
|
|
|
storyteller = pipeline( |
|
task="text2text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device_map="auto", |
|
temperature=0.7, |
|
top_p=0.9, |
|
repetition_penalty=1.2, |
|
no_repeat_ngram_size=3, |
|
max_new_tokens=120 |
|
) |
|
load_time = time.time() - t0 |
|
st.text(f"✅ Story model loaded in {load_time:.1f}s (cached)") |
|
return caption_client, storyteller |
|
|
|
caption_client, storyteller = load_clients() |
|
|
|
|
|
def generate_caption(img: Image.Image) -> str: |
|
buf = BytesIO() |
|
img.save(buf, format="JPEG") |
|
resp = caption_client(data=buf.getvalue()) |
|
if isinstance(resp, list) and resp: |
|
return resp[0].get("generated_text", "").strip() |
|
return "" |
|
|
|
|
|
def generate_story(caption: str) -> str: |
|
prompt = ( |
|
"You are a creative children's-story author.\n" |
|
f"Image description: “{caption}”\n\n" |
|
"Write a coherent 50–100 word story that:\n" |
|
"1. Introduces the main character.\n" |
|
"2. Shows a simple problem or discovery.\n" |
|
"3. Has a happy resolution.\n" |
|
"4. Uses clear language for ages 3–8.\n" |
|
"5. Keeps each sentence under 20 words.\n" |
|
) |
|
t0 = time.time() |
|
result = storyteller(prompt) |
|
gen_time = time.time() - t0 |
|
st.text(f"⏱ Generated in {gen_time:.1f}s on GPU/CPU") |
|
|
|
story = result[0]["generated_text"].strip() |
|
|
|
words = story.split() |
|
if len(words) > 100: |
|
story = " ".join(words[:100]) |
|
if not story.endswith('.'): |
|
story += '.' |
|
return story |
|
|
|
|
|
uploaded = st.file_uploader("Upload an image:", type=["jpg","png","jpeg"]) |
|
if uploaded: |
|
img = Image.open(uploaded).convert("RGB") |
|
if max(img.size) > 2048: |
|
img.thumbnail((2048,2048)) |
|
st.image(img, use_container_width=True) |
|
|
|
with st.spinner("🔍 Generating caption..."): |
|
caption = generate_caption(img) |
|
if not caption: |
|
st.error("😢 Couldn't understand this image. Try another one!") |
|
st.stop() |
|
st.success(f"**Caption:** {caption}") |
|
|
|
with st.spinner("📝 Writing story..."): |
|
story = generate_story(caption) |
|
|
|
st.subheader("📚 Your Magical Story") |
|
st.write(story) |
|
|
|
with st.spinner("🔊 Converting to audio..."): |
|
try: |
|
tts = gTTS(text=story, lang="en", slow=False) |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as fp: |
|
tts.save(fp.name) |
|
st.audio(fp.name, format="audio/mp3") |
|
except Exception as e: |
|
st.warning(f"⚠️ TTS failed: {e}") |
|
|
|
|
|
st.markdown("---\n*Made with ❤️ by your friendly story wizard*") |
|
|