import torch.distributions as dist from typing import List, Dict import itertools start_token = "<|startoftext|>" end_token = "<|endoftext|>" def _get_outside_indices(subtree_indices, attn_map_idx_to_wp): flattened_subtree_indices = _flatten_indices(subtree_indices) outside_indices = [ map_idx for map_idx in attn_map_idx_to_wp.keys() if (map_idx not in flattened_subtree_indices) ] return outside_indices def _flatten_indices(related_indices): flattened_related_indices = [] for item in related_indices: if isinstance(item, list): flattened_related_indices.extend(item) else: flattened_related_indices.append(item) return flattened_related_indices def split_indices(related_indices: List[int]): noun = [related_indices[-1]] # assumes noun is always last in the list modifier = related_indices[:-1] if isinstance(modifier, int): modifier = [modifier] return noun, modifier def _symmetric_kl(attention_map1, attention_map2): # Convert map into a single distribution: 16x16 -> 256 if len(attention_map1.shape) > 1: attention_map1 = attention_map1.reshape(-1) if len(attention_map2.shape) > 1: attention_map2 = attention_map2.reshape(-1) p = dist.Categorical(probs=attention_map1) q = dist.Categorical(probs=attention_map2) kl_divergence_pq = dist.kl_divergence(p, q) kl_divergence_qp = dist.kl_divergence(q, p) avg_kl_divergence = (kl_divergence_pq + kl_divergence_qp) / 2 return avg_kl_divergence def calculate_positive_loss(attention_maps, modifier, noun): src_indices = modifier dest_indices = noun if isinstance(src_indices, list) and isinstance(dest_indices, list): wp_pos_loss = [ _symmetric_kl(attention_maps[s], attention_maps[d]) for (s, d) in itertools.product(src_indices, dest_indices) ] positive_loss = max(wp_pos_loss) elif isinstance(dest_indices, list): wp_pos_loss = [ _symmetric_kl(attention_maps[src_indices], attention_maps[d]) for d in dest_indices ] positive_loss = max(wp_pos_loss) elif isinstance(src_indices, list): wp_pos_loss = [ _symmetric_kl(attention_maps[s], attention_maps[dest_indices]) for s in src_indices ] positive_loss = max(wp_pos_loss) else: positive_loss = _symmetric_kl( attention_maps[src_indices], attention_maps[dest_indices] ) return positive_loss def _calculate_outside_loss(attention_maps, src_indices, outside_loss): negative_loss = [] computed_pairs = set() pair_counter = 0 for outside_idx in outside_loss: if isinstance(src_indices, list): wp_neg_loss = [] for t in src_indices: pair_key = (t, outside_idx) if pair_key not in computed_pairs: wp_neg_loss.append( _symmetric_kl( attention_maps[t], attention_maps[outside_idx] ) ) computed_pairs.add(pair_key) negative_loss.append(max(wp_neg_loss) if wp_neg_loss else 0) pair_counter += 1 else: pair_key = (src_indices, outside_idx) if pair_key not in computed_pairs: negative_loss.append( _symmetric_kl( attention_maps[src_indices], attention_maps[outside_idx] ) ) computed_pairs.add(pair_key) pair_counter += 1 return negative_loss, pair_counter def align_wordpieces_indices( wordpieces2indices, start_idx, target_word ): """ Aligns a `target_word` that contains more than one wordpiece (the first wordpiece is `start_idx`) """ wp_indices = [start_idx] wp = wordpieces2indices[start_idx].replace("", "") # Run over the next wordpieces in the sequence (which is why we use +1) for wp_idx in range(start_idx + 1, len(wordpieces2indices)): if wp == target_word: break wp2 = wordpieces2indices[wp_idx].replace("", "") if target_word.startswith(wp + wp2) and wp2 != target_word: wp += wordpieces2indices[wp_idx].replace("", "") wp_indices.append(wp_idx) else: wp_indices = ( [] ) # if there's no match, you want to clear the list and finish break return wp_indices def extract_attribution_indices(prompt, parser): doc = parser(prompt) subtrees = [] modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"] for w in doc: if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers: continue subtree = [] stack = [] for child in w.children: if child.dep_ in modifiers: subtree.append(child) stack.extend(child.children) while stack: node = stack.pop() if node.dep_ in modifiers or node.dep_ == "conj": subtree.append(node) stack.extend(node.children) if subtree: subtree.append(w) subtrees.append(subtree) return subtrees def calculate_negative_loss( attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp ): outside_indices = _get_outside_indices(subtree_indices, attn_map_idx_to_wp) negative_modifier_loss, num_modifier_pairs = _calculate_outside_loss( attention_maps, modifier, outside_indices ) negative_noun_loss, num_noun_pairs = _calculate_outside_loss( attention_maps, noun, outside_indices ) negative_modifier_loss = -sum(negative_modifier_loss) / len(outside_indices) negative_noun_loss = -sum(negative_noun_loss) / len(outside_indices) negative_loss = (negative_modifier_loss + negative_noun_loss) / 2 return negative_loss def get_indices(tokenizer, prompt: str) -> Dict[str, int]: """Utility function to list the indices of the tokens you wish to alte""" ids = tokenizer(prompt).input_ids indices = { i: tok for tok, i in zip( tokenizer.convert_ids_to_tokens(ids), range(len(ids)) ) } return indices def get_attention_map_index_to_wordpiece(tokenizer, prompt): attn_map_idx_to_wp = {} wordpieces2indices = get_indices(tokenizer, prompt) # Ignore `start_token` and `end_token` for i in list(wordpieces2indices.keys())[1:-1]: wordpiece = wordpieces2indices[i] wordpiece = wordpiece.replace("", "") attn_map_idx_to_wp[i] = wordpiece return attn_map_idx_to_wp