File size: 8,045 Bytes
261056f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from datasets import load_dataset
from huggingface_hub import HfApi, hf_hub_download
import faiss
import os
import datetime
import time

# from embedding_model import EmbeddingModel
# from app import EmbeddingModel
from arxiv_stuff import retrieve_arxiv_papers, ARXIV_CATEGORIES_FLAT
from sentence_transformers import SentenceTransformer

# Dataset details
dataset_name = "nomadicsynth/arxiv-dataset-abstract-embeddings"
HF_TOKEN = os.getenv("HF_TOKEN")


class DatasetManager:

    def __init__(
        self,
        dataset_name: str,
        embedding_model: SentenceTransformer,
        hf_token: str = None,
    ):
        """
        Initialize the DatasetManager with the dataset name, Hugging Face token, and embedding model.
        Args:
            dataset_name (str): The name of the dataset on Hugging Face Hub.
            embedding_model (SentenceTransformer): The embedding model to use for generating embeddings.
            hf_token (str): The Hugging Face token for authentication.
        """
        self.dataset_name = dataset_name
        self.hf_token = hf_token
        self.embedding_model = embedding_model
        self.dataset = None

        self.setup_dataset()

    def get_revision_name(self):
        """Generate a timestamp-based revision name."""
        return datetime.datetime.now().strftime("v%Y-%m-%d")

    def get_latest_revision(self):
        """Return the latest timestamp-based revision."""
        api = HfApi()
        print(f"Fetching revisions for dataset: {self.dataset_name}")

        # List all tags in the repository
        refs = api.list_repo_refs(repo_id=self.dataset_name, repo_type="dataset", token=self.hf_token)
        tags = refs.tags

        print(f"Found tags: {[tag.name for tag in tags]}")

        # Filter tags with the "vYYYY-MM-DD" format
        timestamp_tags = [
            tag.name for tag in tags if tag.name.startswith("v") and len(tag.name) == 11 and tag.name[1:11].isdigit()
        ]

        if not timestamp_tags:
            print("No valid timestamp-based revisions found. Using `v1.0.0` as default.")
            return "v1.0.0"
        print(f"Valid timestamp-based revisions: {timestamp_tags}")

        # Sort and return the most recent tag
        latest_revision = sorted(timestamp_tags)[-1]
        print(f"Latest revision determined: {latest_revision}")
        return latest_revision

    def setup_dataset(self):
        """Load dataset with FAISS index."""
        print("Loading dataset from Hugging Face...")

        # Fetch the latest revision dynamically
        latest_revision = self.get_latest_revision()

        # Load dataset
        dataset = load_dataset(
            dataset_name,
            revision=latest_revision,
        )

        # Try to load the index from the Hub
        try:
            print("Downloading pre-built FAISS index...")
            index_path = hf_hub_download(
                repo_id=dataset_name,
                filename="arxiv_faiss_index.faiss",
                revision=latest_revision,
                token=self.hf_token,
                repo_type="dataset",
            )

            print("Loading pre-built FAISS index...")
            dataset["train"].load_faiss_index("embedding", index_path)
            print("Pre-built FAISS index loaded successfully")

        except Exception as e:
            print(f"Could not load pre-built index: {e}")
            print("Building new FAISS index...")

            # Add FAISS index if it doesn't exist
            if not dataset["train"].features.get("embedding"):
                print("Dataset doesn't have 'embedding' column, cannot create FAISS index")
                raise ValueError("Dataset doesn't have 'embedding' column")

            dataset["train"].add_faiss_index(
                column="embedding",
                metric_type=faiss.METRIC_INNER_PRODUCT,
                string_factory="HNSW,RFlat",  # Using reranking
            )

        print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready")

        self.dataset = dataset
        return dataset

    def update_dataset_with_new_papers(self):
        """Fetch new papers from arXiv, ensure no duplicates, and update the dataset and FAISS index."""
        if self.dataset is None:
            self.setup_dataset()

        # Get the last update date from the dataset
        last_update_date = max(
            [datetime.datetime.strptime(row["update_date"], "%Y-%m-%d") for row in self.dataset["train"]],
            default=datetime.datetime.now() - datetime.timedelta(days=1),
        )

        # Initialize variables for iterative querying
        start = 0
        max_results_per_query = 100
        all_new_papers = []

        while True:
            # Fetch new papers from arXiv since the last update
            new_papers = retrieve_arxiv_papers(
                categories=list(ARXIV_CATEGORIES_FLAT.keys()),
                start_date=last_update_date,
                end_date=datetime.datetime.now(),
                start=start,
                max_results=max_results_per_query,
            )

            if not new_papers:
                break

            all_new_papers.extend(new_papers)
            start += max_results_per_query

            # Respect the rate limit of 1 query every 3 seconds
            time.sleep(3)

        # Filter out duplicates
        existing_ids = set(row["id"] for row in self.dataset["train"])
        unique_papers = [paper for paper in all_new_papers if paper["arxiv_id"] not in existing_ids]

        if not unique_papers:
            print("No new papers to add.")
            return

        # Add new papers to the dataset
        for paper in unique_papers:
            embedding = self.embedding_model.embed_text(paper["abstract"])
            self.dataset["train"].add_item(
                {
                    "id": paper["arxiv_id"],
                    "title": paper["title"],
                    "authors": ", ".join(paper["authors"]),
                    "categories": ", ".join(paper["categories"]),
                    "abstract": paper["abstract"],
                    "update_date": paper["published_date"],
                    "embedding": embedding,
                }
            )

        # Update the FAISS index
        self.dataset["train"].add_faiss_index(
            column="embedding",
            metric_type=faiss.METRIC_INNER_PRODUCT,
            string_factory="HNSW,RFlat",
        )

        # Save the FAISS index to the Hub
        self.save_faiss_index_to_hub()

        # Save the updated dataset to the Hub with a new revision
        new_revision = self.get_revision_name()
        self.dataset.push_to_hub(
            repo_id=self.dataset_name,
            token=self.hf_token,
            commit_message=f"Update dataset with new papers ({new_revision})",
            revision=new_revision,
        )

        print(f"Dataset updated and saved to the Hub with revision {new_revision}.")

    def save_faiss_index_to_hub(self):
        """Save the FAISS index to the Hub for easy access"""
        local_index_path = "arxiv_faiss_index.faiss"

        # 1. Save the index to a local file
        self.dataset["train"].save_faiss_index("embedding", local_index_path)
        print(f"FAISS index saved locally to {local_index_path}")

        # 2. Upload the index file to the Hub
        from huggingface_hub import upload_file

        remote_path = upload_file(
            path_or_fileobj=local_index_path,
            path_in_repo=local_index_path,  # Same name on the Hub
            repo_id=self.dataset_name,  # Use your dataset repo
            token=self.hf_token,
            repo_type="dataset",  # This is a dataset file
            revision=self.get_revision_name(),  # Use the current revision
            commit_message="Add FAISS index",  # Commit message
        )

        print(f"FAISS index uploaded to Hub at {remote_path}")

        # Remove the local file. It's now stored on the Hub.
        os.remove(local_index_path)