Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
ddaff53
1
Parent(s):
6f5e355
v1 inference code
Browse files- requirements.txt +1 -0
- utils/models.py +74 -11
- utils/prompts.py +39 -0
requirements.txt
CHANGED
@@ -6,3 +6,4 @@ numpy==1.26.4
|
|
6 |
openai>=1.60.2
|
7 |
torch>=2.5.1
|
8 |
tqdm==4.67.1
|
|
|
|
6 |
openai>=1.60.2
|
7 |
torch>=2.5.1
|
8 |
tqdm==4.67.1
|
9 |
+
flash-attn>=2.7.4
|
utils/models.py
CHANGED
@@ -1,41 +1,104 @@
|
|
|
|
|
|
|
|
|
|
1 |
# --- Dummy Model Summaries ---
|
2 |
# Define functions that simulate model summary generation
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
}
|
9 |
|
10 |
# List of model names for easy access
|
11 |
-
model_names = list(
|
|
|
12 |
|
13 |
def generate_summaries(example, model_a_name, model_b_name):
|
14 |
"""
|
15 |
Generates summaries for the given example using the assigned models.
|
16 |
"""
|
|
|
17 |
# Create a plain text version of the contexts for the models
|
18 |
context_text = ""
|
|
|
19 |
if "contexts" in example and example["contexts"]:
|
20 |
-
context_parts = []
|
21 |
for ctx in example["contexts"]:
|
22 |
if isinstance(ctx, dict) and "content" in ctx:
|
23 |
context_parts.append(ctx["content"])
|
24 |
context_text = "\n---\n".join(context_parts)
|
25 |
else:
|
26 |
# Fallback to full contexts if highlighted contexts are not available
|
27 |
-
context_parts = []
|
28 |
if "full_contexts" in example:
|
29 |
for ctx in example["full_contexts"]:
|
30 |
if isinstance(ctx, dict) and "content" in ctx:
|
31 |
context_parts.append(ctx["content"])
|
32 |
context_text = "\n---\n".join(context_parts)
|
33 |
-
|
34 |
# Pass 'Answerable' status to models (they might use it)
|
35 |
answerable = example.get("Answerable", True)
|
36 |
question = example.get("question", "")
|
37 |
|
38 |
# Call the dummy model functions
|
39 |
-
summary_a =
|
40 |
-
summary_b =
|
41 |
return summary_a, summary_b
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
+
from .prompts import format_rag_prompt
|
4 |
+
|
5 |
# --- Dummy Model Summaries ---
|
6 |
# Define functions that simulate model summary generation
|
7 |
+
# models = {
|
8 |
+
# "Model Alpha": lambda context, question, answerable: f"Alpha Summary: Based on the context for '{question[:20]}...', it appears the question is {'answerable' if answerable else 'unanswerable'}.",
|
9 |
+
# "Model Beta": lambda context, question, answerable: f"Beta Summary: Regarding '{question[:20]}...', the provided documents {'allow' if answerable else 'do not allow'} for a conclusive answer based on the text.",
|
10 |
+
# "Model Gamma": lambda context, question, answerable: f"Gamma Summary: For the question '{question[:20]}...', I {'can' if answerable else 'cannot'} provide a specific answer from the given text snippets.",
|
11 |
+
# "Model Delta (Refusal Specialist)": lambda context, question, answerable: f"Delta Summary: The context for '{question[:20]}...' is {'sufficient' if answerable else 'insufficient'} to formulate a direct response. Therefore, I must refuse."
|
12 |
+
# }
|
13 |
+
|
14 |
+
models = {
|
15 |
+
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
|
16 |
+
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
|
17 |
+
#TODO add more models
|
18 |
}
|
19 |
|
20 |
# List of model names for easy access
|
21 |
+
model_names = list(models.keys())
|
22 |
+
|
23 |
|
24 |
def generate_summaries(example, model_a_name, model_b_name):
|
25 |
"""
|
26 |
Generates summaries for the given example using the assigned models.
|
27 |
"""
|
28 |
+
|
29 |
# Create a plain text version of the contexts for the models
|
30 |
context_text = ""
|
31 |
+
context_parts = []
|
32 |
if "contexts" in example and example["contexts"]:
|
|
|
33 |
for ctx in example["contexts"]:
|
34 |
if isinstance(ctx, dict) and "content" in ctx:
|
35 |
context_parts.append(ctx["content"])
|
36 |
context_text = "\n---\n".join(context_parts)
|
37 |
else:
|
38 |
# Fallback to full contexts if highlighted contexts are not available
|
|
|
39 |
if "full_contexts" in example:
|
40 |
for ctx in example["full_contexts"]:
|
41 |
if isinstance(ctx, dict) and "content" in ctx:
|
42 |
context_parts.append(ctx["content"])
|
43 |
context_text = "\n---\n".join(context_parts)
|
44 |
+
|
45 |
# Pass 'Answerable' status to models (they might use it)
|
46 |
answerable = example.get("Answerable", True)
|
47 |
question = example.get("question", "")
|
48 |
|
49 |
# Call the dummy model functions
|
50 |
+
summary_a = run_inference(models[model_a_name], context_text, question)
|
51 |
+
summary_b = run_inference(models[model_b_name], context_text, question)
|
52 |
return summary_a, summary_b
|
53 |
+
|
54 |
+
|
55 |
+
def run_inference(model_name, context, question):
|
56 |
+
"""
|
57 |
+
Run inference using the specified model.
|
58 |
+
"""
|
59 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
60 |
+
|
61 |
+
# Load the model and tokenizer
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
63 |
+
accepts_sys = (
|
64 |
+
"System role not supported" not in tokenizer.chat_template
|
65 |
+
) # Workaround for Gemma
|
66 |
+
|
67 |
+
# Set padding token if not set
|
68 |
+
if tokenizer.pad_token is None:
|
69 |
+
tokenizer.pad_token = tokenizer.eos_token
|
70 |
+
|
71 |
+
model = AutoModelForCausalLM.from_pretrained(
|
72 |
+
model_name, torch_dtype=torch.bfloat16, attn_implementation="eager"
|
73 |
+
).to(device)
|
74 |
+
|
75 |
+
text_input = format_rag_prompt(question, context, accepts_sys)
|
76 |
+
|
77 |
+
# Tokenize the input
|
78 |
+
actual_input = tokenizer.apply_chat_template(
|
79 |
+
text_input,
|
80 |
+
return_tensors="pt",
|
81 |
+
tokenize=True,
|
82 |
+
max_length=2048,
|
83 |
+
add_generation_prompt=True,
|
84 |
+
).to(device)
|
85 |
+
|
86 |
+
input_length = actual_input.shape[1]
|
87 |
+
|
88 |
+
# Create attention mask (1 for all tokens since we're not padding)
|
89 |
+
attention_mask = torch.ones_like(actual_input).to(device)
|
90 |
+
|
91 |
+
# Generate output
|
92 |
+
with torch.inference_mode():
|
93 |
+
# Disable gradient calculation for inference
|
94 |
+
outputs = model.generate(
|
95 |
+
actual_input,
|
96 |
+
attention_mask=attention_mask,
|
97 |
+
max_new_tokens=512, # Use max_new_tokens instead of max_length
|
98 |
+
pad_token_id=tokenizer.pad_token_id,
|
99 |
+
)
|
100 |
+
|
101 |
+
# Decode the output
|
102 |
+
result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
|
103 |
+
|
104 |
+
return result
|
utils/prompts.py
CHANGED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def format_rag_prompt( query: str, context: str, accepts_sys: bool) -> str:
|
2 |
+
system_prompt = """
|
3 |
+
You are a helpful assistant that provides answers to queries based on the provided context.
|
4 |
+
|
5 |
+
You MUST clearly refuse to answer the query and ask for additional information from the user if the answer cannot be found in the context.
|
6 |
+
The output should not contain your judgment on answerability, only your answer OR your refusal + clarifications.
|
7 |
+
|
8 |
+
Stay within the bounds of the provided context and avoid making assumptions.
|
9 |
+
|
10 |
+
|
11 |
+
"""
|
12 |
+
user_prompt = f"""
|
13 |
+
|
14 |
+
# Role and Task Description
|
15 |
+
Judge if the following query is answerable from ONLY the provided context.
|
16 |
+
If so, provide a complete, grounded answer to the query, and do not mention your judgement.
|
17 |
+
Try to address all aspects of the query, but if certain parts are not answerable, clearly state that you do not have enough information.
|
18 |
+
|
19 |
+
OTHERWISE, refuse clearly to answer and ask for the additional information you require from the user.
|
20 |
+
You should give a concise explanation of why you cannot answer the query based on the context, and ask for more relevant information from the user.
|
21 |
+
|
22 |
+
# Task
|
23 |
+
Given the following query and context, please provide your response:
|
24 |
+
Query: {query}
|
25 |
+
|
26 |
+
Context: {context}
|
27 |
+
|
28 |
+
WITHOUT mentioning your judgement either your grounded answer, OR refusal and clarifications:
|
29 |
+
"""
|
30 |
+
|
31 |
+
messages = (
|
32 |
+
[
|
33 |
+
{"role": "system", "content": system_prompt},
|
34 |
+
{"role": "user", "content": user_prompt},
|
35 |
+
]
|
36 |
+
if accepts_sys
|
37 |
+
else [{"role": "user", "content": system_prompt + user_prompt}]
|
38 |
+
)
|
39 |
+
return messages
|