|
from langchain.llms.base import LLM |
|
from langchain.memory import ConversationBufferWindowMemory |
|
from transformers import GPT2TokenizerFast |
|
from langchain.schema.messages import get_buffer_string |
|
|
|
def get_num_tokens(text): |
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
|
return len(tokenizer.tokenize(text)) |
|
|
|
def get_memory_num_tokens(memory): |
|
buffer = memory.chat_memory.messages |
|
return sum([get_num_tokens(get_buffer_string([m])) for m in buffer]) |
|
|
|
def validate_memory_len(memory, max_token_limit=2000): |
|
buffer = memory.chat_memory.messages |
|
curr_buffer_length = get_memory_num_tokens(memory) |
|
if curr_buffer_length > max_token_limit: |
|
while curr_buffer_length > max_token_limit: |
|
buffer.pop(0) |
|
curr_buffer_length = get_memory_num_tokens(memory) |
|
return memory |
|
|
|
if __name__ == '__main__': |
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
|
text = '''Hi''' |
|
print(len(tokenizer.tokenize(text))) |