| | |
| | import os |
| | import json |
| | import numpy as np |
| | import torch |
| | from tqdm import tqdm |
| | from datasets import concatenate_datasets, load_dataset |
| | from transformers import AutoProcessor, AutoModel |
| |
|
| | |
| | MODEL_ID = "EYEDOL/siglipFULL-agri-finetuned" |
| | DATASET_NAMES = [f"EYEDOL/AGRILLAVA-image-text{i}" for i in range(1, 16)] |
| | BATCH_SIZE = 16 |
| | OUT_DIR = "faiss_free_data" |
| | EMBEDS_FILE = os.path.join(OUT_DIR, "text_embeds.npy") |
| | TEXTS_JSONL = os.path.join(OUT_DIR, "texts.jsonl") |
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | os.makedirs(OUT_DIR, exist_ok=True) |
| |
|
| | |
| | print("Loading datasets...") |
| | all_splits = [load_dataset(name)["train"] for name in DATASET_NAMES] |
| | dataset = concatenate_datasets(all_splits) |
| | texts = list(dataset["text"]) |
| | print(f"Got {len(texts)} texts.") |
| |
|
| | |
| | print("Loading model & processor...") |
| | processor = AutoProcessor.from_pretrained(MODEL_ID) |
| | model = AutoModel.from_pretrained(MODEL_ID).to(DEVICE) |
| | model.eval() |
| |
|
| | |
| | all_embeds = [] |
| | for i in tqdm(range(0, len(texts), BATCH_SIZE), desc="Encoding texts"): |
| | batch = texts[i:i+BATCH_SIZE] |
| | inputs = processor(text=batch, padding=True, truncation=True, return_tensors="pt").to(DEVICE) |
| | with torch.no_grad(): |
| | embeds = model.get_text_features(**inputs) |
| | |
| | embeds = embeds / embeds.norm(p=2, dim=-1, keepdim=True) |
| | all_embeds.append(embeds.cpu().numpy().astype("float32")) |
| | del inputs, embeds |
| | if DEVICE == "cuda": |
| | torch.cuda.empty_cache() |
| |
|
| | all_embeds = np.concatenate(all_embeds, axis=0) |
| | print("Embeddings shape:", all_embeds.shape) |
| |
|
| | |
| | np.save(EMBEDS_FILE, all_embeds) |
| | print(f"Saved embeddings to {EMBEDS_FILE}") |
| |
|
| | with open(TEXTS_JSONL, "w", encoding="utf-8") as f: |
| | for t in texts: |
| | f.write(json.dumps({"text": t}, ensure_ascii=False) + "\n") |
| | print(f"Saved texts to {TEXTS_JSONL}") |
| |
|
| | print("Done. Upload the folder 'faiss_free_data' to your Space repository (git lfs or upload_file).") |
| |
|