hwajjala commited on
Commit
267519a
β€’
1 Parent(s): f4b1311

Create text features

Browse files
Files changed (2) hide show
  1. app.py +17 -3
  2. requirements.txt +2 -0
app.py CHANGED
@@ -14,6 +14,7 @@ TEXT_PROMPTS_FILE_NAME = "text_prompts.json"
14
  LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
15
 
16
 
 
17
  clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")
18
 
19
  with open(
@@ -28,15 +29,28 @@ with open(
28
  ) as f:
29
  lr_model = pickle.load(f)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- def greet(name):
33
  return "Hello " + name + "!"
34
 
35
 
36
  iface = gr.Interface(
37
- fn=greet,
38
  inputs="image",
39
  outputs="text",
40
  allow_flagging="manual"
41
  )
42
- iface.launch()
 
14
  LOGISTIC_REGRESSION_MODEL_FILE_NAME = "logistic_regression_l1_oct_2.pkl"
15
 
16
 
17
+
18
  clip_model, preprocess = clip.load(CLIP_MODEL_NAME, device="cpu")
19
 
20
  with open(
 
29
  ) as f:
30
  lr_model = pickle.load(f)
31
 
32
+ logger.info("Logistic regression model loaded, coefficients: ")
33
+
34
+
35
+ all_text_features = []
36
+ with torch.no_grad():
37
+ for k, prompts in text_prompts.items():
38
+ assert len(prompts) == 2
39
+ inputs = clip.tokenize(prompts)
40
+ outputs = clip_model.encode_text(inputs)
41
+ all_text_features.append(outputs)
42
+ all_text_features = torch.cat(all_text_features, dim=0)
43
+ all_text_features = all_text_features.cpu().numpy()
44
+
45
 
46
+ def predict_fn(name):
47
  return "Hello " + name + "!"
48
 
49
 
50
  iface = gr.Interface(
51
+ fn=predict_fn,
52
  inputs="image",
53
  outputs="text",
54
  allow_flagging="manual"
55
  )
56
+ iface.launch()
requirements.txt CHANGED
@@ -4,3 +4,5 @@ ftfy
4
  regex
5
  tqdm
6
  git+https://github.com/openai/CLIP.git
 
 
 
4
  regex
5
  tqdm
6
  git+https://github.com/openai/CLIP.git
7
+ scikit-learn
8
+ scipy