harisankar99 commited on
Commit
374a0ea
1 Parent(s): 36d181b

Added articles and app

Browse files
app.py CHANGED
@@ -1,63 +1,212 @@
1
- import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  for message in client.chat_completion(
31
  messages,
32
- max_tokens=max_tokens,
33
  stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
  ):
37
  token = message.choices[0].delta.content
38
 
39
  response += token
40
  yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  """
45
  demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
 
59
  )
60
-
61
-
62
  if __name__ == "__main__":
63
- demo.launch()
 
1
+ # import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
4
+ # """
5
+ # For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
+ # """
7
+ # client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+
10
+ # def respond(
11
+ # message,
12
+ # history: list[tuple[str, str]],
13
+ # system_message,
14
+ # max_tokens,
15
+ # temperature,
16
+ # top_p,
17
+ # ):
18
+ # messages = [{"role": "system", "content": system_message}]
19
+
20
+ # for val in history:
21
+ # if val[0]:
22
+ # messages.append({"role": "user", "content": val[0]})
23
+ # if val[1]:
24
+ # messages.append({"role": "assistant", "content": val[1]})
25
+
26
+ # messages.append({"role": "user", "content": message})
27
+
28
+ # response = ""
29
+
30
+ # for message in client.chat_completion(
31
+ # messages,
32
+ # max_tokens=max_tokens,
33
+ # stream=True,
34
+ # temperature=temperature,
35
+ # top_p=top_p,
36
+ # ):
37
+ # token = message.choices[0].delta.content
38
+
39
+ # response += token
40
+ # yield response
41
+
42
+ # """
43
+ # For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
+ # """
45
+ # demo = gr.ChatInterface(
46
+ # respond,
47
+ # additional_inputs=[
48
+ # gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
+ # gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
+ # gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
+ # gr.Slider(
52
+ # minimum=0.1,
53
+ # maximum=1.0,
54
+ # value=0.95,
55
+ # step=0.05,
56
+ # label="Top-p (nucleus sampling)",
57
+ # ),
58
+ # ],
59
+ # )
60
+
61
+
62
+ # if __name__ == "__main__":
63
+ # demo.launch()
64
+ import gradio as gr
65
+ import os
66
+ import spaces
67
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
68
+ import torch
69
+ from threading import Thread
70
+ from datasets import load_from_disk
71
+ import time
72
+ from sentence_transformers import SentenceTransformer
73
+ ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
74
+ dataset =load_from_disk('./articles_embedded')
75
+
76
+ data = dataset
77
+ data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
78
+
79
+ model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
80
+ client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
81
+ # model_id = r"D:\Meta-Llama-3-8B-Instruct"
82
+
83
+ # use quantization to lower GPU usage
84
+ # bnb_config = BitsAndBytesConfig(
85
+ # load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
86
+ # )
87
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
88
+ # model = AutoModelForCausalLM.from_pretrained(
89
+ # model_id,
90
+ # torch_dtype=torch.bfloat16,
91
+ # device_map="auto",
92
+ # quantization_config=bnb_config,
93
+ # )
94
+ # model = AutoModelForCausalLM.from_pretrained(
95
+ # "microsoft/Phi-3-mini-4k-instruct",
96
+ # device_map="cuda",
97
+ # torch_dtype="auto",
98
+ # trust_remote_code=True,
99
+ # )
100
+ # tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
101
+ # terminators = [
102
+ # tokenizer.eos_token_id,
103
+ # tokenizer.convert_tokens_to_ids("<|eot_id|>")
104
+ # ]
105
 
106
+ SYS_PROMPT = """You are an battery assistant named EVolve made by the company Lime.ai in Bangalore for answering questions only related to batteries and EV industry.
107
+ You are given the extracted parts of a long document and a question. Provide a conversational answer.
108
+ If context is more than than 50 percent related to question give any doi or reference if prompted else dont give even if prompted.
109
+ If you don't know the answer,Say I dont know.""Try to answer out of context if you are more than 50 percent confident else say I dont know. Do not mention about what is context you have been given. Answer should be such that you already know the context and you are not reading from the context.Don't make up an answer."""#just say "I do not know." Don't make up an answer."""
110
 
 
 
 
 
 
 
 
 
 
111
 
 
 
 
 
 
112
 
 
113
 
114
+ def search(query: str, k: int = 3 ):
115
+ """a function that embeds a new query and returns the most probable results"""
116
+ embedded_query = ST.encode(query) # embed new query
117
+ scores, retrieved_examples = data.get_nearest_examples( # retrieve results
118
+ "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
119
+ k=k # get only top k results
120
+ )
121
+ return scores, retrieved_examples
122
+ def format_prompt(prompt,retrieved_documents,k):
123
+ """using the retrieved documents we will prompt the model to generate our responses"""
124
+ PROMPT = f"Question:{prompt}+Tell me the reference and doi from where you have taken the answer if it is available.\nContext:"
125
+ for idx in range(k) :
126
+ PROMPT+= f"Reference: "+str(retrieved_documents['title'][idx])+"\n doi: "+str(retrieved_documents['doi'][idx])+"\n Authors:"+str(retrieved_documents['author'][idx])+"\n Page Number:"+str(retrieved_documents['pages'][idx])+"\n Content: "+str(retrieved_documents['text'][idx])+"\n"
127
+ return PROMPT
128
 
