import streamlit as st from transformers import pipeline, AutoTokenizer import torch import re import numpy as np import soundfile as sf from PIL import Image from datasets import load_dataset import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ==================== Model Loading & Caching ==================== @st.cache_resource(show_spinner=False) def load_models(): """Preload and cache all AI models""" logger.info("Loading image captioning model...") caption_model = pipeline( task="image-to-text", model="Salesforce/blip-image-captioning-base", device=0 if torch.cuda.is_available() else -1 ) logger.info("Loading story generation model...") story_model = pipeline( task="text-generation", model="Tincando/fiction_story_generator", device=0 if torch.cuda.is_available() else -1, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 ) logger.info("Loading text-to-speech model...") tts_model = pipeline( task="text-to-audio", model="Chan-Y/speecht5_finetuned_tr_commonvoice", device=0 if torch.cuda.is_available() else -1 ) tts_tokenizer = AutoTokenizer.from_pretrained( "Chan-Y/speecht5_finetuned_tr_commonvoice" ) return caption_model, story_model, tts_model, tts_tokenizer # ==================== Streamlit Page Configuration ==================== st.set_page_config( page_title="🧸 AI Story Generator Pro", page_icon="📖", layout="wide", initial_sidebar_state="expanded" ) # ==================== Sidebar Settings ==================== with st.sidebar: st.title("⚙️ Generation Settings") temperature = st.slider("Creativity Level", 0.5, 1.5, 0.85, step=0.05) max_length = st.slider("Story Length", 100, 500, 200) story_style = st.selectbox("Narrative Style", ["Fairy Tale", "Sci-Fi", "Adventure"]) voice_speed = st.slider("Speech Rate", 0.5, 2.0, 1.0) # ==================== Main Interface ==================== st.title("🖼️ AI-Powered Story Generator") st.write("Transform images into immersive stories with audio narration") # ==================== File Upload ==================== uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"]) if uploaded_file: # ==================== Image Display ==================== col1, col2 = st.columns([1, 2]) with col1: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) # ==================== Generation Pipeline ==================== if st.button("Generate Story", type="primary"): try: progress_bar = st.progress(0) status_text = st.empty() # Model Initialization with st.spinner("🔄 Initializing AI models..."): caption_model, story_model, tts_model, tts_tokenizer = load_models() speaker_emb = torch.tensor( load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"] ).unsqueeze(0) progress_bar.progress(20) # Image Captioning with st.spinner("📷 Analyzing visual content..."): caption_result = caption_model(image) caption = caption_result[0]['generated_text'] progress_bar.progress(40) # Story Generation with st.spinner("✍️ Crafting narrative..."): prompt = f"Write a children's story in {story_style} style about: {caption}" story = story_model( prompt, temperature=temperature, max_length=max_length, do_sample=True )[0]['generated_text'] # Ensure proper punctuation story = re.sub(r'[^.!?]+$', '', story) progress_bar.progress(70) # Audio Synthesis with st.spinner("🔊 Generating narration..."): chunks = re.split(r'(?<=[.!?]) +', story) audio_arrays = [] for chunk in chunks: inputs = tts_tokenizer(chunk, return_tensors="pt") speech = tts_model( inputs["input_ids"], forward_params={ "speaker_embeddings": speaker_emb, "speed": voice_speed } ) audio_arrays.append(speech["audio"].numpy()) combined = np.concatenate(audio_arrays) sf.write("output.wav", combined, samplerate=16000) progress_bar.progress(100) # ==================== Results Display ==================== with col2: st.subheader("📖 Generated Story") st.success(story) st.subheader("🔊 Audio Narration") st.audio("output.wav", format="audio/wav") # Download Options st.download_button( label="Download Story Text", data=story, file_name="generated_story.txt", mime="text/plain" ) st.download_button( label="Download Audio File", data=open("output.wav", "rb"), file_name="story_audio.wav", mime="audio/wav" ) except Exception as e: st.error(f"Generation failed: {str(e)}") st.button("Retry", on_click=st.cache_resource.clear)