JustinLin610
update
10b0761
raw
history blame contribute delete
No virus
8.35 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
from functools import lru_cache
def convert_sentence_to_json(sentence):
if "_" in sentence:
prefix, rest = sentence.split("_", 1)
query, rest = rest.split("_", 1)
query_index = len(prefix.rstrip().split(" "))
else:
query, query_index = None, None
prefix, rest = sentence.split("[", 1)
pronoun, rest = rest.split("]", 1)
pronoun_index = len(prefix.rstrip().split(" "))
sentence = sentence.replace("_", "").replace("[", "").replace("]", "")
return {
"idx": 0,
"text": sentence,
"target": {
"span1_index": query_index,
"span1_text": query,
"span2_index": pronoun_index,
"span2_text": pronoun,
},
}
def extended_noun_chunks(sentence):
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
np_start, cur_np = 0, "NONE"
for i, token in enumerate(sentence):
np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE"
if np_type != cur_np:
if cur_np != "NONE":
noun_chunks.add((np_start, i))
if np_type != "NONE":
np_start = i
cur_np = np_type
if cur_np != "NONE":
noun_chunks.add((np_start, len(sentence)))
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
def find_token(sentence, start_pos):
found_tok = None
for tok in sentence:
if tok.idx == start_pos:
found_tok = tok
break
return found_tok
def find_span(sentence, search_text, start=0):
search_text = search_text.lower()
for tok in sentence[start:]:
remainder = sentence[tok.i :].text.lower()
if remainder.startswith(search_text):
len_to_consume = len(search_text)
start_idx = tok.idx
for next_tok in sentence[tok.i :]:
end_idx = next_tok.idx + len(next_tok.text)
if end_idx - start_idx == len_to_consume:
span = sentence[tok.i : next_tok.i + 1]
return span
return None
@lru_cache(maxsize=1)
def get_detokenizer():
from sacremoses import MosesDetokenizer
detok = MosesDetokenizer(lang="en")
return detok
@lru_cache(maxsize=1)
def get_spacy_nlp():
import en_core_web_lg
nlp = en_core_web_lg.load()
return nlp
def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
detok = get_detokenizer()
nlp = get_spacy_nlp()
with open(input_fname) as fin:
for line in fin:
sample = json.loads(line.strip())
if positive_only and "label" in sample and not sample["label"]:
# only consider examples where the query is correct
continue
target = sample["target"]
# clean up the query
query = target["span1_text"]
if query is not None:
if "\n" in query:
continue
if query.endswith(".") or query.endswith(","):
query = query[:-1]
# split tokens
tokens = sample["text"].split(" ")
def strip_pronoun(x):
return x.rstrip('.,"')
# find the pronoun
pronoun_idx = target["span2_index"]
pronoun = strip_pronoun(target["span2_text"])
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
# hack: sometimes the index is misaligned
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
pronoun_idx += 1
else:
raise Exception("Misaligned pronoun!")
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
# split tokens before and after the pronoun
before = tokens[:pronoun_idx]
after = tokens[pronoun_idx + 1 :]
# the GPT BPE attaches leading spaces to tokens, so we keep track
# of whether we need spaces before or after the pronoun
leading_space = " " if pronoun_idx > 0 else ""
trailing_space = " " if len(after) > 0 else ""
# detokenize
before = detok.detokenize(before, return_str=True)
pronoun = detok.detokenize([pronoun], return_str=True)
after = detok.detokenize(after, return_str=True)
# hack: when the pronoun ends in a period (or comma), move the
# punctuation to the "after" part
if pronoun.endswith(".") or pronoun.endswith(","):
after = pronoun[-1] + trailing_space + after
pronoun = pronoun[:-1]
# hack: when the "after" part begins with a comma or period, remove
# the trailing space
if after.startswith(".") or after.startswith(","):
trailing_space = ""
# parse sentence with spacy
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
# find pronoun span
start = len(before + leading_space)
first_pronoun_tok = find_token(sentence, start_pos=start)
pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i)
assert pronoun_span.text == pronoun
if eval:
# convert to format where pronoun is surrounded by "[]" and
# query is surrounded by "_"
query_span = find_span(sentence, query)
query_with_ws = "_{}_{}".format(
query_span.text,
(" " if query_span.text_with_ws.endswith(" ") else ""),
)
pronoun_with_ws = "[{}]{}".format(
pronoun_span.text,
(" " if pronoun_span.text_with_ws.endswith(" ") else ""),
)
if query_span.start < pronoun_span.start:
first = (query_span, query_with_ws)
second = (pronoun_span, pronoun_with_ws)
else:
first = (pronoun_span, pronoun_with_ws)
second = (query_span, query_with_ws)
sentence = (
sentence[: first[0].start].text_with_ws
+ first[1]
+ sentence[first[0].end : second[0].start].text_with_ws
+ second[1]
+ sentence[second[0].end :].text
)
yield sentence, sample.get("label", None)
else:
yield sentence, pronoun_span, query, sample.get("label", None)
def winogrande_jsonl_iterator(input_fname, eval=False):
with open(input_fname) as fin:
for line in fin:
sample = json.loads(line.strip())
sentence, option1, option2 = (
sample["sentence"],
sample["option1"],
sample["option2"],
)
pronoun_span = (sentence.index("_"), sentence.index("_") + 1)
if eval:
query, cand = option1, option2
else:
query = option1 if sample["answer"] == "1" else option2
cand = option2 if sample["answer"] == "1" else option1
yield sentence, pronoun_span, query, cand
def filter_noun_chunks(
chunks, exclude_pronouns=False, exclude_query=None, exact_match=False
):
if exclude_pronouns:
chunks = [
np
for np in chunks
if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np))
]
if exclude_query is not None:
excl_txt = [exclude_query.lower()]
filtered_chunks = []
for chunk in chunks:
lower_chunk = chunk.text.lower()
found = False
for excl in excl_txt:
if (
not exact_match and (lower_chunk in excl or excl in lower_chunk)
) or lower_chunk == excl:
found = True
break
if not found:
filtered_chunks.append(chunk)
chunks = filtered_chunks
return chunks