SynGen / compute_loss.py
Royir's picture
Update compute_loss.py
19cb368
import torch.distributions as dist
from typing import List, Dict
import itertools
start_token = "<|startoftext|>"
end_token = "<|endoftext|>"
def _get_outside_indices(subtree_indices, attn_map_idx_to_wp):
flattened_subtree_indices = _flatten_indices(subtree_indices)
outside_indices = [
map_idx
for map_idx in attn_map_idx_to_wp.keys() if (map_idx not in flattened_subtree_indices)
]
return outside_indices
def _flatten_indices(related_indices):
flattened_related_indices = []
for item in related_indices:
if isinstance(item, list):
flattened_related_indices.extend(item)
else:
flattened_related_indices.append(item)
return flattened_related_indices
def split_indices(related_indices: List[int]):
noun = [related_indices[-1]] # assumes noun is always last in the list
modifier = related_indices[:-1]
if isinstance(modifier, int):
modifier = [modifier]
return noun, modifier
def _symmetric_kl(attention_map1, attention_map2):
# Convert map into a single distribution: 16x16 -> 256
if len(attention_map1.shape) > 1:
attention_map1 = attention_map1.reshape(-1)
if len(attention_map2.shape) > 1:
attention_map2 = attention_map2.reshape(-1)
p = dist.Categorical(probs=attention_map1)
q = dist.Categorical(probs=attention_map2)
kl_divergence_pq = dist.kl_divergence(p, q)
kl_divergence_qp = dist.kl_divergence(q, p)
avg_kl_divergence = (kl_divergence_pq + kl_divergence_qp) / 2
return avg_kl_divergence
def calculate_positive_loss(attention_maps, modifier, noun):
src_indices = modifier
dest_indices = noun
if isinstance(src_indices, list) and isinstance(dest_indices, list):
wp_pos_loss = [
_symmetric_kl(attention_maps[s], attention_maps[d])
for (s, d) in itertools.product(src_indices, dest_indices)
]
positive_loss = max(wp_pos_loss)
elif isinstance(dest_indices, list):
wp_pos_loss = [
_symmetric_kl(attention_maps[src_indices], attention_maps[d])
for d in dest_indices
]
positive_loss = max(wp_pos_loss)
elif isinstance(src_indices, list):
wp_pos_loss = [
_symmetric_kl(attention_maps[s], attention_maps[dest_indices])
for s in src_indices
]
positive_loss = max(wp_pos_loss)
else:
positive_loss = _symmetric_kl(
attention_maps[src_indices], attention_maps[dest_indices]
)
return positive_loss
def _calculate_outside_loss(attention_maps, src_indices, outside_loss):
negative_loss = []
computed_pairs = set()
pair_counter = 0
for outside_idx in outside_loss:
if isinstance(src_indices, list):
wp_neg_loss = []
for t in src_indices:
pair_key = (t, outside_idx)
if pair_key not in computed_pairs:
wp_neg_loss.append(
_symmetric_kl(
attention_maps[t], attention_maps[outside_idx]
)
)
computed_pairs.add(pair_key)
negative_loss.append(max(wp_neg_loss) if wp_neg_loss else 0)
pair_counter += 1
else:
pair_key = (src_indices, outside_idx)
if pair_key not in computed_pairs:
negative_loss.append(
_symmetric_kl(
attention_maps[src_indices], attention_maps[outside_idx]
)
)
computed_pairs.add(pair_key)
pair_counter += 1
return negative_loss, pair_counter
def align_wordpieces_indices(
wordpieces2indices, start_idx, target_word
):
"""
Aligns a `target_word` that contains more than one wordpiece (the first wordpiece is `start_idx`)
"""
wp_indices = [start_idx]
wp = wordpieces2indices[start_idx].replace("</w>", "")
# Run over the next wordpieces in the sequence (which is why we use +1)
for wp_idx in range(start_idx + 1, len(wordpieces2indices)):
if wp == target_word:
break
wp2 = wordpieces2indices[wp_idx].replace("</w>", "")
if target_word.startswith(wp + wp2) and wp2 != target_word:
wp += wordpieces2indices[wp_idx].replace("</w>", "")
wp_indices.append(wp_idx)
else:
wp_indices = (
[]
) # if there's no match, you want to clear the list and finish
break
return wp_indices
def extract_attribution_indices(doc):
# doc = parser(prompt)
subtrees = []
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
for w in doc:
if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers:
continue
subtree = []
stack = []
for child in w.children:
if child.dep_ in modifiers:
subtree.append(child)
stack.extend(child.children)
while stack:
node = stack.pop()
if node.dep_ in modifiers or node.dep_ == "conj":
subtree.append(node)
stack.extend(node.children)
if subtree:
subtree.append(w)
subtrees.append(subtree)
return subtrees
def extract_attribution_indices_with_verbs(doc):
'''This function specifically addresses cases where a verb is between
a noun and its modifier. For instance: "a dog that is red"
here, the aux is between 'dog' and 'red'. '''
subtrees = []
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp",
'relcl']
for w in doc:
if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers:
continue
subtree = []
stack = []
for child in w.children:
if child.dep_ in modifiers:
if child.pos_ not in ['AUX', 'VERB']:
subtree.append(child)
stack.extend(child.children)
while stack:
node = stack.pop()
if node.dep_ in modifiers or node.dep_ == "conj":
# we don't want to add 'is' or other verbs to the loss, we want their children
if node.pos_ not in ['AUX', 'VERB']:
subtree.append(node)
stack.extend(node.children)
if subtree:
subtree.append(w)
subtrees.append(subtree)
return subtrees
def extract_attribution_indices_with_verb_root(doc):
'''This function specifically addresses cases where a verb is between
a noun and its modifier. For instance: "a dog that is red"
here, the aux is between 'dog' and 'red'. '''
subtrees = []
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
for w in doc:
subtree = []
stack = []
# if w is a verb/aux and has a noun child and a modifier child, add them to the stack
if w.pos_ != 'AUX' or w.dep_ in modifiers:
continue
for child in w.children:
if child.dep_ in modifiers or child.pos_ in ['NOUN', 'PROPN']:
if child.pos_ not in ['AUX', 'VERB']:
subtree.append(child)
stack.extend(child.children)
# did not find a pair of noun and modifier
if len(subtree) < 2:
continue
while stack:
node = stack.pop()
if node.dep_ in modifiers or node.dep_ == "conj":
# we don't want to add 'is' or other verbs to the loss, we want their children
if node.pos_ not in ['AUX']:
subtree.append(node)
stack.extend(node.children)
if subtree:
if w.pos_ not in ['AUX']:
subtree.append(w)
subtrees.append(subtree)
return subtrees
def calculate_negative_loss(
attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp
):
outside_indices = _get_outside_indices(subtree_indices, attn_map_idx_to_wp)
negative_modifier_loss, num_modifier_pairs = _calculate_outside_loss(
attention_maps, modifier, outside_indices
)
negative_noun_loss, num_noun_pairs = _calculate_outside_loss(
attention_maps, noun, outside_indices
)
negative_modifier_loss = -sum(negative_modifier_loss) / len(outside_indices)
negative_noun_loss = -sum(negative_noun_loss) / len(outside_indices)
negative_loss = (negative_modifier_loss + negative_noun_loss) / 2
return negative_loss
def get_indices(tokenizer, prompt: str) -> Dict[str, int]:
"""Utility function to list the indices of the tokens you wish to alter"""
ids = tokenizer(prompt).input_ids
indices = {
i: tok
for tok, i in zip(
tokenizer.convert_ids_to_tokens(ids), range(len(ids))
)
}
return indices
def get_attention_map_index_to_wordpiece(tokenizer, prompt):
attn_map_idx_to_wp = {}
wordpieces2indices = get_indices(tokenizer, prompt)
# Ignore `start_token` and `end_token`
for i in list(wordpieces2indices.keys())[1:-1]:
wordpiece = wordpieces2indices[i]
wordpiece = wordpiece.replace("</w>", "")
attn_map_idx_to_wp[i] = wordpiece
return attn_map_idx_to_wp