File size: 14,991 Bytes
cb5b71d
 
 
 
 
 
 
bbea1cc
cb5b71d
 
 
73ebcab
cb5b71d
 
 
bbea1cc
 
cb5b71d
bbea1cc
 
cb5b71d
 
bbea1cc
6a31b9a
0c5b67f
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
bbea1cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
041af8a
 
 
 
 
 
 
 
 
 
 
 
bbea1cc
 
 
 
 
 
8c11dd4
 
041af8a
 
 
cb5b71d
 
 
 
 
 
 
5a782ad
cb5b71d
dc92053
 
 
 
8c11dd4
041af8a
bbea1cc
 
 
 
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73ebcab
1b94fec
73ebcab
cb5b71d
73ebcab
 
 
 
 
 
 
 
 
 
 
 
cb5b71d
 
 
 
 
 
 
e92e659
cb5b71d
 
 
73ebcab
cb5b71d
 
 
 
 
 
 
 
 
73ebcab
cb5b71d
 
 
 
 
 
 
 
 
73ebcab
cb5b71d
 
edf454b
73ebcab
cb5b71d
 
 
 
 
 
 
73ebcab
cb5b71d
 
 
1b94fec
73ebcab
bc133ae
 
7fe906d
cb5b71d
bc133ae
cb5b71d
 
 
bc133ae
cb5b71d
 
 
 
 
 
 
 
73ebcab
 
cb5b71d
 
 
 
 
 
 
 
 
 
 
73ebcab
 
cb5b71d
73ebcab
cb5b71d
 
73ebcab
 
 
 
 
 
 
 
 
cb5b71d
 
 
 
73ebcab
 
cb5b71d
73ebcab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb5b71d
 
 
 
 
73ebcab
cb5b71d
 
 
73ebcab
 
 
cb5b71d
 
 
 
73ebcab
 
cb5b71d
73ebcab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb5b71d
 
 
 
 
 
 
 
0c5b67f
 
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b94fec
cb5b71d
 
1b94fec
cb5b71d
1b94fec
cb5b71d
 
 
 
1b94fec
 
 
 
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c5b67f
 
73ebcab
 
 
 
 
 
 
 
 
 
6a31b9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
"""Streamlit session state.

In the future, this could be the serialization format between front and back.
"""

from __future__ import annotations

import base64
import dataclasses
import datetime
from typing import Any
import uuid

from etils import epath
import pandas as pd
import requests
import streamlit as st

from core.constants import OAUTH_CLIENT_ID
from core.constants import OAUTH_CLIENT_SECRET
from core.constants import PAST_PROJECTS_PATH
from core.constants import PROJECT_FOLDER_PATTERN
from core.constants import REDIRECT_URI
from core.constants import TABS
from core.names import find_unique_name
import mlcroissant as mlc


def create_class(mlc_class: type, instance: Any, **kwargs) -> Any:
    """Creates the mlcroissant class `mlc_class` from the editor `instance`."""
    fields = dataclasses.fields(mlc_class)
    params: dict[str, Any] = {}
    for field in fields:
        name = field.name
        if hasattr(instance, name) and name not in kwargs:
            params[name] = getattr(instance, name)
    return mlc_class(**params, **kwargs)


@dataclasses.dataclass
class User:
    """The connected user."""

    access_token: str
    id_token: str
    username: str

    @classmethod
    def connect(cls, code: str):
        credentials = base64.b64encode(
            f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}".encode()
        ).decode()
        headers = {
            "Authorization": f"Basic {credentials}",
        }
        data = {
            "client_id": OAUTH_CLIENT_ID,
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": REDIRECT_URI,
        }
        url = "https://huggingface.co/oauth/token"
        response = requests.post(url, data=data, headers=headers)
        if response.status_code == 200:
            response = response.json()
            access_token = response.get("access_token")
            id_token = response.get("id_token")
            if access_token and id_token:
                url = "https://huggingface.co/oauth/userinfo"
                headers = {"Authorization": f"Bearer {access_token}"}
                response = requests.get(url, headers=headers)
                if response.status_code == 200:
                    response = response.json()
                    username = response.get("preferred_username")
                    if username:
                        return User(
                            access_token=access_token,
                            username=username,
                            id_token=id_token,
                        )
        raise Exception(
            f"Could not connect to Hugging Face. Please, go to {REDIRECT_URI}."
            f" ({response=})."
        )


def get_user():
    """Get user from session_state."""
    return st.session_state.get(User)


