Royir commited on
Commit
19cb368
1 Parent(s): cfaf86b

Update compute_loss.py

Browse files

enabling new syntax structures (e.g., 'an apple is blue')

Files changed (1) hide show
  1. compute_loss.py +71 -3
compute_loss.py CHANGED
@@ -142,8 +142,8 @@ def align_wordpieces_indices(
142
  return wp_indices
143
 
144
 
145
- def extract_attribution_indices(prompt, parser):
146
- doc = parser(prompt)
147
  subtrees = []
148
  modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
149
 
@@ -167,6 +167,74 @@ def extract_attribution_indices(prompt, parser):
167
  subtrees.append(subtree)
168
  return subtrees
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  def calculate_negative_loss(
172
  attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp
@@ -187,7 +255,7 @@ def calculate_negative_loss(
187
  return negative_loss
188
 
189
  def get_indices(tokenizer, prompt: str) -> Dict[str, int]:
190
- """Utility function to list the indices of the tokens you wish to alte"""
191
  ids = tokenizer(prompt).input_ids
192
  indices = {
193
  i: tok
 
142
  return wp_indices
143
 
144
 
145
+ def extract_attribution_indices(doc):
146
+ # doc = parser(prompt)
147
  subtrees = []
148
  modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
149
 
 
167
  subtrees.append(subtree)
168
  return subtrees
169
 
170
+ def extract_attribution_indices_with_verbs(doc):
171
+ '''This function specifically addresses cases where a verb is between
172
+ a noun and its modifier. For instance: "a dog that is red"
173
+ here, the aux is between 'dog' and 'red'. '''
174
+
175
+ subtrees = []
176
+ modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp",
177
+ 'relcl']
178
+ for w in doc:
179
+ if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers:
180
+ continue
181
+ subtree = []
182
+ stack = []
183
+ for child in w.children:
184
+ if child.dep_ in modifiers:
185
+ if child.pos_ not in ['AUX', 'VERB']:
186
+ subtree.append(child)
187
+ stack.extend(child.children)
188
+
189
+ while stack:
190
+ node = stack.pop()
191
+ if node.dep_ in modifiers or node.dep_ == "conj":
192
+ # we don't want to add 'is' or other verbs to the loss, we want their children
193
+ if node.pos_ not in ['AUX', 'VERB']:
194
+ subtree.append(node)
195
+ stack.extend(node.children)
196
+ if subtree:
197
+ subtree.append(w)
198
+ subtrees.append(subtree)
199
+ return subtrees
200
+
201
+ def extract_attribution_indices_with_verb_root(doc):
202
+ '''This function specifically addresses cases where a verb is between
203
+ a noun and its modifier. For instance: "a dog that is red"
204
+ here, the aux is between 'dog' and 'red'. '''
205
+
206
+ subtrees = []
207
+ modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
208
+ for w in doc:
209
+ subtree = []
210
+ stack = []
211
+
212
+ # if w is a verb/aux and has a noun child and a modifier child, add them to the stack
213
+ if w.pos_ != 'AUX' or w.dep_ in modifiers:
214
+ continue
215
+
216
+ for child in w.children:
217
+ if child.dep_ in modifiers or child.pos_ in ['NOUN', 'PROPN']:
218
+ if child.pos_ not in ['AUX', 'VERB']:
219
+ subtree.append(child)
220
+ stack.extend(child.children)
221
+ # did not find a pair of noun and modifier
222
+ if len(subtree) < 2:
223
+ continue
224
+
225
+ while stack:
226
+ node = stack.pop()
227
+ if node.dep_ in modifiers or node.dep_ == "conj":
228
+ # we don't want to add 'is' or other verbs to the loss, we want their children
229
+ if node.pos_ not in ['AUX']:
230
+ subtree.append(node)
231
+ stack.extend(node.children)
232
+
233
+ if subtree:
234
+ if w.pos_ not in ['AUX']:
235
+ subtree.append(w)
236
+ subtrees.append(subtree)
237
+ return subtrees
238
 
239
  def calculate_negative_loss(
240
  attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp
 
255
  return negative_loss
256
 
257
  def get_indices(tokenizer, prompt: str) -> Dict[str, int]:
258
+ """Utility function to list the indices of the tokens you wish to alter"""
259
  ids = tokenizer(prompt).input_ids
260
  indices = {
261
  i: tok