import os, gc from typing import AsyncGenerator from rwkv.model import RWKV from rwkv.utils import PIPELINE, PIPELINE_ARGS from asyncio import sleep class Answerer: def __init__(self, model: str, vocab: str, strategy: str, ctx_limit: int): os.environ["RWKV_JIT_ON"] = "1" # os.environ["RWKV_CUDA_ON"] = "1" self.__model = RWKV(f"models/{model}.pth", strategy=strategy) self.__pipeline = PIPELINE(self.__model, vocab) self.ctx_limit = ctx_limit async def __call__( self, input: str, max_output_length_tk: int, chaos = .1, repetitiveness = .3, diversity = 0, _count_penalty = 1, ) -> AsyncGenerator[str, None]: args = PIPELINE_ARGS( temperature=chaos, top_p=repetitiveness, alpha_frequency=_count_penalty, alpha_presence=diversity, token_ban = [], token_stop = [0], ) input = input.strip() result: str = "" occurrences: dict[int, int] = {} tokens: list[int] = [] current_token = None state = None for _ in range(max_output_length_tk): out, state = self.__model.forward( [current_token] if current_token else self.__pipeline.encode(input)[-self.ctx_limit:], state, ) for token in occurrences: out[token] -= args.alpha_presence + occurrences[token] * args.alpha_frequency current_token = self.__pipeline.sample_logits( out, temperature=args.temperature, top_p=args.top_p, ) if current_token in args.token_stop: break tokens.append(current_token) for token in occurrences: occurrences[token] *= 0.996 if current_token in occurrences: occurrences[current_token] += 1 else: occurrences[current_token] = 1 tmp: str = self.__pipeline.decode(tokens) if "\ufffd" not in tmp: tokens.clear() result += tmp if result.rstrip().endswith("\n\nUser:"): yield result.rstrip().removesuffix("\n\nUser:") break yield result await sleep(.02) tokens.clear() occurrences.clear() del out, tmp del occurrences, tokens, current_token, state gc.collect()