Samuel Schmidt commited on
Commit
cb03df9
·
1 Parent(s): 49a4ce1

Bugfix, comma

Browse files
Files changed (1) hide show
  1. src/app.py +5 -5
src/app.py CHANGED
@@ -49,14 +49,14 @@ def check_index(ds):
49
 
50
  else:
51
  return index_dataset(ds)
52
-
53
 
54
  dataset_with_embeddings = check_index(candidate_subset)
55
 
56
  # Main function, to find similar images
57
  # TODO: implement different distance measures
58
 
59
- def get_neighbors(query_image, selected_descriptor, selected_distance top_k=5):
60
  """Returns the top k nearest examples to the query image.
61
 
62
  Args:
@@ -75,10 +75,10 @@ def get_neighbors(query_image, selected_descriptor, selected_distance top_k=5):
75
  'color_embeddings', qi_np, k=top_k)
76
  elif selected_distance == "Chi-squared":
77
  tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': chi2_distance(histA=query_vector, histB=row['color_embeddings'])})
78
- retrieved_examples = tmp_dataset.sort("distance")
79
- else:
80
  tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': euclidian_distance(histA=query_vector, histB=row['color_embeddings'])})
81
- retrieved_examples = tmp_dataset.sort("distance")
82
  images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
83
  return images
84
  if "CLIP" == selected_descriptor:
 
49
 
50
  else:
51
  return index_dataset(ds)
52
+
53
 
54
  dataset_with_embeddings = check_index(candidate_subset)
55
 
56
  # Main function, to find similar images
57
  # TODO: implement different distance measures
58
 
59
+ def get_neighbors(query_image, selected_descriptor, selected_distance, top_k=5):
60
  """Returns the top k nearest examples to the query image.
61
 
62
  Args:
 
75
  'color_embeddings', qi_np, k=top_k)
76
  elif selected_distance == "Chi-squared":
77
  tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': chi2_distance(histA=query_vector, histB=row['color_embeddings'])})
78
+ retrieved_examples = tmp_dataset.sort("distance")[:5]
79
+ else:
80
  tmp_dataset = dataset_with_embeddings.map(lambda row: {'distance': euclidian_distance(histA=query_vector, histB=row['color_embeddings'])})
81
+ retrieved_examples = tmp_dataset.sort("distance")[:5]
82
  images = retrieved_examples['image'] #retrieved images is a dict, with images and embeddings
83
  return images
84
  if "CLIP" == selected_descriptor: