oliver-aizip commited on
Commit
ddaff53
·
1 Parent(s): 6f5e355

v1 inference code

Browse files
Files changed (3) hide show
  1. requirements.txt +1 -0
  2. utils/models.py +74 -11
  3. utils/prompts.py +39 -0
requirements.txt CHANGED
@@ -6,3 +6,4 @@ numpy==1.26.4
6
  openai>=1.60.2
7
  torch>=2.5.1
8
  tqdm==4.67.1
 
 
6
  openai>=1.60.2
7
  torch>=2.5.1
8
  tqdm==4.67.1
9
+ flash-attn>=2.7.4
utils/models.py CHANGED
@@ -1,41 +1,104 @@
 
 
 
 
1
  # --- Dummy Model Summaries ---
2
  # Define functions that simulate model summary generation
3
- dummy_models = {
4
- "Model Alpha": lambda context, question, answerable: f"Alpha Summary: Based on the context for '{question[:20]}...', it appears the question is {'answerable' if answerable else 'unanswerable'}.",
5
- "Model Beta": lambda context, question, answerable: f"Beta Summary: Regarding '{question[:20]}...', the provided documents {'allow' if answerable else 'do not allow'} for a conclusive answer based on the text.",
6
- "Model Gamma": lambda context, question, answerable: f"Gamma Summary: For the question '{question[:20]}...', I {'can' if answerable else 'cannot'} provide a specific answer from the given text snippets.",
7
- "Model Delta (Refusal Specialist)": lambda context, question, answerable: f"Delta Summary: The context for '{question[:20]}...' is {'sufficient' if answerable else 'insufficient'} to formulate a direct response. Therefore, I must refuse."
 
 
 
 
 
 
8
  }
9
 
10
  # List of model names for easy access
11
- model_names = list(dummy_models.keys())
 
12
 
13
  def generate_summaries(example, model_a_name, model_b_name):
14
  """
15
  Generates summaries for the given example using the assigned models.
16
  """
 
17
  # Create a plain text version of the contexts for the models
18
  context_text = ""
 
19
  if "contexts" in example and example["contexts"]:
20
- context_parts = []
21
  for ctx in example["contexts"]:
22
  if isinstance(ctx, dict) and "content" in ctx:
23
  context_parts.append(ctx["content"])
24
  context_text = "\n---\n".join(context_parts)
25
  else:
26
  # Fallback to full contexts if highlighted contexts are not available
27
- context_parts = []
28
  if "full_contexts" in example:
29
  for ctx in example["full_contexts"]:
30
  if isinstance(ctx, dict) and "content" in ctx:
31
  context_parts.append(ctx["content"])
32
  context_text = "\n---\n".join(context_parts)
33
-
34
  # Pass 'Answerable' status to models (they might use it)
35
  answerable = example.get("Answerable", True)
36
  question = example.get("question", "")
37
 
38
  # Call the dummy model functions
39
- summary_a = dummy_models[model_a_name](context_text, question, answerable)
40
- summary_b = dummy_models[model_b_name](context_text, question, answerable)
41
  return summary_a, summary_b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from .prompts import format_rag_prompt
4
+
5
  # --- Dummy Model Summaries ---
6
  # Define functions that simulate model summary generation
7
+ # models = {
8
+ # "Model Alpha": lambda context, question, answerable: f"Alpha Summary: Based on the context for '{question[:20]}...', it appears the question is {'answerable' if answerable else 'unanswerable'}.",
9
+ # "Model Beta": lambda context, question, answerable: f"Beta Summary: Regarding '{question[:20]}...', the provided documents {'allow' if answerable else 'do not allow'} for a conclusive answer based on the text.",
10
+ # "Model Gamma": lambda context, question, answerable: f"Gamma Summary: For the question '{question[:20]}...', I {'can' if answerable else 'cannot'} provide a specific answer from the given text snippets.",
11
+ # "Model Delta (Refusal Specialist)": lambda context, question, answerable: f"Delta Summary: The context for '{question[:20]}...' is {'sufficient' if answerable else 'insufficient'} to formulate a direct response. Therefore, I must refuse."
12
+ # }
13
+
14
+ models = {
15
+ "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
16
+ "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
17
+ #TODO add more models
18
  }
19
 
20
  # List of model names for easy access
21
+ model_names = list(models.keys())
22
+
23
 
24
  def generate_summaries(example, model_a_name, model_b_name):
25
  """
26
  Generates summaries for the given example using the assigned models.
27
  """
