Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,601 Bytes
6c63a2d b8ee0a2 6c63a2d ddaff53 b8ee0a2 ddaff53 1db9e92 97629be ddaff53 dc2d37a e7fd2d9 1db9e92 c0f44db 533ae49 e7fd2d9 dc2d37a 491e00d dc2d37a 8a142a6 ddaff53 1db9e92 8a142a6 97629be 8a142a6 97629be 8a142a6 1db9e92 798ebc4 8a142a6 ddaff53 798ebc4 8a142a6 798ebc4 8a142a6 798ebc4 8a142a6 798ebc4 1db9e92 798ebc4 97629be 1db9e92 97629be 8a142a6 ddaff53 b8ee0a2 97629be ddaff53 97629be ddaff53 97629be 1db9e92 97629be 6b26b26 ddaff53 6b26b26 ddaff53 1db9e92 6b26b26 1db9e92 6b26b26 b8ee0a2 6b26b26 b8ee0a2 97629be b8ee0a2 6b26b26 1db9e92 ddaff53 97629be b8ee0a2 97629be e6127a4 6b26b26 1db9e92 97629be 6b26b26 1db9e92 97629be 1db9e92 97629be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
import spaces
import torch
from transformers import pipeline, AutoTokenizer, StoppingCriteria, StoppingCriteriaList
from .prompts import format_rag_prompt
from .shared import generation_interrupt
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
"Gemma-3-1b-it": "google/gemma-3-1b-it",
#"Gemma-3-4b-it": "google/gemma-3-4b-it",
"Gemma-2-2b-it": "google/gemma-2-2b-it",
"Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
#"Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
"IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T"
}
# List of model names for easy access
model_names = list(models.keys())
# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
def __init__(self, interrupt_event):
self.interrupt_event = interrupt_event
def __call__(self, input_ids, scores, **kwargs):
return self.interrupt_event.is_set()
@spaces.GPU
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models sequentially.
"""
if generation_interrupt.is_set():
return "", ""
context_text = ""
context_parts = []
if "full_contexts" in example and example["full_contexts"]:
for i, ctx in enumerate(example["full_contexts"]):
content = ""
# Extract content from either dict or string
if isinstance(ctx, dict) and "content" in ctx:
content = ctx["content"]
elif isinstance(ctx, str):
content = ctx
# Add document number if not already present
if not content.strip().startswith("Document"):
content = f"Document {i+1}:\n{content}"
context_parts.append(content)
context_text = "\n\n".join(context_parts)
else:
# Provide a graceful fallback instead of raising an error
print("Warning: No full context found in the example, using empty context")
context_text = ""
question = example.get("question", "")
if generation_interrupt.is_set():
return "", ""
# Run model A
summary_a = run_inference(models[model_a_name], context_text, question)
if generation_interrupt.is_set():
return summary_a, ""
# Run model B
summary_b = run_inference(models[model_b_name], context_text, question)
return summary_a, summary_b
@spaces.GPU
def run_inference(model_name, context, question):
"""
Run inference using the specified model.
Returns the generated text or empty string if interrupted.
"""
# Check interrupt at the beginning
if generation_interrupt.is_set():
return ""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
result = ""
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left", token=True)
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
if tokenizer.chat_template else False # Handle missing chat_template
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Check interrupt before loading the model
if generation_interrupt.is_set():
return ""
pipe = pipeline(
"text-generation",
model=model_name,
tokenizer=tokenizer,
device_map='auto',
max_length=512,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
text_input = format_rag_prompt(question, context, accepts_sys)
# Check interrupt before generation
if generation_interrupt.is_set():
return ""
outputs = pipe(text_input, max_new_tokens=512)
result = outputs[0]['generated_text'][-1]['content']
except Exception as e:
print(f"Error in inference for {model_name}: {e}")
result = f"Error generating response: {str(e)[:200]}..."
finally:
# Clean up resources
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result |