saadob12's picture
Update app.py
a154d32
raw
history blame
4.71 kB
import streamlit as st
import torch
import pandas as pd
from io import StringIO
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class preProcess:
def __init__(self, filename, titlename):
self.filename = filename
self.title = titlename + '\n'
def read_data(self):
df = pd.read_csv(self.filename)
return df
def check_columns(self, df):
if (len(df.columns) > 4):
st.error('File has more than 3 coloumns.')
return False
if (len(df.columns) == 0):
st.error('File has no column.')
return False
else:
return True
def format_data(self, df):
headers = [[] for i in range(0, len(df.columns))]
for i in range(len(df.columns)):
headers[i] = list(df[df.columns[i]])
zipped = list(zip(*headers))
res = [' '.join(map(str,tups)) for tups in zipped]
if len(df.columns) < 3:
input_format = ' x-y values ' + ' - '.join(list(df.columns)) + ' values ' + ' , '.join(res)
else:
input_format = ' labels ' + ' - '.join(list(df.columns)) + ' values ' + ' , '.join(res)
return input_format
def combine_title_data(self,df):
data = self.format_data(df)
title_data = ' '.join([self.title,data])
return title_data
class Model:
def __init__(self,text,mode):
self.padding = 'max_length'
self.truncation = True
self.prefix = 'C2T: '
self.device = device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.text = text
if mode.lower() == 'simple':
self.tokenizer = AutoTokenizer.from_pretrained('saadob12/t5_C2T_big')
self.model = AutoModelForSeq2SeqLM.from_pretrained('saadob12/t5_C2T_big').to(self.device)
elif mode.lower() == 'analytical':
self.tokenizer = AutoTokenizer.from_pretrained('saadob12/t5_autochart_2')
self.model = AutoModelForSeq2SeqLM.from_pretrained('saadob12/t5_autochart_2').to(self.device)
def generate(self):
tokens = self.tokenizer.encode(self.prefix + self.text, truncation=self.truncation, padding=self.padding, return_tensors='pt').to(self.device)
generated = self.model.generate(tokens, num_beams=4, max_length=256)
tgt_text = self.tokenizer.decode(generated[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
summary = str(tgt_text).strip('[]""')
if 'barchart' in summary:
summary.replace('barchart','statistic')
elif 'bar graph' in summary:
summary.replace('bar graph','statistic')
elif 'bar plot' in summary:
summary.replace('bar plot','statistic')
elif 'scatter plot' in summary:
summary.replace('scatter plot','statistic')
elif 'scatter graph' in summary:
summary.replace('scatter graph','statistic')
elif 'scatterchart' in summary:
summary.replace('scatter chart','statistic')
elif 'line plot' in summary:
summary.replace('line plot','statistic')
elif 'line graph' in summary:
summary.replace('line graph','statistic')
elif 'linechart' in summary:
summary.replace('linechart','statistic')
if 'graph' in summary:
summary.replace('graph','statistic')
return summary
st.title('Chart and Data Summarization')
st.write('This application generates a summary of a datafile (.csv) (or the underlying data of a chart). Right now, it only generates summaries of files with maximum of four columns. If the file contains more than four columns, the app will throw an error.')
mode = st.selectbox('What kind of summary do you want?',
('Simple', 'Analytical'))
st.write('You selected: ' + mode + ' summary.')
title = st.text_input('Add appropriate Title of the .csv file', 'State minimum wage rates in the United States as of January 1 , 2020')
st.write('Title of the file is: ' + title)
uploaded_file = st.file_uploader("Upload only .csv file")
if uploaded_file is not None and mode is not None and title is not None:
st.write('Preprocessing file...')
p = preProcess(uploaded_file, title)
contents = p.read_data()
check = p.check_columns(contents)
if check:
st.write('Your file contents:\n')
st.write(contents)
title_data = p.combine_title_data(contents)
st.write('Linearized input format of the data file:\n ')
st.markdown('**'+ title_data + '**')
st.write('Loading model...')
model = Model(title_data, mode)
st.write('Model loading done!\nGenerating Summary...')
summary = model.generate()
st.write('Generated Summary:\n')
st.markdown('**'+ summary + '**')