orionweller's picture
Update app.py
c09090d verified
import sys
import os
import shutil
import warnings
import spaces
from threading import Thread
from transformers import TextIteratorStreamer
from functools import partial
from huggingface_hub import snapshot_download
import gradio as gr
import torch
import numpy as np
from model import Rank1
import math
print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")
# Suppress CUDA initialization warning
warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML")
MODEL_PATH = None
reranker = None
@spaces.GPU
def process_input(query: str, passage: str) -> tuple[str, str, str]:
"""Process input through the reranker and return formatted outputs."""
global MODEL_PATH
global reranker
prompt = f"Determine if the following passage is relevant to the query. Answer only with 'true' or 'false'.\n" \
f"Query: {query}\n" \
f"Passage: {passage}\n" \
"<think>"
reranker.model = reranker.model.to("cuda")
inputs = reranker.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=reranker.context_size
).to("cuda")
streamer = TextIteratorStreamer(
reranker.tokenizer,
skip_prompt=True,
skip_special_tokens=False
)
current_text = "<think>"
generation_output = None
def generate_with_output():
nonlocal generation_output
generation_output = reranker.model.generate(
**inputs,
generation_config=reranker.generation_config,
stopping_criteria=reranker.stopping_criteria,
return_dict_in_generate=True,
output_scores=True,
streamer=streamer
)
thread = Thread(target=generate_with_output)
thread.start()
# Stream tokens as they're generated
for new_text in streamer:
current_text += new_text
yield (
"Processing...",
"Processing...",
current_text
)
thread.join()
# Add the stopping sequence and calculate final scores
if "</think>" not in current_text:
current_text += "\n" + reranker.stopping_criteria[0].matched_sequence
with torch.no_grad():
final_scores = generation_output.scores[-1][0]
true_logit = final_scores[reranker.true_token].item()
false_logit = final_scores[reranker.false_token].item()
true_score = math.exp(true_logit)
false_score = math.exp(false_logit)
score = true_score / (true_score + false_score)
yield (
score > 0.5,
score,
current_text
)
# Example inputs
examples = [
[
"What movies were directed by James Cameron?",
"Avatar: The Way of Water is a 2022 American epic science fiction film directed by James Cameron.",
],
[
"What movies were directed by James Cameron?",
"Common symptoms of COVID-19 include fever, cough, fatigue, loss of taste or smell, and difficulty breathing.",
]
]
theme = gr.themes.Soft(
primary_hue="indigo",
font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"],
neutral_hue="slate",
radius_size="lg",
)
with gr.Blocks(theme=theme, css=".red-text { color: red; }") as demo:
gr.Markdown("# Rank1: Test Time Compute in Reranking")
gr.HTML('NOTE: for demo purposes this is a <span style="color: red;">quantized</span> model limited to a <span style="color: red;">1024</span> context length. HF spaces cannot use vLLM so this is <span style="color: red;">significantly slower</span>')
gr.HTML('πŸ“„ Paper Link: <a href="https://arxiv.org/abs/2502.18418" target="_blank">https://arxiv.org/abs/2502.18418</a>')
with gr.Row():
with gr.Column():
query_input = gr.Textbox(
label="Query",
placeholder="Enter your search query here",
lines=2
)
passage_input = gr.Textbox(
label="Passage",
placeholder="Enter the passage to check for relevance",
lines=6
)
submit_button = gr.Button("Check Relevance")
with gr.Column():
relevance_output = gr.Textbox(label="Relevance")
confidence_output = gr.Textbox(label="Confidence")
reasoning_output = gr.Textbox(
label="Model Reasoning",
lines=10,
interactive=False
)
gr.Examples(
examples=examples,
inputs=[query_input, passage_input],
outputs=[relevance_output, confidence_output, reasoning_output],
fn=process_input,
cache_examples=True,
)
submit_button.click(
fn=process_input,
inputs=[query_input, passage_input],
outputs=[relevance_output, confidence_output, reasoning_output],
api_name="predict",
queue=True
)
if __name__ == "__main__":
# download model first, so we don't have to wait for it
MODEL_PATH = snapshot_download(
repo_id="orionweller/rank1-7b-awq",
)
print(f"Downloaded model to: {MODEL_PATH}")
reranker = Rank1(model_name_or_path=MODEL_PATH)
demo.launch(share=False)