Spaces:
Sleeping
Sleeping
| import io | |
| import wave | |
| import streamlit as st | |
| from transformers import pipeline | |
| from PIL import Image | |
| import numpy as np | |
| import time | |
| import threading | |
| # βββ 1) MODEL LOADING (cached) ββββββββββββββββ | |
| # Cache the model loading to avoid reloading on every rerun, improving performance | |
| def get_image_captioner(model_name="Salesforce/blip-image-captioning-base"): | |
| # Load the image-to-text model for generating captions from images | |
| return pipeline("image-to-text", model=model_name, device="cpu") | |
| def get_story_pipe(model_name="google/flan-t5-base"): | |
| # Load the text-to-text model for generating stories from captions | |
| return pipeline("text2text-generation", model=model_name, device="cpu") | |
| def get_tts_pipe(model_name="facebook/mms-tts-eng"): | |
| # Load the text-to-speech model for converting stories to audio | |
| return pipeline("text-to-speech", model=model_name, device="cpu") | |
| # βββ 2) TRANSFORM FUNCTIONS ββββββββββββββββ | |
| def part1_image_to_text(pil_img, captioner): | |
| # Generate a caption for the input image using the captioner model | |
| results = captioner(pil_img) | |
| # Extract the generated caption, return empty string if no result | |
| return results[0].get("generated_text", "") if results else "" | |
| def part2_text_to_story( | |
| caption: str, | |
| story_pipe, | |
| target_words: int = 100, | |
| max_length: int = 100, | |
| min_length: int = 80, | |
| do_sample: bool = True, | |
| top_k: int = 100, | |
| top_p: float= 0.9, | |
| temperature: float= 0.7, | |
| repetition_penalty: float = 1.1, | |
| no_repeat_ngram_size: int = 4 | |
| ) -> str: | |
| # Create a prompt instructing the model to write a story based on the caption | |
| prompt = ( | |
| f"Write a vivid, imaginative short story of about {target_words} words " | |
| f"describing this scene: {caption}" | |
| ) | |
| # Generate the story using the text-to-text model with specified parameters | |
| out = story_pipe( | |
| prompt, | |
| max_length=max_length, # Maximum length of generated text | |
| min_length=min_length, # Minimum length to ensure sufficient content | |
| do_sample=do_sample, # Enable sampling for creative output | |
| top_k=top_k, # Consider top-k tokens for sampling | |
| top_p=top_p, # Use nucleus sampling for diversity | |
| temperature=temperature, # Control randomness of output | |
| repetition_penalty=repetition_penalty, # Penalize repeated phrases | |
| no_repeat_ngram_size=no_repeat_ngram_size, # Prevent repeating n-grams | |
| early_stopping=False # Continue until max_length is reached | |
| ) | |
| # Extract the generated text and clean it | |
| raw = out[0].get("generated_text", "").strip() | |
| if not raw: | |
| return "" | |
| # Remove the prompt if it appears in the output | |
| if raw.lower().startswith(prompt.lower()): | |
| story = raw[len(prompt):].strip() | |
| else: | |
| story = raw | |
| # Truncate at the last full stop for a natural ending | |
| idx = story.rfind(".") | |
| if idx != -1: | |
| story = story[:idx+1] | |
| return story | |
| def part3_text_to_speech_bytes(text: str, tts_pipe) -> bytes: | |
| # Convert the input text to audio using the text-to-speech model | |
| out = tts_pipe(text) | |
| if isinstance(out, list): | |
| out = out[0] | |
| # Extract audio data (numpy array) and sampling rate | |
| audio_array = out["audio"] # np.ndarray (channels, samples) | |
| rate = out["sampling_rate"] # int | |
| # Transpose audio array if it has multiple channels | |
| data = audio_array.T if audio_array.ndim == 2 else audio_array | |
| # Convert audio to 16-bit PCM format for WAV compatibility | |
| pcm = (data * 32767).astype(np.int16) | |
| # Create a WAV file in memory | |
| buffer = io.BytesIO() | |
| wf = wave.open(buffer, "wb") | |
| channels = 1 if data.ndim == 1 else data.shape[1] # Set mono or stereo | |
| wf.setnchannels(channels) | |
| wf.setsampwidth(2) # 2 bytes for 16-bit audio | |
| wf.setframerate(rate) # Set sampling rate | |
| wf.writeframes(pcm.tobytes()) # Write audio data | |
| wf.close() | |
| buffer.seek(0) # Reset buffer to start for reading | |
| return buffer.read() # Return WAV bytes | |
| # βββ 3) STREAMLIT UI ββββββββββββββββββββββββββββ | |
| # Configure the Streamlit page for a kid-friendly, centered layout | |
| st.set_page_config( | |
| page_title="Picture to Story Magic", | |
| page_icon="β¨", | |
| layout="centered" | |
| ) | |
| # Apply custom CSS for a colorful, engaging, and readable interface | |
| st.markdown(""" | |
| <style> | |
| .main { | |
| background-color: #e6f3ff; /* Light blue background for main area */ | |
| padding: 20px; | |
| border-radius: 15px; | |
| } | |
| .stButton>button { | |
| background-color: #ffcccb; /* Pink button background */ | |
| button-color: #000000; | |
| border-radius: 10px; | |
| border: 2px solid #ff9999; /* Red border */ | |
| font-size: 18px; | |
| font-weight: bold; | |
| padding: 10px 20px; | |
| transition: all 0.3s; /* Smooth hover effect */ | |
| } | |
| .stButton>button:hover { | |
| background-color: #ff9999; /* Darker pink on hover */ | |
| color: #ffffff; | |
| transform: scale(1.05); /* Slight zoom on hover */ | |
| } | |
| .stFileUploader { | |
| background-color: #ffb300; /* Orange uploader background */ | |
| border: 2px dashed #ff8c00; /* Dashed orange border */ | |
| border-radius: 10px; | |
| padding: 10px; | |
| } | |
| .stFileUploader div[role="button"] { | |
| background-color: #f0f0f0; /* Light gray button */ | |
| border-radius: 10px; | |
| padding: 10px; | |
| } | |
| .stFileUploader div[role="button"] > div { | |
| color: #000000 !important; /* Black text for readability */ | |
| font-size: 16px; | |
| } | |
| .stFileUploader button { | |
| background-color: #ffca28 !important; /* Yellow button */ | |
| color: #000000 !important; | |
| border-radius: 8px !important; | |
| border: 2px solid #ffb300 !important; /* Orange border */ | |
| padding: 5px 15px !important; | |
| font-weight: bold !important; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.2) !important; /* Subtle shadow */ | |
| } | |
| .stFileUploader button:hover { | |
| background-color: #ff8c00 !important; /* Orange on hover */ | |
| color: #000000 !important; | |
| } | |
| .stImage { | |
| border: 3px solid #81c784; /* Green border for images */ | |
| border-radius: 10px; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); /* Soft shadow */ | |
| } | |
| .section-header { | |
| background-color: #b3e5fc; /* Light blue header background */ | |
| padding: 10px; | |
| border-radius: 10px; | |
| text-align: center; | |
| font-size: 24px; | |
| font-weight: bold; | |
| color: #000000; | |
| margin-bottom: 10px; | |
| } | |
| .caption-box, .story-box { | |
| background-color: #f0f4c3; /* Light yellow for text boxes */ | |
| padding: 15px; | |
| border-radius: 10px; | |
| border: 2px solid #d4e157; /* Green-yellow border */ | |
| margin-bottom: 20px; | |
| color: #000000; | |
| } | |
| .caption-box b, .story-box b { | |
| color: #000000; /* Black for bold text */ | |
| } | |
| .stProgress > div > div { | |
| background-color: #81c784; /* Green progress bar */ | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Display the main title with a fun, magical theme | |
| st.markdown("<div class='section-header'>Picture to Story Magic! β¨</div>", unsafe_allow_html=True) | |
| # Image upload section | |
| with st.container(): | |
| # Prompt user to upload an image | |
| st.markdown("<div class='section-header'>1οΈβ£ Pick a Fun Picture! πΌοΈ</div>", unsafe_allow_html=True) | |
| uploaded = st.file_uploader("Choose a picture to start the magic! π", type=["jpg","jpeg","png"]) | |
| if not uploaded: | |
| # Stop execution if no image is uploaded, with a friendly message | |
| st.info("Upload a picture, and let's make a story! π") | |
| st.stop() | |
| # Display the uploaded image | |
| with st.spinner("Looking at your picture..."): | |
| pil_img = Image.open(uploaded) | |
| st.image(pil_img, use_container_width=True) # Show image scaled to container | |
| # Caption generation section | |
| with st.container(): | |
| st.markdown("<div class='section-header'>2οΈβ£ What's in the Picture? π§</div>", unsafe_allow_html=True) | |
| captioner = get_image_captioner() # Load captioning model | |
| progress_bar = st.progress(0) # Initialize progress bar | |
| result = [None] # Store caption result | |
| def run_caption(): | |
| # Run captioning in a separate thread to avoid blocking UI | |
| result[0] = part1_image_to_text(pil_img, captioner) | |
| with st.spinner("Figuring out what's in your picture..."): | |
| thread = threading.Thread(target=run_caption) | |
| thread.start() | |
| # Simulate progress for ~5 seconds | |
| for i in range(100): | |
| progress_bar.progress(i + 1) | |
| time.sleep(0.05) | |
| thread.join() # Wait for captioning to complete | |
| progress_bar.empty() # Clear progress bar | |
| caption = result[0] | |
| # Display the generated caption in a styled box | |
| st.markdown(f"<div class='caption-box'><b>Picture Description:</b><br>{caption}</div>", unsafe_allow_html=True) | |
| # Story and audio generation section | |
| with st.container(): | |
| st.markdown("<div class='section-header'>3οΈβ£ Your Story and Audio! π΅</div>", unsafe_allow_html=True) | |
| # Story generation | |
| story_pipe = get_story_pipe() # Load story model | |
| progress_bar = st.progress(0) | |
| result = [None] # Store story result | |
| def run_story(): | |
| # Generate story in a separate thread | |
| result[0] = part2_text_to_story(caption, story_pipe) | |
| with st.spinner("Writing a super cool story..."): | |
| thread = threading.Thread(target=run_story) | |
| thread.start() | |
| # Simulate progress for ~7 seconds | |
| for i in range(100): | |
| progress_bar.progress(i + 1) | |
| time.sleep(0.07) | |
| thread.join() | |
| progress_bar.empty() | |
| story = result[0] | |
| # Display the generated story in a styled box | |
| st.markdown(f"<div class='story-box'><b>Your Cool Story! π</b><br>{story}</div>", unsafe_allow_html=True) | |
| # Text-to-speech conversion | |
| tts_pipe = get_tts_pipe() # Load TTS model | |
| progress_bar = st.progress(0) | |
| result = [None] # Store audio result | |
| def run_tts(): | |
| # Generate audio in a separate thread | |
| result[0] = part3_text_to_speech_bytes(story, tts_pipe) | |
| with st.spinner("Turning your story into sound..."): | |
| thread = threading.Thread(target=run_tts) | |
| thread.start() | |
| # Simulate progress for ~10 seconds | |
| for i in range(100): | |
| progress_bar.progress(i + 1) | |
| time.sleep(0.10) | |
| thread.join() | |
| progress_bar.empty() | |
| audio_bytes = result[0] | |
| # Play the generated audio in the UI | |
| st.audio(audio_bytes, format="audio/wav") |