@dataclasses.dataclass
class CurrentProject:
    """The selected project."""

    path: epath.Path

    @classmethod
    def create_new(cls) -> CurrentProject | None:
        timestamp = datetime.datetime.now().strftime(PROJECT_FOLDER_PATTERN)
        return cls.from_timestamp(timestamp)

    @classmethod
    def from_timestamp(cls, timestamp: str) -> CurrentProject | None:
        user = get_user()
        if user is None and OAUTH_CLIENT_ID:
            return None
        else:
            path = PAST_PROJECTS_PATH(user)
            return CurrentProject(path=path / timestamp)


class SelectedResource:
    """The selected FileSet or FileObject on the `Resources` page."""

    pass


@dataclasses.dataclass
class SelectedRecordSet:
    """The selected RecordSet on the `RecordSets` page."""

    record_set_key: int
    record_set: RecordSet


@dataclasses.dataclass
class Node:
    ctx: mlc.Context = dataclasses.field(default_factory=mlc.Context)
    id: str | None = None
    name: str | None = None

    def get_name_or_id(self):
        if self.ctx.is_v0():
            return self.name
        else:
            return self.id


@dataclasses.dataclass
class FileObject(Node):
    """FileObject analogue for editor"""

    description: str | None = None
    contained_in: list[str] | None = dataclasses.field(default_factory=list)
    content_size: str | None = None
    content_url: str | None = None
    encoding_format: str | None = None
    sha256: str | None = None
    df: pd.DataFrame | None = None
    folder: epath.PathLike | None = None


@dataclasses.dataclass
class FileSet(Node):
    """FileSet analogue for editor"""

    contained_in: list[str] = dataclasses.field(default_factory=list)
    description: str | None = None
    encoding_format: str | None = ""
    includes: str | None = ""


@dataclasses.dataclass
class Field(Node):
    """Field analogue for editor"""

    description: str | None = None
    data_types: str | list[str] | None = None
    source: mlc.Source | None = None
    references: mlc.Source | None = None


@dataclasses.dataclass
class RecordSet(Node):
    """Record Set analogue for editor"""

    data: list[Any] | None = None
    data_types: list[str] | None = None
    description: str | None = None
    is_enumeration: bool | None = None
    key: str | list[str] | None = None
    fields: list[Field] = dataclasses.field(default_factory=list)


