AioMedica / feature_extractor /build_graphs.py
chris1nexus
First commit
d60982d
raw
history blame
5.03 kB
from cl import IClassifier
from build_graph_utils import *
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.transforms.functional as VF
from torchvision import transforms
import sys, argparse, os, glob
import pandas as pd
import numpy as np
from PIL import Image
from collections import OrderedDict
def compute_feats(args, bags_list, i_classifier, device, save_path=None, whole_slide_path=None):
num_bags = len(bags_list)
Tensor = torch.FloatTensor
for i in range(0, num_bags):
feats_list = []
if args.magnification == '20x':
glob_path = os.path.join(bags_list[i], '*.jpeg')
csv_file_path = glob.glob(glob_path)
# line below was in the original version, commented due to errror with current version
#file_name = bags_list[i].split('/')[-3].split('_')[0]
file_name = glob_path.split('/')[-3].split('_')[0]
if args.magnification == '5x' or args.magnification == '10x':
csv_file_path = glob.glob(os.path.join(bags_list[i], '*.jpg'))
dataloader, bag_size = bag_dataset(args, csv_file_path)
print('{} files to be processed: {}'.format(len(csv_file_path), file_name))
if os.path.isdir(os.path.join(save_path, 'simclr_files', file_name)) or len(csv_file_path) < 1:
print('alreday exists')
continue
with torch.no_grad():
for iteration, batch in enumerate(dataloader):
patches = batch['input'].float().to(device)
feats, classes = i_classifier(patches)
#feats = feats.cpu().numpy()
feats_list.extend(feats)
os.makedirs(os.path.join(save_path, 'simclr_files', file_name), exist_ok=True)
txt_file = open(os.path.join(save_path, 'simclr_files', file_name, 'c_idx.txt'), "w+")
save_coords(txt_file, csv_file_path)
# save node features
output = torch.stack(feats_list, dim=0).to(device)
torch.save(output, os.path.join(save_path, 'simclr_files', file_name, 'features.pt'))
# save adjacent matrix
adj_s = adj_matrix(csv_file_path, output, device=device)
torch.save(adj_s, os.path.join(save_path, 'simclr_files', file_name, 'adj_s.pt'))
print('\r Computed: {}/{}'.format(i+1, num_bags))
def main():
parser = argparse.ArgumentParser(description='Compute TCGA features from SimCLR embedder')
parser.add_argument('--num_classes', default=2, type=int, help='Number of output classes')
parser.add_argument('--num_feats', default=512, type=int, help='Feature size')
parser.add_argument('--batch_size', default=128, type=int, help='Batch size of dataloader')
parser.add_argument('--num_workers', default=0, type=int, help='Number of threads for datalodaer')
parser.add_argument('--dataset', default=None, type=str, help='path to patches')
parser.add_argument('--backbone', default='resnet18', type=str, help='Embedder backbone')
parser.add_argument('--magnification', default='20x', type=str, help='Magnification to compute features')
parser.add_argument('--weights', default=None, type=str, help='path to the pretrained weights')
parser.add_argument('--output', default=None, type=str, help='path to the output graph folder')
args = parser.parse_args()
if args.backbone == 'resnet18':
resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d)
num_feats = 512
if args.backbone == 'resnet34':
resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d)
num_feats = 512
if args.backbone == 'resnet50':
resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d)
num_feats = 2048
if args.backbone == 'resnet101':
resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d)
num_feats = 2048
for param in resnet.parameters():
param.requires_grad = False
resnet.fc = nn.Identity()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Running on:", device)
i_classifier = IClassifier(resnet, num_feats, output_class=args.num_classes).to(device)
# load feature extractor
if args.weights is None:
print('No feature extractor')
return
state_dict_weights = torch.load(args.weights)
state_dict_init = i_classifier.state_dict()
new_state_dict = OrderedDict()
for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()):
if 'features' not in k:
continue
name = k_0
new_state_dict[name] = v
i_classifier.load_state_dict(new_state_dict, strict=False)
os.makedirs(args.output, exist_ok=True)
bags_list = glob.glob(args.dataset)
print(bags_list)
compute_feats(args, bags_list, i_classifier, device, args.output)
if __name__ == '__main__':
main()