File size: 3,602 Bytes
55a3c9a
 
 
3d31fa1
55a3c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd0d204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55a3c9a
 
 
e4c85fa
5044478
55a3c9a
 
cd0d204
55a3c9a
 
cd0d204
55a3c9a
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import requests
from tqdm import tqdm
from pathlib import Path
from S2I.logger import logger

def sc_vae_encoder_fwd(self, sample):
    sample = self.conv_in(sample)
    self.current_down_blocks = []

    for down_block in self.down_blocks:
        self.current_down_blocks.append(sample)
        sample = down_block(sample)

    sample = self.mid_block(sample)
    sample = self.conv_norm_out(sample)
    sample = self.conv_act(sample)
    sample = self.conv_out(sample)
    return sample

def sc_vae_decoder_fwd(self, sample, latent_embeds=None):
    sample = self.conv_in(sample)
    upscale_dtype = next(self.up_blocks.parameters()).dtype
    sample = self.mid_block(sample, latent_embeds)
    sample = sample.to(upscale_dtype)

    if not self.ignore_skip:
        skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
        reversed_skip_acts = self.incoming_skip_acts[::-1]
        for idx, (up_block, skip_conv) in enumerate(zip(self.up_blocks, skip_convs)):
            skip_in = skip_conv(reversed_skip_acts[idx] * self.gamma)
            sample += skip_in
            sample = up_block(sample, latent_embeds)
    else:
        for up_block in self.up_blocks:
            sample = up_block(sample, latent_embeds)

    sample = self.conv_norm_out(sample, latent_embeds) if latent_embeds else self.conv_norm_out(sample)
    sample = self.conv_act(sample)
    sample = self.conv_out(sample)
    return sample

def downloading(url, outf):
    if not os.path.exists(outf):
        print(f"Downloading checkpoint to {outf}")
        response = requests.get(url, stream=True)
        total_size_in_bytes = int(response.headers.get('content-length', 0))
        block_size = 1024  # 1 Kibibyte
        progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
        with open(outf, 'wb') as file:
            for data in response.iter_content(block_size):
                progress_bar.update(len(data))
                file.write(data)
        progress_bar.close()
        if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
            print("ERROR, something went wrong")
        print(f"Downloaded successfully to {outf}")

def initialize_folder() -> None:
    """
    Initialize the folder for storing model weights.

    Raises:
        OSError: if the folder cannot be created.
    """
    home = get_s2i_home()
    s2i_home_path = home + "/.s2i"
    weights_path = s2i_home_path + "/weights"
    print(weights_path)
    if not os.path.exists(s2i_home_path):
        os.makedirs(s2i_home_path, exist_ok=True)

    if not os.path.exists(weights_path):
        os.makedirs(weights_path, exist_ok=True)

def get_s2i_home() -> str:
    """
    Get the home directory for storing model weights

    Returns:
        str: the home directory.
    """
    return str(os.getenv("S2I_HOME", default=str(Path.home())))

def download_models():
    urls = {
        '350k-adapter': 'https://huggingface.co/myn0908/sk2ks/resolve/main/adapter_weights_large_sketch2image_lora.pkl?download=true',
        '350k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/sketch_to_image_mixed_weights_350k_lora.pkl?download=true'
    }
    # Get the current working directory
    home = get_s2i_home()
    model_paths = {}
    for model_name, url in urls.items():
        outf = os.path.join(home, f"sketch2image_lora_{model_name}.pkl")
        downloading(url, outf)
        model_paths[model_name] = outf

    return model_paths


def get_model_path(model_name, model_paths):
    return model_paths.get(model_name, "Model not found")