HF-SillyTavern-Extras / talkinghead /tha3 /compute /cached_computation_protocol.py
TomatoCocotree
上传
6a62ffb
raw
history blame
1.57 kB
from abc import ABC, abstractmethod
from typing import Dict, List
from torch import Tensor
from torch.nn import Module
from tha3.compute.cached_computation_func import TensorCachedComputationFunc, TensorListCachedComputationFunc
class CachedComputationProtocol(ABC):
def get_output(self,
key: str,
modules: Dict[str, Module],
batch: List[Tensor],
outputs: Dict[str, List[Tensor]]):
if key in outputs:
return outputs[key]
else:
output = self.compute_output(key, modules, batch, outputs)
outputs[key] = output
return outputs[key]
@abstractmethod
def compute_output(self,
key: str,
modules: Dict[str, Module],
batch: List[Tensor],
outputs: Dict[str, List[Tensor]]) -> List[Tensor]:
pass
def get_output_tensor_func(self, key: str, index: int) -> TensorCachedComputationFunc:
def func(modules: Dict[str, Module],
batch: List[Tensor],
outputs: Dict[str, List[Tensor]]):
return self.get_output(key, modules, batch, outputs)[index]
return func
def get_output_tensor_list_func(self, key: str) -> TensorListCachedComputationFunc:
def func(modules: Dict[str, Module],
batch: List[Tensor],
outputs: Dict[str, List[Tensor]]):
return self.get_output(key, modules, batch, outputs)
return func