Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pickle | |
from tensorflow.keras.models import load_model | |
import tensorflow as tf | |
import pickle | |
import os | |
#import tensorflow_hub as hub | |
import numpy as np | |
import pandas as pd | |
from tensorflow.keras.preprocessing.text import Tokenizer | |
from tensorflow.keras .preprocessing.sequence import pad_sequences | |
import pickle | |
import json | |
with open('tokenizer.pkl', 'rb') as f: | |
tokenizer = pickle.load(f) | |
with open('id2label.pkl', 'rb') as f: | |
id2label = pickle.load(f) | |
with open('label2id.pkl', 'rb') as f: | |
label2id = pickle.load(f) | |
model = load_model(r'model.h5') | |
def get_prediction(text): | |
word_vector=tokenizer.texts_to_sequences([text]) | |
max_length=500 | |
word_vector_padded= pad_sequences(word_vector, maxlen= max_length, padding='post', | |
truncating='post') | |
y_pred= model.predict(word_vector_padded) | |
prediction=y_pred.argmax(axis=1)[0] | |
return id2label[int(prediction)] | |
def main(): | |
st.set_page_config(page_title="Spend Classification App", page_icon=":smiley:", layout="wide") | |
st.title("Spend Classification App :smiley:") | |
# Define pages | |
#pages = ["spend classification"] | |
# Add radio buttons to toggle between pages | |
#page = st.sidebar.radio("Select a page", pages) | |
#if page == pages[0]: | |
st.header("Spend Classification") | |
st.write("Enter a product description:") | |
st.write("e.g. Key Features of Alisha Solid Women's Cycling Shorts Cotton Lycra Navy, Red, Navy,Specifications of Alisha Solid Women's Cycling Shorts Shorts Details Number of Contents in Sales Package Pack of 3 Fabric Cotton Lycra Type Cycling Shorts General Details Pattern Solid Ideal For Women's Fabric Care Gentle Machine Wash in Lukewarm Water, Do Not Bleach Additional Details Style Code ALTHT_3P_21 In the Box 3 shorts") | |
input_string = st.text_input("") | |
if st.button("Enter"): | |
st.write("classification is:") | |
pred = get_prediction(input_string) | |
categories = pred.split(" >> ") | |
formatted_output = [] | |
for i, category in enumerate(categories, 1): | |
formatted_output.append(f'Hierarchy {i} classification: {category}') | |
for line in formatted_output: | |
st.write(line) | |
if __name__ == "__main__": | |
main() | |