resrer-pegasus-x / utils.py
seonglae's picture
Training in progress, step 500
a9082f6
raw
history blame
2.03 kB
from typing import TypedDict, List
from tiktoken import Encoding
class Row(TypedDict):
id: str
title: str
url: str
text: str
def split_token(encoder: Encoding, rows: List[Row], input_texts: List[str], split: int = 512) -> List[Row]:
dict_list: List[Row] = []
# Batch documents
for i, text_tokenes in enumerate(encoder.encode_batch(input_texts)):
row = rows[i]
passages_count = int((len(text_tokenes) - 1) / split)
# Passages from start
for i in range(passages_count):
tokens = text_tokenes[i * split:(i + 1) * split]
for i in range(passages_count):
tokens = text_tokenes[i * split:(i + 1) * split]
# Append tokens until meet whitespace
for token in text_tokenes[(i + 1) * split:]:
if not encoder.decode_single_token_bytes(token).startswith(b' '):
tokens.append(token)
else:
break
# Unshift tokens until meet whitespace
if not encoder.decode_single_token_bytes(text_tokenes[i * split]).startswith(b' '):
for token in reversed(text_tokenes[:i * split]):
if not encoder.decode_single_token_bytes(token).startswith(b' '):
tokens.insert(0, token)
else:
tokens.insert(0, token)
break
dict_list.append({'id': f"{row['id']}_{i}", 'title': row['title'], 'url': row['url'],
'text': encoder.decode(tokens)})
# Passages from end
tokens = text_tokenes[-split:]
if not encoder.decode_single_token_bytes(text_tokenes[0]).startswith(b' '):
# Unshift tokens until meet whitespace
for token in reversed(text_tokenes[:-split]):
if not encoder.decode_single_token_bytes(token).startswith(b' '):
tokens.insert(0, token)
else:
tokens.insert(0, token)
break
dict_list.append({'id': f"{row['id']}_{passages_count}", 'title': row['title'], 'url': row['url'],
'text': encoder.decode(tokens)})
return dict_list