Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import abc | |
from typing import Any, Generator, Generic, TypeVar | |
import pydantic | |
T = TypeVar("T") | |
C = TypeVar("C") | |
class StatefulIterator(Generic[T, C], abc.ABC): | |
def get_state(self) -> C: | |
pass | |
def create_iter(self) -> Generator[T, Any, None]: | |
pass | |
class IteratorState(Generic[C]): | |
def build(self) -> StatefulIterator[T, C]: | |
pass | |
class PydanticIteratorState(pydantic.BaseModel, IteratorState): | |
model_config = pydantic.ConfigDict(extra="forbid") | |
def get_state_and_refresh(iterator: StatefulIterator): | |
# Re-init dataloader and iterator is necessary since get_state() | |
# on mp iterator shuts down MP to correctly persist state and it needs | |
# to be restarted. | |
state = iterator.get_state() | |
data_loader = state.build() | |
py_iterator = data_loader.create_iter() | |
return state, data_loader, py_iterator | |