RamAnanth1 commited on
Commit
c5c689b
1 Parent(s): 9c9406a

Update app.py

Browse files

Add pose model to check for memory

Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -9,6 +9,7 @@ from pytorch_lightning import seed_everything
9
  from util import resize_image, HWC3, apply_canny
10
  from ldm.models.diffusion.ddim import DDIMSampler
11
 
 
12
 
13
  from cldm.model import create_model, load_state_dict
14
 
@@ -17,14 +18,21 @@ from huggingface_hub import hf_hub_url, cached_download
17
  REPO_ID = "lllyasviel/ControlNet"
18
  canny_checkpoint = "models/control_sd15_canny.pth"
19
  scribble_checkpoint = "models/control_sd15_scribble.pth"
 
20
 
21
- canny_model = create_model('./models/cldm_v15.yaml')
22
  canny_model.load_state_dict(load_state_dict(cached_download(
23
  hf_hub_url(REPO_ID, canny_checkpoint)
24
- ), location='cpu'))
25
  canny_model = canny_model.cuda()
26
  ddim_sampler = DDIMSampler(canny_model)
27
 
 
 
 
 
 
 
28
 
29
  scribble_model = create_model('./models/cldm_v15.yaml')
30
  scribble_model.load_state_dict(load_state_dict(cached_download(
 
9
  from util import resize_image, HWC3, apply_canny
10
  from ldm.models.diffusion.ddim import DDIMSampler
11
 
12
+ from annotator.openpose import apply_openpose
13
 
14
  from cldm.model import create_model, load_state_dict
15
 
 
18
  REPO_ID = "lllyasviel/ControlNet"
19
  canny_checkpoint = "models/control_sd15_canny.pth"
20
  scribble_checkpoint = "models/control_sd15_scribble.pth"
21
+ pose_checkpoint = "models/control_sd15_openpose.pth"
22
 
23
+ canny_model = create_model('./models/cldm_v15.yaml').cpu()
24
  canny_model.load_state_dict(load_state_dict(cached_download(
25
  hf_hub_url(REPO_ID, canny_checkpoint)
26
+ ), location='cuda'))
27
  canny_model = canny_model.cuda()
28
  ddim_sampler = DDIMSampler(canny_model)
29
 
30
+ pose_model = create_model('./models/cldm_v15.yaml').cpu()
31
+ pose_model.load_state_dict(load_state_dict(cached_download(
32
+ hf_hub_url(REPO_ID, pose_checkpoint)
33
+ ), location='cuda'))
34
+ pose_model = pose_model.cuda()
35
+ ddim_sampler_pose = DDIMSampler(pose_model)
36
 
37
  scribble_model = create_model('./models/cldm_v15.yaml')
38
  scribble_model.load_state_dict(load_state_dict(cached_download(