tori29umai commited on
Commit
953a099
·
1 Parent(s): 5826348
Files changed (1) hide show
  1. app.py +27 -28
app.py CHANGED
@@ -9,6 +9,28 @@ import time
9
  from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
10
  from utils.prompt_analysis import PromptAnalysis
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class Img2Img:
13
  def __init__(self):
14
  self.setup_paths()
@@ -24,27 +46,6 @@ class Img2Img:
24
  os.makedirs(self.tagger_dir, exist_ok=True)
25
  os.makedirs(self.lora_dir, exist_ok=True)
26
 
27
- def setup_models(self):
28
- load_cn_model(self.cn_dir)
29
- load_cn_config(self.cn_dir)
30
- load_tagger_model(self.tagger_dir)
31
- load_lora_model(self.lora_dir)
32
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
33
- self.dtype = torch.float16
34
- self.model = "cagliostrolab/animagine-xl-3.1"
35
- self.scheduler = DDIMScheduler.from_pretrained(self.model, subfolder="scheduler")
36
- self.controlnet = ControlNetModel.from_pretrained(self.cn_dir, torch_dtype=self.dtype, use_safetensors=True)
37
- self.pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
38
- self.model,
39
- controlnet=self.controlnet,
40
- torch_dtype=self.dtype,
41
- use_safetensors=True,
42
- scheduler=self.scheduler,
43
- )
44
- self.pipe.load_lora_weights(self.lora_dir, weight_name="sdxl_BWLine.safetensors")
45
- self.pipe = self.pipe.to(self.device)
46
-
47
-
48
  def layout(self):
49
  css = """
50
  #intro{
@@ -73,24 +74,22 @@ class Img2Img:
73
 
74
  @spaces.GPU
75
  def predict(self, input_image_path, prompt, negative_prompt, controlnet_scale):
 
76
  input_image_pil = Image.open(input_image_path)
77
  base_size = input_image_pil.size
78
  resize_image = resize_image_aspect_ratio(input_image_pil)
79
  resize_image_size = resize_image.size
80
  width, height = resize_image_size
81
  white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
82
- conditioning, pooled = self.compel([prompt, negative_prompt])
83
  generator = torch.manual_seed(0)
84
  last_time = time.time()
85
 
86
- output_image = self.pipe(
87
  image=white_base_pil,
88
  control_image=resize_image,
89
  strength=1.0,
90
- prompt_embeds=conditioning[0:1],
91
- pooled_prompt_embeds=pooled[0:1],
92
- negative_prompt_embeds=conditioning[1:2],
93
- negative_pooled_prompt_embeds=pooled[1:2],
94
  width=width,
95
  height=height,
96
  controlnet_conditioning_scale=float(controlnet_scale),
@@ -100,7 +99,7 @@ class Img2Img:
100
  num_inference_steps=30,
101
  guidance_scale=8.5,
102
  eta=1.0,
103
- )
104
  print(f"Time taken: {time.time() - last_time}")
105
  output_image = output_image.resize(base_size, Image.LANCZOS)
106
  return output_image
 
9
  from utils.utils import load_cn_model, load_cn_config, load_tagger_model, load_lora_model, resize_image_aspect_ratio, base_generation
10
  from utils.prompt_analysis import PromptAnalysis
11
 
12
+
13
+ def load_model():
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ dtype = torch.float16
16
+ model = "cagliostrolab/animagine-xl-3.1"
17
+ scheduler = DDIMScheduler.from_pretrained(model, subfolder="scheduler")
18
+ controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)
19
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
20
+ model,
21
+ controlnet=controlnet,
22
+ torch_dtype=dtype,
23
+ use_safetensors=True,
24
+ scheduler=scheduler,
25
+ )
26
+ pipe.load_lora_weights(lora_dir, weight_name="sdxl_BWLine.safetensors")
27
+ pipe = pipe.to(device)
28
+ return pipe
29
+
30
+
31
+
32
+
33
+
34
  class Img2Img:
35
  def __init__(self):
36
  self.setup_paths()
 
46
  os.makedirs(self.tagger_dir, exist_ok=True)
47
  os.makedirs(self.lora_dir, exist_ok=True)
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def layout(self):
50
  css = """
51
  #intro{
 
74
 
75
  @spaces.GPU
76
  def predict(self, input_image_path, prompt, negative_prompt, controlnet_scale):
77
+ pipe = load_model()
78
  input_image_pil = Image.open(input_image_path)
79
  base_size = input_image_pil.size
80
  resize_image = resize_image_aspect_ratio(input_image_pil)
81
  resize_image_size = resize_image.size
82
  width, height = resize_image_size
83
  white_base_pil = base_generation(resize_image.size, (255, 255, 255, 255)).convert("RGB")
 
84
  generator = torch.manual_seed(0)
85
  last_time = time.time()
86
 
87
+ output_image = pipe(
88
  image=white_base_pil,
89
  control_image=resize_image,
90
  strength=1.0,
91
+ prompt=prompt,
92
+ negative_prompt = negative_prompt,
 
 
93
  width=width,
94
  height=height,
95
  controlnet_conditioning_scale=float(controlnet_scale),
 
99
  num_inference_steps=30,
100
  guidance_scale=8.5,
101
  eta=1.0,
102
+ ).images[0]
103
  print(f"Time taken: {time.time() - last_time}")
104
  output_image = output_image.resize(base_size, Image.LANCZOS)
105
  return output_image