thesis / utils /modelling.py
LennardZuendorf's picture
Bump to Version 1.0.1 (#4)
fe1089d unverified
raw
history blame
No virus
2.31 kB
# module for modelling utilities
# external imports
import gradio as gr
def prompt_limiter(
tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
):
# initializing the prompt history empty
prompt_history = []
# getting the token count for the message, system prompt, and knowledge
pre_count = (
token_counter(tokenizer, message)
+ token_counter(tokenizer, system_prompt)
+ token_counter(tokenizer, knowledge)
)
# validating the token count
# check if token count already too high
if pre_count > 1024:
# check if token count too high even without knowledge
if (
token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
> 1024
):
# show warning and raise error
gr.Warning("Message and system prompt are too long. Please shorten them.")
raise RuntimeError(
"Message and system prompt are too long. Please shorten them."
)
# show warning and remove knowledge
gr.Warning("Knowledge is too long. It has been removed to keep model running.")
return message, prompt_history, system_prompt, ""
# if token count small enough, add history
if pre_count < 800:
# setting the count to the precount
count = pre_count
# reversing the history to prioritize recent conversations
history.reverse()
# iterating through the history
for conversation in history:
# checking the token count with the current conversation
count += token_counter(tokenizer, conversation[0]) + token_counter(
tokenizer, conversation[1]
)
# add conversation or break loop depending on token count
if count < 1024:
prompt_history.append(conversation)
else:
break
# return the message, prompt history, system prompt, and knowledge
return message, prompt_history, system_prompt, knowledge
# token counter function using the model tokenizer
def token_counter(tokenizer, text: str):
# tokenize the text
tokens = tokenizer(text, return_tensors="pt").input_ids
# return the token count
return len(tokens[0])