Spaces:
Paused
Paused
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
def l1(_crossattn_similarity, mask, token_idx = [1,2]): | |
similarity = torch.cat(_crossattn_similarity,1)[1] | |
similarity = similarity.mean(0).permute(2,0,1) | |
# similarity = similarity.softmax(dim = 0) | |
return (similarity[token_idx] * mask.cuda()).sum() | |
def bce(_crossattn_similarity, mask, token_idx = [1,2]): | |
similarity = torch.cat(_crossattn_similarity,1)[1] | |
similarity = similarity.mean(0).permute(2,0,1) | |
# similarity = similarity.softmax(dim = 0) | |
return -sum([ | |
F.binary_cross_entropy_with_logits(x - 1.0, mask.cuda()) | |
for x in similarity[token_idx] | |
]) | |
def softmax(_crossattn_similarity, mask, token_idx = [1,2]): | |
similarity = torch.cat(_crossattn_similarity,1)[1] | |
similarity = similarity.mean(0).permute(2,0,1) | |
similarity = similarity[1:].softmax(dim = 0) # Comute the softmax to obtain probability values | |
token_idx = [x - 1 for x in token_idx] | |
score = similarity[token_idx].sum(dim = 0) # Sum up all relevant tokens to get pixel-wise probability of belonging to the correct class | |
score = torch.log(score) # Obtain log-probabilities per-pixel | |
return (score * mask.cuda()).sum() # Sum up log-probabilities (equivalent to multiplying P-values) for all pixels inside of the mask |