RoboApocalypse commited on
Commit
1f9e30b
·
1 Parent(s): 435181d

Refactor generate_embedding function to remove unneeded variables

Browse files
Files changed (1) hide show
  1. app.py +8 -45
app.py CHANGED
@@ -1,10 +1,9 @@
 
1
  import gradio as gr
2
  from numpy import empty
3
  import open_clip
4
  import torch
5
  import PIL.Image as Image
6
- from io import BytesIO
7
- import base64
8
 
9
  # Set device to GPU if available
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -12,8 +11,8 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
  # Load the OpenCLIP model and the necessary preprocessors
13
  # openclip_model = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
14
  # openclip_model = 'laion/CLIP-ViT-B-16-laion2B-s34B-b88K'
15
- openclip_model = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K'
16
- openclip_model = 'hf-hub:' + openclip_model
17
  model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
18
  model_name=openclip_model,
19
  device=device
@@ -21,7 +20,7 @@ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
21
 
22
 
23
  # Define function to generate text embeddings
24
- def generate_text_embedding(text_data):
25
  """
26
  Generate embeddings for text data using the OpenCLIP model.
27
 
@@ -76,7 +75,7 @@ def generate_text_embedding(text_data):
76
  return text_embeddings
77
 
78
  # Define function to generate image embeddings
79
- def generate_image_embedding(image_data):
80
  """
81
  Generate embeddings for image data using the OpenCLIP model.
82
 
@@ -129,7 +128,7 @@ def generate_image_embedding(image_data):
129
 
130
 
131
  # Define function to generate embeddings
132
- def generate_embedding(text_data, image_data, image_data_base64):
133
  """
134
  Generate embeddings for text and image data using the OpenCLIP model.
135
 
@@ -139,8 +138,6 @@ def generate_embedding(text_data, image_data, image_data_base64):
139
  Text data to embed.
140
  image_data : PIL.Image.Image or tuple of PIL.Image.Image
141
  Image data to embed.
142
- image_data_base64 : str or tuple of str
143
- Base64 encoded image data to embed.
144
 
145
  Returns
146
  -------
@@ -150,8 +147,6 @@ def generate_embedding(text_data, image_data, image_data_base64):
150
  List of image embeddings.
151
  similarity : list of str
152
  List of cosine similarity between text and image embeddings.
153
- image_data_base64_embeddings : str or tuple of str
154
- List of image embeddings for base64 encoded image data.
155
  """
156
 
157
  # Embed text data
@@ -193,38 +188,7 @@ def generate_embedding(text_data, image_data, image_data_base64):
193
  for i in empty_data_indices:
194
  similarity.insert(i, "")
195
 
196
- # Embed base64 encoded image data
197
- decoded_image_data = []
198
- if image_data_base64:
199
- # If image_data_base64 is a string, convert to list of strings
200
- if isinstance(image_data_base64, str):
201
- image_data_base64 = [image_data_base64]
202
-
203
- # If image_data_base64 is a tuple of strings, convert to list of strings
204
- if isinstance(image_data_base64, tuple):
205
- image_data_base64 = list(image_data_base64)
206
-
207
- # If image_data_base64 is not a list of strings, raise error
208
- if not isinstance(image_data_base64, list):
209
- raise TypeError("image_data_base64 must be a string or a tuple of strings.")
210
-
211
- # Keep track of indices of empty image strings
212
- empty_data_indices = [i for i, img in enumerate(image_data_base64) if img == ""]
213
-
214
- # Remove empty image strings
215
- image_data_base64 = [img for img in image_data_base64 if img != ""]
216
-
217
- if image_data_base64:
218
- # Decode base64 encoded image data
219
- decoded_image_data = [Image.open(BytesIO(base64.b64decode(img))) for img in image_data_base64]
220
-
221
- # Insert empty strings at indices of empty image strings
222
- for i in empty_data_indices:
223
- decoded_image_data.insert(i, None)
224
-
225
- image_data_base64_embeddings = generate_image_embedding(tuple(decoded_image_data))
226
-
227
- return (text_embeddings, image_embeddings, similarity, image_data_base64_embeddings)
228
 
