mart9992 commited on
Commit
79fc7ef
1 Parent(s): 5553910

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +13 -441
handler.py CHANGED
@@ -2,7 +2,7 @@ import string
2
  import warnings
3
  warnings.filterwarnings('ignore')
4
  import subprocess, io, os, sys, time
5
- from dw_pose.main import dwpose
6
 
7
  # os.environ["XFORMERS_DISABLE_FLASH_ATTN"] = "1"
8
  # result = subprocess.run(['pip', 'install', 'xformers'], check=True)
@@ -77,19 +77,6 @@ kosmos_enable = False
77
  # qwen_enable = True
78
  # from qwen_utils import *
79
 
80
- if os.environ.get('IS_MY_DEBUG') is not None:
81
- sam_enable = False
82
- ram_enable = False
83
- inpainting_enable = False
84
- kosmos_enable = False
85
-
86
- if lama_cleaner_enable:
87
- try:
88
- from lama_cleaner.model_manager import ModelManager
89
- from lama_cleaner.schema import Config as lama_Config
90
- except Exception as e:
91
- lama_cleaner_enable = False
92
-
93
  # segment anything
94
  from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
95
 
@@ -98,34 +85,8 @@ import PIL
98
  import requests
99
  import torch
100
  from io import BytesIO
101
- from diffusers import StableDiffusionInpaintPipeline
102
  from huggingface_hub import hf_hub_download
103
 
104
- from util_computer import computer_info
105
-
106
- # relate anything
107
- from ram_utils import iou, sort_and_deduplicate, relation_classes, MLP, show_anns, ram_show_mask
108
- from ram_train_eval import RamModel, RamPredictor
109
- from mmengine.config import Config as mmengine_Config
110
-
111
- if lama_cleaner_enable:
112
- from lama_cleaner.helper import (
113
- load_img,
114
- numpy_to_bytes,
115
- resize_max_size,
116
- )
117
-
118
- # from transformers import AutoProcessor, AutoModelForVision2Seq
119
- import ast
120
-
121
- if kosmos_enable and install_stuff:
122
- os.system("pip install transformers@git+https://github.com/huggingface/transformers.git@main")
123
- # os.system("pip install transformers==4.32.0")
124
-
125
- from kosmos_utils import *
126
-
127
- from util_tencent import getTextTrans
128
-
129
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
130
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
131
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
@@ -133,10 +94,7 @@ sam_checkpoint = './sam_hq_vit_h.pth'
133
  output_dir = "outputs"
134
 
135
  device = 'cpu'
136
- os.makedirs(output_dir, exist_ok=True)
137
- groundingdino_model = None
138
  sam_device = "cuda"
139
- sam_model = None
140
 
141
 
142
  def get_sam_vit_h_4b8939():
@@ -150,20 +108,20 @@ def get_sam_vit_h_4b8939():
150
  f.write(response.content)
151
  print('Downloaded sam_vit_h_4b8939.pth')
152
 
153
- get_sam_vit_h_4b8939()
154
  logger.info(f"initialize SAM model...")
155
  sam_device = "cuda"
156
- sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
157
- sam_predictor = SamPredictor(sam_model)
158
- sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
159
 
160
- sam_mask_generator = None
161
  sd_model = None
162
  lama_cleaner_model= None
163
  ram_model = None
164
  kosmos_model = None
165
  kosmos_processor = None
166
 
 
 
 
 
 
167
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
168
  args = SLConfig.fromfile(model_config_path)
169
  model = build_model(args)
@@ -324,167 +282,6 @@ def mix_masks(imgs):
324
  re_img = 1 - re_img
325
  return Image.fromarray(np.uint8(255*re_img))
326
 
