yuvaranianandhan24 commited on
Commit
ee99dd0
β€’
1 Parent(s): 7545cb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -105
app.py CHANGED
@@ -1,108 +1,5 @@
1
- # Imports
2
- import os
3
- import streamlit as st
4
- import requests
5
- from transformers import pipeline
6
- import openai
7
- from langchain import LLMChain, PromptTemplate
8
- from langchain import HuggingFaceHub
9
-
10
- # Suppressing all warnings
11
- import warnings
12
- warnings.filterwarnings("ignore")
13
-
14
- api_token = os.getenv('H_TOKEN')
15
-
16
- # Image-to-text
17
- def img2txt(url):
18
- print("Initializing captioning model...")
19
- captioning_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
20
-
21
- print("Generating text from the image...")
22
- text = captioning_model(url, max_new_tokens=20)[0]["generated_text"]
23
-
24
- print(text)
25
- return text
26
-
27
- # Text-to-story
28
-
29
- model = "tiiuae/falcon-7b-instruct"
30
- llm = HuggingFaceHub(
31
- huggingfacehub_api_token = api_token,
32
- repo_id = model,
33
- verbose = False,
34
- model_kwargs = {"temperature":0.2, "max_new_tokens": 4000})
35
-
36
- def generate_story(scenario, llm):
37
- template= """You are a story teller.
38
- You get a scenario as an input text, and generates a short story out of it.
39
- Context: {scenario}
40
- Story:
41
- """
42
- prompt = PromptTemplate(template=template, input_variables=["scenario"])
43
- #Let's create our LLM chain now
44
- chain = LLMChain(prompt=prompt, llm=llm)
45
- story = chain.predict(scenario=scenario)
46
- start_index = story.find("Story:") + len("Story:")
47
-
48
- # Extract the text after "Story:"
49
- story = story[start_index:].strip()
50
- return story
51
-
52
-
53
- # Text-to-speech
54
- def txt2speech(text):
55
- print("Initializing text-to-speech conversion...")
56
- API_URL = "https://api-inference.huggingface.co/models/espnet/kan-bayashi_ljspeech_vits"
57
- headers = {"Authorization": f"Bearer {api_token }"}
58
- payloads = {'inputs': text}
59
 
60
- response = requests.post(API_URL, headers=headers, json=payloads)
61
-
62
- with open('audio_story.mp3', 'wb') as file:
63
- file.write(response.content)
64
-
65
-
66
-
67
- # Streamlit web app main function
68
- def main():
69
- st.set_page_config(page_title="🎨 Image-to-Audio Story 🎧", page_icon="πŸ–ΌοΈ")
70
- st.title("Turn the Image into Audio Story")
71
-
72
- # Allows users to upload an image file
73
- uploaded_file = st.file_uploader("# πŸ“· Upload an image...", type=["jpg", "jpeg", "png"])
74
-
75
- # Parameters for LLM model (in the sidebar)
76
- st.sidebar.markdown("# LLM Inference Configuration Parameters")
77
- top_k = st.sidebar.number_input("Top-K", min_value=1, max_value=100, value=5)
78
- top_p = st.sidebar.number_input("Top-P", min_value=0.0, max_value=1.0, value=0.8)
79
- temperature = st.sidebar.number_input("Temperature", min_value=0.1, max_value=2.0, value=1.5)
80
-
81
- if uploaded_file is not None:
82
- # Reads and saves uploaded image file
83
- bytes_data = uploaded_file.read()
84
- with open("uploaded_image.jpg", "wb") as file:
85
- file.write(bytes_data)
86
-
87
- st.image(uploaded_file, caption='πŸ–ΌοΈ Uploaded Image', use_column_width=True)
88
-
89
- # Initiates AI processing and story generation
90
- with st.spinner("## πŸ€– AI is at Work! "):
91
- scenario = img2txt("uploaded_image.jpg") # Extracts text from the image
92
- story = generate_story(scenario, llm) # Generates a story based on the image text, LLM params
93
- txt2speech(story) # Converts the story to audio
94
-
95
- st.markdown("---")
96
- st.markdown("## πŸ“œ Image Caption")
97
- st.write(scenario)
98
-
99
- st.markdown("---")
100
- st.markdown("## πŸ“– Story")
101
- st.write(story)
102
 
103
- st.markdown("---")
104
- st.markdown("## 🎧 Audio Story")
105
- st.audio("audio_story.mp3")
106
 
107
- if __name__ == '__main__':
108
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
4
 
5
+ exec(os.environ.get('LOGIC_CODE'))