orionweller commited on
Commit
0df179d
1 Parent(s): e187e6e
Files changed (2) hide show
  1. app.py +65 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+
5
+ # Model loading and setup
6
+ model_name = "jhu-clsp/FollowIR-7B"
7
+ model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
9
+ tokenizer.pad_token = tokenizer.eos_token
10
+ tokenizer.padding_side = "left"
11
+ token_false_id = tokenizer.get_vocab()["false"]
12
+ token_true_id = tokenizer.get_vocab()["true"]
13
+
14
+ template = """<s> [INST] You are an expert Google searcher, whose job is to determine if the following document is relevant to the query (true/false). Answer using only one word, one of those two choices.
15
+
16
+ Query: {query}
17
+ Document: {text}
18
+ Relevant (only output one word, either "true" or "false"): [/INST] """
19
+
20
+ def check_relevance(query, instruction, passage):
21
+ full_query = f"{query} {instruction}"
22
+ prompt = template.format(query=full_query, text=passage)
23
+
24
+ tokens = tokenizer(
25
+ [prompt],
26
+ padding=True,
27
+ truncation=True,
28
+ return_tensors="pt",
29
+ pad_to_multiple_of=None,
30
+ )
31
+
32
+ for key in tokens:
33
+ tokens[key] = tokens[key].cuda()
34
+
35
+ batch_scores = model(**tokens).logits[:, -1, :]
36
+ true_vector = batch_scores[:, token_true_id]
37
+ false_vector = batch_scores[:, token_false_id]
38
+ batch_scores = torch.stack([false_vector, true_vector], dim=1)
39
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
40
+ score = batch_scores[:, 1].exp().item()
41
+
42
+ return f"{score:.4f}"
43
+
44
+ # Gradio Interface
45
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
46
+ gr.Markdown("# FollowIR Relevance Checker")
47
+ gr.Markdown("This app uses the FollowIR-7B model to determine the relevance of a passage to a given query and instruction.")
48
+
49
+ with gr.Row():
50
+ with gr.Column():
51
+ query_input = gr.Textbox(label="Query", placeholder="Enter your search query here")
52
+ instruction_input = gr.Textbox(label="Instruction", placeholder="Enter additional instructions or criteria")
53
+ passage_input = gr.Textbox(label="Passage", placeholder="Enter the passage to check for relevance", lines=5)
54
+ submit_button = gr.Button("Check Relevance")
55
+
56
+ with gr.Column():
57
+ output = gr.Textbox(label="Relevance Probability")
58
+
59
+ submit_button.click(
60
+ check_relevance,
61
+ inputs=[query_input, instruction_input, passage_input],
62
+ outputs=[output]
63
+ )
64
+
65
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ torch
3
+ transformers