# -*- coding:utf-8 -*- import os import sys import shutil from tqdm import tqdm import yaml import random import importlib from PIL import Image import imageio import numpy as np import cv2 import torch from torchvision import utils from scipy.interpolate import PchipInterpolator def split_filename(filename): absname = os.path.abspath(filename) dirname, basename = os.path.split(absname) split_tmp = basename.rsplit('.', maxsplit=1) if len(split_tmp) == 2: rootname, extname = split_tmp elif len(split_tmp) == 1: rootname = split_tmp[0] extname = None else: raise ValueError("programming error!") return dirname, rootname, extname def data2file(data, filename, type=None, override=False, printable=False, **kwargs): dirname, rootname, extname = split_filename(filename) print_did_not_save_flag = True if type: extname = type if not os.path.exists(dirname): os.makedirs(dirname, exist_ok=True) if not os.path.exists(filename) or override: if extname in ['jpg', 'png', 'jpeg']: utils.save_image(data, filename, **kwargs) elif extname == 'gif': imageio.mimsave(filename, data, format='GIF', duration=kwargs.get('duration'), loop=0) elif extname == 'txt': if kwargs is None: kwargs = {} max_step = kwargs.get('max_step') if max_step is None: max_step = np.Infinity with open(filename, 'w', encoding='utf-8') as f: for i, e in enumerate(data): if i < max_step: f.write(str(e) + '\n') else: break else: raise ValueError('Do not support this type') if printable: print('Saved data to %s' % os.path.abspath(filename)) else: if print_did_not_save_flag: print( 'Did not save data to %s because file exists and override is False' % os.path.abspath( filename)) def file2data(filename, type=None, printable=True, **kwargs): dirname, rootname, extname = split_filename(filename) print_load_flag = True if type: extname = type if extname in ['pth', 'ckpt', 'bin']: data = torch.load(filename, map_location=kwargs.get('map_location')) if "state_dict" in data.keys(): data = data["state_dict"] data = {k.replace("_forward_module.", ""):v for k,v in data.items()} elif extname == 'txt': top = kwargs.get('top', None) with open(filename, encoding='utf-8') as f: if top: data = [f.readline() for _ in range(top)] else: data = [e for e in f.read().split('\n') if e] elif extname == 'yaml': with open(filename, 'r') as f: data = yaml.load(f) else: raise ValueError('type can only support h5, npy, json, txt') if printable: if print_load_flag: print('Loaded data from %s' % os.path.abspath(filename)) return data def ensure_dirname(dirname, override=False): if os.path.exists(dirname) and override: print('Removing dirname: %s' % os.path.abspath(dirname)) try: shutil.rmtree(dirname) except OSError as e: raise ValueError('Failed to delete %s because %s' % (dirname, e)) if not os.path.exists(dirname): print('Making dirname: %s' % os.path.abspath(dirname)) os.makedirs(dirname, exist_ok=True) def import_filename(filename): spec = importlib.util.spec_from_file_location("mymodule", filename) module = importlib.util.module_from_spec(spec) sys.modules[spec.name] = module spec.loader.exec_module(module) return module def adaptively_load_state_dict(target, state_dict): target_dict = target.state_dict() try: common_dict = {k: v for k, v in state_dict.items() if k in target_dict and v.size() == target_dict[k].size()} # unmatch_dict = {k: v for k, v in state_dict.items() if k not in target_dict or v.size() != target_dict[k].size()} except Exception as e: print('load error %s', e) common_dict = {k: v for k, v in state_dict.items() if k in target_dict} if 'param_groups' in common_dict and common_dict['param_groups'][0]['params'] != \ target.state_dict()['param_groups'][0]['params']: print('Detected mismatch params, auto adapte state_dict to current') common_dict['param_groups'][0]['params'] = target.state_dict()['param_groups'][0]['params'] target_dict.update(common_dict) target.load_state_dict(target_dict) missing_keys = [k for k in target_dict.keys() if k not in common_dict] unexpected_keys = [k for k in state_dict.keys() if k not in common_dict] if len(unexpected_keys) != 0: print( f"Some weights of state_dict were not used in target: {unexpected_keys}" ) if len(missing_keys) != 0: print( f"Some weights of state_dict are missing used in target {missing_keys}" ) if len(unexpected_keys) == 0 and len(missing_keys) == 0: print("Strictly Loaded state_dict.") def set_seed(seed=42): random.seed(seed) os.environ['PYHTONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True def image2pil(filename): return Image.open(filename) def image2arr(filename): pil = image2pil(filename) return pil2arr(pil) def pil2arr(pil): if isinstance(pil, list): arr = np.array( [np.array(e.convert('RGB').getdata(), dtype=np.uint8).reshape(e.size[1], e.size[0], 3) for e in pil]) else: arr = np.array(pil) return arr def arr2pil(arr): if arr.ndim == 3: return Image.fromarray(arr.astype('uint8'), 'RGB') elif arr.ndim == 4: return [Image.fromarray(e.astype('uint8'), 'RGB') for e in list(arr)] else: raise ValueError('arr must has ndim of 3 or 4, but got %s' % arr.ndim) def interpolate_trajectory(points, n_points): x = [point[0] for point in points] y = [point[1] for point in points] t = np.linspace(0, 1, len(points)) fx = PchipInterpolator(t, x) fy = PchipInterpolator(t, y) new_t = np.linspace(0, 1, n_points) new_x = fx(new_t) new_y = fy(new_t) new_points = list(zip(new_x, new_y)) return new_points def visualize_drag(background_image_path, splited_tracks, drag_mode, width, height, model_length): if drag_mode=='object': color = (255, 0, 0, 255) elif drag_mode=='camera': color = (0, 0, 255, 255) background_image = Image.open(background_image_path).convert('RGBA') background_image = background_image.resize((width, height)) w, h = background_image.size transparent_background = np.array(background_image) transparent_background[:, :, -1] = 128 transparent_background = Image.fromarray(transparent_background) # Create a transparent layer with the same size as the background image transparent_layer = np.zeros((h, w, 4)) for splited_track in splited_tracks: if len(splited_track) > 1: splited_track = interpolate_trajectory(splited_track, model_length) splited_track = splited_track[:model_length] for i in range(len(splited_track)-1): start_point = (int(splited_track[i][0]), int(splited_track[i][1])) end_point = (int(splited_track[i+1][0]), int(splited_track[i+1][1])) vx = end_point[0] - start_point[0] vy = end_point[1] - start_point[1] arrow_length = np.sqrt(vx**2 + vy**2) if i == len(splited_track)-2: cv2.arrowedLine(transparent_layer, start_point, end_point, color, 2, tipLength=8 / arrow_length) else: cv2.line(transparent_layer, start_point, end_point, color, 2) else: cv2.circle(transparent_layer, (int(splited_track[0][0]), int(splited_track[0][1])), 5, color, -1) transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8)) trajectory_map = Image.alpha_composite(transparent_background, transparent_layer) return trajectory_map, transparent_layer