Spaces:
Runtime error
Runtime error
import streamlit as st | |
import numpy as np | |
import torch | |
from babel.numbers import format_currency | |
from bert.tokenize import get_tokenizer | |
from bert.model import CamembertRegressor | |
from bert.performance import predict | |
MODEL_STATE_DICT_PATH = './bert/trained_model/model_epoch_5.pt' | |
# ENVRIONMENT SET UP | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
# MODEL LOADING | |
saved_model_dict = torch.load(MODEL_STATE_DICT_PATH, map_location=device) | |
model = CamembertRegressor() | |
model.load_state_dict(saved_model_dict['model_state_dict']) | |
tokenizer = get_tokenizer() | |
max_len = saved_model_dict['max_input_len'] | |
scaler = saved_model_dict['labels_scaler'] | |
# WEB APP | |
st.title("Text 2 Price - Real Estate") | |
st.markdown("") | |
example_description = "Superbe maison de 500m2 à Pétaouchnok..." | |
description = st.text_area("Décris ton bien immobilier : ", example_description) | |
if (len(description)>0) & (description != example_description): | |
predicted_price = predict([description], tokenizer, scaler, model, device, | |
max_len, 32, return_scaled=False)[0] | |
predicted_price_formatted = format_currency(predicted_price, 'EUR', | |
locale='fr_FR') | |
st.markdown('') | |
st.markdown('') | |
st.markdown('On estime que ton bien immobilier serait annoncé à :') | |
st.markdown("<h1 style='text-align: center;'>" \ | |
+ predicted_price_formatted + "</h1>", unsafe_allow_html=True) |