SLM-RAG-Arena / utils /models.py
oliver-aizip's picture
first pass at async handling
665e5a3
raw
history blame
4.04 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from .prompts import format_rag_prompt
# --- Dummy Model Summaries ---
# Define functions that simulate model summary generation
# models = {
# "Model Alpha": lambda context, question, answerable: f"Alpha Summary: Based on the context for '{question[:20]}...', it appears the question is {'answerable' if answerable else 'unanswerable'}.",
# "Model Beta": lambda context, question, answerable: f"Beta Summary: Regarding '{question[:20]}...', the provided documents {'allow' if answerable else 'do not allow'} for a conclusive answer based on the text.",
# "Model Gamma": lambda context, question, answerable: f"Gamma Summary: For the question '{question[:20]}...', I {'can' if answerable else 'cannot'} provide a specific answer from the given text snippets.",
# "Model Delta (Refusal Specialist)": lambda context, question, answerable: f"Delta Summary: The context for '{question[:20]}...' is {'sufficient' if answerable else 'insufficient'} to formulate a direct response. Therefore, I must refuse."
# }
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct", # remove gated for now
"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Gemma-3-1b-it" : "google/gemma-3-1b-it",
#"Bitnet-b1.58-2B-4T": "microsoft/bitnet-b1.58-2B-4T",
#TODO add more models
}
# List of model names for easy access
model_names = list(models.keys())
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models.
"""
# Create a plain text version of the contexts for the models
context_text = ""
context_parts = []
if "full_contexts" in example:
for ctx in example["full_contexts"]:
if isinstance(ctx, dict) and "content" in ctx:
context_parts.append(ctx["content"])
context_text = "\n---\n".join(context_parts)
else:
raise ValueError("No context found in the example.")
# Pass 'Answerable' status to models (they might use it)
answerable = example.get("Answerable", True)
question = example.get("question", "")
# Call the dummy model functions
summary_a = run_inference(models[model_a_name], context_text, question)
summary_b = run_inference(models[model_b_name], context_text, question)
return summary_a, summary_b
def run_inference(model_name, context, question):
"""
Run inference using the specified model.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
) # Workaround for Gemma
# Set padding token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
).to(device)
text_input = format_rag_prompt(question, context, accepts_sys)
# Tokenize the input
actual_input = tokenizer.apply_chat_template(
text_input,
return_tensors="pt",
tokenize=True,
max_length=2048,
add_generation_prompt=True,
).to(device)
input_length = actual_input.shape[1]
attention_mask = torch.ones_like(actual_input).to(device)
# Generate output
with torch.inference_mode():
outputs = model.generate(
actual_input,
attention_mask=attention_mask,
max_new_tokens=512,
pad_token_id=tokenizer.pad_token_id,
)
# Decode the output
result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
return result