Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,850 Bytes
ddaff53 1db9e92 ddaff53 1db9e92 6b26b26 ddaff53 665e5a3 1db9e92 8a142a6 ddaff53 1db9e92 8a142a6 1db9e92 ddaff53 8a142a6 ddaff53 8151596 8a142a6 8151596 ddaff53 8a142a6 1db9e92 6b26b26 1db9e92 6b26b26 8a142a6 ddaff53 6b26b26 ddaff53 6b26b26 ddaff53 6b26b26 1db9e92 6b26b26 ddaff53 6b26b26 ddaff53 1db9e92 6b26b26 1db9e92 6b26b26 1db9e92 6b26b26 1db9e92 6b26b26 1db9e92 6b26b26 ddaff53 1db9e92 ddaff53 6b26b26 1db9e92 6b26b26 1db9e92 6b26b26 1db9e92 ddaff53 6b26b26 1db9e92 6b26b26 1db9e92 6b26b26 1db9e92 6b26b26 1db9e92 6b26b26 1db9e92 6b26b26 1db9e92 ddaff53 6b26b26 1db9e92 6b26b26 1db9e92 6b26b26 1db9e92 6b26b26 1db9e92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from .prompts import format_rag_prompt
from .shared import generation_interrupt
import threading
import queue
import time # Added for sleep
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Gemma-3-1b-it": "google/gemma-3-1b-it",
}
# List of model names for easy access
model_names = list(models.keys())
# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
def __init__(self, interrupt_event):
self.interrupt_event = interrupt_event
def __call__(self, input_ids, scores, **kwargs):
return self.interrupt_event.is_set()
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models.
"""
if generation_interrupt.is_set():
return "", ""
context_text = ""
context_parts = []
if "full_contexts" in example:
for ctx in example["full_contexts"]:
if isinstance(ctx, dict) and "content" in ctx:
context_parts.append(ctx["content"])
context_text = "\n---\n".join(context_parts)
else:
raise ValueError("No context found in the example.")
question = example.get("question", "")
if generation_interrupt.is_set():
return "", ""
# Use a queue to get results from threads
result_queue_a = queue.Queue()
thread_a = threading.Thread(target=run_inference, args=(models[model_a_name], context_text, question, result_queue_a))
thread_a.start()
summary_a = ""
while thread_a.is_alive():
if generation_interrupt.is_set():
print(f"Interrupting model A ({model_a_name})...")
# The InterruptCriteria within the thread will handle stopping generate
# We return early from the main control flow.
thread_a.join(timeout=1.0) # Give thread a moment to potentially stop
return "", ""
try:
summary_a = result_queue_a.get(timeout=0.1) # Check queue periodically
break # Got result
except queue.Empty:
continue # Still running, check interrupt again
# If thread finished but we didn't get a result (e.g., interrupted just before putting in queue)
if not summary_a and not result_queue_a.empty():
summary_a = result_queue_a.get_nowait()
elif not summary_a and generation_interrupt.is_set(): # Check interrupt again if thread finished quickly
return "", ""
if generation_interrupt.is_set(): # Check between models
return summary_a, ""
# --- Model B ---
result_queue_b = queue.Queue()
thread_b = threading.Thread(target=run_inference, args=(models[model_b_name], context_text, question, result_queue_b))
thread_b.start()
summary_b = ""
while thread_b.is_alive():
if generation_interrupt.is_set():
print(f"Interrupting model B ({model_b_name})...")
thread_b.join(timeout=1.0)
return summary_a, "" # Return summary_a obtained so far
try:
summary_b = result_queue_b.get(timeout=0.1)
break
except queue.Empty:
continue
if not summary_b and not result_queue_b.empty():
summary_b = result_queue_b.get_nowait()
elif not summary_b and generation_interrupt.is_set():
return summary_a, ""
return summary_a, summary_b
# Modified run_inference to run in a thread and use a queue for results
def run_inference(model_name, context, question, result_queue):
"""
Run inference using the specified model. Designed to be run in a thread.
Puts the result or an error string into the result_queue.
"""
# Check interrupt at the very beginning of the thread
if generation_interrupt.is_set():
result_queue.put("")
return
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None
tokenizer = None
result = ""
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
if tokenizer.chat_template else False # Handle missing chat_template
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Check interrupt before loading the model
if generation_interrupt.is_set():
result_queue.put("")
return
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
).to(device)
model.eval() # Set model to evaluation mode
text_input = format_rag_prompt(question, context, accepts_sys)
# Check interrupt before tokenization/template application
if generation_interrupt.is_set():
result_queue.put("")
return
actual_input = tokenizer.apply_chat_template(
text_input,
return_tensors="pt",
tokenize=True,
# Consider reducing max_length if context/question is very long
# max_length=tokenizer.model_max_length, # Use model's max length
# truncation=True, # Ensure truncation if needed
max_length=2048, # Keep original max_length for now
add_generation_prompt=True,
).to(device)
# Ensure input does not exceed model max length after adding generation prompt
# This check might be redundant if tokenizer handles it, but good for safety
# if actual_input.shape[1] > tokenizer.model_max_length:
# # Handle too long input - maybe truncate manually or raise error
# print(f"Warning: Input length {actual_input.shape[1]} exceeds model max length {tokenizer.model_max_length}")
# # Simple truncation (might lose important info):
# # actual_input = actual_input[:, -tokenizer.model_max_length:]
input_length = actual_input.shape[1]
attention_mask = torch.ones_like(actual_input).to(device)
# Check interrupt before generation
if generation_interrupt.is_set():
result_queue.put("")
return
stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])
with torch.inference_mode():
outputs = model.generate(
actual_input,
attention_mask=attention_mask,
max_new_tokens=512,
pad_token_id=tokenizer.pad_token_id,
stopping_criteria=stopping_criteria,
do_sample=True, # Consider adding sampling parameters if needed
temperature=0.6,
top_p=0.9,
)
# Check interrupt immediately after generation finishes or stops
if generation_interrupt.is_set():
result = "" # Discard potentially partial result if interrupted
else:
# Decode the generated tokens, excluding the input tokens
result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
result_queue.put(result)
except Exception as e:
print(f"Error in inference thread for {model_name}: {e}")
# Put error message in queue for the main thread to handle/display
result_queue.put(f"Error generating response: {str(e)[:100]}...")
finally:
# Clean up resources within the thread
del model
del tokenizer
del actual_input
del outputs
if torch.cuda.is_available():
torch.cuda.empty_cache() |