hysts HF staff commited on
Commit
6349538
1 Parent(s): c37294a

Use huggingface_hub to download models

Browse files
Files changed (1) hide show
  1. app.py +23 -78
app.py CHANGED
@@ -9,6 +9,7 @@ import torch
9
  import torchvision
10
  from diffusers import AutoencoderKL, DDIMScheduler
11
  from einops import rearrange
 
12
  from omegaconf import OmegaConf
13
  from PIL import Image
14
  from torchvision import transforms
@@ -87,72 +88,19 @@ If you have any questions, please feel free to reach me out at <b>ywl@stu.pku.ed
87
 
88
  # """
89
 
90
- os.makedirs("models/personalized")
91
- os.makedirs("models/sd1-5")
92
-
93
- if not os.path.exists("models/flow_controlnet.ckpt"):
94
- os.system(
95
- f"wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/flow_controlnet.ckpt?download=true -P models/"
96
- )
97
- os.system(f"mv models/flow_controlnet.ckpt?download=true models/flow_controlnet.ckpt")
98
- print(
99
- "flow_controlnet Download!",
100
- )
101
-
102
- if not os.path.exists("models/image_controlnet.ckpt"):
103
- os.system(
104
- f"wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/image_controlnet.ckpt?download=true -P models/"
105
- )
106
- os.system(f"mv models/image_controlnet.ckpt?download=true models/image_controlnet.ckpt")
107
- print(
108
- "image_controlnet Download!",
109
- )
110
-
111
- if not os.path.exists("models/unet.ckpt"):
112
- os.system(
113
- f"wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/unet.ckpt?download=true -P models/"
114
- )
115
- os.system(f"mv models/unet.ckpt?download=true models/unet.ckpt")
116
- print(
117
- "unet Download!",
118
- )
119
 
 
 
120
 
 
 
121
  if not os.path.exists("models/sd1-5/config.json"):
122
- os.system(
123
- f"wget -q https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/config.json?download=true -P models/sd1-5/"
124
- )
125
- os.system(f"mv models/sd1-5/config.json?download=true models/sd1-5/config.json")
126
- print(
127
- "config Download!",
128
- )
129
-
130
-
131
  if not os.path.exists("models/sd1-5/unet.ckpt"):
132
- os.system(f"cp -r models/unet.ckpt models/sd1-5/unet.ckpt")
133
-
134
- # os.system(f'wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.bin?download=true -P models/sd1-5/')
135
-
136
- if not os.path.exists("models/personalized/helloobjects_V12c.safetensors"):
137
- os.system(
138
- f"wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/helloobjects_V12c.safetensors?download=true -P models/personalized"
139
- )
140
- os.system(
141
- f"mv models/personalized/helloobjects_V12c.safetensors?download=true models/personalized/helloobjects_V12c.safetensors"
142
- )
143
- print(
144
- "helloobjects_V12c Download!",
145
- )
146
-
147
-
148
- if not os.path.exists("models/personalized/TUSUN.safetensors"):
149
- os.system(
150
- f"wget -q https://huggingface.co/TencentARC/ImageConductor/resolve/main/TUSUN.safetensors?download=true -P models/personalized"
151
- )
152
- os.system(f"mv models/personalized/TUSUN.safetensors?download=true models/personalized/TUSUN.safetensors")
153
- print(
154
- "TUSUN Download!",
155
- )
156
 
157
  # mv1 = os.system(f'mv /usr/local/lib/python3.10/site-packages/gradio/helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers_bkp.py')
158
  # mv2 = os.system(f'mv helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers.py')
