myn0908's picture
replace upload image by url input
5044478
import os
import requests
from tqdm import tqdm
from pathlib import Path
from S2I.logger import logger
def sc_vae_encoder_fwd(self, sample):
sample = self.conv_in(sample)
self.current_down_blocks = []
for down_block in self.down_blocks:
self.current_down_blocks.append(sample)
sample = down_block(sample)
sample = self.mid_block(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
def sc_vae_decoder_fwd(self, sample, latent_embeds=None):
sample = self.conv_in(sample)
upscale_dtype = next(self.up_blocks.parameters()).dtype
sample = self.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
if not self.ignore_skip:
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
reversed_skip_acts = self.incoming_skip_acts[::-1]
for idx, (up_block, skip_conv) in enumerate(zip(self.up_blocks, skip_convs)):
skip_in = skip_conv(reversed_skip_acts[idx] * self.gamma)
sample += skip_in
sample = up_block(sample, latent_embeds)
else:
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds)
sample = self.conv_norm_out(sample, latent_embeds) if latent_embeds else self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
def downloading(url, outf):
if not os.path.exists(outf):
print(f"Downloading checkpoint to {outf}")
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(outf, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
print("ERROR, something went wrong")
print(f"Downloaded successfully to {outf}")
def initialize_folder() -> None:
"""
Initialize the folder for storing model weights.
Raises:
OSError: if the folder cannot be created.
"""
home = get_s2i_home()
s2i_home_path = home + "/.s2i"
weights_path = s2i_home_path + "/weights"
print(weights_path)
if not os.path.exists(s2i_home_path):
os.makedirs(s2i_home_path, exist_ok=True)
if not os.path.exists(weights_path):
os.makedirs(weights_path, exist_ok=True)
def get_s2i_home() -> str:
"""
Get the home directory for storing model weights
Returns:
str: the home directory.
"""
return str(os.getenv("S2I_HOME", default=str(Path.home())))
def download_models():
urls = {
'350k-adapter': 'https://huggingface.co/myn0908/sk2ks/resolve/main/adapter_weights_large_sketch2image_lora.pkl?download=true',
'350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true'
}
# Get the current working directory
home = get_s2i_home()
model_paths = {}
for model_name, url in urls.items():
outf = os.path.join(home, f"sketch2image_lora_{model_name}.pkl")
downloading(url, outf)
model_paths[model_name] = outf
return model_paths
def get_model_path(model_name, model_paths):
return model_paths.get(model_name, "Model not found")