|
import os |
|
import requests |
|
from tqdm import tqdm |
|
from pathlib import Path |
|
from S2I.logger import logger |
|
|
|
def sc_vae_encoder_fwd(self, sample): |
|
sample = self.conv_in(sample) |
|
self.current_down_blocks = [] |
|
|
|
for down_block in self.down_blocks: |
|
self.current_down_blocks.append(sample) |
|
sample = down_block(sample) |
|
|
|
sample = self.mid_block(sample) |
|
sample = self.conv_norm_out(sample) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
return sample |
|
|
|
def sc_vae_decoder_fwd(self, sample, latent_embeds=None): |
|
sample = self.conv_in(sample) |
|
upscale_dtype = next(self.up_blocks.parameters()).dtype |
|
sample = self.mid_block(sample, latent_embeds) |
|
sample = sample.to(upscale_dtype) |
|
|
|
if not self.ignore_skip: |
|
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4] |
|
reversed_skip_acts = self.incoming_skip_acts[::-1] |
|
for idx, (up_block, skip_conv) in enumerate(zip(self.up_blocks, skip_convs)): |
|
skip_in = skip_conv(reversed_skip_acts[idx] * self.gamma) |
|
sample += skip_in |
|
sample = up_block(sample, latent_embeds) |
|
else: |
|
for up_block in self.up_blocks: |
|
sample = up_block(sample, latent_embeds) |
|
|
|
sample = self.conv_norm_out(sample, latent_embeds) if latent_embeds else self.conv_norm_out(sample) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
return sample |
|
|
|
def downloading(url, outf): |
|
if not os.path.exists(outf): |
|
print(f"Downloading checkpoint to {outf}") |
|
response = requests.get(url, stream=True) |
|
total_size_in_bytes = int(response.headers.get('content-length', 0)) |
|
block_size = 1024 |
|
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) |
|
with open(outf, 'wb') as file: |
|
for data in response.iter_content(block_size): |
|
progress_bar.update(len(data)) |
|
file.write(data) |
|
progress_bar.close() |
|
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: |
|
print("ERROR, something went wrong") |
|
print(f"Downloaded successfully to {outf}") |
|
|
|
def initialize_folder() -> None: |
|
""" |
|
Initialize the folder for storing model weights. |
|
|
|
Raises: |
|
OSError: if the folder cannot be created. |
|
""" |
|
home = get_s2i_home() |
|
s2i_home_path = home + "/.s2i" |
|
weights_path = s2i_home_path + "/weights" |
|
print(weights_path) |
|
if not os.path.exists(s2i_home_path): |
|
os.makedirs(s2i_home_path, exist_ok=True) |
|
|
|
if not os.path.exists(weights_path): |
|
os.makedirs(weights_path, exist_ok=True) |
|
|
|
def get_s2i_home() -> str: |
|
""" |
|
Get the home directory for storing model weights |
|
|
|
Returns: |
|
str: the home directory. |
|
""" |
|
return str(os.getenv("S2I_HOME", default=str(Path.home()))) |
|
|
|
def download_models(): |
|
urls = { |
|
'350k-adapter': 'https://huggingface.co/myn0908/sk2ks/resolve/main/adapter_weights_large_sketch2image_lora.pkl?download=true', |
|
'350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true' |
|
} |
|
|
|
home = get_s2i_home() |
|
model_paths = {} |
|
for model_name, url in urls.items(): |
|
outf = os.path.join(home, f"sketch2image_lora_{model_name}.pkl") |
|
downloading(url, outf) |
|
model_paths[model_name] = outf |
|
|
|
return model_paths |
|
|
|
|
|
def get_model_path(model_name, model_paths): |
|
return model_paths.get(model_name, "Model not found") |
|
|