Kunpeng Song commited on
Commit
b5f6f82
1 Parent(s): 7c69fc1
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. app.py +0 -5
  3. dataset_lib/dataset_eval_MoMA.py +153 -2
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -6,7 +6,6 @@ import numpy as np
6
  import torch
7
  from pytorch_lightning import seed_everything
8
  from model_lib.utils import parse_args
9
- # from llava.mm_utils import process_image
10
 
11
  os.environ["CUDA_VISIBLE_DEVICES"]="0"
12
 
@@ -18,10 +17,6 @@ args = parse_args()
18
 
19
  model = None
20
 
21
- def my_process_image(a, b, c):
22
- # return process_image(a, b, c)
23
- return (a, b, c)
24
-
25
  @spaces.GPU
26
  def inference(rgb, subject, prompt, strength, seed):
27
  seed = int(seed) if seed else 0
 
6
  import torch
7
  from pytorch_lightning import seed_everything
8
  from model_lib.utils import parse_args
 
9
 
10
  os.environ["CUDA_VISIBLE_DEVICES"]="0"
11
 
 
17
 
18
  model = None
19
 
 
 
 
 
20
  @spaces.GPU
21
  def inference(rgb, subject, prompt, strength, seed):
22
  seed = int(seed) if seed else 0
dataset_lib/dataset_eval_MoMA.py CHANGED
@@ -2,8 +2,159 @@ from PIL import Image
2
  import numpy as np
3
  import torch
4
  from torchvision import transforms
5
- from ..app import my_process_image
6
  from rembg import remove
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  def create_binary_mask(image):
9
  grayscale = image.convert("L")
@@ -38,7 +189,7 @@ def Dataset_evaluate_MoMA(image_pil, prompt,subject, moMA_main_modal):
38
  image_wb = image * mask + torch.ones_like(image)* (1-mask)*255
39
  image_pil = Image.fromarray(image_wb.permute(1,2,0).numpy().astype(np.uint8))
40
 
41
- res['llava_processed'] = my_process_image([image_pil], LLaVa_processor, llava_config)
42
  res['label'] = [subject]
43
  return res
44
 
 
2
  import numpy as np
3
  import torch
4
  from torchvision import transforms
 
5
  from rembg import remove
6
+ import ast
7
+ import math
8
+
9
+ def select_best_resolution(original_size, possible_resolutions):
10
+ """
11
+ Selects the best resolution from a list of possible resolutions based on the original size.
12
+
13
+ Args:
14
+ original_size (tuple): The original size of the image in the format (width, height).
15
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
16
+
17
+ Returns:
18
+ tuple: The best fit resolution in the format (width, height).
19
+ """
20
+ original_width, original_height = original_size
21
+ best_fit = None
22
+ max_effective_resolution = 0
23
+ min_wasted_resolution = float('inf')
24
+
25
+ for width, height in possible_resolutions:
26
+ scale = min(width / original_width, height / original_height)
27
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
28
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
29
+ wasted_resolution = (width * height) - effective_resolution
30
+
31
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
32
+ max_effective_resolution = effective_resolution
33
+ min_wasted_resolution = wasted_resolution
34
+ best_fit = (width, height)
35
+
36
+ return best_fit
37
+
38
+
39
+ def resize_and_pad_image(image, target_resolution):
40
+ """
41
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
42
+
43
+ Args:
44
+ image (PIL.Image.Image): The input image.
45
+ target_resolution (tuple): The target resolution (width, height) of the image.
46
+
47
+ Returns:
48
+ PIL.Image.Image: The resized and padded image.
49
+ """
50
+ original_width, original_height = image.size
51
+ target_width, target_height = target_resolution
52
+
53
+ scale_w = target_width / original_width
54
+ scale_h = target_height / original_height
55
+
56
+ if scale_w < scale_h:
57
+ new_width = target_width
58
+ new_height = min(math.ceil(original_height * scale_w), target_height)
59
+ else:
60
+ new_height = target_height
61
+ new_width = min(math.ceil(original_width * scale_h), target_width)
62
+
63
+ # Resize the image
64
+ resized_image = image.resize((new_width, new_height))
65
+
66
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
67
+ paste_x = (target_width - new_width) // 2
68
+ paste_y = (target_height - new_height) // 2
69
+ new_image.paste(resized_image, (paste_x, paste_y))
70
+
71
+ return new_image
72
+
73
+
74
+ def divide_to_patches(image, patch_size):
75
+ """
76
+ Divides an image into patches of a specified size.
77
+
78
+ Args:
79
+ image (PIL.Image.Image): The input image.
80
+ patch_size (int): The size of each patch.
81
+
82
+ Returns:
83
+ list: A list of PIL.Image.Image objects representing the patches.
84
+ """
85
+ patches = []
86
+ width, height = image.size
87
+ for i in range(0, height, patch_size):
88
+ for j in range(0, width, patch_size):
89
+ box = (j, i, j + patch_size, i + patch_size)
90
+ patch = image.crop(box)
91
+ patches.append(patch)
92
+
93
+ return patches
94
+
95
+
96
+ def process_anyres_image(image, processor, grid_pinpoints):
97
+ """
98
+ Process an image with variable resolutions.
99
+
100
+ Args:
101
+ image (PIL.Image.Image): The input image to be processed.
102
+ processor: The image processor object.
103
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
104
+
105
+ Returns:
106
+ torch.Tensor: A tensor containing the processed image patches.
107
+ """
108
+ if type(grid_pinpoints) is list:
109
+ possible_resolutions = grid_pinpoints
110
+ else:
111
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
112
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
113
+ image_padded = resize_and_pad_image(image, best_resolution)
114
+
115
+ patches = divide_to_patches(image_padded, processor.crop_size['height'])
116
+
117
+ image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
118
+
119
+ image_patches = [image_original_resize] + patches
120
+ image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
121
+ for image_patch in image_patches]
122
+ return torch.stack(image_patches, dim=0)
123
+
124
+
125
+
126
+ def expand2square(pil_img, background_color):
127
+ width, height = pil_img.size
128
+ if width == height:
129
+ return pil_img
130
+ elif width > height:
131
+ result = Image.new(pil_img.mode, (width, width), background_color)
132
+ result.paste(pil_img, (0, (width - height) // 2))
133
+ return result
134
+ else:
135
+ result = Image.new(pil_img.mode, (height, height), background_color)
136
+ result.paste(pil_img, ((height - width) // 2, 0))
137
+ return result
138
+
139
+
140
+ def process_images(images, image_processor, model_cfg):
141
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
142
+ new_images = []
143
+ if image_aspect_ratio == 'pad':
144
+ for image in images:
145
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
146
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
147
+ new_images.append(image)
148
+ elif image_aspect_ratio == "anyres":
149
+ for image in images:
150
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
151
+ new_images.append(image)
152
+ else:
153
+ return image_processor(images, return_tensors='pt')['pixel_values']
154
+ if all(x.shape == new_images[0].shape for x in new_images):
155
+ new_images = torch.stack(new_images, dim=0)
156
+ return new_images
157
+
158
 
159
  def create_binary_mask(image):
160
  grayscale = image.convert("L")
 
189
  image_wb = image * mask + torch.ones_like(image)* (1-mask)*255
190
  image_pil = Image.fromarray(image_wb.permute(1,2,0).numpy().astype(np.uint8))
191
 
192
+ res['llava_processed'] = process_images([image_pil], LLaVa_processor, llava_config)
193
  res['label'] = [subject]
194
  return res
195