language-model commited on
Commit
36b1569
·
verified ·
1 Parent(s): 02a31c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -0
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from requests.exceptions import ConnectTimeout
3
+ import time
4
+ import requests
5
+ import base64
6
+
7
+ global headers
8
+ global cancel_url
9
+ global path
10
+ global output_image
11
+ global property_name_array
12
+ property_name_array =[]
13
+ output_image = ''
14
+ path = ''
15
+ cancel_url =''
16
+ headers = {
17
+ 'Content-Type': 'application/json',
18
+ 'Authorization': 'Token r8_ZGZlzThfRkPZVDMygVclY1XZ9AuxmIQ2qwwPP',
19
+ "Access-Control-Allow-Headers": "Content-Type",
20
+ "Access-Control-Allow-Origin": '**',
21
+ "Access-Control-Allow-Methods": "OPTIONS,POST,GET,PATCH"}
22
+
23
+ with gr.Blocks() as demo:
24
+ owner = "adirik"
25
+ name = "mamba-130m"
26
+ max_retries = 3
27
+ retry_delay = 2
28
+ for retry in range(max_retries):
29
+ try:
30
+ url = f'https://api.replicate.com/v1/models/{owner}/{name}'
31
+ response = requests.get(url, headers=headers, timeout=10)
32
+ # Process the response
33
+ break # Break out of the loop if the request is successful
34
+ except ConnectTimeout:
35
+ if retry < max_retries - 1:
36
+ print(f"Connection timed out. Retrying in {retry_delay} seconds...")
37
+ time.sleep(retry_delay)
38
+ else:
39
+ print("Max retries exceeded. Unable to establish connection.")
40
+
41
+ data = response.json()
42
+ description =data.get("description", '')
43
+ title = data.get("default_example",'').get("model",'')
44
+ version = data.get("default_example",'').get("version",'')
45
+
46
+ gr.Markdown(
47
+ f"""
48
+ # {title}
49
+ {description}
50
+ """)
51
+
52
+ with gr.Row():
53
+ with gr.Column():
54
+ inputs =[]
55
+ schema = data.get("latest_version", {}).get("openapi_schema", {}).get("components", {}).get("schemas", {})
56
+ ordered_properties = sorted(schema.get("Input", {}).get("properties", {}).items(), key=lambda x: x[1].get("x-order", 0))
57
+ required = schema.get("Input", '').get('required', [])
58
+ for property_name, property_info in ordered_properties :
59
+ property_name_array.append(property_name)
60
+ if required:
61
+ for item in required:
62
+ if item == property_name:
63
+ label = "*"+ property_info.get('title', '')
64
+ description = property_info.get('description','')
65
+ break
66
+ else:
67
+ label = property_info.get('title', '')
68
+ description = property_info.get('description','')
69
+ else:
70
+ label = property_info.get('title', '')
71
+ description = property_info.get('description','')
72
+
73
+ if "x-order" in property_info:
74
+ order = int(property_info.get('x-order',''))
75
+ if property_info.get("type", {}) == "integer":
76
+ value= data.get('default_example', '').get('input','').get(property_name,0)
77
+ if "minimum" and "maximum" in property_info:
78
+ if value == 0:
79
+ inputs.insert(order, gr.Slider(label=label, info= description, value=property_info.get('default', value), minimum=property_info.get('minimum', ''), maximum=property_info.get('maximum', ''), step=1))
80
+ else:
81
+ inputs.insert(order, gr.Slider(label=label, info= description, value=value, minimum=property_info.get('minimum', ''), maximum=property_info.get('maximum', ''), step=1))
82
+ else:
83
+ if value == 0:
84
+ inputs.insert(order, gr.Number(label=label, info= description, value=property_info.get('default', value)))
85
+ else:
86
+ inputs.insert(order, gr.Number(label=label, info= description, value=value))
87
+
88
+ elif property_info.get("type", {}) == "string":
89
+ value= data.get('default_example', '').get('input','').get(property_name,'')
90
+ if property_info.get('format','') == 'uri':
91
+
92
+ if value :
93
+ inputs.insert(order, gr.Image(label=label, value=value, type="filepath"))
94
+ else :
95
+ inputs.insert(order, gr.Image(label=label, type="filepath"))
96
+
97
+ else:
98
+ if value == '':
99
+ inputs.insert(order, gr.Textbox(label=label,info= description, value=property_info.get('default', value)))
100
+ else:
101
+ inputs.insert(order, gr.Textbox(label=label,info= description, value=value))
102
+
103
+ elif property_info.get("type", {}) == "number":
104
+ value= data.get('default_example', '').get('input','').get(property_name, 0)
105
+ if "minimum" and "maximum" in property_info:
106
+ if value == 0:
107
+ inputs.insert(order, gr.Slider(label=label,info= description, value=property_info.get('default', value), minimum=property_info.get('minimum', ''), maximum=property_info.get('maximum', '')))
108
+ else:
109
+ inputs.insert(order, gr.Slider(label=label,info= description, value=value, minimum=property_info.get('minimum', ''), maximum=property_info.get('maximum', '')))
110
+ else:
111
+ if value == 0:
112
+ inputs.insert(order, gr.Number(label=label,info= description, value=property_info.get('default', value)))
113
+ else:
114
+ inputs.insert(order, gr.Number(label=label,info= description, value=value))
115
+ elif property_info.get("type", {}) == "boolean":
116
+ value= data.get('default_example', '').get('input','').get(property_name,'')
117
+ if value == '':
118
+ inputs.insert(order, gr.Checkbox(label=label,info= description, value=property_info.get('default', value)))
119
+ else:
120
+ inputs.insert(order, gr.Checkbox(label=label,info= description, value=value))
121
+ else:
122
+ value= data.get('default_example', '').get('input','').get(property_name,'')
123
+ options=schema.get(property_name,'').get('enum',[])
124
+ if value:
125
+ inputs.insert(order, gr.Dropdown(label=property_name,info= description,choices=options, value=property_info.get("default", value)))
126
+ else:
127
+ inputs.insert(order, gr.Dropdown(label=property_name,info= description,choices=options, value=value))
128
+
129
+ with gr.Row():
130
+ cancel_btn = gr.Button("Cancel")
131
+ run_btn = gr.Button("Run")
132
+
133
+ with gr.Column():
134
+
135
+ outputs = []
136
+
137
+ output_result = data.get("default_example", '').get("output")
138
+ output_type= schema.get("Output", '').get("type", '')
139
+ if output_type == 'array':
140
+ output_image = ''.join(output_result)
141
+ else:
142
+ output_image = output_result
143
+ outputs.append(gr.TextArea(value=output_image))
144
+ outputs.append(gr.Image(visible=False))
145
+ outputs.append(gr.Image(visible=False))
146
+ outputs.append(gr.Image(visible=False))
147
+
148
+
149
+
150
+ def run_process(input1, input2, input3, input4, input5, input6, input7):
151
+ global cancel_url
152
+ global property_name_array
153
+ print(len(property_name_array))
154
+ cancel_url=''
155
+ url = 'https://replicate.com/api/predictions'
156
+
157
+ body = {
158
+ "version": version,
159
+ "input": {
160
+ property_name_array[0]: input1,
161
+ property_name_array[1]: input2,
162
+ property_name_array[2]: input3,
163
+ property_name_array[3]: input4,
164
+ }
165
+ }
166
+
167
+
168
+ response = requests.post(url, json=body)
169
+ print(response.status_code)
170
+ if response.status_code == 201:
171
+ response_data = response.json()
172
+ get_url = response_data.get('urls','').get('get','')
173
+ identifier = 'https://replicate.com/api/predictions/'+get_url.split("/")[-1]
174
+
175
+ print(identifier,'')
176
+ time.sleep(3)
177
+ output =verify_image(identifier)
178
+ if output:
179
+ return gr.TextArea(value=''.join(output)), gr.Image(),gr.Image(),gr.Image()
180
+
181
+ return gr.Image(),gr.Image(visible=False),gr.Image(visible=False),gr.Image(visible=False)
182
+
183
+ def cancel_process(input1, input2, input3, input4, input5, input6, input7):
184
+ global cancel_url
185
+ cancel_url = '123'
186
+ global output_image
187
+ return gr.TextArea(value=output_image), gr.Image(visible=False),gr.Image(visible=False),gr.Image(visible=False)
188
+
189
+ def verify_image(get_url):
190
+ res = requests.get(get_url)
191
+ if res.status_code == 200:
192
+ res_data = res.json()
193
+ if res_data.get('error',''):
194
+ return
195
+ else:
196
+ if cancel_url:
197
+ return
198
+ else:
199
+ output = res_data.get('output', [])
200
+ if output:
201
+ return output
202
+
203
+ else:
204
+ time.sleep(1)
205
+ val = verify_image(get_url)
206
+ return val
207
+ else:
208
+ return []
209
+
210
+ run_btn.click(run_process, inputs=inputs, outputs=outputs, api_name="run")
211
+ cancel_btn.click(cancel_process, inputs=inputs, outputs=outputs, api_name="cancel")
212
+
213
+ demo.launch()
214
+
215
+
216
+
217
+
218
+