|
from diffusers.utils.peft_utils import set_weights_and_activate_adapters |
|
from S2I.modules.models import PrimaryModel |
|
import re |
|
import gc |
|
import torch |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class Sketch2ImagePipeline(PrimaryModel): |
|
def __init__(self): |
|
super().__init__() |
|
self.timestep = torch.tensor([999], device="cuda").long() |
|
|
|
def generate(self, c_t, prompt=None, prompt_quality=None, prompt_template=None, prompt_tokens=None, r=1.0, noise_map=None, half_model=None, model_name=None): |
|
self.from_pretrained(model_name=model_name, r=r) |
|
prompt_enhanced = self.automatic_enhance_prompt(prompt, prompt_quality) |
|
prompt_enhanced = prompt_template.replace("{prompt}", prompt_enhanced) |
|
assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided" |
|
|
|
if half_model == 'float16': |
|
output_image = self._generate_fp16(c_t, prompt_enhanced, prompt_tokens, r, noise_map) |
|
else: |
|
output_image = self._generate_full_precision(c_t, prompt_enhanced, prompt_tokens, r, noise_map) |
|
|
|
return output_image |
|
|
|
def _generate_fp16(self, c_t, prompt, prompt_tokens, r, noise_map): |
|
with torch.autocast(device_type='cuda', dtype=torch.float16): |
|
caption_enc = self._get_caption_enc(prompt, prompt_tokens) |
|
|
|
self._set_weights_and_activate_adapters(r) |
|
encoded_control = self.global_vae.encode(c_t).latent_dist.sample() * self.global_vae.config.scaling_factor |
|
|
|
unet_input = encoded_control * r + noise_map * (1 - r) |
|
unet_output = self.global_unet(unet_input, self.timestep, encoder_hidden_states=caption_enc).sample |
|
x_denoise = self.global_scheduler.step(unet_output, self.timestep, unet_input, return_dict=True).prev_sample |
|
|
|
self.global_vae.decoder.incoming_skip_acts = self.global_vae.encoder.current_down_blocks |
|
self.global_vae.decoder.gamma = r |
|
|
|
output_image = self.global_vae.decode(x_denoise / self.global_vae.config.scaling_factor).sample.clamp(-1, 1) |
|
|
|
return output_image |
|
|
|
def _generate_full_precision(self, c_t, prompt, prompt_tokens, r, noise_map): |
|
caption_enc = self._get_caption_enc(prompt, prompt_tokens) |
|
|
|
self._set_weights_and_activate_adapters(r) |
|
encoded_control = self.global_vae.encode(c_t).latent_dist.sample() * self.global_vae.config.scaling_factor |
|
|
|
unet_input = encoded_control * r + noise_map * (1 - r) |
|
unet_output = self.global_unet(unet_input, self.timestep, encoder_hidden_states=caption_enc).sample |
|
x_denoise = self.global_scheduler.step(unet_output, self.timestep, unet_input, return_dict=True).prev_sample |
|
|
|
self.global_vae.decoder.incoming_skip_acts = self.global_vae.encoder.current_down_blocks |
|
self.global_vae.decoder.gamma = r |
|
|
|
output_image = self.global_vae.decode(x_denoise / self.global_vae.config.scaling_factor).sample.clamp(-1, 1) |
|
|
|
return output_image |
|
|
|
def _get_caption_enc(self, prompt, prompt_tokens): |
|
if prompt is not None: |
|
caption_tokens = self.global_tokenizer(prompt, max_length=self.global_tokenizer.model_max_length, |
|
padding="max_length", truncation=True, |
|
return_tensors="pt").input_ids.cuda() |
|
else: |
|
caption_tokens = prompt_tokens.cuda() |
|
|
|
return self.global_text_encoder(caption_tokens)[0] |
|
|
|
def _set_weights_and_activate_adapters(self, r): |
|
self.global_unet.set_adapters(["default"], weights=[r]) |
|
set_weights_and_activate_adapters(self.global_vae, ["vae_skip"], [r]) |
|
|
|
def automatic_enhance_prompt(self, input_prompt, prompt_quality): |
|
if prompt_quality: |
|
result = self.global_medium_prompt("Enhance the description: " + input_prompt) |
|
enhanced_text = result[0]['summary_text'] |
|
|
|
pattern = r'^.*?of\s+(.*?(?:\.|$))' |
|
match = re.match(pattern, enhanced_text, re.IGNORECASE | re.DOTALL) |
|
|
|
if match: |
|
remaining_text = enhanced_text[match.end():].strip() |
|
modified_sentence = match.group(1).capitalize() |
|
enhanced_text = modified_sentence + ' ' + remaining_text |
|
else: |
|
enhanced_text = input_prompt |
|
return enhanced_text |
|
|
|
def _move_to_cpu(self, module): |
|
module.to("cpu") |
|
|
|
def _move_to_gpu(self, module): |
|
module.to("cuda") |