Sujit Pal commited on
Commit
6c0a88f
1 Parent(s): ad3fcf3

fix: replaced st.form call with simpler workflow

Browse files
Files changed (1) hide show
  1. dashboard_featurefinder.py +5 -33
dashboard_featurefinder.py CHANGED
@@ -11,6 +11,7 @@ from torchvision.transforms import Compose, Resize, ToPILImage
11
  from transformers import CLIPProcessor, FlaxCLIPModel
12
  from PIL import Image
13
 
 
14
 
15
  BASELINE_MODEL = "openai/clip-vit-base-patch32"
16
  # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
@@ -23,32 +24,6 @@ IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
23
  # IMAGES_DIR = "/home/shared/data/rsicd_images"
24
  IMAGES_DIR = "./images"
25
 
26
- 2
27
- # @st.cache(allow_output_mutation=True)
28
- # def load_index():
29
- # filenames, image_vecs = [], []
30
- # fvec = open(IMAGE_VECTOR_FILE, "r")
31
- # for line in fvec:
32
- # cols = line.strip().split('\t')
33
- # filename = cols[0]
34
- # image_vec = np.array([float(x) for x in cols[1].split(',')])
35
- # filenames.append(filename)
36
- # image_vecs.append(image_vec)
37
- # V = np.array(image_vecs)
38
- # index = nmslib.init(method='hnsw', space='cosinesimil')
39
- # index.addDataPointBatch(V)
40
- # index.createIndex({'post': 2}, print_progress=True)
41
- # return filenames, index
42
-
43
-
44
- @st.cache(allow_output_mutation=True)
45
- def load_model():
46
- # model = FlaxCLIPModel.from_pretrained(MODEL_PATH)
47
- # processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
48
- model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2")
49
- processor = CLIPProcessor.from_pretrained("flax-community/clip-rsicd-v2")
50
- return model, processor
51
-
52
 
53
  def split_image(X):
54
  num_rows = X.shape[0] // 224
@@ -86,7 +61,7 @@ def get_image_ranks(probs):
86
 
87
 
88
  def app():
89
- model, processor = load_model()
90
 
91
  st.title("Find Features in Images")
92
  st.markdown("""
@@ -114,12 +89,9 @@ def app():
114
  aerial photographs. You will need to download from Unsplash to your
115
  computer and upload it to the demo app.
116
  """)
117
- with st.form(key="form_3"):
118
- buf = st.file_uploader("Upload Image for Analysis")
119
- searched_feature = st.text_input(label="Feature to find")
120
- submit_button = st.form_submit_button("Find")
121
-
122
- if submit_button:
123
  ftmp = NamedTemporaryFile()
124
  ftmp.write(buf.getvalue())
125
  image = plt.imread(ftmp.name)
 
11
  from transformers import CLIPProcessor, FlaxCLIPModel
12
  from PIL import Image
13
 
14
+ import utils
15
 
16
  BASELINE_MODEL = "openai/clip-vit-base-patch32"
17
  # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
 
24
  # IMAGES_DIR = "/home/shared/data/rsicd_images"
25
  IMAGES_DIR = "./images"
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  def split_image(X):
29
  num_rows = X.shape[0] // 224
 
61
 
62
 
63
  def app():
64
+ model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
65
 
66
  st.title("Find Features in Images")
67
  st.markdown("""
 
89
  aerial photographs. You will need to download from Unsplash to your
90
  computer and upload it to the demo app.
91
  """)
92
+ buf = st.file_uploader("Upload Image for Analysis")
93
+ searched_feature = st.text_input("Feature to find")
94
+ if st.button("Find"):
 
 
 
95
  ftmp = NamedTemporaryFile()
96
  ftmp.write(buf.getvalue())
97
  image = plt.imread(ftmp.name)