update code
Browse files- S2I/modules/utils.py +27 -4
S2I/modules/utils.py
CHANGED
@@ -55,6 +55,31 @@ def downloading(url, outf):
|
|
55 |
print("ERROR, something went wrong")
|
56 |
print(f"Downloaded successfully to {outf}")
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def download_models():
|
60 |
urls = {
|
@@ -62,12 +87,10 @@ def download_models():
|
|
62 |
'100k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/model_16001.pkl?download=true',
|
63 |
}
|
64 |
# Get the current working directory
|
65 |
-
|
66 |
-
os.makedirs(ckpt_folder, exist_ok=True)
|
67 |
-
|
68 |
model_paths = {}
|
69 |
for model_name, url in urls.items():
|
70 |
-
outf = os.path.join(
|
71 |
downloading(url, outf)
|
72 |
model_paths[model_name] = outf
|
73 |
|
|
|
55 |
print("ERROR, something went wrong")
|
56 |
print(f"Downloaded successfully to {outf}")
|
57 |
|
58 |
+
def initialize_folder() -> None:
|
59 |
+
"""
|
60 |
+
Initialize the folder for storing model weights.
|
61 |
+
|
62 |
+
Raises:
|
63 |
+
OSError: if the folder cannot be created.
|
64 |
+
"""
|
65 |
+
home = get_s2i_home()
|
66 |
+
s2i_home_path = home + "/.s2i"
|
67 |
+
weights_path = s2i_home_path + "/weights"
|
68 |
+
print(weights_path)
|
69 |
+
if not os.path.exists(s2i_home_path):
|
70 |
+
os.makedirs(s2i_home_path, exist_ok=True)
|
71 |
+
|
72 |
+
if not os.path.exists(weights_path):
|
73 |
+
os.makedirs(weights_path, exist_ok=True)
|
74 |
+
|
75 |
+
def get_s2i_home() -> str:
|
76 |
+
"""
|
77 |
+
Get the home directory for storing model weights
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
str: the home directory.
|
81 |
+
"""
|
82 |
+
return str(os.getenv("S2I_HOME", default=str(Path.home())))
|
83 |
|
84 |
def download_models():
|
85 |
urls = {
|
|
|
87 |
'100k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/model_16001.pkl?download=true',
|
88 |
}
|
89 |
# Get the current working directory
|
90 |
+
home = get_s2i_home()
|
|
|
|
|
91 |
model_paths = {}
|
92 |
for model_name, url in urls.items():
|
93 |
+
outf = os.path.join(home, f"sketch2image_lora_{model_name}.pkl")
|
94 |
downloading(url, outf)
|
95 |
model_paths[model_name] = outf
|
96 |
|