Spaces:
Runtime error
Runtime error
from cl import IClassifier | |
from build_graph_utils import * | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader | |
import torchvision.models as models | |
import torchvision.transforms.functional as VF | |
from torchvision import transforms | |
import sys, argparse, os, glob | |
import pandas as pd | |
import numpy as np | |
from PIL import Image | |
from collections import OrderedDict | |
def compute_feats(args, bags_list, i_classifier, device, save_path=None, whole_slide_path=None): | |
num_bags = len(bags_list) | |
Tensor = torch.FloatTensor | |
for i in range(0, num_bags): | |
feats_list = [] | |
if args.magnification == '20x': | |
glob_path = os.path.join(bags_list[i], '*.jpeg') | |
csv_file_path = glob.glob(glob_path) | |
# line below was in the original version, commented due to errror with current version | |
#file_name = bags_list[i].split('/')[-3].split('_')[0] | |
file_name = glob_path.split('/')[-3].split('_')[0] | |
if args.magnification == '5x' or args.magnification == '10x': | |
csv_file_path = glob.glob(os.path.join(bags_list[i], '*.jpg')) | |
dataloader, bag_size = bag_dataset(args, csv_file_path) | |
print('{} files to be processed: {}'.format(len(csv_file_path), file_name)) | |
if os.path.isdir(os.path.join(save_path, 'simclr_files', file_name)) or len(csv_file_path) < 1: | |
print('alreday exists') | |
continue | |
with torch.no_grad(): | |
for iteration, batch in enumerate(dataloader): | |
patches = batch['input'].float().to(device) | |
feats, classes = i_classifier(patches) | |
#feats = feats.cpu().numpy() | |
feats_list.extend(feats) | |
os.makedirs(os.path.join(save_path, 'simclr_files', file_name), exist_ok=True) | |
txt_file = open(os.path.join(save_path, 'simclr_files', file_name, 'c_idx.txt'), "w+") | |
save_coords(txt_file, csv_file_path) | |
# save node features | |
output = torch.stack(feats_list, dim=0).to(device) | |
torch.save(output, os.path.join(save_path, 'simclr_files', file_name, 'features.pt')) | |
# save adjacent matrix | |
adj_s = adj_matrix(csv_file_path, output, device=device) | |
torch.save(adj_s, os.path.join(save_path, 'simclr_files', file_name, 'adj_s.pt')) | |
print('\r Computed: {}/{}'.format(i+1, num_bags)) | |
def main(): | |
parser = argparse.ArgumentParser(description='Compute TCGA features from SimCLR embedder') | |
parser.add_argument('--num_classes', default=2, type=int, help='Number of output classes') | |
parser.add_argument('--num_feats', default=512, type=int, help='Feature size') | |
parser.add_argument('--batch_size', default=128, type=int, help='Batch size of dataloader') | |
parser.add_argument('--num_workers', default=0, type=int, help='Number of threads for datalodaer') | |
parser.add_argument('--dataset', default=None, type=str, help='path to patches') | |
parser.add_argument('--backbone', default='resnet18', type=str, help='Embedder backbone') | |
parser.add_argument('--magnification', default='20x', type=str, help='Magnification to compute features') | |
parser.add_argument('--weights', default=None, type=str, help='path to the pretrained weights') | |
parser.add_argument('--output', default=None, type=str, help='path to the output graph folder') | |
args = parser.parse_args() | |
if args.backbone == 'resnet18': | |
resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d) | |
num_feats = 512 | |
if args.backbone == 'resnet34': | |
resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d) | |
num_feats = 512 | |
if args.backbone == 'resnet50': | |
resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d) | |
num_feats = 2048 | |
if args.backbone == 'resnet101': | |
resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d) | |
num_feats = 2048 | |
for param in resnet.parameters(): | |
param.requires_grad = False | |
resnet.fc = nn.Identity() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print("Running on:", device) | |
i_classifier = IClassifier(resnet, num_feats, output_class=args.num_classes).to(device) | |
# load feature extractor | |
if args.weights is None: | |
print('No feature extractor') | |
return | |
state_dict_weights = torch.load(args.weights) | |
state_dict_init = i_classifier.state_dict() | |
new_state_dict = OrderedDict() | |
for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()): | |
if 'features' not in k: | |
continue | |
name = k_0 | |
new_state_dict[name] = v | |
i_classifier.load_state_dict(new_state_dict, strict=False) | |
os.makedirs(args.output, exist_ok=True) | |
bags_list = glob.glob(args.dataset) | |
print(bags_list) | |
compute_feats(args, bags_list, i_classifier, device, args.output) | |
if __name__ == '__main__': | |
main() | |