mattikris's picture
Update app.py
a95e12c
raw
history blame contribute delete
No virus
2.16 kB
import streamlit as st
st.markdown(""" This is a Streamlit App """)
import streamlit as st
import pandas as pd
import numpy as np
import simpletransformers
import pickle
import torch
import chardet
from pathlib import Path
from detect_delimiter import detect
data = pd.read_csv("training_data.csv")
#Change Information - Sammenligning to information
data['Category'] = data['Category'].replace('Information - Sammenligning', 'Informational')
data['Category'] = data['Category'].replace('Information', 'Informational')
data = data.groupby('Category').apply(lambda x: x.sample(1500, replace=True)).reset_index(drop=True)
train_df = pd.DataFrame()
train_df['text'] = data['keywords']
train_df['labels'] = data['Category']
train_df['labels'] = train_df['labels'].astype('category').cat.codes
n_labels = len(train_df['labels'].unique())
from simpletransformers.ner import NERModel
from simpletransformers.classification import ClassificationModel
model = ClassificationModel('bert', 'Maltehb/danish-bert-botxo', num_labels=n_labels, use_cuda=True, args={'reprocess_input_data': True, 'overwrite_output_dir': True})
model.train_model(train_df)
label_dict = {
0: "Brandsøgning",
1: "Informational",
2: "Inspiration",
3: "Navigational",
4: "Transactional"
}
upload_file = st.file_uploader("Choose a file",type="csv" )
#model = pickle.load(open("finalized_model.sav","rb"))
if upload_file is not None:
result = chardet.detect(upload_file.getvalue())
encoding_value = result["encoding"]
if encoding_value == "UTF-16":
white_space = True
else:
white_space = False
df = pd.read_csv((upload_file), on_bad_lines='skip', encoding=encoding_value, delim_whitespace=white_space)
print(df)
result = {}
result['Keyword'] = df['Keyword'][:5000]
result['volume'] =df['Volume'][:5000]
classes = [label_dict[model.predict(item)[0][0]] for item in df['Keyword'].values[:5000]]
result['Classes'] = classes
df = pd.DataFrame(result)
st.download_button(
label="Download CSV file",
data=df.to_csv().encode('utf-8'),
file_name='labbeled_data.csv',
mime='text/csv'
)