File size: 6,759 Bytes
f11dfb5
 
34c3ae0
f11dfb5
 
 
 
 
 
 
 
 
23fee25
f11dfb5
 
 
 
 
 
 
 
0a08480
f11dfb5
 
d21c9dd
 
 
34c3ae0
 
 
 
d21c9dd
34c3ae0
 
d21c9dd
 
f11dfb5
 
 
 
 
 
 
 
 
 
 
 
 
23fee25
f11dfb5
23fee25
 
f11dfb5
 
 
 
 
 
 
0a08480
f11dfb5
0a08480
 
f11dfb5
0a08480
f11dfb5
0a08480
 
 
 
f11dfb5
0a08480
f11dfb5
23fee25
f11dfb5
 
 
 
 
 
 
d21c9dd
f11dfb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48d6619
f11dfb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cce25bb
f11dfb5
 
 
 
 
d21c9dd
f11dfb5
 
 
 
 
 
 
 
 
 
 
d21c9dd
f11dfb5
0a08480
 
 
 
d21c9dd
0a08480
 
 
d21c9dd
 
 
 
0a08480
 
 
 
 
d21c9dd
 
 
 
 
0a08480
 
 
 
f11dfb5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import re
import os
from io import BytesIO
from dotenv import load_dotenv

import gradio as gr
import pandas as pd
from pandas import DataFrame as PandasDataFrame

from llm import MessageChatCompletion
from customization import css, js
from examples import example_1, example_2, example_3, example_4
from prompt_template import system_message_template, user_message_template

load_dotenv()

API_KEY = os.getenv("API_KEY")


df = pd.read_csv('subsectors.csv')
logs_columns = ['Abstract', 'Model', 'Results']
logs_df = PandasDataFrame(columns=logs_columns)


def download_logs():
    global logs_df

    # Convert the DataFrame into a CSV byte stream
    output = BytesIO()
    logs_df.to_csv(output, index=False)
    output.seek(0)  # Rewind the buffer

    # Return the CSV byte stream and filename for Gradio to generate a download link
    return output, "classification_logs.csv"


def build_context(row):
    subsector_name = row['Subsector']
    context = f"Subsector name: {subsector_name}. "
    context += f"{subsector_name} Definition: {row['Definition']}. "
    context += f"{subsector_name} keywords: {row['Keywords']}. "
    context += f"{subsector_name} Does include: {row['Does include']}. "
    context += f"{subsector_name} Does not include: {row['Does not include']}.\n"

    return context


def click_button(model, api_key, abstract):
    labels = df['Subsector'].tolist()
    prompt_context = [build_context(row) for _, row in df.iterrows()]
    language_model = MessageChatCompletion(model=model, api_key=api_key)
    system_message = system_message_template.format(prompt_context=prompt_context)
    user_message = user_message_template.format(labels=labels, abstract=abstract)
    language_model.new_system_message(content=system_message)
    language_model.new_user_message(content=user_message)
    language_model.send_message()

    response_reasoning = language_model.get_last_message()

    dict_pattern = r'\{.*?\}'
    match = re.search(dict_pattern, response_reasoning, re.DOTALL)

    if match and language_model.error is False:
        match_score_dict = eval(match.group(0))
    else:
        match_score_dict = {}

    # Update Logs
    new_log_entry = pd.DataFrame({'Abstract': [abstract], 'Model': [model], 'Results': [str(match_score_dict)]})
    global logs_df
    logs_df = pd.concat([logs_df, new_log_entry], ignore_index=True)

    return match_score_dict, response_reasoning, logs_df


def on_select(evt: gr.SelectData):  # SelectData is a subclass of EventData
    selected = df.iloc[[evt.index[0]]].iloc[0]
    name, definition, keywords, does_include, does_not_include = selected['Subsector'], selected['Definition'], selected['Keywords'], selected['Does include'], selected['Does not include']
    name_accordion = gr.Accordion(label=name)
    return name_accordion, definition, keywords, does_include, does_not_include


# --- GRADIO INTERFACE --- #
with gr.Blocks(css=css, js=js) as demo:
    state_lotto = gr.State()
    selected_x_labels = gr.State()
    with gr.Tab("Patent Discovery"):
        with gr.Row():
            with gr.Column(scale=5):
                dropdown_model = gr.Dropdown(
                    label="Model",
                    choices=["gpt-4", "gpt-4-turbo-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-0125"],
                    value="gpt-3.5-turbo-0125",
                    multiselect=False,
                    interactive=True
                )
            with gr.Column(scale=5):
                api_key = gr.Textbox(
                    label="API Key",
                    interactive=True,
                    lines=1,
                    max_lines=1,
                    type="password",
                    value=API_KEY
                )
        with gr.Row(equal_height=True):
            abstract_description = gr.Textbox(
                label="Abstract description",
                lines=5,
                max_lines=10000,
                interactive=True,
                placeholder="Input a patent abstract"
            )
        with gr.Row():
            with gr.Accordion(label="Example Abstracts", open=False):
                gr.Examples(
                    examples=[example_1, example_2, example_3, example_4],
                    inputs=abstract_description,
                    fn=click_button,
                    label="",
                    # cache_examples=True,
                )
        with gr.Row():
            btn_get_result = gr.Button("Classify")
        with gr.Row(elem_classes=['all_results']):
            with gr.Column(scale=4):
                label_result = gr.Label(num_top_classes=None)
            with gr.Column(scale=6):
                reasoning = gr.Markdown(label="Reasoning", elem_classes=['reasoning_results'])

    with gr.Tab("Subsector definitions"):
        with gr.Row():
            with gr.Column(scale=4):
                df_subsectors = gr.DataFrame(df[['Subsector']], interactive=False, height=800)
            with gr.Column(scale=6):
                with gr.Accordion(label='Artificial Intelligence, Big Data and Analytics') as subsector_name:
                    s1_definition = gr.Textbox(label="Definition", lines=5, max_lines=100, value="Virtual reality (VR) is an artificial, computer-generated simulation or recreation of a real life environment or situation. Augmented reality (AR) is a technology that layers computer-generated enhancements atop an existing reality in order to make it more meaningful through the ability to interact with it. ")
                    s1_keywords = gr.Textbox(label="Keywords", lines=5, max_lines=100,
                                             value="Mixed Reality, 360 video, frame rate, metaverse, virtual world, cross reality, Artificial intelligence, computer vision")
                    does_include = gr.Textbox(label="Does include", lines=4)
                    does_not_include = gr.Textbox(label="Does not include", lines=3)

    with gr.Tab("Logs"):
        output_dataframe = gr.Dataframe(
            value=logs_df,
            type="pandas",
            height=500,
            headers=['Abstract', 'Model', 'Results'],
            interactive=False,
            column_widths=["45%", "10%", "45%"],
        )
        btn_export = gr.Button(
            value="Export to CSV",
            size="sm",
        )

    btn_get_result.click(
        fn=click_button,
        inputs=[dropdown_model, api_key, abstract_description],
        outputs=[label_result, reasoning, output_dataframe])

    btn_export.click(
        fn=download_logs,
    )

    df_subsectors.select(
        fn=on_select,
        outputs=[subsector_name, s1_definition, s1_keywords, does_include, does_not_include]
    )

if __name__ == "__main__":
    # demo.queue()
    demo.launch()