|
import re |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
from pandas import DataFrame as PandasDataFrame |
|
|
|
from llm import MessageChatCompletion |
|
from customization import css, js |
|
from examples import example_1, example_2, example_3, example_4 |
|
from prompt_template import system_message_template, user_message_template |
|
|
|
load_dotenv() |
|
|
|
API_KEY = os.getenv("API_KEY") |
|
|
|
|
|
df = pd.read_csv('subsectors.csv') |
|
logs_columns = ['Abstract', 'Model', 'Results'] |
|
logs_df = PandasDataFrame(columns=logs_columns) |
|
|
|
|
|
def download_logs(): |
|
global logs_df |
|
|
|
if os.name == 'nt': |
|
desktop = os.path.join(os.path.join(os.environ['USERPROFILE']), 'Desktop') |
|
else: |
|
desktop = os.path.join(os.path.join(os.path.expanduser('~')), 'Desktop') |
|
|
|
|
|
file_path = os.path.join(desktop, 'classification_logs.csv') |
|
|
|
|
|
logs_df.to_csv(file_path) |
|
|
|
|
|
def build_context(row): |
|
subsector_name = row['Subsector'] |
|
context = f"Subsector name: {subsector_name}. " |
|
context += f"{subsector_name} Definition: {row['Definition']}. " |
|
context += f"{subsector_name} keywords: {row['Keywords']}. " |
|
context += f"{subsector_name} Does include: {row['Does include']}. " |
|
context += f"{subsector_name} Does not include: {row['Does not include']}.\n" |
|
|
|
return context |
|
|
|
|
|
def click_button(model, api_key, abstract): |
|
labels = df['Subsector'].tolist() |
|
prompt_context = [build_context(row) for _, row in df.iterrows()] |
|
language_model = MessageChatCompletion(model=model, api_key=api_key) |
|
system_message = system_message_template.format(prompt_context=prompt_context) |
|
user_message = user_message_template.format(labels=labels, abstract=abstract) |
|
language_model.new_system_message(content=system_message) |
|
language_model.new_user_message(content=user_message) |
|
language_model.send_message() |
|
|
|
response_reasoning = language_model.get_last_message() |
|
|
|
dict_pattern = r'\{.*?\}' |
|
match = re.search(dict_pattern, response_reasoning, re.DOTALL) |
|
|
|
if match and language_model.error is False: |
|
match_score_dict = eval(match.group(0)) |
|
else: |
|
match_score_dict = {} |
|
|
|
|
|
new_log_entry = pd.DataFrame({'Abstract': [abstract], 'Model': [model], 'Results': [str(match_score_dict)]}) |
|
global logs_df |
|
logs_df = pd.concat([logs_df, new_log_entry], ignore_index=True) |
|
|
|
return match_score_dict, response_reasoning, logs_df |
|
|
|
|
|
def on_select(evt: gr.SelectData): |
|
selected = df.iloc[[evt.index[0]]].iloc[0] |
|
name, definition, keywords, does_include, does_not_include = selected['Subsector'], selected['Definition'], selected['Keywords'], selected['Does include'], selected['Does not include'] |
|
name_accordion = gr.Accordion(label=name) |
|
return name_accordion, definition, keywords, does_include, does_not_include |
|
|
|
|
|
|
|
with gr.Blocks(css=css, js=js) as demo: |
|
state_lotto = gr.State() |
|
selected_x_labels = gr.State() |
|
with gr.Tab("Patent Discovery"): |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
dropdown_model = gr.Dropdown( |
|
label="Model", |
|
choices=["gpt-4", "gpt-4-turbo-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-0125"], |
|
value="gpt-3.5-turbo-0125", |
|
multiselect=False, |
|
interactive=True |
|
) |
|
with gr.Column(scale=5): |
|
api_key = gr.Textbox( |
|
label="API KEY", |
|
interactive=True, |
|
lines=1, |
|
max_lines=1, |
|
type="password", |
|
value=API_KEY |
|
) |
|
with gr.Row(equal_height=True): |
|
abstract_description = gr.Textbox( |
|
label="Abstract description", |
|
lines=5, |
|
max_lines=10000, |
|
interactive=True, |
|
placeholder="Input a patent abstract" |
|
) |
|
with gr.Row(): |
|
with gr.Accordion(label="Example Abstracts", open=False): |
|
gr.Examples( |
|
examples=[example_1, example_2, example_3, example_4], |
|
inputs=abstract_description, |
|
fn=click_button, |
|
label="", |
|
|
|
) |
|
with gr.Row(): |
|
btn_get_result = gr.Button("Show classification") |
|
with gr.Row(elem_classes=['all_results']): |
|
with gr.Column(scale=4): |
|
label_result = gr.Label(num_top_classes=None) |
|
with gr.Column(scale=6): |
|
reasoning = gr.Markdown(label="Reasoning", elem_classes=['reasoning_results']) |
|
|
|
with gr.Tab("Subsector definitions"): |
|
with gr.Row(): |
|
with gr.Column(scale=4): |
|
df_subsectors = gr.DataFrame(df[['Subsector']], interactive=False, height=800) |
|
with gr.Column(scale=6): |
|
with gr.Accordion(label='Artificial Intelligence, Big Data and Analytics') as subsector_name: |
|
s1_definition = gr.Textbox(label="Definition", lines=5, max_lines=100, value="Virtual reality (VR) is an artificial, computer-generated simulation or recreation of a real life environment or situation. Augmented reality (AR) is a technology that layers computer-generated enhancements atop an existing reality in order to make it more meaningful through the ability to interact with it. ") |
|
s1_keywords = gr.Textbox(label="Keywords", lines=5, max_lines=100, |
|
value="Mixed Reality, 360 video, frame rate, metaverse, virtual world, cross reality, Artificial intelligence, computer vision") |
|
does_include = gr.Textbox(label="Does include", lines=4) |
|
does_not_include = gr.Textbox(label="Does not include", lines=3) |
|
|
|
with gr.Tab("Logs"): |
|
output_dataframe = gr.Dataframe( |
|
value=logs_df, |
|
type="pandas", |
|
height=500, |
|
headers=['Abstract', 'Model', 'Results'], |
|
interactive=False, |
|
column_widths=["45%", "10%", "45%"], |
|
) |
|
btn_export = gr.Button( |
|
value="Export to CSV", |
|
size="sm", |
|
) |
|
|
|
btn_get_result.click( |
|
fn=click_button, |
|
inputs=[dropdown_model, api_key, abstract_description], |
|
outputs=[label_result, reasoning, output_dataframe]) |
|
|
|
btn_export.click( |
|
fn=download_logs, |
|
) |
|
|
|
df_subsectors.select( |
|
fn=on_select, |
|
outputs=[subsector_name, s1_definition, s1_keywords, does_include, does_not_include] |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
demo.launch() |
|
|