myn0908 commited on
Commit
cd0d204
1 Parent(s): 55a3c9a

update code

Browse files
Files changed (1) hide show
  1. S2I/modules/utils.py +27 -4
S2I/modules/utils.py CHANGED
@@ -55,6 +55,31 @@ def downloading(url, outf):
55
  print("ERROR, something went wrong")
56
  print(f"Downloaded successfully to {outf}")
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  def download_models():
60
  urls = {
@@ -62,12 +87,10 @@ def download_models():
62
  '100k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/model_16001.pkl?download=true',
63
  }
64
  # Get the current working directory
65
- ckpt_folder = os.path.join(os.getcwd(), 'checkpoints')
66
- os.makedirs(ckpt_folder, exist_ok=True)
67
-
68
  model_paths = {}
69
  for model_name, url in urls.items():
70
- outf = os.path.join(ckpt_folder, f"sketch2image_lora_{model_name}.pkl")
71
  downloading(url, outf)
72
  model_paths[model_name] = outf
73
 
 
55
  print("ERROR, something went wrong")
56
  print(f"Downloaded successfully to {outf}")
57
 
58
+ def initialize_folder() -> None:
59
+ """
60
+ Initialize the folder for storing model weights.
61
+
62
+ Raises:
63
+ OSError: if the folder cannot be created.
64
+ """
65
+ home = get_s2i_home()
66
+ s2i_home_path = home + "/.s2i"
67
+ weights_path = s2i_home_path + "/weights"
68
+ print(weights_path)
69
+ if not os.path.exists(s2i_home_path):
70
+ os.makedirs(s2i_home_path, exist_ok=True)
71
+
72
+ if not os.path.exists(weights_path):
73
+ os.makedirs(weights_path, exist_ok=True)
74
+
75
+ def get_s2i_home() -> str:
76
+ """
77
+ Get the home directory for storing model weights
78
+
79
+ Returns:
80
+ str: the home directory.
81
+ """
82
+ return str(os.getenv("S2I_HOME", default=str(Path.home())))
83
 
84
  def download_models():
85
  urls = {
 
87
  '100k': 'https://huggingface.co/myn0908/sk2ks/resolve/main/model_16001.pkl?download=true',
88
  }
89
  # Get the current working directory
90
+ home = get_s2i_home()
 
 
91
  model_paths = {}
92
  for model_name, url in urls.items():
93
+ outf = os.path.join(home, f"sketch2image_lora_{model_name}.pkl")
94
  downloading(url, outf)
95
  model_paths[model_name] = outf
96