import os import requests from tqdm import tqdm from pathlib import Path from S2I.logger import logger def sc_vae_encoder_fwd(self, sample): sample = self.conv_in(sample) self.current_down_blocks = [] for down_block in self.down_blocks: self.current_down_blocks.append(sample) sample = down_block(sample) sample = self.mid_block(sample) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample def sc_vae_decoder_fwd(self, sample, latent_embeds=None): sample = self.conv_in(sample) upscale_dtype = next(self.up_blocks.parameters()).dtype sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) if not self.ignore_skip: skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4] reversed_skip_acts = self.incoming_skip_acts[::-1] for idx, (up_block, skip_conv) in enumerate(zip(self.up_blocks, skip_convs)): skip_in = skip_conv(reversed_skip_acts[idx] * self.gamma) sample += skip_in sample = up_block(sample, latent_embeds) else: for up_block in self.up_blocks: sample = up_block(sample, latent_embeds) sample = self.conv_norm_out(sample, latent_embeds) if latent_embeds else self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample def downloading(url, outf): if not os.path.exists(outf): print(f"Downloading checkpoint to {outf}") response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get('content-length', 0)) block_size = 1024 # 1 Kibibyte progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) with open(outf, 'wb') as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: print("ERROR, something went wrong") print(f"Downloaded successfully to {outf}") def initialize_folder() -> None: """ Initialize the folder for storing model weights. Raises: OSError: if the folder cannot be created. """ home = get_s2i_home() s2i_home_path = home + "/.s2i" weights_path = s2i_home_path + "/weights" print(weights_path) if not os.path.exists(s2i_home_path): os.makedirs(s2i_home_path, exist_ok=True) if not os.path.exists(weights_path): os.makedirs(weights_path, exist_ok=True) def get_s2i_home() -> str: """ Get the home directory for storing model weights Returns: str: the home directory. """ return str(os.getenv("S2I_HOME", default=str(Path.home()))) def download_models(): urls = { '350k-adapter': 'https://huggingface.co/myn0908/sk2ks/resolve/main/adapter_weights_large_sketch2image_lora.pkl?download=true', '350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true' } # Get the current working directory home = get_s2i_home() model_paths = {} for model_name, url in urls.items(): outf = os.path.join(home, f"sketch2image_lora_{model_name}.pkl") downloading(url, outf) model_paths[model_name] = outf return model_paths def get_model_path(model_name, model_paths): return model_paths.get(model_name, "Model not found")