Spaces:
Running
Running
# app.py | |
import io | |
import wave | |
import re | |
import streamlit as st | |
from transformers import pipeline, SpeechT5Processor, SpeechT5HifiGan | |
from datasets import load_dataset | |
from PIL import Image | |
import numpy as np | |
import torch | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 1) LOAD PIPELINES | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def load_captioner(): | |
return pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", device="cpu") | |
def load_story_generator(): | |
return pipeline("text-generation", model="microsoft/Phi-4-mini-reasoning", device="cpu") | |
def load_tts_pipe(): | |
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
model = pipeline("text-to-speech", model="microsoft/speecht5_tts", device="cpu") | |
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
speaker_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
speaker_embedding = torch.tensor(speaker_dataset[7306]["xvector"]).unsqueeze(0) | |
return processor, model, vocoder, speaker_embedding | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 2) PIPELINE FUNCTIONS | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def get_caption(image, captioner): | |
return captioner(image)[0]['generated_text'] | |
def generate_story(caption, generator): | |
prompt = f"Write a short, magical story for children aged 3 to 10 based on this scene: {caption}. Keep it under 100 words." | |
outputs = generator( | |
prompt, | |
max_new_tokens=120, | |
temperature=0.8, | |
top_p=0.95, | |
do_sample=True | |
) | |
story = outputs[0]["generated_text"] | |
return clean_story_output(story, prompt) | |
def clean_story_output(story, prompt): | |
story = story[len(prompt):].strip() if story.startswith(prompt) else story | |
if "." in story: | |
story = story[: story.rfind(".") + 1] | |
return sentence_case(story) | |
def sentence_case(text): | |
parts = re.split(r'([.!?])', text) | |
out = [] | |
for i in range(0, len(parts) - 1, 2): | |
sentence = parts[i].strip().capitalize() | |
out.append(f"{sentence}{parts[i + 1]}") | |
if len(parts) % 2: | |
last = parts[-1].strip().capitalize() | |
if last: | |
out.append(last) | |
return " ".join(out) | |
def convert_to_audio(text, processor, tts_pipe, vocoder, speaker_embedding): | |
inputs = processor(text=text, return_tensors="pt") | |
speech = tts_pipe.model.generate_speech(inputs["input_ids"], speaker_embedding, vocoder=vocoder) | |
pcm = (speech.numpy() * 32767).astype(np.int16) | |
buffer = io.BytesIO() | |
with wave.open(buffer, "wb") as wf: | |
wf.setnchannels(1) | |
wf.setsampwidth(2) | |
wf.setframerate(16000) | |
wf.writeframes(pcm.tobytes()) | |
buffer.seek(0) | |
return buffer.read() | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 3) STREAMLIT APP UI | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
st.set_page_config(page_title="Magic Storyteller", layout="centered") | |
st.title("π§ Magic Storyteller") | |
st.markdown("Upload an image to generate a magical story and hear it read aloud!") | |
uploaded = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) | |
if uploaded: | |
image = Image.open(uploaded) | |
st.image(image, caption="Your uploaded image", use_column_width=True) | |
st.subheader("πΌοΈ Step 1: Captioning") | |
captioner = load_captioner() | |
caption = get_caption(image, captioner) | |
st.markdown(f"**Caption:** {sentence_case(caption)}") | |
st.subheader("π Step 2: Story Generation") | |
story_pipe = load_story_generator() | |
story = generate_story(caption, story_pipe) | |
st.write(story) | |
st.subheader("π Step 3: Listen to the Story") | |
processor, tts_pipe, vocoder, speaker_embedding = load_tts_pipe() | |
audio_bytes = convert_to_audio(story, processor, tts_pipe, vocoder, speaker_embedding) | |
st.audio(audio_bytes, format="audio/wav") | |
st.balloons() | |
else: | |
st.info("Please upload an image to begin.") |