A-New-Day-001 commited on
Commit
46f6438
·
1 Parent(s): 9e453e5

Update screens/predict.py

Browse files
Files changed (1) hide show
  1. screens/predict.py +152 -101
screens/predict.py CHANGED
@@ -1,85 +1,130 @@
1
  import streamlit as st
2
  import json
3
  from autogluon.multimodal import MultiModalPredictor
4
- from autogluon.tabular import TabularPredictor
5
  import pandas as pd
 
6
  import os
7
  import tempfile
8
 
9
  def predict_page():
10
- # Add a title with an icon
11
- st.title("💵 Property Price Estimator")
12
-
13
- # User role selection (buyer or seller)
14
- user_role = st.selectbox("Select Your Role", options=("Buyer", "Seller"), index=0)
15
-
16
- # Load geocoder and model based on user role
17
- @st.cache(allow_output_mutation=True)
18
-
19
-
20
- if user_role == "Seller":
21
- model_path = "models/mm-nlp-image-transformer/"
22
- else:
23
- model_path = "models/tabular/"
24
-
25
- @st.cache(allow_output_mutation=True)
26
- def load_mm_text_model():
27
- if user_role == "Seller":
28
- return MultiModalPredictor.load(model_path, verbosity=0)
29
- elif user_role == "Buyer":
30
- return TabularPredictor.load(model_path, verbosity=0)
31
-
32
- mm_text_predictor = load_mm_text_model()
33
-
34
- # Common user input fields
35
- st.header("Location Details")
36
- city_map = json.load(open("city-map.json"))
37
- city = st.selectbox("Choose city 🏙️", options=list(city_map.values()))
38
- city_district_map = json.load(open("city-district-map.json"))
39
- district = st.selectbox("Choose district 🏘️", options=list(city_district_map[city].values()))
40
- location = st.text_input("Enter precise location 📍")
41
-
42
- st.header("Property Specifications")
43
- area = st.number_input("Area (m2) 📏", min_value=1.0)
44
- bedrooms = st.number_input("Number of bedrooms 🛏️", min_value=1, value=1)
45
- bathrooms = st.number_input("Number of bathrooms 🚽", min_value=1, value=1)
46
- floors = st.number_input("Number of floors 🏢", min_value=1, value=1)
47
- front_width = st.number_input("Front width (m) 📏", min_value=0.0, value=0.0, step=0.1)
48
- road_width = st.number_input("Road width (m) 🚗", min_value=0.0, value=0.0, step=0.1)
49
-
50
- st.header("Additional Details")
51
- timestamp = st.date_input("Date posted 📅")
52
- cert_status = st.selectbox("Certification status 📜", options=["Không có", "hợp đồng", "sổ đỏ / sổ hồng"])
53
- direction = st.selectbox("Direction 🧭", options=["Không có", "Tây - Nam", "Đông - Nam", "Đông - Bắc", "Tây - Bắc", "Nam", "Tây", "Bắc", "Đông"])
54
- balcony_direction = st.selectbox("Balcony direction 🌞", options=["Không có", "Tây - Nam", "Đông - Nam", "Đông - Bắc", "Tây - Bắc", "Nam", "Tây", "Bắc", "Đông"])
55
-
56
- # Description and image upload (conditional for sellers)
57
- if user_role == "Seller":
58
- description = st.text_area("Description ✍️")
59
- title = description.split(".", maxsplit=1)[0]
60
-
61
- uploaded_image = st.file_uploader("Upload an image 📷")
62
- image_tmp = None
63
- if uploaded_image:
64
- image_tmp = tempfile.NamedTemporaryFile(suffix=uploaded_image.name)
65
- image_tmp.write(uploaded_image.read())
66
- st.image(uploaded_image, caption='Uploaded Image', use_column_width=True)
67
- else:
68
- description = ""
69
- title = ""
70
-
71
- # Calculate latitude and longitude
72
-
 
 
73
  latitude = float("nan")
74
  longitude = float("nan")
75
-
76
- # Display location map
77
- st.map([(latitude, longitude)], zoom=15)
78
-
79
- # Create a button with an icon to get estimated price
80
- if st.button("Get Estimated Price with Text 💰"):
81
- if user_role == "Seller":
82
- input_data = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  "Title": title,
84
  "Area": area,
