Someshfengde commited on
Commit
2063d73
β€’
1 Parent(s): ad36e03

Upload folder using huggingface_hub

Browse files
.github/workflows/update_space.yml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Run Python script
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+
12
+ steps:
13
+ - name: Checkout
14
+ uses: actions/checkout@v2
15
+
16
+ - name: Set up Python
17
+ uses: actions/setup-python@v2
18
+ with:
19
+ python-version: '3.9'
20
+
21
+ - name: Install Gradio
22
+ run: python -m pip install gradio
23
+
24
+ - name: Log in to Hugging Face
25
+ run: python -c 'import huggingface_hub; huggingface_hub.login(token="${{ secrets.hf_token }}")'
26
+
27
+ - name: Deploy to Spaces
28
+ run: gradio deploy
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Visualized BGE Demo
3
- emoji: πŸ“‰
4
- colorFrom: yellow
5
- colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Visualized_BGE_demo
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.44.0
 
 
6
  ---
 
 
Visualized_base_en_v1.5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07e58cf70ee6962530490ef1ac5b632e7e0153ba8c7ed49d55e0f41ec97bf6a6
3
+ size 392860018
app.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from FlagEmbedding.visual.modeling import Visualized_BGE
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+ from pdf2image import convert_from_path
10
+ import numpy as np
11
+ import torch.nn.functional as F
12
+ import io
13
+
14
+ # Initialize the Visualized-BGE model
15
+ def load_bge_model(model_name: str, model_weight_path: str):
16
+ model = Visualized_BGE(model_name_bge=model_name, model_weight=model_weight_path)
17
+ model.eval()
18
+ return model
19
+
20
+ # Load the BGE model (ensure you have downloaded the weights and provide the correct path)
21
+ model_name = "BAAI/bge-base-en-v1.5" # or "BAAI/bge-m3" for multilingual
22
+ model_weight_path ="./Visualized_base_en_v1.5.pth"
23
+ model = load_bge_model(model_name, model_weight_path)
24
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
25
+ model.to(device)
26
+
27
+ # Function to encode images
28
+ import tempfile
29
+ import os
30
+
31
+ def encode_image(image_input):
32
+ """
33
+ Encodes an image for retrieval.
34
+
35
+ Args:
36
+ image_input: Can be a file path (str), a NumPy array, or a PIL Image.
37
+
38
+ Returns:
39
+ torch.Tensor: The image embedding.
40
+ """
41
+ delete_temp_file = False
42
+ if isinstance(image_input, str):
43
+ image_path = image_input
44
+ else:
45
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
46
+ if isinstance(image_input, np.ndarray):
47
+ image = Image.fromarray(image_input)
48
+ elif isinstance(image_input, Image.Image):
49
+ image = image_input
50
+ else:
51
+ raise ValueError("Unsupported image input type for image encoding.")
52
+
53
+ image.save(tmp_file.name)
54
+ image_path = tmp_file.name
55
+ delete_temp_file = True # Mark that we need to delete this temp file
56
+
57
+ try:
58
+ with torch.no_grad():
59
+ embed = model.encode(image=image_path)
60
+ embed = embed.squeeze(0)
61
+ finally:
62
+ if delete_temp_file:
63
+ # Remove the temporary file
64
+ os.remove(image_path)
65
+
66
+ return embed.cpu()
67
+
68
+
69
+ # Function to encode text
70
+ def encode_text(text):
71
+ with torch.no_grad():
72
+ embed = model.encode(text=text) # Assuming encode returns [1, D]
73
+ embed = embed.squeeze(0) # Remove the batch dimension if present
74
+ return embed.cpu()
75
+
76
+ # Function to index uploaded files (PDFs or images)
77
+ def index_files(files, embeddings_state, metadata_state):
78
+ print("Indexing files...")
79
+ embeddings = []
80
+ metadata = []
81
+
82
+ for file in files:
83
+ if file.name.lower().endswith('.pdf'):
84
+ images = convert_from_path(file.name, thread_count=4)
85
+ for idx, img in enumerate(images):
86
+ img_path = f"{file.name}_page_{idx}.png"
87
+ img.save(img_path)
88
+ embed = encode_image(img_path)
89
+ print(f"Embedding shape after encoding image: {embed.shape}") # Should be [768]
90
+ embeddings.append(embed)
91
+ metadata.append({"type": "image", "path": img_path, "info": f"Page {idx}"})
92
+ elif file.name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
93
+ img_path = file.name
94
+ embed = encode_image(img_path)
95
+ print(f"Embedding shape after encoding image: {embed.shape}") # Should be [768]
96
+ embeddings.append(embed)
97
+ metadata.append({"type": "image", "path": img_path, "info": "Uploaded Image"})
98
+ else:
99
+ raise gr.Error("Unsupported file type. Please upload PDFs or image files.")
100
+
101
+ embeddings = torch.stack(embeddings).to(device) # Should result in shape [N, 768]
102
+ print(f"Stacked embeddings shape: {embeddings.shape}")
103
+ embeddings_state = embeddings
104
+ metadata_state = metadata
105
+ return f"Indexed {len(embeddings)} items.", embeddings_state, metadata_state
106
+
107
+ def search(query_text, query_image, k, embeddings_state, metadata_state):
108
+ embeddings = embeddings_state
109
+ metadata = metadata_state
110
+
111
+ if embeddings is None or embeddings.size(0) == 0:
112
+ return "No embeddings indexed. Please upload and index files first.", []
113
+
114
+ query_emb = None
115
+
116
+ if query_text and query_image:
117
+ gr.warning("Please provide either a text query or an image query, not both. Using text query by default.")
118
+ # text_emb = encode_text(query_text) # [D]
119
+ # image_emb = encode_image(query_image) # [D]
120
+ # query_emb = (text_emb + image_emb) / 2 # [D]
121
+ # print("Combined text and image embeddings for query.")
122
+ query_emb = encode_text(query_text) # [D]
123
+ if query_text:
124
+ query_emb = encode_text(query_text) # [D]
125
+ print("Encoded text query.")
126
+ elif query_image is not None :
127
+ print(query_image)
128
+ query_emb = encode_image(query_image) # [D]
129
+ print("Encoded image query.")
130
+ else:
131
+ return "Please provide at least a text query or an image query.", []
132
+
133
+ # Ensure query_emb has shape [1, D]
134
+ if query_emb.dim() == 1:
135
+ query_emb = query_emb.unsqueeze(0) # [1, D]
136
+
137
+ # Normalize embeddings for cosine similarity
138
+ query_emb = F.normalize(query_emb.to(device), p=2, dim=1) # [1, D]
139
+ indexed_emb = F.normalize(embeddings.to(device), p=2, dim=1) # [N, D]
140
+
141
+ print(f"Query embedding shape: {query_emb.shape}") # Should be [1, 768]
142
+ print(f"Indexed embeddings shape: {indexed_emb.shape}") # Should be [N, 768]
143
+
144
+ # Compute cosine similarities
145
+ similarities = torch.matmul(query_emb, indexed_emb.T).squeeze(0) # [N]
146
+ print(f"Similarities shape: {similarities.shape}")
147
+
148
+ # Get top-k results
149
+ topk = torch.topk(similarities, k)
150
+ topk_indices = topk.indices.cpu().numpy()
151
+ topk_scores = topk.values.cpu().numpy()
152
+
153
+ print(f"Top-{k} indices: {topk_indices}")
154
+ print(f"Top-{k} scores: {topk_scores}")
155
+
156
+ results = []
157
+ for idx, score in zip(topk_indices, topk_scores):
158
+ item = metadata[idx]
159
+ if item["type"] == "image":
160
+ # Load image from path
161
+ img = Image.open(item["path"]).convert("RGB")
162
+ results.append((img, f"Score: {score:.4f} | {item['info']}"))
163
+ else:
164
+ # Handle text data if applicable
165
+ results.append((item["data"], f"Score: {score:.4f} | {item['info']}"))
166
+
167
+ return results
168
+
169
+ # Gradio Interface
170
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
171
+ gr.Markdown("# Visualized-BGE: Multimodal Retrieval Demo πŸŽ‰")
172
+ gr.Markdown("""
173
+ Upload PDF or image files to index them. Then, perform searches using text, images, or both to retrieve the most relevant items.
174
+
175
+ **Note:** Ensure that you have indexed the files before performing a search.
176
+ """)
177
+
178
+ # Initialize state variables
179
+ embeddings_state = gr.State(None)
180
+ metadata_state = gr.State(None)
181
+
182
+ with gr.Row():
183
+ with gr.Column(scale=2):
184
+ gr.Markdown("## 1️⃣ Upload and Index Files")
185
+ file_input = gr.File(file_types=["pdf", "png", "jpg", "jpeg", "bmp", "gif"], file_count="multiple", label="Upload Files")
186
+ index_button = gr.Button("πŸ”„ Index Files")
187
+ index_status = gr.Textbox("No files indexed yet.", label="Indexing Status")
188
+
189
+ with gr.Column(scale=3):
190
+ gr.Markdown("## 2️⃣ Perform Search")
191
+ with gr.Row():
192
+ query_text = gr.Textbox(placeholder="Enter your text query here...", label="Text Query")
193
+ query_image = gr.Image(label="Image Query (Optional)")
194
+ k = gr.Slider(minimum=1, maximum=20, step=1, label="Number of Results", value=5)
195
+ search_button = gr.Button("πŸ” Search")
196
+ output_gallery = gr.Gallery(label="Retrieved Results", show_label=True, columns=2)
197
+
198
+ # Define button actions
199
+ index_button.click(
200
+ index_files,
201
+ inputs=[file_input, embeddings_state, metadata_state],
202
+ outputs=[index_status, embeddings_state, metadata_state]
203
+ )
204
+ search_button.click(
205
+ search,
206
+ inputs=[query_text, query_image, k, embeddings_state, metadata_state],
207
+ outputs=output_gallery
208
+ )
209
+
210
+ gr.Markdown("""
211
+ ---
212
+ ## About
213
+ This demo uses the **Visualized-BGE** model for efficient multimodal retrieval tasks. Upload your documents or images, index them, and perform searches using text, images, or a combination of both.
214
+
215
+ **References:**
216
+ - [Visualized-BGE Paper](https://arxiv.org/abs/2406.04292)
217
+ - [FlagEmbedding GitHub](https://github.com/FlagOpen/FlagEmbedding)
218
+ """)
219
+
220
+ if __name__ == "__main__":
221
+ demo.launch(debug=True, share=True)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ poppler-utils
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pdf2image
2
+ gradio
3
+ FlagEmbedding