RAMYASRI-39's picture
Update app.py
36d3167 verified
raw
history blame
18.4 kB
from ragatouille import RAGPretrainedModel
import subprocess
import json
import spaces
import firebase_admin
from firebase_admin import credentials, firestore
import logging
from pathlib import Path
from time import perf_counter
from datetime import datetime
import gradio as gr
from jinja2 import Environment, FileSystemLoader
import numpy as np
from sentence_transformers import CrossEncoder
from huggingface_hub import InferenceClient
from os import getenv
from backend.query_llm import generate_hf, generate_openai
from backend.semantic_search import table, retriever
from huggingface_hub import InferenceClient
VECTOR_COLUMN_NAME = "vector"
TEXT_COLUMN_NAME = "text"
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1",token=HF_TOKEN)
# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')
#___________________
# service_account_key='firebase.json'
# # Create a Certificate object from the service account info
# cred = credentials.Certificate(service_account_key)
# # Initialize the Firebase Admin
# firebase_admin.initialize_app(cred)
# # # Create a reference to the Firestore database
# db = firestore.client()
# #db usage
# collection_name = 'Nirvachana' # Replace with your collection name
# field_name = 'message_count' # Replace with your field name for count
# Examples
examples = ['Tabulate the difference between veins and arteries','What are defects in Human eye?',
'Frame 5 short questions and 5 MCQ on Chapter 2 ','Suggest creative and engaging ideas to teach students on Chapter on Metals and Non Metals '
]
# def get_and_increment_value_count(db , collection_name, field_name):
# """
# Retrieves a value count from the specified Firestore collection and field,
# increments it by 1, and updates the field with the new value."""
# collection_ref = db.collection(collection_name)
# doc_ref = collection_ref.document('count_doc') # Assuming a dedicated document for count
# # Use a transaction to ensure consistency across reads and writes
# try:
# with db.transaction() as transaction:
# # Get the current value count (or initialize to 0 if it doesn't exist)
# current_count_doc = doc_ref.get()
# current_count_data = current_count_doc.to_dict()
# if current_count_data:
# current_count = current_count_data.get(field_name, 0)
# else:
# current_count = 0
# # Increment the count
# new_count = current_count + 1
# # Update the document with the new count
# transaction.set(doc_ref, {field_name: new_count})
# return new_count
# except Exception as e:
# print(f"Error retrieving and updating value count: {e}")
# return None # Indicate error
# def update_count_html():
# usage_count = get_and_increment_value_count(db ,collection_name, field_name)
# ccount_html = gr.HTML(value=f"""
# <div style="display: flex; justify-content: flex-end;">
# <span style="font-weight: bold; color: maroon; font-size: 18px;">No of Usages:</span>
# <span style="font-weight: bold; color: maroon; font-size: 18px;">{usage_count}</span>
# </div>
# """)
# return count_html
# def store_message(db,query,answer,cross_encoder):
# timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# # Create a new document reference with a dynamic document name based on timestamp
# new_completion= db.collection('Nirvachana').document(f"chatlogs_{timestamp}")
# new_completion.set({
# 'query': query,
# 'answer':answer,
# 'created_time': firestore.SERVER_TIMESTAMP,
# 'embedding': cross_encoder,
# 'title': 'Expenditure observer bot'
# })
def add_text(history, text):
history = [] if history is None else history
history = history + [(text, None)]
return history, gr.Textbox(value="", interactive=False)
def bot(history, cross_encoder):
top_rerank = 15
top_k_rank = 10
query = history[-1][0]
if not query:
gr.Warning("Please submit a non-empty string as a prompt")
raise ValueError("Empty string was submitted")
logger.warning('Retrieving documents...')
# if COLBERT RAGATATOUILLE PROCEDURE :
if cross_encoder=='(HIGH ACCURATE) ColBERT':
gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait')
RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
RAG_db=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
documents_full=RAG_db.search(query,k=top_k_rank)
documents=[item['content'] for item in documents_full]
# Create Prompt
prompt = template.render(documents=documents, query=query)
prompt_html = template_html.render(documents=documents, query=query)
generate_fn = generate_hf
history[-1][1] = ""
for character in generate_fn(prompt, history[:-1]):
history[-1][1] = character
print('Final history is ',history)
yield history, prompt_html
#store_message(db,history[-1][0],history[-1][1],cross_encoder)
else:
# Retrieve documents relevant to query
document_start = perf_counter()
query_vec = retriever.encode(query)
logger.warning(f'Finished query vec')
doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
logger.warning(f'Finished search')
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
logger.warning(f'start cross encoder {len(documents)}')
# Retrieve documents relevant to query
query_doc_pair = [[query, doc] for doc in documents]
if cross_encoder=='(FAST) MiniLM-L6v2' :
cross_encoder1 = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
elif cross_encoder=='(ACCURATE) BGE reranker':
cross_encoder1 = CrossEncoder('BAAI/bge-reranker-base')
cross_scores = cross_encoder1.predict(query_doc_pair)
sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
logger.warning(f'Finished cross encoder {len(documents)}')
documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
logger.warning(f'num documents {len(documents)}')
document_time = perf_counter() - document_start
logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
# Create Prompt
prompt = template.render(documents=documents, query=query)
prompt_html = template_html.render(documents=documents, query=query)
generate_fn = generate_hf
history[-1][1] = ""
for character in generate_fn(prompt, history[:-1]):
history[-1][1] = character
print('Final history is ',history)
yield history, prompt_html
#store_message(db,history[-1][0],history[-1][1],cross_encoder)
def system_instructions(question_difficulty, topic,documents_str):
return f"""<s> [INST] Your are a great teacher and your task is to create 10 questions with 4 choices with a {question_difficulty} difficulty about topic request " {topic} " only from the below given documents, {documents_str} then create an answers. Index in JSON format, the questions as "Q#":"" to "Q#":"", the four choices as "Q#:C1":"" to "Q#:C4":"", and the answers as "A#":"Q#:C#" to "A#":"Q#:C#". [/INST]"""
#with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
with gr.Blocks(theme='NoCrypt/miku') as CHATBOT:
with gr.Row():
with gr.Column(scale=10):
# gr.Markdown(
# """
# # Theme preview: `paris`
# To use this theme, set `theme='earneleh/paris'` in `gr.Blocks()` or `gr.Interface()`.
# You can append an `@` and a semantic version expression, e.g. @>=1.0.0,<2.0.0 to pin to a given version
# of this theme.
# """
# )
gr.HTML(value="""<div style="color: #FF4500;"><h1>CHEERFULL CBSE-</h1> <h1><span style="color: #008000">AI Assisted Fun Learning</span></h1>
</div>""", elem_id='heading')
gr.HTML(value=f"""
<p style="font-family: sans-serif; font-size: 16px;">
A free Artificial Intelligence Chatbot assistant trained on CBSE Class 10 Science Notes to engage and help students and teachers of Puducherry.
</p>
""", elem_id='Sub-heading')
#usage_count = get_and_increment_value_count(db,collection_name, field_name)
gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 14px;">Developed by K M Ramyasri , TGT . Suggestions may be sent to <a href="mailto:ramyadevi1607@yahoo.com" style="color: #00008B; font-style: italic;">mramesh.irs@gov.in</a>.</p>""", elem_id='Sub-heading1 ')
with gr.Column(scale=3):
gr.Image(value='logo.png',height=200,width=200)
# gr.HTML(value="""<div style="color: #FF4500;"><h1>CHEERFULL CBSE-</h1> <h1><span style="color: #008000">AI Assisted Fun Learning</span></h1>
# <img src='logo.png' alt="Chatbot" width="50" height="50" />
# </div>""", elem_id='heading')
# gr.HTML(value=f"""
# <p style="font-family: sans-serif; font-size: 16px;">
# A free Artificial Intelligence Chatbot assistant trained on CBSE Class 10 Science Notes to engage and help students and teachers of Puducherry.
# </p>
# """, elem_id='Sub-heading')
# #usage_count = get_and_increment_value_count(db,collection_name, field_name)
# gr.HTML(value=f"""<p style="font-family: Arial, sans-serif; font-size: 16px;">Developed by K M Ramyasri , PGT . Suggestions may be sent to <a href="mailto:mramesh.irs@gov.in" style="color: #00008B; font-style: italic;">mramesh.irs@gov.in</a>.</p>""", elem_id='Sub-heading1 ')
# # count_html = gr.HTML(value=f"""
# # <div style="display: flex; justify-content: flex-end;">
# # <span style="font-weight: bold; color: maroon; font-size: 18px;">No of Usages:</span>
# # <span style="font-weight: bold; color: maroon; font-size: 18px;">{usage_count}</span>
# # </div>
# # """)
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
bubble_full_width=False,
show_copy_button=True,
show_share_button=True,
)
with gr.Row():
txt = gr.Textbox(
scale=3,
show_label=False,
placeholder="Enter text and press enter",
container=False,
)
txt_btn = gr.Button(value="Submit text", scale=1)
cross_encoder = gr.Radio(choices=['(FAST) MiniLM-L6v2','(ACCURATE) BGE reranker','(HIGH ACCURATE) ColBERT'], value='(ACCURATE) BGE reranker',label="Embeddings", info="Only First query to Colbert may take litte time)")
prompt_html = gr.HTML()
# Turn off interactivity while generating if you click
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, [chatbot, cross_encoder], [chatbot, prompt_html])#.then(update_count_html,[],[count_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
# Turn off interactivity while generating if you hit enter
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, [chatbot, cross_encoder], [chatbot, prompt_html])#.then(update_count_html,[],[count_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
# Examples
gr.Examples(examples, txt)
RAG_db=gr.State()
with gr.Blocks(title="Quiz Maker", theme=gr.themes.Default(primary_hue="green", secondary_hue="green"), css="style.css") as QUIZBOT:
def load_model():
RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
RAG_db.value=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
return 'Ready to Go!!'
with gr.Column(scale=4):
gr.HTML("""
<center>
<h1><span style="color: purple;">AI NANBAN</span> - CBSE Class Quiz Maker</h1>
<h2>AI-powered Learning Game</h2>
<i>⚠️ Students create quiz from any topic /CBSE Chapter ! ⚠️</i>
</center>
""")
#gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait')
with gr.Column(scale=2):
load_btn = gr.Button("Click to Load!🚀")
load_text=gr.Textbox()
load_btn.click(load_model,[],load_text)
topic = gr.Textbox(label="Enter the Topic for Quiz", placeholder="Write any topic from CBSE notes")
with gr.Row():
radio = gr.Radio(
["easy", "average", "hard"], label="How difficult should the quiz be?"
)
generate_quiz_btn = gr.Button("Generate Quiz!🚀")
quiz_msg=gr.Textbox()
question_radios = [gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(visible=False), gr.Radio(
visible=False), gr.Radio(visible=False), gr.Radio(visible=False)]
print(question_radios)
@spaces.GPU
@generate_quiz_btn.click(inputs=[radio, topic], outputs=[quiz_msg]+question_radios, api_name="generate_quiz")
def generate_quiz(question_difficulty, topic):
top_k_rank=10
RAG_db_=RAG_db.value
documents_full=RAG_db_.search(topic,k=top_k_rank)
documents=[item['content'] for item in documents_full]
document_summaries = [f"[DOCUMENT {i+1}]: {summary}" for i, summary in enumerate(documents)]
documents_str='\n'.join(document_summaries)
formatted_prompt = system_instructions(
question_difficulty, topic,documents_str)
print(formatted_prompt)
pre_prompt = [
{"role": "system", "content": formatted_prompt}
]
generate_kwargs = dict(
temperature=0.2,
max_new_tokens=4000,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
seed=42,
)
question_radio_list = []
count=0
while count<=3:
try:
response = client.text_generation(
formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=False,
)
output_json = json.loads(f"{response}")
print(response)
print('output json', output_json)
global quiz_data
quiz_data = output_json
for question_num in range(1, 11):
question_key = f"Q{question_num}"
answer_key = f"A{question_num}"
question = quiz_data.get(question_key)
answer = quiz_data.get(quiz_data.get(answer_key))
if not question or not answer:
continue
choice_keys = [f"{question_key}:C{i}" for i in range(1, 5)]
choice_list = []
for choice_key in choice_keys:
choice = quiz_data.get(choice_key, "Choice not found")
choice_list.append(f"{choice}")
radio = gr.Radio(choices=choice_list, label=question,
visible=True, interactive=True)
question_radio_list.append(radio)
if len(question_radio_list)==10:
break
else:
print('10 questions not generated . So trying again!')
continue
except Exception as e:
count+=1
print(f"Exception occurred: {e}")
if count==3:
print('Retry exhausted')
gr.Warning('Sorry. Pls try with another topic !')
else:
print(f"Trying again..{count} time...please wait")
continue
print('Question radio list ' , question_radio_list)
return ['Quiz Generated!']+ question_radio_list
check_button = gr.Button("Check Score")
score_textbox = gr.Markdown()
@check_button.click(inputs=question_radios, outputs=score_textbox)
def compare_answers(*user_answers):
user_anwser_list = []
user_anwser_list = user_answers
answers_list = []
for question_num in range(1, 20):
answer_key = f"A{question_num}"
answer = quiz_data.get(quiz_data.get(answer_key))
if not answer:
break
answers_list.append(answer)
score = 0
for item in user_anwser_list:
if item in answers_list:
score += 1
if score>5:
message = f"### Good ! You got {score} over 10!"
elif score>7:
message = f"### Excellent ! You got {score} over 10!"
else:
message = f"### You got {score} over 10! Dont worry . You can prepare well and try better next time !"
return message
demo = gr.TabbedInterface([CHATBOT,QUIZBOT], ["AI ChatBot", "AI Nanban-Quizbot"])
demo.queue()
demo.launch(debug=True)