File size: 7,850 Bytes
ddaff53
1db9e92
ddaff53
1db9e92
6b26b26
 
 
ddaff53
 
 
665e5a3
1db9e92
8a142a6
 
 
ddaff53
 
1db9e92
 
 
 
 
 
 
8a142a6
 
 
 
 
1db9e92
 
ddaff53
8a142a6
ddaff53
8151596
 
8a142a6
 
 
 
8151596
ddaff53
8a142a6
 
1db9e92
 
 
6b26b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1db9e92
6b26b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a142a6
ddaff53
6b26b26
 
 
ddaff53
6b26b26
 
ddaff53
6b26b26
1db9e92
6b26b26
 
 
ddaff53
6b26b26
 
 
ddaff53
1db9e92
 
 
 
6b26b26
1db9e92
6b26b26
1db9e92
 
6b26b26
 
1db9e92
6b26b26
 
 
1db9e92
 
 
6b26b26
ddaff53
1db9e92
ddaff53
6b26b26
1db9e92
6b26b26
 
 
1db9e92
 
 
 
6b26b26
 
 
 
1db9e92
 
ddaff53
6b26b26
 
 
 
 
 
 
 
1db9e92
 
6b26b26
 
1db9e92
6b26b26
 
 
1db9e92
6b26b26
1db9e92
 
 
 
6b26b26
1db9e92
6b26b26
 
 
 
1db9e92
ddaff53
6b26b26
1db9e92
6b26b26
 
 
 
 
 
 
1db9e92
6b26b26
 
 
 
1db9e92
6b26b26
 
 
 
 
1db9e92
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from .prompts import format_rag_prompt
from .shared import generation_interrupt
import threading
import queue
import time # Added for sleep

models = {
    "Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
    "Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
    "Gemma-3-1b-it": "google/gemma-3-1b-it",
}

# List of model names for easy access
model_names = list(models.keys())

# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
    def __init__(self, interrupt_event):
        self.interrupt_event = interrupt_event
        
    def __call__(self, input_ids, scores, **kwargs):
        return self.interrupt_event.is_set()

def generate_summaries(example, model_a_name, model_b_name):
    """
    Generates summaries for the given example using the assigned models.
    """
    if generation_interrupt.is_set():
        return "", ""

    context_text = ""
    context_parts = []
    if "full_contexts" in example:
        for ctx in example["full_contexts"]:
            if isinstance(ctx, dict) and "content" in ctx:
                context_parts.append(ctx["content"])
        context_text = "\n---\n".join(context_parts)
    else:
        raise ValueError("No context found in the example.")

    question = example.get("question", "")

    if generation_interrupt.is_set():
        return "", ""

    # Use a queue to get results from threads
    result_queue_a = queue.Queue()
    thread_a = threading.Thread(target=run_inference, args=(models[model_a_name], context_text, question, result_queue_a))
    thread_a.start()

    summary_a = ""
    while thread_a.is_alive():
        if generation_interrupt.is_set():
            print(f"Interrupting model A ({model_a_name})...")
            # The InterruptCriteria within the thread will handle stopping generate
            # We return early from the main control flow.
            thread_a.join(timeout=1.0) # Give thread a moment to potentially stop
            return "", ""
        try:
            summary_a = result_queue_a.get(timeout=0.1) # Check queue periodically
            break # Got result
        except queue.Empty:
            continue # Still running, check interrupt again

    # If thread finished but we didn't get a result (e.g., interrupted just before putting in queue)
    if not summary_a and not result_queue_a.empty():
         summary_a = result_queue_a.get_nowait()
    elif not summary_a and generation_interrupt.is_set(): # Check interrupt again if thread finished quickly
         return "", ""


    if generation_interrupt.is_set(): # Check between models
        return summary_a, ""

    # --- Model B ---
    result_queue_b = queue.Queue()
    thread_b = threading.Thread(target=run_inference, args=(models[model_b_name], context_text, question, result_queue_b))
    thread_b.start()

    summary_b = ""
    while thread_b.is_alive():
        if generation_interrupt.is_set():
            print(f"Interrupting model B ({model_b_name})...")
            thread_b.join(timeout=1.0)
            return summary_a, "" # Return summary_a obtained so far
        try:
            summary_b = result_queue_b.get(timeout=0.1)
            break
        except queue.Empty:
            continue

    if not summary_b and not result_queue_b.empty():
        summary_b = result_queue_b.get_nowait()
    elif not summary_b and generation_interrupt.is_set():
         return summary_a, ""


    return summary_a, summary_b


# Modified run_inference to run in a thread and use a queue for results
def run_inference(model_name, context, question, result_queue):
    """
    Run inference using the specified model. Designed to be run in a thread.
    Puts the result or an error string into the result_queue.
    """
    # Check interrupt at the very beginning of the thread
    if generation_interrupt.is_set():
        result_queue.put("")
        return

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = None
    tokenizer = None
    result = ""

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
        accepts_sys = (
            "System role not supported" not in tokenizer.chat_template
            if tokenizer.chat_template else False # Handle missing chat_template
        )

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # Check interrupt before loading the model
        if generation_interrupt.is_set():
             result_queue.put("")
             return

        model = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", token=True
        ).to(device)
        model.eval() # Set model to evaluation mode

        text_input = format_rag_prompt(question, context, accepts_sys)

        # Check interrupt before tokenization/template application
        if generation_interrupt.is_set():
             result_queue.put("")
             return

        actual_input = tokenizer.apply_chat_template(
            text_input,
            return_tensors="pt",
            tokenize=True,
            # Consider reducing max_length if context/question is very long
            # max_length=tokenizer.model_max_length, # Use model's max length
            # truncation=True, # Ensure truncation if needed
            max_length=2048, # Keep original max_length for now
            add_generation_prompt=True,
        ).to(device)

        # Ensure input does not exceed model max length after adding generation prompt
        # This check might be redundant if tokenizer handles it, but good for safety
        # if actual_input.shape[1] > tokenizer.model_max_length:
        #    # Handle too long input - maybe truncate manually or raise error
        #    print(f"Warning: Input length {actual_input.shape[1]} exceeds model max length {tokenizer.model_max_length}")
        #    # Simple truncation (might lose important info):
        #    # actual_input = actual_input[:, -tokenizer.model_max_length:]

        input_length = actual_input.shape[1]
        attention_mask = torch.ones_like(actual_input).to(device)

        # Check interrupt before generation
        if generation_interrupt.is_set():
            result_queue.put("")
            return

        stopping_criteria = StoppingCriteriaList([InterruptCriteria(generation_interrupt)])

        with torch.inference_mode():
            outputs = model.generate(
                actual_input,
                attention_mask=attention_mask,
                max_new_tokens=512,
                pad_token_id=tokenizer.pad_token_id,
                stopping_criteria=stopping_criteria,
                do_sample=True, # Consider adding sampling parameters if needed
                temperature=0.6,
                top_p=0.9,
            )

        # Check interrupt immediately after generation finishes or stops
        if generation_interrupt.is_set():
            result = "" # Discard potentially partial result if interrupted
        else:
            # Decode the generated tokens, excluding the input tokens
            result = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)

        result_queue.put(result)

    except Exception as e:
        print(f"Error in inference thread for {model_name}: {e}")
        # Put error message in queue for the main thread to handle/display
        result_queue.put(f"Error generating response: {str(e)[:100]}...")

    finally:
        # Clean up resources within the thread
        del model
        del tokenizer
        del actual_input
        del outputs
        if torch.cuda.is_available():
            torch.cuda.empty_cache()