85
  "Location": location,
@@ -93,36 +138,42 @@ def predict_page():
93
  "Description": description,
94
  "Image URL": image_tmp.name if image_tmp else None,
95
  "Road width": road_width or float("nan"),
96
- "City_code": city,
97
- "DistrictId": district,
98
  "Lattitude": latitude,
99
  "Longitude": longitude,
100
  "Balcony_Direction": balcony_direction,
101
  }
102
- elif user_role == "Buyer":
103
- input_data = {
104
- "Area": area,
105
- "Location": location,
106
- "Time stamp": timestamp,
107
- "Certification status": cert_status,
108
- "Direction": direction,
109
- "Bedrooms": bedrooms,
110
- "Bathrooms": bathrooms,
111
- "Front width": front_width or float("nan"),
112
- "Floor": floors,
113
- "Image URL": image_tmp.name if image_tmp else None,
114
- "Road width": road_width or float("nan"),
115
- "City_code": city,
116
- "DistrictId": district,
117
- "Lattitude": latitude,
118
- "Longitude": longitude,
119
- "Balcony_Direction": balcony_direction,
120
- }
121
-
122
- input_df = pd.DataFrame([input_data])
123
- predicted_price = mm_text_predictor.predict(input_df, as_pandas=False).item()
124
-
125
- # Display the estimated price with an icon
126
- st.subheader("Estimated Price 💲")
127
- st.write(f"{predicted_price * 1e6:,.0f} VND")
128
-
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import json
3
  from autogluon.multimodal import MultiModalPredictor
 
4
  import pandas as pd
5
+ from geopy.geocoders import GoogleV3
6
  import os
7
  import tempfile
8
 
9
  def predict_page():
10
+ if "price_text" not in st.session_state:
11
+ st.session_state.price_text = 0
12
+
13
+ @st.cache_resource
14
+ def load_mm_text_no_price_model():
15
+ return MultiModalPredictor.load("models/mm-text-no-price/", verbosity=0)
16
+
17
+
18
+ mm_text_no_price_predictor = load_mm_text_no_price_model()
19
+
20
+
21
+ @st.cache_resource
22
+ def load_city_map():
23
+ return json.load(open("city-map.json"))
24
+
25
+
26
+ city_map = load_city_map()
27
+
28
+
29
+ @st.cache_resource
30
+ def load_city_district_map():
31
+ return json.load(open("city-district-map.json"))
32
+
33
+
34
+ city_district_map = load_city_district_map()
35
+
36
+ CERT_STATUS = pd.CategoricalDtype(
37
+ categories=["Không ", "hợp đồng", "sổ đỏ / sổ hồng"], ordered=False
38
+ )
39
+ DIRECTION = pd.CategoricalDtype(
40
+ categories=[
41
+ "Không có",
42
+ "Tây - Nam",
43
+ "Đông - Nam",
44
+ "Đông - Bắc",
45
+ "Tây - Bắc",
46
+ "Nam",
47
+ "Tây",
48
+ "Bắc",
49
+ "Đông",
50
+ ],
51
+ ordered=False,
52
+ )
53
+ CITY = pd.CategoricalDtype(categories=city_map.keys(), ordered=False)
54
+ DISTRICT = pd.CategoricalDtype(
55
+ categories=sum([list(map(int, v.keys())) for v in city_district_map.values()], []),
56
+ ordered=False,
57
+ )
58
+
59
+ location_options = st.columns([1, 1, 2, 1, 1])
60
+ with location_options[0]:
61
+ city = st.selectbox(
62
+ "Choose city", options=city_map.items(), format_func=lambda x: x[1]
63
+ )
64
+ with location_options[1]:
65
+ district = st.selectbox(
66
+ "Choose district",
67
+ options=city_district_map[city[0]].items(),
68
+ format_func=lambda x: x[1],
69
+ )
70
+ with location_options[2]:
71
+ location = st.text_input("Enter precise location")
72
+
73
+ location = (location + ", " if location else "") + city[1] + ", " + district[1]
74
+ geocode_result = geocoder.geocode(query=location, region="vn", language="vi")
75
  latitude = float("nan")
