Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
import warnings | |
print("Warning: This application requires specific library versions. Please ensure you have the correct versions installed.") | |
import spaces | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import numpy as np | |
print(f"NumPy version: {np.__version__}") | |
print(f"PyTorch version: {torch.__version__}") | |
# Suppress CUDA initialization warning | |
warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML") | |
# Check for GPU availability | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Model loading and setup | |
model_name = "jhu-clsp/FollowIR-7B" | |
try: | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
except ValueError as e: | |
print(f"Error loading model or tokenizer: {e}") | |
print("Please ensure you have the correct versions of transformers and sentencepiece installed.") | |
sys.exit(1) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "left" | |
token_false_id = tokenizer.get_vocab()["false"] | |
token_true_id = tokenizer.get_vocab()["true"] | |
template = """<s> [INST] You are an expert Google searcher, whose job is to determine if the following document is relevant to the query (true/false). Answer using only one word, one of those two choices. | |
Query: {query} | |
Document: {text} | |
Relevant (only output one word, either "true" or "false"): [/INST] """ | |
def check_relevance(query, instruction, passage): | |
global model | |
global tokenizer | |
global template | |
global token_false_id | |
global token_true_id | |
if torch.cuda.is_available(): | |
device = "cuda" | |
model = model.to(device) | |
full_query = f"{query} {instruction}" | |
prompt = template.format(query=full_query, text=passage) | |
tokens = tokenizer( | |
[prompt], | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
pad_to_multiple_of=None, | |
) | |
for key in tokens: | |
tokens[key] = tokens[key].to(device) | |
with torch.no_grad(): | |
batch_scores = model(**tokens).logits[:, -1, :] | |
true_vector = batch_scores[:, token_true_id] | |
false_vector = batch_scores[:, token_false_id] | |
batch_scores = torch.stack([false_vector, true_vector], dim=1) | |
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) | |
score = batch_scores[:, 1].exp().item() | |
return f"{score:.4f}" | |
# Example inputs | |
examples = [ | |
[ | |
"What movies were directed by James Cameron?", | |
"A relevant document would describe any movie that was directed by James Cameron but not any that are co-directed.", | |
"Avatar: The Way of Water is a 2022 American epic science fiction film co-produced and co-directed by James Cameron and Rick Jaffe." | |
], | |
[ | |
"What movies were directed by James Cameron?", | |
"A relevant document would describe any movie that was directed by James Cameron but not any that are co-directed.", | |
"Avatar: The Way of Water is a 2022 American epic science fiction film co-produced and directed by James Cameron. Rick Jaffe helped write the script." | |
] | |
] | |
# Gradio Interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Relevance Using Instructions") | |
gr.Markdown("This app uses the FollowIR-7B model to determine the relevance of a passage to a given query and instruction.") | |
with gr.Row(): | |
with gr.Column(): | |
query_input = gr.Textbox(label="Query", placeholder="Enter your search query here") | |
instruction_input = gr.Textbox(label="Instruction", placeholder="Enter additional instructions or criteria") | |
passage_input = gr.Textbox(label="Passage", placeholder="Enter the passage to check for relevance", lines=5) | |
submit_button = gr.Button("Check Relevance") | |
with gr.Column(): | |
output = gr.Textbox(label="Relevance Probability") | |
gr.Examples( | |
examples=examples, | |
inputs=[query_input, instruction_input, passage_input], | |
outputs=output, | |
fn=check_relevance, | |
cache_examples=True, | |
) | |
submit_button.click( | |
check_relevance, | |
inputs=[query_input, instruction_input, passage_input], | |
outputs=[output] | |
) | |
if __name__ == "__main__": | |
if np.__version__.startswith("2."): | |
print("Error: This application is not compatible with NumPy 2.x. Please downgrade to NumPy < 2.0.0.") | |
sys.exit(1) | |
demo.launch() |