liuhaotian commited on
Commit
087de09
Β·
1 Parent(s): 01d67d8
app.py CHANGED
@@ -27,8 +27,8 @@ def parse_option():
27
  parser.add_argument("--guidance_scale", type=float, default=5, help="")
28
  parser.add_argument("--alpha_scale", type=float, default=1, help="scale tanh(alpha). If 0, the behaviour is same as original model")
29
  parser.add_argument("--load-text-box-generation", type=arg_bool, default=True, help="Load text-box generation pipeline.")
30
- parser.add_argument("--load-text-box-inpainting", type=arg_bool, default=False, help="Load text-box inpainting pipeline.")
31
- parser.add_argument("--load-text-image-box-generation", type=arg_bool, default=False, help="Load text-image-box generation pipeline.")
32
  args = parser.parse_args()
33
  return args
34
  args = parse_option()
 
27
  parser.add_argument("--guidance_scale", type=float, default=5, help="")
28
  parser.add_argument("--alpha_scale", type=float, default=1, help="scale tanh(alpha). If 0, the behaviour is same as original model")
29
  parser.add_argument("--load-text-box-generation", type=arg_bool, default=True, help="Load text-box generation pipeline.")
30
+ parser.add_argument("--load-text-box-inpainting", type=arg_bool, default=True, help="Load text-box inpainting pipeline.")
31
+ parser.add_argument("--load-text-image-box-generation", type=arg_bool, default=True, help="Load text-image-box generation pipeline.")
32
  args = parser.parse_args()
33
  return args
34
  args = parse_option()
dataset/tsv_dataset.py CHANGED
@@ -190,7 +190,7 @@ class TSVDataset(BaseDataset):
190
  self.which_layer_image = which_layer[1]
191
 
192
  #self.projection_matrix = torch.load(os.path.join(os.path.dirname(__file__), 'projection_matrix') )
193
- self.projection_matrix = torch.load('projection_matrix')
194
 
195
  # Load tsv data
196
  self.tsv_file = TSVFile(self.tsv_path)
 
190
  self.which_layer_image = which_layer[1]
191
 
192
  #self.projection_matrix = torch.load(os.path.join(os.path.dirname(__file__), 'projection_matrix') )
193
+ self.projection_matrix = torch.load('projection_matrix.pth')
194
 
195
  # Load tsv data
196
  self.tsv_file = TSVFile(self.tsv_path)
gligen/projection_matrix.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:819d51fde084e16e5960323c8bafba07fa8ee727e5403e5e4bdced4333c68faa
3
+ size 2360043
gligen/task_grounded_generation.py CHANGED
@@ -107,7 +107,7 @@ def get_clip_feature(model, processor, input, is_image=False):
107
  if feature_type[1] == 'after_renorm':
108
  feature = feature*28.7
109
  if feature_type[1] == 'after_reproject':
110
- feature = project( feature, torch.load('gligen/projection_matrix').cuda().T ).squeeze(0)
111
  feature = ( feature / feature.norm() ) * 28.7
112
  feature = feature.unsqueeze(0)
113
  else:
@@ -249,16 +249,9 @@ def grounded_generation_box(loaded_model_list, instruction, *args, **kwargs):
249
 
250
 
251
  # ------------- other logistics ------------- #
252
- os.makedirs( os.path.join(save_folder, 'images'), exist_ok=True)
253
- os.makedirs( os.path.join(save_folder, 'layout'), exist_ok=True)
254
- os.makedirs( os.path.join(save_folder, 'overlay'), exist_ok=True)
255
-
256
- start = len( os.listdir(os.path.join(save_folder, 'images')) )
257
- image_ids = list(range(start,start+batch_size))
258
- print(image_ids)
259
 
260
  sample_list = []
261
- for image_id, sample in zip(image_ids, samples_fake):
262
  sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5
263
  sample = sample.cpu().numpy().transpose(1,2,0) * 255
264
  sample = Image.fromarray(sample.astype(np.uint8))
 
107
  if feature_type[1] == 'after_renorm':
108
  feature = feature*28.7
109
  if feature_type[1] == 'after_reproject':
110
+ feature = project( feature, torch.load('gligen/projection_matrix.pth').cuda().T ).squeeze(0)
111
  feature = ( feature / feature.norm() ) * 28.7
112
  feature = feature.unsqueeze(0)
113
  else:
 
249
 
250
 
251
  # ------------- other logistics ------------- #
 
 
 
 
 
 
 
252
 
253
  sample_list = []
254
+ for sample in samples_fake:
255
  sample = torch.clamp(sample, min=-1, max=1) * 0.5 + 0.5
256
  sample = sample.cpu().numpy().transpose(1,2,0) * 255
257
  sample = Image.fromarray(sample.astype(np.uint8))