Qifan Zhang
fear: add log
e691ea0
raw history blame
No virus
2.89 kB
import traceback
from io import StringIO
from typing import Optional
import gradio as gr
import pandas as pd
from utils import pipeline
from utils.models import list_models
from loguru import logger
def read_data(filepath: str) -> Optional[pd.DataFrame]:
if filepath.endswith('.xlsx'):
df = pd.read_excel(filepath)
elif filepath.endswith('.csv'):
df = pd.read_csv(filepath)
else:
raise Exception('File type not supported')
return df
def process(
task_name: str,
model_name: str,
pooling: str,
text: str,
file=None,
) -> (None, pd.DataFrame, str):
try:
logger.info(f'Processing {task_name} with {model_name} and {pooling}')
# load file
if file:
df = read_data(file.name)
elif text:
string_io = StringIO(text)
df = pd.read_csv(string_io)
assert len(df) >= 1, 'No input data'
else:
raise Exception('No input data')
# process
if task_name == 'Originality':
df = pipeline.p0_originality(df, model_name, pooling)
elif task_name == 'Flexibility':
df = pipeline.p1_flexibility(df, model_name, pooling)
else:
raise Exception('Task not supported')
# save
path = 'output.csv'
df.to_csv(path, index=False, encoding='utf-8-sig')
return None, df.iloc[:10], path
except:
error = traceback.format_exc()
logger.warning({
'error': error,
'task_name': task_name,
'model_name': model_name,
'pooling': pooling,
'text': text,
'file': file,
})
return {'Info': 'Something wrong', 'Error': traceback.format_exc()}, None, None
# input
task_name_dropdown = gr.components.Dropdown(
label='Task Name',
value='Originality',
choices=['Originality', 'Flexibility']
)
model_name_dropdown = gr.components.Dropdown(
label='Model Name',
value=list_models[0],
choices=list_models
)
pooling_dropdown = gr.components.Dropdown(
label='Pooling',
value='mean',
choices=['mean', 'cls']
)
text_input = gr.components.Textbox(
value=open('data/example_xlm.csv', 'r').read(),
lines=10,
)
file_input = gr.components.File(label='Input File', file_types=['.csv', '.xlsx'])
# output
text_output = gr.components.Textbox(label='Output')
dataframe_output = gr.components.Dataframe(label='DataFrame')
file_output = gr.components.File(label='Output File', file_types=['.csv', '.xlsx'])
app = gr.Interface(
fn=process,
inputs=[task_name_dropdown, model_name_dropdown, pooling_dropdown, text_input, file_input],
outputs=[text_output, dataframe_output, file_output],
description=open('data/description.txt', 'r').read(),
title='TransDis-CreativityAutoAssessment',
)
app.launch()