JackAILab commited on
Commit
8eb54f6
1 Parent(s): 35e9299

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -0
app.py CHANGED
@@ -10,6 +10,10 @@ from diffusers.utils import load_image
10
  from diffusers import EulerDiscreteScheduler
11
  from pipline_StableDiffusion_ConsistentID import ConsistentIDStableDiffusionPipeline
12
  from huggingface_hub import hf_hub_download
 
 
 
 
13
 
14
  zero = torch.Tensor([0]).cuda()
15
  print(zero.device) # <-- 'cpu' 🤔
@@ -30,9 +34,17 @@ pipe = ConsistentIDStableDiffusionPipeline.from_pretrained(
30
  variant="fp16"
31
  ).to(device)
32
 
 
 
 
 
 
 
 
33
  ### Load consistentID_model checkpoint
34
  pipe.load_ConsistentID_model(
35
  os.path.dirname(consistentID_path),
 
36
  subfolder="",
37
  weight_name=os.path.basename(consistentID_path),
38
  trigger_word="img",
 
10
  from diffusers import EulerDiscreteScheduler
11
  from pipline_StableDiffusion_ConsistentID import ConsistentIDStableDiffusionPipeline
12
  from huggingface_hub import hf_hub_download
13
+ ### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
14
+ ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
15
+ ### Thanks for the open source of face-parsing model.
16
+ from models.BiSeNet.model import BiSeNet
17
 
18
  zero = torch.Tensor([0]).cuda()
19
  print(zero.device) # <-- 'cpu' 🤔
 
34
  variant="fp16"
35
  ).to(device)
36
 
37
+ ### Load other pretrained models
38
+ ## BiSenet
39
+ bise_net = BiSeNet(n_classes = 19)
40
+ bise_net.cuda() # CUDA must not be initialized in the main process on Spaces with Stateless GPU environment
41
+ bise_net_cp_path = hf_hub_download(repo_id="JackAILab/ConsistentID", filename="face_parsing.pth", repo_type="model")
42
+ bise_net.load_state_dict(torch.load(bise_net_cp_path))
43
+
44
  ### Load consistentID_model checkpoint
45
  pipe.load_ConsistentID_model(
46
  os.path.dirname(consistentID_path),
47
+ bise_net,
48
  subfolder="",
49
  weight_name=os.path.basename(consistentID_path),
50
  trigger_word="img",