|
|
|
from ragatouille import RAGPretrainedModel |
|
import subprocess |
|
import json |
|
import spaces |
|
import firebase_admin |
|
from firebase_admin import credentials, firestore |
|
import logging |
|
from pathlib import Path |
|
from time import perf_counter |
|
from datetime import datetime |
|
import gradio as gr |
|
from jinja2 import Environment, FileSystemLoader |
|
import numpy as np |
|
from sentence_transformers import CrossEncoder |
|
from huggingface_hub import InferenceClient |
|
from os import getenv |
|
|
|
from backend.query_llm import generate_hf, generate_openai |
|
from backend.semantic_search import table, retriever |
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
VECTOR_COLUMN_NAME = "vector" |
|
TEXT_COLUMN_NAME = "text" |
|
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN") |
|
proj_dir = Path(__file__).parent |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1",token=HF_TOKEN) |
|
|
|
env = Environment(loader=FileSystemLoader(proj_dir / 'templates')) |
|
|
|
|
|
template = env.get_template('template.j2') |
|
template_html = env.get_template('template_html.j2') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples = ['My transhipment cargo is missing','can u explain and tabulate difference between b 17 bond and a warehousing bond', |
|
'What are benefits of the AEO Scheme and eligibility criteria?', |
|
'What are penalties for customs offences? ', 'what are penalties to customs officers misusing their powers under customs act?','What are eligibility criteria for exemption from cost recovery charges','list in detail what is procedure for obtaining new approval for openeing a CFS attached to an ICD'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_text(history, text): |
|
history = [] if history is None else history |
|
history = history + [(text, None)] |
|
return history, gr.Textbox(value="", interactive=False) |
|
|
|
|
|
def bot(history, cross_encoder): |
|
top_rerank = 25 |
|
top_k_rank = 20 |
|
query = history[-1][0] |
|
|
|
if not query: |
|
gr.Warning("Please submit a non-empty string as a prompt") |
|
raise ValueError("Empty string was submitted") |
|
|
|
logger.warning('Retrieving documents...') |
|
|
|
|
|
if cross_encoder=='(HIGH ACCURATE) ColBERT': |
|
gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait') |
|
RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") |
|
RAG_db=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index') |
|
documents_full=RAG_db.search(query,k=top_k_rank) |
|
|
|
documents=[item['content'] for item in documents_full] |
|
|
|
prompt = template.render(documents=documents, query=query) |
|
prompt_html = template_html.render(documents=documents, query=query) |
|
|
|
generate_fn = generate_hf |
|
|
|
history[-1][1] = "" |
|
for character in generate_fn(prompt, history[:-1]): |
|
history[-1][1] = character |
|
yield history, prompt_html |
|
print('Final history is ',history) |
|
|
|
else: |
|
|
|
document_start = perf_counter() |
|
|
|
query_vec = retriever.encode(query) |
|
logger.warning(f'Finished query vec') |
|
doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank) |
|
|
|
|
|
|
|
logger.warning(f'Finished search') |
|
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list() |
|
documents = [doc[TEXT_COLUMN_NAME] for doc in documents] |
|
logger.warning(f'start cross encoder {len(documents)}') |
|
|
|
query_doc_pair = [[query, doc] for doc in documents] |
|
if cross_encoder=='(FAST) MiniLM-L6v2' : |
|
cross_encoder1 = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') |
|
elif cross_encoder=='(ACCURATE) BGE reranker': |
|
cross_encoder1 = CrossEncoder('BAAI/bge-reranker-base') |
|
|
|
cross_scores = cross_encoder1.predict(query_doc_pair) |
|
sim_scores_argsort = list(reversed(np.argsort(cross_scores))) |
|
logger.warning(f'Finished cross encoder {len(documents)}') |
|
|
|
documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]] |
|
logger.warning(f'num documents {len(documents)}') |
|
|
|
document_time = perf_counter() - document_start |
|
logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...') |
|
|
|
|
|
prompt = template.render(documents=documents, query=query) |
|
prompt_html = template_html.render(documents=documents, query=query) |
|
|
|
generate_fn = generate_hf |
|
|
|
history[-1][1] = "" |
|
for character in generate_fn(prompt, history[:-1]): |
|
history[-1][1] = character |
|
yield history, prompt_html |
|
print('Final history is ',history) |
|
|
|
|
|
def system_instructions(question_difficulty, topic,documents_str): |
|
return f"""<s> [INST] Your are a great teacher and your task is to create 10 questions with 4 choices with a {question_difficulty} difficulty about topic request " {topic} " only from the below given documents, {documents_str} then create an answers. Index in JSON format, the questions as "Q#":"" to "Q#":"", the four choices as "Q#:C1":"" to "Q#:C4":"", and the answers as "A#":"Q#:C#" to "A#":"Q#:C#". [/INST]""" |
|
|
|
|
|
|
|
with gr.Blocks(theme='NoCrypt/miku') as CHATBOT: |
|
with gr.Row(): |
|
with gr.Column(scale=10): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.HTML(value="""<div style="color: #FF4500;"><h1>ADWITIYA-</h1> <h1><span style="color: #008000">Custom Manual Chatbot and Quizbot</span></h1> |
|
</div>""", elem_id='heading') |
|
|
|
gr.HTML(value=f""" |
|
<p style="font-family: sans-serif; font-size: 16px;"> |
|
Using GenAI for CBIC Capacity Building - A free chat bot developed by National Customs Targeting Center using Open source LLMs for CBIC Officers |
|
</p> |
|
""", elem_id='Sub-heading') |
|
|
|
gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 14px;">Developed by NCTC,Mumbai . Suggestions may be sent to <a href="mailto:nctc-admin@gov.in" style="color: #00008B; font-style: italic;">ramyadevi1607@yahoo.com</a>.</p>""", elem_id='Sub-heading1 ') |
|
|
|
with gr.Column(scale=3): |
|
gr.Image(value='logo.png',height=200,width=200) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot( |
|
[], |
|
elem_id="chatbot", |
|
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg', |
|
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'), |
|
bubble_full_width=False, |
|
show_copy_button=True, |
|
show_share_button=True, |
|
) |
|
|
|
with gr.Row(): |
|
txt = gr.Textbox( |
|
scale=3, |
|
show_label=False, |
|
placeholder="Enter text and press enter", |
|
container=False, |
|
) |
|
txt_btn = gr.Button(value="Submit text", scale=1) |
|
|
|
cross_encoder = gr.Radio(choices=['(FAST) MiniLM-L6v2','(ACCURATE) BGE reranker','(HIGH ACCURATE) ColBERT'], value='(ACCURATE) BGE reranker',label="Embeddings", info="Only First query to Colbert may take litte time)") |
|
|
|
prompt_html = gr.HTML() |
|
|
|
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( |
|
bot, [chatbot, cross_encoder], [chatbot, prompt_html]) |
|
|
|
|
|
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) |
|
|
|
|
|
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then( |
|
bot, [chatbot, cross_encoder], [chatbot, prompt_html]) |
|
|
|
|
|
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False) |
|
|
|
|
|
gr.Examples(examples, txt) |
|
|
|
|
|
RAG_db=gr.State() |
|
|
|
with gr.Blocks(title="Quiz Maker", theme=gr.themes.Default(primary_hue="green", secondary_hue="green"), css="style.css") as QUIZBOT: |
|
def load_model(): |
|
RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0") |
|
RAG_db.value=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index') |
|
return 'Ready to Go!!' |
|
with gr.Column(scale=4): |
|
gr.HTML(""" |
|
<center> |
|
<h1><span style="color: purple;">ADWITIYA</span> Customs Manual Quizbot</h1> |
|
<h2>Generative AI-powered Capacity building for Training Officers</h2> |
|
<i>β οΈ NACIN Faculties create quiz from any topic dynamically for classroom evaluation after their sessions ! β οΈ</i> |
|
</center> |
|
""") |
|
|
|
with gr.Column(scale=2): |
|
load_btn = gr.Button("Click to Load!π") |
|
load_text=gr.Textbox() |
|
load_btn.click(load_model,[],load_text) |
|
|
|
|
|
topic = gr.Textbox(label="Enter the Topic for Quiz", placeholder="Write any topic/details from Customs Manual") |
|
|
|
with gr.Row(): |
|
radio = gr.Radio( |
|
["easy", "average", "hard"], label="How difficult should the quiz be?" |
|
) |
|
|
|
|
|
generate_quiz_btn = gr.Button("Generate Quiz!π") |
|
quiz_msg=gr.Textbox() |
|
|
|
question_radios = [gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio( |
|
visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio( |
|
visible=False), gr.Radio(visible=False), gr.Radio(visible=False)] |
|
|
|
print(question_radios) |
|
|
|
@spaces.GPU |
|
@generate_quiz_btn.click(inputs=[radio, topic], outputs=[quiz_msg]+question_radios, api_name="generate_quiz") |
|
def generate_quiz(question_difficulty, topic): |
|
top_k_rank=10 |
|
RAG_db_=RAG_db.value |
|
documents_full=RAG_db_.search(topic,k=top_k_rank) |
|
|
|
|
|
|
|
generate_kwargs = dict( |
|
temperature=0.2, |
|
max_new_tokens=4000, |
|
top_p=0.95, |
|
repetition_penalty=1.0, |
|
do_sample=True, |
|
seed=42, |
|
) |
|
question_radio_list = [] |
|
count=0 |
|
while count<=3: |
|
try: |
|
documents=[item['content'] for item in documents_full] |
|
document_summaries = [f"[DOCUMENT {i+1}]: {summary}{count}" for i, summary in enumerate(documents)] |
|
documents_str='\n'.join(document_summaries) |
|
formatted_prompt = system_instructions( |
|
question_difficulty, topic,documents_str) |
|
print(formatted_prompt) |
|
pre_prompt = [ |
|
{"role": "system", "content": formatted_prompt} |
|
] |
|
response = client.text_generation( |
|
formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=False, |
|
) |
|
output_json = json.loads(f"{response}") |
|
|
|
|
|
print(response) |
|
print('output json', output_json) |
|
|
|
global quiz_data |
|
|
|
quiz_data = output_json |
|
|
|
|
|
|
|
for question_num in range(1, 11): |
|
question_key = f"Q{question_num}" |
|
answer_key = f"A{question_num}" |
|
|
|
question = quiz_data.get(question_key) |
|
answer = quiz_data.get(quiz_data.get(answer_key)) |
|
|
|
if not question or not answer: |
|
continue |
|
|
|
choice_keys = [f"{question_key}:C{i}" for i in range(1, 5)] |
|
choice_list = [] |
|
for choice_key in choice_keys: |
|
choice = quiz_data.get(choice_key, "Choice not found") |
|
choice_list.append(f"{choice}") |
|
|
|
radio = gr.Radio(choices=choice_list, label=question, |
|
visible=True, interactive=True) |
|
|
|
question_radio_list.append(radio) |
|
if len(question_radio_list)==10: |
|
break |
|
else: |
|
print('10 questions not generated . So trying again!') |
|
count+=1 |
|
continue |
|
except Exception as e: |
|
count+=1 |
|
print(f"Exception occurred: {e}") |
|
if count==3: |
|
print('Retry exhausted') |
|
gr.Warning('Sorry. Pls try with another topic !') |
|
else: |
|
print(f"Trying again..{count} time...please wait") |
|
continue |
|
|
|
print('Question radio list ' , question_radio_list) |
|
|
|
return ['Quiz Generated!']+ question_radio_list |
|
|
|
check_button = gr.Button("Check Score") |
|
|
|
score_textbox = gr.Markdown() |
|
|
|
@check_button.click(inputs=question_radios, outputs=score_textbox) |
|
def compare_answers(*user_answers): |
|
user_anwser_list = [] |
|
user_anwser_list = user_answers |
|
|
|
answers_list = [] |
|
|
|
for question_num in range(1, 20): |
|
answer_key = f"A{question_num}" |
|
answer = quiz_data.get(quiz_data.get(answer_key)) |
|
if not answer: |
|
break |
|
answers_list.append(answer) |
|
|
|
score = 0 |
|
|
|
for item in user_anwser_list: |
|
if item in answers_list: |
|
score += 1 |
|
if score>5: |
|
message = f"### Good ! You got {score} over 10!" |
|
elif score>7: |
|
message = f"### Excellent ! You got {score} over 10!" |
|
else: |
|
message = f"### You got {score} over 10! Dont worry . You can prepare well and try better next time !" |
|
|
|
return message |
|
|
|
|
|
|
|
demo = gr.TabbedInterface([CHATBOT,QUIZBOT], ["AI ChatBot", "AI Quizbot"]) |
|
|
|
demo.queue() |
|
demo.launch(debug=True) |
|
|