File size: 5,639 Bytes
77c4802
 
 
 
 
 
 
 
f136dda
77c4802
 
 
 
 
f136dda
77c4802
 
 
 
 
 
 
f136dda
77c4802
 
 
 
 
 
 
f136dda
77c4802
 
 
 
70a4e6d
77c4802
 
f136dda
 
 
 
 
77c4802
f136dda
 
 
 
77c4802
 
 
70a4e6d
 
f136dda
 
77c4802
 
70a4e6d
 
f136dda
77c4802
 
f136dda
 
 
 
 
 
 
 
77c4802
 
f136dda
 
 
 
 
 
 
 
 
77c4802
 
 
70a4e6d
f136dda
 
 
 
 
 
 
77c4802
f136dda
 
70a4e6d
 
77c4802
f136dda
77c4802
f136dda
77c4802
f136dda
 
 
 
 
 
 
 
 
77c4802
 
f136dda
70a4e6d
f136dda
 
 
77c4802
f136dda
77c4802
f136dda
77c4802
f136dda
 
70a4e6d
f136dda
 
77c4802
f136dda
77c4802
f136dda
 
 
 
 
 
77c4802
f136dda
77c4802
f136dda
77c4802
 
f136dda
 
 
 
 
 
 
70a4e6d
f136dda
77c4802
 
 
f136dda
 
 
 
 
 
 
70a4e6d
77c4802
 
 
f136dda
77c4802
f136dda
 
 
 
 
 
 
77c4802
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import io                   # for creating in-memory binary streams
import wave                 # for writing WAV audio files
import re                   # for regular expression utilities
import streamlit as st      # Streamlit UI library
from transformers import pipeline  # Hugging Face inference pipelines
from PIL import Image       # Python Imaging Library for image loading
import numpy as np          # numerical operations, especially array handling

# 1) CACHE & LOAD MODELS (CPU only)
@st.cache_resource(show_spinner=False)
def load_captioner():
    return pipeline(
        "image-to-text",
        model="Salesforce/blip-image-captioning-base",
        device=-1  # force CPU
    )

@st.cache_resource(show_spinner=False)
def load_story_pipe():
    return pipeline(
        "text2text-generation",
        model="google/flan-t5-base",
        device=-1  # force CPU
    )

@st.cache_resource(show_spinner=False)
def load_tts_pipe():
    return pipeline(
        "text-to-speech",
        model="facebook/mms-tts-eng",
        device=-1  # force CPU
    )

# 2) HELPER FUNCTIONS
def sentence_case(text: str) -> str:
    parts = re.split(r'([.!?])', text)
    out = []
    for i in range(0, len(parts) - 1, 2):
        sentence = parts[i].strip()
        delimiter = parts[i + 1]
        if sentence:
            formatted = sentence[0].upper() + sentence[1:]
            out.append(f"{formatted}{delimiter}")
    if len(parts) % 2:
        last = parts[-1].strip()
        if last:
            formatted = last[0].upper() + last[1:]
            out.append(formatted)
    return " ".join(" ".join(out).split())

def caption_image(img: Image.Image, captioner) -> str:
    if img.mode != "RGB":
        img = img.convert("RGB")
    results = captioner(img)
    return (results[0].get("generated_text", "") if results else "")

def story_from_caption(caption: str, pipe) -> str:
    if not caption:
        return "Could not generate a story without a caption."
    prompt = f"Write a vivid, imaginative ~100-word story about this scene: {caption}\n\nWrite a creative and descriptive short story."
    results = pipe(
        prompt,
        max_length=120,
        min_length=60,
        do_sample=True,
        top_k=100,
        top_p=0.9,
        temperature=0.8,
        repetition_penalty=1.1,
        no_repeat_ngram_size=4,
        early_stopping=False
    )
    raw = results[0]["generated_text"].strip()
    # Remove prompt echo if present
    raw = re.sub(re.escape(prompt), "", raw, flags=re.IGNORECASE).strip()
    # Trim to last full sentence
    idx = max(raw.rfind("."), raw.rfind("!"), raw.rfind("?"))
    if idx != -1:
        raw = raw[:idx+1]
    elif len(raw) > 80:
        raw = raw[:raw.rfind(" ") if raw.rfind(" ") > 60 else 80] + "..."
    return sentence_case(raw)

def tts_bytes(text: str, tts_pipe) -> bytes:
    if not text:
        return b""
    cleaned = re.sub(r'^["\']|["\']$', '', text).strip()
    cleaned = re.sub(r'\.{2,}', '.', cleaned).replace('…', '...')
    if cleaned[-1] not in ".!?":
        cleaned += "."
    cleaned = " ".join(cleaned.split())
    output = tts_pipe(cleaned)
    result = output[0] if isinstance(output, list) else output
    audio_array = result.get("audio")
    rate = result.get("sampling_rate")
    if audio_array is None or rate is None:
        return b""
    if audio_array.ndim == 1:
        data = audio_array[:, np.newaxis]
    else:
        data = audio_array.T
    pcm = (data * 32767).astype(np.int16)
    buf = io.BytesIO()
    wf = wave.open(buf, "wb")
    wf.setnchannels(data.shape[1])
    wf.setsampwidth(2)
    wf.setframerate(rate)
    wf.writeframes(pcm.tobytes())
    wf.close()
    buf.seek(0)
    return buf.read()

# 3) STREAMLIT USER INTERFACE
st.set_page_config(page_title="✨ Imagine & Narrate", page_icon="✨", layout="centered")

# Persist upload across reruns
if "uploaded_file" not in st.session_state:
    st.session_state.uploaded_file = None

new_upload = st.file_uploader(
    "Choose an image file",
    type=["jpg", "jpeg", "png"]
)
if new_upload is not None:
    st.session_state.uploaded_file = new_upload

if st.session_state.uploaded_file is None:
    st.title("✨ Imagine & Narrate")
    st.info("➑️ Upload an image above to start the magic!")
    st.stop()

uploaded = st.session_state.uploaded_file
try:
    img = Image.open(uploaded)
except Exception as e:
    st.error(f"Could not load the image: {e}")
    st.stop()

st.title("✨ Imagine & Narrate")
st.subheader("πŸ“Έ Your Visual Input")
st.image(img, caption=uploaded.name, use_container_width=True)
st.divider()

# Step 1: Generate Caption
st.subheader("🧠 Generating Caption")
with st.spinner("Analyzing image..."):
    captioner = load_captioner()
    raw_caption = caption_image(img, captioner)
    if not raw_caption:
        st.error("Failed to generate caption.")
        st.stop()
    caption = sentence_case(raw_caption)
st.markdown(f"**Identified Scene:** {caption}")
st.divider()

# Step 2: Generate Story
st.subheader("πŸ“– Crafting a Story")
with st.spinner("Writing story..."):
    story_pipe = load_story_pipe()
    story = story_from_caption(caption, story_pipe)
    if not story or story.strip() in {".", "..", "..."}:
        st.error("Failed to generate story.")
        st.stop()
st.write(story)
st.divider()

# Step 3: Synthesize Audio
st.subheader("πŸ‘‚ Hear the Story")
with st.spinner("Synthesizing audio..."):
    tts_pipe = load_tts_pipe()
    audio_bytes = tts_bytes(story, tts_pipe)
    if not audio_bytes:
        st.warning("Audio generation failed.")
    else:
        st.audio(audio_bytes, format="audio/wav")
st.balloons()