starsaround's picture
Upload memory_func.py
5071898
raw
history blame
973 Bytes
from langchain.llms.base import LLM
from langchain.memory import ConversationBufferWindowMemory
from transformers import GPT2TokenizerFast
from langchain.schema.messages import get_buffer_string
def get_num_tokens(text):
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
return len(tokenizer.tokenize(text))
def get_memory_num_tokens(memory):
buffer = memory.chat_memory.messages
return sum([get_num_tokens(get_buffer_string([m])) for m in buffer])
def validate_memory_len(memory, max_token_limit=2000):
buffer = memory.chat_memory.messages
curr_buffer_length = get_memory_num_tokens(memory)
if curr_buffer_length > max_token_limit:
while curr_buffer_length > max_token_limit:
buffer.pop(0)
curr_buffer_length = get_memory_num_tokens(memory)
return memory
if __name__ == '__main__':
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
text = '''Hi'''
print(len(tokenizer.tokenize(text)))