feishen29 commited on
Commit
6d0321f
·
verified ·
1 Parent(s): 1a8ca3d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -64,7 +64,7 @@ parser.add_argument('--model_ckpt',
64
  default="/home/sf/weights/sd_stage2/mp_rank_00_model_states_0628.pt",
65
  type=str)
66
  parser.add_argument('--output_path', type=str, default="./output_ipa_control_resampler")
67
- parser.add_argument('--device', type=str, default="cuda:0")
68
  args = parser.parse_args()
69
 
70
  # svae path
@@ -73,6 +73,8 @@ output_path = args.output_path
73
  if not os.path.exists(output_path):
74
  os.makedirs(output_path)
75
 
 
 
76
 
77
  generator = torch.Generator(device=args.device).manual_seed(42)
78
  vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_path).to(dtype=torch.float16, device=args.device)
 
64
  default="/home/sf/weights/sd_stage2/mp_rank_00_model_states_0628.pt",
65
  type=str)
66
  parser.add_argument('--output_path', type=str, default="./output_ipa_control_resampler")
67
+ # parser.add_argument('--device', type=str, default="cuda:0")
68
  args = parser.parse_args()
69
 
70
  # svae path
 
73
  if not os.path.exists(output_path):
74
  os.makedirs(output_path)
75
 
76
+ device = "cuda" if torch.cuda.is_available() else "cpu"
77
+ args.device = device
78
 
79
  generator = torch.Generator(device=args.device).manual_seed(42)
80
  vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_path).to(dtype=torch.float16, device=args.device)