IvaElen commited on
Commit
6e828f7
1 Parent(s): 8e9828f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -27,10 +27,10 @@ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
27
 
28
  st.title('Find my pic!')
29
 
30
- def find_image_disc(prompt, df):
31
  img_descs = []
32
  img_descs_vit = []
33
- list_images_names, list_images_names_vit = get_similiarity(prompt, model_resnet, model_vit, 3)
34
  for img in list_images_names:
35
  img_descs.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))
36
  #vit
@@ -40,7 +40,7 @@ def find_image_disc(prompt, df):
40
  return list_images_names, img_descs, list_images_names_vit, img_descs_vit
41
 
42
  txt = st.text_area("Describe the picture you'd like to see")
43
-
44
  df = pd.read_csv('results.csv',
45
  sep = '|',
46
  names = ['image_name', 'comment_number', 'comment'],
@@ -49,8 +49,7 @@ df = pd.read_csv('results.csv',
49
 
50
  if txt is not None:
51
  if st.button('Find!'):
52
-
53
- list_images, img_desc, list_images_vit, img_descs_vit = find_image_disc(txt, df)
54
  col1, col2 = st.columns(2)
55
  col1.header('ResNet50')
56
  col2.header('ViT 32')
 
27
 
28
  st.title('Find my pic!')
29
 
30
+ def find_image_disc(prompt, df, top_k):
31
  img_descs = []
32
  img_descs_vit = []
33
+ list_images_names, list_images_names_vit = get_similiarity(prompt, model_resnet, model_vit, top_k)
34
  for img in list_images_names:
35
  img_descs.append(random.choice(df[df['image_name'] == img.split('/')[-1]]['comment'].values).replace('.', ''))
36
  #vit
 
40
  return list_images_names, img_descs, list_images_names_vit, img_descs_vit
41
 
42
  txt = st.text_area("Describe the picture you'd like to see")
43
+ top_k = st.slider('Number of images', 1, 5, 3)
44
  df = pd.read_csv('results.csv',
45
  sep = '|',
46
  names = ['image_name', 'comment_number', 'comment'],
 
49
 
50
  if txt is not None:
51
  if st.button('Find!'):
52
+ list_images, img_desc, list_images_vit, img_descs_vit = find_image_disc(txt, df, top_k)
 
53
  col1, col2 = st.columns(2)
54
  col1.header('ResNet50')
55
  col2.header('ViT 32')