Spaces:
Running
Running
hysts
commited on
Commit
•
3a3e71c
1
Parent(s):
385c783
Use the original models
Browse files
model.py
CHANGED
@@ -40,24 +40,11 @@ ORIGINAL_MODEL_NAMES = {
|
|
40 |
}
|
41 |
ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
|
42 |
|
43 |
-
LIGHTWEIGHT_MODEL_NAMES = {
|
44 |
-
'canny': 'control_canny-fp16.safetensors',
|
45 |
-
'hough': 'control_mlsd-fp16.safetensors',
|
46 |
-
'hed': 'control_hed-fp16.safetensors',
|
47 |
-
'scribble': 'control_scribble-fp16.safetensors',
|
48 |
-
'pose': 'control_openpose-fp16.safetensors',
|
49 |
-
'seg': 'control_seg-fp16.safetensors',
|
50 |
-
'depth': 'control_depth-fp16.safetensors',
|
51 |
-
'normal': 'control_normal-fp16.safetensors',
|
52 |
-
}
|
53 |
-
LIGHTWEIGHT_WEIGHT_ROOT = 'https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/'
|
54 |
-
|
55 |
|
56 |
class Model:
|
57 |
def __init__(self,
|
58 |
model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
|
59 |
-
model_dir: str = 'models'
|
60 |
-
use_lightweight: bool = True):
|
61 |
self.device = torch.device(
|
62 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
63 |
self.model = create_model(model_config_path).to(self.device)
|
@@ -67,40 +54,16 @@ class Model:
|
|
67 |
self.model_dir = pathlib.Path(model_dir)
|
68 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
69 |
|
70 |
-
self.
|
71 |
-
|
72 |
-
self.model_names = LIGHTWEIGHT_MODEL_NAMES
|
73 |
-
self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
|
74 |
-
base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
|
75 |
-
self.load_base_model(base_model_url)
|
76 |
-
else:
|
77 |
-
self.model_names = ORIGINAL_MODEL_NAMES
|
78 |
-
self.weight_root = ORIGINAL_WEIGHT_ROOT
|
79 |
self.download_models()
|
80 |
|
81 |
-
def download_base_model(self, model_url: str) -> pathlib.Path:
|
82 |
-
model_name = model_url.split('/')[-1]
|
83 |
-
out_path = self.model_dir / model_name
|
84 |
-
if not out_path.exists():
|
85 |
-
subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
|
86 |
-
return out_path
|
87 |
-
|
88 |
-
def load_base_model(self, model_url: str) -> None:
|
89 |
-
model_path = self.download_base_model(model_url)
|
90 |
-
self.model.load_state_dict(load_state_dict(model_path,
|
91 |
-
location=self.device.type),
|
92 |
-
strict=False)
|
93 |
-
|
94 |
def load_weight(self, task_name: str) -> None:
|
95 |
if task_name == self.task_name:
|
96 |
return
|
97 |
weight_path = self.get_weight_path(task_name)
|
98 |
-
|
99 |
-
self.
|
100 |
-
load_state_dict(weight_path, location=self.device))
|
101 |
-
else:
|
102 |
-
self.model.control_model.load_state_dict(
|
103 |
-
load_state_dict(weight_path, location=self.device.type))
|
104 |
self.task_name = task_name
|
105 |
|
106 |
def get_weight_path(self, task_name: str) -> str:
|
|
|
40 |
}
|
41 |
ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
class Model:
|
45 |
def __init__(self,
|
46 |
model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
|
47 |
+
model_dir: str = 'models'):
|
|
|
48 |
self.device = torch.device(
|
49 |
'cuda:0' if torch.cuda.is_available() else 'cpu')
|
50 |
self.model = create_model(model_config_path).to(self.device)
|
|
|
54 |
self.model_dir = pathlib.Path(model_dir)
|
55 |
self.model_dir.mkdir(exist_ok=True, parents=True)
|
56 |
|
57 |
+
self.model_names = ORIGINAL_MODEL_NAMES
|
58 |
+
self.weight_root = ORIGINAL_WEIGHT_ROOT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
self.download_models()
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def load_weight(self, task_name: str) -> None:
|
62 |
if task_name == self.task_name:
|
63 |
return
|
64 |
weight_path = self.get_weight_path(task_name)
|
65 |
+
self.model.load_state_dict(
|
66 |
+
load_state_dict(weight_path, location=self.device))
|
|
|
|
|
|
|
|
|
67 |
self.task_name = task_name
|
68 |
|
69 |
def get_weight_path(self, task_name: str) -> str:
|