yeq6x commited on
Commit
c8099eb
1 Parent(s): d086933

デコレータを最小スコープに

Browse files
Files changed (2) hide show
  1. app.py +0 -2
  2. scripts/process_utils.py +16 -9
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import spaces
3
  import os
4
  import io
5
  from PIL import Image
@@ -23,7 +22,6 @@ def process_image(input_image, mode, weight1, weight2):
23
 
24
  return sotai_pil, sketch_pil
25
 
26
- @spaces.GPU
27
  def gradio_process_image(input_image, mode, weight1, weight2):
28
  sotai_image, sketch_image = process_image(input_image, mode, weight1, weight2)
29
  return sotai_image, sketch_image
 
1
  import gradio as gr
 
2
  import os
3
  import io
4
  from PIL import Image
 
22
 
23
  return sotai_pil, sketch_pil
24
 
 
25
  def gradio_process_image(input_image, mode, weight1, weight2):
26
  sotai_image, sketch_image = process_image(input_image, mode, weight1, weight2)
27
  return sotai_image, sketch_image
scripts/process_utils.py CHANGED
@@ -14,6 +14,8 @@ from peft import PeftModel
14
  from dotenv import load_dotenv
15
  from scripts.hf_utils import download_file
16
 
 
 
17
  # グローバル変数
18
  use_local = False
19
  model = None
@@ -33,6 +35,7 @@ def ensure_rgb(image):
33
  return image.convert('RGB')
34
  return image
35
 
 
36
  def initialize(_use_local=False, use_gpu=False, use_dotenv=False):
37
  if use_dotenv:
38
  load_dotenv()
@@ -52,6 +55,7 @@ def load_lora(pipeline, lora_path, alpha=0.75):
52
  pipeline.load_lora_weights(lora_path)
53
  pipeline.fuse_lora(lora_scale=alpha)
54
 
 
55
  def initialize_sotai_model():
56
  global device, torch_dtype
57
 
@@ -65,19 +69,19 @@ def initialize_sotai_model():
65
  sotai_sd_model_path,
66
  torch_dtype=torch_dtype,
67
  use_safetensors=True
68
- )
69
 
70
  # Load the ControlNet model
71
  controlnet1 = ControlNetModel.from_single_file(
72
  controlnet_path1,
73
  torch_dtype=torch_dtype
74
- )
75
 
76
  # Load the ControlNet model
77
  controlnet2 = ControlNetModel.from_single_file(
78
  controlnet_path2,
79
  torch_dtype=torch_dtype
80
- )
81
 
82
  # Create the ControlNet pipeline
83
  sotai_gen_pipe = StableDiffusionControlNetPipeline(
@@ -89,7 +93,7 @@ def initialize_sotai_model():
89
  safety_checker=sd_pipe.safety_checker,
90
  feature_extractor=sd_pipe.feature_extractor,
91
  controlnet=[controlnet1, controlnet2]
92
- )
93
 
94
  # LoRAの適用
95
  lora_names = [
@@ -106,6 +110,7 @@ def initialize_sotai_model():
106
 
107
  return sotai_gen_pipe
108
 
 
109
  def initialize_refine_model():
110
  global device, torch_dtype
111
 
@@ -119,23 +124,23 @@ def initialize_refine_model():
119
  refine_sd_model_path,
120
  torch_dtype=torch_dtype,
121
  use_safetensors=True
122
- )
123
 
124
  # controlnet_path = "models/cn/control_v11p_sd15_canny.pth"
125
  controlnet1 = ControlNetModel.from_single_file(
126
  controlnet_path3,
127
  torch_dtype=torch_dtype
128
- )
129
 
130
  # Load the ControlNet model
131
  controlnet2 = ControlNetModel.from_single_file(
132
  controlnet_path4,
133
  torch_dtype=torch_dtype
134
- )
135
 
136
  # Create the ControlNet pipeline
137
  refine_gen_pipe = StableDiffusionControlNetPipeline(
138
- vae=AutoencoderKL.from_single_file(vae_path, torch_dtype=torch_dtype),
139
  text_encoder=sd_pipe.text_encoder,
140
  tokenizer=sd_pipe.tokenizer,
141
  unet=sd_pipe.unet,
@@ -143,7 +148,7 @@ def initialize_refine_model():
143
  safety_checker=sd_pipe.safety_checker,
144
  feature_extractor=sd_pipe.feature_extractor,
145
  controlnet=[controlnet1, controlnet2], # 複数のControlNetを指定
146
- )
147
 
148
  # スケジューラーの設定
149
  refine_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(refine_gen_pipe.scheduler.config)
@@ -201,6 +206,7 @@ def create_rgba_image(binary_image: np.ndarray, color: list) -> Image.Image:
201
  rgba_image[:, :, 3] = binary_image
