Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import datetime | |
import json | |
import src.constants as constants_utils | |
import src.kkms_kssw as kkms_kssw | |
import src.weather as weather_utils | |
os.environ["CURL_CA_BUNDLE"] = "" | |
import warnings | |
warnings.filterwarnings('ignore') | |
class DomState: | |
def __init__( | |
self, | |
index_type, | |
load_from_existing_index_file | |
): | |
self.index_type = index_type | |
self.load_from_existing_index_file = load_from_existing_index_file | |
self.relevant_paragraphs = '' | |
self.sources_relevant_paragraphs = '' | |
self.answer = '' | |
self.summary = '' | |
self.mandi_price = '' | |
self.mandi_from_date = (datetime.datetime.now() - datetime.timedelta(days=5)).strftime('%Y-%m-%d') | |
self.mandi_to_date = datetime.datetime.now().strftime('%Y-%m-%d') | |
self.weather_info = '' | |
self.weather_forecast = '' | |
self.weather_forecast_summary = '' | |
self.indic_translation = '' | |
self.kb_sources = '' | |
# Initialize index (vector store) - This will create a new index from scratch if load_from_existing_index_file == False | |
self.kkms_kssw_obj = kkms_kssw.KKMS_KSSW() | |
self.kkms_kssw_obj.load_create_index() | |
def click_handler_for_get_relevant_paragraphs( | |
self, | |
question_category, | |
question | |
): | |
self.relevant_paragraphs = self.kkms_kssw_obj.query( | |
question=question, | |
question_category=question_category | |
) | |
if self.index_type in ['FAISS', 'Chroma']: | |
self.sources_relevant_paragraphs = [doc.metadata for doc in self.relevant_paragraphs] | |
self.relevant_paragraphs = [doc.page_content.replace('\n', '').replace('\t', ' ') for doc in self.relevant_paragraphs] | |
return self.relevant_paragraphs | |
def click_handler_for_relevant_paragraphs_source( | |
self, | |
relevant_paragraphs | |
): | |
return self.sources_relevant_paragraphs | |
def click_handler_for_summary( | |
self, | |
answer | |
): | |
self.sumamry = self.kkms_kssw_obj.langchain_utils_obj.get_textual_summary(answer) | |
return self.sumamry | |
def click_handler_for_get_answer( | |
self, | |
relevant_paragraphs, | |
question | |
): | |
self.answer = self.kkms_kssw_obj.langchain_utils_obj.get_answer_from_para( | |
relevant_paragraphs, | |
question | |
) | |
return self.answer | |
def click_handler_for_mandi_price( | |
self, | |
state_name, | |
apmc_name, | |
commodity_name, | |
from_date, | |
to_date | |
): | |
if state_name and apmc_name and commodity_name and from_date and to_date: | |
self.mandi_price = self.kkms_kssw_obj.mandi_utils_obj.get_mandi_price(state_name, apmc_name, commodity_name, from_date, to_date) | |
return self.mandi_price | |
def click_handler_for_get_weather( | |
self, | |
city | |
): | |
time, info, temperature = self.kkms_kssw_obj.weather_utils_obj.get_weather(city) | |
self.weather_info = f'Weather in {city.capitalize()} on {time} is {temperature} with {info}.' | |
return self.weather_info | |
def click_handler_for_get_weather_forecast( | |
self, | |
state, | |
district | |
): | |
self.weather_forecast = self.kkms_kssw_obj.weather_utils_obj.get_weather_forecast(state, district) | |
return self.weather_forecast | |
def click_handler_for_weather_forecast_summary( | |
self, | |
weather_forecast | |
): | |
self.weather_forecast_summary = self.kkms_kssw_obj.langchain_utils_obj.get_weather_forecast_summary(weather_forecast) | |
return self.weather_forecast_summary | |
def click_handler_for_load_files_urls( | |
self, | |
doc_type, | |
files_or_urls, | |
question_category | |
): | |
self.kkms_kssw_obj.upload_data( | |
doc_type=constants_utils.DATA_SOURCES[doc_type], | |
files_or_urls=files_or_urls, | |
index_category=question_category | |
) | |
def click_handler_for_get_indic_translation( | |
self, | |
eng_ans, | |
language='Hindi' | |
): | |
self.indic_translation = self.kkms_kssw_obj.translator_utils_obj.get_indic_google_translate(eng_ans, language) | |
return self.indic_translation | |
def click_handler_for_weather_forecast_districts_dropdown_list_update( | |
self, | |
state, | |
district | |
): | |
return gr.update( | |
choices=self.kkms_kssw_obj.weather_utils_obj.get_district_names(state) | |
) | |
def click_handler_for_weather_forecast_district( | |
self, | |
state, | |
district, | |
weather | |
): | |
return self.kkms_kssw_obj.weather_utils_obj.get_weather_forecast(state, district) | |
def click_handler_for_get_kb_sources( | |
self | |
): | |
def _serialize_sets(obj): | |
if isinstance(obj, set): | |
return list(obj) | |
return obj | |
self.kb_sources = self.kkms_kssw_obj.langchain_utils_obj.get_index_category_wise_data_sources() | |
# return json.dumps(self.kb_sources, default=_serialize_sets) | |
kb_sources = '' | |
for index_category, doc_type in self.kb_sources.items(): | |
if not doc_type: | |
continue | |
kb_sources += f'='*100 + '\n' | |
kb_sources += f'Question Category: {index_category}' | |
for dt, source in doc_type.items(): | |
if not source: | |
continue | |
kb_sources += '\n' + f'='*50 + '\n' | |
kb_sources += f'Document type: {dt}' | |
kb_sources += '\n' + f'='*25 | |
for doc in source: | |
kb_sources += f'\t\t\t\n{doc}' | |
kb_sources += '\n\n' | |
self.kb_sources = kb_sources | |
return self.kb_sources | |
def _upload_file(self, files): | |
file_paths = [file.name for file in files] | |
return file_paths | |
def select_widget( | |
self, | |
choice | |
): | |
if choice == "Custom Query": | |
return [ | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
] | |
elif choice == "General (AgGPT)": | |
return [ | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
] | |
elif choice == "Mandi Price": | |
return [ | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
] | |
elif choice == "Weather": | |
return [ | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
] | |
elif choice == "Load Custom Data": | |
return [ | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
] | |
elif choice == "Display Data Sources": | |
return [ | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True) | |
] | |
else: | |
return gr.update(visible=False) | |
def select_files_urls( | |
self, | |
choice | |
): | |
if choice == "PDF": | |
return [ | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
] | |
elif choice == "Online PDF": | |
return [ | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
] | |
elif choice == "Text File": | |
return [ | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
gr.update(visible=False), | |
] | |
elif choice == "URLs": | |
return [ | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=True), | |
] | |
else: | |
return [ | |
gr.update(visible=True), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
gr.update(visible=False), | |
] | |
with gr.Blocks(title='KKMS-Smart-Search-Demo') as demo: | |
dom = DomState( | |
index_type=constants_utils.INDEX_TYPE, | |
load_from_existing_index_file=constants_utils.LOAD_FROM_EXISTING_INDEX_STORE | |
) | |
widgets = gr.Radio( | |
[ | |
"Custom Query", | |
"General (AgGPT)", | |
"Mandi Price", | |
"Weather", | |
"Load Custom Data", | |
"Display Data Sources", | |
], | |
label="Query related to", | |
value="Custom Query" | |
) | |
############################################################################# | |
# Widget for Custom Queries | |
with gr.Row(visible=True) as rowCustomQuery: | |
with gr.Column(scale=1, min_width=600): | |
with gr.Tab(label='Relevant paragraphs'): | |
question_category = gr.Dropdown( | |
constants_utils.INDEX_CATEGORY, | |
label="Select Question Category", | |
value=constants_utils.INDEX_CATEGORY[0] | |
) | |
question = gr.Textbox(label="Enter your question", placeholder='Type the question here') | |
# Get the Relevant paragraphs for the question asked | |
relevant_paragraphs = gr.Textbox(label="Relevant paragraphs are:", value=dom.relevant_paragraphs, interactive=False) | |
b_relevant_paragraphs = gr.Button("Get Relevant paragraphs").style(size='sm') | |
b_relevant_paragraphs.click( | |
fn=dom.click_handler_for_get_relevant_paragraphs, | |
inputs=[question_category, question], | |
outputs=[relevant_paragraphs] | |
) | |
with gr.Column(scale=1): | |
with gr.Tab(label='Sources of relevant paragraphs'): | |
# Get the Sources of relevant paragraphs | |
sources_relevant_paragraphs = gr.Textbox(label="Sources of relevant paragraphs are:", interactive=False) | |
relevant_paragraphs.change( | |
dom.click_handler_for_relevant_paragraphs_source, | |
relevant_paragraphs, | |
sources_relevant_paragraphs | |
) | |
# Get the exact answer for the question asked from the retrieved Relevant paragraphs | |
with gr.Column(scale=1, min_width=600): | |
with gr.Tab(label='Answer'): | |
answer = gr.Textbox(label="Answer is:", value=dom.answer, interactive=False) | |
relevant_paragraphs.change( | |
dom.click_handler_for_get_answer, | |
[relevant_paragraphs, question], | |
answer | |
) | |
# Covert the answer to Indian language | |
with gr.Column(scale=1, min_width=600): | |
with gr.Tab(label='Answer in selected language'): | |
# Select the language | |
language = gr.Dropdown( | |
list(constants_utils.INDIC_LANGUAGE.keys()), | |
label="Select language", | |
value=list(constants_utils.INDIC_LANGUAGE.keys())[0] | |
) | |
indic_lang_answer = gr.Textbox(label="Answer in the selected language is:", interactive=False) | |
answer.change( | |
dom.click_handler_for_get_indic_translation, | |
answer, | |
indic_lang_answer | |
) | |
b_indic_lang_answer = gr.Button("Get answer in selected language").style(size='sm') | |
b_indic_lang_answer.click(fn=dom.click_handler_for_get_indic_translation, inputs=[answer, language], outputs=[indic_lang_answer]) | |
############################################################################# | |
# Widget for General Query using AgGPT | |
with gr.Row(visible=False) as rowGeneral: | |
with gr.Column(scale=1, min_width=600): | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox() | |
submit = gr.Button("Submit") | |
clear = gr.Button("Clear") | |
submit.click( | |
dom.kkms_kssw_obj.langchain_utils_obj.user, [msg, chatbot], [msg, chatbot] | |
).then(dom.kkms_kssw_obj.langchain_utils_obj.bot, chatbot, chatbot) | |
clear.click( | |
dom.kkms_kssw_obj.langchain_utils_obj.clear_history, None, chatbot, queue=False) | |
############################################################################# | |
# Widget for Mandi Price | |
with gr.Row(visible=False) as rowMandiPrice: | |
with gr.Column(scale=1, min_width=600): | |
# Select State | |
state_name = gr.Dropdown( | |
constants_utils.MANDI_PRICE_STATES, | |
label="Select state", | |
value=constants_utils.MANDI_PRICE_STATES[0] | |
) | |
# APMC name | |
apmc_name = gr.Textbox(label="Enter APMC name", placeholder='Type the APMC name here') | |
# APMC name | |
commodity_name = gr.Textbox(label="Enter Commodity name", placeholder='Type the Commodity name here') | |
# From/To date in yyyy-mm-dd format | |
from_date = gr.Textbox(label="From date?", value=dom.mandi_from_date, placeholder='Please enter the From date here in yyyy-mm-dd format') | |
to_date = gr.Textbox(label="To date?", value=dom.mandi_to_date, placeholder='Please enter the To date here in yyyy-mm-dd format') | |
with gr.Column(scale=1, min_width=600): | |
mandi_price = gr.Textbox(label=f"Mandi Price is:", value=dom.mandi_price, interactive=False) | |
b_summary = gr.Button("Get Mandi Price").style(size='sm') | |
b_summary.click(fn=dom.click_handler_for_mandi_price, inputs=[state_name, apmc_name, commodity_name, from_date, to_date], outputs=[mandi_price]) | |
############################################################################# | |
# Widget for Weather Info | |
with gr.Row(visible=False) as rowWeather: | |
########### Weather Forecast ########### | |
with gr.Column(scale=1, min_width=600): | |
with gr.Tab(label='Weather Forecast for next 5 days'): | |
# Select the State | |
state = gr.Dropdown( | |
list(constants_utils.WEATHER_FORECAST_STATE_CODES.keys()), | |
label="Select state", | |
value=list(constants_utils.WEATHER_FORECAST_STATE_CODES.keys())[0] | |
) | |
# Select District | |
district = gr.Dropdown( | |
choices=[], | |
label="Select District" | |
) | |
# Get districts of the selected state | |
state.change( | |
dom.click_handler_for_weather_forecast_districts_dropdown_list_update, | |
state, | |
district | |
) | |
# Get weather forecast on district selection event | |
district_weather = gr.Textbox(label=f"Weather forecast is:", interactive=False) | |
district.change( | |
dom.click_handler_for_weather_forecast_district, | |
[state, district], | |
district_weather | |
) | |
with gr.Column(scale=1, min_width=600): | |
with gr.Tab(label='Weather Forecast Summary'): | |
# Get the summary of the weather forecast | |
weather_forecast_summary = gr.Textbox(label="Weather Forecast Summary is:", interactive=False) | |
district.change( | |
dom.click_handler_for_weather_forecast_summary, | |
district_weather, | |
weather_forecast_summary | |
) | |
# Covert the weather forcast summary in Indian language | |
with gr.Column(scale=1, min_width=600): | |
with gr.Tab(label='Weather Forecast Summary in selected language'): | |
# Select the language | |
language = gr.Dropdown( | |
list(constants_utils.INDIC_LANGUAGE.keys()), | |
label="Select language", | |
value=list(constants_utils.INDIC_LANGUAGE.keys())[0] | |
) | |
indic_weather_forecast_summary = gr.Textbox(label="Weather Forecast Summary in the selected language is:", interactive=False) | |
# By default display weather forecast summary in Hindi. User can change it later on. | |
weather_forecast_summary.change( | |
dom.click_handler_for_get_indic_translation, | |
weather_forecast_summary, | |
indic_weather_forecast_summary | |
) | |
# User can get the weather forecast summary in their preferred language as well | |
b_indic_weather_forecast_summary = gr.Button("Get answer in selected language").style(size='sm') | |
b_indic_weather_forecast_summary.click(fn=dom.click_handler_for_get_indic_translation, inputs=[weather_forecast_summary, language], outputs=[indic_weather_forecast_summary]) | |
with gr.Column(scale=1, min_width=600): | |
with gr.Tab(label='Weather Info'): | |
weather = gr.Textbox(label=f"Current weather is:", interactive=False) | |
district.change( | |
dom.click_handler_for_get_weather, | |
district, | |
weather | |
) | |
############################################################################# | |
# Widget to load and process from the custom data source | |
with gr.Row(visible=False) as rowLoadCustomData: | |
with gr.Column(scale=1, min_width=600): | |
with gr.Tab(label='Load Custom Data (Do not upload data from the same file/url again. Once it is uploaded, it gets stored forever.)'): | |
question_category = gr.Dropdown( | |
constants_utils.INDEX_CATEGORY, | |
label="Select Query Type", | |
value=constants_utils.INDEX_CATEGORY[0] | |
) | |
doc_type = gr.Radio( | |
list(constants_utils.DATA_SOURCES.keys()), | |
label="Select data source (Supports uploading multiple Files/URLs)", | |
value="PDF" | |
) | |
with gr.Row(visible=True) as rowUploadPdf: | |
with gr.Column(scale=1, min_width=600): | |
file_output = gr.File() | |
upload_button = gr.UploadButton( | |
"Click to Upload PDF Files", | |
file_types=['.pdf'], | |
file_count="multiple" | |
) | |
upload_button.upload(dom._upload_file, upload_button, file_output) | |
b_files = gr.Button("Load PDF Files").style(size='sm') | |
b_files.click( | |
fn=dom.click_handler_for_load_files_urls, | |
inputs=[doc_type, file_output, question_category] | |
) | |
with gr.Row(visible=False) as rowUploadOnlinePdf: | |
with gr.Column(scale=1, min_width=600): | |
urls = gr.Textbox(label="Enter URLs for Online PDF (Supports uploading from multiple URLs. Enter the URLs in comma (,) separated format.)", placeholder='Type the URLs here') | |
b_urls = gr.Button("Load Online PDFs").style(size='sm') | |
b_urls.click( | |
fn=dom.click_handler_for_load_files_urls, | |
inputs=[doc_type, urls, question_category] | |
) | |
with gr.Row(visible=False) as rowUploadTextFile: | |
with gr.Column(scale=1, min_width=600): | |
file_output = gr.File() | |
upload_button = gr.UploadButton( | |
"Click to Upload Text Files", | |
file_types=['.txt'], | |
file_count="multiple" | |
) | |
upload_button.upload(dom._upload_file, upload_button, file_output) | |
b_files = gr.Button("Load Text Files").style(size='sm') | |
b_files.click( | |
fn=dom.click_handler_for_load_files_urls, | |
inputs=[doc_type, file_output, question_category] | |
) | |
with gr.Row(visible=False) as rowUploadUrls: | |
with gr.Column(scale=1, min_width=600): | |
urls = gr.Textbox(label="Enter URLs (Supports uploading from multiple URLs. Enter the URLs in comma (,) separated format.)", placeholder='Type the URLs here') | |
b_urls = gr.Button("Load URLs").style(size='sm') | |
b_urls.click( | |
fn=dom.click_handler_for_load_files_urls, | |
inputs=[doc_type, urls, question_category] | |
) | |
doc_type.change( | |
fn=dom.select_files_urls, | |
inputs=doc_type, | |
outputs=[ | |
rowUploadPdf, | |
rowUploadOnlinePdf, | |
rowUploadTextFile, | |
rowUploadUrls, | |
], | |
) | |
############################################################################# | |
# Widget to display what all PDFs/Text files, URLs are ingested and indexed for querying in the KB (Knowledge Base) | |
with gr.Row(visible=False) as rowDisplayDataSources: | |
with gr.Column(scale=1, min_width=600): | |
with gr.Tab(label='Following PDFs, Text files, and URLs have been ingested and indexed in the Knowledge Base and are available for querying.'): | |
kb_sources = gr.Textbox(label=f"Data loaded from:", value=dom.kb_sources, interactive=False) | |
b_kb_sources = gr.Button("Display Data Sources").style(size='sm') | |
b_kb_sources.click( | |
fn=dom.click_handler_for_get_kb_sources, | |
outputs=kb_sources | |
) | |
widgets.change( | |
fn=dom.select_widget, | |
inputs=widgets, | |
outputs=[ | |
rowCustomQuery, | |
rowGeneral, | |
rowMandiPrice, | |
rowWeather, | |
rowLoadCustomData, | |
rowDisplayDataSources, | |
], | |
) | |
demo.launch(share=False) | |