Spaces:
Sleeping
Sleeping
Update compute_loss.py
Browse filesenabling new syntax structures (e.g., 'an apple is blue')
- compute_loss.py +71 -3
compute_loss.py
CHANGED
@@ -142,8 +142,8 @@ def align_wordpieces_indices(
|
|
142 |
return wp_indices
|
143 |
|
144 |
|
145 |
-
def extract_attribution_indices(
|
146 |
-
doc = parser(prompt)
|
147 |
subtrees = []
|
148 |
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
|
149 |
|
@@ -167,6 +167,74 @@ def extract_attribution_indices(prompt, parser):
|
|
167 |
subtrees.append(subtree)
|
168 |
return subtrees
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
def calculate_negative_loss(
|
172 |
attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp
|
@@ -187,7 +255,7 @@ def calculate_negative_loss(
|
|
187 |
return negative_loss
|
188 |
|
189 |
def get_indices(tokenizer, prompt: str) -> Dict[str, int]:
|
190 |
-
"""Utility function to list the indices of the tokens you wish to
|
191 |
ids = tokenizer(prompt).input_ids
|
192 |
indices = {
|
193 |
i: tok
|
|
|
142 |
return wp_indices
|
143 |
|
144 |
|
145 |
+
def extract_attribution_indices(doc):
|
146 |
+
# doc = parser(prompt)
|
147 |
subtrees = []
|
148 |
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
|
149 |
|
|
|
167 |
subtrees.append(subtree)
|
168 |
return subtrees
|
169 |
|
170 |
+
def extract_attribution_indices_with_verbs(doc):
|
171 |
+
'''This function specifically addresses cases where a verb is between
|
172 |
+
a noun and its modifier. For instance: "a dog that is red"
|
173 |
+
here, the aux is between 'dog' and 'red'. '''
|
174 |
+
|
175 |
+
subtrees = []
|
176 |
+
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp",
|
177 |
+
'relcl']
|
178 |
+
for w in doc:
|
179 |
+
if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers:
|
180 |
+
continue
|
181 |
+
subtree = []
|
182 |
+
stack = []
|
183 |
+
for child in w.children:
|
184 |
+
if child.dep_ in modifiers:
|
185 |
+
if child.pos_ not in ['AUX', 'VERB']:
|
186 |
+
subtree.append(child)
|
187 |
+
stack.extend(child.children)
|
188 |
+
|
189 |
+
while stack:
|
190 |
+
node = stack.pop()
|
191 |
+
if node.dep_ in modifiers or node.dep_ == "conj":
|
192 |
+
# we don't want to add 'is' or other verbs to the loss, we want their children
|
193 |
+
if node.pos_ not in ['AUX', 'VERB']:
|
194 |
+
subtree.append(node)
|
195 |
+
stack.extend(node.children)
|
196 |
+
if subtree:
|
197 |
+
subtree.append(w)
|
198 |
+
subtrees.append(subtree)
|
199 |
+
return subtrees
|
200 |
+
|
201 |
+
def extract_attribution_indices_with_verb_root(doc):
|
202 |
+
'''This function specifically addresses cases where a verb is between
|
203 |
+
a noun and its modifier. For instance: "a dog that is red"
|
204 |
+
here, the aux is between 'dog' and 'red'. '''
|
205 |
+
|
206 |
+
subtrees = []
|
207 |
+
modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"]
|
208 |
+
for w in doc:
|
209 |
+
subtree = []
|
210 |
+
stack = []
|
211 |
+
|
212 |
+
# if w is a verb/aux and has a noun child and a modifier child, add them to the stack
|
213 |
+
if w.pos_ != 'AUX' or w.dep_ in modifiers:
|
214 |
+
continue
|
215 |
+
|
216 |
+
for child in w.children:
|
217 |
+
if child.dep_ in modifiers or child.pos_ in ['NOUN', 'PROPN']:
|
218 |
+
if child.pos_ not in ['AUX', 'VERB']:
|
219 |
+
subtree.append(child)
|
220 |
+
stack.extend(child.children)
|
221 |
+
# did not find a pair of noun and modifier
|
222 |
+
if len(subtree) < 2:
|
223 |
+
continue
|
224 |
+
|
225 |
+
while stack:
|
226 |
+
node = stack.pop()
|
227 |
+
if node.dep_ in modifiers or node.dep_ == "conj":
|
228 |
+
# we don't want to add 'is' or other verbs to the loss, we want their children
|
229 |
+
if node.pos_ not in ['AUX']:
|
230 |
+
subtree.append(node)
|
231 |
+
stack.extend(node.children)
|
232 |
+
|
233 |
+
if subtree:
|
234 |
+
if w.pos_ not in ['AUX']:
|
235 |
+
subtree.append(w)
|
236 |
+
subtrees.append(subtree)
|
237 |
+
return subtrees
|
238 |
|
239 |
def calculate_negative_loss(
|
240 |
attention_maps, modifier, noun, subtree_indices, attn_map_idx_to_wp
|
|
|
255 |
return negative_loss
|
256 |
|
257 |
def get_indices(tokenizer, prompt: str) -> Dict[str, int]:
|
258 |
+
"""Utility function to list the indices of the tokens you wish to alter"""
|
259 |
ids = tokenizer(prompt).input_ids
|
260 |
indices = {
|
261 |
i: tok
|