hysts commited on
Commit
3a3e71c
·
1 Parent(s): 385c783

Use the original models

Browse files
Files changed (1) hide show
  1. model.py +5 -42
model.py CHANGED
@@ -40,24 +40,11 @@ ORIGINAL_MODEL_NAMES = {
40
  }
41
  ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
42
 
43
- LIGHTWEIGHT_MODEL_NAMES = {
44
- 'canny': 'control_canny-fp16.safetensors',
45
- 'hough': 'control_mlsd-fp16.safetensors',
46
- 'hed': 'control_hed-fp16.safetensors',
47
- 'scribble': 'control_scribble-fp16.safetensors',
48
- 'pose': 'control_openpose-fp16.safetensors',
49
- 'seg': 'control_seg-fp16.safetensors',
50
- 'depth': 'control_depth-fp16.safetensors',
51
- 'normal': 'control_normal-fp16.safetensors',
52
- }
53
- LIGHTWEIGHT_WEIGHT_ROOT = 'https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/'
54
-
55
 
56
  class Model:
57
  def __init__(self,
58
  model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
59
- model_dir: str = 'models',
60
- use_lightweight: bool = True):
61
  self.device = torch.device(
62
  'cuda:0' if torch.cuda.is_available() else 'cpu')
63
  self.model = create_model(model_config_path).to(self.device)
@@ -67,40 +54,16 @@ class Model:
67
  self.model_dir = pathlib.Path(model_dir)
68
  self.model_dir.mkdir(exist_ok=True, parents=True)
69
 
70
- self.use_lightweight = use_lightweight
71
- if use_lightweight:
72
- self.model_names = LIGHTWEIGHT_MODEL_NAMES
73
- self.weight_root = LIGHTWEIGHT_WEIGHT_ROOT
74
- base_model_url = 'https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors'
75
- self.load_base_model(base_model_url)
76
- else:
77
- self.model_names = ORIGINAL_MODEL_NAMES
78
- self.weight_root = ORIGINAL_WEIGHT_ROOT
79
  self.download_models()
80
 
81
- def download_base_model(self, model_url: str) -> pathlib.Path:
82
- model_name = model_url.split('/')[-1]
83
- out_path = self.model_dir / model_name
84
- if not out_path.exists():
85
- subprocess.run(shlex.split(f'wget {model_url} -O {out_path}'))
86
- return out_path
87
-
88
- def load_base_model(self, model_url: str) -> None:
89
- model_path = self.download_base_model(model_url)
90
- self.model.load_state_dict(load_state_dict(model_path,
91
- location=self.device.type),
92
- strict=False)
93
-
94
  def load_weight(self, task_name: str) -> None:
95
  if task_name == self.task_name:
96
  return
97
  weight_path = self.get_weight_path(task_name)
98
- if not self.use_lightweight:
99
- self.model.load_state_dict(
100
- load_state_dict(weight_path, location=self.device))
101
- else:
102
- self.model.control_model.load_state_dict(
103
- load_state_dict(weight_path, location=self.device.type))
104
  self.task_name = task_name
105
 
106
  def get_weight_path(self, task_name: str) -> str:
 
40
  }
41
  ORIGINAL_WEIGHT_ROOT = 'https://huggingface.co/lllyasviel/ControlNet/resolve/main/models/'
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  class Model:
45
  def __init__(self,
46
  model_config_path: str = 'ControlNet/models/cldm_v15.yaml',
47
+ model_dir: str = 'models'):
 
48
  self.device = torch.device(
49
  'cuda:0' if torch.cuda.is_available() else 'cpu')
50
  self.model = create_model(model_config_path).to(self.device)
 
54
  self.model_dir = pathlib.Path(model_dir)
55
  self.model_dir.mkdir(exist_ok=True, parents=True)
56
 
57
+ self.model_names = ORIGINAL_MODEL_NAMES
58
+ self.weight_root = ORIGINAL_WEIGHT_ROOT
 
 
 
 
 
 
 
59
  self.download_models()
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def load_weight(self, task_name: str) -> None:
62
  if task_name == self.task_name:
63
  return
64
  weight_path = self.get_weight_path(task_name)
65
+ self.model.load_state_dict(
66
+ load_state_dict(weight_path, location=self.device))
 
 
 
 
67
  self.task_name = task_name
68
 
69
  def get_weight_path(self, task_name: str) -> str: