cocktailpeanut commited on
Commit
e9640d5
1 Parent(s): 3961dc5
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -24,8 +24,10 @@ import gradio as gr
24
  # global variable
25
  MAX_SEED = np.iinfo(np.int32).max
26
  #device = "cuda" if torch.cuda.is_available() else "cpu"
 
27
  if torch.backends.mps.is_available():
28
  device = "mps"
 
29
  elif torch.cuda.is_available():
30
  device = "cuda"
31
  else:
@@ -48,27 +50,30 @@ face_adapter = f'./checkpoints/ip-adapter.bin'
48
  controlnet_path = f'./checkpoints/ControlNetModel'
49
 
50
  # Load pipeline
51
- controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
 
52
 
53
  base_model_path = 'wangqixun/YamerMIX_v8'
54
 
55
  pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
56
  base_model_path,
57
  controlnet=controlnet,
58
- torch_dtype=torch.float16,
 
59
  safety_checker=None,
60
  feature_extractor=None,
61
  )
62
  if device == 'mps':
63
- pipe.to("mps")
 
64
  elif device == 'cuda':
65
- pipe.cuda()
66
  pipe.load_ip_adapter_instantid(face_adapter)
67
  #pipe.image_proj_model.to('cuda')
68
  #pipe.unet.to('cuda')
69
  if device == 'mps' or device == 'cuda':
70
- pipe.image_proj_model.to(device)
71
- pipe.unet.to(device)
72
 
73
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
74
  if randomize_seed:
 
24
  # global variable
25
  MAX_SEED = np.iinfo(np.int32).max
26
  #device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ torch_dtype = torch.float16
28
  if torch.backends.mps.is_available():
29
  device = "mps"
30
+ torch_dtype = torch.float32
31
  elif torch.cuda.is_available():
32
  device = "cuda"
33
  else:
 
50
  controlnet_path = f'./checkpoints/ControlNetModel'
51
 
52
  # Load pipeline
53
+ #controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
54
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype)
55
 
56
  base_model_path = 'wangqixun/YamerMIX_v8'
57
 
58
  pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
59
  base_model_path,
60
  controlnet=controlnet,
61
+ #torch_dtype=torch.float16,
62
+ torch_dtype=torch_dtype,
63
  safety_checker=None,
64
  feature_extractor=None,
65
  )
66
  if device == 'mps':
67
+ pipe.to("mps", torch_dtype)
68
+ pipe.enable_attention_slicing()
69
  elif device == 'cuda':
70
+ pipe.cuda()
71
  pipe.load_ip_adapter_instantid(face_adapter)
72
  #pipe.image_proj_model.to('cuda')
73
  #pipe.unet.to('cuda')
74
  if device == 'mps' or device == 'cuda':
75
+ pipe.image_proj_model.to(device)
76
+ pipe.unet.to(device)
77
 
78
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
79
  if randomize_seed: