import torch import torchvision.transforms as tvt import pandas as pd import os from tqdm import tqdm from PIL import Image torch.set_num_threads(2) outdir = 'pt_files/train' yolo_crop_file = 'image_yolo.txt' def crop(img, x, y, w, h): #if not dets: # return img #x, y, w, h = [float(e) for e in dets.split(',')[0:4]] W, H = img.size x1 = x * W - w * W / 2.0 x2 = x * W + w * W / 2.0 y1 = y * H - h * H / 2.0 y2 = y * H + h * H / 2.0 return img.crop((x1,y1,x2,y2)) is_report_file = lambda s: 'RPT' in s get_barcode = lambda s: s.split('/')[-3] CHANNEL = 3 IMAGE_SIZE = 448 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) normalize = tvt.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)) transform_ops = tvt.Compose([tvt.Resize(IMAGE_SIZE), tvt.CenterCrop(IMAGE_SIZE), tvt.ToTensor(), normalize]) model_path = './traced_swav_imagenet_layer2.pt' df = pd.read_csv(yolo_crop_file) df.insert(0, 'is_report_file', [is_report_file(s) for s in df.orig]) df.insert(0, 'patient_barcode', [get_barcode(s) for s in df.orig]) df = df[df.is_report_file == False] net = torch.jit.load(model_path) net = net.cuda() net.eval() for patient_barcode, dfg in tqdm(df.groupby('patient_barcode'), total=len(df.patient_barcode.unique())): outfile = f"{outdir}/{patient_barcode}.pt" if os.path.exists(outfile):continue N = len(dfg) image_tensors = torch.zeros(N, CHANNEL, IMAGE_SIZE, IMAGE_SIZE) for i, image_file, x, y, w, h in zip(range(N), dfg.orig, dfg.x, dfg.y, dfg.w, dfg.h): with open(image_file, 'rb') as f: img = Image.open(f) img = img.convert('RGB') img = crop(img, x, y, w, h) img_tensor = transform_ops(img) image_tensors[i] = img_tensor image_tensors = image_tensors.cuda() with torch.no_grad(): features = net(image_tensors).cpu() torch.save(features, outfile)