anthony.galtier
renamed to app.py
b44559b
raw
history blame
1.54 kB
import streamlit as st
import numpy as np
import torch
from babel.numbers import format_currency
from bert.tokenize import get_tokenizer
from bert.model import CamembertRegressor
from bert.performance import predict
MODEL_STATE_DICT_PATH = './bert/trained_model/model_epoch_5.pt'
# ENVRIONMENT SET UP
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# MODEL LOADING
saved_model_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=device)
model = CamembertRegressor()
model.load_state_dict(saved_model_dict['model_state_dict'])
tokenizer = get_tokenizer()
max_len = saved_model_dict['max_input_len']
scaler = saved_model_dict['labels_scaler']
# WEB APP
st.title("Text 2 Price - Real Estate")
st.markdown("")
example_description = "Superbe maison de 500m2 à Pétaouchnok..."
description = st.text_area("Décris ton bien immobilier : ", example_description)
if (len(description)>0) & (description != example_description):
predicted_price = predict([description], tokenizer, scaler, model, device,
max_len, 32, return_scaled=False)[0]
predicted_price_formatted = format_currency(predicted_price, 'EUR',
locale='fr_FR')
st.markdown('')
st.markdown('')
st.markdown('On estime que ton bien immobilier serait annoncé à :')
st.markdown("<h1 style='text-align: center;'>" \
+ predicted_price_formatted + "</h1>", unsafe_allow_html=True)