nightfury commited on
Commit
49dc097
1 Parent(s): 986ef15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -35,27 +35,31 @@ pipe = StableDiffusionInpaintingPipeline.from_pretrained(
35
  #model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
36
  model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
37
 
 
38
  model.eval()
39
  model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False)
40
 
 
 
 
41
  transform = transforms.Compose([
42
  transforms.ToTensor(),
43
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
44
- transforms.Resize((512, 512)),
45
  ])
46
 
47
  def predict(radio, dict, word_mask, prompt=""):
48
  if(radio == "draw a mask above"):
49
  #with autocast("cuda"):
50
- with autocast(enable=(False if device=='cpu' else True)):
51
- init_image = dict["image"].convert("RGB").resize((512, 512))
52
- mask = dict["mask"].convert("RGB").resize((512, 512))
53
  else:
54
  img = transform(dict["image"]).unsqueeze(0)
55
  word_masks = [word_mask]
56
  with torch.no_grad():
57
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
58
- init_image = dict['image'].convert('RGB').resize((512, 512))
59
  filename = f"{uuid.uuid4()}.png"
60
  plt.imsave(filename,torch.sigmoid(preds[0][0]))
61
  img2 = cv2.imread(filename)
@@ -65,7 +69,7 @@ def predict(radio, dict, word_mask, prompt=""):
65
  mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
66
  os.remove(filename)
67
  #with autocast("cuda"):
68
- with autocast(enable=(False if device=='cpu' else True)):
69
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
70
  return images[0]
71
 
35
  #model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64)
36
  model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, complex_trans_conv=True)
37
 
38
+ model = model.to(torch.device(device))
39
  model.eval()
40
  model.load_state_dict(torch.load('./clipseg/weights/rd64-uni.pth', map_location=torch.device(device)), strict=False)
41
 
42
+ print ("Torch load(model) : ", model)
43
+ imgRes = 256 #512
44
+
45
  transform = transforms.Compose([
46
  transforms.ToTensor(),
47
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
48
+ transforms.Resize((imgRes, imgRes)),
49
  ])
50
 
51
  def predict(radio, dict, word_mask, prompt=""):
52
  if(radio == "draw a mask above"):
53
  #with autocast("cuda"):
54
+ with autocast(device): #enable=(False if device=='cpu' else True)):
55
+ init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
56
+ mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
57
  else:
58
  img = transform(dict["image"]).unsqueeze(0)
59
  word_masks = [word_mask]
60
  with torch.no_grad():
61
  preds = model(img.repeat(len(word_masks),1,1,1), word_masks)[0]
62
+ init_image = dict['image'].convert('RGB').resize((imgRes, imgRes))
63
  filename = f"{uuid.uuid4()}.png"
64
  plt.imsave(filename,torch.sigmoid(preds[0][0]))
65
  img2 = cv2.imread(filename)
69
  mask = Image.fromarray(np.uint8(bw_image)).convert('RGB')
70
  os.remove(filename)
71
  #with autocast("cuda"):
72
+ with autocast(device): #enable=(False if device=='cpu' else True)):
73
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
74
  return images[0]
75