justREE commited on
Commit
77c4802
·
verified ·
1 Parent(s): 936f674

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +208 -40
src/streamlit_app.py CHANGED
@@ -1,40 +1,208 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import io # for creating in-memory binary streams
4
+ import wave # for writing WAV audio files
5
+ import re # for regular expression utilities
6
+ import streamlit as st # Streamlit UI library
7
+ from transformers import pipeline # Hugging Face inference pipelines
8
+ from PIL import Image # Python Imaging Library for image loading
9
+ import numpy as np # numerical operations, especially array handling
10
+
11
+ # 1) CACHE & LOAD MODELS
12
+ @st.cache_resource(show_spinner=False)
13
+ def load_captioner():
14
+ # Loads BLIP image-to-text model; cached so it loads only once.
15
+ # Returns: a function captioner(image: PIL.Image) -> List[Dict],
16
+ return pipeline(
17
+ "image-to-text",
18
+ model="Salesforce/blip-image-captioning-base",
19
+ device="cpu" # Can change to "cuda" if GPU is available
20
+ )
21
+
22
+ @st.cache_resource(show_spinner=False)
23
+ def load_story_pipe():
24
+ # Loads FLAN-T5 text-to-text model for story generation; cached once.
25
+ # Returns: a function story_pipe(prompt: str, **kwargs) -> List[Dict].
26
+ return pipeline(
27
+ "text2text-generation",
28
+ model="google/flan-t5-base",
29
+ device="cpu" # Can change to "cuda" if GPU is available
30
+ )
31
+
32
+ @st.cache_resource(show_spinner=False)
33
+ def load_tts_pipe():
34
+ # Loads Meta MMS-TTS text-to-speech model; cached once.
35
+ # Returns: a function tts_pipe(text: str) -> List[Dict] with "audio" and "sampling_rate".
36
+ return pipeline(
37
+ "text-to-speech",
38
+ model="facebook/mms-tts-eng",
39
+ device="cpu" # Can change to "cuda" if GPU is available
40
+ )
41
+
42
+ # 2) HELPER FUNCTIONS
43
+ def sentence_case(text: str) -> str:
44
+ # Splits text into sentences on .!? delimiters,
45
+ # capitalizes the first character of each sentence,
46
+ # then rejoins into a single string.
47
+ parts = re.split(r'([.!?])', text) # ["hello", ".", " world", "!"]
48
+ out = []
49
+ for i in range(0, len(parts) - 1, 2):
50
+ sentence = parts[i].strip().capitalize() # capitalize first letter
51
+ delimiter = parts[i + 1] # punctuation
52
+ # Ensure a space before the sentence if it wasn't the very first part
53
+ if out and not sentence.startswith(' ') and out[-1][-1] not in '.!?':
54
+ out.append(f" {sentence}{delimiter}")
55
+ else:
56
+ out.append(f"{sentence}{delimiter}")
57
+
58
+ # If trailing text without punctuation exists, capitalize and append it.
59
+ if len(parts) % 2:
60
+ last = parts[-1].strip().capitalize()
61
+ if last:
62
+ # Ensure a space before if needed
63
+ if out and not last.startswith(' ') and out[-1][-1] not in '.!?':
64
+ out.append(f" {last}")
65
+ else:
66
+ out.append(last)
67
+
68
+ # Clean up potential multiple spaces resulting from split/join
69
+ return " ".join(" ".join(out).split())
70
+
71
+
72
+ def caption_image(img: Image.Image, captioner) -> str:
73
+ # Given a PIL image and a captioner pipeline, returns a single-line caption.
74
+ results = captioner(img) # run model
75
+ if not results:
76
+ return ""
77
+ # extract "generated_text" field from first result
78
+ return results[0].get("generated_text", "")
79
+
80
+ def story_from_caption(caption: str, pipe) -> str:
81
+ # Given a caption string and a text2text pipeline, returns a ~100-word story.
82
+ prompt = f"Write a vivid, imaginative ~100-word story about this scene: {caption}"
83
+ results = pipe(
84
+ prompt,
85
+ max_length=120, # increased max length slightly
86
+ min_length=80, # minimum generated tokens
87
+ do_sample=True, # enable sampling
88
+ top_k=100, # sample from top_k tokens
89
+ top_p=0.9, # nucleus sampling threshold
90
+ temperature=0.7, # sampling temperature
91
+ repetition_penalty=1.1, # discourage repetition
92
+ no_repeat_ngram_size=4, # block repeated n-grams
93
+ early_stopping=False
94
+ )
95
+ raw = results[0]["generated_text"].strip() # full generated text
96
+ # strip out the prompt if it echoes back - make comparison case-insensitive
97
+ if raw.lower().startswith(prompt.lower()):
98
+ raw = raw[len(prompt):].strip()
99
+
100
+ # trim to last complete sentence ending in . ! or ?
101
+ match = re.search(r'[.!?]', raw[::-1]) # Search for the first punctuation from the end
102
+ if match:
103
+ raw = raw[:len(raw) - match.start()] # Trim at that position
104
+ elif len(raw) > 80: # If no punctuation found but story is long, trim to a reasonable length
105
+ raw = raw[:80] + "..."
106
+
107
+ return sentence_case(raw)
108
+
109
+ def tts_bytes(text: str, tts_pipe) -> bytes:
110
+ # Given a text string and a tts pipeline, returns WAV-format bytes.
111
+ # Clean up text for TTS - remove leading/trailing quotes, etc.
112
+ cleaned_text = re.sub(r'^["\']|["\']$', '', text).strip()
113
+ # Basic punctuation cleaning (optional, depending on TTS model)
114
+ cleaned_text = re.sub(r'\.{2,}', '.', cleaned_text) # Replace multiple periods with one
115
+ cleaned_text = cleaned_text.replace('…', '...') # Replace ellipsis char with dots
116
+ # Add a period if the text doesn't end with punctuation (helps TTS model finalize)
117
+ if cleaned_text and cleaned_text[-1] not in '.!?':
118
+ cleaned_text += '.'
119
+
120
+ output = tts_pipe(cleaned_text)
121
+ # pipeline may return list or single dict
122
+ result = output[0] if isinstance(output, list) else output
123
+ audio_array = result["audio"] # numpy array: (channels, samples) or (samples,)
124
+ rate = result["sampling_rate"] # sampling rate integer
125
+
126
+ # ensure audio_array is 2D (samples, channels) for consistent handling
127
+ if audio_array.ndim == 1:
128
+ data = audio_array[:, np.newaxis] # add channel dimension
129
+ else:
130
+ data = audio_array.T # transpose from (channels, samples) to (samples, channels)
131
+
132
+
133
+ # convert float32 [-1..1] to int16 PCM [-32768..32767]
134
+ pcm = (data * 32767).astype(np.int16)
135
+
136
+ buffer = io.BytesIO()
137
+ wf = wave.open(buffer, "wb")
138
+ wf.setnchannels(data.shape[1]) # number of channels
139
+ wf.setsampwidth(2) # 16 bits = 2 bytes
140
+ wf.setframerate(rate) # samples per second
141
+ wf.writeframes(pcm.tobytes()) # write PCM data
142
+ wf.close()
143
+ buffer.seek(0)
144
+ return buffer.read() # return raw WAV bytes
145
+
146
+ # 3) STREAMLIT USER INTERFACE
147
+ st.set_page_config(page_title="Imagine & Narrate", page_icon="✨", layout="centered")
148
+ st.title("✨ Imagine & Narrate")
149
+ st.write("Upload any image below to see AI imagine and narrate a story about it!")
150
+
151
+ # -- Upload image widget --
152
+ uploaded = st.file_uploader(
153
+ "Choose an image file",
154
+ type=["jpg", "jpeg", "png"]
155
+ )
156
+ if not uploaded:
157
+ st.info("➡️ Upload an image above to start the magic!")
158
+ st.stop()
159
+
160
+ # Load the uploaded file into a PIL Image
161
+ try:
162
+ img = Image.open(uploaded)
163
+ except Exception as e:
164
+ st.error(f"Error loading image: {e}")
165
+ st.stop()
166
+
167
+
168
+ # -- Step 1: Display the image --
169
+ st.subheader("📸 Your Visual Input")
170
+ st.image(img, use_container_width=True)
171
+ st.divider()
172
+
173
+ # -- Step 2: Generate and display caption --
174
+ st.subheader("🧠 Generating Insights")
175
+ with st.spinner("Scanning image for key elements…"):
176
+ captioner = load_captioner()
177
+ raw_caption = caption_image(img, captioner)
178
+ if not raw_caption:
179
+ st.warning("Could not generate a caption for the image.")
180
+ st.stop()
181
+ caption = sentence_case(raw_caption)
182
+ st.markdown(f"**Identified Scene:** {caption}")
183
+ st.divider()
184
+
185
+ # -- Step 3: Generate and display story --
186
+ st.subheader("📖 Crafting a Narrative")
187
+ with st.spinner("Writing a compelling story…"):
188
+ story_pipe = load_story_pipe()
189
+ story = story_from_caption(caption, story_pipe)
190
+ if not story or story.strip() == '...': # Check for empty or minimal story
191
+ st.warning("Could not generate a meaningful story from the caption.")
192
+ st.stop()
193
+ st.write(story)
194
+ st.divider()
195
+
196
+ # -- Step 4: Synthesize and play audio --
197
+ st.subheader("👂 Hear the Story")
198
+ with st.spinner("Synthesizing audio narration…"):
199
+ tts_pipe = load_tts_pipe()
200
+ try:
201
+ audio_bytes = tts_bytes(story, tts_pipe)
202
+ st.audio(audio_bytes, format="audio/wav")
203
+ except Exception as e:
204
+ st.error(f"Error generating audio: {e}")
205
+
206
+
207
+ # Celebration animation
208
+ st.balloons()