File size: 12,434 Bytes
cb5b71d
 
 
 
 
 
 
bbea1cc
cb5b71d
 
 
 
 
 
bbea1cc
 
cb5b71d
bbea1cc
 
cb5b71d
 
bbea1cc
6a31b9a
0c5b67f
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
bbea1cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
041af8a
 
 
 
 
 
 
 
 
 
 
 
bbea1cc
 
 
 
 
 
8c11dd4
 
041af8a
 
 
cb5b71d
 
 
 
 
 
 
5a782ad
cb5b71d
dc92053
 
 
 
8c11dd4
041af8a
bbea1cc
 
 
 
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e92e659
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edf454b
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
36f4fe3
bc133ae
 
36f4fe3
cb5b71d
bc133ae
cb5b71d
 
 
 
bc133ae
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edf454b
cb5b71d
 
edf454b
 
 
 
 
cb5b71d
 
 
 
 
 
 
edf454b
 
 
 
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c5b67f
 
cb5b71d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c5b67f
 
 
 
6a31b9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
"""Streamlit session state.

In the future, this could be the serialization format between front and back.
"""

from __future__ import annotations

import base64
import dataclasses
import datetime
from typing import Any

from etils import epath
import pandas as pd
import requests
import streamlit as st

from core.constants import OAUTH_CLIENT_ID
from core.constants import OAUTH_CLIENT_SECRET
from core.constants import PAST_PROJECTS_PATH
from core.constants import PROJECT_FOLDER_PATTERN
from core.constants import REDIRECT_URI
from core.constants import TABS
from core.names import find_unique_name
import mlcroissant as mlc


def create_class(mlc_class: type, instance: Any, **kwargs) -> Any:
    """Creates the mlcroissant class `mlc_class` from the editor `instance`."""
    fields = dataclasses.fields(mlc_class)
    params: dict[str, Any] = {}
    for field in fields:
        name = field.name
        if hasattr(instance, name) and name not in kwargs:
            params[name] = getattr(instance, name)
    return mlc_class(**params, **kwargs)


@dataclasses.dataclass
class User:
    """The connected user."""

    access_token: str
    id_token: str
    username: str

    @classmethod
    def connect(cls, code: str):
        credentials = base64.b64encode(
            f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}".encode()
        ).decode()
        headers = {
            "Authorization": f"Basic {credentials}",
        }
        data = {
            "client_id": OAUTH_CLIENT_ID,
            "grant_type": "authorization_code",
            "code": code,
            "redirect_uri": REDIRECT_URI,
        }
        url = "https://huggingface.co/oauth/token"
        response = requests.post(url, data=data, headers=headers)
        if response.status_code == 200:
            response = response.json()
            access_token = response.get("access_token")
            id_token = response.get("id_token")
            if access_token and id_token:
                url = "https://huggingface.co/oauth/userinfo"
                headers = {"Authorization": f"Bearer {access_token}"}
                response = requests.get(url, headers=headers)
                if response.status_code == 200:
                    response = response.json()
                    username = response.get("preferred_username")
                    if username:
                        return User(
                            access_token=access_token,
                            username=username,
                            id_token=id_token,
                        )
        raise Exception(
            f"Could not connect to Hugging Face. Please, go to {REDIRECT_URI}."
            f" ({response=})."
        )


def get_user():
    """Get user from session_state."""
    return st.session_state.get(User)


@dataclasses.dataclass
class CurrentProject:
    """The selected project."""

    path: epath.Path

    @classmethod
    def create_new(cls) -> CurrentProject | None:
        timestamp = datetime.datetime.now().strftime(PROJECT_FOLDER_PATTERN)
        return cls.from_timestamp(timestamp)

    @classmethod
    def from_timestamp(cls, timestamp: str) -> CurrentProject | None:
        user = get_user()
        if user is None and OAUTH_CLIENT_ID:
            return None
        else:
            path = PAST_PROJECTS_PATH(user)
            return CurrentProject(path=path / timestamp)


class SelectedResource:
    """The selected FileSet or FileObject on the `Resources` page."""

    pass


@dataclasses.dataclass
class SelectedRecordSet:
    """The selected RecordSet on the `RecordSets` page."""

    record_set_key: int
    record_set: RecordSet


@dataclasses.dataclass
class FileObject:
    """FileObject analogue for editor"""

    name: str | None = None
    description: str | None = None
    contained_in: list[str] | None = dataclasses.field(default_factory=list)
    content_size: str | None = None
    content_url: str | None = None
    encoding_format: str | None = None
    sha256: str | None = None
    df: pd.DataFrame | None = None
    rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)
    folder: epath.PathLike | None = None


@dataclasses.dataclass
class FileSet:
    """FileSet analogue for editor"""

    contained_in: list[str] = dataclasses.field(default_factory=list)
    description: str | None = None
    encoding_format: str | None = ""
    includes: str | None = ""
    name: str = ""
    rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)


@dataclasses.dataclass
class Field:
    """Field analogue for editor"""

    name: str | None = None
    description: str | None = None
    data_types: str | list[str] | None = None
    source: mlc.Source | None = None
    rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)
    references: mlc.Source | None = None


@dataclasses.dataclass
class RecordSet:
    """Record Set analogue for editor"""

    name: str = ""
    data: list[Any] | None = None
    description: str | None = None
    is_enumeration: bool | None = None
    key: str | list[str] | None = None
    fields: list[Field] = dataclasses.field(default_factory=list)
    rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)


