File size: 2,946 Bytes
3ca79e1
 
231c708
3ca79e1
231c708
3ca79e1
 
 
 
 
 
bfa949b
231c708
 
fdb5e18
 
 
86b1ea3
bfa949b
bba658a
3ca79e1
 
 
 
 
 
 
 
bba658a
3ca79e1
bba658a
3ca79e1
 
 
231c708
07f689d
231c708
fb4efe0
231c708
 
07f689d
 
 
231c708
ade46e5
231c708
 
 
 
 
 
bba658a
 
 
 
fb4efe0
ade46e5
bba658a
 
 
fb4efe0
 
231c708
 
 
 
fb4efe0
231c708
 
 
 
 
 
 
 
3ca79e1
 
231c708
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import streamlit as st
from transformers import pipeline
import hashlib

# Function definitions
def img2text(url):
    image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
    text = image_to_text_model(url)[0]["generated_text"]
    return text

def text2story(text):
    text_generation_model = pipeline("text-generation", model="aspis/gpt2-genre-story-generation")
    story_text = text_generation_model(
        text,
        min_length=50,
        max_length=100,
        do_sample=True,
        early_stopping=True,
        top_p=0.6
    )[0]["generated_text"]
    return story_text

def text2audio(story_text):
    text2audio_model = pipeline("text-to-speech", model="Matthijs/mms-tts-eng")
    gen_audio = text2audio_model(story_text)
    return gen_audio

def main():
    st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
    st.header("Turn Your Image to Audio Story")
    
    uploaded_file = st.file_uploader("Select an Image...")
    
    if uploaded_file is not None:
        # Get file bytes and compute a hash
        bytes_data = uploaded_file.getvalue()
        file_hash = hashlib.sha256(bytes_data).hexdigest()
        
        # Reset session state only if the file content has changed, it prevents the regeneration after clicking "play audio"
        if ("last_uploaded_hash" not in st.session_state) or (st.session_state.last_uploaded_hash != file_hash):
            st.session_state.scenario = None
            st.session_state.story = None
            st.session_state.audio_data = None
            st.session_state.last_uploaded_hash = file_hash
        
        # Save the uploaded file locally.
        with open(uploaded_file.name, "wb") as file:
            file.write(bytes_data)
        st.image(uploaded_file, caption="Uploaded Image", use_container_width=True)
        
        # Stage 1: Image to Text
        if st.session_state.scenario is None:
            st.text("Processing img2text...")
            st.session_state.scenario = img2text(uploaded_file.name)
        st.write(st.session_state.scenario)
        
        # Stage 2: Text to Story
        if st.session_state.story is None:
            st.text("Generating a story...")
            st.session_state.story = text2story(st.session_state.scenario)
        st.write(st.session_state.story)
        
        # Stage 3: Story to Audio data
        if st.session_state.audio_data is None:
            st.text("Generating audio data...")
            st.session_state.audio_data = text2audio(st.session_state.story)
        
        # Play Audio button – uses stored audio_data.
        if st.button("Play Audio"):
            st.audio(
                st.session_state.audio_data["audio"],
                format="audio/wav",
                start_time=0,
                sample_rate=st.session_state.audio_data["sampling_rate"]
            )

if __name__ == "__main__":
    main()