129
+
130
+ @spaces.GPU(duration=150)
131
+ def talk(prompt,history):
132
+ k = 1 # number of retrieved documents
133
+ scores , retrieved_documents = search(prompt, k)
134
+ formatted_prompt = format_prompt(prompt,retrieved_documents,k)
135
+ formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
136
+ messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
137
+ # tell the model to generate
138
+ # input_ids = tokenizer.apply_chat_template(
139
+ # messages,
140
+ # add_generation_prompt=True,
141
+ # return_tensors="pt"
142
+ # ).to(model.device)
143
  for message in client.chat_completion(
144
  messages,
145
+ max_tokens=1024,
146
  stream=True,
147
+ temperature=0.6,
148
+ top_p=0.9,
149
  ):
150
  token = message.choices[0].delta.content
151
 
152
  response += token
153
  yield response
154
+ # outputs = model.generate(
155
+ # input_ids,
156
+ # max_new_tokens=1024,
157
+ # eos_token_id=terminators,
158
+ # do_sample=True,
159
+ # temperature=0.6,
160
+ # top_p=0.9,
161
+ # )
162
+ # streamer = TextIteratorStreamer(
163
+ # tokenizer, timeout=10000.0, skip_prompt=True, skip_special_tokens=True
164
+ # )
165
+ # generate_kwargs = dict(
166
+ # input_ids= input_ids,
167
+ # streamer=streamer,
168
+ # max_new_tokens=1024,
169
+ # do_sample=True,
170
+ # top_p=0.95,
171
+ # temperature=0.75,
172
+ # eos_token_id=terminators,
173
+ # )
174
+ # t = Thread(target=model.generate, kwargs=generate_kwargs)
175
+ # t.start()
176
 
177
+ # outputs = []
178
+ # for text in streamer:
179
+ # outputs.append(text)
180
+ # print(outputs)
181
+ # yield "".join(outputs)
182
+
183
+
184
+ TITLE = "EVolve AI"
185
+
186
+ DESCRIPTION = """
187
+ This is a project by Lime.ai
188
+
189
+ Resources used to build this project :
190
+
191
+ * embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
192
+ * faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index
193
+ * chatbot : https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
194
  """
195
  demo = gr.ChatInterface(
196
+ fn=talk,
197
+ chatbot=gr.Chatbot(
198
+ show_label=True,
199
+ show_share_button=True,
200
+ show_copy_button=True,
201
+ likeable=True,
202
+ layout="bubble",
203
+ bubble_full_width=False,
204
+ ),
205
+ theme="Soft",
206
+ examples=[["What are the reasons of capacity fade due to LAM and LLI?"],["How much cycles does Li-air batteries last before degradation?"],["What are different types of battery chemistries?"]],
207
+ title=TITLE,
208
+ description=DESCRIPTION,
209
+
210
  )
 
 
211
  if __name__ == "__main__":
212
+ demo.launch(debug=True,share=True)
articles_embedded/data-00000-of-00001.arrow ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5213cd83fb8bab5ddcb64c85b9c7e7b66049a15d69a96d6ef92a64ed70b82267
3
+ size 5332744
articles_embedded/dataset_info.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "citation": "",
3
+ "description": "",
4
+ "features": {
5
+ "doi": {
6
+ "dtype": "string",
7
+ "_type": "Value"
8
+ },
9
+ "title": {
10
+ "dtype": "string",
11
+ "_type": "Value"
12
+ },
13
+ "author": {
14
+ "dtype": "string",
15
+ "_type": "Value"
16
+ },
17
+ "pages": {
18
+ "dtype": "int64",
19
+ "_type": "Value"
20
+ },
21
+ "text": {
22
+ "dtype": "string",
23
+ "_type": "Value"
24
+ },
25
+ "embeddings": {
26
+ "feature": {
27
+ "dtype": "float32",
28
+ "_type": "Value"
29
+ },
30
+ "_type": "Sequence"
31
+ }
32
+ },
33
+ "homepage": "",
34
+ "license": ""
35
+ }
articles_embedded/state.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_data_files": [
3
+ {
4
+ "filename": "data-00000-of-00001.arrow"
5
+ }
6
+ ],
7
+ "_fingerprint": "c280083923fa8c13",
8
+ "_format_columns": null,
9
+ "_format_kwargs": {},
10
+ "_format_type": null,
11
+ "_output_all_columns": false,
12
+ "_split": null
13
+ }