ImageToSound / src /streamlit_app.py
napstablook911's picture
Update src/streamlit_app.py
b693fac verified
import streamlit as st
from PIL import Image
import io
import soundfile as sf
import numpy as np
import torch
from transformers import pipeline
from diffusers import StableAudioPipeline
# --- Configuration ---
# Determine the optimal device for model inference
# Prioritize CUDA (NVIDIA GPUs), then MPS (Apple Silicon), fallback to CPU
DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
# Use float16 for reduced memory and faster inference on compatible hardware (GPU/MPS)
# Fallback to float32 for CPU for better stability
TORCH_DTYPE = torch.float16 if DEVICE in ["cuda", "mps"] else torch.float32
# --- Cached Model Loading Functions ---
@st.cache_resource(show_spinner="Loading Image Captioning Model (BLIP)...")
def load_blip_model():
"""
Loads the BLIP image captioning model using Hugging Face transformers pipeline.
The model is cached to prevent reloading on every Streamlit rerun.
"""
try:
captioner = pipeline(
"image-to-text",
model="Salesforce/blip-image-captioning-base",
torch_dtype=TORCH_DTYPE,
device=DEVICE
)
return captioner
except Exception as e:
st.error(f"Failed to load BLIP model: {e}")
return None
@st.cache_resource(show_spinner="Loading Audio Generation Model (Stable Audio Open Small)...")
def load_stable_audio_model():
"""
Loads the Stable Audio Open Small pipeline using Hugging Face diffusers.
The pipeline is cached to prevent reloading on every Streamlit rerun.
"""
try:
# Changed model to stabilityai/stable-audio-open-small
audio_pipeline = StableAudioPipeline.from_pretrained(
"stabilityai/stable-audio-open-1.0",
torch_dtype=TORCH_DTYPE
).to(DEVICE)
return audio_pipeline
except Exception as e:
st.error(f"Failed to load Stable Audio model: {e}")
return None
# --- Audio Conversion Utility ---
def convert_numpy_to_wav_bytes(audio_array: np.ndarray, sample_rate: int) -> bytes:
"""
Converts a NumPy audio array to an in-memory WAV byte stream.
This avoids writing temporary files to disk, which is efficient and
suitable for ephemeral environments like Hugging Face Spaces.
"""
byte_io = io.BytesIO()
# Stable Audio Open's diffusers output is (channels, frames).
# soundfile typically expects (frames, channels) for stereo.
# Transpose if it's a 2D array (stereo) to match soundfile's expectation.
if audio_array.ndim == 2 and audio_array.shape == 2: # Check if stereo (2 channels)
audio_array = audio_array.T # Transpose to (frames, channels) [1]
# Write the NumPy array to the in-memory BytesIO object as a WAV file [1, 2]
sf.write(byte_io, audio_array, sample_rate, format='WAV', subtype='FLOAT')
# IMPORTANT: Reset the stream position to the beginning before reading [3]
byte_io.seek(0)
return byte_io.read()
# --- Streamlit App Layout ---
st.set_page_config(layout="centered", page_title="Image-to-Soundscape Generator")
st.title("🏞️ Image-to-Soundscape Generator 🎢")
st.markdown("Upload a landscape image, and let AI transform it into a unique soundscape!")
# Initialize session state for persistence across reruns [4]
if "audio_bytes" not in st.session_state:
st.session_state.audio_bytes = None
if "image_uploaded" not in st.session_state:
st.session_state.image_uploaded = False
# --- UI Components ---
uploaded_file = st.file_uploader("Choose a landscape image...", type=["jpg", "jpeg", "png"]) # [5]
if uploaded_file is not None:
st.session_state.image_uploaded = True
image = Image.open(uploaded_file).convert("RGB") # Ensure image is in RGB format
st.image(image, caption="Uploaded Image", use_container_width=True) # Updated deprecated parameter [6]
# Button to trigger the generation pipeline
if st.button("Generate Soundscape"):
st.session_state.audio_bytes = None # Clear previous audio
with st.spinner("Generating soundscape... This may take a moment."): # [4]
try:
# 1. Load BLIP model and generate caption (hidden from user)
captioner = load_blip_model()
if captioner is None:
st.error("Image captioning model could not be loaded. Please try again.")
st.session_state.image_uploaded = False # Reset to allow re-upload
st.stop()
# Generate caption
# The BLIP pipeline expects a PIL Image object directly
caption_results = captioner(image)
# Extract the generated text from the pipeline's output
generated_caption = caption_results[0]['generated_text']
# Optional: Enhance prompt for soundscape generation
# This helps guide the audio model towards environmental sounds
soundscape_prompt = f"A soundscape of {generated_caption}"
# 2. Load Stable Audio model and generate audio
audio_pipeline = load_stable_audio_model()
if audio_pipeline is None:
st.error("Audio generation model could not be loaded. Please try again.")
st.session_state.image_uploaded = False # Reset to allow re-upload
st.stop()
# Generate audio with optimized parameters for speed [7, 8]
# num_inference_steps: Lower for faster generation, higher for better quality
# audio_end_in_s: Shorter audio for faster generation (max 11s for stable-audio-open-small) [10, 11, 12]
# negative_prompt: Helps improve perceived quality [8]
audio_output = audio_pipeline(
prompt=soundscape_prompt,
num_inference_steps=10, # Tuned for faster generation [8]
audio_end_in_s=5, # 10 seconds audio length (within 11s limit for small model) [10, 11, 12]
negative_prompt="low quality, average quality, distorted" # [8]
)
# Extract the NumPy array and sample rate [9]
audio_numpy_array = audio_output.audios
sample_rate = audio_pipeline.config.sampling_rate
# 3. Convert NumPy array to WAV bytes and store in session state
st.session_state.audio_bytes = convert_numpy_to_wav_bytes(audio_numpy_array, sample_rate)
st.success("Soundscape generated successfully!")
except Exception as e:
st.error(f"An error occurred during generation: {e}") #
st.session_state.audio_bytes = None # Clear any partial audio
st.session_state.image_uploaded = False # Reset to allow re-upload
st.exception(e) # Display full traceback for debugging
# Display generated soundscape if available in session state
if st.session_state.audio_bytes:
st.subheader("Generated Soundscape:")
st.audio(st.session_state.audio_bytes, format='audio/wav') #
st.markdown("You can download the audio using the controls above.")
# Reset button for new image upload
if st.session_state.image_uploaded and st.button("Upload New Image"):
st.session_state.audio_bytes = None
st.session_state.image_uploaded = False
st.rerun() # Rerun the app to clear the file uploader