hskwon7 commited on
Commit
cdc9632
·
verified ·
1 Parent(s): 9f8fd3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -9,7 +9,7 @@ st.write("Upload an image and watch as it’s captioned, turned into a short sto
9
 
10
  @st.cache_resource
11
  def load_captioner():
12
- return pipeline("image-to-text", model="unography/blip-large-long-cap")
13
 
14
  @st.cache_resource
15
  def load_story_gen():
@@ -18,19 +18,19 @@ def load_story_gen():
18
  captioner = load_captioner()
19
  story_gen = load_story_gen()
20
 
21
- # 1) Upload (key='image' gives us st.session_state.image)
22
  uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"], key="image")
23
  if uploaded:
24
  img = Image.open(uploaded)
25
  st.image(img, use_column_width=True)
26
 
27
- # 2) Caption (once per upload)
28
  if "caption" not in st.session_state:
29
  with st.spinner("Generating caption…"):
30
  st.session_state.caption = captioner(img)[0]["generated_text"]
31
  st.write("**Caption:**", st.session_state.caption)
32
 
33
- # 3) Story (once per upload)
34
  if "story" not in st.session_state:
35
  with st.spinner("Spinning up a story…"):
36
  out = story_gen(
@@ -43,15 +43,15 @@ if uploaded:
43
  st.session_state.story = out[0]["generated_text"]
44
  st.write("**Story:**", st.session_state.story)
45
 
46
- # 4) Pre-generate audio buffer (once per upload)
47
- if "audio_buffer" not in st.session_state:
48
  with st.spinner("Generating audio…"):
49
  tts = gTTS(text=st.session_state.story, lang="en")
50
  buf = io.BytesIO()
51
  tts.write_to_fp(buf)
52
- buf.seek(0)
53
- st.session_state.audio_buffer = buf.read()
54
 
55
  # 5) Play on demand
56
  if st.button("🔊 Play Story Audio"):
57
- st.audio(st.session_state.audio_buffer, format="audio/mp3")
 
 
9
 
10
  @st.cache_resource
11
  def load_captioner():
12
+ return pipeline("image-captioning", model="nlpconnect/vit-gpt2-image-captioning")
13
 
14
  @st.cache_resource
15
  def load_story_gen():
 
18
  captioner = load_captioner()
19
  story_gen = load_story_gen()
20
 
21
+ # 1) Upload
22
  uploaded = st.file_uploader("Upload an image", type=["png","jpg","jpeg"], key="image")
23
  if uploaded:
24
  img = Image.open(uploaded)
25
  st.image(img, use_column_width=True)
26
 
27
+ # 2) Caption
28
  if "caption" not in st.session_state:
29
  with st.spinner("Generating caption…"):
30
  st.session_state.caption = captioner(img)[0]["generated_text"]
31
  st.write("**Caption:**", st.session_state.caption)
32
 
33
+ # 3) Story
34
  if "story" not in st.session_state:
35
  with st.spinner("Spinning up a story…"):
36
  out = story_gen(
 
43
  st.session_state.story = out[0]["generated_text"]
44
  st.write("**Story:**", st.session_state.story)
45
 
46
+ # 4) Pre-generate raw MP3 bytes
47
+ if "audio_bytes" not in st.session_state:
48
  with st.spinner("Generating audio…"):
49
  tts = gTTS(text=st.session_state.story, lang="en")
50
  buf = io.BytesIO()
51
  tts.write_to_fp(buf)
52
+ st.session_state.audio_bytes = buf.getvalue()
 
53
 
54
  # 5) Play on demand
55
  if st.button("🔊 Play Story Audio"):
56
+ audio_buffer = io.BytesIO(st.session_state.audio_bytes)
57
+ st.audio(audio_buffer, format="audio/mp3")