birdortyedi commited on
Commit
74de975
1 Parent(s): 2a92dc2

hf hub added

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -1,24 +1,20 @@
1
- import requests
2
- import os
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
  import torchvision.models as models
 
7
 
8
  from configs.default import get_cfg_defaults
9
  from modeling.build import build_model
10
  from utils.data_utils import linear_scaling
11
 
12
 
13
- url = "https://www.dropbox.com/s/uxvax5sjx5iysyl/cifr.pth?dl=0"
14
- r = requests.get(url, stream=True)
15
- if not os.path.exists("cifr.pth"):
16
- with open("cifr.pth", 'wb') as f:
17
- for data in r:
18
- f.write(data)
19
 
20
  cfg = get_cfg_defaults()
21
- cfg.MODEL.CKPT = "cifr.pth"
22
  net, _ = build_model(cfg)
23
  net = net.eval()
24
  vgg16 = models.vgg16(pretrained=True).features.eval()
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
  import torchvision.models as models
5
+ from huggingface_hub import hf_hub_url, cached_download
6
 
7
  from configs.default import get_cfg_defaults
8
  from modeling.build import build_model
9
  from utils.data_utils import linear_scaling
10
 
11
 
12
+ url = hf_hub_url(repo_id="birdortyedi/cifr", filename="cifr.pth")
13
+ model_path = cached_download(url)
14
+
 
 
 
15
 
16
  cfg = get_cfg_defaults()
17
+ cfg.MODEL.CKPT = model_path
18
  net, _ = build_model(cfg)
19
  net = net.eval()
20
  vgg16 = models.vgg16(pretrained=True).features.eval()