adamelliotfields commited on
Commit
507ffa3
1 Parent(s): 767128b

Ensure files are downloaded

Browse files
Files changed (4) hide show
  1. README.md +5 -1
  2. app.py +4 -1
  3. lib/__init__.py +2 -1
  4. lib/download.py +38 -0
README.md CHANGED
@@ -21,6 +21,7 @@ models:
21
  - h94/IP-Adapter
22
  - Linaqruf/anything-v3-1
23
  - Lykon/dreamshaper-8
 
24
  - prompthero/openjourney-v4
25
  - SG161222/Realistic_Vision_V5.1_noVAE
26
  - XpucT/Deliberate
@@ -45,7 +46,10 @@ preload_from_hub:
45
  anything-v3-2.safetensors
46
  - >-
47
  Lykon/dreamshaper-8
48
- text_encoder/model.fp16.safetensors,unet/diffusion_pytorch_model.fp16.safetensors,vae/diffusion_pytorch_model.fp16.safetensors,model_index.json
 
 
 
49
  - >-
50
  prompthero/openjourney-v4
51
  openjourney-v4.ckpt
 
21
  - h94/IP-Adapter
22
  - Linaqruf/anything-v3-1
23
  - Lykon/dreamshaper-8
24
+ - madebyollin/taesd
25
  - prompthero/openjourney-v4
26
  - SG161222/Realistic_Vision_V5.1_noVAE
27
  - XpucT/Deliberate
 
46
  anything-v3-2.safetensors
47
  - >-
48
  Lykon/dreamshaper-8
49
+ feature_extractor/preprocessor_config.json,safety_checker/config.json,scheduler/scheduler_config.json,text_encoder/config.json,text_encoder/model.fp16.safetensors,tokenizer/merges.txt,tokenizer/special_tokens_map.json,tokenizer/tokenizer_config.json,tokenizer/vocab.json,unet/config.json,unet/diffusion_pytorch_model.fp16.safetensors,vae/config.json,vae/diffusion_pytorch_model.fp16.safetensors,model_index.json
50
+ - >-
51
+ madebyollin/taesd
52
+ config.json,diffusion_pytorch_model.safetensors
53
  - >-
54
  prompthero/openjourney-v4
55
  openjourney-v4.ckpt
app.py CHANGED
@@ -4,7 +4,7 @@ import random
4
 
5
  import gradio as gr
6
 
7
- from lib import Config, async_call, generate
8
 
9
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
10
  refresh_seed_js = """
@@ -476,6 +476,9 @@ if __name__ == "__main__":
476
  parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
477
  args = parser.parse_args()
478
 
 
 
 
479
  # https://www.gradio.app/docs/gradio/interface#interface-queue
480
  demo.queue().launch(
481
  server_name=args.server,
 
4
 
5
  import gradio as gr
6
 
7
+ from lib import Config, async_call, download_repo_files, generate
8
 
9
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
10
  refresh_seed_js = """
 
476
  parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
477
  args = parser.parse_args()
478
 
479
+ # download to hub cache
480
+ download_repo_files()
481
+
482
  # https://www.gradio.app/docs/gradio/interface#interface-queue
483
  demo.queue().launch(
484
  server_name=args.server,
lib/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
  from .config import Config
 
2
  from .inference import async_call, generate
3
  from .loader import Loader
4
  from .upscaler import RealESRGAN
5
 
6
- __all__ = ["Config", "Loader", "RealESRGAN", "async_call", "generate"]
 
1
  from .config import Config
2
+ from .download import download_repo_files
3
  from .inference import async_call, generate
4
  from .loader import Loader
5
  from .upscaler import RealESRGAN
6
 
7
+ __all__ = ["Config", "Loader", "RealESRGAN", "async_call", "download_repo_files", "generate"]
lib/download.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub._snapshot_download import snapshot_download
4
+
5
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
6
+
7
+ SPACES_ZERO_GPU = os.environ.get("SPACES_ZERO_GPU", "").lower() == "true"
8
+
9
+ REPO = "Lykon/dreamshaper-8"
10
+
11
+ FILES = [
12
+ "feature_extractor/preprocessor_config.json",
13
+ "safety_checker/config.json",
14
+ "scheduler/scheduler_config.json",
15
+ "text_encoder/config.json",
16
+ "text_encoder/model.fp16.safetensors",
17
+ "tokenizer/merges.txt",
18
+ "tokenizer/special_tokens_map.json",
19
+ "tokenizer/tokenizer_config.json",
20
+ "tokenizer/vocab.json",
21
+ "unet/config.json",
22
+ "unet/diffusion_pytorch_model.fp16.safetensors",
23
+ "vae/config.json",
24
+ "vae/diffusion_pytorch_model.fp16.safetensors",
25
+ "model_index.json",
26
+ ]
27
+
28
+
29
+ def download_repo_files():
30
+ global REPO, FILES
31
+ return snapshot_download(
32
+ repo_id=REPO,
33
+ repo_type="model",
34
+ revision="main",
35
+ token=HF_TOKEN,
36
+ allow_patterns=FILES,
37
+ ignore_patterns=None,
38
+ )