@dataclasses.dataclass
class Metadata:
    """main croissant data object, helper functions exist to load and unload this into the mlcroissant version"""

    name: str = ""
    description: str | None = None
    citation: str | None = None
    creator: mlc.PersonOrOrganization | None = None
    data_biases: str | None = None
    data_collection: str | None = None
    date_published: datetime.datetime | None = None
    license: str | None = ""
    personal_sensitive_information: str | None = None
    url: str = ""
    distribution: list[FileObject | FileSet] = dataclasses.field(default_factory=list)
    record_sets: list[RecordSet] = dataclasses.field(default_factory=list)
    rdf: mlc.Rdf = dataclasses.field(default_factory=mlc.Rdf)
    version: str | None = None

    def __bool__(self):
        return self.name != "" and self.url != ""

    def rename_distribution(self, old_name: str, new_name: str):
        """Renames a resource by changing all the references to this resource."""
        # Update other resources:
        for i, resource in enumerate(self.distribution):
            contained_in = resource.contained_in
            if contained_in and old_name in contained_in:
                self.distribution[i].contained_in = [
                    new_name if name == old_name else name for name in contained_in
                ]
        # Updating source/references works just as with RecordSets.
        self.rename_record_set(old_name, new_name)

    def rename_record_set(self, old_name: str, new_name: str):
        """Renames a RecordSet by changing all the references to this RecordSet."""
        for i, record_set in enumerate(self.record_sets):
            for j, field in enumerate(record_set.fields):
                possible_uid = f"{old_name}/"
                # Update source
                source = field.source
                if (
                    source
                    and source.uid
                    and (source.uid.startswith(possible_uid) or source.uid == old_name)
                ):
                    new_uid = source.uid.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].source.uid = new_uid
                # Update references
                references = field.references
                if (
                    references
                    and references.uid
                    and (
                        references.uid.startswith(possible_uid)
                        or references.uid == old_name
                    )
                ):
                    new_uid = references.uid.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].references.uid = new_uid

    def rename_field(self, old_name: str, new_name: str):
        """Renames a field by changing all the references to this field."""
        for i, record_set in enumerate(self.record_sets):
            for j, field in enumerate(record_set.fields):
                # Update source
                source = field.source
                # The difference with RecordSet is the `.endswith` here:
                if (
                    source
                    and source.uid
                    and "/" in source.uid
                    and source.uid.endswith(old_name)
                ):
                    new_uid = source.uid.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].source.uid = new_uid
                # Update references
                references = field.references
                if (
                    references
                    and references.uid
                    and "/" in references.uid
                    and references.uid.endswith(old_name)
                ):
                    new_uid = references.uid.replace(old_name, new_name, 1)
                    self.record_sets[i].fields[j].references.uid = new_uid

    def add_distribution(self, distribution: FileSet | FileObject) -> None:
        self.distribution.append(distribution)

    def remove_distribution(self, key: int) -> None:
        del self.distribution[key]

    def add_record_set(self, record_set: RecordSet) -> None:
        name = find_unique_name(self.names(), record_set.name)
        record_set.name = name
        self.record_sets.append(record_set)

    def remove_record_set(self, key: int) -> None:
        del self.record_sets[key]

    def _find_record_set(self, record_set_key: int) -> RecordSet:
        if record_set_key >= len(self.record_sets):
            raise ValueError(f"Wrong index when finding a RecordSet: {record_set_key}")
        return self.record_sets[record_set_key]

    def add_field(self, record_set_key: int, field: Field) -> None:
        record_set = self._find_record_set(record_set_key)
        record_set.fields.append(field)

    def remove_field(self, record_set_key: int, field_key: int) -> None:
        record_set = self._find_record_set(record_set_key)
        if field_key >= len(record_set.fields):
            raise ValueError(f"Wrong index when removing field: {field_key}")
        del record_set.fields[field_key]

    def to_canonical(self) -> mlc.Metadata:
        distribution = []
        for file in self.distribution:
            if isinstance(file, FileObject):
                distribution.append(create_class(mlc.FileObject, file))
            elif isinstance(file, FileSet):
                distribution.append(create_class(mlc.FileSet, file))
        record_sets = []
        for record_set in self.record_sets:
            fields = []
            for field in record_set.fields:
                fields.append(create_class(mlc.Field, field))
            record_sets.append(create_class(mlc.RecordSet, record_set, fields=fields))
        return create_class(
            mlc.Metadata,
            self,
            distribution=distribution,
            record_sets=record_sets,
        )

    @classmethod
    def from_canonical(cls, canonical_metadata: mlc.Metadata) -> Metadata:
        distribution = []
        for file in canonical_metadata.distribution:
            if isinstance(file, mlc.FileObject):
                distribution.append(create_class(FileObject, file))
            else:
                distribution.append(create_class(FileSet, file))
        record_sets = []
        for record_set in canonical_metadata.record_sets:
            fields = []
            for field in record_set.fields:
                fields.append(create_class(Field, field))
            record_sets.append(
                create_class(
                    RecordSet,
                    record_set,
                    fields=fields,
                )
            )
        return create_class(
            cls,
            canonical_metadata,
            distribution=distribution,
            record_sets=record_sets,
        )

    def names(self) -> set[str]:
        nodes = self.distribution + self.record_sets
        return set([node.name for node in nodes])


class OpenTab:
    pass


def get_tab():
    tab = st.session_state.get(OpenTab)
    if tab is None:
        return 0
    else:
        return tab


def set_tab(tab: str):
    if tab not in TABS:
        return
    index = TABS.index(tab)
    st.session_state[OpenTab] = index