tori29umai commited on
Commit
31877a7
1 Parent(s): 00c86c6
Files changed (2) hide show
  1. app.py +7 -7
  2. utils/dl_utils.py +6 -43
app.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
  import os
7
  import time
8
 
9
- from utils.dl_utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model
10
  from utils.image_utils import resize_image_aspect_ratio, base_generation
11
 
12
  from utils.prompt_utils import execute_prompt, remove_color, remove_duplicates
@@ -22,10 +22,10 @@ os.makedirs(cn_dir, exist_ok=True)
22
  os.makedirs(tagger_dir, exist_ok=True)
23
  os.makedirs(lora_dir, exist_ok=True)
24
 
25
- load_cn_model(cn_dir)
26
- load_cn_config(cn_dir)
27
- load_tagger_model(tagger_dir)
28
- load_lora_model(lora_dir)
29
 
30
  def load_model(lora_dir, cn_dir):
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -35,8 +35,8 @@ def load_model(lora_dir, cn_dir):
35
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
36
  "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
37
  )
38
- pipe.load_lora_weights(lora_dir, weight_name="sdxl_BW_bold_Line.safetensors")
39
- pipe.set_adapters(["sdxl_BW_bold_Line"], adapter_weights=[1.2])
40
  pipe.fuse_lora()
41
  pipe = pipe.to(device)
42
  return pipe
 
6
  import os
7
  import time
8
 
9
+ from utils.dl_utils import dl_cn_model, dl_cn_config, dl_tagger_model, dl_lora_model
10
  from utils.image_utils import resize_image_aspect_ratio, base_generation
11
 
12
  from utils.prompt_utils import execute_prompt, remove_color, remove_duplicates
 
22
  os.makedirs(tagger_dir, exist_ok=True)
23
  os.makedirs(lora_dir, exist_ok=True)
24
 
25
+ dl_cn_model(cn_dir)
26
+ dl_cn_config(cn_dir)
27
+ dl_tagger_model(tagger_dir)
28
+ dl_lora_model(lora_dir)
29
 
30
  def load_model(lora_dir, cn_dir):
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
35
  pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
36
  "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
37
  )
38
+ pipe.load_lora_weights(lora_dir, weight_name="sdxl_BW_Line.safetensors")
39
+ pipe.set_adapters(["sdxl_BW_Line"], adapter_weights=[1.4])
40
  pipe.fuse_lora()
41
  pipe = pipe.to(device)
42
  return pipe
utils/dl_utils.py CHANGED
@@ -8,7 +8,7 @@ from PIL import Image, ImageOps
8
  import numpy as np
9
  import cv2
10
 
11
- def load_cn_model(model_dir):
12
  folder = model_dir
13
  file_name = 'diffusion_pytorch_model.safetensors'
14
  url = " https://huggingface.co/2vXpSwA7/iroiro-lora/resolve/main/test_controlnet2/CN-anytest_v3-50000_fp16.safetensors"
@@ -24,7 +24,7 @@ def load_cn_model(model_dir):
24
  else:
25
  print(f'{file_name} already exists.')
26
 
27
- def load_cn_config(model_dir):
28
  folder = model_dir
29
  file_name = 'config.json'
30
  file_path = os.path.join(folder, file_name)
@@ -32,7 +32,7 @@ def load_cn_config(model_dir):
32
  config_path = os.path.join(os.getcwd(), file_name)
33
  shutil.copy(config_path, file_path)
34
 
35
- def load_tagger_model(model_dir):
36
  model_id = 'SmilingWolf/wd-swinv2-tagger-v3'
37
  files = [
38
  'config.json', 'model.onnx', 'selected_tags.csv', 'sw_jax_cv_config.json'
@@ -56,11 +56,11 @@ def load_tagger_model(model_dir):
56
  print(f'{file} already exists.')
57
 
58
 
59
- def load_lora_model(model_dir):
60
- file_name = 'sdxl_BW_bold_Line.safetensors'
61
  file_path = os.path.join(model_dir, file_name)
62
  if not os.path.exists(file_path):
63
- url = "https://huggingface.co/tori29umai/lineart/resolve/main/sdxl_BW_bold_Line.safetensors"
64
  response = requests.get(url, allow_redirects=True)
65
  if response.status_code == 200:
66
  with open(file_path, 'wb') as f:
@@ -70,40 +70,3 @@ def load_lora_model(model_dir):
70
  print(f'Failed to download {file_name}')
71
  else:
72
  print(f'{file_name} already exists.')
73
-
74
-
75
- def resize_image_aspect_ratio(image):
76
- # 元の画像サイズを取得
77
- original_width, original_height = image.size
78
-
79
- # アスペクト比を計算
80
- aspect_ratio = original_width / original_height
81
-
82
- # 標準のアスペクト比サイズを定義
83
- sizes = {
84
- 1: (1024, 1024), # 正方形
85
- 4/3: (1152, 896), # 横長画像
86
- 3/2: (1216, 832),
87
- 16/9: (1344, 768),
88
- 21/9: (1568, 672),
89
- 3/1: (1728, 576),
90
- 1/4: (512, 2048), # 縦長画像
91
- 1/3: (576, 1728),
92
- 9/16: (768, 1344),
93
- 2/3: (832, 1216),
94
- 3/4: (896, 1152)
95
- }
96
-
97
- # 最も近いアスペクト比を見つける
98
- closest_aspect_ratio = min(sizes.keys(), key=lambda x: abs(x - aspect_ratio))
99
- target_width, target_height = sizes[closest_aspect_ratio]
100
-
101
- # リサイズ処理
102
- resized_image = image.resize((target_width, target_height), Image.LANCZOS)
103
-
104
- return resized_image
105
-
106
-
107
- def base_generation(size, color):
108
- canvas = Image.new("RGBA", size, color)
109
- return canvas
 
8
  import numpy as np
9
  import cv2
10
 
11
+ def dl_cn_model(model_dir):
12
  folder = model_dir
13
  file_name = 'diffusion_pytorch_model.safetensors'
14
  url = " https://huggingface.co/2vXpSwA7/iroiro-lora/resolve/main/test_controlnet2/CN-anytest_v3-50000_fp16.safetensors"
 
24
  else:
25
  print(f'{file_name} already exists.')
26
 
27
+ def dl_cn_config(model_dir):
28
  folder = model_dir
29
  file_name = 'config.json'
30
  file_path = os.path.join(folder, file_name)
 
32
  config_path = os.path.join(os.getcwd(), file_name)
33
  shutil.copy(config_path, file_path)
34
 
35
+ def dl_tagger_model(model_dir):
36
  model_id = 'SmilingWolf/wd-swinv2-tagger-v3'
37
  files = [
38
  'config.json', 'model.onnx', 'selected_tags.csv', 'sw_jax_cv_config.json'
 
56
  print(f'{file} already exists.')
57
 
58
 
59
+ def dl_lora_model(model_dir):
60
+ file_name = 'sdxl_BW_Line.safetensors'
61
  file_path = os.path.join(model_dir, file_name)
62
  if not os.path.exists(file_path):
63
+ url = "https://huggingface.co/tori29umai/lineart/resolve/main/sdxl_BW_Line.safetensors"
64
  response = requests.get(url, allow_redirects=True)
65
  if response.status_code == 200:
66
  with open(file_path, 'wb') as f:
 
70
  print(f'Failed to download {file_name}')
71
  else:
72
  print(f'{file_name} already exists.')