abyildirim commited on
Commit
9bfd721
1 Parent(s): 2d42726

model files are moved to hf from gdrive due to the download limit error

Browse files
Files changed (1) hide show
  1. app.py +17 -6
app.py CHANGED
@@ -8,9 +8,20 @@ import utils
8
  from ldm.util import instantiate_from_config
9
  from omegaconf import OmegaConf
10
  from zipfile import ZipFile
11
- import gdown
12
  import os
 
 
13
 
 
 
 
 
 
 
 
 
 
 
14
  MODEL = None
15
 
16
  def inference(image: np.ndarray, instruction: str, center_crop: bool):
@@ -37,11 +48,11 @@ if __name__ == "__main__":
37
  )
38
  args = parser.parse_args()
39
 
40
- gdown.download(id="1tp0aHAS-ccrIfNz7XrGTSdNIPNZjOVSp", output="models/")
41
- with ZipFile("models/gqa_inpaint.zip", 'r') as zObject:
42
- zObject.extractall(path="models/")
43
- os.remove("models/gqa_inpaint.zip")
44
-
45
  parsed_config = OmegaConf.load(args.config)
46
  MODEL = instantiate_from_config(parsed_config["model"])
47
  model_state_dict = torch.load(args.checkpoint, map_location="cpu")["state_dict"]
 
8
  from ldm.util import instantiate_from_config
9
  from omegaconf import OmegaConf
10
  from zipfile import ZipFile
 
11
  import os
12
+ import requests
13
+ import shutil
14
 
15
+ def download_model(url):
16
+ os.makedirs("models", exist_ok=True)
17
+ local_filename = url.split('/')[-1]
18
+ with requests.get(url, stream=True) as r:
19
+ with open(os.path.join("models", local_filename), 'wb') as file:
20
+ shutil.copyfileobj(r.raw, file)
21
+ with ZipFile("models/gqa_inpaint.zip", 'r') as zObject:
22
+ zObject.extractall(path="models/")
23
+ os.remove("models/gqa_inpaint.zip")
24
+
25
  MODEL = None
26
 
27
  def inference(image: np.ndarray, instruction: str, center_crop: bool):
 
48
  )
49
  args = parser.parse_args()
50
 
51
+ print("## Downloading the model file")
52
+ download_model("https://huggingface.co/abyildirim/inst-inpaint-models/resolve/main/gqa_inpaint.zip")
53
+ print("## Download is completed")
54
+
55
+ print("## Running the demo")
56
  parsed_config = OmegaConf.load(args.config)
57
  MODEL = instantiate_from_config(parsed_config["model"])
58
  model_state_dict = torch.load(args.checkpoint, map_location="cpu")["state_dict"]