LayBraid commited on
Commit
4bd9371
1 Parent(s): d2f5566

:construction: update app

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. text_to_image.py +7 -6
requirements.txt CHANGED
@@ -6,4 +6,4 @@ Pillow~=9.0.1
6
  jax
7
  jaxlib
8
  flax
9
- Jinja2
 
6
  jax
7
  jaxlib
8
  flax
9
+ Jinja2==3.0.1
text_to_image.py CHANGED
@@ -34,17 +34,17 @@ def load_captions(caption_file):
34
  return image2caption
35
 
36
 
37
- def get_image(text):
38
  model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2")
39
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
40
  filename, index = load_index("./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv")
41
  image2caption = load_captions("./images/test-captions.json")
42
 
43
- inputs = processor(text=[text], image=None, return_tensors="jax", padding=True)
44
 
45
  vector = model.get_text_features(**inputs)
46
  vector = np.asarray(vector)
47
- ids, distances = index.knnQuery(vector, k=10)
48
  result_filenames = [filename[id] for id in ids]
49
  for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
50
  caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
@@ -56,8 +56,8 @@ def get_image(text):
56
  for caption in image2caption[result_filename]:
57
  caption_text.append("* {:s}".format(caption))
58
  col3.markdown("".join(caption_text))
59
- st.markdown("---")
60
- suggest_idx = -1
61
 
62
 
63
  def app():
@@ -65,7 +65,8 @@ def app():
65
  st.text("You want search an image with given text.")
66
 
67
  text = st.text_input("Enter text: ")
 
68
 
69
  if st.button("Search"):
70
- get_image(text)
71
 
 
34
  return image2caption
35
 
36
 
37
+ def get_image(text, number):
38
  model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2")
39
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
40
  filename, index = load_index("./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv")
41
  image2caption = load_captions("./images/test-captions.json")
42
 
43
+ inputs = processor(text=[text], images=None, return_tensors="jax", padding=True)
44
 
45
  vector = model.get_text_features(**inputs)
46
  vector = np.asarray(vector)
47
+ ids, distances = index.knnQuery(vector, k=number)
48
  result_filenames = [filename[id] for id in ids]
49
  for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
50
  caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
 
56
  for caption in image2caption[result_filename]:
57
  caption_text.append("* {:s}".format(caption))
58
  col3.markdown("".join(caption_text))
59
+ st.markdown("---")
60
+ suggest_idx = -1
61
 
62
 
63
  def app():
 
65
  st.text("You want search an image with given text.")
66
 
67
  text = st.text_input("Enter text: ")
68
+ number = st.number_input("Enter number of images result: ", min_value=1, max_value=10)
69
 
70
  if st.button("Search"):
71
+ get_image(text, number)
72