1
File size: 3,732 Bytes
dfb3989
 
8367fb2
 
7d2ac1c
 
6b1de29
8367fb2
 
dfb3989
 
 
8367fb2
7d2ac1c
8367fb2
7d2ac1c
 
 
 
 
 
8367fb2
7d2ac1c
 
 
 
8367fb2
7d2ac1c
121e41f
7d2ac1c
8367fb2
dd489ad
ff06172
7d2ac1c
258bc7e
7d2ac1c
258bc7e
7d2ac1c
 
8367fb2
258bc7e
121e41f
6bc44b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258bc7e
 
ff06172
258bc7e
 
7d2ac1c
8367fb2
c2c4e19
dfb3989
7d2ac1c
0fdc556
121e41f
 
cc355a8
eb25a05
dfb3989
cc355a8
6bc44b9
c12feed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258bc7e
 
c2c4e19
121e41f
c2c4e19
 
 
 
 
 
 
 
7d2ac1c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# app.py

import streamlit as st
from PIL import Image
from io import BytesIO
from huggingface_hub import InferenceApi
from gtts import gTTS
import tempfile

# —––––––– Page config
st.set_page_config(page_title="Storyteller for Kids", layout="centered")
st.title("🖼️ ➡️ 📖 Interactive Storyteller")

# —––––––– Inference clients (cached)
@st.cache_resource
def load_clients():
    hf_token = st.secrets["HF_TOKEN"]
    caption_client = InferenceApi(
        repo_id="Salesforce/blip-image-captioning-base",
        task="image-to-text",
        token=hf_token
    )
    story_client = InferenceApi(
        repo_id="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
        task="text-generation",
        token=hf_token
    )
    return caption_client, story_client

caption_client, story_client = load_clients()

# —––––––– Main UI
uploaded = st.file_uploader("Upload an image:", type=["jpg", "jpeg", "png"])
if not uploaded:
    st.info("Please upload a JPG/PNG image to begin.")
else:
    # 1) Display image
    img = Image.open(uploaded).convert("RGB")
    st.image(img, use_container_width=True)

    # 2) Generate caption
    with st.spinner("🔍 Generating caption..."):
        try:
            buf = BytesIO()
            img.save(buf, format="PNG")
            cap_out = caption_client(data=buf.getvalue())

            # Handle caption response
            if isinstance(cap_out, list) and cap_out:
                cap_text = cap_out[0].get("generated_text", "").strip()
            elif isinstance(cap_out, dict):
                cap_text = cap_out.get("generated_text", "").strip()
            else:
                cap_text = str(cap_out).strip()
                
        except Exception as e:
            st.error(f"🚨 Caption generation failed: {str(e)}")
            st.stop()

    if not cap_text:
        st.error("😕 Couldn’t generate a caption. Try another image.")
        st.stop()

    st.markdown(f"**Caption:** {cap_text}")

    # 3) Build story prompt
    prompt = (
        f"Here’s an image description: “{cap_text}”.\n\n"
        "Write an 80–100 word playful story for 3–10 year-old children that:\n"
        "1) Describes the scene and main subject.\n"
        "2) Explains what it’s doing and how it feels.\n"
        "3) Concludes with a fun, imaginative ending.\n\n"
        "Story:"
    )


复制代码
# 4) Generate story with corrected parameter format
with st.spinner("✍️ Generating story..."):
    try:
        story_out = story_client(
            prompt,
            max_new_tokens=250,  # Direct keyword arguments
            temperature=0.7,
            top_p=0.9,
            top_k=50,
            repetition_penalty=1.1,
            do_sample=True,
            no_repeat_ngram_size=2
        )

        # Handle response format
        if isinstance(story_out, list):
            story_text = story_out[0].get("generated_text", "")
        else:  # Handle single-dictionary response
            story_text = story_out.get("generated_text", "")
        
        # Extract story content after last prompt mention
        story = story_text.split("Story:")[-1].strip()

    except Exception as e:
        st.error(f"🚨 Story generation failed: {str(e)}")
        st.stop()

    # 5) Text-to-Speech
    with st.spinner("🔊 Converting to speech..."):
        try:
            tts = gTTS(text=story, lang="en")
            with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
                tts.write_to_fp(tmp)
                tmp.seek(0)
                st.audio(tmp.name, format="audio/mp3")
        except Exception as e:
            st.error(f"🔇 Audio conversion failed: {str(e)}")