|
import traceback |
|
from io import StringIO |
|
from typing import Optional |
|
|
|
import gradio as gr |
|
import pandas as pd |
|
|
|
from utils import pipeline |
|
from utils.models import list_models |
|
from loguru import logger |
|
|
|
|
|
def read_data(filepath: str) -> Optional[pd.DataFrame]: |
|
if filepath.endswith('.xlsx'): |
|
df = pd.read_excel(filepath) |
|
elif filepath.endswith('.csv'): |
|
df = pd.read_csv(filepath) |
|
else: |
|
raise Exception('File type not supported') |
|
return df |
|
|
|
|
|
def process( |
|
task_name: str, |
|
model_name: str, |
|
pooling: str, |
|
text: str, |
|
file=None, |
|
) -> (None, pd.DataFrame, str): |
|
try: |
|
logger.info(f'Processing {task_name} with {model_name} and {pooling}') |
|
|
|
if file: |
|
df = read_data(file.name) |
|
elif text: |
|
string_io = StringIO(text) |
|
df = pd.read_csv(string_io) |
|
assert len(df) >= 1, 'No input data' |
|
else: |
|
raise Exception('No input data') |
|
|
|
|
|
if task_name == 'Originality': |
|
df = pipeline.p0_originality(df, model_name, pooling) |
|
elif task_name == 'Flexibility': |
|
df = pipeline.p1_flexibility(df, model_name, pooling) |
|
else: |
|
raise Exception('Task not supported') |
|
|
|
|
|
path = 'output.csv' |
|
df.to_csv(path, index=False, encoding='utf-8-sig') |
|
return None, df.iloc[:10], path |
|
|
|
except: |
|
error = traceback.format_exc() |
|
logger.warning({ |
|
'error': error, |
|
'task_name': task_name, |
|
'model_name': model_name, |
|
'pooling': pooling, |
|
'text': text, |
|
'file': file, |
|
}) |
|
return {'Info': 'Something wrong', 'Error': traceback.format_exc()}, None, None |
|
|
|
|
|
|
|
task_name_dropdown = gr.components.Dropdown( |
|
label='Task Name', |
|
value='Originality', |
|
choices=['Originality', 'Flexibility'] |
|
) |
|
model_name_dropdown = gr.components.Dropdown( |
|
label='Model Name', |
|
value=list_models[0], |
|
choices=list_models |
|
) |
|
pooling_dropdown = gr.components.Dropdown( |
|
label='Pooling', |
|
value='mean', |
|
choices=['mean', 'cls'] |
|
) |
|
text_input = gr.components.Textbox( |
|
value=open('data/example_xlm.csv', 'r').read(), |
|
lines=10, |
|
) |
|
file_input = gr.components.File(label='Input File', file_types=['.csv', '.xlsx']) |
|
|
|
|
|
text_output = gr.components.Textbox(label='Output') |
|
dataframe_output = gr.components.Dataframe(label='DataFrame') |
|
file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx']) |
|
|
|
app = gr.Interface( |
|
fn=process, |
|
inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input], |
|
outputs=[text_output, dataframe_output, file_output], |
|
description=open('data/description.txt', 'r').read(), |
|
title='TransDis-CreativityAutoAssessment', |
|
) |
|
app.launch() |
|
|