Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,038 Bytes
ddaff53 8a142a6 ddaff53 dab8aab 665e5a3 ddaff53 8a142a6 ddaff53 8a142a6 ddaff53 8a142a6 ddaff53 8151596 8a142a6 8151596 ddaff53 8a142a6 ddaff53 8a142a6 ddaff53 665e5a3 ddaff53 665e5a3 ddaff53 8151596 ddaff53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from .prompts import format_rag_prompt
# --- Dummy Model Summaries ---
# Define functions that simulate model summary generation
# models = {
# "Model Alpha": lambda context, question, answerable: f"Alpha Summary: Based on the context for '{question[:20]}...', it appears the question is {'answerable' if answerable else 'unanswerable'}.",
# "Model Beta": lambda context, question, answerable: f"Beta Summary: Regarding '{question[:20]}...', the provided documents {'allow' if answerable else 'do not allow'} for a conclusive answer based on the text.",
# "Model Gamma": lambda context, question, answerable: f"Gamma Summary: For the question '{question[:20]}...', I {'can' if answerable else 'cannot'} provide a specific answer from the given text snippets.",
# "Model Delta (Refusal Specialist)": lambda context, question, answerable: f"Delta Summary: The context for '{question[:20]}...' is {'sufficient' if answerable else 'insufficient'} to formulate a direct response. Therefore, I must refuse."
# }
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct", # remove gated for now
"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Gemma-3-1b-it" : "google/gemma-3-1b-it",
#"Bitnet-b1.58-2B-4T": "microsoft/bitnet-b1.58-2B-4T",
#TODO add more models
}
# List of model names for easy access
model_names = list(models.keys())
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models.
"""
# Create a plain text version of the contexts for the models
context_text = ""
context_parts = []
if "full_contexts" in example:
for ctx in example["full_contexts"]:
if isinstance(ctx, dict) and "content" in ctx:
context_parts.append(ctx["content"])
context_text = "\n---\n".join(context_parts)
else:
raise ValueError("No context found in the example.")
# Pass 'Answerable' status to models (they might use it)
answerable = example.get("Answerable", True)
question = example.get("question", "")
# Call the dummy model functions
summary_a = run_inference(models[model_a_name], context_text, question)
summary_b = run_inference(models[model_b_name], context_text, question)
return summary_a, summary_b
def run_inference(model_name, context, question):
"""
Run inference using the specified model.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
) # Workaround for Gemma
# Set padding token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
).to(device)
text_input = format_rag_prompt(question, context, accepts_sys)
# Tokenize the input
actual_input = tokenizer.apply_chat_template(
text_input,
return_tensors="pt",
tokenize=True,
max_length=2048,
add_generation_prompt=True,
).to(device)
input_length = actual_input.shape[1]
attention_mask = torch.ones_like(actual_input).to(device)
# Generate output
with torch.inference_mode():
outputs = model.generate(
actual_input,
attention_mask=attention_mask,
max_new_tokens=512,
pad_token_id=tokenizer.pad_token_id,
)
# Decode the output
result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
return result
|