Perm_classifier / app.py
Westwing's picture
Update app.py
282c3a9
raw
history blame contribute delete
No virus
1.04 kB
import pandas
import pickle
import gradio
import sklearn
file = open('RFC_model.pk', 'rb')
rf_clf=pickle.load(file)
file.close()
def Perm_pred(file_obj):
dataframe=pandas.read_csv(file_obj.name, delimiter=',')
new_df=dataframe[['pdp_avg','RETURN_RATE','conversion','Cat_1','Cat_2','Cat_3']]
sku=dataframe['SKU']
pred_score = rf_clf.predict(new_df)
#pred=pd.DataFrame(pred_score,columns=['Prediction'])
pred_list=pred_score.tolist()
y=[]
for i in range(len(pred_list)):
if pred_score[i]==2:
y.append('Never_Perm')
elif pred_score[i]==1:
y.append('Middle_Perm')
else:
y.append('Perm_A')
pred=pandas.DataFrame(y,columns=['Prediction'])
pred['SKU']=sku
df=pred[['SKU','Prediction']]
return df
demo = gradio.Interface(fn=Perm_pred, inputs=[gradio.inputs.File(label='Enter CSV File')], outputs= [gradio.outputs.Dataframe(label='Predicted Label')])
demo.launch()