oliver-aizip kai-aizip commited on
Commit
1db9e92
·
verified ·
1 Parent(s): 69f6a43

Handled interruption (#10)

Browse files

- Handled interruption (9a1fcf079875ce647f4228f03d39b0a16a575134)


Co-authored-by: Kai <kai-aizip@users.noreply.huggingface.co>

Files changed (1) hide show
  1. utils/models.py +81 -55
utils/models.py CHANGED
@@ -1,36 +1,32 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from .prompts import format_rag_prompt
4
-
5
- # --- Dummy Model Summaries ---
6
- # Define functions that simulate model summary generation
7
- # models = {
8
- # "Model Alpha": lambda context, question, answerable: f"Alpha Summary: Based on the context for '{question[:20]}...', it appears the question is {'answerable' if answerable else 'unanswerable'}.",
9
- # "Model Beta": lambda context, question, answerable: f"Beta Summary: Regarding '{question[:20]}...', the provided documents {'allow' if answerable else 'do not allow'} for a conclusive answer based on the text.",
10
- # "Model Gamma": lambda context, question, answerable: f"Gamma Summary: For the question '{question[:20]}...', I {'can' if answerable else 'cannot'} provide a specific answer from the given text snippets.",
11
- # "Model Delta (Refusal Specialist)": lambda context, question, answerable: f"Delta Summary: The context for '{question[:20]}...' is {'sufficient' if answerable else 'insufficient'} to formulate a direct response. Therefore, I must refuse."
12
- # }
13
 
14
  models = {
15
  "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
16
- #"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct", # remove gated for now
17
- #"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
18
  "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
19
- "Gemma-3-1b-it" : "google/gemma-3-1b-it",
20
- #"Bitnet-b1.58-2B-4T": "microsoft/bitnet-b1.58-2B-4T",
21
- #TODO add more models
22
  }
23
 
24
  # List of model names for easy access
25
  model_names = list(models.keys())
26
 
 
 
 
 
 
 
 
27
 
28
  def generate_summaries(example, model_a_name, model_b_name):
29
  """
30
  Generates summaries for the given example using the assigned models.
31
  """
 
 
32
 
33
- # Create a plain text version of the contexts for the models
34
  context_text = ""
35
  context_parts = []
36
  if "full_contexts" in example:
@@ -41,12 +37,16 @@ def generate_summaries(example, model_a_name, model_b_name):
41
  else:
42
  raise ValueError("No context found in the example.")
43
 
44
- # Pass 'Answerable' status to models (they might use it)
45
- answerable = example.get("Answerable", True)
46
  question = example.get("question", "")
47
 
48
- # Call the dummy model functions
 
 
49
  summary_a = run_inference(models[model_a_name], context_text, question)
 
 
 
 
50
  summary_b = run_inference(models[model_b_name], context_text, question)
51
  return summary_a, summary_b
52
 
@@ -54,46 +54,72 @@ def run_inference(model_name, context, question):
54
  """
55
  Run inference using the specified model.
56
  """
 
 
 
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
 
59
- # Load the model and tokenizer
60
- tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
61
- accepts_sys = (
62
- "System role not supported" not in tokenizer.chat_template
63
- ) # Workaround for Gemma
64
-
65
- # Set padding token if not set
66
- if tokenizer.pad_token is None:
67
- tokenizer.pad_token = tokenizer.eos_token
68
 
69
- model = AutoModelForCausalLM.from_pretrained(
70
- model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
71
- ).to(device)
 
 
 
 
 
 
72
 
73
- text_input = format_rag_prompt(question, context, accepts_sys)
74
 
75
- # Tokenize the input
76
- actual_input = tokenizer.apply_chat_template(
77
- text_input,
78
- return_tensors="pt",
79
- tokenize=True,
80
- max_length=2048,
81
- add_generation_prompt=True,
82
- ).to(device)
 
 
83
 
84
- input_length = actual_input.shape[1]
85
- attention_mask = torch.ones_like(actual_input).to(device)
86
-
87
- # Generate output
88
- with torch.inference_mode():
89
- outputs = model.generate(
90
- actual_input,
91
- attention_mask=attention_mask,
92
- max_new_tokens=512,
93
- pad_token_id=tokenizer.pad_token_id,
94
- )
95
-
96
- # Decode the output
97
- result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
 
 
98
 
99
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
3
  from .prompts import format_rag_prompt
4
+ from .shared import generation_interrupt
 
 
 
 
 
 
 
 
5
 
6
  models = {
7
  "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
 
 
8
  "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
9
+ "Gemma-3-1b-it": "google/gemma-3-1b-it",
 
 
10
  }
11
 
12
  # List of model names for easy access
13
  model_names = list(models.keys())
14
 
15
+ # Custom stopping criteria that checks the interrupt flag
16
+ class InterruptCriteria(StoppingCriteria):
17
+ def __init__(self, interrupt_event):
18
+ self.interrupt_event = interrupt_event
19
+
20
+ def __call__(self, input_ids, scores, **kwargs):
21
+ return self.interrupt_event.is_set()
22
 
23
  def generate_summaries(example, model_a_name, model_b_name):
24
  """
25
  Generates summaries for the given example using the assigned models.
26
  """
27
+ if generation_interrupt.is_set():
28
+ return "", ""
29
 
 
30
  context_text = ""
31
  context_parts = []
32
  if "full_contexts" in example:
 
37
  else:
38
  raise ValueError("No context found in the example.")
39
 
 
 
40
  question = example.get("question", "")
41
 
42
+ if generation_interrupt.is_set():
43
+ return "", ""
44
+
45
  summary_a = run_inference(models[model_a_name], context_text, question)
46
+
47
+ if generation_interrupt.is_set():
48
+ return summary_a, ""
49
+
50
  summary_b = run_inference(models[model_b_name], context_text, question)
51
  return summary_a, summary_b
52
 
 
54
  """
55
  Run inference using the specified model.
56
  """
57
+ if generation_interrupt.is_set():
58
+ return ""
59
+
60
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
 
62
+ try:
63
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
64
+ accepts_sys = (
65
+ "System role not supported" not in tokenizer.chat_template
66
+ )
 
 
 
 
67
 
68
+ if tokenizer.pad_token is None:
69
+ tokenizer.pad_token = tokenizer.eos_token
70
+
71
+ if generation_interrupt.is_set():
72
+ return ""
73
+
74
+ model = AutoModelForCausalLM.from_pretrained(
75
+ model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
76
+ ).to(device)
77
 
78
+ text_input = format_rag_prompt(question, context, accepts_sys)
79
 
80
+ if generation_interrupt.is_set():
81
+ return ""
82
+
83
+ actual_input = tokenizer.apply_chat_template(
84
+ text_input,
85
+ return_tensors="pt",
86
+ tokenize=True,
87
+ max_length=2048,
88
+ add_generation_prompt=True,
89
+ ).to(device)
90
 
91
+ input_length = actual_input.shape[1]
92
+ attention_mask = torch.ones_like(actual_input).to(device)
93
+
94
+ if generation_interrupt.is_set():
95
+ return ""
96
+
97
+ stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])
98
+
99
+ with torch.inference_mode():
100
+ outputs = model.generate(
101
+ actual_input,
102
+ attention_mask=attention_mask,
103
+ max_new_tokens=512,
104
+ pad_token_id=tokenizer.pad_token_id,
105
+ stopping_criteria=stopping_criteria
106
+ )
107
 
108
+ if generation_interrupt.is_set():
109
+ return ""
110
+
111
+ result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
112
+
113
+ return result
114
+
115
+ except Exception as e:
116
+ print(f"Error in inference: {e}")
117
+ return f"Error generating response: {str(e)[:100]}..."
118
+
119
+ finally:
120
+ if 'model' in locals():
121
+ del model
122
+ if 'tokenizer' in locals():
123
+ del tokenizer
124
+ if torch.cuda.is_available():
125
+ torch.cuda.empty_cache()