amoldwalunj's picture
Update app.py
97fa3b6
import streamlit as st
import pickle
from tensorflow.keras.models import load_model
import tensorflow as tf
import pickle
import os
#import tensorflow_hub as hub
import numpy as np
import pandas as pd
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras .preprocessing.sequence import pad_sequences
import pickle
import json
with open('tokenizer.pkl', 'rb') as f:
tokenizer = pickle.load(f)
with open('id2label.pkl', 'rb') as f:
id2label = pickle.load(f)
with open('label2id.pkl', 'rb') as f:
label2id = pickle.load(f)
model = load_model(r'model.h5')
def get_prediction(text):
word_vector=tokenizer.texts_to_sequences([text])
max_length=500
word_vector_padded= pad_sequences(word_vector, maxlen= max_length, padding='post',
truncating='post')
y_pred= model.predict(word_vector_padded)
prediction=y_pred.argmax(axis=1)[0]
return id2label[int(prediction)]
def main():
st.set_page_config(page_title="Spend Classification App", page_icon=":smiley:", layout="wide")
st.title("Spend Classification App :smiley:")
# Define pages
#pages = ["spend classification"]
# Add radio buttons to toggle between pages
#page = st.sidebar.radio("Select a page", pages)
#if page == pages[0]:
st.header("Spend Classification")
st.write("Enter a product description:")
st.write("e.g. Key Features of Alisha Solid Women's Cycling Shorts Cotton Lycra Navy, Red, Navy,Specifications of Alisha Solid Women's Cycling Shorts Shorts Details Number of Contents in Sales Package Pack of 3 Fabric Cotton Lycra Type Cycling Shorts General Details Pattern Solid Ideal For Women's Fabric Care Gentle Machine Wash in Lukewarm Water, Do Not Bleach Additional Details Style Code ALTHT_3P_21 In the Box 3 shorts")
input_string = st.text_input("")
if st.button("Enter"):
st.write("classification is:")
pred = get_prediction(input_string)
categories = pred.split(" >> ")
formatted_output = []
for i, category in enumerate(categories, 1):
formatted_output.append(f'Hierarchy {i} classification: {category}')
for line in formatted_output:
st.write(line)
if __name__ == "__main__":
main()