Geraldine commited on
Commit
c265bc5
1 Parent(s): 47e792a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +242 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from huggingface_hub import InferenceClient, AsyncInferenceClient
3
+ from PIL import Image
4
+ from pathlib import Path
5
+ import os, subprocess
6
+
7
+ st.set_page_config(page_title='HG Inference Client Demo',layout="wide")
8
+ # Cache the header of the app to prevent re-rendering on each load
9
+ @st.cache_resource
10
+ def display_app_header():
11
+ """Display the header of the Streamlit app."""
12
+ st.title("1️⃣ HG Inference Client Demo 📊 ")
13
+ st.subheader("Just a little demontstrator")
14
+ # Display the header of the app
15
+ display_app_header()
16
+
17
+ # UI sidebar parameters ####################################
18
+ st.sidebar.header("Loging")
19
+ if hg_token :=st.sidebar.text_input('Enter your HG token'):
20
+ try:
21
+ subprocess.check_call(["huggingface-cli", "login", "--token", hg_token])
22
+ st.sidebar.info('Logged', icon="ℹ️")
23
+ except subprocess.CalledProcessError:
24
+ st.sidebar.error('Error with token, try again', icon="⚠️")
25
+ else:
26
+ st.sidebar.warning("enter your token")
27
+
28
+ st.sidebar.header("Model")
29
+ selected_model = st.sidebar.radio(
30
+ "Choose a model or let the client do it",
31
+ ["Not choose", "Choose"]
32
+ )
33
+ if selected_model == "Choose":
34
+ model = st.sidebar.text_input('Enter a model name. ex : facebook/fastspeech2-en-ljspeech')
35
+ else:
36
+ model = None
37
+
38
+ st.sidebar.header("Task")
39
+ dict_hg_tasks = {
40
+ "Automatic Speech Recognition":"automatic_speech_recognition",
41
+ "Text-to-Speech (choose model)":"text_to_speech",
42
+ "Image Classification":"image_classification",
43
+ "Image Segmentation":"image_segmentation",
44
+ "Image-to-Text":"image_to_text",
45
+ "Object Detection":"object_detection",
46
+ "Text-to-Image":"text_to_image",
47
+ "Visual Question Answering":"visual_question_answering",
48
+ "Conversational":"conversational",
49
+ "Feature Extraction":"feature_extraction",
50
+ "Question Answering":"question_answering",
51
+ "Summarization":"summarization",
52
+ "Text Classification":"text_classification",
53
+ "Text Generation":"text_generation",
54
+ "Token Classification":"token_classification",
55
+ "Translation (choose model)":"translation",
56
+ }
57
+
58
+ dict_hg_tasks_params = {
59
+ "automatic_speech_recognition": {
60
+ "input": "upload,url",
61
+ "output": "text",
62
+ "prompt": False,
63
+ "context": False
64
+ },
65
+ "text_to_speech": {
66
+ "input": "text",
67
+ "output": "audio",
68
+ "prompt": False,
69
+ "context": False
70
+ },
71
+ "image_classification": {
72
+ "input": "upload,url",
73
+ "output": "image,text",
74
+ "prompt": False,
75
+ "context": False
76
+ },
77
+ "image_segmentation": {
78
+ "input": "upload,url",
79
+ "output": "image,text",
80
+ "prompt": False,
81
+ "context": False
82
+ },
83
+ "image_to_text": {
84
+ "input": "upload,url",
85
+ "output": "image,text",
86
+ "prompt": False,
87
+ "context": False
88
+ },
89
+ "object_detection": {
90
+ "input": "upload,url",
91
+ "output": "image,text",
92
+ "prompt": False,
93
+ "context": False
94
+ },
95
+ "text_to_image": {
96
+ "input": "text",
97
+ "output": "image",
98
+ "prompt": False,
99
+ "context": False
100
+ },
101
+ "visual_question_answering": {
102
+ "input": "upload,url",
103
+ "output": "image,text",
104
+ "prompt": True,
105
+ "context": False
106
+ },
107
+ "image_to_image": {
108
+ "input": "upload,url",
109
+ "output": "image,text",
110
+ "prompt": True,
111
+ "context": False
112
+ },
113
+ "feature_extraction": {
114
+ "input": "text",
115
+ "output": "text",
116
+ "prompt": False,
117
+ "context": False
118
+ },
119
+ "conversational": {
120
+ "input": "text",
121
+ "output": "text",
122
+ "prompt": False,
123
+ "context": False
124
+ },
125
+ "question_answering": {
126
+ "input": None,
127
+ "output": "text",
128
+ "prompt": True,
129
+ "context": True
130
+ },
131
+ "text_classification": {
132
+ "input": "text",
133
+ "output": "text",
134
+ "prompt": False,
135
+ "context": False
136
+ },
137
+ "token_classification": {
138
+ "input": "text",
139
+ "output": "text",
140
+ "prompt": False,
141
+ "context": False
142
+ },
143
+ "text_generation": {
144
+ "input": "text",
145
+ "output": "text",
146
+ "prompt": False,
147
+ "context": False
148
+ },
149
+ "text_classification": {
150
+ "input": "text",
151
+ "output": "text",
152
+ "prompt": False,
153
+ "context": False
154
+ },
155
+ "translation": {
156
+ "input": "text",
157
+ "output": "text",
158
+ "prompt": False,
159
+ "context": False
160
+ },
161
+ "summarization": {
162
+ "input": "text",
163
+ "output": "text",
164
+ "prompt": False,
165
+ "context": False
166
+ },
167
+ }
168
+ selected_task = st.sidebar.radio(
169
+ "Choose the task you want to do", # see https://huggingface.co/docs/huggingface_hub/guides/inference"
170
+ dict_hg_tasks.keys()
171
+ )
172
+ st.write(f"The current selected task is : {dict_hg_tasks[selected_task]}")
173
+ with st.sidebar.expander("tasks documentation"):
174
+ st.write("https://huggingface.co/docs/huggingface_hub/package_reference/inference_client")
175
+
176
+ # functions ########################################
177
+ cwd = os.getcwd()
178
+ def get_input(upload,url,text):
179
+ if upload is not None:
180
+ return upload
181
+ else:
182
+ if url:
183
+ return url
184
+ elif text:
185
+ return text
186
+ return None # Default return if neither upload nor url is provided
187
+
188
+ def display_inputs(task):
189
+ if dict_hg_tasks_params[task]["input"] == "upload,url":
190
+ return st.file_uploader("Choose a file"),st.text_input("or enter a file url"),""
191
+ elif dict_hg_tasks_params[task]["input"] == "text":
192
+ return None,"",st.text_input("Enter a text")
193
+ else:
194
+ return None,"",""
195
+
196
+ def display_prompt(task):
197
+ if dict_hg_tasks_params[task]["prompt"] is True:
198
+ return st.text_input("Enter a question")
199
+ return None
200
+
201
+ def display_context(task):
202
+ if dict_hg_tasks_params[task]["context"] is True:
203
+ return st.text_area("Enter a context")
204
+ return None
205
+
206
+ # UI main client ####################################
207
+
208
+ if selected_task :
209
+ response = None
210
+ task = dict_hg_tasks[selected_task]
211
+ if model:
212
+ client = InferenceClient(model=model)
213
+ else:
214
+ client = InferenceClient()
215
+ uploaded_input,url_input,text_input = display_inputs(task)
216
+ prompt_input = display_prompt(task)
217
+ context_input = display_context(task)
218
+ if get_input(uploaded_input,url_input,text_input):
219
+ input = get_input(uploaded_input,url_input,text_input)
220
+ response = getattr(client, task)(input)
221
+ elif prompt_input:
222
+ if context_input is not None:
223
+ response = getattr(client, task)(question=prompt_input,context=context_input)
224
+ else:
225
+ response = getattr(client, task)(input,prompt=prompt_input)
226
+ if response is not None:
227
+ col1,col2 = st.columns(2)
228
+ with col1:
229
+ if "text" in dict_hg_tasks_params[task]["output"]:
230
+ st.write(response)
231
+ elif "audio" in dict_hg_tasks_params[task]["output"]:
232
+ Path(os.path.join(cwd,"audio.flac")).write_bytes(response)
233
+ st.audio(os.path.join(cwd,"audio.flac"))
234
+ with col2:
235
+ if dict_hg_tasks_params[task]["output"] == "image,text":
236
+ image = Image.open(input)
237
+ st.image(image)
238
+ elif dict_hg_tasks_params[task]["output"] == "image":
239
+ response.save(os.path.join(cwd,"generated_image.png"))
240
+ image = Image.open(os.path.join(cwd,"generated_image.png"))
241
+ st.image(image)
242
+