Update handler.py
Browse files- handler.py +13 -441
handler.py
CHANGED
@@ -2,7 +2,7 @@ import string
|
|
2 |
import warnings
|
3 |
warnings.filterwarnings('ignore')
|
4 |
import subprocess, io, os, sys, time
|
5 |
-
|
6 |
|
7 |
# os.environ["XFORMERS_DISABLE_FLASH_ATTN"] = "1"
|
8 |
# result = subprocess.run(['pip', 'install', 'xformers'], check=True)
|
@@ -77,19 +77,6 @@ kosmos_enable = False
|
|
77 |
# qwen_enable = True
|
78 |
# from qwen_utils import *
|
79 |
|
80 |
-
if os.environ.get('IS_MY_DEBUG') is not None:
|
81 |
-
sam_enable = False
|
82 |
-
ram_enable = False
|
83 |
-
inpainting_enable = False
|
84 |
-
kosmos_enable = False
|
85 |
-
|
86 |
-
if lama_cleaner_enable:
|
87 |
-
try:
|
88 |
-
from lama_cleaner.model_manager import ModelManager
|
89 |
-
from lama_cleaner.schema import Config as lama_Config
|
90 |
-
except Exception as e:
|
91 |
-
lama_cleaner_enable = False
|
92 |
-
|
93 |
# segment anything
|
94 |
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
|
95 |
|
@@ -98,34 +85,8 @@ import PIL
|
|
98 |
import requests
|
99 |
import torch
|
100 |
from io import BytesIO
|
101 |
-
from diffusers import StableDiffusionInpaintPipeline
|
102 |
from huggingface_hub import hf_hub_download
|
103 |
|
104 |
-
from util_computer import computer_info
|
105 |
-
|
106 |
-
# relate anything
|
107 |
-
from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask
|
108 |
-
from ram_train_eval import RamModel, RamPredictor
|
109 |
-
from mmengine.config import Config as mmengine_Config
|
110 |
-
|
111 |
-
if lama_cleaner_enable:
|
112 |
-
from lama_cleaner.helper import (
|
113 |
-
load_img,
|
114 |
-
numpy_to_bytes,
|
115 |
-
resize_max_size,
|
116 |
-
)
|
117 |
-
|
118 |
-
# from transformers import AutoProcessor, AutoModelForVision2Seq
|
119 |
-
import ast
|
120 |
-
|
121 |
-
if kosmos_enable and install_stuff:
|
122 |
-
os.system("pip install transformers@git+https://github.com/huggingface/transformers.git@main")
|
123 |
-
# os.system("pip install transformers==4.32.0")
|
124 |
-
|
125 |
-
from kosmos_utils import *
|
126 |
-
|
127 |
-
from util_tencent import getTextTrans
|
128 |
-
|
129 |
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
|
130 |
ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
131 |
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
@@ -133,10 +94,7 @@ sam_checkpoint = './sam_hq_vit_h.pth'
|
|
133 |
output_dir = "outputs"
|
134 |
|
135 |
device = 'cpu'
|
136 |
-
os.makedirs(output_dir, exist_ok=True)
|
137 |
-
groundingdino_model = None
|
138 |
sam_device = "cuda"
|
139 |
-
sam_model = None
|
140 |
|
141 |
|
142 |
def get_sam_vit_h_4b8939():
|
@@ -150,20 +108,20 @@ def get_sam_vit_h_4b8939():
|
|
150 |
f.write(response.content)
|
151 |
print('Downloaded sam_vit_h_4b8939.pth')
|
152 |
|
153 |
-
get_sam_vit_h_4b8939()
|
154 |
logger.info(f"initialize SAM model...")
|
155 |
sam_device = "cuda"
|
156 |
-
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
157 |
-
sam_predictor = SamPredictor(sam_model)
|
158 |
-
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
|
159 |
|
160 |
-
sam_mask_generator = None
|
161 |
sd_model = None
|
162 |
lama_cleaner_model= None
|
163 |
ram_model = None
|
164 |
kosmos_model = None
|
165 |
kosmos_processor = None
|
166 |
|
|
|
|
|
|
|
|
|
|
|
167 |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
168 |
args = SLConfig.fromfile(model_config_path)
|
169 |
model = build_model(args)
|
@@ -324,167 +282,6 @@ def mix_masks(imgs):
|
|
324 |
re_img = 1 - re_img
|
325 |
return Image.fromarray(np.uint8(255*re_img))
|
326 |
|
327 |
-
def set_device():
|
328 |
-
if os.environ.get('IS_MY_DEBUG') is None:
|
329 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
330 |
-
else:
|
331 |
-
device = 'cpu'
|
332 |
-
print(f'device={device}')
|
333 |
-
return device
|
334 |
-
|
335 |
-
def load_groundingdino_model(device):
|
336 |
-
# initialize groundingdino model
|
337 |
-
logger.info(f"initialize groundingdino model...")
|
338 |
-
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device) #'cpu')
|
339 |
-
return groundingdino_model
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
def load_sam_model(device):
|
344 |
-
# initialize SAM
|
345 |
-
global sam_model, sam_predictor, sam_mask_generator, sam_device
|
346 |
-
get_sam_vit_h_4b8939()
|
347 |
-
logger.info(f"initialize SAM model...")
|
348 |
-
sam_device = device
|
349 |
-
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
350 |
-
sam_predictor = SamPredictor(sam_model)
|
351 |
-
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
|
352 |
-
|
353 |
-
def load_sd_model(device):
|
354 |
-
# initialize stable-diffusion-inpainting
|
355 |
-
global sd_model
|
356 |
-
logger.info(f"initialize stable-diffusion-inpainting...")
|
357 |
-
sd_model = None
|
358 |
-
if os.environ.get('IS_MY_DEBUG') is None:
|
359 |
-
sd_model = StableDiffusionInpaintPipeline.from_pretrained(
|
360 |
-
"runwayml/stable-diffusion-inpainting",
|
361 |
-
revision="fp16",
|
362 |
-
# "stabilityai/stable-diffusion-2-inpainting",
|
363 |
-
torch_dtype=torch.float16,
|
364 |
-
)
|
365 |
-
sd_model = sd_model.to(device)
|
366 |
-
|
367 |
-
def load_lama_cleaner_model(device):
|
368 |
-
# initialize lama_cleaner
|
369 |
-
global lama_cleaner_model
|
370 |
-
logger.info(f"initialize lama_cleaner...")
|
371 |
-
|
372 |
-
lama_cleaner_model = ModelManager(
|
373 |
-
name='lama',
|
374 |
-
device=device,
|
375 |
-
)
|
376 |
-
|
377 |
-
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
378 |
-
try:
|
379 |
-
logger.info(f'_______lama_cleaner_process_______1____')
|
380 |
-
ori_image = image
|
381 |
-
if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
|
382 |
-
# rotate image
|
383 |
-
logger.info(f'_______lama_cleaner_process_______2____')
|
384 |
-
ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
|
385 |
-
logger.info(f'_______lama_cleaner_process_______3____')
|
386 |
-
image = ori_image
|
387 |
-
|
388 |
-
logger.info(f'_______lama_cleaner_process_______4____')
|
389 |
-
original_shape = ori_image.shape
|
390 |
-
logger.info(f'_______lama_cleaner_process_______5____')
|
391 |
-
interpolation = cv2.INTER_CUBIC
|
392 |
-
|
393 |
-
size_limit = cleaner_size_limit
|
394 |
-
if size_limit == -1:
|
395 |
-
logger.info(f'_______lama_cleaner_process_______6____')
|
396 |
-
size_limit = max(image.shape)
|
397 |
-
else:
|
398 |
-
logger.info(f'_______lama_cleaner_process_______7____')
|
399 |
-
size_limit = int(size_limit)
|
400 |
-
|
401 |
-
logger.info(f'_______lama_cleaner_process_______8____')
|
402 |
-
config = lama_Config(
|
403 |
-
ldm_steps=25,
|
404 |
-
ldm_sampler='plms',
|
405 |
-
zits_wireframe=True,
|
406 |
-
hd_strategy='Original',
|
407 |
-
hd_strategy_crop_margin=196,
|
408 |
-
hd_strategy_crop_trigger_size=1280,
|
409 |
-
hd_strategy_resize_limit=2048,
|
410 |
-
prompt='',
|
411 |
-
use_croper=False,
|
412 |
-
croper_x=0,
|
413 |
-
croper_y=0,
|
414 |
-
croper_height=512,
|
415 |
-
croper_width=512,
|
416 |
-
sd_mask_blur=5,
|
417 |
-
sd_strength=0.75,
|
418 |
-
sd_steps=50,
|
419 |
-
sd_guidance_scale=7.5,
|
420 |
-
sd_sampler='ddim',
|
421 |
-
sd_seed=42,
|
422 |
-
cv2_flag='INPAINT_NS',
|
423 |
-
cv2_radius=5,
|
424 |
-
)
|
425 |
-
|
426 |
-
logger.info(f'_______lama_cleaner_process_______9____')
|
427 |
-
if config.sd_seed == -1:
|
428 |
-
config.sd_seed = random.randint(1, 999999999)
|
429 |
-
|
430 |
-
# logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
|
431 |
-
logger.info(f'_______lama_cleaner_process_______10____')
|
432 |
-
image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
|
433 |
-
# logger.info(f"Resized image shape_1_: {image.shape}")
|
434 |
-
|
435 |
-
# logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
|
436 |
-
logger.info(f'_______lama_cleaner_process_______11____')
|
437 |
-
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
438 |
-
# logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
|
439 |
-
|
440 |
-
logger.info(f'_______lama_cleaner_process_______12____')
|
441 |
-
res_np_img = lama_cleaner_model(image, mask, config)
|
442 |
-
logger.info(f'_______lama_cleaner_process_______13____')
|
443 |
-
torch.cuda.empty_cache()
|
444 |
-
|
445 |
-
logger.info(f'_______lama_cleaner_process_______14____')
|
446 |
-
image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
|
447 |
-
logger.info(f'_______lama_cleaner_process_______15____')
|
448 |
-
except Exception as e:
|
449 |
-
logger.info(f'lama_cleaner_process[Error]:' + str(e))
|
450 |
-
image = None
|
451 |
-
return image
|
452 |
-
|
453 |
-
class Ram_Predictor(RamPredictor):
|
454 |
-
def __init__(self, config, device='cpu'):
|
455 |
-
self.config = config
|
456 |
-
self.device = torch.device(device)
|
457 |
-
self._build_model()
|
458 |
-
|
459 |
-
def _build_model(self):
|
460 |
-
self.model = RamModel(**self.config.model).to(self.device)
|
461 |
-
if self.config.load_from is not None:
|
462 |
-
self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
|
463 |
-
self.model.train()
|
464 |
-
|
465 |
-
def load_ram_model(device):
|
466 |
-
# load ram model
|
467 |
-
global ram_model
|
468 |
-
if os.environ.get('IS_MY_DEBUG') is not None:
|
469 |
-
return
|
470 |
-
model_path = "./checkpoints/ram_epoch12.pth"
|
471 |
-
ram_config = dict(
|
472 |
-
model=dict(
|
473 |
-
pretrained_model_name_or_path='bert-base-uncased',
|
474 |
-
load_pretrained_weights=False,
|
475 |
-
num_transformer_layer=2,
|
476 |
-
input_feature_size=256,
|
477 |
-
output_feature_size=768,
|
478 |
-
cls_feature_size=512,
|
479 |
-
num_relation_classes=56,
|
480 |
-
pred_type='attention',
|
481 |
-
loss_type='multi_label_ce',
|
482 |
-
),
|
483 |
-
load_from=model_path,
|
484 |
-
)
|
485 |
-
ram_config = mmengine_Config(ram_config)
|
486 |
-
ram_model = Ram_Predictor(ram_config, device)
|
487 |
-
|
488 |
# visualization
|
489 |
def draw_selected_mask(mask, draw):
|
490 |
color = (255, 0, 0, 153)
|
@@ -623,10 +420,6 @@ def get_time_cost(run_task_time, time_cost_str):
|
|
623 |
|
624 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
625 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
626 |
-
|
627 |
-
text_prompt = getTextTrans(text_prompt, source='zh', target='en')
|
628 |
-
inpaint_prompt = getTextTrans(inpaint_prompt, source='zh', target='en')
|
629 |
-
|
630 |
run_task_time = 0
|
631 |
time_cost_str = ''
|
632 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
@@ -792,115 +585,6 @@ def get_model_device(module):
|
|
792 |
except Exception as e:
|
793 |
return 'Error'
|
794 |
|
795 |
-
def main_gradio(args):
|
796 |
-
block = gr.Blocks().queue()
|
797 |
-
with block:
|
798 |
-
with gr.Row():
|
799 |
-
with gr.Column():
|
800 |
-
task_types = ["detection"]
|
801 |
-
if sam_enable:
|
802 |
-
task_types.append("segment")
|
803 |
-
if inpainting_enable:
|
804 |
-
task_types.append("inpainting")
|
805 |
-
if lama_cleaner_enable:
|
806 |
-
task_types.append("remove")
|
807 |
-
if ram_enable:
|
808 |
-
task_types.append("relate anything")
|
809 |
-
if kosmos_enable:
|
810 |
-
task_types.append("Kosmos-2")
|
811 |
-
|
812 |
-
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload",
|
813 |
-
height=512, brush_color='#00FFFF', mask_opacity=0.6)
|
814 |
-
task_type = gr.Radio(task_types, value="detection",
|
815 |
-
label='Task type', visible=True)
|
816 |
-
mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
|
817 |
-
value=mask_source_segment, label="Mask from",
|
818 |
-
visible=False)
|
819 |
-
text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
|
820 |
-
inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
|
821 |
-
num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
|
822 |
-
|
823 |
-
kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
|
824 |
-
|
825 |
-
run_button = gr.Button(label="Run", visible=True)
|
826 |
-
with gr.Accordion("Advanced options", open=False) as advanced_options:
|
827 |
-
box_threshold = gr.Slider(
|
828 |
-
label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
|
829 |
-
)
|
830 |
-
text_threshold = gr.Slider(
|
831 |
-
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
|
832 |
-
)
|
833 |
-
iou_threshold = gr.Slider(
|
834 |
-
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
|
835 |
-
)
|
836 |
-
inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
|
837 |
-
with gr.Row():
|
838 |
-
with gr.Column(scale=1):
|
839 |
-
remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
|
840 |
-
with gr.Column(scale=1):
|
841 |
-
remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
|
842 |
-
|
843 |
-
with gr.Column():
|
844 |
-
image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
|
845 |
-
).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
|
846 |
-
time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
|
847 |
-
|
848 |
-
kosmos_output = gr.Image(type="pil", label="result images", visible=False)
|
849 |
-
kosmos_text_output = gr.HighlightedText(
|
850 |
-
label="Generated Description",
|
851 |
-
combine_adjacent=False,
|
852 |
-
show_legend=True,
|
853 |
-
visible=False,
|
854 |
-
).style(color_map=color_map)
|
855 |
-
# record which text span (label) is selected
|
856 |
-
selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
|
857 |
-
|
858 |
-
# record the current `entities`
|
859 |
-
entity_output = gr.Textbox(visible=False)
|
860 |
-
|
861 |
-
# get the current selected span label
|
862 |
-
def get_text_span_label(evt: gr.SelectData):
|
863 |
-
if evt.value[-1] is None:
|
864 |
-
return -1
|
865 |
-
return int(evt.value[-1])
|
866 |
-
# and set this information to `selected`
|
867 |
-
kosmos_text_output.select(get_text_span_label, None, selected)
|
868 |
-
|
869 |
-
# update output image when we change the span (enity) selection
|
870 |
-
def update_output_image(img_input, image_output, entities, idx):
|
871 |
-
entities = ast.literal_eval(entities)
|
872 |
-
updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
|
873 |
-
return updated_image
|
874 |
-
selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
|
875 |
-
|
876 |
-
run_button.click(fn=run_anything_task, inputs=[
|
877 |
-
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
878 |
-
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
|
879 |
-
outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
|
880 |
-
|
881 |
-
mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
|
882 |
-
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
|
883 |
-
task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
|
884 |
-
outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
|
885 |
-
image_gallery, kosmos_input, kosmos_output, kosmos_text_output
|
886 |
-
])
|
887 |
-
|
888 |
-
DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
|
889 |
-
if lama_cleaner_enable:
|
890 |
-
DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner). <br>'
|
891 |
-
if kosmos_enable:
|
892 |
-
DESCRIPTION += f'Kosmos-2 from [Kosmos-2](https://github.com/microsoft/unilm/tree/master/kosmos-2). <br>'
|
893 |
-
if ram_enable:
|
894 |
-
DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
|
895 |
-
DESCRIPTION += f'Thanks for their excellent work.'
|
896 |
-
DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \
|
897 |
-
<a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
|
898 |
-
gr.Markdown(DESCRIPTION)
|
899 |
-
|
900 |
-
print(f'device = {device}')
|
901 |
-
print(f'torch.cuda.is_available = {torch.cuda.is_available()}')
|
902 |
-
computer_info()
|
903 |
-
block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
|
904 |
|
905 |
import signal
|
906 |
import json
|
@@ -908,122 +592,13 @@ from datetime import date, datetime, timedelta
|
|
908 |
from gevent import pywsgi
|
909 |
import base64
|
910 |
|
911 |
-
def
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
return im_b64
|
917 |
-
|
918 |
-
def base64_to_bytes(im_b64):
|
919 |
-
im_b64_encode = im_b64.encode("utf-8")
|
920 |
-
im_bytes = base64.b64decode(im_b64_encode)
|
921 |
-
return im_bytes
|
922 |
-
|
923 |
-
def base64_to_PILImage(im_b64):
|
924 |
-
im_bytes = base64_to_bytes(im_b64)
|
925 |
-
pil_img = Image.open(io.BytesIO(im_bytes))
|
926 |
-
return pil_img
|
927 |
-
|
928 |
-
class API_Starter:
|
929 |
-
def __init__(self):
|
930 |
-
from flask import Flask, request, jsonify, make_response
|
931 |
-
from flask_cors import CORS, cross_origin
|
932 |
-
import logging
|
933 |
-
|
934 |
-
app = Flask(__name__)
|
935 |
-
app.logger.setLevel(logging.ERROR)
|
936 |
-
CORS(app, supports_credentials=True, resources={r"/*": {"origins": "*"}})
|
937 |
-
|
938 |
-
@app.route('/imgCLeaner', methods=['GET', 'POST'])
|
939 |
-
@cross_origin()
|
940 |
-
def processAssist():
|
941 |
-
if request.method == 'GET':
|
942 |
-
ret_json = {'code': -1, 'reason':'no support to get'}
|
943 |
-
elif request.method == 'POST':
|
944 |
-
request_data = request.data.decode('utf-8')
|
945 |
-
data = json.loads(request_data)
|
946 |
-
result = self.handle_data(data)
|
947 |
-
if result is None:
|
948 |
-
ret_json = {'code': -2, 'reason':'handle error'}
|
949 |
-
else:
|
950 |
-
ret_json = {'code': 0, 'result':result}
|
951 |
-
return jsonify(ret_json)
|
952 |
-
|
953 |
-
self.app = app
|
954 |
-
now_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
955 |
-
logger.add(f'./logs/logger_[{args.port}]_{now_time}.log')
|
956 |
-
signal.signal(signal.SIGINT, self.signal_handler)
|
957 |
-
|
958 |
-
def handle_data(self, data):
|
959 |
-
im_b64 = data['img']
|
960 |
-
img = base64_to_PILImage(im_b64)
|
961 |
-
remove_texts = data['remove_texts']
|
962 |
-
remove_mask_extend = data['mask_extend']
|
963 |
-
results = run_anything_task(input_image = img,
|
964 |
-
text_prompt = f"{remove_texts}",
|
965 |
-
task_type = 'remove',
|
966 |
-
inpaint_prompt = '',
|
967 |
-
box_threshold = 0.3,
|
968 |
-
text_threshold = 0.25,
|
969 |
-
iou_threshold = 0.8,
|
970 |
-
inpaint_mode = "merge",
|
971 |
-
mask_source_radio = "type what to detect below",
|
972 |
-
remove_mode = "rectangle", # ["segment", "rectangle"]
|
973 |
-
remove_mask_extend = f"{remove_mask_extend}",
|
974 |
-
num_relation = 5,
|
975 |
-
kosmos_input = None,
|
976 |
-
cleaner_size_limit = -1,
|
977 |
-
)
|
978 |
-
output_images = results[0]
|
979 |
-
if output_images is None:
|
980 |
-
return None
|
981 |
-
ret_json_images = []
|
982 |
-
file_temp = int(time.time())
|
983 |
-
count = 0
|
984 |
-
output_images = output_images[-1:]
|
985 |
-
for image_pil in output_images:
|
986 |
-
try:
|
987 |
-
img_format = image_pil.format.lower()
|
988 |
-
except Exception as e:
|
989 |
-
img_format = 'png'
|
990 |
-
image_path = os.path.join(output_dir, f"api_images_{file_temp}_{count}.{img_format}")
|
991 |
-
count += 1
|
992 |
-
try:
|
993 |
-
image_pil.save(image_path)
|
994 |
-
except Exception as e:
|
995 |
-
Image.fromarray(image_pil).save(image_path)
|
996 |
-
im_b64 = imgFile_to_base64(image_path)
|
997 |
-
ret_json_images.append(im_b64)
|
998 |
-
os.remove(image_path)
|
999 |
-
data = {
|
1000 |
-
'imgs': ret_json_images,
|
1001 |
-
}
|
1002 |
-
return data
|
1003 |
-
|
1004 |
-
def signal_handler(self, signal, frame):
|
1005 |
-
print('\nSignal Catched! You have just type Ctrl+C!')
|
1006 |
-
sys.exit(0)
|
1007 |
-
|
1008 |
-
def run(self):
|
1009 |
-
from gevent import pywsgi
|
1010 |
-
logger.info(f'\nargs={args}\n')
|
1011 |
-
computer_info()
|
1012 |
-
print(f"Start a api server: http://0.0.0.0:{args.port}/imgCLeaner")
|
1013 |
-
server = pywsgi.WSGIServer(('0.0.0.0', args.port), self.app)
|
1014 |
-
server.serve_forever()
|
1015 |
-
|
1016 |
-
device = set_device()
|
1017 |
-
|
1018 |
-
groundingdino_model = load_groundingdino_model('cuda:0')
|
1019 |
-
load_sam_model("cuda:0")
|
1020 |
-
|
1021 |
-
load_sd_model("cuda:0")
|
1022 |
-
|
1023 |
-
load_lama_cleaner_model("cuda:0")
|
1024 |
-
|
1025 |
-
# load_ram_model("cuda:0")
|
1026 |
|
|
|
1027 |
|
1028 |
def expand_white_pixels(input_pil, expand_by=1):
|
1029 |
# Convert the input image to grayscale
|
@@ -1063,6 +638,7 @@ s3 = s3_session.client(
|
|
1063 |
endpoint_url=S3_ENDPOINT_URL,
|
1064 |
)
|
1065 |
|
|
|
1066 |
class EndpointHandler():
|
1067 |
def __init__(self, path=""):
|
1068 |
# get_nude(Image.open("girl.png"))
|
@@ -1105,7 +681,3 @@ class EndpointHandler():
|
|
1105 |
return {
|
1106 |
"filenames": filenames
|
1107 |
}
|
1108 |
-
|
1109 |
-
print(EndpointHandler()({
|
1110 |
-
"original_link": "https://www.shutterstock.com/image-photo/attractive-confident-young-woman-posing-600nw-2185228917.jpg"
|
1111 |
-
}))
|
|
|
2 |
import warnings
|
3 |
warnings.filterwarnings('ignore')
|
4 |
import subprocess, io, os, sys, time
|
5 |
+
import random
|
6 |
|
7 |
# os.environ["XFORMERS_DISABLE_FLASH_ATTN"] = "1"
|
8 |
# result = subprocess.run(['pip', 'install', 'xformers'], check=True)
|
|
|
77 |
# qwen_enable = True
|
78 |
# from qwen_utils import *
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
# segment anything
|
81 |
from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
|
82 |
|
|
|
85 |
import requests
|
86 |
import torch
|
87 |
from io import BytesIO
|
|
|
88 |
from huggingface_hub import hf_hub_download
|
89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
|
91 |
ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
92 |
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
|
|
94 |
output_dir = "outputs"
|
95 |
|
96 |
device = 'cpu'
|
|
|
|
|
97 |
sam_device = "cuda"
|
|
|
98 |
|
99 |
|
100 |
def get_sam_vit_h_4b8939():
|
|
|
108 |
f.write(response.content)
|
109 |
print('Downloaded sam_vit_h_4b8939.pth')
|
110 |
|
|
|
111 |
logger.info(f"initialize SAM model...")
|
112 |
sam_device = "cuda"
|
|
|
|
|
|
|
113 |
|
|
|
114 |
sd_model = None
|
115 |
lama_cleaner_model= None
|
116 |
ram_model = None
|
117 |
kosmos_model = None
|
118 |
kosmos_processor = None
|
119 |
|
120 |
+
get_sam_vit_h_4b8939()
|
121 |
+
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
122 |
+
sam_predictor = SamPredictor(sam_model)
|
123 |
+
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
|
124 |
+
|
125 |
def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
|
126 |
args = SLConfig.fromfile(model_config_path)
|
127 |
model = build_model(args)
|
|
|
282 |
re_img = 1 - re_img
|
283 |
return Image.fromarray(np.uint8(255*re_img))
|
284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
# visualization
|
286 |
def draw_selected_mask(mask, draw):
|
287 |
color = (255, 0, 0, 153)
|
|
|
420 |
|
421 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
422 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
|
|
|
|
|
|
|
|
|
423 |
run_task_time = 0
|
424 |
time_cost_str = ''
|
425 |
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
|
|
585 |
except Exception as e:
|
586 |
return 'Error'
|
587 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
588 |
|
589 |
import signal
|
590 |
import json
|
|
|
592 |
from gevent import pywsgi
|
593 |
import base64
|
594 |
|
595 |
+
def get_groundingdino_model(device):
|
596 |
+
# initialize groundingdino model
|
597 |
+
logger.info(f"initialize groundingdino model...")
|
598 |
+
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device)
|
599 |
+
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
600 |
|
601 |
+
groundingdino_model = get_groundingdino_model("cuda")
|
602 |
|
603 |
def expand_white_pixels(input_pil, expand_by=1):
|
604 |
# Convert the input image to grayscale
|
|
|
638 |
endpoint_url=S3_ENDPOINT_URL,
|
639 |
)
|
640 |
|
641 |
+
|
642 |
class EndpointHandler():
|
643 |
def __init__(self, path=""):
|
644 |
# get_nude(Image.open("girl.png"))
|
|
|
681 |
return {
|
682 |
"filenames": filenames
|
683 |
}
|
|
|
|
|
|
|
|