davidlee1102
Update model
708199f
raw
history blame
2.26 kB
import numpy as np
import streamlit as st
import tensorflow as tf
from PIL import Image
from scipy.spatial.distance import euclidean
st.title("Surrey 2023 - Image Retrieval")
# # Load your TensorFlow model
model = tf.keras.models.load_model("model/adcv_model")
def preprocess_image(image, target_size=(128, 128)):
image = image.resize(target_size)
image_array = np.array(image) # / 255.0
return np.expand_dims(image_array, axis=0)
def distance_quadruplet(result):
distance_list = []
for i in range(1, len(result)):
distance = euclidean(result[0], result[i])
distance_list.append(1 / (1 + distance))
return distance_list
col1, col2 = st.columns(2)
image_list = st.file_uploader("Choose List Image You Want To Search - No More Than 10 Images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
image_query = st.file_uploader("Choose Images For Querying", type=["jpg", "jpeg", "png"], accept_multiple_files=False)
image_process_list = []
image_query_process = []
if image_list is not None:
if len(image_list) >= 10:
col1.write("Your list image have problem - Try to refresh and upload again")
else:
count = 0
columns = st.columns(10)
for i, uploaded_file in enumerate(image_list):
img = Image.open(uploaded_file).convert('RGB')
img = img.resize((128, 128))
image_process_list.append(img)
if count <= 10:
count += 1
columns[i].image(img, caption=f"Uploaded image: {uploaded_file.name}", width=64)
else:
col1.write("Upload Image")
if image_query is not None:
img_qr = Image.open(uploaded_file).convert('RGB')
img_qr = img_qr.resize((128, 128))
image_query_process.append(img_qr)
col1.image(img_qr, caption="Query Image", use_column_width=True, width=98)
else:
col1.write("Upload Image")
image_pr = []
if st.button("Classify"):
image_pr = image_query_process + image_process_list
image_pr = np.stack(image_pr)
result = model.predict(image_pr)
distance_list = []
for i in range(1, len(result)):
distance = euclidean(result[0], result[i])
distance_list.append(1 / (1 + distance))
col2.write(f"len image: {distance_list")