Sujit Pal commited on
Commit
ad3fcf3
1 Parent(s): e2b6044

fix: moving common code to utils so loading happens once

Browse files
Files changed (1) hide show
  1. utils.py +32 -0
utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import nmslib
3
+ import numpy as np
4
+ import os
5
+ import streamlit as st
6
+
7
+ from PIL import Image
8
+ from transformers import CLIPProcessor, FlaxCLIPModel
9
+
10
+
11
+ @st.cache(allow_output_mutation=True)
12
+ def load_index(image_vector_file):
13
+ filenames, image_vecs = [], []
14
+ fvec = open(image_vector_file, "r")
15
+ for line in fvec:
16
+ cols = line.strip().split('\t')
17
+ filename = cols[0]
18
+ image_vec = np.array([float(x) for x in cols[1].split(',')])
19
+ filenames.append(filename)
20
+ image_vecs.append(image_vec)
21
+ V = np.array(image_vecs)
22
+ index = nmslib.init(method='hnsw', space='cosinesimil')
23
+ index.addDataPointBatch(V)
24
+ index.createIndex({'post': 2}, print_progress=True)
25
+ return filenames, index
26
+
27
+
28
+ @st.cache(allow_output_mutation=True)
29
+ def load_model(model_path, baseline_model):
30
+ model = FlaxCLIPModel.from_pretrained(model_path)
31
+ processor = CLIPProcessor.from_pretrained(baseline_model)
32
+ return model, processor