Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.environ['MKL_THREADING_LAYER'] = 'GNU' | |
import spaces | |
import torch | |
from transformers import pipeline, AutoTokenizer, StoppingCriteria, StoppingCriteriaList | |
from .prompts import format_rag_prompt | |
from .shared import generation_interrupt | |
models = { | |
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct", | |
"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct", | |
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct", | |
"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct", | |
"Gemma-3-1b-it": "google/gemma-3-1b-it", | |
#"Gemma-3-4b-it": "google/gemma-3-4b-it", | |
"Gemma-2-2b-it": "google/gemma-2-2b-it", | |
"Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct", | |
#"Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b", | |
"IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct", | |
"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T" | |
} | |
# List of model names for easy access | |
model_names = list(models.keys()) | |
# Custom stopping criteria that checks the interrupt flag | |
class InterruptCriteria(StoppingCriteria): | |
def __init__(self, interrupt_event): | |
self.interrupt_event = interrupt_event | |
def __call__(self, input_ids, scores, **kwargs): | |
return self.interrupt_event.is_set() | |
def generate_summaries(example, model_a_name, model_b_name): | |
""" | |
Generates summaries for the given example using the assigned models sequentially. | |
""" | |
if generation_interrupt.is_set(): | |
return "", "" | |
context_text = "" | |
context_parts = [] | |
if "full_contexts" in example and example["full_contexts"]: | |
for i, ctx in enumerate(example["full_contexts"]): | |
content = "" | |
# Extract content from either dict or string | |
if isinstance(ctx, dict) and "content" in ctx: | |
content = ctx["content"] | |
elif isinstance(ctx, str): | |
content = ctx | |
# Add document number if not already present | |
if not content.strip().startswith("Document"): | |
content = f"Document {i+1}:\n{content}" | |
context_parts.append(content) | |
context_text = "\n\n".join(context_parts) | |
else: | |
# Provide a graceful fallback instead of raising an error | |
print("Warning: No full context found in the example, using empty context") | |
context_text = "" | |
question = example.get("question", "") | |
if generation_interrupt.is_set(): | |
return "", "" | |
# Run model A | |
summary_a = run_inference(models[model_a_name], context_text, question) | |
if generation_interrupt.is_set(): | |
return summary_a, "" | |
# Run model B | |
summary_b = run_inference(models[model_b_name], context_text, question) | |
return summary_a, summary_b | |
def run_inference(model_name, context, question): | |
""" | |
Run inference using the specified model. | |
Returns the generated text or empty string if interrupted. | |
""" | |
# Check interrupt at the beginning | |
if generation_interrupt.is_set(): | |
return "" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
result = "" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True) | |
accepts_sys = ( | |
"System role not supported" not in tokenizer.chat_template | |
if tokenizer.chat_template else False # Handle missing chat_template | |
) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Check interrupt before loading the model | |
if generation_interrupt.is_set(): | |
return "" | |
pipe = pipeline( | |
"text-generation", | |
model=model_name, | |
tokenizer=tokenizer, | |
device_map='auto', | |
max_length=512, | |
do_sample=True, | |
temperature=0.6, | |
top_p=0.9, | |
) | |
text_input = format_rag_prompt(question, context, accepts_sys) | |
# Check interrupt before generation | |
if generation_interrupt.is_set(): | |
return "" | |
outputs = pipe(text_input, max_new_tokens=512) | |
result = outputs[0]['generated_text'][-1]['content'] | |
except Exception as e: | |
print(f"Error in inference for {model_name}: {e}") | |
result = f"Error generating response: {str(e)[:200]}..." | |
finally: | |
# Clean up resources | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return result |