Spaces:
Build error
Build error
# Copyright 2020 Erik Härkönen. All rights reserved. | |
# This file is licensed to you under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. You may obtain a copy | |
# of the License at http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software distributed under | |
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS | |
# OF ANY KIND, either express or implied. See the License for the specific language | |
# governing permissions and limitations under the License. | |
import string | |
import numpy as np | |
from pathlib import Path | |
import requests | |
import pickle | |
import sys | |
import re | |
import gdown | |
def prettify_name(name): | |
valid = "-_%s%s" % (string.ascii_letters, string.digits) | |
return ''.join(map(lambda c : c if c in valid else '_', name)) | |
# Add padding to sequence of images | |
# Used in conjunction with np.hstack/np.vstack | |
# By default: adds one 64th of the width of horizontal padding | |
def pad_frames(strip, pad_fract_horiz=64, pad_fract_vert=0, pad_value=None): | |
dtype = strip[0].dtype | |
if pad_value is None: | |
if dtype in [np.float32, np.float64]: | |
pad_value = 1.0 | |
else: | |
pad_value = np.iinfo(dtype).max | |
frames = [strip[0]] | |
for frame in strip[1:]: | |
if pad_fract_horiz > 0: | |
frames.append(pad_value*np.ones((frame.shape[0], frame.shape[1]//pad_fract_horiz, 3), dtype=dtype)) | |
elif pad_fract_vert > 0: | |
frames.append(pad_value*np.ones((frame.shape[0]//pad_fract_vert, frame.shape[1], 3), dtype=dtype)) | |
frames.append(frame) | |
return frames | |
def download_google_drive(url, output_name): | |
print('Downloading', url) | |
gdown.download(url, str(output_name)) | |
# session = requests.Session() | |
# r = session.get(url, allow_redirects=True) | |
# r.raise_for_status() | |
# # Google Drive virus check message | |
# if r.encoding is not None: | |
# tokens = re.search('(confirm=.+)&id', str(r.content)) | |
# assert tokens is not None, 'Could not extract token from response' | |
# url = url.replace('id=', f'{tokens[1]}&id=') | |
# r = session.get(url, allow_redirects=True) | |
# r.raise_for_status() | |
# assert r.encoding is None, f'Failed to download weight file from {url}' | |
# with open(output_name, 'wb') as f: | |
# f.write(r.content) | |
def download_generic(url, output_name): | |
print('Downloading', url) | |
session = requests.Session() | |
r = session.get(url, allow_redirects=True) | |
r.raise_for_status() | |
# No encoding means raw data | |
if r.encoding is None: | |
with open(output_name, 'wb') as f: | |
f.write(r.content) | |
else: | |
download_manual(url, output_name) | |
def download_manual(url, output_name): | |
outpath = Path(output_name).resolve() | |
while not outpath.is_file(): | |
print('Could not find checkpoint') | |
print(f'Please download the checkpoint from\n{url}\nand save it as\n{outpath}') | |
input('Press any key to continue...') | |
def download_ckpt(url, output_name): | |
if 'drive.google' in url: | |
download_google_drive(url, output_name) | |
elif 'mega.nz' in url: | |
download_manual(url, output_name) | |
else: | |
download_generic(url, output_name) |