Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
from __future__ import absolute_import, division, print_function | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from torchvision import transforms | |
from utils.metrics import ConfusionMatrix | |
from PIL import Image | |
import os | |
# torch.cuda.synchronize() | |
# torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.deterministic = True | |
def collate(batch): | |
image = [ b['image'] for b in batch ] # w, h | |
label = [ b['label'] for b in batch ] | |
id = [ b['id'] for b in batch ] | |
adj_s = [ b['adj_s'] for b in batch ] | |
return {'image': image, 'label': label, 'id': id, 'adj_s': adj_s} | |
def preparefeatureLabel(batch_graph, batch_label, batch_adjs, device='cpu'): | |
batch_size = len(batch_graph) | |
labels = torch.LongTensor(batch_size) | |
max_node_num = 0 | |
for i in range(batch_size): | |
labels[i] = batch_label[i] | |
max_node_num = max(max_node_num, batch_graph[i].shape[0]) | |
masks = torch.zeros(batch_size, max_node_num) | |
adjs = torch.zeros(batch_size, max_node_num, max_node_num) | |
batch_node_feat = torch.zeros(batch_size, max_node_num, 512) | |
for i in range(batch_size): | |
cur_node_num = batch_graph[i].shape[0] | |
#node attribute feature | |
tmp_node_fea = batch_graph[i] | |
batch_node_feat[i, 0:cur_node_num] = tmp_node_fea | |
#adjs | |
adjs[i, 0:cur_node_num, 0:cur_node_num] = batch_adjs[i] | |
#masks | |
masks[i,0:cur_node_num] = 1 | |
node_feat = batch_node_feat.to(device) | |
labels = labels.to(device) | |
adjs = adjs.to(device) | |
masks = masks.to(device) | |
return node_feat, labels, adjs, masks | |
class Trainer(object): | |
def __init__(self, n_class): | |
self.metrics = ConfusionMatrix(n_class) | |
def get_scores(self): | |
acc = self.metrics.get_scores() | |
return acc | |
def reset_metrics(self): | |
self.metrics.reset() | |
def plot_cm(self): | |
self.metrics.plotcm() | |
def train(self, sample, model): | |
node_feat, labels, adjs, masks = preparefeatureLabel(sample['image'], sample['label'], sample['adj_s']) | |
pred,labels,loss = model.forward(node_feat, labels, adjs, masks) | |
return pred,labels,loss | |
class Evaluator(object): | |
def __init__(self, n_class): | |
self.metrics = ConfusionMatrix(n_class) | |
def get_scores(self): | |
acc = self.metrics.get_scores() | |
return acc | |
def reset_metrics(self): | |
self.metrics.reset() | |
def plot_cm(self): | |
self.metrics.plotcm() | |
def eval_test(self, sample, model, graphcam_flag=False): | |
node_feat, labels, adjs, masks = preparefeatureLabel(sample['image'], sample['label'], sample['adj_s']) | |
if not graphcam_flag: | |
with torch.no_grad(): | |
pred,labels,loss = model.forward(node_feat, labels, adjs, masks) | |
else: | |
torch.set_grad_enabled(True) | |
pred,labels,loss= model.forward(node_feat, labels, adjs, masks, graphcam_flag=graphcam_flag) | |
return pred,labels,loss |