Spaces:
Sleeping
Sleeping
import torch.distributions as dist | |
from typing import List, Dict | |
import itertools | |
start_token = "<|startoftext|>" | |
end_token = "<|endoftext|>" | |
def _get_outside_indices(subtree_indices, attn_map_idx_to_wp): | |
flattened_subtree_indices = _flatten_indices(subtree_indices) | |
outside_indices = [ | |
map_idx | |
for map_idx in attn_map_idx_to_wp.keys() if (map_idx not in flattened_subtree_indices) | |
] | |
return outside_indices | |
def _flatten_indices(related_indices): | |
flattened_related_indices = [] | |
for item in related_indices: | |
if isinstance(item, list): | |
flattened_related_indices.extend(item) | |
else: | |
flattened_related_indices.append(item) | |
return flattened_related_indices | |
def split_indices(related_indices: List[int]): | |
noun = [related_indices[-1]] # assumes noun is always last in the list | |
modifier = related_indices[:-1] | |
if isinstance(modifier, int): | |
modifier = [modifier] | |
return noun, modifier | |
def _symmetric_kl(attention_map1, attention_map2): | |
# Convert map into a single distribution: 16x16 -> 256 | |
if len(attention_map1.shape) > 1: | |
attention_map1 = attention_map1.reshape(-1) | |
if len(attention_map2.shape) > 1: | |
attention_map2 = attention_map2.reshape(-1) | |
p = dist.Categorical(probs=attention_map1) | |
q = dist.Categorical(probs=attention_map2) | |
kl_divergence_pq = dist.kl_divergence(p, q) | |
kl_divergence_qp = dist.kl_divergence(q, p) | |
avg_kl_divergence = (kl_divergence_pq + kl_divergence_qp) / 2 | |
return avg_kl_divergence | |
def calculate_positive_loss(attention_maps, modifier, noun): | |
src_indices = modifier | |
dest_indices = noun | |
if isinstance(src_indices, list) and isinstance(dest_indices, list): | |
wp_pos_loss = [ | |
_symmetric_kl(attention_maps[s], attention_maps[d]) | |
for (s, d) in itertools.product(src_indices, dest_indices) | |
] | |
positive_loss = max(wp_pos_loss) | |
elif isinstance(dest_indices, list): | |
wp_pos_loss = [ | |
_symmetric_kl(attention_maps[src_indices], attention_maps[d]) | |
for d in dest_indices | |
] | |
positive_loss = max(wp_pos_loss) | |
elif isinstance(src_indices, list): | |
wp_pos_loss = [ | |
_symmetric_kl(attention_maps[s], attention_maps[dest_indices]) | |
for s in src_indices | |
] | |
positive_loss = max(wp_pos_loss) | |
else: | |
positive_loss = _symmetric_kl( | |
attention_maps[src_indices], attention_maps[dest_indices] | |
) | |
return positive_loss | |
def _calculate_outside_loss(attention_maps, src_indices, outside_loss): | |
negative_loss = [] | |
computed_pairs = set() | |
pair_counter = 0 | |
for outside_idx in outside_loss: | |
if isinstance(src_indices, list): | |
wp_neg_loss = [] | |
for t in src_indices: | |
pair_key = (t, outside_idx) | |
if pair_key not in computed_pairs: | |
wp_neg_loss.append( | |
_symmetric_kl( | |
attention_maps[t], attention_maps[outside_idx] | |
) | |
) | |
computed_pairs.add(pair_key) | |
negative_loss.append(max(wp_neg_loss) if wp_neg_loss else 0) | |
pair_counter += 1 | |
else: | |
pair_key = (src_indices, outside_idx) | |
if pair_key not in computed_pairs: | |
negative_loss.append( | |
_symmetric_kl( | |
attention_maps[src_indices], attention_maps[outside_idx] | |
) | |
) | |
computed_pairs.add(pair_key) | |
pair_counter += 1 | |
return negative_loss, pair_counter | |
def align_wordpieces_indices( | |
wordpieces2indices, start_idx, target_word | |
): | |
""" | |
Aligns a `target_word` that contains more than one wordpiece (the first wordpiece is `start_idx`) | |
""" | |
wp_indices = [start_idx] | |
wp = wordpieces2indices[start_idx].replace("</w>", "") | |
# Run over the next wordpieces in the sequence (which is why we use +1) | |
for wp_idx in range(start_idx + 1, len(wordpieces2indices)): | |
if wp == target_word: | |
break | |
wp2 = wordpieces2indices[wp_idx].replace("</w>", "") | |
if target_word.startswith(wp + wp2) and wp2 != target_word: | |
wp += wordpieces2indices[wp_idx].replace("</w>", "") | |
wp_indices.append(wp_idx) | |
else: | |
wp_indices = ( | |
[] | |
) # if there's no match, you want to clear the list and finish | |
break | |
return wp_indices | |
def extract_attribution_indices(prompt, parser): | |
doc = parser(prompt) | |
subtrees = [] | |
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"] | |
for w in doc: | |
if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers: | |
continue | |
subtree = [] | |
stack = [] | |
for child in w.children: | |
if child.dep_ in modifiers: | |
subtree.append(child) | |
stack.extend(child.children) | |
while stack: | |
node = stack.pop() | |
if node.dep_ in modifiers or node.dep_ == "conj": | |
subtree.append(node) | |
stack.extend(node.children) | |
if subtree: | |
subtree.append(w) | |
subtrees.append(subtree) | |
return subtrees | |
def calculate_negative_loss( | |
attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp | |
): | |
outside_indices = _get_outside_indices(subtree_indices, attn_map_idx_to_wp) | |
negative_modifier_loss, num_modifier_pairs = _calculate_outside_loss( | |
attention_maps, modifier, outside_indices | |
) | |
negative_noun_loss, num_noun_pairs = _calculate_outside_loss( | |
attention_maps, noun, outside_indices | |
) | |
negative_modifier_loss = -sum(negative_modifier_loss) / len(outside_indices) | |
negative_noun_loss = -sum(negative_noun_loss) / len(outside_indices) | |
negative_loss = (negative_modifier_loss + negative_noun_loss) / 2 | |
return negative_loss | |
def get_indices(tokenizer, prompt: str) -> Dict[str, int]: | |
"""Utility function to list the indices of the tokens you wish to alte""" | |
ids = tokenizer(prompt).input_ids | |
indices = { | |
i: tok | |
for tok, i in zip( | |
tokenizer.convert_ids_to_tokens(ids), range(len(ids)) | |
) | |
} | |
return indices | |
def get_attention_map_index_to_wordpiece(tokenizer, prompt): | |
attn_map_idx_to_wp = {} | |
wordpieces2indices = get_indices(tokenizer, prompt) | |
# Ignore `start_token` and `end_token` | |
for i in list(wordpieces2indices.keys())[1:-1]: | |
wordpiece = wordpieces2indices[i] | |
wordpiece = wordpiece.replace("</w>", "") | |
attn_map_idx_to_wp[i] = wordpiece | |
return attn_map_idx_to_wp |