thesis / utils /modelling.py
LennardZuendorf's picture
fix: final fix of attention
517fd4c unverified
# modelling util module providing formatting functions for model functionalities
# external imports
import torch
import gradio as gr
from transformers import BitsAndBytesConfig
# function that limits the prompt to contain model runtime
# tries to keep as much as possible, always keeping at least message and system prompt
def prompt_limiter(
tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
):
# initializing the new prompt history empty
prompt_history = []
# getting the current token count for the message, system prompt, and knowledge
pre_count = (
token_counter(tokenizer, message)
+ token_counter(tokenizer, system_prompt)
+ token_counter(tokenizer, knowledge)
)
# validating the token count against threshold of 1024
# check if token count already too high without history
if pre_count > 1024:
# check if token count too high even without knowledge and history
if (
token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
> 1024
):
# show warning and raise error
gr.Warning("Message and system prompt are too long. Please shorten them.")
raise RuntimeError(
"Message and system prompt are too long. Please shorten them."
)
# show warning and return with empty history and empty knowledge
gr.Warning("""
Input too long.
Knowledge and conversation history have been removed to keep model running.
""")
return message, prompt_history, system_prompt, ""
# if token count small enough, adding history bit by bit
if pre_count < 800:
# setting the count to the pre-count
count = pre_count
# reversing the history to prioritize recent conversations
history.reverse()
# iterating through the history
for conversation in history:
# checking the token count i´with the current conversation
count += token_counter(tokenizer, conversation[0]) + token_counter(
tokenizer, conversation[1]
)
# add conversation or break loop depending on token count
if count < 1024:
prompt_history.append(conversation)
else:
break
# return the message, adapted, system prompt, and knowledge
return message, prompt_history, system_prompt, knowledge
# token counter function using the model tokenizer
def token_counter(tokenizer, text: str):
# tokenize the text
tokens = tokenizer(text, return_tensors="pt").input_ids
# return the token count
return len(tokens[0])
# function to determine the device to use
def get_device():
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
return device
# function to set device config
# CREDIT: Copied from captum llama 2 example
# see https://captum.ai/tutorials/Llama2_LLM_Attribution
def gpu_loading_config(max_memory: str = "15000MB"):
n_gpus = torch.cuda.device_count()
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
return n_gpus, max_memory, bnb_config
# formatting mistral attention values
# CREDIT: copied from BERTViz
# see https://github.com/jessevig/bertviz
def format_mistral_attention(attention_values, layers=None, heads=None):
if layers:
attention_values = [attention_values[layer_index] for layer_index in layers]
squeezed = []
for layer_attention in attention_values:
layer_attention = layer_attention.squeeze(0)
if heads:
layer_attention = layer_attention[heads]
squeezed.append(layer_attention)
return torch.stack(squeezed)