Update src/utils.py
Browse files- src/utils.py +1 -51
src/utils.py
CHANGED
@@ -2,7 +2,6 @@
|
|
2 |
|
3 |
import imp
|
4 |
import numpy as np
|
5 |
-
import cv2
|
6 |
import torch
|
7 |
import random
|
8 |
from PIL import Image, ImageDraw, ImageFont
|
@@ -207,18 +206,6 @@ def load_512(image_path, left=0, right=0, top=0, bottom=0):
|
|
207 |
image = np.array(Image.fromarray(image).resize((512, 512)))
|
208 |
return image
|
209 |
|
210 |
-
def get_canny(image_path):
|
211 |
-
image = load_512(
|
212 |
-
image_path
|
213 |
-
)
|
214 |
-
image = np.array(image)
|
215 |
-
|
216 |
-
# get canny image
|
217 |
-
image = cv2.Canny(image, 100, 200)
|
218 |
-
image = image[:, :, None]
|
219 |
-
image = np.concatenate([image, image, image], axis=2)
|
220 |
-
canny_image = Image.fromarray(image)
|
221 |
-
return canny_image
|
222 |
|
223 |
|
224 |
def get_scribble(image_path, hed):
|
@@ -229,44 +216,7 @@ def get_scribble(image_path, hed):
|
|
229 |
|
230 |
return image
|
231 |
|
232 |
-
|
233 |
-
data_ls = []
|
234 |
-
with open(prompt_path) as f:
|
235 |
-
prompt_ls = json.load(f)
|
236 |
-
img_path = 'COCO2017-val/val2017'
|
237 |
-
for prompt in tqdm(prompt_ls):
|
238 |
-
caption = prompt['caption'].replace('/','_')
|
239 |
-
image_id = str(prompt['image_id'])
|
240 |
-
image_id = (12-len(image_id))*'0' + image_id+'.jpg'
|
241 |
-
image_path = os.path.join(img_path, image_id)
|
242 |
-
try:
|
243 |
-
image = get_canny(image_path)
|
244 |
-
except:
|
245 |
-
continue
|
246 |
-
curr_data = {'image':image, 'prompt':caption}
|
247 |
-
data_ls.append(curr_data)
|
248 |
-
return data_ls
|
249 |
-
|
250 |
-
def get_cocoimages2(prompt_path):
|
251 |
-
"""scribble condition
|
252 |
-
"""
|
253 |
-
data_ls = []
|
254 |
-
with open(prompt_path) as f:
|
255 |
-
prompt_ls = json.load(f)
|
256 |
-
img_path = 'COCO2017-val/val2017'
|
257 |
-
hed = HEDdetector.from_pretrained('ControlNet/detector_weights/annotator', filename='network-bsds500.pth')
|
258 |
-
for prompt in tqdm(prompt_ls):
|
259 |
-
caption = prompt['caption'].replace('/','_')
|
260 |
-
image_id = str(prompt['image_id'])
|
261 |
-
image_id = (12-len(image_id))*'0' + image_id+'.jpg'
|
262 |
-
image_path = os.path.join(img_path, image_id)
|
263 |
-
try:
|
264 |
-
image = get_scribble(image_path,hed)
|
265 |
-
except:
|
266 |
-
continue
|
267 |
-
curr_data = {'image':image, 'prompt':caption}
|
268 |
-
data_ls.append(curr_data)
|
269 |
-
return data_ls
|
270 |
|
271 |
def warpped_feature(sample, step):
|
272 |
"""
|
|
|
2 |
|
3 |
import imp
|
4 |
import numpy as np
|
|
|
5 |
import torch
|
6 |
import random
|
7 |
from PIL import Image, ImageDraw, ImageFont
|
|
|
206 |
image = np.array(Image.fromarray(image).resize((512, 512)))
|
207 |
return image
|
208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
|
211 |
def get_scribble(image_path, hed):
|
|
|
216 |
|
217 |
return image
|
218 |
|
219 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
def warpped_feature(sample, step):
|
222 |
"""
|