Yemin Shi commited on
Commit
4417ee8
1 Parent(s): fa7b0e2

update APIs

Browse files
Files changed (1) hide show
  1. app.py +38 -25
app.py CHANGED
@@ -57,7 +57,6 @@ def upload_image(img):
57
  def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None):
58
  data = {
59
  "type": "gen-image",
60
- "gen_image_num": image_num,
61
  "parameters": {
62
  "width": width, # output height width
63
  "height": height, # output image height
@@ -86,35 +85,43 @@ def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None):
86
  "height": mask.height,
87
  }
88
  headers = {"token": token}
 
89
  # Send create task request
90
- # url = "http://flagstudio.baai.ac.cn/api/v1/task/create"
91
  url = url_host+"/api/v1/task/create"
92
- r = requests.post(url, json=data, headers=headers)
93
- if r.status_code != 200:
94
- raise gr.Error(r.reason)
95
- create_res = r.json()
96
- task_id = create_res["data"]["task_id"]
 
97
 
98
  # Get result
99
  url = url_host+"/api/v1/task/status"
 
100
  while True:
101
- r = requests.post(url, json=create_res["data"], headers=headers)
102
- if r.status_code != 200:
103
- raise gr.Error(r.reason)
104
- res = r.json()
105
- if res["code"] == 6002:
106
- # Running
107
- time.sleep(1)
108
- continue
109
- elif res["code"] == 0:
110
- # Finished
111
- images = []
112
- for img_info in res["data"]["images"]:
113
- img_res = requests.get(img_info["url"])
114
- images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB"))
115
  return images
116
- else:
117
- raise gr.Error(f"Error code: {res['code']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed):
120
  if filter_content(class_draw) != "国画":
@@ -133,7 +140,13 @@ def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed):
133
 
134
 
135
  def img2img(prompt, image_and_mask):
136
- return post_reqest(0, prompt, 512, 512, 1, image_and_mask["image"], image_and_mask["mask"])
 
 
 
 
 
 
137
 
138
 
139
  examples = [
@@ -310,4 +323,4 @@ if __name__ == "__main__":
310
  gr.HTML(read_content("footer.html"))
311
  # gr.Image('./contributors.png')
312
 
313
- block.queue(max_size=50, concurrency_count=20).launch()
 
57
  def post_reqest(seed, prompt, width, height, image_num, img=None, mask=None):
58
  data = {
59
  "type": "gen-image",
 
60
  "parameters": {
61
  "width": width, # output height width
62
  "height": height, # output image height
 
85
  "height": mask.height,
86
  }
87
  headers = {"token": token}
88
+
89
  # Send create task request
90
+ all_task_data = []
91
  url = url_host+"/api/v1/task/create"
92
+ for _ in range(image_num):
93
+ r = requests.post(url, json=data, headers=headers)
94
+ if r.status_code != 200:
95
+ raise gr.Error(r.reason)
96
+ create_res = r.json()
97
+ all_task_data.append(create_res["data"])
98
 
99
  # Get result
100
  url = url_host+"/api/v1/task/status"
101
+ images = []
102
  while True:
103
+ if len(all_task_data) <= 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  return images
105
+ for i in range(len(all_task_data)-1, -1, -1):
106
+ data = all_task_data[i]
107
+ r = requests.post(url, json=data, headers=headers)
108
+ if r.status_code != 200:
109
+ raise gr.Error(r.reason)
110
+ res = r.json()
111
+ if res["code"] == 6002:
112
+ # Running
113
+ continue
114
+ if res["code"] == 6005:
115
+ raise gr.Error("NSFW image detected.")
116
+ elif res["code"] == 0:
117
+ # Finished
118
+ for img_info in res["data"]["images"]:
119
+ img_res = requests.get(img_info["url"])
120
+ images.append(Image.open(io.BytesIO(img_res.content)).convert("RGB"))
121
+ del all_task_data[i]
122
+ else:
123
+ raise gr.Error(f"Error code: {res['code']}")
124
+ time.sleep(1)
125
 
126
  def request_images(raw_text, class_draw, style_draw, batch_size, w, h, seed):
127
  if filter_content(class_draw) != "国画":
 
140
 
141
 
142
  def img2img(prompt, image_and_mask):
143
+ if image_and_mask["image"].width <= image_and_mask["image"].height:
144
+ width = 512
145
+ height = int((width/image_and_mask["image"].width)*image_and_mask["image"].height)
146
+ else:
147
+ height = 512
148
+ width = int((height/image_and_mask["image"].height)*image_and_mask["image"].width)
149
+ return post_reqest(0, prompt, width, height, 1, image_and_mask["image"], image_and_mask["mask"])
150
 
151
 
152
  examples = [
 
323
  gr.HTML(read_content("footer.html"))
324
  # gr.Image('./contributors.png')
325
 
326
+ block.queue(max_size=100, concurrency_count=50).launch()