TheLitttleThings commited on
Commit
71ee83a
1 Parent(s): 3d33526

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -25,7 +25,7 @@ if 'tab' not in st.session_state:
25
  st.session_state['tab'] = 0
26
 
27
  @st.experimental_memo
28
- def image_search(query, top_k=12):
29
  with torch.no_grad():
30
  text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
31
  _, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
@@ -54,7 +54,7 @@ def most_similars(embeddings_1, embeddings_2):
54
  values, indices = torch.cosine_similarity(embeddings_1, embeddings_2).sort(descending=True)
55
  return values.cpu(), indices.cpu()
56
 
57
- def analogy(input_image_path: str, top_k=12, additional_text: str = '', input_include=True):
58
  """ Analogies with embedding space arithmetic.
59
  Args:
60
  input_image_path (str): The path to original image
@@ -68,7 +68,7 @@ def analogy(input_image_path: str, top_k=12, additional_text: str = '', input_i
68
 
69
  return [links[i] for i in indices[:top_k]]
70
 
71
- def image_comparison(base_image, top_k=12):
72
  image_embedding = image_query_embedding(base_image)
73
  #additional_embedding = text_query_embedding(query=additional_text)
74
  new_image_embedding = image_embedding #+ additional_embedding
25
  st.session_state['tab'] = 0
26
 
27
  @st.experimental_memo
28
+ def image_search(query, top_k=24):
29
  with torch.no_grad():
30
  text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
31
  _, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
54
  values, indices = torch.cosine_similarity(embeddings_1, embeddings_2).sort(descending=True)
55
  return values.cpu(), indices.cpu()
56
 
57
+ def analogy(input_image_path: str, top_k=24, additional_text: str = '', input_include=True):
58
  """ Analogies with embedding space arithmetic.
59
  Args:
60
  input_image_path (str): The path to original image
68
 
69
  return [links[i] for i in indices[:top_k]]
70
 
71
+ def image_comparison(base_image, top_k=24):
72
  image_embedding = image_query_embedding(base_image)
73
  #additional_embedding = text_query_embedding(query=additional_text)
74
  new_image_embedding = image_embedding #+ additional_embedding