# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import json from functools import lru_cache def convert_sentence_to_json(sentence): if "_" in sentence: prefix, rest = sentence.split("_", 1) query, rest = rest.split("_", 1) query_index = len(prefix.rstrip().split(" ")) else: query, query_index = None, None prefix, rest = sentence.split("[", 1) pronoun, rest = rest.split("]", 1) pronoun_index = len(prefix.rstrip().split(" ")) sentence = sentence.replace("_", "").replace("[", "").replace("]", "") return { "idx": 0, "text": sentence, "target": { "span1_index": query_index, "span1_text": query, "span2_index": pronoun_index, "span2_text": pronoun, }, } def extended_noun_chunks(sentence): noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks} np_start, cur_np = 0, "NONE" for i, token in enumerate(sentence): np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE" if np_type != cur_np: if cur_np != "NONE": noun_chunks.add((np_start, i)) if np_type != "NONE": np_start = i cur_np = np_type if cur_np != "NONE": noun_chunks.add((np_start, len(sentence))) return [sentence[s:e] for (s, e) in sorted(noun_chunks)] def find_token(sentence, start_pos): found_tok = None for tok in sentence: if tok.idx == start_pos: found_tok = tok break return found_tok def find_span(sentence, search_text, start=0): search_text = search_text.lower() for tok in sentence[start:]: remainder = sentence[tok.i :].text.lower() if remainder.startswith(search_text): len_to_consume = len(search_text) start_idx = tok.idx for next_tok in sentence[tok.i :]: end_idx = next_tok.idx + len(next_tok.text) if end_idx - start_idx == len_to_consume: span = sentence[tok.i : next_tok.i + 1] return span return None @lru_cache(maxsize=1) def get_detokenizer(): from sacremoses import MosesDetokenizer detok = MosesDetokenizer(lang="en") return detok @lru_cache(maxsize=1) def get_spacy_nlp(): import en_core_web_lg nlp = en_core_web_lg.load() return nlp def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False): detok = get_detokenizer() nlp = get_spacy_nlp() with open(input_fname) as fin: for line in fin: sample = json.loads(line.strip()) if positive_only and "label" in sample and not sample["label"]: # only consider examples where the query is correct continue target = sample["target"] # clean up the query query = target["span1_text"] if query is not None: if "\n" in query: continue if query.endswith(".") or query.endswith(","): query = query[:-1] # split tokens tokens = sample["text"].split(" ") def strip_pronoun(x): return x.rstrip('.,"') # find the pronoun pronoun_idx = target["span2_index"] pronoun = strip_pronoun(target["span2_text"]) if strip_pronoun(tokens[pronoun_idx]) != pronoun: # hack: sometimes the index is misaligned if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun: pronoun_idx += 1 else: raise Exception("Misaligned pronoun!") assert strip_pronoun(tokens[pronoun_idx]) == pronoun # split tokens before and after the pronoun before = tokens[:pronoun_idx] after = tokens[pronoun_idx + 1 :] # the GPT BPE attaches leading spaces to tokens, so we keep track # of whether we need spaces before or after the pronoun leading_space = " " if pronoun_idx > 0 else "" trailing_space = " " if len(after) > 0 else "" # detokenize before = detok.detokenize(before, return_str=True) pronoun = detok.detokenize([pronoun], return_str=True) after = detok.detokenize(after, return_str=True) # hack: when the pronoun ends in a period (or comma), move the # punctuation to the "after" part if pronoun.endswith(".") or pronoun.endswith(","): after = pronoun[-1] + trailing_space + after pronoun = pronoun[:-1] # hack: when the "after" part begins with a comma or period, remove # the trailing space if after.startswith(".") or after.startswith(","): trailing_space = "" # parse sentence with spacy sentence = nlp(before + leading_space + pronoun + trailing_space + after) # find pronoun span start = len(before + leading_space) first_pronoun_tok = find_token(sentence, start_pos=start) pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i) assert pronoun_span.text == pronoun if eval: # convert to format where pronoun is surrounded by "[]" and # query is surrounded by "_" query_span = find_span(sentence, query) query_with_ws = "_{}_{}".format( query_span.text, (" " if query_span.text_with_ws.endswith(" ") else ""), ) pronoun_with_ws = "[{}]{}".format( pronoun_span.text, (" " if pronoun_span.text_with_ws.endswith(" ") else ""), ) if query_span.start < pronoun_span.start: first = (query_span, query_with_ws) second = (pronoun_span, pronoun_with_ws) else: first = (pronoun_span, pronoun_with_ws) second = (query_span, query_with_ws) sentence = ( sentence[: first[0].start].text_with_ws + first[1] + sentence[first[0].end : second[0].start].text_with_ws + second[1] + sentence[second[0].end :].text ) yield sentence, sample.get("label", None) else: yield sentence, pronoun_span, query, sample.get("label", None) def winogrande_jsonl_iterator(input_fname, eval=False): with open(input_fname) as fin: for line in fin: sample = json.loads(line.strip()) sentence, option1, option2 = ( sample["sentence"], sample["option1"], sample["option2"], ) pronoun_span = (sentence.index("_"), sentence.index("_") + 1) if eval: query, cand = option1, option2 else: query = option1 if sample["answer"] == "1" else option2 cand = option2 if sample["answer"] == "1" else option1 yield sentence, pronoun_span, query, cand def filter_noun_chunks( chunks, exclude_pronouns=False, exclude_query=None, exact_match=False ): if exclude_pronouns: chunks = [ np for np in chunks if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np)) ] if exclude_query is not None: excl_txt = [exclude_query.lower()] filtered_chunks = [] for chunk in chunks: lower_chunk = chunk.text.lower() found = False for excl in excl_txt: if ( not exact_match and (lower_chunk in excl or excl in lower_chunk) ) or lower_chunk == excl: found = True break if not found: filtered_chunks.append(chunk) chunks = filtered_chunks return chunks