NCTCMumbai commited on
Commit
91fad79
·
verified ·
1 Parent(s): 7a71d68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +437 -437
app.py CHANGED
@@ -1,437 +1,437 @@
1
-
2
- from ragatouille import RAGPretrainedModel
3
- import subprocess
4
- import json
5
- import spaces
6
- import firebase_admin
7
- from firebase_admin import credentials, firestore
8
- import logging
9
- from pathlib import Path
10
- from time import perf_counter
11
- from datetime import datetime
12
- import gradio as gr
13
- from jinja2 import Environment, FileSystemLoader
14
- import numpy as np
15
- from sentence_transformers import CrossEncoder
16
- from huggingface_hub import InferenceClient
17
- from os import getenv
18
-
19
- from backend.query_llm import generate_hf, generate_openai
20
- from backend.semantic_search import table, retriever
21
- from huggingface_hub import InferenceClient
22
-
23
-
24
- VECTOR_COLUMN_NAME = "vector"
25
- TEXT_COLUMN_NAME = "text"
26
- HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
27
- proj_dir = Path(__file__).parent
28
- # Setting up the logging
29
- logging.basicConfig(level=logging.INFO)
30
- logger = logging.getLogger(__name__)
31
- client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1",token=HF_TOKEN)
32
- # Set up the template environment with the templates directory
33
- env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
34
-
35
- # Load the templates directly from the environment
36
- template = env.get_template('template.j2')
37
- template_html = env.get_template('template_html.j2')
38
- #___________________
39
- # service_account_key='firebase.json'
40
- # # Create a Certificate object from the service account info
41
- # cred = credentials.Certificate(service_account_key)
42
- # # Initialize the Firebase Admin
43
- # firebase_admin.initialize_app(cred)
44
-
45
- # # # Create a reference to the Firestore database
46
- # db = firestore.client()
47
- # #db usage
48
- # collection_name = 'Nirvachana' # Replace with your collection name
49
- # field_name = 'message_count' # Replace with your field name for count
50
- # Examples
51
- examples = ['Tabulate the difference between veins and arteries','What are defects in Human eye?',
52
- 'Frame 5 short questions and 5 MCQ on Chapter 2 ','Suggest creative and engaging ideas to teach students on Chapter on Metals and Non Metals '
53
- ]
54
-
55
-
56
-
57
- # def get_and_increment_value_count(db , collection_name, field_name):
58
- # """
59
- # Retrieves a value count from the specified Firestore collection and field,
60
- # increments it by 1, and updates the field with the new value."""
61
- # collection_ref = db.collection(collection_name)
62
- # doc_ref = collection_ref.document('count_doc') # Assuming a dedicated document for count
63
-
64
- # # Use a transaction to ensure consistency across reads and writes
65
- # try:
66
- # with db.transaction() as transaction:
67
- # # Get the current value count (or initialize to 0 if it doesn't exist)
68
- # current_count_doc = doc_ref.get()
69
- # current_count_data = current_count_doc.to_dict()
70
- # if current_count_data:
71
- # current_count = current_count_data.get(field_name, 0)
72
- # else:
73
- # current_count = 0
74
- # # Increment the count
75
- # new_count = current_count + 1
76
- # # Update the document with the new count
77
- # transaction.set(doc_ref, {field_name: new_count})
78
- # return new_count
79
- # except Exception as e:
80
- # print(f"Error retrieving and updating value count: {e}")
81
- # return None # Indicate error
82
-
83
- # def update_count_html():
84
- # usage_count = get_and_increment_value_count(db ,collection_name, field_name)
85
- # ccount_html = gr.HTML(value=f"""
86
- # <div style="display: flex; justify-content: flex-end;">
87
- # <span style="font-weight: bold; color: maroon; font-size: 18px;">No of Usages:</span>
88
- # <span style="font-weight: bold; color: maroon; font-size: 18px;">{usage_count}</span>
89
- # </div>
90
- # """)
91
- # return count_html
92
-
93
- # def store_message(db,query,answer,cross_encoder):
94
- # timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
95
- # # Create a new document reference with a dynamic document name based on timestamp
96
- # new_completion= db.collection('Nirvachana').document(f"chatlogs_{timestamp}")
97
- # new_completion.set({
98
- # 'query': query,
99
- # 'answer':answer,
100
- # 'created_time': firestore.SERVER_TIMESTAMP,
101
- # 'embedding': cross_encoder,
102
- # 'title': 'Expenditure observer bot'
103
- # })
104
-
105
-
106
- def add_text(history, text):
107
- history = [] if history is None else history
108
- history = history + [(text, None)]
109
- return history, gr.Textbox(value="", interactive=False)
110
-
111
-
112
- def bot(history, cross_encoder):
113
- top_rerank = 25
114
- top_k_rank = 20
115
- query = history[-1][0]
116
-
117
- if not query:
118
- gr.Warning("Please submit a non-empty string as a prompt")
119
- raise ValueError("Empty string was submitted")
120
-
121
- logger.warning('Retrieving documents...')
122
-
123
- # if COLBERT RAGATATOUILLE PROCEDURE :
124
- if cross_encoder=='(HIGH ACCURATE) ColBERT':
125
- gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait')
126
- RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
127
- RAG_db=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
128
- documents_full=RAG_db.search(query,k=top_k_rank)
129
-
130
- documents=[item['content'] for item in documents_full]
131
- # Create Prompt
132
- prompt = template.render(documents=documents, query=query)
133
- prompt_html = template_html.render(documents=documents, query=query)
134
-
135
- generate_fn = generate_hf
136
-
137
- history[-1][1] = ""
138
- for character in generate_fn(prompt, history[:-1]):
139
- history[-1][1] = character
140
- yield history, prompt_html
141
- print('Final history is ',history)
142
- #store_message(db,history[-1][0],history[-1][1],cross_encoder)
143
- else:
144
- # Retrieve documents relevant to query
145
- document_start = perf_counter()
146
-
147
- query_vec = retriever.encode(query)
148
- logger.warning(f'Finished query vec')
149
- doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
150
-
151
-
152
-
153
- logger.warning(f'Finished search')
154
- documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
155
- documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
156
- logger.warning(f'start cross encoder {len(documents)}')
157
- # Retrieve documents relevant to query
158
- query_doc_pair = [[query, doc] for doc in documents]
159
- if cross_encoder=='(FAST) MiniLM-L6v2' :
160
- cross_encoder1 = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
161
- elif cross_encoder=='(ACCURATE) BGE reranker':
162
- cross_encoder1 = CrossEncoder('BAAI/bge-reranker-base')
163
-
164
- cross_scores = cross_encoder1.predict(query_doc_pair)
165
- sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
166
- logger.warning(f'Finished cross encoder {len(documents)}')
167
-
168
- documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
169
- logger.warning(f'num documents {len(documents)}')
170
-
171
- document_time = perf_counter() - document_start
172
- logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
173
-
174
- # Create Prompt
175
- prompt = template.render(documents=documents, query=query)
176
- prompt_html = template_html.render(documents=documents, query=query)
177
-
178
- generate_fn = generate_hf
179
-
180
- history[-1][1] = ""
181
- for character in generate_fn(prompt, history[:-1]):
182
- history[-1][1] = character
183
- yield history, prompt_html
184
- print('Final history is ',history)
185
- #store_message(db,history[-1][0],history[-1][1],cross_encoder)
186
-
187
- def system_instructions(question_difficulty, topic,documents_str):
188
- return f"""<s> [INST] Your are a great teacher and your task is to create 10 questions with 4 choices with a {question_difficulty} difficulty about topic request " {topic} " only from the below given documents, {documents_str} then create an answers. Index in JSON format, the questions as "Q#":"" to "Q#":"", the four choices as "Q#:C1":"" to "Q#:C4":"", and the answers as "A#":"Q#:C#" to "A#":"Q#:C#". [/INST]"""
189
-
190
-
191
- #with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
192
- with gr.Blocks(theme='NoCrypt/miku') as CHATBOT:
193
- with gr.Row():
194
- with gr.Column(scale=10):
195
- # gr.Markdown(
196
- # """
197
- # # Theme preview: `paris`
198
- # To use this theme, set `theme='earneleh/paris'` in `gr.Blocks()` or `gr.Interface()`.
199
- # You can append an `@` and a semantic version expression, e.g. @>=1.0.0,<2.0.0 to pin to a given version
200
- # of this theme.
201
- # """
202
- # )
203
- gr.HTML(value="""<div style="color: #FF4500;"><h1>CHEERFULL CBSE-</h1> <h1><span style="color: #008000">AI Assisted Fun Learning</span></h1>
204
- </div>""", elem_id='heading')
205
-
206
- gr.HTML(value=f"""
207
- <p style="font-family: sans-serif; font-size: 16px;">
208
- A free Artificial Intelligence Chatbot assistant trained on CBSE Class 10 Science Notes to engage and help students and teachers of Puducherry.
209
- </p>
210
- """, elem_id='Sub-heading')
211
- #usage_count = get_and_increment_value_count(db,collection_name, field_name)
212
- gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 14px;">Developed by K M Ramyasri , TGT,GHS.SUTHUKENY . Suggestions may be sent to <a href="mailto:ramyadevi1607@yahoo.com" style="color: #00008B; font-style: italic;">ramyadevi1607@yahoo.com</a>.</p>""", elem_id='Sub-heading1 ')
213
-
214
- with gr.Column(scale=3):
215
- gr.Image(value='logo.png',height=200,width=200)
216
-
217
-
218
- # gr.HTML(value="""<div style="color: #FF4500;"><h1>CHEERFULL CBSE-</h1> <h1><span style="color: #008000">AI Assisted Fun Learning</span></h1>
219
- # <img src='logo.png' alt="Chatbot" width="50" height="50" />
220
- # </div>""", elem_id='heading')
221
-
222
- # gr.HTML(value=f"""
223
- # <p style="font-family: sans-serif; font-size: 16px;">
224
- # A free Artificial Intelligence Chatbot assistant trained on CBSE Class 10 Science Notes to engage and help students and teachers of Puducherry.
225
- # </p>
226
- # """, elem_id='Sub-heading')
227
- # #usage_count = get_and_increment_value_count(db,collection_name, field_name)
228
- # gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 16px;">Developed by K M Ramyasri , PGT . Suggestions may be sent to <a href="mailto:ramyadevi1607@yahoo.com" style="color: #00008B; font-style: italic;">ramyadevi1607@yahoo.com</a>.</p>""", elem_id='Sub-heading1 ')
229
- # # count_html = gr.HTML(value=f"""
230
- # # <div style="display: flex; justify-content: flex-end;">
231
- # # <span style="font-weight: bold; color: maroon; font-size: 18px;">No of Usages:</span>
232
- # # <span style="font-weight: bold; color: maroon; font-size: 18px;">{usage_count}</span>
233
- # # </div>
234
- # # """)
235
-
236
- chatbot = gr.Chatbot(
237
- [],
238
- elem_id="chatbot",
239
- avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
240
- 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
241
- bubble_full_width=False,
242
- show_copy_button=True,
243
- show_share_button=True,
244
- )
245
-
246
- with gr.Row():
247
- txt = gr.Textbox(
248
- scale=3,
249
- show_label=False,
250
- placeholder="Enter text and press enter",
251
- container=False,
252
- )
253
- txt_btn = gr.Button(value="Submit text", scale=1)
254
-
255
- cross_encoder = gr.Radio(choices=['(FAST) MiniLM-L6v2','(ACCURATE) BGE reranker','(HIGH ACCURATE) ColBERT'], value='(ACCURATE) BGE reranker',label="Embeddings", info="Only First query to Colbert may take litte time)")
256
-
257
- prompt_html = gr.HTML()
258
- # Turn off interactivity while generating if you click
259
- txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
260
- bot, [chatbot, cross_encoder], [chatbot, prompt_html])#.then(update_count_html,[],[count_html])
261
-
262
- # Turn it back on
263
- txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
264
-
265
- # Turn off interactivity while generating if you hit enter
266
- txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
267
- bot, [chatbot, cross_encoder], [chatbot, prompt_html])#.then(update_count_html,[],[count_html])
268
-
269
- # Turn it back on
270
- txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
271
-
272
- # Examples
273
- gr.Examples(examples, txt)
274
-
275
-
276
- RAG_db=gr.State()
277
-
278
- with gr.Blocks(title="Quiz Maker", theme=gr.themes.Default(primary_hue="green", secondary_hue="green"), css="style.css") as QUIZBOT:
279
- def load_model():
280
- RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
281
- RAG_db.value=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
282
- return 'Ready to Go!!'
283
- with gr.Column(scale=4):
284
- gr.HTML("""
285
- <center>
286
- <h1><span style="color: purple;">AI NANBAN</span> - CBSE Class Quiz Maker</h1>
287
- <h2>AI-powered Learning Game</h2>
288
- <i>⚠️ Students create quiz from any topic /CBSE Chapter ! ⚠️</i>
289
- </center>
290
- """)
291
- #gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait')
292
- with gr.Column(scale=2):
293
- load_btn = gr.Button("Click to Load!🚀")
294
- load_text=gr.Textbox()
295
- load_btn.click(load_model,[],load_text)
296
-
297
-
298
- topic = gr.Textbox(label="Enter the Topic for Quiz", placeholder="Write any topic from CBSE notes")
299
-
300
- with gr.Row():
301
- radio = gr.Radio(
302
- ["easy", "average", "hard"], label="How difficult should the quiz be?"
303
- )
304
-
305
-
306
- generate_quiz_btn = gr.Button("Generate Quiz!🚀")
307
- quiz_msg=gr.Textbox()
308
-
309
- question_radios = [gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
310
- visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
311
- visible=False), gr.Radio(visible=False), gr.Radio(visible=False)]
312
-
313
- print(question_radios)
314
-
315
- @spaces.GPU
316
- @generate_quiz_btn.click(inputs=[radio, topic], outputs=[quiz_msg]+question_radios, api_name="generate_quiz")
317
- def generate_quiz(question_difficulty, topic):
318
- top_k_rank=10
319
- RAG_db_=RAG_db.value
320
- documents_full=RAG_db_.search(topic,k=top_k_rank)
321
-
322
-
323
-
324
- generate_kwargs = dict(
325
- temperature=0.2,
326
- max_new_tokens=4000,
327
- top_p=0.95,
328
- repetition_penalty=1.0,
329
- do_sample=True,
330
- seed=42,
331
- )
332
- question_radio_list = []
333
- count=0
334
- while count<=3:
335
- try:
336
- documents=[item['content'] for item in documents_full]
337
- document_summaries = [f"[DOCUMENT {i+1}]: {summary}{count}" for i, summary in enumerate(documents)]
338
- documents_str='\n'.join(document_summaries)
339
- formatted_prompt = system_instructions(
340
- question_difficulty, topic,documents_str)
341
- print(formatted_prompt)
342
- pre_prompt = [
343
- {"role": "system", "content": formatted_prompt}
344
- ]
345
- response = client.text_generation(
346
- formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=False,
347
- )
348
- output_json = json.loads(f"{response}")
349
-
350
-
351
- print(response)
352
- print('output json', output_json)
353
-
354
- global quiz_data
355
-
356
- quiz_data = output_json
357
-
358
-
359
-
360
- for question_num in range(1, 11):
361
- question_key = f"Q{question_num}"
362
- answer_key = f"A{question_num}"
363
-
364
- question = quiz_data.get(question_key)
365
- answer = quiz_data.get(quiz_data.get(answer_key))
366
-
367
- if not question or not answer:
368
- continue
369
-
370
- choice_keys = [f"{question_key}:C{i}" for i in range(1, 5)]
371
- choice_list = []
372
- for choice_key in choice_keys:
373
- choice = quiz_data.get(choice_key, "Choice not found")
374
- choice_list.append(f"{choice}")
375
-
376
- radio = gr.Radio(choices=choice_list, label=question,
377
- visible=True, interactive=True)
378
-
379
- question_radio_list.append(radio)
380
- if len(question_radio_list)==10:
381
- break
382
- else:
383
- print('10 questions not generated . So trying again!')
384
- count+=1
385
- continue
386
- except Exception as e:
387
- count+=1
388
- print(f"Exception occurred: {e}")
389
- if count==3:
390
- print('Retry exhausted')
391
- gr.Warning('Sorry. Pls try with another topic !')
392
- else:
393
- print(f"Trying again..{count} time...please wait")
394
- continue
395
-
396
- print('Question radio list ' , question_radio_list)
397
-
398
- return ['Quiz Generated!']+ question_radio_list
399
-
400
- check_button = gr.Button("Check Score")
401
-
402
- score_textbox = gr.Markdown()
403
-
404
- @check_button.click(inputs=question_radios, outputs=score_textbox)
405
- def compare_answers(*user_answers):
406
- user_anwser_list = []
407
- user_anwser_list = user_answers
408
-
409
- answers_list = []
410
-
411
- for question_num in range(1, 20):
412
- answer_key = f"A{question_num}"
413
- answer = quiz_data.get(quiz_data.get(answer_key))
414
- if not answer:
415
- break
416
- answers_list.append(answer)
417
-
418
- score = 0
419
-
420
- for item in user_anwser_list:
421
- if item in answers_list:
422
- score += 1
423
- if score>5:
424
- message = f"### Good ! You got {score} over 10!"
425
- elif score>7:
426
- message = f"### Excellent ! You got {score} over 10!"
427
- else:
428
- message = f"### You got {score} over 10! Dont worry . You can prepare well and try better next time !"
429
-
430
- return message
431
-
432
-
433
-
434
- demo = gr.TabbedInterface([CHATBOT,QUIZBOT], ["AI ChatBot", "AI Nanban-Quizbot"])
435
-
436
- demo.queue()
437
- demo.launch(debug=True)
 