327
- def set_device():
328
- if os.environ.get('IS_MY_DEBUG') is None:
329
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
330
- else:
331
- device = 'cpu'
332
- print(f'device={device}')
333
- return device
334
-
335
- def load_groundingdino_model(device):
336
- # initialize groundingdino model
337
- logger.info(f"initialize groundingdino model...")
338
- groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device) #'cpu')
339
- return groundingdino_model
340
-
341
-
342
-
343
- def load_sam_model(device):
344
- # initialize SAM
345
- global sam_model, sam_predictor, sam_mask_generator, sam_device
346
- get_sam_vit_h_4b8939()
347
- logger.info(f"initialize SAM model...")
348
- sam_device = device
349
- sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
350
- sam_predictor = SamPredictor(sam_model)
351
- sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
352
-
353
- def load_sd_model(device):
354
- # initialize stable-diffusion-inpainting
355
- global sd_model
356
- logger.info(f"initialize stable-diffusion-inpainting...")
357
- sd_model = None
358
- if os.environ.get('IS_MY_DEBUG') is None:
359
- sd_model = StableDiffusionInpaintPipeline.from_pretrained(
360
- "runwayml/stable-diffusion-inpainting",
361
- revision="fp16",
362
- # "stabilityai/stable-diffusion-2-inpainting",
363
- torch_dtype=torch.float16,
364
- )
365
- sd_model = sd_model.to(device)
366
-
367
- def load_lama_cleaner_model(device):
368
- # initialize lama_cleaner
369
- global lama_cleaner_model
370
- logger.info(f"initialize lama_cleaner...")
371
-
372
- lama_cleaner_model = ModelManager(
373
- name='lama',
374
- device=device,
375
- )
376
-
377
- def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
378
- try:
379
- logger.info(f'_______lama_cleaner_process_______1____')
380
- ori_image = image
381
- if mask.shape[0] == image.shape[1] and mask.shape[1] == image.shape[0] and mask.shape[0] != mask.shape[1]:
382
- # rotate image
383
- logger.info(f'_______lama_cleaner_process_______2____')
384
- ori_image = np.transpose(image[::-1, ...][:, ::-1], axes=(1, 0, 2))[::-1, ...]
385
- logger.info(f'_______lama_cleaner_process_______3____')
386
- image = ori_image
387
-
388
- logger.info(f'_______lama_cleaner_process_______4____')
389
- original_shape = ori_image.shape
390
- logger.info(f'_______lama_cleaner_process_______5____')
391
- interpolation = cv2.INTER_CUBIC
392
-
393
- size_limit = cleaner_size_limit
394
- if size_limit == -1:
395
- logger.info(f'_______lama_cleaner_process_______6____')
396
- size_limit = max(image.shape)
397
- else:
398
- logger.info(f'_______lama_cleaner_process_______7____')
399
- size_limit = int(size_limit)
400
-
401
- logger.info(f'_______lama_cleaner_process_______8____')
402
- config = lama_Config(
403
- ldm_steps=25,
404
- ldm_sampler='plms',
405
- zits_wireframe=True,
406
- hd_strategy='Original',
407
- hd_strategy_crop_margin=196,
408
- hd_strategy_crop_trigger_size=1280,
409
- hd_strategy_resize_limit=2048,
410
- prompt='',
411
- use_croper=False,
412
- croper_x=0,
413
- croper_y=0,
414
- croper_height=512,
415
- croper_width=512,
416
- sd_mask_blur=5,
417
- sd_strength=0.75,
418
- sd_steps=50,
419
- sd_guidance_scale=7.5,
420
- sd_sampler='ddim',
421
- sd_seed=42,
422
- cv2_flag='INPAINT_NS',
423
- cv2_radius=5,
424
- )
425
-
426
- logger.info(f'_______lama_cleaner_process_______9____')
427
- if config.sd_seed == -1:
428
- config.sd_seed = random.randint(1, 999999999)
429
-
430
- # logger.info(f"Origin image shape_0_: {original_shape} / {size_limit}")
431
- logger.info(f'_______lama_cleaner_process_______10____')
432
- image = resize_max_size(image, size_limit=size_limit, interpolation=interpolation)
433
- # logger.info(f"Resized image shape_1_: {image.shape}")
434
-
435
- # logger.info(f"mask image shape_0_: {mask.shape} / {type(mask)}")
436
- logger.info(f'_______lama_cleaner_process_______11____')
437
- mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
438
- # logger.info(f"mask image shape_1_: {mask.shape} / {type(mask)}")
439
-
440
- logger.info(f'_______lama_cleaner_process_______12____')
441
- res_np_img = lama_cleaner_model(image, mask, config)
442
- logger.info(f'_______lama_cleaner_process_______13____')
443
- torch.cuda.empty_cache()
444
-
445
- logger.info(f'_______lama_cleaner_process_______14____')
446
- image = Image.open(io.BytesIO(numpy_to_bytes(res_np_img, 'png')))
447
- logger.info(f'_______lama_cleaner_process_______15____')
448
- except Exception as e:
449
- logger.info(f'lama_cleaner_process[Error]:' + str(e))
450
- image = None
451
- return image
452
-
453
- class Ram_Predictor(RamPredictor):
454
- def __init__(self, config, device='cpu'):
455
- self.config = config
456
- self.device = torch.device(device)
457
- self._build_model()
458
-
459
- def _build_model(self):
460
- self.model = RamModel(**self.config.model).to(self.device)
461
- if self.config.load_from is not None:
462
- self.model.load_state_dict(torch.load(self.config.load_from, map_location=self.device))
463
- self.model.train()
464
-
465
- def load_ram_model(device):
466
- # load ram model
467
- global ram_model
468
- if os.environ.get('IS_MY_DEBUG') is not None:
469
- return
470
- model_path = "./checkpoints/ram_epoch12.pth"
471
- ram_config = dict(
472
- model=dict(
473
- pretrained_model_name_or_path='bert-base-uncased',
474
- load_pretrained_weights=False,
475
- num_transformer_layer=2,
476
- input_feature_size=256,
477
- output_feature_size=768,
478
- cls_feature_size=512,
479
- num_relation_classes=56,
480
- pred_type='attention',
481
- loss_type='multi_label_ce',
482
- ),
483
- load_from=model_path,
484
- )
485
- ram_config = mmengine_Config(ram_config)
486
- ram_model = Ram_Predictor(ram_config, device)
487
-
488
  # visualization
