davidlee1102
Update model
011c680
import numpy as np
import streamlit as st
import tensorflow as tf
from PIL import Image
from scipy.spatial.distance import euclidean
st.title("Surrey 2023 - Image Retrieval")
# # Load your TensorFlow model
model = tf.keras.models.load_model("model/adcv_model")
def preprocess_image(image, target_size=(128, 128)):
image = image.resize(target_size)
image_array = np.array(image) # / 255.0
return np.expand_dims(image_array, axis=0)
def distance_quadruplet(result):
distance_list = []
for i in range(1, len(result)):
distance = euclidean(result[0], result[i])
distance_list.append(1 / (1 + distance))
return distance_list
col1, col2 = st.columns(2)
image_list = st.file_uploader("Choose List Image You Want To Search - No More Than 10 Images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
image_query = st.file_uploader("Choose Images For Querying", type=["jpg", "jpeg", "png"], accept_multiple_files=False)
image_process_list = []
image_query_process = []
if image_list is not None:
if len(image_list) >= 10:
col1.write("Your list image have problem - Try to refresh and upload again")
else:
count = 0
columns = st.columns(10)
for i, uploaded_file_1 in enumerate(image_list):
img = Image.open(uploaded_file_1).convert('RGB')
img = img.resize((128, 128))
image_process_list.append(img)
if count <= 10:
count += 1
columns[i].image(img, caption=f"Uploaded image: {uploaded_file_1.name}", width=64)
else:
col1.write("Upload Image")
if image_query is not None:
img_qr = Image.open(image_query).convert('RGB')
img_qr = img_qr.resize((128, 128))
image_query_process.append(img_qr)
col1.image(img_qr, caption="Query Image", use_column_width=True, width=98)
else:
col1.write("Upload Image")
image_pr = []
if st.button("Classify"):
image_pr = image_query_process + image_process_list
image_pr = np.stack(image_pr)
result = model.predict(image_pr)
distance_list = []
for i in range(1, len(result)):
distance = euclidean(result[0], result[i])
distance_list.append(1 / (1 + distance))
max_index = np.argmax(distance_list)
image_matches = image_list[max_index]
col2.write(f"The image that have the most similarity: {image_matches}")