Reverting problematic commit 'update download functionality to work in HF Spaces'
1ada7a5
import re | |
import os | |
from dotenv import load_dotenv | |
import gradio as gr | |
import pandas as pd | |
from pandas import DataFrame as PandasDataFrame | |
from llm import MessageChatCompletion | |
from customization import css, js | |
from examples import example_1, example_2, example_3, example_4 | |
from prompt_template import system_message_template, user_message_template | |
load_dotenv() | |
API_KEY = os.getenv("API_KEY") | |
df = pd.read_csv('subsectors.csv') | |
logs_columns = ['Abstract', 'Model', 'Results'] | |
logs_df = PandasDataFrame(columns=logs_columns) | |
def download_logs(): | |
global logs_df | |
# Check for the current operating system's desktop path | |
if os.name == 'nt': # For Windows | |
desktop = os.path.join(os.path.join(os.environ['USERPROFILE']), 'Desktop') | |
else: # For macOS and Linux | |
desktop = os.path.join(os.path.join(os.path.expanduser('~')), 'Desktop') | |
# Define the path to save the CSV file on the desktop | |
file_path = os.path.join(desktop, 'classification_logs.csv') | |
# Save the DataFrame to the CSV file on the desktop | |
logs_df.to_csv(file_path) | |
def build_context(row): | |
subsector_name = row['Subsector'] | |
context = f"Subsector name: {subsector_name}. " | |
context += f"{subsector_name} Definition: {row['Definition']}. " | |
context += f"{subsector_name} keywords: {row['Keywords']}. " | |
context += f"{subsector_name} Does include: {row['Does include']}. " | |
context += f"{subsector_name} Does not include: {row['Does not include']}.\n" | |
return context | |
def click_button(model, api_key, abstract): | |
labels = df['Subsector'].tolist() | |
prompt_context = [build_context(row) for _, row in df.iterrows()] | |
language_model = MessageChatCompletion(model=model, api_key=api_key) | |
system_message = system_message_template.format(prompt_context=prompt_context) | |
user_message = user_message_template.format(labels=labels, abstract=abstract) | |
language_model.new_system_message(content=system_message) | |
language_model.new_user_message(content=user_message) | |
language_model.send_message() | |
response_reasoning = language_model.get_last_message() | |
dict_pattern = r'\{.*?\}' | |
match = re.search(dict_pattern, response_reasoning, re.DOTALL) | |
if match and language_model.error is False: | |
match_score_dict = eval(match.group(0)) | |
else: | |
match_score_dict = {} | |
# Update Logs | |
new_log_entry = pd.DataFrame({'Abstract': [abstract], 'Model': [model], 'Results': [str(match_score_dict)]}) | |
global logs_df | |
logs_df = pd.concat([logs_df, new_log_entry], ignore_index=True) | |
return match_score_dict, response_reasoning, logs_df | |
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData | |
selected = df.iloc[[evt.index[0]]].iloc[0] | |
name, definition, keywords, does_include, does_not_include = selected['Subsector'], selected['Definition'], selected['Keywords'], selected['Does include'], selected['Does not include'] | |
name_accordion = gr.Accordion(label=name) | |
return name_accordion, definition, keywords, does_include, does_not_include | |
# --- GRADIO INTERFACE --- # | |
with gr.Blocks(css=css, js=js) as demo: | |
state_lotto = gr.State() | |
selected_x_labels = gr.State() | |
with gr.Tab("Patent Discovery"): | |
with gr.Row(): | |
with gr.Column(scale=5): | |
dropdown_model = gr.Dropdown( | |
label="Model", | |
choices=["gpt-4", "gpt-4-turbo-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-0125"], | |
value="gpt-3.5-turbo-0125", | |
multiselect=False, | |
interactive=True | |
) | |
with gr.Column(scale=5): | |
api_key = gr.Textbox( | |
label="API Key", | |
interactive=True, | |
lines=1, | |
max_lines=1, | |
type="password", | |
value=API_KEY | |
) | |
with gr.Row(equal_height=True): | |
abstract_description = gr.Textbox( | |
label="Abstract description", | |
lines=5, | |
max_lines=10000, | |
interactive=True, | |
placeholder="Input a patent abstract" | |
) | |
with gr.Row(): | |
with gr.Accordion(label="Example Abstracts", open=False): | |
gr.Examples( | |
examples=[example_1, example_2, example_3, example_4], | |
inputs=abstract_description, | |
fn=click_button, | |
label="", | |
# cache_examples=True, | |
) | |
with gr.Row(): | |
btn_get_result = gr.Button("Classify") | |
with gr.Row(elem_classes=['all_results']): | |
with gr.Column(scale=4): | |
label_result = gr.Label(num_top_classes=None) | |
with gr.Column(scale=6): | |
reasoning = gr.Markdown(label="Reasoning", elem_classes=['reasoning_results']) | |
with gr.Tab("Subsector definitions"): | |
with gr.Row(): | |
with gr.Column(scale=4): | |
df_subsectors = gr.DataFrame(df[['Subsector']], interactive=False, height=800) | |
with gr.Column(scale=6): | |
with gr.Accordion(label='Artificial Intelligence, Big Data and Analytics') as subsector_name: | |
s1_definition = gr.Textbox(label="Definition", lines=5, max_lines=100, value="Virtual reality (VR) is an artificial, computer-generated simulation or recreation of a real life environment or situation. Augmented reality (AR) is a technology that layers computer-generated enhancements atop an existing reality in order to make it more meaningful through the ability to interact with it. ") | |
s1_keywords = gr.Textbox(label="Keywords", lines=5, max_lines=100, | |
value="Mixed Reality, 360 video, frame rate, metaverse, virtual world, cross reality, Artificial intelligence, computer vision") | |
does_include = gr.Textbox(label="Does include", lines=4) | |
does_not_include = gr.Textbox(label="Does not include", lines=3) | |
with gr.Tab("Logs"): | |
output_dataframe = gr.Dataframe( | |
value=logs_df, | |
type="pandas", | |
height=500, | |
headers=['Abstract', 'Model', 'Results'], | |
interactive=False, | |
column_widths=["45%", "10%", "45%"], | |
) | |
btn_export = gr.Button( | |
value="Export to CSV", | |
size="sm", | |
) | |
btn_get_result.click( | |
fn=click_button, | |
inputs=[dropdown_model, api_key, abstract_description], | |
outputs=[label_result, reasoning, output_dataframe]) | |
btn_export.click( | |
fn=download_logs, | |
) | |
df_subsectors.select( | |
fn=on_select, | |
outputs=[subsector_name, s1_definition, s1_keywords, does_include, does_not_include] | |
) | |
if __name__ == "__main__": | |
# demo.queue() | |
demo.launch() | |