489
  def draw_selected_mask(mask, draw):
490
  color = (255, 0, 0, 153)
@@ -623,10 +420,6 @@ def get_time_cost(run_task_time, time_cost_str):
623
 
624
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
625
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
626
-
627
- text_prompt = getTextTrans(text_prompt, source='zh', target='en')
628
- inpaint_prompt = getTextTrans(inpaint_prompt, source='zh', target='en')
629
-
630
  run_task_time = 0
631
  time_cost_str = ''
632
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
@@ -792,115 +585,6 @@ def get_model_device(module):
792
  except Exception as e:
793
  return 'Error'
794
 
795
- def main_gradio(args):
796
- block = gr.Blocks().queue()
797
- with block:
798
- with gr.Row():
799
- with gr.Column():
800
- task_types = ["detection"]
801
- if sam_enable:
802
- task_types.append("segment")
803
- if inpainting_enable:
804
- task_types.append("inpainting")
805
- if lama_cleaner_enable:
806
- task_types.append("remove")
807
- if ram_enable:
808
- task_types.append("relate anything")
809
- if kosmos_enable:
810
- task_types.append("Kosmos-2")
811
-
812
- input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload",
813
- height=512, brush_color='#00FFFF', mask_opacity=0.6)
814
- task_type = gr.Radio(task_types, value="detection",
815
- label='Task type', visible=True)
816
- mask_source_radio = gr.Radio([mask_source_draw, mask_source_segment],
817
- value=mask_source_segment, label="Mask from",
818
- visible=False)
819
- text_prompt = gr.Textbox(label="Detection Prompt[To detect multiple objects, seperating each with '.', like this: cat . dog . chair ]", placeholder="Cannot be empty")
820
- inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
821
- num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
822
-
823
- kosmos_input = gr.Radio(["Brief", "Detailed"], label="Kosmos Description Type", value="Brief", visible=False)
824
-
825
- run_button = gr.Button(label="Run", visible=True)
826
- with gr.Accordion("Advanced options", open=False) as advanced_options:
827
- box_threshold = gr.Slider(
828
- label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
829
- )
830
- text_threshold = gr.Slider(
831
- label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
832
- )
833
- iou_threshold = gr.Slider(
834
- label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
835
- )
836
- inpaint_mode = gr.Radio(["merge", "first"], value="merge", label="inpaint_mode")
837
- with gr.Row():
838
- with gr.Column(scale=1):
839
- remove_mode = gr.Radio(["segment", "rectangle"], value="segment", label='remove mode')
840
- with gr.Column(scale=1):
841
- remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
842
-
843
- with gr.Column():
844
- image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", height=512, visible=True
845
- ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
846
- time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
847
-
848
- kosmos_output = gr.Image(type="pil", label="result images", visible=False)
849
- kosmos_text_output = gr.HighlightedText(
850
- label="Generated Description",
851
- combine_adjacent=False,
852
- show_legend=True,
853
- visible=False,
854
- ).style(color_map=color_map)
855
- # record which text span (label) is selected
856
- selected = gr.Number(-1, show_label=False, placeholder="Selected", visible=False)
857
-
858
- # record the current `entities`
859
- entity_output = gr.Textbox(visible=False)
860
-
861
- # get the current selected span label
862
- def get_text_span_label(evt: gr.SelectData):
863
- if evt.value[-1] is None:
864
- return -1
865
- return int(evt.value[-1])
866
- # and set this information to `selected`
867
- kosmos_text_output.select(get_text_span_label, None, selected)
868
-
869
- # update output image when we change the span (enity) selection
870
- def update_output_image(img_input, image_output, entities, idx):
871
- entities = ast.literal_eval(entities)
872
- updated_image = draw_entity_boxes_on_image(img_input, entities, entity_index=idx)
873
- return updated_image
874
- selected.change(update_output_image, [kosmos_output, kosmos_output, entity_output, selected], [kosmos_output])
875
-
876
- run_button.click(fn=run_anything_task, inputs=[
877
- input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
878
- iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
879
- outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
880
-
881
- mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
882
- outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
883
- task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
884
- outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation,
885
- image_gallery, kosmos_input, kosmos_output, kosmos_text_output
886
- ])
887
-
888
- DESCRIPTION = f'### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
889
- if lama_cleaner_enable:
890
- DESCRIPTION += f'Remove(cleaner) from [lama-cleaner](https://github.com/Sanster/lama-cleaner). <br>'
891
- if kosmos_enable:
892
- DESCRIPTION += f'Kosmos-2 from [Kosmos-2](https://github.com/microsoft/unilm/tree/master/kosmos-2). <br>'
893
- if ram_enable:
894
- DESCRIPTION += f'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
895
- DESCRIPTION += f'Thanks for their excellent work.'
896
- DESCRIPTION += f'<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. \
897
- <a href="https://huggingface.co/spaces/yizhangliu/Grounded-Segment-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
898
- gr.Markdown(DESCRIPTION)
899
-
900
- print(f'device = {device}')
901
- print(f'torch.cuda.is_available = {torch.cuda.is_available()}')
902
- computer_info()
903
- block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
904
 
