import torch import math def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions): loss = 0 object_number = len(bboxes) if object_number == 0: return torch.tensor(0).float().cuda() for attn_map_integrated in attn_maps_mid: attn_map = attn_map_integrated.chunk(2)[1] # b, i, j = attn_map.shape H = W = int(math.sqrt(i)) for obj_idx in range(object_number): obj_loss = 0 mask = torch.zeros(size=(H, W)).cuda() for obj_box in bboxes[obj_idx]: x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) mask[y_min: y_max, x_min: x_max] = 1 for obj_position in object_positions[obj_idx]: ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1) obj_loss += torch.mean((1 - activation_value) ** 2) loss += (obj_loss/len(object_positions[obj_idx])) # compute loss on padding tokens # activation_value = torch.zeros(size=(b, )).cuda() # for obj_idx in range(object_number): # bbox = bboxes[obj_idx] # ca_map_obj = attn_map[:, :, padding_start:].reshape(b, H, W, -1) # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) # # loss += torch.mean((1 - activation_value) ** 2) for attn_map_integrated in attn_maps_up[0]: attn_map = attn_map_integrated.chunk(2)[1] # b, i, j = attn_map.shape H = W = int(math.sqrt(i)) for obj_idx in range(object_number): obj_loss = 0 mask = torch.zeros(size=(H, W)).cuda() for obj_box in bboxes[obj_idx]: x_min, y_min, x_max, y_max = int(obj_box[0] * W), \ int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H) mask[y_min: y_max, x_min: x_max] = 1 for obj_position in object_positions[obj_idx]: ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W) # ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W) activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum( dim=-1) obj_loss += torch.mean((1 - activation_value) ** 2) loss += (obj_loss / len(object_positions[obj_idx])) # compute loss on padding tokens # activation_value = torch.zeros(size=(b, )).cuda() # for obj_idx in range(object_number): # bbox = bboxes[obj_idx] # ca_map_obj = attn_map[:, :,padding_start:].reshape(b, H, W, -1) # activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H), # int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1) # # loss += torch.mean((1 - activation_value) ** 2) loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid))) return loss