liuyizhang commited on
Commit
6e3e561
1 Parent(s): 1c6fec7

update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -434,7 +434,9 @@ def concatenate_images_vertical(image1, image2):
434
 
435
  return new_image
436
 
437
- def relate_anything(input_image, k):
 
 
438
  w, h = input_image.size
439
  max_edge = 1500
440
  if w > max_edge or h > max_edge:
@@ -442,12 +444,14 @@ def relate_anything(input_image, k):
442
  new_size = (int(w / ratio), int(h / ratio))
443
  input_image.thumbnail(new_size)
444
 
 
445
  # load image
446
  pil_image = input_image.convert('RGBA')
447
  image = np.array(input_image)
448
  sam_masks = sam_mask_generator.generate(image)
449
  filtered_masks = sort_and_deduplicate(sam_masks)
450
 
 
451
  feat_list = []
452
  for fm in filtered_masks:
453
  feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
@@ -455,6 +459,7 @@ def relate_anything(input_image, k):
455
  feat = torch.cat(feat_list, dim=1).to(device)
456
  matrix_output, rel_triplets = ram_model.predict(feat)
457
 
 
458
  pil_image_list = []
459
  for i, rel in enumerate(rel_triplets[:k]):
460
  s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
@@ -473,6 +478,7 @@ def relate_anything(input_image, k):
473
  concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
474
  pil_image_list.append(concate_pil_image)
475
 
 
476
  yield pil_image_list
477
 
478
 
 
434
 
435
  return new_image
436
 
437
+ def relate_anything(input_image_mask, k):
438
+ logger.info(f'relate_anything_1_')
439
+ input_image = input_image_mask['image']
440
  w, h = input_image.size
441
  max_edge = 1500
442
  if w > max_edge or h > max_edge:
 
444
  new_size = (int(w / ratio), int(h / ratio))
445
  input_image.thumbnail(new_size)
446
 
447
+ logger.info(f'relate_anything_2_')
448
  # load image
449
  pil_image = input_image.convert('RGBA')
450
  image = np.array(input_image)
451
  sam_masks = sam_mask_generator.generate(image)
452
  filtered_masks = sort_and_deduplicate(sam_masks)
453
 
454
+ logger.info(f'relate_anything_3_')
455
  feat_list = []
456
  for fm in filtered_masks:
457
  feat = torch.Tensor(fm['feat']).unsqueeze(0).unsqueeze(0).to(device)
 
459
  feat = torch.cat(feat_list, dim=1).to(device)
460
  matrix_output, rel_triplets = ram_model.predict(feat)
461
 
462
+ logger.info(f'relate_anything_4_')
463
  pil_image_list = []
464
  for i, rel in enumerate(rel_triplets[:k]):
465
  s,o,r = int(rel[0]),int(rel[1]),int(rel[2])
 
478
  concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
479
  pil_image_list.append(concate_pil_image)
480
 
481
+ logger.info(f'relate_anything_5_')
482
  yield pil_image_list
483
 
484