905
  import signal
906
  import json
@@ -908,122 +592,13 @@ from datetime import date, datetime, timedelta
908
  from gevent import pywsgi
909
  import base64
910
 
911
- def imgFile_to_base64(image_file):
912
- with open(image_file, "rb") as f:
913
- im_bytes = f.read()
914
- im_b64_encode = base64.b64encode(im_bytes)
915
- im_b64 = im_b64_encode.decode("utf8")
916
- return im_b64
917
-
918
- def base64_to_bytes(im_b64):
919
- im_b64_encode = im_b64.encode("utf-8")
920
- im_bytes = base64.b64decode(im_b64_encode)
921
- return im_bytes
922
-
923
- def base64_to_PILImage(im_b64):
924
- im_bytes = base64_to_bytes(im_b64)
925
- pil_img = Image.open(io.BytesIO(im_bytes))
926
- return pil_img
927
-
928
- class API_Starter:
929
- def __init__(self):
930
- from flask import Flask, request, jsonify, make_response
931
- from flask_cors import CORS, cross_origin
932
- import logging
933
-
934
- app = Flask(__name__)
935
- app.logger.setLevel(logging.ERROR)
936
- CORS(app, supports_credentials=True, resources={r"/*": {"origins": "*"}})
937
-
938
- @app.route('/imgCLeaner', methods=['GET', 'POST'])
939
- @cross_origin()
940
- def processAssist():
941
- if request.method == 'GET':
942
- ret_json = {'code': -1, 'reason':'no support to get'}
943
- elif request.method == 'POST':
944
- request_data = request.data.decode('utf-8')
945
- data = json.loads(request_data)
946
- result = self.handle_data(data)
947
- if result is None:
948
- ret_json = {'code': -2, 'reason':'handle error'}
949
- else:
950
- ret_json = {'code': 0, 'result':result}
951
- return jsonify(ret_json)
952
-
953
- self.app = app
954
- now_time = datetime.now().strftime('%Y%m%d_%H%M%S')
955
- logger.add(f'./logs/logger_[{args.port}]_{now_time}.log')
956
- signal.signal(signal.SIGINT, self.signal_handler)
957
-
958
- def handle_data(self, data):
959
- im_b64 = data['img']
960
- img = base64_to_PILImage(im_b64)
961
- remove_texts = data['remove_texts']
962
- remove_mask_extend = data['mask_extend']
963
- results = run_anything_task(input_image = img,
964
- text_prompt = f"{remove_texts}",
965
- task_type = 'remove',
966
- inpaint_prompt = '',
967
- box_threshold = 0.3,
968
- text_threshold = 0.25,
969
- iou_threshold = 0.8,
970
- inpaint_mode = "merge",
971
- mask_source_radio = "type what to detect below",
972
- remove_mode = "rectangle", # ["segment", "rectangle"]
973
- remove_mask_extend = f"{remove_mask_extend}",
974
- num_relation = 5,
975
- kosmos_input = None,
976
- cleaner_size_limit = -1,
977
- )
978
- output_images = results[0]
979
- if output_images is None:
980
- return None
981
- ret_json_images = []
982
- file_temp = int(time.time())
983
- count = 0
984
- output_images = output_images[-1:]
985
- for image_pil in output_images:
986
- try:
987
- img_format = image_pil.format.lower()
988
- except Exception as e:
989
- img_format = 'png'
990
- image_path = os.path.join(output_dir, f"api_images_{file_temp}_{count}.{img_format}")
991
- count += 1
992
- try:
993
- image_pil.save(image_path)
994
- except Exception as e:
995
- Image.fromarray(image_pil).save(image_path)
996
- im_b64 = imgFile_to_base64(image_path)
997
- ret_json_images.append(im_b64)
998
- os.remove(image_path)
999
- data = {
1000
- 'imgs': ret_json_images,
1001
- }
1002
- return data
1003
-
1004
- def signal_handler(self, signal, frame):
1005
- print('\nSignal Catched! You have just type Ctrl+C!')
1006
- sys.exit(0)
1007
-
1008
- def run(self):
1009
- from gevent import pywsgi
1010
- logger.info(f'\nargs={args}\n')
1011
- computer_info()
1012
- print(f"Start a api server: http://0.0.0.0:{args.port}/imgCLeaner")
1013
- server = pywsgi.WSGIServer(('0.0.0.0', args.port), self.app)
1014
- server.serve_forever()
1015
-
1016
- device = set_device()
1017
-
1018
- groundingdino_model = load_groundingdino_model('cuda:0')
1019
- load_sam_model("cuda:0")
1020
-
1021
- load_sd_model("cuda:0")
1022
-
1023
- load_lama_cleaner_model("cuda:0")
1024
-
1025
- # load_ram_model("cuda:0")
1026
 
 
1027
 
