joaomorossini's picture
Reverting problematic commit 'update download functionality to work in HF Spaces'
1ada7a5
raw
history blame
6.96 kB
import re
import os
from dotenv import load_dotenv
import gradio as gr
import pandas as pd
from pandas import DataFrame as PandasDataFrame
from llm import MessageChatCompletion
from customization import css, js
from examples import example_1, example_2, example_3, example_4
from prompt_template import system_message_template, user_message_template
load_dotenv()
API_KEY = os.getenv("API_KEY")
df = pd.read_csv('subsectors.csv')
logs_columns = ['Abstract', 'Model', 'Results']
logs_df = PandasDataFrame(columns=logs_columns)
def download_logs():
global logs_df
# Check for the current operating system's desktop path
if os.name == 'nt': # For Windows
desktop = os.path.join(os.path.join(os.environ['USERPROFILE']), 'Desktop')
else: # For macOS and Linux
desktop = os.path.join(os.path.join(os.path.expanduser('~')), 'Desktop')
# Define the path to save the CSV file on the desktop
file_path = os.path.join(desktop, 'classification_logs.csv')
# Save the DataFrame to the CSV file on the desktop
logs_df.to_csv(file_path)
def build_context(row):
subsector_name = row['Subsector']
context = f"Subsector name: {subsector_name}. "
context += f"{subsector_name} Definition: {row['Definition']}. "
context += f"{subsector_name} keywords: {row['Keywords']}. "
context += f"{subsector_name} Does include: {row['Does include']}. "
context += f"{subsector_name} Does not include: {row['Does not include']}.\n"
return context
def click_button(model, api_key, abstract):
labels = df['Subsector'].tolist()
prompt_context = [build_context(row) for _, row in df.iterrows()]
language_model = MessageChatCompletion(model=model, api_key=api_key)
system_message = system_message_template.format(prompt_context=prompt_context)
user_message = user_message_template.format(labels=labels, abstract=abstract)
language_model.new_system_message(content=system_message)
language_model.new_user_message(content=user_message)
language_model.send_message()
response_reasoning = language_model.get_last_message()
dict_pattern = r'\{.*?\}'
match = re.search(dict_pattern, response_reasoning, re.DOTALL)
if match and language_model.error is False:
match_score_dict = eval(match.group(0))
else:
match_score_dict = {}
# Update Logs
new_log_entry = pd.DataFrame({'Abstract': [abstract], 'Model': [model], 'Results': [str(match_score_dict)]})
global logs_df
logs_df = pd.concat([logs_df, new_log_entry], ignore_index=True)
return match_score_dict, response_reasoning, logs_df
def on_select(evt: gr.SelectData): # SelectData is a subclass of EventData
selected = df.iloc[[evt.index[0]]].iloc[0]
name, definition, keywords, does_include, does_not_include = selected['Subsector'], selected['Definition'], selected['Keywords'], selected['Does include'], selected['Does not include']
name_accordion = gr.Accordion(label=name)
return name_accordion, definition, keywords, does_include, does_not_include
# --- GRADIO INTERFACE --- #
with gr.Blocks(css=css, js=js) as demo:
state_lotto = gr.State()
selected_x_labels = gr.State()
with gr.Tab("Patent Discovery"):
with gr.Row():
with gr.Column(scale=5):
dropdown_model = gr.Dropdown(
label="Model",
choices=["gpt-4", "gpt-4-turbo-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-0125"],
value="gpt-3.5-turbo-0125",
multiselect=False,
interactive=True
)
with gr.Column(scale=5):
api_key = gr.Textbox(
label="API Key",
interactive=True,
lines=1,
max_lines=1,
type="password",
value=API_KEY
)
with gr.Row(equal_height=True):
abstract_description = gr.Textbox(
label="Abstract description",
lines=5,
max_lines=10000,
interactive=True,
placeholder="Input a patent abstract"
)
with gr.Row():
with gr.Accordion(label="Example Abstracts", open=False):
gr.Examples(
examples=[example_1, example_2, example_3, example_4],
inputs=abstract_description,
fn=click_button,
label="",
# cache_examples=True,
)
with gr.Row():
btn_get_result = gr.Button("Classify")
with gr.Row(elem_classes=['all_results']):
with gr.Column(scale=4):
label_result = gr.Label(num_top_classes=None)
with gr.Column(scale=6):
reasoning = gr.Markdown(label="Reasoning", elem_classes=['reasoning_results'])
with gr.Tab("Subsector definitions"):
with gr.Row():
with gr.Column(scale=4):
df_subsectors = gr.DataFrame(df[['Subsector']], interactive=False, height=800)
with gr.Column(scale=6):
with gr.Accordion(label='Artificial Intelligence, Big Data and Analytics') as subsector_name:
s1_definition = gr.Textbox(label="Definition", lines=5, max_lines=100, value="Virtual reality (VR) is an artificial, computer-generated simulation or recreation of a real life environment or situation. Augmented reality (AR) is a technology that layers computer-generated enhancements atop an existing reality in order to make it more meaningful through the ability to interact with it. ")
s1_keywords = gr.Textbox(label="Keywords", lines=5, max_lines=100,
value="Mixed Reality, 360 video, frame rate, metaverse, virtual world, cross reality, Artificial intelligence, computer vision")
does_include = gr.Textbox(label="Does include", lines=4)
does_not_include = gr.Textbox(label="Does not include", lines=3)
with gr.Tab("Logs"):
output_dataframe = gr.Dataframe(
value=logs_df,
type="pandas",
height=500,
headers=['Abstract', 'Model', 'Results'],
interactive=False,
column_widths=["45%", "10%", "45%"],
)
btn_export = gr.Button(
value="Export to CSV",
size="sm",
)
btn_get_result.click(
fn=click_button,
inputs=[dropdown_model, api_key, abstract_description],
outputs=[label_result, reasoning, output_dataframe])
btn_export.click(
fn=download_logs,
)
df_subsectors.select(
fn=on_select,
outputs=[subsector_name, s1_definition, s1_keywords, does_include, does_not_include]
)
if __name__ == "__main__":
# demo.queue()
demo.launch()