k22056537
feat: UI nav, onboarding, L2CS weights path + torch.load; trim dev files
37a8ba6
import os
import numpy as np
import cv2
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image, ImageFilter
class Gaze360(Dataset):
def __init__(self, path, root, transform, angle, binwidth, train=True):
self.transform = transform
self.root = root
self.orig_list_len = 0
self.angle = angle
if train==False:
angle=90
self.binwidth=binwidth
self.lines = []
if isinstance(path, list):
for i in path:
with open(i) as f:
line = f.readlines()
line.pop(0)
self.lines.extend(line)
else:
with open(path) as f:
lines = f.readlines()
lines.pop(0)
self.orig_list_len = len(lines)
for line in lines:
gaze2d = line.strip().split(" ")[5]
label = np.array(gaze2d.split(",")).astype("float")
if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
self.lines.append(line)
print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines), angle))
def __len__(self):
return len(self.lines)
def __getitem__(self, idx):
line = self.lines[idx]
line = line.strip().split(" ")
face = line[0]
lefteye = line[1]
righteye = line[2]
name = line[3]
gaze2d = line[5]
label = np.array(gaze2d.split(",")).astype("float")
label = torch.from_numpy(label).type(torch.FloatTensor)
pitch = label[0]* 180 / np.pi
yaw = label[1]* 180 / np.pi
img = Image.open(os.path.join(self.root, face))
# fimg = cv2.imread(os.path.join(self.root, face))
# fimg = cv2.resize(fimg, (448, 448))/255.0
# fimg = fimg.transpose(2, 0, 1)
# img=torch.from_numpy(fimg).type(torch.FloatTensor)
if self.transform:
img = self.transform(img)
# Bin values
bins = np.array(range(-1*self.angle, self.angle, self.binwidth))
binned_pose = np.digitize([pitch, yaw], bins) - 1
labels = binned_pose
cont_labels = torch.FloatTensor([pitch, yaw])
return img, labels, cont_labels, name
class Mpiigaze(Dataset):
def __init__(self, pathorg, root, transform, train, angle,fold=0):
self.transform = transform
self.root = root
self.orig_list_len = 0
self.lines = []
path=pathorg.copy()
if train==True:
path.pop(fold)
else:
path=path[fold]
if isinstance(path, list):
for i in path:
with open(i) as f:
lines = f.readlines()
lines.pop(0)
self.orig_list_len += len(lines)
for line in lines:
gaze2d = line.strip().split(" ")[7]
label = np.array(gaze2d.split(",")).astype("float")
if abs((label[0]*180/np.pi)) <= angle and abs((label[1]*180/np.pi)) <= angle:
self.lines.append(line)
else:
with open(path) as f:
lines = f.readlines()
lines.pop(0)
self.orig_list_len += len(lines)
for line in lines:
gaze2d = line.strip().split(" ")[7]
label = np.array(gaze2d.split(",")).astype("float")
if abs((label[0]*180/np.pi)) <= 42 and abs((label[1]*180/np.pi)) <= 42:
self.lines.append(line)
print("{} items removed from dataset that have an angle > {}".format(self.orig_list_len-len(self.lines),angle))
def __len__(self):
return len(self.lines)
def __getitem__(self, idx):
line = self.lines[idx]
line = line.strip().split(" ")
name = line[3]
gaze2d = line[7]
head2d = line[8]
lefteye = line[1]
righteye = line[2]
face = line[0]
label = np.array(gaze2d.split(",")).astype("float")
label = torch.from_numpy(label).type(torch.FloatTensor)
pitch = label[0]* 180 / np.pi
yaw = label[1]* 180 / np.pi
img = Image.open(os.path.join(self.root, face))
# fimg = cv2.imread(os.path.join(self.root, face))
# fimg = cv2.resize(fimg, (448, 448))/255.0
# fimg = fimg.transpose(2, 0, 1)
# img=torch.from_numpy(fimg).type(torch.FloatTensor)
if self.transform:
img = self.transform(img)
# Bin values
bins = np.array(range(-42, 42,3))
binned_pose = np.digitize([pitch, yaw], bins) - 1
labels = binned_pose
cont_labels = torch.FloatTensor([pitch, yaw])
return img, labels, cont_labels, name