LanHarmony commited on
Commit
8c482b3
1 Parent(s): b863a7a

support GroundingDINO and segment-anything

Browse files
Files changed (3) hide show
  1. app.py +5 -0
  2. requirements.txt +28 -28
  3. visual_foundation_models.py +395 -157
app.py CHANGED
@@ -87,6 +87,11 @@ VISUAL_CHATGPT_SUFFIX_CN = """你对文件名的正确性非常严格,而且
87
  Thought: Do I need to use a tool? {agent_scratchpad}
88
  """
89
 
 
 
 
 
 
90
  from visual_foundation_models import *
91
  from langchain.agents.initialize import initialize_agent
92
  from langchain.agents.tools import Tool
87
  Thought: Do I need to use a tool? {agent_scratchpad}
88
  """
89
 
90
+ import os
91
+
92
+ os.system('pip install git+https://github.com/IDEA-Research/GroundingDINO.git')
93
+ os.system('pip install git+https://github.com/facebookresearch/segment-anything.git')
94
+
95
  from visual_foundation_models import *
96
  from langchain.agents.initialize import initialize_agent
97
  from langchain.agents.tools import Tool
requirements.txt CHANGED
@@ -1,32 +1,32 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu113
2
- torch==1.12.1
3
- torchvision==0.13.1
4
- numpy==1.23.1
5
- transformers==4.26.1
6
- albumentations==1.3.0
7
- opencv-contrib-python==4.3.0.36
8
- imageio==2.9.0
9
- imageio-ffmpeg==0.4.2
10
- pytorch-lightning==1.5.0
11
- omegaconf==2.1.1
12
- test-tube>=0.7.5
13
- streamlit==1.12.1
14
- einops==0.3.0
15
- webdataset==0.2.5
16
- kornia==0.6
17
- open_clip_torch==2.0.2
18
- invisible-watermark>=0.1.5
19
- streamlit-drawable-canvas==0.8.0
20
- torchmetrics==0.6.0
21
- timm==0.6.12
22
- addict==2.4.0
23
- yapf==0.32.0
24
- prettytable==3.6.0
25
- safetensors==0.2.7
26
- basicsr==1.4.2
27
  langchain==0.0.101
28
- diffusers==0.14.0
 
 
 
 
 
 
 
 
 
29
  gradio
 
 
 
 
 
 
 
30
  openai
31
- accelerate
32
- controlnet-aux==0.0.1
 
 
 
 
 
 
 
 
 
1
  --extra-index-url https://download.pytorch.org/whl/cu113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  langchain==0.0.101
3
+ torch==1.13.1
4
+ torchvision==0.14.1
5
+ wget==3.2
6
+ accelerate
7
+ addict
8
+ albumentations
9
+ basicsr
10
+ controlnet-aux
11
+ diffusers
12
+ einops
13
  gradio
14
+ imageio
15
+ imageio-ffmpeg
16
+ invisible-watermark
17
+ kornia
18
+ numpy
19
+ omegaconf
20
+ open_clip_torch
21
  openai
22
+ opencv-python
23
+ prettytable
24
+ safetensors
25
+ streamlit
26
+ test-tube
27
+ timm
28
+ torchmetrics
29
+ transformers
30
+ webdataset
31
+ yapf
32
+
visual_foundation_models.py CHANGED
@@ -19,6 +19,18 @@ import math
19
 
20
  from langchain.llms.openai import OpenAI
21
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def prompts(name, description):
23
  def decorator(func):
24
  func.name = name
@@ -101,76 +113,6 @@ def get_new_image_name(org_img_name, func_name="update"):
101
  return os.path.join(head, new_file_name)
102
 
103
 
104
- class MaskFormer:
105
- def __init__(self, device):
106
- print(f"Initializing MaskFormer to {device}")
107
- self.device = device
108
- self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
109
- self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
110
-
111
- def inference(self, image_path, text):
112
- threshold = 0.5
113
- min_area = 0.02
114
- padding = 20
115
- original_image = Image.open(image_path)
116
- image = original_image.resize((512, 512))
117
- inputs = self.processor(text=text, images=image, padding="max_length", return_tensors="pt").to(self.device)
118
- with torch.no_grad():
119
- outputs = self.model(**inputs)
120
- mask = torch.sigmoid(outputs[0]).squeeze().cpu().numpy() > threshold
121
- area_ratio = len(np.argwhere(mask)) / (mask.shape[0] * mask.shape[1])
122
- if area_ratio < min_area:
123
- return None
124
- true_indices = np.argwhere(mask)
125
- mask_array = np.zeros_like(mask, dtype=bool)
126
- for idx in true_indices:
127
- padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx)
128
- mask_array[padded_slice] = True
129
- visual_mask = (mask_array * 255).astype(np.uint8)
130
- image_mask = Image.fromarray(visual_mask)
131
- return image_mask.resize(original_image.size)
132
-
133
-
134
- class ImageEditing:
135
- def __init__(self, device):
136
- print(f"Initializing ImageEditing to {device}")
137
- self.device = device
138
- self.mask_former = MaskFormer(device=self.device)
139
- self.revision = 'fp16' if 'cuda' in device else None
140
- self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
141
- self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
142
- "runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device)
143
-
144
- @prompts(name="Remove Something From The Photo",
145
- description="useful when you want to remove and object or something from the photo "
146
- "from its description or location. "
147
- "The input to this tool should be a comma separated string of two, "
148
- "representing the image_path and the object need to be removed. ")
149
- def inference_remove(self, inputs):
150
- image_path, to_be_removed_txt = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
151
- return self.inference_replace(f"{image_path},{to_be_removed_txt},background")
152
-
153
- @prompts(name="Replace Something From The Photo",
154
- description="useful when you want to replace an object from the object description or "
155
- "location with another object from its description. "
156
- "The input to this tool should be a comma separated string of three, "
157
- "representing the image_path, the object to be replaced, the object to be replaced with ")
158
- def inference_replace(self, inputs):
159
- image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
160
- original_image = Image.open(image_path)
161
- original_size = original_image.size
162
- mask_image = self.mask_former.inference(image_path, to_be_replaced_txt)
163
- updated_image = self.inpaint(prompt=replace_with_txt, image=original_image.resize((512, 512)),
164
- mask_image=mask_image.resize((512, 512))).images[0]
165
- updated_image_path = get_new_image_name(image_path, func_name="replace-something")
166
- updated_image = updated_image.resize(original_size)
167
- updated_image.save(updated_image_path)
168
- print(
169
- f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, "
170
- f"Output Image: {updated_image_path}")
171
- return updated_image_path
172
-
173
-
174
  class InstructPix2Pix:
175
  def __init__(self, device):
176
  print(f"Initializing InstructPix2Pix to {device}")
@@ -283,7 +225,7 @@ class CannyText2Image:
283
  self.seed = -1
284
  self.a_prompt = 'best quality, extremely detailed'
285
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
286
- 'fewer digits, cropped, worst quality, low quality'
287
 
288
  @prompts(name="Generate Image Condition On Canny Image",
289
  description="useful when you want to generate a new real image from both the user description and a canny image."
@@ -340,7 +282,7 @@ class LineText2Image:
340
  self.seed = -1
341
  self.a_prompt = 'best quality, extremely detailed'
342
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
343
- 'fewer digits, cropped, worst quality, low quality'
344
 
345
  @prompts(name="Generate Image Condition On Line Image",
346
  description="useful when you want to generate a new real image from both the user description "
@@ -398,7 +340,7 @@ class HedText2Image:
398
  self.seed = -1
399
  self.a_prompt = 'best quality, extremely detailed'
400
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
401
- 'fewer digits, cropped, worst quality, low quality'
402
 
403
  @prompts(name="Generate Image Condition On Soft Hed Boundary Image",
404
  description="useful when you want to generate a new real image from both the user description "
@@ -456,7 +398,7 @@ class ScribbleText2Image:
456
  self.seed = -1
457
  self.a_prompt = 'best quality, extremely detailed'
458
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
459
- 'fewer digits, cropped, worst quality, low quality'
460
 
461
  @prompts(name="Generate Image Condition On Sketch Image",
462
  description="useful when you want to generate a new real image from both the user description and "
@@ -512,7 +454,7 @@ class PoseText2Image:
512
  self.unconditional_guidance_scale = 9.0
513
  self.a_prompt = 'best quality, extremely detailed'
514
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
515
- ' fewer digits, cropped, worst quality, low quality'
516
 
517
  @prompts(name="Generate Image Condition On Pose Image",
518
  description="useful when you want to generate a new real image from both the user description "
@@ -536,73 +478,6 @@ class PoseText2Image:
536
  return updated_image_path
537
 
538
 
539
- class Image2Seg:
540
- def __init__(self, device):
541
- print("Initializing Image2Seg")
542
- self.image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small")
543
- self.image_segmentor = UperNetForSemanticSegmentation.from_pretrained("openmmlab/upernet-convnext-small")
544
- self.ade_palette = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
545
- [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
546
- [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
547
- [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
548
- [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
549
- [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
550
- [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
551
- [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
552
- [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
553
- [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
554
- [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
555
- [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
556
- [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
557
- [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
558
- [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
559
- [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
560
- [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
561
- [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
562
- [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
563
- [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
564
- [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
565
- [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
566
- [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
567
- [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
568
- [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
569
- [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
570
- [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
571
- [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
572
- [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
573
- [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
574
- [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
575
- [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
576
- [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
577
- [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
578
- [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
579
- [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
580
- [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
581
- [102, 255, 0], [92, 0, 255]]
582
-
583
- @prompts(name="Segmentation On Image",
584
- description="useful when you want to detect segmentations of the image. "
585
- "like: segment this image, or generate segmentations on this image, "
586
- "or perform segmentation on this image. "
587
- "The input to this tool should be a string, representing the image_path")
588
- def inference(self, inputs):
589
- image = Image.open(inputs)
590
- pixel_values = self.image_processor(image, return_tensors="pt").pixel_values
591
- with torch.no_grad():
592
- outputs = self.image_segmentor(pixel_values)
593
- seg = self.image_processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
594
- color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
595
- palette = np.array(self.ade_palette)
596
- for label, color in enumerate(palette):
597
- color_seg[seg == label, :] = color
598
- color_seg = color_seg.astype(np.uint8)
599
- segmentation = Image.fromarray(color_seg)
600
- updated_image_path = get_new_image_name(inputs, func_name="segmentation")
601
- segmentation.save(updated_image_path)
602
- print(f"\nProcessed Image2Seg, Input Image: {inputs}, Output Pose: {updated_image_path}")
603
- return updated_image_path
604
-
605
-
606
  class SegText2Image:
607
  def __init__(self, device):
608
  print(f"Initializing SegText2Image to {device}")
@@ -617,7 +492,7 @@ class SegText2Image:
617
  self.seed = -1
618
  self.a_prompt = 'best quality, extremely detailed'
619
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
620
- ' fewer digits, cropped, worst quality, low quality'
621
 
622
  @prompts(name="Generate Image Condition On Segmentations",
623
  description="useful when you want to generate a new real image from both the user description and segmentations. "
@@ -676,7 +551,7 @@ class DepthText2Image:
676
  self.seed = -1
677
  self.a_prompt = 'best quality, extremely detailed'
678
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
679
- ' fewer digits, cropped, worst quality, low quality'
680
 
681
  @prompts(name="Generate Image Condition On Depth",
682
  description="useful when you want to generate a new real image from both the user description and depth image. "
@@ -747,7 +622,7 @@ class NormalText2Image:
747
  self.seed = -1
748
  self.a_prompt = 'best quality, extremely detailed'
749
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
750
- ' fewer digits, cropped, worst quality, low quality'
751
 
752
  @prompts(name="Generate Image Condition On Normal Map",
753
  description="useful when you want to generate a new real image from both the user description and normal map. "
@@ -793,10 +668,284 @@ class VisualQuestionAnswering:
793
  f"Output Answer: {answer}")
794
  return answer
795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
  class InfinityOutPainting:
797
- template_model = True # Add this line to show this is a template model.
 
798
  def __init__(self, ImageCaptioning, ImageEditing, VisualQuestionAnswering):
799
- # self.llm = OpenAI(temperature=0)
800
  self.ImageCaption = ImageCaptioning
801
  self.ImageEditing = ImageEditing
802
  self.ImageVQA = VisualQuestionAnswering
@@ -814,16 +963,16 @@ class InfinityOutPainting:
814
 
815
  def get_BLIP_caption(self, image):
816
  inputs = self.ImageCaption.processor(image, return_tensors="pt").to(self.ImageCaption.device,
817
- self.ImageCaption.torch_dtype)
818
  out = self.ImageCaption.model.generate(**inputs)
819
  BLIP_caption = self.ImageCaption.processor.decode(out[0], skip_special_tokens=True)
820
  return BLIP_caption
821
 
822
- # def check_prompt(self, prompt):
823
- # check = f"Here is a paragraph with adjectives. " \
824
- # f"{prompt} " \
825
- # f"Please change all plural forms in the adjectives to singular forms. "
826
- # return self.llm(check)
827
 
828
  def get_imagine_caption(self, image, imagine):
829
  BLIP_caption = self.get_BLIP_caption(image)
@@ -836,14 +985,13 @@ class InfinityOutPainting:
836
  f"You should make the painting as vivid and realistic as possible" \
837
  f"You can not use words like painting or picture" \
838
  f"and you should use no more than 50 words to describe it"
839
- # caption = self.llm(imagine_prompt) if imagine else BLIP_caption
840
- caption = BLIP_caption
841
- # caption = self.check_prompt(caption)
842
  print(f'BLIP observation: {BLIP_caption}, ChatGPT imagine to {caption}') if imagine else print(
843
  f'Prompt: {caption}')
844
  return caption
845
 
846
- def resize_image(self, image, max_size=100000, multiple=8):
847
  aspect_ratio = image.size[0] / image.size[1]
848
  new_width = int(math.sqrt(max_size * aspect_ratio))
849
  new_height = int(new_width / aspect_ratio)
@@ -889,4 +1037,94 @@ class InfinityOutPainting:
889
  out_painted_image.save(updated_image_path)
890
  print(f"\nProcessed InfinityOutPainting, Input Image: {image_path}, Input Resolution: {resolution}, "
891
  f"Output Image: {updated_image_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892
  return updated_image_path
19
 
20
  from langchain.llms.openai import OpenAI
21
 
22
+ # Grounding DINO
23
+ import groundingdino.datasets.transforms as T
24
+ from groundingdino.models import build_model
25
+ from groundingdino.util import box_ops
26
+ from groundingdino.util.slconfig import SLConfig
27
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
28
+
29
+ # segment anything
30
+ from segment_anything import build_sam, SamPredictor, SamAutomaticMaskGenerator
31
+ import matplotlib.pyplot as plt
32
+ import wget
33
+
34
  def prompts(name, description):
35
  def decorator(func):
36
  func.name = name
113
  return os.path.join(head, new_file_name)
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  class InstructPix2Pix:
117
  def __init__(self, device):
118
  print(f"Initializing InstructPix2Pix to {device}")
225
  self.seed = -1
226
  self.a_prompt = 'best quality, extremely detailed'
227
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
228
+ 'fewer digits, cropped, worst quality, low quality'
229
 
230
  @prompts(name="Generate Image Condition On Canny Image",
231
  description="useful when you want to generate a new real image from both the user description and a canny image."
282
  self.seed = -1
283
  self.a_prompt = 'best quality, extremely detailed'
284
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
285
+ 'fewer digits, cropped, worst quality, low quality'
286
 
287
  @prompts(name="Generate Image Condition On Line Image",
288
  description="useful when you want to generate a new real image from both the user description "
340
  self.seed = -1
341
  self.a_prompt = 'best quality, extremely detailed'
342
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
343
+ 'fewer digits, cropped, worst quality, low quality'
344
 
345
  @prompts(name="Generate Image Condition On Soft Hed Boundary Image",
346
  description="useful when you want to generate a new real image from both the user description "
398
  self.seed = -1
399
  self.a_prompt = 'best quality, extremely detailed'
400
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
401
+ 'fewer digits, cropped, worst quality, low quality'
402
 
403
  @prompts(name="Generate Image Condition On Sketch Image",
404
  description="useful when you want to generate a new real image from both the user description and "
454
  self.unconditional_guidance_scale = 9.0
455
  self.a_prompt = 'best quality, extremely detailed'
456
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
457
+ ' fewer digits, cropped, worst quality, low quality'
458
 
459
  @prompts(name="Generate Image Condition On Pose Image",
460
  description="useful when you want to generate a new real image from both the user description "
478
  return updated_image_path
479
 
480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  class SegText2Image:
482
  def __init__(self, device):
483
  print(f"Initializing SegText2Image to {device}")
492
  self.seed = -1
493
  self.a_prompt = 'best quality, extremely detailed'
494
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
495
+ ' fewer digits, cropped, worst quality, low quality'
496
 
497
  @prompts(name="Generate Image Condition On Segmentations",
498
  description="useful when you want to generate a new real image from both the user description and segmentations. "
551
  self.seed = -1
552
  self.a_prompt = 'best quality, extremely detailed'
553
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
554
+ ' fewer digits, cropped, worst quality, low quality'
555
 
556
  @prompts(name="Generate Image Condition On Depth",
557
  description="useful when you want to generate a new real image from both the user description and depth image. "
622
  self.seed = -1
623
  self.a_prompt = 'best quality, extremely detailed'
624
  self.n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit,' \
625
+ ' fewer digits, cropped, worst quality, low quality'
626
 
627
  @prompts(name="Generate Image Condition On Normal Map",
628
  description="useful when you want to generate a new real image from both the user description and normal map. "
668
  f"Output Answer: {answer}")
669
  return answer
670
 
671
+
672
+ class Segmenting:
673
+ def __init__(self, device):
674
+ print(f"Inintializing Segmentation to {device}")
675
+ self.device = device
676
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
677
+ self.model_checkpoint_path = os.path.join("checkpoints", "sam")
678
+
679
+ self.download_parameters()
680
+ self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device)
681
+ self.sam_predictor = SamPredictor(self.sam)
682
+ self.mask_generator = SamAutomaticMaskGenerator(self.sam)
683
+
684
+ def download_parameters(self):
685
+ url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
686
+ if not os.path.exists(self.model_checkpoint_path):
687
+ wget.download(url, out=self.model_checkpoint_path)
688
+
689
+ def show_mask(self, mask, ax, random_color=False):
690
+ if random_color:
691
+ color = np.concatenate([np.random.random(3), np.array([1])], axis=0)
692
+ else:
693
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 1])
694
+ h, w = mask.shape[-2:]
695
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
696
+ ax.imshow(mask_image)
697
+
698
+ def show_box(self, box, ax, label):
699
+ x0, y0 = box[0], box[1]
700
+ w, h = box[2] - box[0], box[3] - box[1]
701
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
702
+ ax.text(x0, y0, label)
703
+
704
+ def get_mask_with_boxes(self, image_pil, image, boxes_filt):
705
+
706
+ size = image_pil.size
707
+ H, W = size[1], size[0]
708
+ for i in range(boxes_filt.size(0)):
709
+ boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
710
+ boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
711
+ boxes_filt[i][2:] += boxes_filt[i][:2]
712
+
713
+ boxes_filt = boxes_filt.cpu()
714
+ transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(self.device)
715
+
716
+ masks, _, _ = self.sam_predictor.predict_torch(
717
+ point_coords=None,
718
+ point_labels=None,
719
+ boxes=transformed_boxes.to(self.device),
720
+ multimask_output=False,
721
+ )
722
+ return masks
723
+
724
+ def segment_image_with_boxes(self, image_pil, image_path, boxes_filt, pred_phrases):
725
+
726
+ image = cv2.imread(image_path)
727
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
728
+ self.sam_predictor.set_image(image)
729
+
730
+ masks = self.get_mask_with_boxes(image_pil, image, boxes_filt)
731
+
732
+ # draw output image
733
+ plt.figure(figsize=(10, 10))
734
+ plt.imshow(image)
735
+ for mask in masks:
736
+ self.show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
737
+
738
+ updated_image_path = get_new_image_name(image_path, func_name="segmentation")
739
+ plt.axis('off')
740
+ plt.savefig(
741
+ updated_image_path,
742
+ bbox_inches="tight", dpi=300, pad_inches=0.0
743
+ )
744
+ return updated_image_path
745
+
746
+ @prompts(name="Segment the Image",
747
+ description="useful when you want to segment all the part of the image, but not segment a certain object."
748
+ "like: segment all the object in this image, or generate segmentations on this image, "
749
+ "or segment the image,"
750
+ "or perform segmentation on this image, "
751
+ "or segment all the object in this image."
752
+ "The input to this tool should be a string, representing the image_path")
753
+ def inference_all(self, image_path):
754
+ image = cv2.imread(image_path)
755
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
756
+ masks = self.mask_generator.generate(image)
757
+ plt.figure(figsize=(20, 20))
758
+ plt.imshow(image)
759
+ if len(masks) == 0:
760
+ return
761
+ sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
762
+ ax = plt.gca()
763
+ ax.set_autoscale_on(False)
764
+ polygons = []
765
+ color = []
766
+ for ann in sorted_anns:
767
+ m = ann['segmentation']
768
+ img = np.ones((m.shape[0], m.shape[1], 3))
769
+ color_mask = np.random.random((1, 3)).tolist()[0]
770
+ for i in range(3):
771
+ img[:, :, i] = color_mask[i]
772
+ ax.imshow(np.dstack((img, m)))
773
+
774
+ updated_image_path = get_new_image_name(image_path, func_name="segment-image")
775
+ plt.axis('off')
776
+ plt.savefig(
777
+ updated_image_path,
778
+ bbox_inches="tight", dpi=300, pad_inches=0.0
779
+ )
780
+ return updated_image_path
781
+
782
+
783
+ class Text2Box:
784
+ def __init__(self, device):
785
+ print(f"Initializing ObjectDetection to {device}")
786
+ self.device = device
787
+ self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
788
+ self.model_checkpoint_path = os.path.join("checkpoints", "groundingdino")
789
+ self.model_config_path = os.path.join("checkpoints", "grounding_config.py")
790
+ self.download_parameters()
791
+ self.box_threshold = 0.3
792
+ self.text_threshold = 0.25
793
+ self.grounding = (self.load_model()).to(self.device)
794
+
795
+ def download_parameters(self):
796
+ url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
797
+ if not os.path.exists(self.model_checkpoint_path):
798
+ wget.download(url, out=self.model_checkpoint_path)
799
+ config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
800
+ if not os.path.exists(self.model_config_path):
801
+ wget.download(config_url, out=self.model_config_path)
802
+
803
+ def load_image(self, image_path):
804
+ # load image
805
+ image_pil = Image.open(image_path).convert("RGB") # load image
806
+
807
+ transform = T.Compose(
808
+ [
809
+ T.RandomResize([512], max_size=1333),
810
+ T.ToTensor(),
811
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
812
+ ]
813
+ )
814
+ image, _ = transform(image_pil, None) # 3, h, w
815
+ return image_pil, image
816
+
817
+ def load_model(self):
818
+ args = SLConfig.fromfile(self.model_config_path)
819
+ args.device = self.device
820
+ model = build_model(args)
821
+ checkpoint = torch.load(self.model_checkpoint_path, map_location="cpu")
822
+ load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
823
+ print(load_res)
824
+ _ = model.eval()
825
+ return model
826
+
827
+ def get_grounding_boxes(self, image, caption, with_logits=True):
828
+ caption = caption.lower()
829
+ caption = caption.strip()
830
+ if not caption.endswith("."):
831
+ caption = caption + "."
832
+ image = image.to(self.device)
833
+ with torch.no_grad():
834
+ outputs = self.grounding(image[None], captions=[caption])
835
+ logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
836
+ boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
837
+ logits.shape[0]
838
+
839
+ # filter output
840
+ logits_filt = logits.clone()
841
+ boxes_filt = boxes.clone()
842
+ filt_mask = logits_filt.max(dim=1)[0] > self.box_threshold
843
+ logits_filt = logits_filt[filt_mask] # num_filt, 256
844
+ boxes_filt = boxes_filt[filt_mask] # num_filt, 4
845
+ logits_filt.shape[0]
846
+
847
+ # get phrase
848
+ tokenlizer = self.grounding.tokenizer
849
+ tokenized = tokenlizer(caption)
850
+ # build pred
851
+ pred_phrases = []
852
+ for logit, box in zip(logits_filt, boxes_filt):
853
+ pred_phrase = get_phrases_from_posmap(logit > self.text_threshold, tokenized, tokenlizer)
854
+ if with_logits:
855
+ pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
856
+ else:
857
+ pred_phrases.append(pred_phrase)
858
+
859
+ return boxes_filt, pred_phrases
860
+
861
+ def plot_boxes_to_image(self, image_pil, tgt):
862
+ H, W = tgt["size"]
863
+ boxes = tgt["boxes"]
864
+ labels = tgt["labels"]
865
+ assert len(boxes) == len(labels), "boxes and labels must have same length"
866
+
867
+ draw = ImageDraw.Draw(image_pil)
868
+ mask = Image.new("L", image_pil.size, 0)
869
+ mask_draw = ImageDraw.Draw(mask)
870
+
871
+ # draw boxes and masks
872
+ for box, label in zip(boxes, labels):
873
+ # from 0..1 to 0..W, 0..H
874
+ box = box * torch.Tensor([W, H, W, H])
875
+ # from xywh to xyxy
876
+ box[:2] -= box[2:] / 2
877
+ box[2:] += box[:2]
878
+ # random color
879
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
880
+ # draw
881
+ x0, y0, x1, y1 = box
882
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
883
+
884
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
885
+ # draw.text((x0, y0), str(label), fill=color)
886
+
887
+ font = ImageFont.load_default()
888
+ if hasattr(font, "getbbox"):
889
+ bbox = draw.textbbox((x0, y0), str(label), font)
890
+ else:
891
+ w, h = draw.textsize(str(label), font)
892
+ bbox = (x0, y0, w + x0, y0 + h)
893
+ # bbox = draw.textbbox((x0, y0), str(label))
894
+ draw.rectangle(bbox, fill=color)
895
+ draw.text((x0, y0), str(label), fill="white")
896
+
897
+ mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=2)
898
+
899
+ return image_pil, mask
900
+
901
+ @prompts(name="Detect the Give Object",
902
+ description="useful when you only want to detect or find out given objects in the picture"
903
+ "The input to this tool should be a comma separated string of two, "
904
+ "representing the image_path, the text description of the object to be found")
905
+ def inference(self, inputs):
906
+ image_path, det_prompt = inputs.split(",")
907
+ print(f"image_path={image_path}, text_prompt={det_prompt}")
908
+ image_pil, image = self.load_image(image_path)
909
+
910
+ boxes_filt, pred_phrases = self.get_grounding_boxes(image, det_prompt)
911
+
912
+ size = image_pil.size
913
+ pred_dict = {
914
+ "boxes": boxes_filt,
915
+ "size": [size[1], size[0]], # H,W
916
+ "labels": pred_phrases, }
917
+
918
+ image_with_box = self.plot_boxes_to_image(image_pil, pred_dict)[0]
919
+
920
+ updated_image_path = get_new_image_name(image_path, func_name="detect-something")
921
+ updated_image = image_with_box.resize(size)
922
+ updated_image.save(updated_image_path)
923
+ print(
924
+ f"\nProcessed ObejectDetecting, Input Image: {image_path}, Object to be Detect {det_prompt}, "
925
+ f"Output Image: {updated_image_path}")
926
+ return updated_image_path
927
+
928
+
929
+ class Inpainting:
930
+ def __init__(self, device):
931
+ self.device = device
932
+ self.revision = 'fp16' if 'cuda' in self.device else None
933
+ self.torch_dtype = torch.float16 if 'cuda' in self.device else torch.float32
934
+
935
+ self.inpaint = StableDiffusionInpaintPipeline.from_pretrained(
936
+ "runwayml/stable-diffusion-inpainting", revision=self.revision, torch_dtype=self.torch_dtype).to(device)
937
+
938
+ def __call__(self, prompt, original_image, mask_image):
939
+ update_image = self.inpaint(prompt=prompt, image=original_image.resize((512, 512)),
940
+ mask_image=mask_image.resize((512, 512))).images[0]
941
+ return update_image
942
+
943
+
944
  class InfinityOutPainting:
945
+ template_model = True # Add this line to show this is a template model.
946
+
947
  def __init__(self, ImageCaptioning, ImageEditing, VisualQuestionAnswering):
948
+ self.llm = OpenAI(temperature=0)
949
  self.ImageCaption = ImageCaptioning
950
  self.ImageEditing = ImageEditing
951
  self.ImageVQA = VisualQuestionAnswering
963
 
964
  def get_BLIP_caption(self, image):
965
  inputs = self.ImageCaption.processor(image, return_tensors="pt").to(self.ImageCaption.device,
966
+ self.ImageCaption.torch_dtype)
967
  out = self.ImageCaption.model.generate(**inputs)
968
  BLIP_caption = self.ImageCaption.processor.decode(out[0], skip_special_tokens=True)
969
  return BLIP_caption
970
 
971
+ def check_prompt(self, prompt):
972
+ check = f"Here is a paragraph with adjectives. " \
973
+ f"{prompt} " \
974
+ f"Please change all plural forms in the adjectives to singular forms. "
975
+ return self.llm(check)
976
 
977
  def get_imagine_caption(self, image, imagine):
978
  BLIP_caption = self.get_BLIP_caption(image)
985
  f"You should make the painting as vivid and realistic as possible" \
986
  f"You can not use words like painting or picture" \
987
  f"and you should use no more than 50 words to describe it"
988
+ caption = self.llm(imagine_prompt) if imagine else BLIP_caption
989
+ caption = self.check_prompt(caption)
 
990
  print(f'BLIP observation: {BLIP_caption}, ChatGPT imagine to {caption}') if imagine else print(
991
  f'Prompt: {caption}')
992
  return caption
993
 
994
+ def resize_image(self, image, max_size=1000000, multiple=8):
995
  aspect_ratio = image.size[0] / image.size[1]
996
  new_width = int(math.sqrt(max_size * aspect_ratio))
997
  new_height = int(new_width / aspect_ratio)
1037
  out_painted_image.save(updated_image_path)
1038
  print(f"\nProcessed InfinityOutPainting, Input Image: {image_path}, Input Resolution: {resolution}, "
1039
  f"Output Image: {updated_image_path}")
1040
+ return updated_image_path
1041
+
1042
+
1043
+ class ObjectSegmenting:
1044
+ template_model = True # Add this line to show this is a template model.
1045
+
1046
+ def __init__(self, Text2Box: Text2Box, Segmenting: Segmenting):
1047
+ # self.llm = OpenAI(temperature=0)
1048
+ self.grounding = Text2Box
1049
+ self.sam = Segmenting
1050
+
1051
+ @prompts(name="Segment the given object",
1052
+ description="useful when you only want to segment the certain objects in the picture"
1053
+ "according to the given text"
1054
+ "like: segment the cat,"
1055
+ "or can you segment an obeject for me"
1056
+ "The input to this tool should be a comma separated string of two, "
1057
+ "representing the image_path, the text description of the object to be found")
1058
+ def inference(self, inputs):
1059
+ image_path, det_prompt = inputs.split(",")
1060
+ print(f"image_path={image_path}, text_prompt={det_prompt}")
1061
+ image_pil, image = self.grounding.load_image(image_path)
1062
+ boxes_filt, pred_phrases = self.grounding.get_grounding_boxes(image, det_prompt)
1063
+ updated_image_path = self.sam.segment_image_with_boxes(image_pil, image_path, boxes_filt, pred_phrases)
1064
+ print(
1065
+ f"\nProcessed ObejectSegmenting, Input Image: {image_path}, Object to be Segment {det_prompt}, "
1066
+ f"Output Image: {updated_image_path}")
1067
+ return updated_image_path
1068
+
1069
+
1070
+ class ImageEditing:
1071
+ template_model = True
1072
+
1073
+ def __init__(self, Text2Box: Text2Box, Segmenting: Segmenting, Inpainting: Inpainting):
1074
+ print(f"Initializing ImageEditing")
1075
+ self.sam = Segmenting
1076
+ self.grounding = Text2Box
1077
+ self.inpaint = Inpainting
1078
+
1079
+ def pad_edge(self, mask, padding):
1080
+ # mask Tensor [H,W]
1081
+ mask = mask.numpy()
1082
+ true_indices = np.argwhere(mask)
1083
+ mask_array = np.zeros_like(mask, dtype=bool)
1084
+ for idx in true_indices:
1085
+ padded_slice = tuple(slice(max(0, i - padding), i + padding + 1) for i in idx)
1086
+ mask_array[padded_slice] = True
1087
+ new_mask = (mask_array * 255).astype(np.uint8)
1088
+ # new_mask
1089
+ return new_mask
1090
+
1091
+ @prompts(name="Remove Something From The Photo",
1092
+ description="useful when you want to remove and object or something from the photo "
1093
+ "from its description or location. "
1094
+ "The input to this tool should be a comma separated string of two, "
1095
+ "representing the image_path and the object need to be removed. ")
1096
+ def inference_remove(self, inputs):
1097
+ image_path, to_be_removed_txt = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
1098
+ return self.inference_replace_sam(f"{image_path},{to_be_removed_txt},background")
1099
+
1100
+ @prompts(name="Replace Something From The Photo",
1101
+ description="useful when you want to replace an object from the object description or "
1102
+ "location with another object from its description. "
1103
+ "The input to this tool should be a comma separated string of three, "
1104
+ "representing the image_path, the object to be replaced, the object to be replaced with ")
1105
+ def inference_replace_sam(self, inputs):
1106
+ image_path, to_be_replaced_txt, replace_with_txt = inputs.split(",")
1107
+
1108
+ print(f"image_path={image_path}, to_be_replaced_txt={to_be_replaced_txt}")
1109
+ image_pil, image = self.grounding.load_image(image_path)
1110
+ boxes_filt, pred_phrases = self.grounding.get_grounding_boxes(image, to_be_replaced_txt)
1111
+ image = cv2.imread(image_path)
1112
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
1113
+ self.sam.sam_predictor.set_image(image)
1114
+ masks = self.sam.get_mask_with_boxes(image_pil, image, boxes_filt)
1115
+ mask = torch.sum(masks, dim=0).unsqueeze(0)
1116
+ mask = torch.where(mask > 0, True, False)
1117
+ mask = mask.squeeze(0).squeeze(0).cpu() # tensor
1118
+
1119
+ mask = self.pad_edge(mask, padding=20) # numpy
1120
+ mask_image = Image.fromarray(mask)
1121
+
1122
+ updated_image = self.inpaint(prompt=replace_with_txt, original_image=image_pil,
1123
+ mask_image=mask_image)
1124
+ updated_image_path = get_new_image_name(image_path, func_name="replace-something")
1125
+ updated_image = updated_image.resize(image_pil.size)
1126
+ updated_image.save(updated_image_path)
1127
+ print(
1128
+ f"\nProcessed ImageEditing, Input Image: {image_path}, Replace {to_be_replaced_txt} to {replace_with_txt}, "
1129
+ f"Output Image: {updated_image_path}")
1130
  return updated_image_path