dejanseo commited on
Commit
84116b7
·
verified ·
1 Parent(s): 5379c4e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +344 -0
app.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import tempfile
5
+ import os
6
+ from dejan.veczip import veczip
7
+ import csv
8
+ import ast
9
+ from huggingface_hub import hf_hub_download, HfApi
10
+ from transformers import AutoTokenizer, AutoModel
11
+ import torch
12
+
13
+ # Function definitions (is_numeric, parse_as_array, get_line_pattern, detect_header, looks_like_id_column, detect_columns, load_and_validate_embeddings, save_compressed_embeddings, run_veczip - same as before)
14
+ # -----------------
15
+ def is_numeric(s):
16
+ """Checks if a given string is numeric."""
17
+ try:
18
+ float(s)
19
+ return True
20
+ except:
21
+ return False
22
+
23
+ def parse_as_array(val):
24
+ """Parses a string as an array of numbers."""
25
+ if isinstance(val, (int, float)):
26
+ return [val]
27
+ val_str = str(val).strip()
28
+ if val_str.startswith("[") and val_str.endswith("]"):
29
+ try:
30
+ arr = ast.literal_eval(val_str)
31
+ if isinstance(arr, list) and all(is_numeric(str(x)) for x in arr):
32
+ return arr
33
+ return None
34
+ except:
35
+ return None
36
+ parts = val_str.split(",")
37
+ if len(parts) > 1 and all(is_numeric(p.strip()) for p in parts):
38
+ return [float(p.strip()) for p in parts]
39
+ return None
40
+
41
+ def get_line_pattern(row):
42
+ """Detects the pattern (text, number, or array) of a row."""
43
+ pattern = []
44
+ for val in row:
45
+ arr = parse_as_array(val)
46
+ if arr is not None:
47
+ pattern.append('arr')
48
+ else:
49
+ if is_numeric(val):
50
+ pattern.append('num')
51
+ else:
52
+ pattern.append('text')
53
+ return pattern
54
+
55
+ def detect_header(lines):
56
+ """Detects if a CSV has a header."""
57
+ if len(lines) < 2:
58
+ return False
59
+ first_line_pattern = get_line_pattern(lines[0])
60
+ subsequent_patterns = [get_line_pattern(r) for r in lines[1:]]
61
+ if len(subsequent_patterns) > 1:
62
+ if all(p == subsequent_patterns[0] for p in subsequent_patterns) and first_line_pattern != subsequent_patterns[0]:
63
+ return True
64
+ else:
65
+ if subsequent_patterns and first_line_pattern != subsequent_patterns[0]:
66
+ return True
67
+ return False
68
+
69
+ def looks_like_id_column(col_values):
70
+ """Checks if a column looks like an ID column (sequential integers)."""
71
+ try:
72
+ nums = [int(float(v)) for v in col_values]
73
+ return nums == list(range(nums[0], nums[0] + len(nums)))
74
+ except:
75
+ return False
76
+
77
+ def detect_columns(file_path):
78
+ """Detects embedding and metadata columns in a CSV file."""
79
+ with open(file_path, "r", newline="", encoding="utf-8") as f:
80
+ try:
81
+ sample = f.read(1024*10) # Read a larger sample for sniffing
82
+ dialect = csv.Sniffer().sniff(sample, delimiters=[',','\t',';','|'])
83
+ delimiter = dialect.delimiter
84
+ except:
85
+ delimiter = ','
86
+ f.seek(0) # reset file pointer
87
+ reader = csv.reader(f, delimiter=delimiter)
88
+ first_lines = list(reader)[:10]
89
+
90
+ if not first_lines:
91
+ raise ValueError("No data")
92
+
93
+ has_header = detect_header(first_lines)
94
+ if has_header:
95
+ header = first_lines[0]
96
+ data = first_lines[1:]
97
+ else:
98
+ header = []
99
+ data = first_lines
100
+
101
+ if not data:
102
+ return has_header, [], [], delimiter
103
+
104
+ cols = list(zip(*data))
105
+
106
+ candidate_arrays = []
107
+ candidate_numeric = []
108
+ id_like_columns = set()
109
+ text_like_columns = set()
110
+
111
+ for ci, col in enumerate(cols):
112
+ col = list(col)
113
+ parsed_rows = [parse_as_array(val) for val in col]
114
+
115
+ if all(r is not None for r in parsed_rows):
116
+ lengths = {len(r) for r in parsed_rows}
117
+ if len(lengths) == 1:
118
+ candidate_arrays.append(ci)
119
+ continue
120
+ else:
121
+ text_like_columns.add(ci)
122
+ continue
123
+
124
+ if all(is_numeric(v) for v in col):
125
+ if looks_like_id_column(col):
126
+ id_like_columns.add(ci)
127
+ else:
128
+ candidate_numeric.append(ci)
129
+ else:
130
+ text_like_columns.add(ci)
131
+
132
+ identified_embedding_columns = set(candidate_arrays)
133
+ identified_metadata_columns = set()
134
+
135
+ if candidate_arrays:
136
+ identified_metadata_columns.update(candidate_numeric)
137
+ else:
138
+ if len(candidate_numeric) > 1:
139
+ identified_embedding_columns.update(candidate_numeric)
140
+ else:
141
+ identified_metadata_columns.update(candidate_numeric)
142
+
143
+ identified_metadata_columns.update(id_like_columns)
144
+ identified_metadata_columns.update(text_like_columns)
145
+
146
+
147
+ if header:
148
+ for ci, col_name in enumerate(header):
149
+ if col_name.lower() == 'id':
150
+ if ci in identified_embedding_columns:
151
+ identified_embedding_columns.remove(ci)
152
+ identified_metadata_columns.add(ci)
153
+ break
154
+
155
+ emb_cols = [header[i] if header and i < len(header) else i for i in identified_embedding_columns]
156
+ meta_cols = [header[i] if header and i < len(header) else i for i in identified_metadata_columns]
157
+
158
+
159
+ return has_header, emb_cols, meta_cols, delimiter
160
+
161
+ def load_and_validate_embeddings(input_file, target_dims):
162
+ """Loads, validates, and summarizes embedding data from a CSV."""
163
+ print(f"Loading data from {input_file}...")
164
+ has_header, embedding_columns, metadata_columns, delimiter = detect_columns(input_file)
165
+ data = pd.read_csv(input_file, header=0 if has_header else None, delimiter=delimiter)
166
+
167
+
168
+ def is_valid_row(row):
169
+ for col in embedding_columns:
170
+ if parse_as_array(row[col]) is None:
171
+ return False
172
+ return True
173
+
174
+ valid_rows_filter = data.apply(is_valid_row, axis=1)
175
+ data = data[valid_rows_filter]
176
+
177
+ print("\n=== File Summary ===")
178
+ print(f"File: {input_file}")
179
+ print(f"Rows: {len(data)}")
180
+ print(f"Metadata Columns: {metadata_columns}")
181
+ print(f"Embedding Columns: {embedding_columns}")
182
+ print("====================\n")
183
+
184
+ return data, embedding_columns, metadata_columns, has_header, list(data.columns)
185
+
186
+
187
+ def save_compressed_embeddings(output_file, metadata, compressed_embeddings, embedding_columns, original_columns, has_header):
188
+ """Saves compressed embeddings to a CSV file."""
189
+ print(f"Saving compressed data to {output_file}...")
190
+ metadata = metadata.copy()
191
+
192
+
193
+ for i, col in enumerate(embedding_columns):
194
+ metadata[col] = [compressed_embeddings[i][j].tolist() for j in range(compressed_embeddings[i].shape[0])]
195
+
196
+ header_option = True if has_header else False
197
+ final_df = metadata.reindex(columns=original_columns) if original_columns else metadata
198
+ final_df.to_csv(output_file, index=False, header=header_option)
199
+ print(f"Data saved to {output_file}.")
200
+
201
+ def run_veczip(input_file, target_dims=16):
202
+ """Runs veczip compression on the input data."""
203
+ data, embedding_columns, metadata_columns, has_header, original_columns = load_and_validate_embeddings(input_file, target_dims)
204
+
205
+ all_embeddings = []
206
+ for col in embedding_columns:
207
+ embeddings = np.array([parse_as_array(x) for x in data[col].values])
208
+ all_embeddings.append(embeddings)
209
+
210
+ combined_embeddings = np.concatenate(all_embeddings, axis=0)
211
+ compressor = veczip(target_dims=target_dims)
212
+ retained_indices = compressor.compress(combined_embeddings)
213
+
214
+
215
+ compressed_embeddings = []
216
+ for embeddings in all_embeddings:
217
+ compressed_embeddings.append(embeddings[:, retained_indices])
218
+
219
+ temp_output = tempfile.NamedTemporaryFile(suffix='.csv', delete=False)
220
+ save_compressed_embeddings(temp_output.name, data[metadata_columns], compressed_embeddings, embedding_columns, original_columns, has_header)
221
+ return temp_output.name
222
+ # -----------------
223
+
224
+ # Embedding Generation Function
225
+ @st.cache_resource
226
+ def load_embedding_model(model_name="mixedbread-ai/mxbai-embed-large-v1"):
227
+ """Loads the embedding model and tokenizer."""
228
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
229
+ model = AutoModel.from_pretrained(model_name)
230
+ return tokenizer, model
231
+
232
+ @st.cache_data
233
+ def generate_embeddings(text_list, tokenizer, model):
234
+ """Generates embeddings for a list of text entries."""
235
+ encoded_input = tokenizer(
236
+ text_list, padding=True, truncation=True, return_tensors="pt"
237
+ )
238
+ with torch.no_grad():
239
+ model_output = model(**encoded_input)
240
+ embeddings = model_output.last_hidden_state.mean(dim=1)
241
+ return embeddings.cpu().numpy()
242
+
243
+
244
+ # Streamlit App
245
+ def main():
246
+ st.title("Veczip Embeddings Tool")
247
+
248
+ st.markdown(
249
+ """
250
+ This tool offers two ways to compress your embeddings:
251
+
252
+ 1. **Compress Your Embeddings:** Upload a CSV file containing pre-existing embeddings and reduce their dimensionality using `dejan.veczip`.
253
+ 2. **Generate & Compress Embeddings:** Provide a list of text entries, and this tool will generate embeddings using `mxbai-embed-large-v1` and then compress them.
254
+ """
255
+ )
256
+ st.markdown(
257
+ """
258
+ **General Usage Guide**
259
+
260
+ * Both tools work best with larger datasets (hundreds or thousands of entries).
261
+ * For CSV files with embeddings, ensure that numeric embedding columns are parsed as arrays (e.g. '[1,2,3]' or '1,2,3') and metadata columns are parsed as text or numbers.
262
+ * Output files are compressed to 16 dimensions.
263
+ """
264
+ )
265
+
266
+
267
+ tab1, tab2 = st.tabs(["Compress Your Embeddings", "Generate & Compress Embeddings"])
268
+
269
+ with tab1:
270
+ st.header("Compress Your Embeddings")
271
+ st.markdown(
272
+ """
273
+ Upload a CSV file containing pre-existing embeddings.
274
+ This will reduce the dimensionality of the embeddings to 16 dimensions using `dejan.veczip`.
275
+ """
276
+ )
277
+ uploaded_file = st.file_uploader(
278
+ "Upload CSV file with embeddings", type=["csv"],
279
+ help="Ensure the CSV file has columns where embedding arrays are represented as text. Examples: '[1,2,3]' or '1,2,3'",
280
+ )
281
+ if uploaded_file:
282
+ try:
283
+ with st.spinner("Analyzing and compressing embeddings..."):
284
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
285
+ temp_file.write(uploaded_file.read())
286
+ temp_file.close()
287
+ output_file_path = run_veczip(temp_file.name)
288
+ with open(output_file_path, 'rb') as f:
289
+ st.download_button(
290
+ label="Download Compressed CSV",
291
+ data=f,
292
+ file_name="compressed_embeddings.csv",
293
+ mime="text/csv"
294
+ )
295
+ os.unlink(temp_file.name)
296
+ os.unlink(output_file_path)
297
+ st.success("Compression complete! Download your compressed file below.")
298
+ except Exception as e:
299
+ st.error(f"Error processing file: {e}")
300
+ with tab2:
301
+ st.header("Generate & Compress Embeddings")
302
+ st.markdown(
303
+ """
304
+ Provide a list of text entries (one per line), and this tool will:
305
+ 1. Generate embeddings using `mixedbread-ai/mxbai-embed-large-v1`.
306
+ 2. Compress those embeddings to 16 dimensions using `dejan.veczip`.
307
+ """
308
+ )
309
+ text_input = st.text_area(
310
+ "Enter text entries (one per line)",
311
+ help="Enter each text entry on a new line. This tool works best with a large sample size.",
312
+ )
313
+
314
+ if text_input:
315
+ text_list = text_input.strip().split("\n")
316
+ if len(text_list) == 0:
317
+ st.warning("Please enter some text for embedding")
318
+ else:
319
+ try:
320
+ with st.spinner("Generating and compressing embeddings..."):
321
+ tokenizer, model = load_embedding_model()
322
+ embeddings = generate_embeddings(text_list, tokenizer, model)
323
+ compressor = veczip(target_dims=16)
324
+ retained_indices = compressor.compress(embeddings)
325
+ compressed_embeddings = embeddings[:, retained_indices]
326
+ df = pd.DataFrame(
327
+ {"text": text_list, "embeddings": compressed_embeddings.tolist()}
328
+ )
329
+ st.dataframe(df)
330
+ csv_file = df.to_csv(index=False).encode()
331
+ st.download_button(
332
+ label="Download Compressed Embeddings (CSV)",
333
+ data=csv_file,
334
+ file_name="generated_compressed_embeddings.csv",
335
+ mime="text/csv",
336
+ )
337
+ st.success("Generated and compressed! Download your file below.")
338
+
339
+ except Exception as e:
340
+ st.error(f"Error: {e}")
341
+
342
+
343
+ if __name__ == "__main__":
344
+ main()