Upload attribution.py

#2
by jshenoy - opened
Files changed (1) hide show
  1. attribution.py +145 -0
attribution.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+
4
+ import os
5
+ import spaces
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
7
+ import torch
8
+ from threading import Thread
9
+ from sentence_transformers import SentenceTransformer
10
+ from datasets import load_dataset
11
+ import time
12
+
13
+ token = os.environ["HF_TOKEN"]
14
+ ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
15
+
16
+ dataset = load_dataset("AI-4-Health/embedded-dataset")
17
+
18
+ data = dataset["train"]
19
+ data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
20
+
21
+
22
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
23
+
24
+ # use quantization to lower GPU usage
25
+ bnb_config = BitsAndBytesConfig(
26
+ load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
27
+ )
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained(model_id,token=token)
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ model_id,
32
+ torch_dtype=torch.bfloat16,
33
+ device_map="auto",
34
+ quantization_config=bnb_config,
35
+ token=token
36
+ )
37
+ terminators = [
38
+ tokenizer.eos_token_id,
39
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
40
+ ]
41
+
42
+ SYS_PROMPT = """You are an assistant for answering questions.
43
+ You are given the extracted parts of a long document and a question. Provide a conversational answer.
44
+ If you don't know the answer, just say "I do not know." Don't make up an answer."""
45
+
46
+
47
+
48
+ def search(query: str, k: int = 3 ):
49
+ """a function that embeds a new query and returns the most probable results"""
50
+ embedded_query = ST.encode(query) # embed new query
51
+ scores, retrieved_examples = data.get_nearest_examples( # retrieve results
52
+ "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
53
+ k=k # get only top k results
54
+ )
55
+ return scores, retrieved_examples
56
+
57
+ def format_prompt(prompt,retrieved_documents,k):
58
+ """using the retrieved documents we will prompt the model to generate our responses"""
59
+ PROMPT = f"Question:{prompt}\nContext:"
60
+ for idx in range(k) :
61
+ PROMPT+= f"{retrieved_documents['text'][idx]}\n"
62
+ return PROMPT
63
+
64
+
65
+
66
+ TITLE = "# RAG"
67
+
68
+ DESCRIPTION = """
69
+ HPP Chatbot
70
+ """
71
+ @spaces.GPU(duration=150)
72
+ def talk(prompt):
73
+ k = 1 # number of retrieved documents
74
+ scores, retrieved_documents = search(prompt, k)
75
+ filename = retrieved_documents['filename'][0] # Assuming filename is in the returned dictionary
76
+ print("filename is ", filename)
77
+ formatted_prompt = format_prompt(prompt, retrieved_documents, k)
78
+ formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
79
+ messages = [{"role":"system", "content":SYS_PROMPT}, {"role":"user", "content":formatted_prompt}]
80
+ # Tell the model to generate
81
+ input_ids = tokenizer.apply_chat_template(
82
+ messages,
83
+ add_generation_prompt=True,
84
+ return_tensors="pt"
85
+ ).to(model.device)
86
+ outputs = model.generate(
87
+ input_ids,
88
+ max_new_tokens=1024,
89
+ eos_token_id=terminators,
90
+ do_sample=True,
91
+ temperature=0.6,
92
+ top_p=0.9,
93
+ )
94
+ streamer = TextIteratorStreamer(
95
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
96
+ )
97
+ generate_kwargs = dict(
98
+ input_ids=input_ids,
99
+ streamer=streamer,
100
+ max_new_tokens=1024,
101
+ do_sample=True,
102
+ top_p=0.95,
103
+ temperature=0.75,
104
+ eos_token_id=terminators,
105
+ )
106
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
107
+ t.start()
108
+
109
+ outputs = []
110
+ for text in streamer:
111
+ outputs.append(text)
112
+ #print(outputs)
113
+ return "".join(outputs), filename, filename
114
+
115
+ def update_document(filename):
116
+ # Reads the content of the specified file for display
117
+ with open('datasets/'+filename, "r", encoding='iso-8859-15') as file:
118
+ content = file.read()
119
+ return content
120
+
121
+ TITLE = "# RAG"
122
+ DESCRIPTION = """
123
+ HPP Chatbot
124
+ """
125
+ with gr.Blocks() as demo:
126
+ with gr.Row():
127
+ prompt_input = gr.Textbox(label="Enter your prompt")
128
+ submit_button = gr.Button("Submit")
129
+ chat_output = gr.Textbox(label="Chat Response", lines=5)
130
+ filename = gr.Textbox(label="File Name", lines=1)
131
+ file_display = gr.Textbox(label="File Content", lines=10)
132
+
133
+ submit_button.click(
134
+ fn=talk,
135
+ inputs=prompt_input,
136
+ outputs=[chat_output, filename, file_display]
137
+ )
138
+
139
+ file_display.change(
140
+ fn=update_document,
141
+ inputs=filename,
142
+ outputs=file_display
143
+ )
144
+
145
+ demo.launch(debug=True, share=True)