Vivien commited on
Commit
000d238
1 Parent(s): 3f8dd94

Simplify code

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -10,14 +10,13 @@ from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
10
  dict: lambda _: None})
11
  def load():
12
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
14
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
15
  df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
16
  embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
17
  for k in [0, 1]:
18
  embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
19
- return model, text_model, processor, df, embeddings
20
- model, text_model, processor, df, embeddings = load()
21
 
22
  source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}
23
 
@@ -33,7 +32,7 @@ def get_html(url_list, height=200):
33
 
34
  def compute_text_embeddings(list_of_strings):
35
  inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
36
- return model.text_projection(text_model(**inputs).pooler_output)
37
 
38
  st.cache(show_spinner=False)
39
  def image_search(query, corpus, n_results=24):
 
10
  dict: lambda _: None})
11
  def load():
12
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
 
13
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
  df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
15
  embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
16
  for k in [0, 1]:
17
  embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
18
+ return model, processor, df, embeddings
19
+ model, processor, df, embeddings = load()
20
 
21
  source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}
22
 
 
32
 
33
  def compute_text_embeddings(list_of_strings):
34
  inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
35
+ return model.get_text_features(**inputs)
36
 
37
  st.cache(show_spinner=False)
38
  def image_search(query, corpus, n_results=24):