@@ -245,11 +193,11 @@ IMAGE_PATH = {
245
 
246
 
247
  DREAM_BOOTH = {
248
- "HelloObject": "models/personalized/helloobjects_V12c.safetensors",
249
  }
250
 
251
  LORA = {
252
- "TUSUN": "models/personalized/TUSUN.safetensors",
253
  }
254
 
255
  LORA_ALPHA = {
@@ -632,6 +580,17 @@ def delete_last_step(tracking_points, first_frame_path, drag_mode):
632
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
633
 
634
 
 
 
 
 
 
 
 
 
 
 
 
635
  block = gr.Blocks(theme=gr.themes.Soft(radius_size=gr.themes.sizes.radius_none, text_size=gr.themes.sizes.text_md))
636
  with block:
637
  with gr.Row():
@@ -644,20 +603,6 @@ with block:
644
  with gr.Row(equal_height=True):
645
  gr.Markdown(instructions)
646
 
647
- # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
648
- device = torch.device("cuda")
649
- unet_path = "models/unet.ckpt"
650
- image_controlnet_path = "models/image_controlnet.ckpt"
651
- flow_controlnet_path = "models/flow_controlnet.ckpt"
652
- ImageConductor_net = ImageConductor(
653
- device=device,
654
- unet_path=unet_path,
655
- image_controlnet_path=image_controlnet_path,
656
- flow_controlnet_path=flow_controlnet_path,
657
- height=256,
658
- width=384,
659
- model_length=16,
660
- )
661
  first_frame_path_var = gr.State(value=None)
662
  tracking_points_var = gr.State([])
663
 
 
9
  import torchvision
10
  from diffusers import AutoencoderKL, DDIMScheduler
11
  from einops import rearrange
12
+ from huggingface_hub import hf_hub_download
13
  from omegaconf import OmegaConf
14
  from PIL import Image
15
  from torchvision import transforms
 
88
 
89
  # """
90
 
91
+ flow_controlnet_path = hf_hub_download("TencentARC/ImageConductor", "flow_controlnet.ckpt")
92
+ image_controlnet_path = hf_hub_download("TencentARC/ImageConductor", "image_controlnet.ckpt")
93
+ unet_path = hf_hub_download("TencentARC/ImageConductor", "unet.ckpt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ helloobjects_path = hf_hub_download("TencentARC/ImageConductor", "helloobjects_V12c.safetensors")
96
+ tusun_path = hf_hub_download("TencentARC/ImageConductor", "TUSUN.safetensors")
97
 
98
+ os.makedirs("models/sd1-5", exist_ok=True)
99
+ sd15_config_path = hf_hub_download("runwayml/stable-diffusion-v1-5", "config.json", subfolder="unet")
100
  if not os.path.exists("models/sd1-5/config.json"):
101
+ os.symlink(sd15_config_path, "models/sd1-5/config.json")
 
 
 
 
 
 
 
 
102
  if not os.path.exists("models/sd1-5/unet.ckpt"):
103
+ os.symlink(unet_path, "models/sd1-5/unet.ckpt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # mv1 = os.system(f'mv /usr/local/lib/python3.10/site-packages/gradio/helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers_bkp.py')
106
  # mv2 = os.system(f'mv helpers.py /usr/local/lib/python3.10/site-packages/gradio/helpers.py')
 
193
 
194
 
195
  DREAM_BOOTH = {
196
+ "HelloObject": helloobjects_path,
197
  }
198
 
199
  LORA = {
200
+ "TUSUN": tusun_path,
201
  }
202
 
203
  LORA_ALPHA = {
 
580
  return {tracking_points_var: tracking_points, input_image: trajectory_map}
581
 
582
 
583
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
584
+ ImageConductor_net = ImageConductor(
585
+ device=device,
586
+ unet_path=unet_path,
587
+ image_controlnet_path=image_controlnet_path,
588
+ flow_controlnet_path=flow_controlnet_path,
589
+ height=256,
590
+ width=384,
591
+ model_length=16,
592
+ )
593
+
594
  block = gr.Blocks(theme=gr.themes.Soft(radius_size=gr.themes.sizes.radius_none, text_size=gr.themes.sizes.text_md))
595
  with block:
596
  with gr.Row():
 
603
  with gr.Row(equal_height=True):
604
  gr.Markdown(instructions)
605
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
  first_frame_path_var = gr.State(value=None)
607
  tracking_points_var = gr.State([])
608