1
+
2
+ from ragatouille import RAGPretrainedModel
3
+ import subprocess
4
+ import json
5
+ import spaces
6
+ import firebase_admin
7
+ from firebase_admin import credentials, firestore
8
+ import logging
9
+ from pathlib import Path
10
+ from time import perf_counter
11
+ from datetime import datetime
12
+ import gradio as gr
13
+ from jinja2 import Environment, FileSystemLoader
14
+ import numpy as np
15
+ from sentence_transformers import CrossEncoder
16
+ from huggingface_hub import InferenceClient
17
+ from os import getenv
18
+
19
+ from backend.query_llm import generate_hf, generate_openai
20
+ from backend.semantic_search import table, retriever
21
+ from huggingface_hub import InferenceClient
22
+
23
+
24
+ VECTOR_COLUMN_NAME = "vector"
25
+ TEXT_COLUMN_NAME = "text"
26
+ HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
27
+ proj_dir = Path(__file__).parent
28
+ # Setting up the logging
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger(__name__)
31
+ client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1",token=HF_TOKEN)
32
+ # Set up the template environment with the templates directory
33
+ env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
34
+
35
+ # Load the templates directly from the environment
36
+ template = env.get_template('template.j2')
37
+ template_html = env.get_template('template_html.j2')
38
+ #___________________
39
+ # service_account_key='firebase.json'
40
+ # # Create a Certificate object from the service account info
41
+ # cred = credentials.Certificate(service_account_key)
42
+ # # Initialize the Firebase Admin
43
+ # firebase_admin.initialize_app(cred)
44
+
45
+ # # # Create a reference to the Firestore database
46
+ # db = firestore.client()
47
+ # #db usage
48
+ # collection_name = 'Nirvachana' # Replace with your collection name
49
+ # field_name = 'message_count' # Replace with your field name for count
50
+ # Examples
51
+ examples = ['Tabulate the difference between veins and arteries','What are defects in Human eye?',
52
+ 'Frame 5 short questions and 5 MCQ on Chapter 2 ','Suggest creative and engaging ideas to teach students on Chapter on Metals and Non Metals '
53
+ ]
54
+
55
+
56
+
57
+ # def get_and_increment_value_count(db , collection_name, field_name):
58
+ # """
59
+ # Retrieves a value count from the specified Firestore collection and field,
60
+ # increments it by 1, and updates the field with the new value."""
61
+ # collection_ref = db.collection(collection_name)
62
+ # doc_ref = collection_ref.document('count_doc') # Assuming a dedicated document for count
63
+
64
+ # # Use a transaction to ensure consistency across reads and writes
65
+ # try:
66
+ # with db.transaction() as transaction:
67
+ # # Get the current value count (or initialize to 0 if it doesn't exist)
68
+ # current_count_doc = doc_ref.get()
69
+ # current_count_data = current_count_doc.to_dict()
70
+ # if current_count_data:
71
+ # current_count = current_count_data.get(field_name, 0)
72
+ # else:
73
+ # current_count = 0
74
+ # # Increment the count
75
+ # new_count = current_count + 1
76
+ # # Update the document with the new count
77
+ # transaction.set(doc_ref, {field_name: new_count})
78
+ # return new_count
79
+ # except Exception as e:
80
+ # print(f"Error retrieving and updating value count: {e}")
81
+ # return None # Indicate error
82
+
83
+ # def update_count_html():
84
+ # usage_count = get_and_increment_value_count(db ,collection_name, field_name)
85
+ # ccount_html = gr.HTML(value=f"""
86
+ # <div style="display: flex; justify-content: flex-end;">
87
+ # <span style="font-weight: bold; color: maroon; font-size: 18px;">No of Usages:</span>
88
+ # <span style="font-weight: bold; color: maroon; font-size: 18px;">{usage_count}</span>
89
+ # </div>
90
+ # """)
91
+ # return count_html
92
+
93
+ # def store_message(db,query,answer,cross_encoder):
94
+ # timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
95
+ # # Create a new document reference with a dynamic document name based on timestamp
96
+ # new_completion= db.collection('Nirvachana').document(f"chatlogs_{timestamp}")
97
+ # new_completion.set({
98
+ # 'query': query,
99
+ # 'answer':answer,
100
+ # 'created_time': firestore.SERVER_TIMESTAMP,
101
+ # 'embedding': cross_encoder,
102
+ # 'title': 'Expenditure observer bot'
103
+ # })
104
+
105
+
106
+ def add_text(history, text):
107
+ history = [] if history is None else history
108
+ history = history + [(text, None)]
109
+ return history, gr.Textbox(value="", interactive=False)
110
+
111
+
112
+ def bot(history, cross_encoder):
113
+ top_rerank = 25
114
+ top_k_rank = 20
115
+ query = history[-1][0]
116
+
117
+ if not query:
118
+ gr.Warning("Please submit a non-empty string as a prompt")
119
+ raise ValueError("Empty string was submitted")
120
+
121
+ logger.warning('Retrieving documents...')
122
+
123
+ # if COLBERT RAGATATOUILLE PROCEDURE :
124
+ if cross_encoder=='(HIGH ACCURATE) ColBERT':
125
+ gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait')
126
+ RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
127
+ RAG_db=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
128
+ documents_full=RAG_db.search(query,k=top_k_rank)
129
+
130
+ documents=[item['content'] for item in documents_full]
131
+ # Create Prompt
132
+ prompt = template.render(documents=documents, query=query)
133
+ prompt_html = template_html.render(documents=documents, query=query)
134
+
135
+ generate_fn = generate_hf
136
+
137
+ history[-1][1] = ""
138
+ for character in generate_fn(prompt, history[:-1]):
139
+ history[-1][1] = character
140
+ yield history, prompt_html
141
+ print('Final history is ',history)
142
+ #store_message(db,history[-1][0],history[-1][1],cross_encoder)
143
+ else:
144
+ # Retrieve documents relevant to query
145
+ document_start = perf_counter()
146
+
147
+ query_vec = retriever.encode(query)
148
+ logger.warning(f'Finished query vec')
149
+ doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
150
+
151
+
152
+
153
+ logger.warning(f'Finished search')
154
+ documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
155
+ documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
156
+ logger.warning(f'start cross encoder {len(documents)}')
157
+ # Retrieve documents relevant to query
158
+ query_doc_pair = [[query, doc] for doc in documents]
159
+ if cross_encoder=='(FAST) MiniLM-L6v2' :
160
+ cross_encoder1 = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
161
+ elif cross_encoder=='(ACCURATE) BGE reranker':
162
+ cross_encoder1 = CrossEncoder('BAAI/bge-reranker-base')
163
+
164
+ cross_scores = cross_encoder1.predict(query_doc_pair)
165
+ sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
166
+ logger.warning(f'Finished cross encoder {len(documents)}')
167
+
168
+ documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
169
+ logger.warning(f'num documents {len(documents)}')
170
+
171
+ document_time = perf_counter() - document_start
172
+ logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
173
+
174
+ # Create Prompt
175
+ prompt = template.render(documents=documents, query=query)
176
+ prompt_html = template_html.render(documents=documents, query=query)
177
+
178
+ generate_fn = generate_hf
179
+
180
+ history[-1][1] = ""
181
+ for character in generate_fn(prompt, history[:-1]):
182
+ history[-1][1] = character
183
+ yield history, prompt_html
184
+ print('Final history is ',history)
185
+ #store_message(db,history[-1][0],history[-1][1],cross_encoder)
186
+
187
+ def system_instructions(question_difficulty, topic,documents_str):
188
+ return f"""<s> [INST] Your are a great teacher and your task is to create 10 questions with 4 choices with a {question_difficulty} difficulty about topic request " {topic} " only from the below given documents, {documents_str} then create an answers. Index in JSON format, the questions as "Q#":"" to "Q#":"", the four choices as "Q#:C1":"" to "Q#:C4":"", and the answers as "A#":"Q#:C#" to "A#":"Q#:C#". [/INST]"""
189
+
190
+
191
+ #with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
192
+ with gr.Blocks(theme='NoCrypt/miku') as CHATBOT:
193
+ with gr.Row():
194
+ with gr.Column(scale=10):
195
+ # gr.Markdown(
196
+ # """
197
+ # # Theme preview: `paris`
198
+ # To use this theme, set `theme='earneleh/paris'` in `gr.Blocks()` or `gr.Interface()`.
199
+ # You can append an `@` and a semantic version expression, e.g. @>=1.0.0,<2.0.0 to pin to a given version
200
+ # of this theme.
201
+ # """
202
+ # )
203
+ gr.HTML(value="""<div style="color: #FF4500;"><h1>CHEERFULL CBSE-</h1> <h1><span style="color: #008000">AI Assisted Fun Learning</span></h1>
204
+ </div>""", elem_id='heading')
205
+
206
+ gr.HTML(value=f"""
207
+ <p style="font-family: sans-serif; font-size: 16px;">
208
+ A free Artificial Intelligence Chatbot assistant trained on CBSE Class 10 Science Notes to engage and help students and teachers of Puducherry.
209
+ </p>
210
+ """, elem_id='Sub-heading')
211
+ #usage_count = get_and_increment_value_count(db,collection_name, field_name)
212
+ gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 14px;">Developed by K M Ramyasri , TGT,GHS.SUTHUKENY . Suggestions may be sent to <a href="mailto:ramyadevi1607@yahoo.com" style="color: #00008B; font-style: italic;">ramyadevi1607@yahoo.com</a>.</p>""", elem_id='Sub-heading1 ')
213
+
214
+ with gr.Column(scale=3):
215
+ gr.Image(value='logo.png',height=200,width=200)
216
+
217
+
218
+ # gr.HTML(value="""<div style="color: #FF4500;"><h1>CHEERFULL CBSE-</h1> <h1><span style="color: #008000">AI Assisted Fun Learning</span></h1>
219
+ # <img src='logo.png' alt="Chatbot" width="50" height="50" />
220
+ # </div>""", elem_id='heading')
221
+
222
+ # gr.HTML(value=f"""
223
+ # <p style="font-family: sans-serif; font-size: 16px;">
224
+ # A free Artificial Intelligence Chatbot assistant trained on CBSE Class 10 Science Notes to engage and help students and teachers of Puducherry.
225
+ # </p>
226
+ # """, elem_id='Sub-heading')
227
+ # #usage_count = get_and_increment_value_count(db,collection_name, field_name)
228
+ # gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 16px;">Developed by K M Ramyasri , PGT . Suggestions may be sent to <a href="mailto:ramyadevi1607@yahoo.com" style="color: #00008B; font-style: italic;">ramyadevi1607@yahoo.com</a>.</p>""", elem_id='Sub-heading1 ')
229
+ # # count_html = gr.HTML(value=f"""
230
+ # # <div style="display: flex; justify-content: flex-end;">
231
+ # # <span style="font-weight: bold; color: maroon; font-size: 18px;">No of Usages:</span>
232
+ # # <span style="font-weight: bold; color: maroon; font-size: 18px;">{usage_count}</span>
233
+ # # </div>
234
+ # # """)
235
+
236
+ chatbot = gr.Chatbot(
237
+ [],
238
+ elem_id="chatbot",
239
+ avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
240
+ 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
241
+ bubble_full_width=False,
242
+ show_copy_button=True,
243
+ show_share_button=True,
244
+ )
245
+
246
+ with gr.Row():
247
+ txt = gr.Textbox(
248
+ scale=3,
249
+ show_label=False,
250
+ placeholder="Enter text and press enter",
251
+ container=False,
252
+ )
253
+ txt_btn = gr.Button(value="Submit text", scale=1)
254
+
255
+ cross_encoder = gr.Radio(choices=['(FAST) MiniLM-L6v2','(ACCURATE) BGE reranker','(HIGH ACCURATE) ColBERT'], value='(ACCURATE) BGE reranker',label="Embeddings", info="Only First query to Colbert may take litte time)")
256
+
257
+ prompt_html = gr.HTML()
258
+ # Turn off interactivity while generating if you click
259
+ txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
260
+ bot, [chatbot, cross_encoder], [chatbot, prompt_html])#.then(update_count_html,[],[count_html])
261
+
262
+ # Turn it back on
263
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
264
+
265
+ # Turn off interactivity while generating if you hit enter
266
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
267
+ bot, [chatbot, cross_encoder], [chatbot, prompt_html])#.then(update_count_html,[],[count_html])
268
+
269
+ # Turn it back on
270
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
271
+
272
+ # Examples
273
+ gr.Examples(examples, txt)
274
+
275
+
276
+ RAG_db=gr.State()
277
+
278
+ with gr.Blocks(title="Quiz Maker", theme=gr.themes.Default(primary_hue="green", secondary_hue="green"), css="style.css") as QUIZBOT:
279
+ def load_model():
280
+ RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
281
+ RAG_db.value=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
282
+ return 'Ready to Go!!'
283
+ with gr.Column(scale=4):
284
+ gr.HTML("""
285
+ <center>
286
+ <h1><span style="color: purple;">ADWITIYA</span> Customs Manual Quizbot</h1>
287
+ <h2>Generative AI-powered Capacity building for Training Officers</h2>
288
+ <i>⚠️ NACIN Faculties create quiz from any topic dynamically for classroom evaluation after their sessions ! ⚠️</i>
289
+ </center>
290
+ """)
291
+ #gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait')
292
+ with gr.Column(scale=2):
293
+ load_btn = gr.Button("Click to Load!🚀")
294
+ load_text=gr.Textbox()
295
+ load_btn.click(load_model,[],load_text)
296
+
297
+
298
+ topic = gr.Textbox(label="Enter the Topic for Quiz", placeholder="Write any topic/details from Customs Manual")
299
+
300
+ with gr.Row():
301
+ radio = gr.Radio(
302
+ ["easy", "average", "hard"], label="How difficult should the quiz be?"
303
+ )
304
+
305
+
306
+ generate_quiz_btn = gr.Button("Generate Quiz!🚀")
307
+ quiz_msg=gr.Textbox()
308
+
309
+ question_radios = [gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
310
+ visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
311
+ visible=False), gr.Radio(visible=False), gr.Radio(visible=False)]
312
+
313
+ print(question_radios)
314
+
315
+ @spaces.GPU
316
+ @generate_quiz_btn.click(inputs=[radio, topic], outputs=[quiz_msg]+question_radios, api_name="generate_quiz")
317
+ def generate_quiz(question_difficulty, topic):
318
+ top_k_rank=10
319
+ RAG_db_=RAG_db.value
320
+ documents_full=RAG_db_.search(topic,k=top_k_rank)
321
+
322
+
323
+
324
+ generate_kwargs = dict(
325
+ temperature=0.2,
326
+ max_new_tokens=4000,
327
+ top_p=0.95,
328
+ repetition_penalty=1.0,
329
+ do_sample=True,
330
+ seed=42,
331
+ )
332
+ question_radio_list = []
333
+ count=0
334
+ while count<=3:
335
+ try:
336
+ documents=[item['content'] for item in documents_full]
337
+ document_summaries = [f"[DOCUMENT {i+1}]: {summary}{count}" for i, summary in enumerate(documents)]
338
+ documents_str='\n'.join(document_summaries)
339
+ formatted_prompt = system_instructions(
340
+ question_difficulty, topic,documents_str)
341
+ print(formatted_prompt)
342
+ pre_prompt = [
343
+ {"role": "system", "content": formatted_prompt}
344
+ ]
345
+ response = client.text_generation(
346
+ formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=False,
347
+ )
348
+ output_json = json.loads(f"{response}")
349
+
350
+
351
+ print(response)
352
+ print('output json', output_json)
353
+
354
+ global quiz_data
355
+
356
+ quiz_data = output_json
357
+
358
+
359
+
360
+ for question_num in range(1, 11):
361
+ question_key = f"Q{question_num}"
362
+ answer_key = f"A{question_num}"
363
+
364
+ question = quiz_data.get(question_key)
365
+ answer = quiz_data.get(quiz_data.get(answer_key))
366
+
367
+ if not question or not answer:
368
+ continue
369
+
370
+ choice_keys = [f"{question_key}:C{i}" for i in range(1, 5)]
371
+ choice_list = []
372
+ for choice_key in choice_keys:
373
+ choice = quiz_data.get(choice_key, "Choice not found")
374
+ choice_list.append(f"{choice}")
375
+
376
+ radio = gr.Radio(choices=choice_list, label=question,
377
+ visible=True, interactive=True)
378
+
379
+ question_radio_list.append(radio)
380
+ if len(question_radio_list)==10:
381
+ break
382
+ else:
383
+ print('10 questions not generated . So trying again!')
384
+ count+=1
385
+ continue
386
+ except Exception as e:
387
+ count+=1
388
+ print(f"Exception occurred: {e}")
389
+ if count==3:
390
+ print('Retry exhausted')
391
+ gr.Warning('Sorry. Pls try with another topic !')
392
+ else:
393
+ print(f"Trying again..{count} time...please wait")
394
+ continue
395
+
396
+ print('Question radio list ' , question_radio_list)
397
+
398
+ return ['Quiz Generated!']+ question_radio_list
399
+
400
+ check_button = gr.Button("Check Score")
401
+
402
+ score_textbox = gr.Markdown()
403
+
404
+ @check_button.click(inputs=question_radios, outputs=score_textbox)
405
+ def compare_answers(*user_answers):
406
+ user_anwser_list = []
407
+ user_anwser_list = user_answers
408
+
409
+ answers_list = []
410
+
411
+ for question_num in range(1, 20):
412
+ answer_key = f"A{question_num}"
413
+ answer = quiz_data.get(quiz_data.get(answer_key))
414
+ if not answer:
415
+ break
416
+ answers_list.append(answer)
417
+
418
+ score = 0
419
+
420
+ for item in user_anwser_list:
421
+ if item in answers_list:
422
+ score += 1
423
+ if score>5:
424
+ message = f"### Good ! You got {score} over 10!"
425
+ elif score>7:
426
+ message = f"### Excellent ! You got {score} over 10!"
427
+ else:
428
+ message = f"### You got {score} over 10! Dont worry . You can prepare well and try better next time !"
429
+
430
+ return message
431
+
432
+
433
+
434
+ demo = gr.TabbedInterface([CHATBOT,QUIZBOT], ["AI ChatBot", "AI Quizbot"])
435
+
436
+ demo.queue()
437
+ demo.launch(debug=True)