202
  return Image.fromarray(rgba_image, 'RGBA')
203
 
 
204
  def generate_sotai_image(input_image: Image.Image, output_width: int, output_height: int) -> Image.Image:
205
  input_image = ensure_rgb(input_image)
206
  global sotai_gen_pipe
@@ -245,6 +251,7 @@ def generate_sotai_image(input_image: Image.Image, output_width: int, output_hei
245
  torch.cuda.empty_cache()
246
  gc.collect()
247
 
 
248
  def generate_refined_image(prompt: str, original_image: Image.Image, output_width: int, output_height: int, weight1: float, weight2: float) -> Image.Image:
249
  original_image = ensure_rgb(original_image)
250
  global refine_gen_pipe
 
14
  from dotenv import load_dotenv
15
  from scripts.hf_utils import download_file
16
 
17
+ import spaces
18
+
19
  # グローバル変数
20
  use_local = False
21
  model = None
 
35
  return image.convert('RGB')
36
  return image
37
 
38
+ @spaces.GPU
39
  def initialize(_use_local=False, use_gpu=False, use_dotenv=False):
40
  if use_dotenv:
41
  load_dotenv()
 
55
  pipeline.load_lora_weights(lora_path)
56
  pipeline.fuse_lora(lora_scale=alpha)
57
 
58
+ @spaces.GPU
59
  def initialize_sotai_model():
60
  global device, torch_dtype
61
 
 
69
  sotai_sd_model_path,
70
  torch_dtype=torch_dtype,
71
  use_safetensors=True
72
+ ).to(device)
73
 
74
  # Load the ControlNet model
75
  controlnet1 = ControlNetModel.from_single_file(
76
  controlnet_path1,
77
  torch_dtype=torch_dtype
78
+ ).to(device)
79
 
80
  # Load the ControlNet model
81
  controlnet2 = ControlNetModel.from_single_file(
82
  controlnet_path2,
83
  torch_dtype=torch_dtype
84
+ ).to(device)
85
 
86
  # Create the ControlNet pipeline
87
  sotai_gen_pipe = StableDiffusionControlNetPipeline(
 
93
  safety_checker=sd_pipe.safety_checker,
94
  feature_extractor=sd_pipe.feature_extractor,
95
  controlnet=[controlnet1, controlnet2]
96
+ ).to(device)
97
 
98
  # LoRAの適用
99
  lora_names = [
 
110
 
111
  return sotai_gen_pipe
112
 
113
+ @spaces.GPU
114
  def initialize_refine_model():
115
  global device, torch_dtype
116
 
 
124
  refine_sd_model_path,
125
  torch_dtype=torch_dtype,
126
  use_safetensors=True
127
+ ).to(device)
128
 
129
  # controlnet_path = "models/cn/control_v11p_sd15_canny.pth"
130
  controlnet1 = ControlNetModel.from_single_file(
131
  controlnet_path3,
132
  torch_dtype=torch_dtype
133
+ ).to(device)
134
 
135
  # Load the ControlNet model
136
  controlnet2 = ControlNetModel.from_single_file(
137
  controlnet_path4,
138
  torch_dtype=torch_dtype
139
+ ).to(device)
140
 
141
  # Create the ControlNet pipeline
142
  refine_gen_pipe = StableDiffusionControlNetPipeline(
143
+ vae=AutoencoderKL.from_single_file(vae_path, torch_dtype=torch_dtype).to(device),
144
  text_encoder=sd_pipe.text_encoder,
145
  tokenizer=sd_pipe.tokenizer,
146
  unet=sd_pipe.unet,
 
148
  safety_checker=sd_pipe.safety_checker,
149
  feature_extractor=sd_pipe.feature_extractor,
150
  controlnet=[controlnet1, controlnet2], # 複数のControlNetを指定
151
+ ).to(device)
152
 
153
  # スケジューラーの設定
154
  refine_gen_pipe.scheduler = UniPCMultistepScheduler.from_config(refine_gen_pipe.scheduler.config)
 
206
  rgba_image[:, :, 3] = binary_image
207
  return Image.fromarray(rgba_image, 'RGBA')
208
 
209
+ @spaces.GPU
210
  def generate_sotai_image(input_image: Image.Image, output_width: int, output_height: int) -> Image.Image:
211
  input_image = ensure_rgb(input_image)
212
  global sotai_gen_pipe
 
251
  torch.cuda.empty_cache()
252
  gc.collect()
253
 
254
+ @spaces.GPU
255
  def generate_refined_image(prompt: str, original_image: Image.Image, output_width: int, output_height: int, weight1: float, weight2: float) -> Image.Image:
256
  original_image = ensure_rgb(original_image)
257
  global refine_gen_pipe