Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import pandas as pd | |
| from io import StringIO | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| class preProcess: | |
| def __init__(self, filename, titlename): | |
| self.filename = filename | |
| self.title = titlename + '\n' | |
| def read_data(self): | |
| df = pd.read_csv(self.filename) | |
| return df | |
| def check_columns(self, df): | |
| if (len(df.columns) > 4): | |
| st.error('File has more than 3 coloumns.') | |
| return False | |
| if (len(df.columns) == 0): | |
| st.error('File has no column.') | |
| return False | |
| else: | |
| return True | |
| def format_data(self, df): | |
| headers = [[] for i in range(0, len(df.columns))] | |
| for i in range(len(df.columns)): | |
| headers[i] = list(df[df.columns[i]]) | |
| zipped = list(zip(*headers)) | |
| res = [' '.join(map(str,tups)) for tups in zipped] | |
| if len(df.columns) < 3: | |
| input_format = ' x-y values ' + ' - '.join(list(df.columns)) + ' values ' + ' , '.join(res) | |
| else: | |
| input_format = ' labels ' + ' - '.join(list(df.columns)) + ' values ' + ' , '.join(res) | |
| return input_format | |
| def combine_title_data(self,df): | |
| data = self.format_data(df) | |
| title_data = ' '.join([self.title,data]) | |
| return title_data | |
| class Model: | |
| def __init__(self,text,mode): | |
| self.padding = 'max_length' | |
| self.truncation = True | |
| self.prefix = 'C2T: ' | |
| self.device = device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| self.text = text | |
| if mode.lower() == 'simple': | |
| self.tokenizer = AutoTokenizer.from_pretrained('saadob12/t5_C2T_big') | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained('saadob12/t5_C2T_big').to(self.device) | |
| elif mode.lower() == 'analytical': | |
| self.tokenizer = AutoTokenizer.from_pretrained('saadob12/t5_autochart_2') | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained('saadob12/t5_autochart_2').to(self.device) | |
| def generate(self): | |
| tokens = self.tokenizer.encode(self.prefix + self.text, truncation=self.truncation, padding=self.padding, return_tensors='pt').to(self.device) | |
| generated = self.model.generate(tokens, num_beams=4, max_length=256) | |
| tgt_text = self.tokenizer.decode(generated[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| summary = str(tgt_text).strip('[]""') | |
| if 'barchart' in summary: | |
| summary.replace('barchart','statistic') | |
| elif 'bar graph' in summary: | |
| summary.replace('bar graph','statistic') | |
| elif 'bar plot' in summary: | |
| summary.replace('bar plot','statistic') | |
| elif 'scatter plot' in summary: | |
| summary.replace('scatter plot','statistic') | |
| elif 'scatter graph' in summary: | |
| summary.replace('scatter graph','statistic') | |
| elif 'scatterchart' in summary: | |
| summary.replace('scatter chart','statistic') | |
| elif 'line plot' in summary: | |
| summary.replace('line plot','statistic') | |
| elif 'line graph' in summary: | |
| summary.replace('line graph','statistic') | |
| elif 'linechart' in summary: | |
| summary.replace('linechart','statistic') | |
| if 'graph' in summary: | |
| summary.replace('graph','statistic') | |
| return summary | |
| st.title('Chart and Data Summarization') | |
| st.write('This application generates a summary of a datafile (.csv) (or the underlying data of a chart). Right now, it only generates summaries of files with maximum of four columns. If the file contains more than four columns, the app will throw an error.') | |
| mode = st.selectbox('What kind of summary do you want?', | |
| ('Simple', 'Analytical')) | |
| st.write('You selected: ' + mode + ' summary.') | |
| title = st.text_input('Add appropriate Title of the .csv file', 'State minimum wage rates in the United States as of January 1 , 2020') | |
| st.write('Title of the file is: ' + title) | |
| uploaded_file = st.file_uploader("Upload only .csv file") | |
| if uploaded_file is not None and mode is not None and title is not None: | |
| st.write('Preprocessing file...') | |
| p = preProcess(uploaded_file, title) | |
| contents = p.read_data() | |
| check = p.check_columns(contents) | |
| if check: | |
| st.write('Your file contents:\n') | |
| st.write(contents) | |
| title_data = p.combine_title_data(contents) | |
| st.write('Linearized input format of the data file:\n ') | |
| st.markdown('**'+ title_data + '**') | |
| st.write('Loading model...') | |
| model = Model(title_data, mode) | |
| st.write('Model loading done!\nGenerating Summary...') | |
| summary = model.generate() | |
| st.write('Generated Summary:\n') | |
| st.markdown('**'+ summary + '**') | |