28
+
29
  # Create a plain text version of the contexts for the models
30
  context_text = ""
31
+ context_parts = []
32
  if "contexts" in example and example["contexts"]:
 
33
  for ctx in example["contexts"]:
34
  if isinstance(ctx, dict) and "content" in ctx:
35
  context_parts.append(ctx["content"])
36
  context_text = "\n---\n".join(context_parts)
37
  else:
38
  # Fallback to full contexts if highlighted contexts are not available
 
39
  if "full_contexts" in example:
40
  for ctx in example["full_contexts"]:
41
  if isinstance(ctx, dict) and "content" in ctx:
42
  context_parts.append(ctx["content"])
43
  context_text = "\n---\n".join(context_parts)
44
+
45
  # Pass 'Answerable' status to models (they might use it)
46
  answerable = example.get("Answerable", True)
47
  question = example.get("question", "")
48
 
49
  # Call the dummy model functions
50
+ summary_a = run_inference(models[model_a_name], context_text, question)
51
+ summary_b = run_inference(models[model_b_name], context_text, question)
52
  return summary_a, summary_b
53
+
54
+
55
+ def run_inference(model_name, context, question):
56
+ """
57
+ Run inference using the specified model.
58
+ """
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+
61
+ # Load the model and tokenizer
62
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
63
+ accepts_sys = (
64
+ "System role not supported" not in tokenizer.chat_template
65
+ ) # Workaround for Gemma
66
+
67
+ # Set padding token if not set
68
+ if tokenizer.pad_token is None:
69
+ tokenizer.pad_token = tokenizer.eos_token
70
+
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ model_name, torch_dtype=torch.bfloat16, attn_implementation="eager"
73
+ ).to(device)
74
+
75
+ text_input = format_rag_prompt(question, context, accepts_sys)
76
+
77
+ # Tokenize the input
78
+ actual_input = tokenizer.apply_chat_template(
79
+ text_input,
80
+ return_tensors="pt",
81
+ tokenize=True,
82
+ max_length=2048,
83
+ add_generation_prompt=True,
84
+ ).to(device)
85
+
86
+ input_length = actual_input.shape[1]
87
+
88
+ # Create attention mask (1 for all tokens since we're not padding)
89
+ attention_mask = torch.ones_like(actual_input).to(device)
90
+
91
+ # Generate output
92
+ with torch.inference_mode():
93
+ # Disable gradient calculation for inference
94
+ outputs = model.generate(
95
+ actual_input,
96
+ attention_mask=attention_mask,
97
+ max_new_tokens=512, # Use max_new_tokens instead of max_length
98
+ pad_token_id=tokenizer.pad_token_id,
99
+ )
100
+
101
+ # Decode the output
102
+ result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
103
+
104
+ return result
utils/prompts.py CHANGED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def format_rag_prompt( query: str, context: str, accepts_sys: bool) -> str:
2
+ system_prompt = """
3
+ You are a helpful assistant that provides answers to queries based on the provided context.
4
+
5
+ You MUST clearly refuse to answer the query and ask for additional information from the user if the answer cannot be found in the context.
6
+ The output should not contain your judgment on answerability, only your answer OR your refusal + clarifications.
7
+
8
+ Stay within the bounds of the provided context and avoid making assumptions.
9
+
10
+
11
+ """
12
+ user_prompt = f"""
13
+
14
+ # Role and Task Description
15
+ Judge if the following query is answerable from ONLY the provided context.
16
+ If so, provide a complete, grounded answer to the query, and do not mention your judgement.
17
+ Try to address all aspects of the query, but if certain parts are not answerable, clearly state that you do not have enough information.
18
+
19
+ OTHERWISE, refuse clearly to answer and ask for the additional information you require from the user.
20
+ You should give a concise explanation of why you cannot answer the query based on the context, and ask for more relevant information from the user.
21
+
22
+ # Task
23
+ Given the following query and context, please provide your response:
24
+ Query: {query}
25
+
26
+ Context: {context}
27
+
28
+ WITHOUT mentioning your judgement either your grounded answer, OR refusal and clarifications:
29
+ """
30
+
31
+ messages = (
32
+ [
33
+ {"role": "system", "content": system_prompt},
34
+ {"role": "user", "content": user_prompt},
35
+ ]
36
+ if accepts_sys
37
+ else [{"role": "user", "content": system_prompt + user_prompt}]
38
+ )
39
+ return messages