Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline, AutoTokenizer | |
import torch | |
import re | |
import numpy as np | |
import soundfile as sf | |
from PIL import Image | |
from datasets import load_dataset | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# ==================== Model Loading & Caching ==================== | |
def load_models(): | |
"""Preload and cache all AI models""" | |
logger.info("Loading image captioning model...") | |
caption_model = pipeline( | |
task="image-to-text", | |
model="Salesforce/blip-image-captioning-base", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
logger.info("Loading story generation model...") | |
story_model = pipeline( | |
task="text-generation", | |
model="Tincando/fiction_story_generator", | |
device=0 if torch.cuda.is_available() else -1, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
) | |
logger.info("Loading text-to-speech model...") | |
tts_model = pipeline( | |
task="text-to-audio", | |
model="Chan-Y/speecht5_finetuned_tr_commonvoice", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
tts_tokenizer = AutoTokenizer.from_pretrained( | |
"Chan-Y/speecht5_finetuned_tr_commonvoice" | |
) | |
return caption_model, story_model, tts_model, tts_tokenizer | |
# ==================== Streamlit Page Configuration ==================== | |
st.set_page_config( | |
page_title="π§Έ AI Story Generator Pro", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# ==================== Sidebar Settings ==================== | |
with st.sidebar: | |
st.title("βοΈ Generation Settings") | |
temperature = st.slider("Creativity Level", 0.5, 1.5, 0.85, step=0.05) | |
max_length = st.slider("Story Length", 100, 500, 200) | |
story_style = st.selectbox("Narrative Style", ["Fairy Tale", "Sci-Fi", "Adventure"]) | |
voice_speed = st.slider("Speech Rate", 0.5, 2.0, 1.0) | |
# ==================== Main Interface ==================== | |
st.title("πΌοΈ AI-Powered Story Generator") | |
st.write("Transform images into immersive stories with audio narration") | |
# ==================== File Upload ==================== | |
uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"]) | |
if uploaded_file: | |
# ==================== Image Display ==================== | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# ==================== Generation Pipeline ==================== | |
if st.button("Generate Story", type="primary"): | |
try: | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
# Model Initialization | |
with st.spinner("π Initializing AI models..."): | |
caption_model, story_model, tts_model, tts_tokenizer = load_models() | |
speaker_emb = torch.tensor( | |
load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] | |
).unsqueeze(0) | |
progress_bar.progress(20) | |
# Image Captioning | |
with st.spinner("π· Analyzing visual content..."): | |
caption_result = caption_model(image) | |
caption = caption_result[0]['generated_text'] | |
progress_bar.progress(40) | |
# Story Generation | |
with st.spinner("βοΈ Crafting narrative..."): | |
prompt = f"Write a children's story in {story_style} style about: {caption}" | |
story = story_model( | |
prompt, | |
temperature=temperature, | |
max_length=max_length, | |
do_sample=True | |
)[0]['generated_text'] | |
# Ensure proper punctuation | |
story = re.sub(r'[^.!?]+$', '', story) | |
progress_bar.progress(70) | |
# Audio Synthesis | |
with st.spinner("π Generating narration..."): | |
chunks = re.split(r'(?<=[.!?]) +', story) | |
audio_arrays = [] | |
for chunk in chunks: | |
inputs = tts_tokenizer(chunk, return_tensors="pt") | |
speech = tts_model( | |
inputs["input_ids"], | |
forward_params={ | |
"speaker_embeddings": speaker_emb, | |
"speed": voice_speed | |
} | |
) | |
audio_arrays.append(speech["audio"].numpy()) | |
combined = np.concatenate(audio_arrays) | |
sf.write("output.wav", combined, samplerate=16000) | |
progress_bar.progress(100) | |
# ==================== Results Display ==================== | |
with col2: | |
st.subheader("π Generated Story") | |
st.success(story) | |
st.subheader("π Audio Narration") | |
st.audio("output.wav", format="audio/wav") | |
# Download Options | |
st.download_button( | |
label="Download Story Text", | |
data=story, | |
file_name="generated_story.txt", | |
mime="text/plain" | |
) | |
st.download_button( | |
label="Download Audio File", | |
data=open("output.wav", "rb"), | |
file_name="story_audio.wav", | |
mime="audio/wav" | |
) | |
except Exception as e: | |
st.error(f"Generation failed: {str(e)}") | |
st.button("Retry", on_click=st.cache_resource.clear) |