myn0908's picture
replace upload image by url input
5044478
raw
history blame
3.6 kB
import os
import requests
from tqdm import tqdm
from pathlib import Path
from S2I.logger import logger
def sc_vae_encoder_fwd(self, sample):
sample = self.conv_in(sample)
self.current_down_blocks = []
for down_block in self.down_blocks:
self.current_down_blocks.append(sample)
sample = down_block(sample)
sample = self.mid_block(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
def sc_vae_decoder_fwd(self, sample, latent_embeds=None):
sample = self.conv_in(sample)
upscale_dtype = next(self.up_blocks.parameters()).dtype
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
if not self.ignore_skip:
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
reversed_skip_acts = self.incoming_skip_acts[::-1]
for idx, (up_block, skip_conv) in enumerate(zip(self.up_blocks, skip_convs)):
skip_in = skip_conv(reversed_skip_acts[idx] * self.gamma)
sample += skip_in
sample = up_block(sample, latent_embeds)
else:
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
sample = self.conv_norm_out(sample, latent_embeds) if latent_embeds else self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
def downloading(url, outf):
if not os.path.exists(outf):
print(f"Downloading checkpoint to {outf}")
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(outf, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
print(f"Downloaded successfully to {outf}")
def initialize_folder() -> None:
"""
Initialize the folder for storing model weights.
Raises:
OSError: if the folder cannot be created.
"""
home = get_s2i_home()
s2i_home_path = home + "/.s2i"
weights_path = s2i_home_path + "/weights"
print(weights_path)
if not os.path.exists(s2i_home_path):
os.makedirs(s2i_home_path, exist_ok=True)
if not os.path.exists(weights_path):
os.makedirs(weights_path, exist_ok=True)
def get_s2i_home() -> str:
"""
Get the home directory for storing model weights
Returns:
str: the home directory.
"""
return str(os.getenv("S2I_HOME", default=str(Path.home())))
def download_models():
urls = {
'350k-adapter': 'https://huggingface.co/myn0908/sk2ks/resolve/main/adapter_weights_large_sketch2image_lora.pkl?download=true',
'350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true'
}
# Get the current working directory
home = get_s2i_home()
model_paths = {}
for model_name, url in urls.items():
outf = os.path.join(home, f"sketch2image_lora_{model_name}.pkl")
downloading(url, outf)
model_paths[model_name] = outf
return model_paths
def get_model_path(model_name, model_paths):
return model_paths.get(model_name, "Model not found")