RamAnanth1 commited on
Commit
448a2e3
1 Parent(s): e2e2967

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +6 -2
model.py CHANGED
@@ -211,8 +211,10 @@ class Model:
211
  device = 'cuda'
212
  if base_model == 'sd-v1-4.ckpt':
213
  model = self.model
 
214
  else:
215
  model = self.model_anything
 
216
  # if current_base != base_model:
217
  # ckpt = os.path.join("models", base_model)
218
  # pl_sd = torch.load(ckpt, map_location="cpu")
@@ -254,7 +256,7 @@ class Model:
254
  shape = [4, 64, 64]
255
 
256
  # sampling
257
- samples_ddim, _ = self.sampler.sample(S=50,
258
  conditioning=c,
259
  batch_size=1,
260
  shape=shape,
@@ -283,8 +285,10 @@ class Model:
283
  device = 'cuda'
284
  if base_model == 'sd-v1-4.ckpt':
285
  model = self.model
 
286
  else:
287
  model = self.model_anything
 
288
  # if current_base != base_model:
289
  # ckpt = os.path.join("models", base_model)
290
  # pl_sd = torch.load(ckpt, map_location="cpu")
@@ -347,7 +351,7 @@ class Model:
347
  shape = [4, 64, 64]
348
 
349
  # sampling
350
- samples_ddim, _ = self.sampler.sample(S=50,
351
  conditioning=c,
352
  batch_size=1,
353
  shape=shape,
 
211
  device = 'cuda'
212
  if base_model == 'sd-v1-4.ckpt':
213
  model = self.model
214
+ sampler = self.sampler
215
  else:
216
  model = self.model_anything
217
+ sampler = self.sampler_anything
218
  # if current_base != base_model:
219
  # ckpt = os.path.join("models", base_model)
220
  # pl_sd = torch.load(ckpt, map_location="cpu")
 
256
  shape = [4, 64, 64]
257
 
258
  # sampling
259
+ samples_ddim, _ = sampler.sample(S=50,
260
  conditioning=c,
261
  batch_size=1,
262
  shape=shape,
 
285
  device = 'cuda'
286
  if base_model == 'sd-v1-4.ckpt':
287
  model = self.model
288
+ sampler = self.sampler
289
  else:
290
  model = self.model_anything
291
+ sampler = self.sampler_anything
292
  # if current_base != base_model:
293
  # ckpt = os.path.join("models", base_model)
294
  # pl_sd = torch.load(ckpt, map_location="cpu")
 
351
  shape = [4, 64, 64]
352
 
353
  # sampling
354
+ samples_ddim, _ = sampler.sample(S=50,
355
  conditioning=c,
356
  batch_size=1,
357
  shape=shape,