commune-ai commited on
Commit
fd6a563
1 Parent(s): acbf84b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from requests.exceptions import ConnectTimeout
3
+ import time
4
+ import requests
5
+ import base64
6
+
7
+ global headers
8
+ global cancel_url
9
+ global path
10
+ path = ''
11
+ cancel_url =''
12
+ headers = {
13
+ 'Content-Type': 'application/json',
14
+ 'Authorization': 'Token r8_ZGZlzThfRkPZVDMygVclY1XZ9AuxmIQ2qwwPP',
15
+ "Access-Control-Allow-Headers": "Content-Type",
16
+ "Access-Control-Allow-Origin": '**',
17
+ "Access-Control-Allow-Methods": "OPTIONS,POST,GET,PATCH"}
18
+
19
+ with gr.Blocks() as demo:
20
+ owner = "adirik"
21
+ name = "realvisxl-v4.0"
22
+ max_retries = 3
23
+ retry_delay = 2
24
+ for retry in range(max_retries):
25
+ try:
26
+ url = f'https://api.replicate.com/v1/models/{owner}/{name}'
27
+ response = requests.get(url, headers=headers, timeout=10)
28
+ # Process the response
29
+ break # Break out of the loop if the request is successful
30
+ except ConnectTimeout:
31
+ if retry < max_retries - 1:
32
+ print(f"Connection timed out. Retrying in {retry_delay} seconds...")
33
+ time.sleep(retry_delay)
34
+ else:
35
+ print("Max retries exceeded. Unable to establish connection.")
36
+
37
+ data = response.json()
38
+ description =data.get("description", '')
39
+ title = data.get("default_example",'').get("model",'')
40
+ version = data.get("default_example",'').get("version",'')
41
+
42
+ gr.Markdown(
43
+ f"""
44
+ # {title}
45
+ {description}
46
+ """)
47
+
48
+ with gr.Row():
49
+ with gr.Column():
50
+ inputs =[]
51
+ schema = data.get("latest_version", {}).get("openapi_schema", {}).get("components", {}).get("schemas", {})
52
+ ordered_properties = sorted(schema.get("Input", {}).get("properties", {}).items(), key=lambda x: x[1].get("x-order", 0))
53
+ required = schema.get("Input", '').get('required', [])
54
+ print(required,"required")
55
+ for property_name, property_info in ordered_properties :
56
+ if required:
57
+ for item in required:
58
+ if item == property_name:
59
+ label = "*"+ property_info.get('title', '')
60
+ description = property_info.get('description','')
61
+ break
62
+ else:
63
+ label = property_info.get('title', '')
64
+ description = property_info.get('description','')
65
+ else:
66
+ label = property_info.get('title', '')
67
+ description = property_info.get('description','')
68
+
69
+ if "x-order" in property_info:
70
+ order = int(property_info.get('x-order',''))
71
+ if property_info.get("type", {}) == "integer":
72
+ value= data.get('default_example', '').get('input','').get(property_name,0)
73
+ if "minimum" and "maximum" in property_info:
74
+ if value == 0:
75
+ inputs.insert(order, gr.Slider(label=label, info= description, value=property_info.get('default', value), minimum=property_info.get('minimum', ''), maximum=property_info.get('maximum', ''), step=1))
76
+ else:
77
+ inputs.insert(order, gr.Slider(label=label, info= description, value=value, minimum=property_info.get('minimum', ''), maximum=property_info.get('maximum', ''), step=1))
78
+ else:
79
+ if value == 0:
80
+ inputs.insert(order, gr.Number(label=label, info= description, value=property_info.get('default', value)))
81
+ else:
82
+ inputs.insert(order, gr.Number(label=label, info= description, value=value))
83
+
84
+ elif property_info.get("type", {}) == "string":
85
+ value= data.get('default_example', '').get('input','').get(property_name,'')
86
+ if property_info.get('format','') == 'uri':
87
+
88
+ if value :
89
+ inputs.insert(order, gr.Image(label=label, value=value, type="filepath"))
90
+ else :
91
+ inputs.insert(order, gr.Image(label=label, type="filepath"))
92
+
93
+ else:
94
+ if value == '':
95
+ inputs.insert(order, gr.Textbox(label=label,info= description, value=property_info.get('default', value)))
96
+ else:
97
+ inputs.insert(order, gr.Textbox(label=label,info= description, value=value))
98
+
99
+ elif property_info.get("type", {}) == "number":
100
+ value= data.get('default_example', '').get('input','').get(property_name, 0)
101
+ if "minimum" and "maximum" in property_info:
102
+ if value == 0:
103
+ inputs.insert(order, gr.Slider(label=label,info= description, value=property_info.get('default', value), minimum=property_info.get('minimum', ''), maximum=property_info.get('maximum', '')))
104
+ else:
105
+ inputs.insert(order, gr.Slider(label=label,info= description, value=value, minimum=property_info.get('minimum', ''), maximum=property_info.get('maximum', '')))
106
+ else:
107
+ if value == 0:
108
+ inputs.insert(order, gr.Number(label=label,info= description, value=property_info.get('default', value)))
109
+ else:
110
+ inputs.insert(order, gr.Number(label=label,info= description, value=value))
111
+ elif property_info.get("type", {}) == "boolean":
112
+ value= data.get('default_example', '').get('input','').get(property_name,'')
113
+ if value == '':
114
+ inputs.insert(order, gr.Checkbox(label=label,info= description, value=property_info.get('default', value)))
115
+ else:
116
+ inputs.insert(order, gr.Checkbox(label=label,info= description, value=value))
117
+ else:
118
+ value= data.get('default_example', '').get('input','').get(property_name,'')
119
+ options=schema.get(property_name,'').get('enum',[])
120
+ if value == '':
121
+ inputs.insert(order, gr.Dropdown(label=property_name,info= description,choices=options, value=property_info.get("default", value)))
122
+ else:
123
+ inputs.insert(order, gr.Dropdown(label=property_name,info= description,choices=options, value=value))
124
+
125
+ with gr.Row():
126
+ cancel_btn = gr.Button("Cancel")
127
+ run_btn = gr.Button("Run")
128
+
129
+ with gr.Column():
130
+
131
+ outputs = []
132
+ outputs.append(gr.Image(value='https://replicate.delivery/pbxt/koQLfGV4o8yWGi4reeIvJQwCxmxrD3S7iQFGre8IfISrpnCTC/out-0.png'))
133
+ outputs.append(gr.Image(visible=False))
134
+ outputs.append(gr.Image(visible=False))
135
+ outputs.append(gr.Image(visible=False))
136
+
137
+
138
+
139
+ def run_process(input1,input2,input3,input4,input5,input6,input7, input8, input9,input10,input11,input12, input13,input14, input15, input16, input17):
140
+ global cancel_url
141
+ cancel_url=''
142
+ url = 'https://replicate.com/api/predictions'
143
+ if input3:
144
+ with open(input3, "rb") as file:
145
+ data = file.read()
146
+
147
+ base64_data = base64.b64encode(data).decode("utf-8")
148
+ mimetype = "image/jpg"
149
+ data_uri_image = f"data:{mimetype};base64,{base64_data}"
150
+ else:
151
+ data_uri_image=None
152
+
153
+ if input4:
154
+ with open(input4, "rb") as file:
155
+ data = file.read()
156
+
157
+ base64_data = base64.b64encode(data).decode("utf-8")
158
+ mimetype = "image/jpg"
159
+ data_uri_mask = f"data:{mimetype};base64,{base64_data}"
160
+ else:
161
+ data_uri_mask=None
162
+
163
+ if input3:
164
+ if input4:
165
+ body = {
166
+ "version": version,
167
+ "input": {
168
+ "prompt": input1,
169
+ "negative_prompt": input2,
170
+ "image": data_uri_image,
171
+ "mask": data_uri_mask,
172
+ "width": input5,
173
+ "height": input6,
174
+ "num_outputs": input7,
175
+ "scheduler": input8,
176
+ "num_inference_steps": input9,
177
+ "guidance_scale": input10,
178
+ "prompt_strength":input11,
179
+ "seed": input12,
180
+ "refine": input13,
181
+ "high_noise_frac": input14,
182
+ "refine_steps": input15
183
+ }
184
+ }
185
+ else:
186
+ body = {
187
+ "version": version,
188
+ "input": {
189
+ "prompt": input1,
190
+ "negative_prompt": input2,
191
+ "image": data_uri_image,
192
+ "width": input5,
193
+ "height": input6,
194
+ "num_outputs": input7,
195
+ "scheduler": input8,
196
+ "num_inference_steps": input9,
197
+ "guidance_scale": input10,
198
+ "prompt_strength":input11,
199
+ "seed": input12,
200
+ "refine": input13,
201
+ "high_noise_frac": input14,
202
+ "refine_steps": input15
203
+ }
204
+
205
+ }
206
+
207
+ else:
208
+ if input4:
209
+ body = {
210
+ "version": version,
211
+ "input": {
212
+ "prompt": input1,
213
+ "negative_prompt": input2,
214
+ "mask": data_uri_mask,
215
+ "width": input5,
216
+ "height": input6,
217
+ "num_outputs": input7,
218
+ "scheduler": input8,
219
+ "num_inference_steps": input9,
220
+ "guidance_scale": input10,
221
+ "prompt_strength":input11,
222
+ "seed": input12,
223
+ "refine": input13,
224
+ "high_noise_frac": input14,
225
+ "refine_steps": input15
226
+
227
+ }
228
+ }
229
+ else:
230
+ body = {
231
+ "version": version,
232
+ "input": {
233
+ "prompt": input1,
234
+ "negative_prompt": input2,
235
+ "width": input5,
236
+ "height": input6,
237
+ "num_outputs": input7,
238
+ "scheduler": input8,
239
+ "num_inference_steps": input9,
240
+ "guidance_scale": input10,
241
+ "prompt_strength":input11,
242
+ "seed": input12,
243
+ "refine": input13,
244
+ "high_noise_frac": input14,
245
+ "refine_steps": input15
246
+
247
+ }
248
+ }
249
+
250
+
251
+
252
+ response = requests.post(url, json=body)
253
+ print(response.status_code)
254
+ if response.status_code == 201:
255
+ response_data = response.json()
256
+ get_url = response_data.get('urls','').get('get','')
257
+ identifier = 'https://replicate.com/api/predictions/'+get_url.split("/")[-1]
258
+
259
+ print(identifier,'')
260
+ time.sleep(3)
261
+ output =verify_image(identifier)
262
+ print(output,'333')
263
+ if output:
264
+ if len(output) == 1:
265
+ return gr.Image(value=output[0]), gr.Image(),gr.Image(),gr.Image()
266
+ elif len(output) == 2:
267
+ return gr.Image(value=output[0]), gr.Image(value=output[1],visible= True),gr.Image(),gr.Image()
268
+ elif len(output) == 3:
269
+ return gr.Image(value=output[0]), gr.Image(value=output[1],visible= True),gr.Image(value=output[2],visible= True),gr.Image()
270
+ elif len(output) == 3:
271
+ return gr.Image(value=output[0]), gr.Image(value=output[1],visible= True),gr.Image(value=output[2],visible= True),gr.Image(value=output[2],visible= True)
272
+
273
+ return gr.Image(),gr.Image(visible=False),gr.Image(visible=False),gr.Image(visible=False)
274
+
275
+ def cancel_process(input1,input2,input3,input4,input5,input6,input7, input8, input9,input10,input11,input12, input13,input14, input15, input16, input17):
276
+ global cancel_url
277
+ cancel_url = '123'
278
+ return gr.Image(value='https://replicate.delivery/pbxt/koQLfGV4o8yWGi4reeIvJQwCxmxrD3S7iQFGre8IfISrpnCTC/out-0.png'), gr.Image(visible=False),gr.Image(visible=False),gr.Image(visible=False)
279
+
280
+ def verify_image(get_url):
281
+ res = requests.get(get_url)
282
+ if res.status_code == 200:
283
+ res_data = res.json()
284
+ if res_data.get('error',''):
285
+ return
286
+ else:
287
+ if cancel_url:
288
+ return
289
+ else:
290
+ output = res_data.get('output', [])
291
+ print(output,'111')
292
+ if output:
293
+ print(output,'222')
294
+ return output
295
+
296
+ else:
297
+ time.sleep(1)
298
+ val = verify_image(get_url)
299
+ return val
300
+ else:
301
+ return []
302
+
303
+ run_btn.click(run_process, inputs=inputs, outputs=outputs, api_name="run")
304
+ cancel_btn.click(cancel_process, inputs=inputs, outputs=outputs, api_name="cancel")
305
+
306
+ demo.launch()
307
+
308
+
309
+
310
+
311
+