NetsPresso_QA / scripts /ltr_msmarco /convert_passage_doc.py
geonmin-kim's picture
Upload folder using huggingface_hub
d6585f5
raw
history blame
No virus
4.55 kB
#
# Pyserini: Reproducible IR research with sparse and dense representations
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import multiprocessing
from joblib import Parallel, delayed
import json
import argparse
from transformers import AutoTokenizer, AutoModel
import spacy
import re
from convert_common import read_stopwords, SpacyTextParser, get_retokenized
from pyserini.analysis import Analyzer, get_lucene_analyzer
import time
import os
"""
add fields to jsonl with text(lemmatized), text_unlemm, contents(analyzer), raw, text_bert_tok(BERT token)
"""
parser = argparse.ArgumentParser(description='Convert MSMARCO-adhoc documents.')
parser.add_argument('--input', metavar='input file', help='input file',
type=str, required=True)
parser.add_argument('--input-format', metavar='input format', help='input format',
type=str, default='passage')
parser.add_argument('--output', metavar='output file', help='output file',
type=str, required=True)
parser.add_argument('--max_doc_size', metavar='max doc size bytes',
help='the threshold for the document size, if a document is larger it is truncated',
type=int, default=16536 )
parser.add_argument('--proc_qty', metavar='# of processes', help='# of NLP processes to span',
type=int, default=16)#multiprocessing.cpu_count() - 2)
args = parser.parse_args()
print(args)
arg_vars = vars(args)
inpFile = open(args.input)
outFile = open(args.output, 'w')
maxDocSize = args.max_doc_size
def batch_file(iterable, n=10000):
batch = []
for line in iterable:
batch.append(line)
if len(batch) == n:
yield batch
batch = []
if len(batch)>0:
yield batch
batch = []
return
def batch_process(batch):
#assume call the script from the root dir
stopwords = read_stopwords('./scripts/ltr_msmarco/stopwords.txt', lower_case=True)
nlp = SpacyTextParser('en_core_web_sm', stopwords, keep_only_alpha_num=True, lower_case=True)
analyzer = Analyzer(get_lucene_analyzer())
bert_tokenizer =AutoTokenizer.from_pretrained("bert-base-uncased")
def process(line):
if not line:
return None
json_line = json.loads(line)
pid = json_line['id']
body = json_line['contents']
#url = json_line['url']
#title = json_line['title']
text, text_unlemm = nlp.proc_text(body)
#_,title_unlemm = nlp.proc_text(title)
analyzed = analyzer.analyze(body)
for token in analyzed:
assert ' ' not in token
contents = ' '.join(analyzed)
doc = {"id": pid,
"text": text,
"text_unlemm": text_unlemm,
'contents': contents,
#"title_unlemm": title_unlemm,
#"url": url,
"raw": body}
if (len(body)>512):
doc["text_bert_tok"] = get_retokenized(bert_tokenizer, body.lower()[:512])
else:
doc["text_bert_tok"] = get_retokenized(bert_tokenizer, body.lower())
return doc
res = []
start = time.time()
for line in batch:
res.append(process(line))
if len(res) % 10000 == 0:
end = time.time()
print(f'finish {len(res)} using {end-start}')
start = end
return res
if __name__ == '__main__':
proc_qty = args.proc_qty
print(f'Spanning {proc_qty} processes')
pool = Parallel(n_jobs=proc_qty, verbose=10)
ln = 0
for batch_json in pool([delayed(batch_process)(batch) for batch in batch_file(inpFile)]):
for docJson in batch_json:
ln = ln + 1
if docJson is not None:
outFile.write(json.dumps(docJson) + '\n')
else:
print('Ignoring misformatted line %d' % ln)
if ln % 100 == 0:
print('Processed %d passages' % ln)
print('Processed %d passages' % ln)
inpFile.close()
outFile.close()