Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
import os | |
import shutil | |
import warnings | |
import spaces | |
from threading import Thread | |
from transformers import TextIteratorStreamer | |
from functools import partial | |
from huggingface_hub import snapshot_download | |
import gradio as gr | |
import torch | |
import numpy as np | |
from model import Rank1 | |
import math | |
print(f"NumPy version: {np.__version__}") | |
print(f"PyTorch version: {torch.__version__}") | |
# Suppress CUDA initialization warning | |
warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML") | |
MODEL_PATH = None | |
reranker = None | |
def process_input(query: str, passage: str) -> tuple[str, str, str]: | |
"""Process input through the reranker and return formatted outputs.""" | |
global MODEL_PATH | |
global reranker | |
prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \ | |
f"Query: {query}\n" \ | |
f"Passage: {passage}\n" \ | |
"<think>" | |
reranker.model = reranker.model.to("cuda") | |
inputs = reranker.tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
max_length=reranker.context_size | |
).to("cuda") | |
streamer = TextIteratorStreamer( | |
reranker.tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=False | |
) | |
current_text = "<think>" | |
generation_output = None | |
def generate_with_output(): | |
nonlocal generation_output | |
generation_output = reranker.model.generate( | |
**inputs, | |
generation_config=reranker.generation_config, | |
stopping_criteria=reranker.stopping_criteria, | |
return_dict_in_generate=True, | |
output_scores=True, | |
streamer=streamer | |
) | |
thread = Thread(target=generate_with_output) | |
thread.start() | |
# Stream tokens as they're generated | |
for new_text in streamer: | |
current_text += new_text | |
yield ( | |
"Processing...", | |
"Processing...", | |
current_text | |
) | |
thread.join() | |
# Add the stopping sequence and calculate final scores | |
if "</think>" not in current_text: | |
current_text += "\n" + reranker.stopping_criteria[0].matched_sequence | |
with torch.no_grad(): | |
final_scores = generation_output.scores[-1][0] | |
true_logit = final_scores[reranker.true_token].item() | |
false_logit = final_scores[reranker.false_token].item() | |
true_score = math.exp(true_logit) | |
false_score = math.exp(false_logit) | |
score = true_score / (true_score + false_score) | |
yield ( | |
score > 0.5, | |
score, | |
current_text | |
) | |
# Example inputs | |
examples = [ | |
[ | |
"What movies were directed by James Cameron?", | |
"Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.", | |
], | |
[ | |
"What movies were directed by James Cameron?", | |
"Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.", | |
] | |
] | |
theme = gr.themes.Soft( | |
primary_hue="indigo", | |
font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"], | |
neutral_hue="slate", | |
radius_size="lg", | |
) | |
with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo: | |
gr.Markdown("# Rank1: Test Time Compute in Reranking") | |
gr.HTML('NOTE: for demo purposes this is a <span style="color: red;">quantized</span> model limited to a <span style="color: red;">1024</span> context length. HF spaces cannot use vLLM so this is <span style="color: red;">significantly slower</span>') | |
gr.HTML('π Paper Link: <a href="https://arxiv.org/abs/2502.18418" target="_blank">https://arxiv.org/abs/2502.18418</a>') | |
with gr.Row(): | |
with gr.Column(): | |
query_input = gr.Textbox( | |
label="Query", | |
placeholder="Enter your search query here", | |
lines=2 | |
) | |
passage_input = gr.Textbox( | |
label="Passage", | |
placeholder="Enter the passage to check for relevance", | |
lines=6 | |
) | |
submit_button = gr.Button("Check Relevance") | |
with gr.Column(): | |
relevance_output = gr.Textbox(label="Relevance") | |
confidence_output = gr.Textbox(label="Confidence") | |
reasoning_output = gr.Textbox( | |
label="Model Reasoning", | |
lines=10, | |
interactive=False | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=[query_input, passage_input], | |
outputs=[relevance_output, confidence_output, reasoning_output], | |
fn=process_input, | |
cache_examples=True, | |
) | |
submit_button.click( | |
fn=process_input, | |
inputs=[query_input, passage_input], | |
outputs=[relevance_output, confidence_output, reasoning_output], | |
api_name="predict", | |
queue=True | |
) | |
if __name__ == "__main__": | |
# download model first, so we don't have to wait for it | |
MODEL_PATH = snapshot_download( | |
repo_id="orionweller/rank1-7b-awq", | |
) | |
print(f"Downloaded model to: {MODEL_PATH}") | |
reranker = Rank1(model_name_or_path=MODEL_PATH) | |
demo.launch(share=False) |