229
 
230
  # Define Gradio interface
@@ -233,13 +197,12 @@ demo = gr.Interface(
233
  inputs=[
234
  gr.Textbox(lines=5, max_lines=5, placeholder="Enter Text Here...", label="Text to Embed"),
235
  gr.Image(height=512, type="pil", label="Image to Embed"),
236
- gr.Textbox(lines=5, max_lines=5, label="Base64 Encoded Image", autoscroll=False)
237
  ],
238
  outputs=[
239
  gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False),
240
  gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False),
241
  gr.Textbox(label="Cosine Similarity"),
242
- gr.Textbox(lines=5, max_lines=5, label="Embedding of Base64 Encoded Images", autoscroll=False)
243
  ],
244
  title="OpenCLIP Embedding Generator",
245
  description="Generate embeddings using OpenCLIP model for text and images.",
 
1
+ from typing import Union
2
  import gradio as gr
3
  from numpy import empty
4
  import open_clip
5
  import torch
6
  import PIL.Image as Image
 
 
7
 
8
  # Set device to GPU if available
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
11
  # Load the OpenCLIP model and the necessary preprocessors
12
  # openclip_model = 'laion/CLIP-ViT-B-32-laion2B-s34B-b79K'
13
  # openclip_model = 'laion/CLIP-ViT-B-16-laion2B-s34B-b88K'
14
+ openclip_model_name = 'laion/CLIP-ViT-L-14-laion2B-s32B-b82K'
15
+ openclip_model = "hf-hub:" + openclip_model_name
16
  model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(
17
  model_name=openclip_model,
18
  device=device
 
20
 
21
 
22
  # Define function to generate text embeddings
23
+ def generate_text_embedding(text_data: Union[str, tuple[str]]) -> list[str]:
24
  """
25
  Generate embeddings for text data using the OpenCLIP model.
26
 
 
75
  return text_embeddings
76
 
77
  # Define function to generate image embeddings
78
+ def generate_image_embedding(image_data: Union[Image.Image, tuple[Image.Image]]) -> list[str]:
79
  """
80
  Generate embeddings for image data using the OpenCLIP model.
81
 
 
128
 
129
 
130
  # Define function to generate embeddings
131
+ def generate_embedding(text_data: Union[str, tuple[str]], image_data: Union[Image.Image, tuple[Image.Image]]) -> tuple[list[str], list[str], list[str]]:
132
  """
133
  Generate embeddings for text and image data using the OpenCLIP model.
134
 
 
138
  Text data to embed.
139
  image_data : PIL.Image.Image or tuple of PIL.Image.Image
140
  Image data to embed.
 
 
141
 
142
  Returns
143
  -------
 
147
  List of image embeddings.
148
  similarity : list of str
149
  List of cosine similarity between text and image embeddings.
 
 
150
  """
151
 
152
  # Embed text data
 
188
  for i in empty_data_indices:
189
  similarity.insert(i, "")
190
 
191
+ return (text_embeddings, image_embeddings, similarity, openclip_model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
 
194
  # Define Gradio interface
 
197
  inputs=[
198
  gr.Textbox(lines=5, max_lines=5, placeholder="Enter Text Here...", label="Text to Embed"),
199
  gr.Image(height=512, type="pil", label="Image to Embed"),
 
200
  ],
201
  outputs=[
202
  gr.Textbox(lines=5, max_lines=5, label="Text Embedding", autoscroll=False),
203
  gr.Textbox(lines=5, max_lines=5, label="Image Embedding", autoscroll=False),
204
  gr.Textbox(label="Cosine Similarity"),
205
+ gr.Textbox(label="Embedding Model"),
206
  ],
207
  title="OpenCLIP Embedding Generator",
208
  description="Generate embeddings using OpenCLIP model for text and images.",