mart9992 commited on
Commit
a32aec9
1 Parent(s): 06ba6ce
Files changed (2) hide show
  1. handler.py +1070 -0
  2. requirements.txt +5 -3
handler.py ADDED
@@ -0,0 +1,1070 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings('ignore')
3
+
4
+ import subprocess, io, os, sys, time
5
+
6
+ is_production = True
7
+ os.environ['CUDA_HOME'] = '/usr/local/cuda-11.7/' if is_production else '/usr/local/cuda-12.1/'
8
+
9
+ run_gradio = False
10
+
11
+ if run_gradio:
12
+ os.system("pip install gradio==3.50.2")
13
+
14
+ import gradio as gr
15
+
16
+ from loguru import logger
17
+
18
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
19
+
20
+ if is_production:
21
+ os.chdir("/repository")
22
+ sys.path.insert(0, '/repository')
23
+
24
+ if os.environ.get('IS_MY_DEBUG') is None:
25
+ result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True)
26
+ print(f'pip install GroundingDINO = {result}')
27
+
28
+ # result = subprocess.run(['pip', 'list'], check=True)
29
+ # print(f'pip list = {result}')
30
+
31
+ sys.path.insert(0, '/repository/GroundingDINO' if is_production else "./GroundingDINO")
32
+
33
+ import argparse
34
+ import copy
35
+
36
+ import numpy as np
37
+ import torch
38
+ from PIL import Image, ImageDraw, ImageFont, ImageOps
39
+
40
+ # Grounding DINO
41
+ import GroundingDINO.groundingdino.datasets.transforms as T
42
+ from GroundingDINO.groundingdino.models import build_model
43
+ from GroundingDINO.groundingdino.util import box_ops
44
+ from GroundingDINO.groundingdino.util.slconfig import SLConfig
45
+ from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
46
+
47
+ import cv2
48
+ import numpy as np
49
+ import matplotlib
50
+ matplotlib.use('AGG')
51
+ plt = matplotlib.pyplot
52
+ # import matplotlib.pyplot as plt
53
+
54
+ groundingdino_enable = True
55
+ sam_enable = True
56
+ inpainting_enable = True
57
+ ram_enable = True
58
+
59
+ lama_cleaner_enable = True
60
+
61
+ kosmos_enable = False
62
+
63
+ # qwen_enable = True
64
+ # from qwen_utils import *
65
+
66
+ if os.environ.get('IS_MY_DEBUG') is not None:
67
+ sam_enable = False
68
+ ram_enable = False
69
+ inpainting_enable = False
70
+ kosmos_enable = False
71
+
72
+ if lama_cleaner_enable:
73
+ try:
74
+ from lama_cleaner.model_manager import ModelManager
75
+ from lama_cleaner.schema import Config as lama_Config
76
+ except Exception as e:
77
+ lama_cleaner_enable = False
78
+
79
+ # segment anything
80
+ from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
81
+
82
+ # diffusers
83
+ import PIL
84
+ import requests
85
+ import torch
86
+ from io import BytesIO
87
+ from diffusers import StableDiffusionInpaintPipeline
88
+ from huggingface_hub import hf_hub_download
89
+
90
+ from util_computer import computer_info
91
+
92
+ # relate anything
93
+ from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask
94
+ from ram_train_eval import RamModel, RamPredictor
95
+ from mmengine.config import Config as mmengine_Config
96
+
97
+ if lama_cleaner_enable:
98
+ from lama_cleaner.helper import (
99
+ load_img,
100
+ numpy_to_bytes,
101
+ resize_max_size,
102
+ )
103
+
104
+ # from transformers import AutoProcessor, AutoModelForVision2Seq
105
+ import ast
106
+
107
+ if kosmos_enable:
108
+ os.system("pip install transformers@git+https://github.com/huggingface/transformers.git@main")
109
+ # os.system("pip install transformers==4.32.0")
110
+
111
+ from kosmos_utils import *
112
+
113
+ from util_tencent import getTextTrans
114
+
115
+ config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
116
+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
117
+ ckpt_filenmae = "groundingdino_swint_ogc.pth"
118
+ sam_checkpoint = './sam_vit_h_4b8939.pth'
119
+ output_dir = "outputs"
120
+
121
+ device = 'cpu'
122
+ os.makedirs(output_dir, exist_ok=True)
123
+ groundingdino_model = None
124
+ sam_device = "cuda"
125
+ sam_model = None
126
+
127
+
128
+ def get_sam_vit_h_4b8939():
129
+ if not os.path.exists('./sam_vit_h_4b8939.pth'):
130
+ logger.info(f"get sam_vit_h_4b8939.pth...")
131
+ result = subprocess.run(['wget', 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'], check=True)
132
+ print(f'wget sam_vit_h_4b8939.pth result = {result}')
133
+
134
+ get_sam_vit_h_4b8939()
135
+ logger.info(f"initialize SAM model...")
136
+ sam_device = "cuda"
137
+ sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
138
+ sam_predictor = SamPredictor(sam_model)
139
+ sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
140
+
141
+ sam_mask_generator = None
142
+ sd_model = None
143
+ lama_cleaner_model= None
144
+ ram_model = None
145
+ kosmos_model = None
146
+ kosmos_processor = None
147
+
148
+ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
149
+ args = SLConfig.fromfile(model_config_path)
150
+ model = build_model(args)
151
+ args.device = device
152
+
153
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
154
+ checkpoint = torch.load(cache_file, map_location=device)
155
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
156
+ print("Model loaded from {} \n => {}".format(cache_file, log))
157
+ _ = model.eval()
158
+ return model
159
+
160
+ def plot_boxes_to_image(image_pil, tgt):
161
+ H, W = tgt["size"]
162
+ boxes = tgt["boxes"]
163
+ labels = tgt["labels"]
164
+ assert len(boxes) == len(labels), "boxes and labels must have same length"
165
+
166
+ draw = ImageDraw.Draw(image_pil)
167
+ mask = Image.new("L", image_pil.size, 0)
168
+ mask_draw = ImageDraw.Draw(mask)
169
+
170
+ # draw boxes and masks
171
+ for box, label in zip(boxes, labels):
172
+ # from 0..1 to 0..W, 0..H
173
+ box = box * torch.Tensor([W, H, W, H])
174
+ # from xywh to xyxy
175
+ box[:2] -= box[2:] / 2
176
+ box[2:] += box[:2]
177
+ # random color
178
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
179
+ # draw
180
+ x0, y0, x1, y1 = box
181
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
182
+
183
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
184
+ # draw.text((x0, y0), str(label), fill=color)
185
+
186
+ font = ImageFont.load_default()
187
+ if hasattr(font, "getbbox"):
188
+ bbox = draw.textbbox((x0, y0), str(label), font)
189
+ else:
190
+ w, h = draw.textsize(str(label), font)
191
+ bbox = (x0, y0, w + x0, y0 + h)
192
+ # bbox = draw.textbbox((x0, y0), str(label))
193
+ draw.rectangle(bbox, fill=color)
194
+
195
+ try:
196
+ font = os.path.join(cv2.__path__[0],'qt','fonts','DejaVuSans.ttf')
197
+ font_size = 36
198
+ new_font = ImageFont.truetype(font, font_size)
199
+
200
+ draw.text((x0+2, y0+2), str(label), font=new_font, fill="white")
201
+ except Exception as e:
202
+ pass
203
+
204
+ mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
205
+
206
+
207
+ return image_pil, mask
208
+
209
+ def load_image(image_path):
210
+ # # load image
211
+ if isinstance(image_path, PIL.Image.Image):
212
+ image_pil = image_path
213
+ else:
214
+ image_pil = Image.open(image_path).convert("RGB") # load image
215
+
216
+ transform = T.Compose(
217
+ [
218
+ T.RandomResize([800], max_size=1333),
219
+ T.ToTensor(),
220
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
221
+ ]
222
+ )
223
+ image, _ = transform(image_pil, None) # 3, h, w
224
+ return image_pil, image
225
+
226
+ def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
227
+ caption = caption.lower()
228
+ caption = caption.strip()
229
+ if not caption.endswith("."):
230
+ caption = caption + "."
231
+ model = model.to(device)
232
+ image = image.to(device)
233
+ with torch.no_grad():
234
+ outputs = model(image[None], captions=[caption])
235
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
236
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
237
+ logits.shape[0]
238
+
239
+ # filter output
240
+ logits_filt = logits.clone()
241
+ boxes_filt = boxes.clone()
242
+ filt_mask = logits_filt.max(dim=1)[0] > box_threshold
243
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
244
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
245
+ logits_filt.shape[0]
246
+
247
+ # get phrase
248
+ tokenlizer = model.tokenizer
249
+ tokenized = tokenlizer(caption)
250
+ # build pred
251
+ pred_phrases = []
252
+ for logit, box in zip(logits_filt, boxes_filt):
253
+ pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
254
+ if with_logits:
255
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
256
+ else:
257
+ pred_phrases.append(pred_phrase)
258
+
259
+ return boxes_filt, pred_phrases
260
+
261
+ def show_mask(mask, ax, random_color=False):
262
+ if random_color:
263
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
264
+ else:
265
+ color = np.array([30/255, 144/255, 255/255, 0.6])
266
+ h, w = mask.shape[-2:]
267
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
268
+ ax.imshow(mask_image)
269
+
270
+ def show_box(box, ax, label):
271
+ x0, y0 = box[0], box[1]
272
+ w, h = box[2] - box[0], box[3] - box[1]
273
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
274
+ ax.text(x0, y0, label)
275
+
276
+ def xywh_to_xyxy(box, sizeW, sizeH):
277
+ if isinstance(box, list):
278
+ box = torch.Tensor(box)
279
+ box = box * torch.Tensor([sizeW, sizeH, sizeW, sizeH])
280
+ box[:2] -= box[2:] / 2
281
+ box[2:] += box[:2]
282
+ box = box.numpy()
283
+ return box
284
+
285
+ def mask_extend(img, box, extend_pixels=10, useRectangle=True):
286
+ box[0] = int(box[0])
287
+ box[1] = int(box[1])
288
+ box[2] = int(box[2])
289
+ box[3] = int(box[3])
290
+ region = img.crop(tuple(box))
291
+ new_width = box[2] - box[0] + 2*extend_pixels
292
+ new_height = box[3] - box[1] + 2*extend_pixels
293
+
294
+ region_BILINEAR = region.resize((int(new_width), int(new_height)))
295
+ if useRectangle:
296
+ region_draw = ImageDraw.Draw(region_BILINEAR)
297
+ region_draw.rectangle((0, 0, new_width, new_height), fill=(255, 255, 255))
298
+ img.paste(region_BILINEAR, (int(box[0]-extend_pixels), int(box[1]-extend_pixels)))
299
+ return img
300
+
301
+ def mix_masks(imgs):
302
+ re_img = 1 - np.asarray(imgs[0].convert("1"))
303
+ for i in range(len(imgs)-1):
304
+ re_img = np.multiply(re_img, 1 - np.asarray(imgs[i+1].convert("1")))
305
+ re_img = 1 - re_img
306
+ return Image.fromarray(np.uint8(255*re_img))
307
+
308
+ def set_device():
309
+ if os.environ.get('IS_MY_DEBUG') is None:
310
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
311
+ else:
312
+ device = 'cpu'
313
+ print(f'device={device}')
314
+ return device
315
+
316
+ def load_groundingdino_model(device):
317
+ # initialize groundingdino model
318
+ logger.info(f"initialize groundingdino model...")
319
+ groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device) #'cpu')
320
+ return groundingdino_model
321
+
322
+
323
+
324
+ def load_sam_model(device):
325
+ # initialize SAM
326
+ global sam_model, sam_predictor, sam_mask_generator, sam_device
327
+ get_sam_vit_h_4b8939()
328
+ logger.info(f"initialize SAM model...")
329
+ sam_device = device
330
+ sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
331
+ sam_predictor = SamPredictor(sam_model)
332
+ sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
333
+
334
+ def load_sd_model(device):
335
+ # initialize stable-diffusion-inpainting
336
+ global sd_model
337
+ logger.info(f"initialize stable-diffusion-inpainting...")
338
+ sd_model = None
339
+ if os.environ.get('IS_MY_DEBUG') is None:
340
+ sd_model = StableDiffusionInpaintPipeline.from_pretrained(
341
+ "runwayml/stable-diffusion-inpainting",
342
+ revision="fp16",
343
+ # "stabilityai/stable-diffusion-2-inpainting",
344
+ torch_dtype=torch.float16,
345
+ )
346
+ sd_model = sd_model.to(device)
347
+
348
+ def load_lama_cleaner_model(device):
349
+ # initialize lama_cleaner
350
+ global lama_cleaner_model
351
+ logger.info(f"initialize lama_cleaner...")
352
+
353
+ lama_cleaner_model = ModelManager(
354
+ name='lama',
355
+ device=device,
356
+ )
357
+
358
+ def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
359
+ try:
360
+ logger.info(f'_______lama_cleaner_process_______1____')
361
+ ori_image = image
362
+ if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
363
+ # rotate image
364
+ logger.info(f'_______lama_cleaner_process_______2____')
365
+ ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
366
+ logger.info(f'_______lama_cleaner_process_______3____')
367
+ image = ori_image
368
+
369
+ logger.info(f'_______lama_cleaner_process_______4____')
370
+ original_shape = ori_image.shape
371
+ logger.info(f'_______lama_cleaner_process_______5____')
372
+ interpolation = cv2.INTER_CUBIC
373
+
374
+ size_limit = cleaner_size_limit
375
+ if size_limit == -1:
376
+ logger.info(f'_______lama_cleaner_process_______6____')
377
+ size_limit = max(image.shape)
378
+ else:
379
+ logger.info(f'_______lama_cleaner_process_______7____')
380
+ size_limit = int(size_limit)
381
+
382
+ logger.info(f'_______lama_cleaner_process_______8____')
383
+ config = lama_Config(
384
+ ldm_steps=25,
385
+ ldm_sampler='plms',
386
+ zits_wireframe=True,
387
+ hd_strategy='Original',
388
+ hd_strategy_crop_margin=196,
389
+ hd_strategy_crop_trigger_size=1280,
390
+ hd_strategy_resize_limit=2048,
391
+ prompt='',
392
+ use_croper=False,
393
+ croper_x=0,
394
+ croper_y=0,
395
+ croper_height=512,
396
+ croper_width=512,
397
+ sd_mask_blur=5,
398
+ sd_strength=0.75,
399
+ sd_steps=50,
400
+ sd_guidance_scale=7.5,
401
+ sd_sampler='ddim',
402
+ sd_seed=42,
403
+ cv2_flag='INPAINT_NS',
404
+ cv2_radius=5,
405
+ )
406
+
407
+ logger.info(f'_______lama_cleaner_process_______9____')
408
+ if config.sd_seed == -1:
409
+ config.sd_seed = random.randint(1, 999999999)
410
+
411
+ # logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
412
+ logger.info(f'_______lama_cleaner_process_______10____')
413
+ image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
414
+ # logger.info(f"Resized image shape_1_: {image.shape}")
415
+
416
+ # logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
417
+ logger.info(f'_______lama_cleaner_process_______11____')
418
+ mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
419
+ # logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
420
+
421
+ logger.info(f'_______lama_cleaner_process_______12____')
422
+ res_np_img = lama_cleaner_model(image, mask, config)
423
+ logger.info(f'_______lama_cleaner_process_______13____')
424
+ torch.cuda.empty_cache()
425
+
426
+ logger.info(f'_______lama_cleaner_process_______14____')
427
+ image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
428
+ logger.info(f'_______lama_cleaner_process_______15____')
429
+ except Exception as e:
430
+ logger.info(f'lama_cleaner_process[Error]:' + str(e))
431
+ image = None
432
+ return image
433
+
434
+ class Ram_Predictor(RamPredictor):
435
+ def __init__(self, config, device='cpu'):
436
+ self.config = config
437
+ self.device = torch.device(device)
438
+ self._build_model()
439
+
440
+ def _build_model(self):
441
+ self.model = RamModel(**self.config.model).to(self.device)
442
+ if self.config.load_from is not None:
443
+ self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
444
+ self.model.train()
445
+
446
+ def load_ram_model(device):
447
+ # load ram model
448
+ global ram_model
449
+ if os.environ.get('IS_MY_DEBUG') is not None:
450
+ return
451
+ model_path = "./checkpoints/ram_epoch12.pth"
452
+ ram_config = dict(
453
+ model=dict(
454
+ pretrained_model_name_or_path='bert-base-uncased',
455
+ load_pretrained_weights=False,
456
+ num_transformer_layer=2,
457
+ input_feature_size=256,
458
+ output_feature_size=768,
459
+ cls_feature_size=512,
460
+ num_relation_classes=56,
461
+ pred_type='attention',
462
+ loss_type='multi_label_ce',
463
+ ),
464
+ load_from=model_path,
465
+ )
466
+ ram_config = mmengine_Config(ram_config)
467
+ ram_model = Ram_Predictor(ram_config, device)
468
+
469
+ # visualization
470
+ def draw_selected_mask(mask, draw):
471
+ color = (255, 0, 0, 153)
472
+ nonzero_coords = np.transpose(np.nonzero(mask))
473
+ for coord in nonzero_coords:
474
+ draw.point(coord[::-1], fill=color)
475
+
476
+ def draw_object_mask(mask, draw):
477
+ color = (0, 0, 255, 153)
478
+ nonzero_coords = np.transpose(np.nonzero(mask))
479
+ for coord in nonzero_coords:
480
+ draw.point(coord[::-1], fill=color)
481
+
482
+ def create_title_image(word1, word2, word3, width, font_path='./assets/OpenSans-Bold.ttf'):
483
+ # Define the colors to use for each word
484
+ color_red = (255, 0, 0)
485
+ color_black = (0, 0, 0)
486
+ color_blue = (0, 0, 255)
487
+
488
+ # Define the initial font size and spacing between words
489
+ font_size = 40
490
+
491
+ # Create a new image with the specified width and white background
492
+ image = Image.new('RGB', (width, 60), (255, 255, 255))
493
+
494
+ try:
495
+ # Load the specified font
496
+ font = ImageFont.truetype(font_path, font_size)
497
+
498
+ # Keep increasing the font size until all words fit within the desired width
499
+ while True:
500
+ # Create a draw object for the image
501
+ draw = ImageDraw.Draw(image)
502
+
503
+ word_spacing = font_size / 2
504
+ # Draw each word in the appropriate color
505
+ x_offset = word_spacing
506
+ draw.text((x_offset, 0), word1, color_red, font=font)
507
+ x_offset += font.getsize(word1)[0] + word_spacing
508
+ draw.text((x_offset, 0), word2, color_black, font=font)
509
+ x_offset += font.getsize(word2)[0] + word_spacing
510
+ draw.text((x_offset, 0), word3, color_blue, font=font)
511
+
512
+ word_sizes = [font.getsize(word) for word in [word1, word2, word3]]
513
+ total_width = sum([size[0] for size in word_sizes]) + word_spacing * 3
514
+
515
+ # Stop increasing font size if the image is within the desired width
516
+ if total_width <= width:
517
+ break
518
+
519
+ # Increase font size and reset the draw object
520
+ font_size -= 1
521
+ image = Image.new('RGB', (width, 50), (255, 255, 255))
522
+ font = ImageFont.truetype(font_path, font_size)
523
+ draw = None
524
+ except Exception as e:
525
+ pass
526
+
527
+ return image
528
+
529
+ def concatenate_images_vertical(image1, image2):
530
+ # Get the dimensions of the two images
531
+ width1, height1 = image1.size
532
+ width2, height2 = image2.size
533
+
534
+ # Create a new image with the combined height and the maximum width
535
+ new_image = Image.new('RGBA', (max(width1, width2), height1 + height2))
536
+
537
+ # Paste the first image at the top of the new image
538
+ new_image.paste(image1, (0, 0))
539
+
540
+ # Paste the second image below the first image
541
+ new_image.paste(image2, (0, height1))
542
+
543
+ return new_image
544
+
545
+ def relate_anything(input_image, k):
546
+ logger.info(f'relate_anything_1_{input_image.size}_')
547
+ w, h = input_image.size
548
+ max_edge = 1500
549
+ if w > max_edge or h > max_edge:
550
+ ratio = max(w, h) / max_edge
551
+ new_size = (int(w / ratio), int(h / ratio))
552
+ input_image.thumbnail(new_size)
553
+
554
+ logger.info(f'relate_anything_2_')
555
+ # load image
556
+ pil_image = input_image.convert('RGBA')
557
+ image = np.array(input_image)
558
+ sam_masks = sam_mask_generator.generate(image)
559
+ filtered_masks = sort_and_deduplicate(sam_masks)
560
+
561
+ logger.info(f'relate_anything_3_')
562
+ feat_list = []
563
+ for fm in filtered_masks:
564
+ feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
565
+ feat_list.append(feat)
566
+ feat = torch.cat(feat_list, dim=1).to(device)
567
+ matrix_output, rel_triplets = ram_model.predict(feat)
568
+
569
+ logger.info(f'relate_anything_4_')
570
+ pil_image_list = []
571
+ for i, rel in enumerate(rel_triplets[:k]):
572
+ s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
573
+ relation = relation_classes[r]
574
+
575
+ mask_image = Image.new('RGBA', pil_image.size, color=(0, 0, 0, 0))
576
+ mask_draw = ImageDraw.Draw(mask_image)
577
+
578
+ draw_selected_mask(filtered_masks[s]['segmentation'], mask_draw)
579
+ draw_object_mask(filtered_masks[o]['segmentation'], mask_draw)
580
+
581
+ current_pil_image = pil_image.copy()
582
+ current_pil_image.alpha_composite(mask_image)
583
+
584
+ title_image = create_title_image('Red', relation, 'Blue', current_pil_image.size[0])
585
+ concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
586
+ pil_image_list.append(concate_pil_image)
587
+
588
+ logger.info(f'relate_anything_5_{len(pil_image_list)}')
589
+ return pil_image_list
590
+
591
+ mask_source_draw = "draw a mask on input image"
592
+ mask_source_segment = "type what to detect below"
593
+
594
+ def get_time_cost(run_task_time, time_cost_str):
595
+ now_time = int(time.time()*1000)
596
+ if run_task_time == 0:
597
+ time_cost_str = 'start'
598
+ else:
599
+ if time_cost_str != '':
600
+ time_cost_str += f'-->'
601
+ time_cost_str += f'{now_time - run_task_time}'
602
+ run_task_time = now_time
603
+ return run_task_time, time_cost_str
604
+
605
+ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
606
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
607
+
608
+ text_prompt = getTextTrans(text_prompt, source='zh', target='en')
609
+ inpaint_prompt = getTextTrans(inpaint_prompt, source='zh', target='en')
610
+
611
+ run_task_time = 0
612
+ time_cost_str = ''
613
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
614
+
615
+ text_prompt = text_prompt.strip()
616
+ if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
617
+ if text_prompt == '':
618
+ return [], gr.Gallery.update(label='Detection prompt is not found!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
619
+
620
+ if input_image is None:
621
+ return [], gr.Gallery.update(label='Please upload a image!😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
622
+
623
+ file_temp = int(time.time())
624
+ logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
625
+
626
+ output_images = []
627
+
628
+ image_pil, image = load_image(input_image.convert("RGB"))
629
+ input_img = input_image
630
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
631
+
632
+ size = image_pil.size
633
+ H, W = size[1], size[0]
634
+
635
+ # run grounding dino model
636
+ if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
637
+ pass
638
+ else:
639
+ groundingdino_device = 'cpu'
640
+ if device != 'cpu':
641
+ try:
642
+ from groundingdino import _C
643
+ groundingdino_device = 'cuda:0'
644
+ except:
645
+ warnings.warn("Failed to load custom C++ ops. Running on CPU mode Only in groundingdino!")
646
+
647
+ boxes_filt, pred_phrases = get_grounding_output(
648
+ groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
649
+ )
650
+ if boxes_filt.size(0) == 0:
651
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
652
+ return [], gr.Gallery.update(label='No objects detected, please try others.😂😂😂😂'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
653
+ boxes_filt_ori = copy.deepcopy(boxes_filt)
654
+
655
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
656
+ if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
657
+ image = np.array(input_img)
658
+ if sam_predictor:
659
+ sam_predictor.set_image(image)
660
+
661
+ for i in range(boxes_filt.size(0)):
662
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
663
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
664
+ boxes_filt[i][2:] += boxes_filt[i][:2]
665
+
666
+ if sam_predictor:
667
+ boxes_filt = boxes_filt.to(sam_device)
668
+ transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
669
+
670
+ masks, _, _, _ = sam_predictor.predict_torch(
671
+ point_coords = None,
672
+ point_labels = None,
673
+ boxes = transformed_boxes,
674
+ multimask_output = False,
675
+ )
676
+ # masks: [9, 1, 512, 512]
677
+ assert sam_checkpoint, 'sam_checkpoint is not found!'
678
+ else:
679
+ masks = torch.zeros(len(boxes_filt), 1, H, W)
680
+ mask_count = 0
681
+ for box in boxes_filt:
682
+ masks[mask_count, 0, int(box[1]):int(box[3]), int(box[0]):int(box[2])] = 1
683
+ mask_count += 1
684
+ masks = torch.where(masks > 0, True, False)
685
+ run_mode = "rectangle"
686
+
687
+ # draw output image
688
+ plt.figure(figsize=(10, 10))
689
+ plt.imshow(image)
690
+ for mask in masks:
691
+ show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
692
+ for box, label in zip(boxes_filt, pred_phrases):
693
+ show_box(box.cpu().numpy(), plt.gca(), label)
694
+ plt.axis('off')
695
+ image_path = os.path.join(output_dir, f"grounding_seg_output_{file_temp}.jpg")
696
+ plt.savefig(image_path, bbox_inches="tight")
697
+ plt.clf()
698
+ plt.close('all')
699
+ segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
700
+ os.remove(image_path)
701
+ output_images.append(Image.fromarray(segment_image_result))
702
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
703
+
704
+ print(sam_predictor)
705
+
706
+ if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
707
+ task_type = 'remove'
708
+
709
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_4_')
710
+ if mask_source_radio == mask_source_draw:
711
+ mask_pil = input_mask_pil
712
+ mask = input_mask
713
+ else:
714
+ masks_ori = copy.deepcopy(masks)
715
+ if inpaint_mode == 'merge':
716
+ masks = torch.sum(masks, dim=0).unsqueeze(0)
717
+ masks = torch.where(masks > 0, True, False)
718
+ mask = masks[0][0].cpu().numpy()
719
+ mask_pil = Image.fromarray(mask)
720
+ output_images.append(mask_pil.convert("RGB"))
721
+ return mask_pil
722
+
723
+ def change_radio_display(task_type, mask_source_radio):
724
+ text_prompt_visible = True
725
+ inpaint_prompt_visible = False
726
+ mask_source_radio_visible = False
727
+ num_relation_visible = False
728
+
729
+ image_gallery_visible = True
730
+ kosmos_input_visible = False
731
+ kosmos_output_visible = False
732
+ kosmos_text_output_visible = False
733
+
734
+ if task_type == "Kosmos-2":
735
+ if kosmos_enable:
736
+ text_prompt_visible = False
737
+ image_gallery_visible = False
738
+ kosmos_input_visible = True
739
+ kosmos_output_visible = True
740
+ kosmos_text_output_visible = True
741
+
742
+ if task_type == "inpainting":
743
+ inpaint_prompt_visible = True
744
+ if task_type == "inpainting" or task_type == "remove":
745
+ mask_source_radio_visible = True
746
+ if mask_source_radio == mask_source_draw:
747
+ text_prompt_visible = False
748
+ if task_type == "relate anything":
749
+ text_prompt_visible = False
750
+ num_relation_visible = True
751
+
752
+ return (gr.Textbox.update(visible=text_prompt_visible),
753
+ gr.Textbox.update(visible=inpaint_prompt_visible),
754
+ gr.Radio.update(visible=mask_source_radio_visible),
755
+ gr.Slider.update(visible=num_relation_visible),
756
+ gr.Gallery.update(visible=image_gallery_visible),
757
+ gr.Radio.update(visible=kosmos_input_visible),
758
+ gr.Image.update(visible=kosmos_output_visible),
759
+ gr.HighlightedText.update(visible=kosmos_text_output_visible))
760
+
761
+ def get_model_device(module):
762
+ try:
763
+ if module is None:
764
+ return 'None'
765
+ if isinstance(module, torch.nn.DataParallel):
766
+ module = module.module
767
+ for submodule in module.children():
768
+ if hasattr(submodule, "_parameters"):
769
+ parameters = submodule._parameters
770
+ if "weight" in parameters:
771
+ return parameters["weight"].device
772
+ return 'UnKnown'
773
+ except Exception as e:
774
+ return 'Error'
775
+
776
+ def main_gradio(args):
777
+ block = gr.Blocks().queue()
778
+ with block:
779
+ with gr.Row():
780
+ with gr.Column():
781
+ task_types = ["detection"]
782
+ if sam_enable:
783
+ task_types.append("segment")
784
+ if inpainting_enable:
785
+ task_types.append("inpainting")
786
+ if lama_cleaner_enable:
787
+ task_types.append("remove")
788
+ if ram_enable:
789
+ task_types.append("relate anything")
790
+ if kosmos_enable:
791
+ task_types.append("Kosmos-2")
792
+
793
+ input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload",
794
+ height=512, brush_color='#00FFFF', mask_opacity=0.6)
795
+ task_type = gr.Radio(task_types, value="detection",
796
+ label='Task type', visible=True)
797
+ mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
798
+ value=mask_source_segment, label="Mask from",
799
+ visible=False)
800
+ text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
801
+ inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
802
+ num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
803
+
804
+ kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
805
+
806
+ run_button = gr.Button(label="Run", visible=True)
807
+ with gr.Accordion("Advanced options", open=False) as advanced_options:
808
+ box_threshold = gr.Slider(
809
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
810
+ )
811
+ text_threshold = gr.Slider(
812
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
813
+ )
814
+ iou_threshold = gr.Slider(
815
+ label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
816
+ )
817
+ inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
818
+ with gr.Row():
819
+ with gr.Column(scale=1):
820
+ remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
821
+ with gr.Column(scale=1):
822
+ remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
823
+
824
+ with gr.Column():
825
+ image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
826
+ ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
827
+ time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
828
+
829
+ kosmos_output = gr.Image(type="pil", label="result images", visible=False)
830
+ kosmos_text_output = gr.HighlightedText(
831
+ label="Generated Description",
832
+ combine_adjacent=False,
833
+ show_legend=True,
834
+ visible=False,
835
+ ).style(color_map=color_map)
836
+ # record which text span (label) is selected
837
+ selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
838
+
839
+ # record the current `entities`
840
+ entity_output = gr.Textbox(visible=False)
841
+
842
+ # get the current selected span label
843
+ def get_text_span_label(evt: gr.SelectData):
844
+ if evt.value[-1] is None:
845
+ return -1
846
+ return int(evt.value[-1])
847
+ # and set this information to `selected`
848
+ kosmos_text_output.select(get_text_span_label, None, selected)
849
+
850
+ # update output image when we change the span (enity) selection
851
+ def update_output_image(img_input, image_output, entities, idx):
852
+ entities = ast.literal_eval(entities)
853
+ updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
854
+ return updated_image
855
+ selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
856
+
857
+ run_button.click(fn=run_anything_task, inputs=[
858
+ input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
859
+ iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
860
+ outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
861
+
862
+ mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
863
+ outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
864
+ task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
865
+ outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
866
+ image_gallery, kosmos_input, kosmos_output, kosmos_text_output
867
+ ])
868
+
869
+ DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
870
+ if lama_cleaner_enable:
871
+ DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner). <br>'
872
+ if kosmos_enable:
873
+ DESCRIPTION += f'Kosmos-2 from [Kosmos-2](https://github.com/microsoft/unilm/tree/master/kosmos-2). <br>'
874
+ if ram_enable:
875
+ DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
876
+ DESCRIPTION += f'Thanks for their excellent work.'
877
+ DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \
878
+ <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
879
+ gr.Markdown(DESCRIPTION)
880
+
881
+ print(f'device = {device}')
882
+ print(f'torch.cuda.is_available = {torch.cuda.is_available()}')
883
+ computer_info()
884
+ block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
885
+
886
+ import signal
887
+ import json
888
+ from datetime import date, datetime, timedelta
889
+ from gevent import pywsgi
890
+ import base64
891
+
892
+ def imgFile_to_base64(image_file):
893
+ with open(image_file, "rb") as f:
894
+ im_bytes = f.read()
895
+ im_b64_encode = base64.b64encode(im_bytes)
896
+ im_b64 = im_b64_encode.decode("utf8")
897
+ return im_b64
898
+
899
+ def base64_to_bytes(im_b64):
900
+ im_b64_encode = im_b64.encode("utf-8")
901
+ im_bytes = base64.b64decode(im_b64_encode)
902
+ return im_bytes
903
+
904
+ def base64_to_PILImage(im_b64):
905
+ im_bytes = base64_to_bytes(im_b64)
906
+ pil_img = Image.open(io.BytesIO(im_bytes))
907
+ return pil_img
908
+
909
+ class API_Starter:
910
+ def __init__(self):
911
+ from flask import Flask, request, jsonify, make_response
912
+ from flask_cors import CORS, cross_origin
913
+ import logging
914
+
915
+ app = Flask(__name__)
916
+ app.logger.setLevel(logging.ERROR)
917
+ CORS(app, supports_credentials=True, resources={r"/*": {"origins": "*"}})
918
+
919
+ @app.route('/imgCLeaner', methods=['GET', 'POST'])
920
+ @cross_origin()
921
+ def processAssist():
922
+ if request.method == 'GET':
923
+ ret_json = {'code': -1, 'reason':'no support to get'}
924
+ elif request.method == 'POST':
925
+ request_data = request.data.decode('utf-8')
926
+ data = json.loads(request_data)
927
+ result = self.handle_data(data)
928
+ if result is None:
929
+ ret_json = {'code': -2, 'reason':'handle error'}
930
+ else:
931
+ ret_json = {'code': 0, 'result':result}
932
+ return jsonify(ret_json)
933
+
934
+ self.app = app
935
+ now_time = datetime.now().strftime('%Y%m%d_%H%M%S')
936
+ logger.add(f'./logs/logger_[{args.port}]_{now_time}.log')
937
+ signal.signal(signal.SIGINT, self.signal_handler)
938
+
939
+ def handle_data(self, data):
940
+ im_b64 = data['img']
941
+ img = base64_to_PILImage(im_b64)
942
+ remove_texts = data['remove_texts']
943
+ remove_mask_extend = data['mask_extend']
944
+ results = run_anything_task(input_image = img,
945
+ text_prompt = f"{remove_texts}",
946
+ task_type = 'remove',
947
+ inpaint_prompt = '',
948
+ box_threshold = 0.3,
949
+ text_threshold = 0.25,
950
+ iou_threshold = 0.8,
951
+ inpaint_mode = "merge",
952
+ mask_source_radio = "type what to detect below",
953
+ remove_mode = "rectangle", # ["segment", "rectangle"]
954
+ remove_mask_extend = f"{remove_mask_extend}",
955
+ num_relation = 5,
956
+ kosmos_input = None,
957
+ cleaner_size_limit = -1,
958
+ )
959
+ output_images = results[0]
960
+ if output_images is None:
961
+ return None
962
+ ret_json_images = []
963
+ file_temp = int(time.time())
964
+ count = 0
965
+ output_images = output_images[-1:]
966
+ for image_pil in output_images:
967
+ try:
968
+ img_format = image_pil.format.lower()
969
+ except Exception as e:
970
+ img_format = 'png'
971
+ image_path = os.path.join(output_dir, f"api_images_{file_temp}_{count}.{img_format}")
972
+ count += 1
973
+ try:
974
+ image_pil.save(image_path)
975
+ except Exception as e:
976
+ Image.fromarray(image_pil).save(image_path)
977
+ im_b64 = imgFile_to_base64(image_path)
978
+ ret_json_images.append(im_b64)
979
+ os.remove(image_path)
980
+ data = {
981
+ 'imgs': ret_json_images,
982
+ }
983
+ return data
984
+
985
+ def signal_handler(self, signal, frame):
986
+ print('\nSignal Catched! You have just type Ctrl+C!')
987
+ sys.exit(0)
988
+
989
+ def run(self):
990
+ from gevent import pywsgi
991
+ logger.info(f'\nargs={args}\n')
992
+ computer_info()
993
+ print(f"Start a api server: http://0.0.0.0:{args.port}/imgCLeaner")
994
+ server = pywsgi.WSGIServer(('0.0.0.0', args.port), self.app)
995
+ server.serve_forever()
996
+
997
+ def main_api(args):
998
+ if args.port == 0:
999
+ print('Please give valid port!')
1000
+ else:
1001
+ api_starter = API_Starter()
1002
+ api_starter.run()
1003
+
1004
+ if __name__ == "__main__":
1005
+ parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
1006
+ parser.add_argument("--debug", action="store_true", help="using debug mode")
1007
+ parser.add_argument("--share", action="store_true", help="share the app")
1008
+ parser.add_argument("--port", "-p", type=int, default=7860, help="port")
1009
+ args, _ = parser.parse_known_args()
1010
+ print(f'args = {args}')
1011
+
1012
+ if os.environ.get('IS_MY_DEBUG') is None:
1013
+ os.system("pip list")
1014
+
1015
+ device = set_device()
1016
+ if device == 'cpu':
1017
+ kosmos_enable = False
1018
+
1019
+ if kosmos_enable:
1020
+ kosmos_model, kosmos_processor = load_kosmos_model(device)
1021
+
1022
+ if groundingdino_enable:
1023
+ groundingdino_model = load_groundingdino_model('cpu')
1024
+
1025
+ if sam_enable:
1026
+ load_sam_model(device)
1027
+
1028
+ if inpainting_enable:
1029
+ load_sd_model(device)
1030
+
1031
+ if lama_cleaner_enable:
1032
+ load_lama_cleaner_model(device)
1033
+
1034
+ if ram_enable:
1035
+ load_ram_model(device)
1036
+
1037
+ if os.environ.get('IS_MY_DEBUG') is None:
1038
+ os.system("pip list")
1039
+
1040
+ def just_fucking_get_sd_mask(input_pil, prompt):
1041
+ return run_anything_task(input_pil, prompt, "inpainting", "", 0.3, 0.25, 0.8, "merge", "type what to detect below", "segment", "10", 5, "Brief")
1042
+
1043
+ just_fucking_get_sd_mask(Image.open("chick.png"), "face . shoes").save("fucking.png")
1044
+ just_fucking_get_sd_mask(Image.open("chick.png"), "face . shoes").save("fucking2.png")
1045
+
1046
+ class EndpointHandler():
1047
+ def __init__(self, path=""):
1048
+ pass
1049
+
1050
+ def __call__(self, data):
1051
+ original_link = data.get("original_link")
1052
+ response = requests.get(original_link, verify=False)
1053
+ byte_arr = response.content
1054
+ original_image = Image.open(io.BytesIO(byte_arr))
1055
+
1056
+ mask_pil = just_fucking_get_sd_mask(original_image, "person")
1057
+
1058
+ img_byte_arr = io.BytesIO()
1059
+ mask_pil.save(img_byte_arr, format="PNG")
1060
+ img_byte_arr = img_byte_arr.getvalue()
1061
+
1062
+ # Upload to file.io
1063
+ response = requests.post('https://file.io', files={'file': img_byte_arr})
1064
+ link = response.json()['link']
1065
+
1066
+ return link
1067
+
1068
+ print(EndpointHandler()({
1069
+ "original_link": "https://cdn.karneval-megastore.de/images/rep_art/gra/310/6/310698/justice-league-wonder-woman-damenkostum-lizenzware-blau-gold-rot.jpg"
1070
+ }))
requirements.txt CHANGED
@@ -15,8 +15,10 @@ setuptools
15
  supervision
16
  termcolor
17
  timm
18
- torch==2.0.0
19
- torchvision==0.15.1
 
 
20
 
21
  gevent
22
  yapf
@@ -33,7 +35,7 @@ transformers==4.27.4
33
  # lama-cleaner==1.2.4
34
  lama-cleaner@git+https://github.com/yizhangliu/lama-cleaner.git@main
35
 
36
- mmcv==2.0.0
37
  mmengine
38
  openmim==0.3.9
39
 
 
15
  supervision
16
  termcolor
17
  timm
18
+ # torch==2.0.0 # is production
19
+ # torchvision==0.15.1 # is production
20
+ # torch
21
+ # torchvision
22
 
23
  gevent
24
  yapf
 
35
  # lama-cleaner==1.2.4
36
  lama-cleaner@git+https://github.com/yizhangliu/lama-cleaner.git@main
37
 
38
+ mmcv==1.7.1
39
  mmengine
40
  openmim==0.3.9
41