Spaces:
Running
on
Zero
Running
on
Zero
import pandas as pd | |
from pydantic import ConfigDict | |
from bytelatent.data.data_types import BltExample | |
from bytelatent.data.iterators.abstract_iterator import ( | |
PydanticIteratorState, | |
StatefulIterator, | |
) | |
class BltTestIteratorState(PydanticIteratorState): | |
model_config = ConfigDict(extra="forbid") | |
position: int | |
total: int | |
def build(self): | |
blt_iter = BltTestIteratorState(total=self.total) | |
blt_iter.position = self.position | |
return blt_iter | |
class BltTestIterator(StatefulIterator): | |
def __init__(self, total: int): | |
self.position = 0 | |
self.total = total | |
def get_state(self): | |
return BltTestIteratorState(position=self.position, total=self.total) | |
def create_iter(self): | |
for i in range(self.total): | |
self.position += 1 | |
yield BltExample( | |
sample_id=f"test_{i}", | |
text=f"This is some test {i} text.", | |
tokens=None, | |
mask=None, | |
entropies=None, | |
patch_lengths=None, | |
) | |
class BltTestWithEntropiesIteratorState(PydanticIteratorState): | |
model_config = ConfigDict(extra="forbid") | |
position: int | |
total: int | |
def build(self): | |
blt_iter = BltTestWithEntropiesIteratorState(total=self.total) | |
blt_iter.position = self.position | |
return blt_iter | |
class BltTestWithEntropiesIterator(StatefulIterator): | |
def __init__(self, total: int): | |
self.position = 0 | |
self.total = total | |
def get_state(self): | |
return BltTestIteratorState(position=self.position, total=self.total) | |
def create_iter(self): | |
text = "Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin." | |
df = pd.read_json("fixtures/tokens_with_entropies.json") | |
tokens = df["token_ids"].tolist() | |
entropies = df["entropies"].tolist() | |
# BOS and EOS | |
assert len(tokens) == len(text) + 2 | |
for i in range(self.total): | |
self.position += 1 | |
yield BltExample( | |
sample_id=f"test_{i}", | |
text=text, | |
tokens=tokens, | |
mask=[True] * len(tokens), | |
entropies=entropies, | |
patch_lengths=None, | |
) | |