@dataclasses.dataclass
class Metadata(Node):
    """main croissant data object, helper functions exist to load and unload this into the mlcroissant version"""

    description: str | None = None
    cite_as: str | None = None
    creators: list[mlc.Person] = dataclasses.field(default_factory=list)
    data_biases: str | None = None
    data_collection: str | None = None
    date_published: datetime.datetime | None = None
    license: str | None = ""
    personal_sensitive_information: str | None = None
    url: str = ""
    distribution: list[FileObject | FileSet] = dataclasses.field(default_factory=list)
    record_sets: list[RecordSet] = dataclasses.field(default_factory=list)
    version: str | None = None

    def __bool__(self):
        return self.name != "" and self.url != ""

    def rename_distribution(self, old_name: str, new_name: str):
        """Renames a resource by changing all the references to this resource."""
        # Update other resources:
        for i, resource in enumerate(self.distribution):
            if resource.id == old_name:
                self.distribution[i].id = new_name
            contained_in = resource.contained_in
            if contained_in and old_name in contained_in:
                self.distribution[i].contained_in = [
                    new_name if name == old_name else name for name in contained_in
                ]
        # Updating source/references works just as with RecordSets.
        self.rename_record_set(old_name, new_name)

    def rename_record_set(self, old_name: str, new_name: str):
        """Renames a RecordSet by changing all the references to this RecordSet."""
        for i, record_set in enumerate(self.record_sets):
            if record_set.id == old_name:
                self.record_sets[i].id = new_name
            for j, field in enumerate(record_set.fields):
                possible_uuid = f"{old_name}/"
                # Update source
                source = field.source
                if source and source.field and source.field.startswith(possible_uuid):
                    new_uuid = source.field.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].source.field = new_uuid
                if source and source.file_object and source.file_object == old_name:
                    self.record_sets[i].fields[j].source.file_object = new_name
                if source and source.file_set and source.file_set == old_name:
                    self.record_sets[i].fields[j].source.file_set = new_name
                if source and source.distribution and source.distribution == old_name:
                    self.record_sets[i].fields[j].source.distribution = new_name
                # Update references
                references = field.references
                if (
                    references
                    and references.field
                    and references.field.startswith(possible_uuid)
                ):
                    new_uuid = references.field.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].references.field = new_uuid
                if (
                    references
                    and references.file_object
                    and references.file_object == old_name
                ):
                    self.record_sets[i].fields[j].references.file_object = new_name
                if (
                    references
                    and references.file_set
                    and references.file_set == old_name
                ):
                    self.record_sets[i].fields[j].references.file_set = new_name
                if (
                    references
                    and references.distribution
                    and references.distribution == old_name
                ):
                    self.record_sets[i].fields[j].references.distribution = new_name

    def rename_field(self, old_name: str, new_name: str):
        """Renames a field by changing all the references to this field."""
        for i, record_set in enumerate(self.record_sets):
            for j, field in enumerate(record_set.fields):
                possible_uuid = f"/{old_name}"
                # Update source
                source = field.source
                # The difference with RecordSet is the `.endswith` here:
                if source and source.field and source.field.endswith(possible_uuid):
                    new_uuid = source.field.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].source.field = new_uuid
                # Update references
                references = field.references
                if (
                    references
                    and references.field
                    and references.field.endswith(possible_uuid)
                ):
                    new_uuid = references.field.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].references.field = new_uuid

    def rename_id(self, old_id: str, new_id: str):
        for resource in self.distribution:
            if resource.id == old_id:
                resource.id = new_id
            if resource.contained_in and old_id in resource.contained_in:
                resource.contained_in = [
                    new_id if uuid == old_id else uuid for uuid in resource.contained_in
                ]
        for record_set in self.record_sets:
            if record_set.id == old_id:
                record_set.id = new_id
            for field in record_set.fields:
                if field.id == old_id:
                    field.id = new_id
                for p in ["distribution", "field", "file_object", "file_set"]:
                    if field.source and getattr(field.source, p) == old_id:
                        setattr(field.source, p, new_id)
                    if field.references and getattr(field.references, p) == old_id:
                        setattr(field.references, p, new_id)

    def add_distribution(self, distribution: FileSet | FileObject) -> None:
        self.distribution.append(distribution)

    def remove_distribution(self, key: int) -> None:
        del self.distribution[key]

    def add_record_set(self, record_set: RecordSet) -> None:
        name = find_unique_name(self.names(), record_set.name)
        record_set.name = name
        self.record_sets.append(record_set)

    def remove_record_set(self, key: int) -> None:
        del self.record_sets[key]

    def _find_record_set(self, record_set_key: int) -> RecordSet:
        if record_set_key >= len(self.record_sets):
            raise ValueError(f"Wrong index when finding a RecordSet: {record_set_key}")
        return self.record_sets[record_set_key]

    def add_field(self, record_set_key: int, field: Field) -> None:
        record_set = self._find_record_set(record_set_key)
        record_set.fields.append(field)

    def remove_field(self, record_set_key: int, field_key: int) -> None:
        record_set = self._find_record_set(record_set_key)
        if field_key >= len(record_set.fields):
            raise ValueError(f"Wrong index when removing field: {field_key}")
        del record_set.fields[field_key]

    def to_canonical(self) -> mlc.Metadata:
        distribution = []
        ctx = self.ctx
        for file in self.distribution:
            if isinstance(file, FileObject):
                distribution.append(create_class(mlc.FileObject, file, ctx=ctx))
            elif isinstance(file, FileSet):
                distribution.append(create_class(mlc.FileSet, file, ctx=ctx))
        record_sets = []
        for record_set in self.record_sets:
            fields = []
            for field in record_set.fields:
                fields.append(create_class(mlc.Field, field, ctx=ctx))
            record_sets.append(
                create_class(mlc.RecordSet, record_set, ctx=ctx, fields=fields)
            )
        return create_class(
            mlc.Metadata,
            self,
            distribution=distribution,
            record_sets=record_sets,
        )

    @classmethod
    def from_canonical(cls, canonical_metadata: mlc.Metadata) -> Metadata:
        distribution = []
        for file in canonical_metadata.distribution:
            if isinstance(file, mlc.FileObject):
                distribution.append(create_class(FileObject, file))
            else:
                distribution.append(create_class(FileSet, file))
        record_sets = []
        for record_set in canonical_metadata.record_sets:
            fields = []
            for field in record_set.fields:
                fields.append(create_class(Field, field))
            record_sets.append(
                create_class(
                    RecordSet,
                    record_set,
                    fields=fields,
                )
            )
        return create_class(
            cls,
            canonical_metadata,
            distribution=distribution,
            record_sets=record_sets,
        )

    def names(self) -> set[str]:
        distribution = set()
        record_sets = set()
        fields = set()
        for resource in self.distribution:
            distribution.add(resource.get_name_or_id())
        for record_set in self.record_sets:
            record_sets.add(record_set.get_name_or_id())
            for field in record_set.fields:
                fields.add(field.get_name_or_id())
        return distribution.union(record_sets).union(fields)


class OpenTab:
    pass


def get_tab():
    tab = st.session_state.get(OpenTab)
    if tab is None:
        return 0
    else:
        return tab


def set_tab(tab: str):
    if tab not in TABS:
        return
    index = TABS.index(tab)
    st.session_state[OpenTab] = index