layout-guidance / utils.py
silentchen's picture
First Commit
17dd4b5
raw
history blame contribute delete
No virus
3.41 kB
import torch
import math
def compute_ca_loss(attn_maps_mid, attn_maps_up, bboxes, object_positions):
loss = 0
object_number = len(bboxes)
if object_number == 0:
return torch.tensor(0).float().cuda()
for attn_map_integrated in attn_maps_mid:
attn_map = attn_map_integrated.chunk(2)[1]
#
b, i, j = attn_map.shape
H = W = int(math.sqrt(i))
for obj_idx in range(object_number):
obj_loss = 0
mask = torch.zeros(size=(H, W)).cuda()
for obj_box in bboxes[obj_idx]:
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
mask[y_min: y_max, x_min: x_max] = 1
for obj_position in object_positions[obj_idx]:
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1)
obj_loss += torch.mean((1 - activation_value) ** 2)
loss += (obj_loss/len(object_positions[obj_idx]))
# compute loss on padding tokens
# activation_value = torch.zeros(size=(b, )).cuda()
# for obj_idx in range(object_number):
# bbox = bboxes[obj_idx]
# ca_map_obj = attn_map[:, :, padding_start:].reshape(b, H, W, -1)
# activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H),
# int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1)
#
# loss += torch.mean((1 - activation_value) ** 2)
for attn_map_integrated in attn_maps_up[0]:
attn_map = attn_map_integrated.chunk(2)[1]
#
b, i, j = attn_map.shape
H = W = int(math.sqrt(i))
for obj_idx in range(object_number):
obj_loss = 0
mask = torch.zeros(size=(H, W)).cuda()
for obj_box in bboxes[obj_idx]:
x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
mask[y_min: y_max, x_min: x_max] = 1
for obj_position in object_positions[obj_idx]:
ca_map_obj = attn_map[:, :, obj_position].reshape(b, H, W)
# ca_map_obj = attn_map[:, :, object_positions[obj_position]].reshape(b, H, W)
activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(
dim=-1)
obj_loss += torch.mean((1 - activation_value) ** 2)
loss += (obj_loss / len(object_positions[obj_idx]))
# compute loss on padding tokens
# activation_value = torch.zeros(size=(b, )).cuda()
# for obj_idx in range(object_number):
# bbox = bboxes[obj_idx]
# ca_map_obj = attn_map[:, :,padding_start:].reshape(b, H, W, -1)
# activation_value += ca_map_obj[:, int(bbox[0] * H): int(bbox[1] * H),
# int(bbox[2] * W): int(bbox[3] * W), :].reshape(b, -1).sum(dim=-1) / ca_map_obj.reshape(b, -1).sum(dim=-1)
#
# loss += torch.mean((1 - activation_value) ** 2)
loss = loss / (object_number * (len(attn_maps_up[0]) + len(attn_maps_mid)))
return loss