1028
  def expand_white_pixels(input_pil, expand_by=1):
1029
  # Convert the input image to grayscale
@@ -1063,6 +638,7 @@ s3 = s3_session.client(
1063
  endpoint_url=S3_ENDPOINT_URL,
1064
  )
1065
 
 
1066
  class EndpointHandler():
1067
  def __init__(self, path=""):
1068
  # get_nude(Image.open("girl.png"))
@@ -1105,7 +681,3 @@ class EndpointHandler():
1105
  return {
1106
  "filenames": filenames
1107
  }
1108
-
1109
- print(EndpointHandler()({
1110
- "original_link": "https://www.shutterstock.com/image-photo/attractive-confident-young-woman-posing-600nw-2185228917.jpg"
1111
- }))
 
2
  import warnings
3
  warnings.filterwarnings('ignore')
4
  import subprocess, io, os, sys, time
5
+ import random
6
 
7
  # os.environ["XFORMERS_DISABLE_FLASH_ATTN"] = "1"
8
  # result = subprocess.run(['pip', 'install', 'xformers'], check=True)
 
77
  # qwen_enable = True
78
  # from qwen_utils import *
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # segment anything
81
  from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
82
 
 
85
  import requests
86
  import torch
87
  from io import BytesIO
 
88
  from huggingface_hub import hf_hub_download
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
91
  ckpt_repo_id = "ShilongLiu/GroundingDINO"
