from ragatouille import RAGPretrainedModel
import subprocess
import json
import firebase_admin
from firebase_admin import credentials, firestore
import logging
from pathlib import Path
from time import perf_counter
from datetime import datetime
import gradio as gr
from jinja2 import Environment, FileSystemLoader
import numpy as np
from sentence_transformers import CrossEncoder
from backend.query_llm import generate_hf, generate_openai
from backend.semantic_search import table, retriever
VECTOR_COLUMN_NAME = "vector"
TEXT_COLUMN_NAME = "text"
proj_dir = Path(__file__).parent
# Setting up the logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set up the template environment with the templates directory
env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
# Load the templates directly from the environment
template = env.get_template('template.j2')
template_html = env.get_template('template_html.j2')
#___________________
# service_account_key='firebase.json'
# # Create a Certificate object from the service account info
# cred = credentials.Certificate(service_account_key)
# # Initialize the Firebase Admin
# firebase_admin.initialize_app(cred)
# # # Create a reference to the Firestore database
# db = firestore.client()
# #db usage
# collection_name = 'Nirvachana' # Replace with your collection name
# field_name = 'message_count' # Replace with your field name for count
# Examples
examples = ['Tabulate the difference between veins and arteries','What are defects in Human eye?',
'Frame 5 short questions and 5 MCQ on Chapter 2 ','Suggest creative and engaging ideas to teach students on Chapter on Metals and Non Metals '
]
# def get_and_increment_value_count(db , collection_name, field_name):
# """
# Retrieves a value count from the specified Firestore collection and field,
# increments it by 1, and updates the field with the new value."""
# collection_ref = db.collection(collection_name)
# doc_ref = collection_ref.document('count_doc') # Assuming a dedicated document for count
# # Use a transaction to ensure consistency across reads and writes
# try:
# with db.transaction() as transaction:
# # Get the current value count (or initialize to 0 if it doesn't exist)
# current_count_doc = doc_ref.get()
# current_count_data = current_count_doc.to_dict()
# if current_count_data:
# current_count = current_count_data.get(field_name, 0)
# else:
# current_count = 0
# # Increment the count
# new_count = current_count + 1
# # Update the document with the new count
# transaction.set(doc_ref, {field_name: new_count})
# return new_count
# except Exception as e:
# print(f"Error retrieving and updating value count: {e}")
# return None # Indicate error
# def update_count_html():
# usage_count = get_and_increment_value_count(db ,collection_name, field_name)
# ccount_html = gr.HTML(value=f"""
#
# No of Usages:
# {usage_count}
#
# """)
# return count_html
# def store_message(db,query,answer,cross_encoder):
# timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# # Create a new document reference with a dynamic document name based on timestamp
# new_completion= db.collection('Nirvachana').document(f"chatlogs_{timestamp}")
# new_completion.set({
# 'query': query,
# 'answer':answer,
# 'created_time': firestore.SERVER_TIMESTAMP,
# 'embedding': cross_encoder,
# 'title': 'Expenditure observer bot'
# })
def add_text(history, text):
history = [] if history is None else history
history = history + [(text, None)]
return history, gr.Textbox(value="", interactive=False)
def bot(history, cross_encoder):
top_rerank = 15
top_k_rank = 10
query = history[-1][0]
if not query:
gr.Warning("Please submit a non-empty string as a prompt")
raise ValueError("Empty string was submitted")
logger.warning('Retrieving documents...')
# if COLBERT RAGATATOUILLE PROCEDURE :
if cross_encoder=='(HIGH ACCURATE) ColBERT':
gr.Warning('Retrieving using ColBERT.. First time query will take a minute for model to load..pls wait')
RAG= RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
RAG_db=RAG.from_index('.ragatouille/colbert/indexes/cbseclass10index')
documents_full=RAG_db.search(query,k=top_k_rank)
documents=[item['content'] for item in documents_full]
# Create Prompt
prompt = template.render(documents=documents, query=query)
prompt_html = template_html.render(documents=documents, query=query)
generate_fn = generate_hf
history[-1][1] = ""
for character in generate_fn(prompt, history[:-1]):
history[-1][1] = character
print('Final history is ',history)
yield history, prompt_html
#store_message(db,history[-1][0],history[-1][1],cross_encoder)
else:
# Retrieve documents relevant to query
document_start = perf_counter()
query_vec = retriever.encode(query)
logger.warning(f'Finished query vec')
doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
logger.warning(f'Finished search')
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
logger.warning(f'start cross encoder {len(documents)}')
# Retrieve documents relevant to query
query_doc_pair = [[query, doc] for doc in documents]
if cross_encoder=='(FAST) MiniLM-L6v2' :
cross_encoder1 = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
elif cross_encoder=='(ACCURATE) BGE reranker':
cross_encoder1 = CrossEncoder('BAAI/bge-reranker-base')
cross_scores = cross_encoder1.predict(query_doc_pair)
sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
logger.warning(f'Finished cross encoder {len(documents)}')
documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
logger.warning(f'num documents {len(documents)}')
document_time = perf_counter() - document_start
logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
# Create Prompt
prompt = template.render(documents=documents, query=query)
prompt_html = template_html.render(documents=documents, query=query)
generate_fn = generate_hf
history[-1][1] = ""
for character in generate_fn(prompt, history[:-1]):
history[-1][1] = character
print('Final history is ',history)
yield history, prompt_html
#store_message(db,history[-1][0],history[-1][1],cross_encoder)
#with gr.Blocks(theme='Insuz/SimpleIndigo') as demo:
with gr.Blocks(theme='dawood/dracula_test') as demo:
gr.HTML(value="""CHEERFULL CBSE-
AI Assisted Fun Learning
""", elem_id='heading')
gr.HTML(value=f"""
A free Artificial Intelligence Chatbot assistant trained on CBSE Class 10 Science Notes to engage and help students and teachers of Puducherry.
""", elem_id='Sub-heading')
#usage_count = get_and_increment_value_count(db,collection_name, field_name)
gr.HTML(value=f"""Developed by K M Ramyasri , PGT . Suggestions may be sent to mramesh.irs@gov.in.
""", elem_id='Sub-heading1 ')
# count_html = gr.HTML(value=f"""
#
# No of Usages:
# {usage_count}
#
# """)
chatbot = gr.Chatbot(
[],
elem_id="chatbot",
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
bubble_full_width=False,
show_copy_button=True,
show_share_button=True,
)
with gr.Row():
txt = gr.Textbox(
scale=3,
show_label=False,
placeholder="Enter text and press enter",
container=False,
)
txt_btn = gr.Button(value="Submit text", scale=1)
cross_encoder = gr.Radio(choices=['(FAST) MiniLM-L6v2','(ACCURATE) BGE reranker','(HIGH ACCURATE) ColBERT'], value='(ACCURATE) BGE reranker',label="Embeddings", info="Only First query to Colbert may take litte time)")
prompt_html = gr.HTML()
# Turn off interactivity while generating if you click
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, [chatbot, cross_encoder], [chatbot, prompt_html])#.then(update_count_html,[],[count_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
# Turn off interactivity while generating if you hit enter
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
bot, [chatbot, cross_encoder], [chatbot, prompt_html])#.then(update_count_html,[],[count_html])
# Turn it back on
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
# Examples
gr.Examples(examples, txt)
demo.queue()
demo.launch(debug=True)