76
  longitude = float("nan")
77
+
78
+ with location_options[3]:
79
+ latitude = st.number_input(
80
+ "Enter latitude", value=latitude, step=1e-8, format="%.7f"
81
+ )
82
+ with location_options[4]:
83
+ longitude = st.number_input(
84
+ "Enter longitude", value=longitude, step=1e-8, format="%.7f"
85
+ )
86
+
87
+ numerical_options = st.columns(6)
88
+ with numerical_options[0]:
89
+ area = st.number_input("Area (m2)", min_value=1.0)
90
+ with numerical_options[1]:
91
+ bedrooms = st.number_input("Number of bedrooms", min_value=1, value=1)
92
+ with numerical_options[2]:
93
+ bathrooms = st.number_input("Number of bathrooms", min_value=1, value=1)
94
+ with numerical_options[3]:
95
+ floors = st.number_input("Number of floors", min_value=1, value=1)
96
+ with numerical_options[4]:
97
+ front_width = st.number_input(
98
+ "Front width, leave 0 for N/A", min_value=0.0, value=0.0, step=0.1
99
+ )
100
+ with numerical_options[5]:
101
+ road_width = st.number_input(
102
+ "Road width, leave 0 for N/A", min_value=0.0, value=0.0, step=0.1
103
+ )
104
+
105
+ cat_time_columns = st.columns(4)
106
+ with cat_time_columns[0]:
107
+ timestamp = st.date_input("Date posted", format="DD/MM/YYYY")
108
+ with cat_time_columns[1]:
109
+ cert_status = st.selectbox("Certification status", options=CERT_STATUS.categories)
110
+ with cat_time_columns[2]:
111
+ direction = st.selectbox("Direction", options=DIRECTION.categories)
112
+ with cat_time_columns[3]:
113
+ balcony_direction = st.selectbox("Balcony direction", options=DIRECTION.categories)
114
+
115
+ description = st.text_area("Description")
116
+ title = description.split(".", maxsplit=1)[0]
117
+
118
+ uploaded_image = st.file_uploader("Upload an image")
119
+ image_tmp = None
120
+ if uploaded_image:
121
+ image_tmp = tempfile.NamedTemporaryFile(suffix=uploaded_image.name)
122
+ image_tmp.write(uploaded_image.read())
123
+ print(image_tmp.name)
124
+
125
+ df = pd.DataFrame(
126
+ [
127
+ {
128
  "Title": title,
129
  "Area": area,
130
  "Location": location,
 
138
  "Description": description,
139
  "Image URL": image_tmp.name if image_tmp else None,
140
  "Road width": road_width or float("nan"),
141
+ "City_code": city[0],
142
+ "DistrictId": int(district[0]),
143
  "Lattitude": latitude,
144
  "Longitude": longitude,
145
  "Balcony_Direction": balcony_direction,
146
  }
147
+ ]
148
+ ).astype(
149
+ {
150
+ "Title": "str",
151
+ "Area": "float",
152
+ "Location": "str",
153
+ "Time stamp": "datetime64[ns]",
154
+ "Certification status": CERT_STATUS,
155
+ "Direction": DIRECTION,
156
+ "Bedrooms": "int",
157
+ "Bathrooms": "int",
158
+ "Front width": "float",
159
+ "Floor": "int",
160
+ "Description": "str",
161
+ "Image URL": "str",
162
+ "Road width": "float",
163
+ "City_code": CITY,
164
+ "DistrictId": DISTRICT,
165
+ "Lattitude": "float",
166
+ "Longitude": "float",
167
+ "Balcony_Direction": DIRECTION,
168
+ }
169
+ )
170
+
171
+ if st.button("Get estimated price with text"):
172
+ st.session_state.price_text = mm_text_no_price_predictor.predict(
173
+ df, as_pandas=False
174
+ ).item()
175
+ st.text(
176
+ "Estimated price: {0:,} VND".format(int(st.session_state.price_text * 1e6))
177
+ if st.session_state.price_text
178
+ else "No price estimated."
179
+ )