92
  ckpt_filenmae = "groundingdino_swint_ogc.pth"
 
94
  output_dir = "outputs"
95
 
96
  device = 'cpu'
 
 
97
  sam_device = "cuda"
 
98
 
99
 
100
  def get_sam_vit_h_4b8939():
 
108
  f.write(response.content)
109
  print('Downloaded sam_vit_h_4b8939.pth')
110
 
 
111
  logger.info(f"initialize SAM model...")
112
  sam_device = "cuda"
 
 
 
113
 
 
114
  sd_model = None
115
  lama_cleaner_model= None
116
  ram_model = None
117
  kosmos_model = None
118
  kosmos_processor = None
119
 
120
+ get_sam_vit_h_4b8939()
121
+ sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
122
+ sam_predictor = SamPredictor(sam_model)
123
+ sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
124
+
125
  def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
126
  args = SLConfig.fromfile(model_config_path)
127
  model = build_model(args)
 
282
  re_img = 1 - re_img
283
  return Image.fromarray(np.uint8(255*re_img))
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  # visualization
286
  def draw_selected_mask(mask, draw):
287
  color = (255, 0, 0, 153)
 
420
 
421
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
422
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
 
 
 
 
423
  run_task_time = 0
424
  time_cost_str = ''
425
  run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
 
585
  except Exception as e:
586
  return 'Error'
587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
588
 
589
  import signal
590
  import json
 
592
  from gevent import pywsgi
593
  import base64
594
 
595
+ def get_groundingdino_model(device):
596
+ # initialize groundingdino model
597
+ logger.info(f"initialize groundingdino model...")
598
+ model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device=device)
599
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600
 
601
+ groundingdino_model = get_groundingdino_model("cuda")
602
 
603
  def expand_white_pixels(input_pil, expand_by=1):
604
  # Convert the input image to grayscale
 
638
  endpoint_url=S3_ENDPOINT_URL,
639
  )
640
 
641
+
642
  class EndpointHandler():
643
  def __init__(self, path=""):
644
  # get_nude(Image.open("girl.png"))
 
681
  return {
682
  "filenames": filenames
683
  }