Spaces:
Running
Running
| import sys | |
| import os | |
| import math | |
| from math import cos, sin | |
| from pathlib import Path | |
| import subprocess | |
| import re | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import scipy.io as sio | |
| import cv2 | |
| import torchvision | |
| from torchvision import transforms | |
| from .model import L2CS | |
| transformations = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize(448), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ) | |
| ]) | |
| def atoi(text): | |
| return int(text) if text.isdigit() else text | |
| def natural_keys(text): | |
| ''' | |
| alist.sort(key=natural_keys) sorts in human order | |
| http://nedbatchelder.com/blog/200712/human_sorting.html | |
| (See Toothy's implementation in the comments) | |
| ''' | |
| return [ atoi(c) for c in re.split(r'(\d+)', text) ] | |
| def prep_input_numpy(img:np.ndarray, device:str): | |
| """Preparing a Numpy Array as input to L2CS-Net.""" | |
| if len(img.shape) == 4: | |
| imgs = [] | |
| for im in img: | |
| imgs.append(transformations(im)) | |
| img = torch.stack(imgs) | |
| else: | |
| img = transformations(img) | |
| img = img.to(device) | |
| if len(img.shape) == 3: | |
| img = img.unsqueeze(0) | |
| return img | |
| def gazeto3d(gaze): | |
| gaze_gt = np.zeros([3]) | |
| gaze_gt[0] = -np.cos(gaze[1]) * np.sin(gaze[0]) | |
| gaze_gt[1] = -np.sin(gaze[1]) | |
| gaze_gt[2] = -np.cos(gaze[1]) * np.cos(gaze[0]) | |
| return gaze_gt | |
| def angular(gaze, label): | |
| total = np.sum(gaze * label) | |
| return np.arccos(min(total/(np.linalg.norm(gaze)* np.linalg.norm(label)), 0.9999999))*180/np.pi | |
| def select_device(device='', batch_size=None): | |
| # device = 'cpu' or '0' or '0,1,2,3' | |
| s = f'YOLOv3 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string | |
| cpu = device.lower() == 'cpu' | |
| if cpu: | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False | |
| elif device: # non-cpu device requested | |
| os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable | |
| # assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability | |
| cuda = not cpu and torch.cuda.is_available() | |
| if cuda: | |
| devices = device.split(',') if device else range(torch.cuda.device_count()) # i.e. 0,1,6,7 | |
| n = len(devices) # device count | |
| if n > 1 and batch_size: # check batch_size is divisible by device_count | |
| assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}' | |
| space = ' ' * len(s) | |
| for i, d in enumerate(devices): | |
| p = torch.cuda.get_device_properties(i) | |
| s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / 1024 ** 2}MB)\n" # bytes to MB | |
| else: | |
| s += 'CPU\n' | |
| return torch.device('cuda:0' if cuda else 'cpu') | |
| def spherical2cartesial(x): | |
| output = torch.zeros(x.size(0),3) | |
| output[:,2] = -torch.cos(x[:,1])*torch.cos(x[:,0]) | |
| output[:,0] = torch.cos(x[:,1])*torch.sin(x[:,0]) | |
| output[:,1] = torch.sin(x[:,1]) | |
| return output | |
| def compute_angular_error(input,target): | |
| input = spherical2cartesial(input) | |
| target = spherical2cartesial(target) | |
| input = input.view(-1,3,1) | |
| target = target.view(-1,1,3) | |
| output_dot = torch.bmm(target,input) | |
| output_dot = output_dot.view(-1) | |
| output_dot = torch.acos(output_dot) | |
| output_dot = output_dot.data | |
| output_dot = 180*torch.mean(output_dot)/math.pi | |
| return output_dot | |
| def softmax_temperature(tensor, temperature): | |
| result = torch.exp(tensor / temperature) | |
| result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result)) | |
| return result | |
| def git_describe(path=Path(__file__).parent): # path must be a directory | |
| # return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe | |
| s = f'git -C {path} describe --tags --long --always' | |
| try: | |
| return subprocess.check_output(s, shell=True, stderr=subprocess.STDOUT).decode()[:-1] | |
| except subprocess.CalledProcessError as e: | |
| return '' # not a git repository | |
| def getArch(arch,bins): | |
| # Base network structure | |
| if arch == 'ResNet18': | |
| model = L2CS( torchvision.models.resnet.BasicBlock,[2, 2, 2, 2], bins) | |
| elif arch == 'ResNet34': | |
| model = L2CS( torchvision.models.resnet.BasicBlock,[3, 4, 6, 3], bins) | |
| elif arch == 'ResNet101': | |
| model = L2CS( torchvision.models.resnet.Bottleneck,[3, 4, 23, 3], bins) | |
| elif arch == 'ResNet152': | |
| model = L2CS( torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins) | |
| else: | |
| if arch != 'ResNet50': | |
| print('Invalid value for architecture is passed! ' | |
| 'The default value of ResNet50 will be used instead!') | |
| model = L2CS( torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins) | |
| return model | |