FrozenBurning commited on
Commit
8eda766
·
1 Parent(s): d099347

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -75
app.py CHANGED
@@ -11,81 +11,6 @@ os.system("git clone https://github.com/FrozenBurning/SceneDreamer.git")
11
  os.system("cp -r SceneDreamer/* ./")
12
  os.system("bash install.sh")
13
 
14
- pretrained_model = dict(file_url='https://drive.google.com/uc?id=1IFu1vNrgF1EaRqPizyEgN_5Vt7Fyg0Mj',
15
- alt_url='', file_size=330571863,
16
- file_path='./scenedreamer_released.pt',)
17
-
18
-
19
- def download_file(session, file_spec, use_alt_url=False, chunk_size=128, num_attempts=10):
20
- file_path = file_spec['file_path']
21
- if use_alt_url:
22
- file_url = file_spec['alt_url']
23
- else:
24
- file_url = file_spec['file_url']
25
-
26
- file_dir = os.path.dirname(file_path)
27
- tmp_path = file_path + '.tmp.' + uuid.uuid4().hex
28
- if file_dir:
29
- os.makedirs(file_dir, exist_ok=True)
30
-
31
- progress_bar = tqdm(total=file_spec['file_size'], unit='B', unit_scale=True)
32
- for attempts_left in reversed(range(num_attempts)):
33
- data_size = 0
34
- progress_bar.reset()
35
- try:
36
- # Download.
37
- data_md5 = hashlib.md5()
38
- with session.get(file_url, stream=True) as res:
39
- res.raise_for_status()
40
- with open(tmp_path, 'wb') as f:
41
- for chunk in res.iter_content(chunk_size=chunk_size<<10):
42
- progress_bar.update(len(chunk))
43
- f.write(chunk)
44
- data_size += len(chunk)
45
- data_md5.update(chunk)
46
-
47
- # Validate.
48
- if 'file_size' in file_spec and data_size != file_spec['file_size']:
49
- raise IOError('Incorrect file size', file_path)
50
- if 'file_md5' in file_spec and data_md5.hexdigest() != file_spec['file_md5']:
51
- raise IOError('Incorrect file MD5', file_path)
52
- break
53
-
54
- except Exception as e:
55
- # print(e)
56
- # Last attempt => raise error.
57
- if not attempts_left:
58
- raise
59
-
60
- # Handle Google Drive virus checker nag.
61
- if data_size > 0 and data_size < 8192:
62
- with open(tmp_path, 'rb') as f:
63
- data = f.read()
64
- links = [html.unescape(link) for link in data.decode('utf-8').split('"') if 'confirm=t' in link]
65
- if len(links) == 1:
66
- file_url = requests.compat.urljoin(file_url, links[0])
67
- continue
68
-
69
- progress_bar.close()
70
-
71
- # Rename temp file to the correct name.
72
- os.replace(tmp_path, file_path) # atomic
73
-
74
- # Attempt to clean up any leftover temps.
75
- for filename in glob.glob(file_path + '.tmp.*'):
76
- try:
77
- os.remove(filename)
78
- except:
79
- pass
80
-
81
- print('Downloading SceneDreamer pretrained model...')
82
- with requests.Session() as session:
83
- try:
84
- download_file(session, pretrained_model)
85
- except:
86
- print('Google Drive download failed.\n')
87
-
88
-
89
 
90
  import os
91
  import torch
 
11
  os.system("cp -r SceneDreamer/* ./")
12
  os.system("bash install.sh")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  import os
16
  import torch