Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,315 Bytes
6c63a2d b8ee0a2 6c63a2d ddaff53 0276240 ddaff53 1db9e92 97629be ddaff53 1898f4f c0f44db 1898f4f c5183c8 1898f4f dc2d37a 8a142a6 fd247b7 8a142a6 ddaff53 1db9e92 8a142a6 97629be 8a142a6 97629be 8a142a6 1db9e92 798ebc4 8a142a6 ddaff53 798ebc4 8a142a6 798ebc4 8a142a6 798ebc4 8a142a6 798ebc4 1db9e92 798ebc4 97629be 1db9e92 97629be 8a142a6 ddaff53 b8ee0a2 97629be ddaff53 97629be ddaff53 97629be 1db9e92 97629be 6b26b26 ddaff53 6b26b26 fd247b7 e24b19e c5183c8 fd247b7 ddaff53 1db9e92 fd247b7 1db9e92 6b26b26 1db9e92 6b26b26 b8ee0a2 6b26b26 b8ee0a2 97629be b8ee0a2 b41b93b b8ee0a2 0276240 f35135e fd247b7 1898f4f b8ee0a2 6b26b26 1db9e92 1898f4f 0276240 1898f4f 97629be 1898f4f f9d275c 1898f4f 6b26b26 1db9e92 97629be 6b26b26 1db9e92 97629be 1db9e92 97629be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
import spaces
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
from .prompts import format_rag_prompt
from .shared import generation_interrupt
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
#"Gemma-3-1b-it": "google/gemma-3-1b-it",
#"Gemma-3-4b-it": "google/gemma-3-4b-it",
"Gemma-2-2b-it": "google/gemma-2-2b-it",
"Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
"Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
"IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
# #"Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
# #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
"Qwen3-0.6b": "qwen/qwen3-0.6b",
"Qwen3-1.7b": "qwen/qwen3-1.7b",
"Qwen3-4b": "qwen/qwen3-4b",
"SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
"EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
"OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct",
}
tokenizer_cache = {}
# List of model names for easy access
model_names = list(models.keys())
# Custom stopping criteria that checks the interrupt flag
class InterruptCriteria(StoppingCriteria):
def __init__(self, interrupt_event):
self.interrupt_event = interrupt_event
def __call__(self, input_ids, scores, **kwargs):
return self.interrupt_event.is_set()
@spaces.GPU
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models sequentially.
"""
if generation_interrupt.is_set():
return "", ""
context_text = ""
context_parts = []
if "full_contexts" in example and example["full_contexts"]:
for i, ctx in enumerate(example["full_contexts"]):
content = ""
# Extract content from either dict or string
if isinstance(ctx, dict) and "content" in ctx:
content = ctx["content"]
elif isinstance(ctx, str):
content = ctx
# Add document number if not already present
if not content.strip().startswith("Document"):
content = f"Document {i+1}:\n{content}"
context_parts.append(content)
context_text = "\n\n".join(context_parts)
else:
# Provide a graceful fallback instead of raising an error
print("Warning: No full context found in the example, using empty context")
context_text = ""
question = example.get("question", "")
if generation_interrupt.is_set():
return "", ""
# Run model A
summary_a = run_inference(models[model_a_name], context_text, question)
if generation_interrupt.is_set():
return summary_a, ""
# Run model B
summary_b = run_inference(models[model_b_name], context_text, question)
return summary_a, summary_b
@spaces.GPU
def run_inference(model_name, context, question):
"""
Run inference using the specified model.
Returns the generated text or empty string if interrupted.
"""
# Check interrupt at the beginning
if generation_interrupt.is_set():
return ""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
result = ""
tokenizer_kwargs = {
"add_generation_prompt": True,
} # make sure qwen3 doesn't use thinking
generation_kwargs = {
"max_new_tokens": 512,
}
if "qwen3" in model_name.lower():
print(f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False.")
tokenizer_kwargs["enable_thinking"] = False
try:
if model_name in tokenizer_cache:
tokenizer = tokenizer_cache[model_name]
else:
tokenizer = AutoTokenizer.from_pretrained(
model_name,
padding_side="left",
token=True,
kwargs=tokenizer_kwargs
)
tokenizer_cache[model_name] = tokenizer
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
if tokenizer.chat_template else False # Handle missing chat_template
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Check interrupt before loading the model
if generation_interrupt.is_set():
return ""
pipe = pipeline(
"text-generation",
model=model_name,
tokenizer=tokenizer,
device_map='cuda',
trust_remote_code=True,
torch_dtype=torch.bfloat16,
model_kwargs={
"attn_implementation": "eager",
}
)
text_input = format_rag_prompt(question, context, accepts_sys)
if "Gemma-3".lower() not in model_name.lower():
formatted = pipe.tokenizer.apply_chat_template(
text_input,
tokenize=False,
**tokenizer_kwargs,
)
input_length = len(formatted)
# Check interrupt before generation
outputs = pipe(formatted, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True})
#print(outputs[0]['generated_text'])
result = outputs[0]['generated_text'][input_length:]
else: # don't use apply chat template? I don't know why gemma keeps breaking
result = pipe(text_input, max_new_tokens=512, generation_kwargs={"skip_special_tokens": True})[0]['generated_text']
result = result[0]['generated_text'][-1]['content']
except Exception as e:
print(f"Error in inference for {model_name}: {e}")
result = f"Error generating response: {str(e)[:200]}..."
finally:
# Clean up resources
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result |