multimodalart HF staff commited on
Commit
63eabb5
1 Parent(s): 73856e8

Update convert.py

Browse files
Files changed (1) hide show
  1. convert.py +9 -4
convert.py CHANGED
@@ -38,15 +38,20 @@ def convert_single(model_id: str, token:str, filename: str, model_type: str, sam
38
  config_url = (Path(model_id)/"resolve/main"/filename).with_suffix(".yaml")
39
  config_url = "https://huggingface.co/" + str(config_url)
40
 
41
- config_file = BytesIO(requests.get(config_url).content)
42
-
 
 
 
 
 
43
  if model_type == "ControlNet":
44
  progress(0.2, desc="Converting ControlNet Model")
45
- pipeline = download_controlnet_from_original_ckpt(ckpt_file, config_file, image_size=sample_size, from_safetensors=from_safetensors, extract_ema=extract_ema)
46
  to_args = {"dtype": torch.float16}
47
  else:
48
  progress(0.1, desc="Converting Model")
49
- pipeline = download_from_original_stable_diffusion_ckpt(ckpt_file, config_file, image_size=sample_size, scheduler_type=scheduler_type, from_safetensors=from_safetensors, extract_ema=extract_ema)
50
  to_args = {"torch_dtype": torch.float16}
51
 
52
  pipeline.save_pretrained(folder)
 
38
  config_url = (Path(model_id)/"resolve/main"/filename).with_suffix(".yaml")
39
  config_url = "https://huggingface.co/" + str(config_url)
40
 
41
+ #config_file = BytesIO(requests.get(config_url).content)
42
+
43
+ response = requests.get(config_url)
44
+ with tempfile.NamedTemporaryFile(delete=False, mode='wb') as tmp_file:
45
+ tmp_file.write(response.content)
46
+ temp_config_file_path = tmp_file.name
47
+
48
  if model_type == "ControlNet":
49
  progress(0.2, desc="Converting ControlNet Model")
50
+ pipeline = download_controlnet_from_original_ckpt(ckpt_file, temp_config_file_path, image_size=sample_size, from_safetensors=from_safetensors, extract_ema=extract_ema)
51
  to_args = {"dtype": torch.float16}
52
  else:
53
  progress(0.1, desc="Converting Model")
54
+ pipeline = download_from_original_stable_diffusion_ckpt(ckpt_file, temp_config_file_path, image_size=sample_size, scheduler_type=scheduler_type, from_safetensors=from_safetensors, extract_ema=extract_ema)
55
  to_args = {"torch_dtype": torch.float16}
56
 
57
  pipeline.save_pretrained(folder)