RamAnanth1 commited on
Commit
1eefa67
1 Parent(s): a976ca4

Make sure process function accesses model

Browse files
Files changed (1) hide show
  1. model.py +9 -9
model.py CHANGED
@@ -76,15 +76,15 @@ class Model:
76
  subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth'))
77
 
78
 
79
- model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
80
  current_base = 'sd-v1-4.ckpt'
81
- model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
82
- model_ad.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth"))
83
  net_G = pidinet()
84
  ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
85
  net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
86
  net_G.to(device)
87
- sampler = PLMSSampler(model)
88
  save_memory=True
89
 
90
  @torch.inference_mode()
@@ -121,17 +121,17 @@ class Model:
121
  im = im.float()
122
  im_edge = tensor2img(im)
123
 
124
- c = model.get_learned_conditioning([prompt])
125
- nc = model.get_learned_conditioning([neg_prompt])
126
 
127
  with torch.no_grad():
128
  # extract condition features
129
- features_adapter = model_ad(im.to(device))
130
 
131
  shape = [4, 64, 64]
132
 
133
  # sampling
134
- samples_ddim, _ = sampler.sample(S=50,
135
  conditioning=c,
136
  batch_size=1,
137
  shape=shape,
@@ -144,7 +144,7 @@ class Model:
144
  mode = 'sketch',
145
  con_strength = con_strength)
146
 
147
- x_samples_ddim = model.decode_first_stage(samples_ddim)
148
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
149
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
150
  x_samples_ddim = 255.*x_samples_ddim
 
76
  subprocess.run(shlex.split(f'wget {pidinet_file} -O models/table5_pidinet.pth'))
77
 
78
 
79
+ self.model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
80
  current_base = 'sd-v1-4.ckpt'
81
+ self.model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
82
+ self.model_ad.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth"))
83
  net_G = pidinet()
84
  ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
85
  net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
86
  net_G.to(device)
87
+ self.sampler = PLMSSampler(self.model)
88
  save_memory=True
89
 
90
  @torch.inference_mode()
 
121
  im = im.float()
122
  im_edge = tensor2img(im)
123
 
124
+ c = self.model.get_learned_conditioning([prompt])
125
+ nc = self.model.get_learned_conditioning([neg_prompt])
126
 
127
  with torch.no_grad():
128
  # extract condition features
129
+ features_adapter = self.model_ad(im.to(device))
130
 
131
  shape = [4, 64, 64]
132
 
133
  # sampling
134
+ samples_ddim, _ = self.sampler.sample(S=50,
135
  conditioning=c,
136
  batch_size=1,
137
  shape=shape,
 
144
  mode = 'sketch',
145
  con_strength = con_strength)
146
 
147
+ x_samples_ddim = self.model.decode_first_stage(samples_ddim)
148
  x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
149
  x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
150
  x_samples_ddim = 255.*x_samples_ddim