Spaces:
Running
Running
"""Streamlit session state. | |
In the future, this could be the serialization format between front and back. | |
""" | |
from __future__ import annotations | |
import base64 | |
import dataclasses | |
import datetime | |
from typing import Any | |
import uuid | |
from etils import epath | |
import pandas as pd | |
import requests | |
import streamlit as st | |
from core.constants import OAUTH_CLIENT_ID | |
from core.constants import OAUTH_CLIENT_SECRET | |
from core.constants import PAST_PROJECTS_PATH | |
from core.constants import PROJECT_FOLDER_PATTERN | |
from core.constants import REDIRECT_URI | |
from core.constants import TABS | |
from core.names import find_unique_name | |
import mlcroissant as mlc | |
def create_class(mlc_class: type, instance: Any, **kwargs) -> Any: | |
"""Creates the mlcroissant class `mlc_class` from the editor `instance`.""" | |
fields = dataclasses.fields(mlc_class) | |
params: dict[str, Any] = {} | |
for field in fields: | |
name = field.name | |
if hasattr(instance, name) and name not in kwargs: | |
params[name] = getattr(instance, name) | |
return mlc_class(**params, **kwargs) | |
class User: | |
"""The connected user.""" | |
access_token: str | |
id_token: str | |
username: str | |
def connect(cls, code: str): | |
credentials = base64.b64encode( | |
f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}".encode() | |
).decode() | |
headers = { | |
"Authorization": f"Basic {credentials}", | |
} | |
data = { | |
"client_id": OAUTH_CLIENT_ID, | |
"grant_type": "authorization_code", | |
"code": code, | |
"redirect_uri": REDIRECT_URI, | |
} | |
url = "https://huggingface.co/oauth/token" | |
response = requests.post(url, data=data, headers=headers) | |
if response.status_code == 200: | |
response = response.json() | |
access_token = response.get("access_token") | |
id_token = response.get("id_token") | |
if access_token and id_token: | |
url = "https://huggingface.co/oauth/userinfo" | |
headers = {"Authorization": f"Bearer {access_token}"} | |
response = requests.get(url, headers=headers) | |
if response.status_code == 200: | |
response = response.json() | |
username = response.get("preferred_username") | |
if username: | |
return User( | |
access_token=access_token, | |
username=username, | |
id_token=id_token, | |
) | |
raise Exception( | |
f"Could not connect to Hugging Face. Please, go to {REDIRECT_URI}." | |
f" ({response=})." | |
) | |
def get_user(): | |
"""Get user from session_state.""" | |
return st.session_state.get(User) | |
class CurrentProject: | |
"""The selected project.""" | |
path: epath.Path | |
def create_new(cls) -> CurrentProject | None: | |
timestamp = datetime.datetime.now().strftime(PROJECT_FOLDER_PATTERN) | |
return cls.from_timestamp(timestamp) | |
def from_timestamp(cls, timestamp: str) -> CurrentProject | None: | |
user = get_user() | |
if user is None and OAUTH_CLIENT_ID: | |
return None | |
else: | |
path = PAST_PROJECTS_PATH(user) | |
return CurrentProject(path=path / timestamp) | |
class SelectedResource: | |
"""The selected FileSet or FileObject on the `Resources` page.""" | |
pass | |
class SelectedRecordSet: | |
"""The selected RecordSet on the `RecordSets` page.""" | |
record_set_key: int | |
record_set: RecordSet | |
class Node: | |
ctx: mlc.Context = dataclasses.field(default_factory=mlc.Context) | |
id: str | None = None | |
name: str | None = None | |
def get_name_or_id(self): | |
if self.ctx.is_v0(): | |
return self.name | |
else: | |
return self.id | |
class FileObject(Node): | |
"""FileObject analogue for editor""" | |
description: str | None = None | |
contained_in: list[str] | None = dataclasses.field(default_factory=list) | |
content_size: str | None = None | |
content_url: str | None = None | |
encoding_format: str | None = None | |
sha256: str | None = None | |
df: pd.DataFrame | None = None | |
folder: epath.PathLike | None = None | |
class FileSet(Node): | |
"""FileSet analogue for editor""" | |
contained_in: list[str] = dataclasses.field(default_factory=list) | |
description: str | None = None | |
encoding_format: str | None = "" | |
includes: str | None = "" | |
class Field(Node): | |
"""Field analogue for editor""" | |
description: str | None = None | |
data_types: str | list[str] | None = None | |
source: mlc.Source | None = None | |
references: mlc.Source | None = None | |
class RecordSet(Node): | |
"""Record Set analogue for editor""" | |
data: list[Any] | None = None | |
data_types: list[str] | None = None | |
description: str | None = None | |
is_enumeration: bool | None = None | |
key: str | list[str] | None = None | |
fields: list[Field] = dataclasses.field(default_factory=list) | |
class Metadata(Node): | |
"""main croissant data object, helper functions exist to load and unload this into the mlcroissant version""" | |
description: str | None = None | |
cite_as: str | None = None | |
creators: list[mlc.Person] = dataclasses.field(default_factory=list) | |
data_biases: str | None = None | |
data_collection: str | None = None | |
date_published: datetime.datetime | None = None | |
license: str | None = "" | |
personal_sensitive_information: str | None = None | |
url: str = "" | |
distribution: list[FileObject | FileSet] = dataclasses.field(default_factory=list) | |
record_sets: list[RecordSet] = dataclasses.field(default_factory=list) | |
version: str | None = None | |
def __bool__(self): | |
return self.name != "" and self.url != "" | |
def rename_distribution(self, old_name: str, new_name: str): | |
"""Renames a resource by changing all the references to this resource.""" | |
# Update other resources: | |
for i, resource in enumerate(self.distribution): | |
if resource.id == old_name: | |
self.distribution[i].id = new_name | |
contained_in = resource.contained_in | |
if contained_in and old_name in contained_in: | |
self.distribution[i].contained_in = [ | |
new_name if name == old_name else name for name in contained_in | |
] | |
# Updating source/references works just as with RecordSets. | |
self.rename_record_set(old_name, new_name) | |
def rename_record_set(self, old_name: str, new_name: str): | |
"""Renames a RecordSet by changing all the references to this RecordSet.""" | |
for i, record_set in enumerate(self.record_sets): | |
if record_set.id == old_name: | |
self.record_sets[i].id = new_name | |
for j, field in enumerate(record_set.fields): | |
possible_uuid = f"{old_name}/" | |
# Update source | |
source = field.source | |
if source and source.field and source.field.startswith(possible_uuid): | |
new_uuid = source.field.replace(old_name, new_name, 1) | |
self.record_sets[i].fields[j].source.field = new_uuid | |
if source and source.file_object and source.file_object == old_name: | |
self.record_sets[i].fields[j].source.file_object = new_name | |
if source and source.file_set and source.file_set == old_name: | |
self.record_sets[i].fields[j].source.file_set = new_name | |
if source and source.distribution and source.distribution == old_name: | |
self.record_sets[i].fields[j].source.distribution = new_name | |
# Update references | |
references = field.references | |
if ( | |
references | |
and references.field | |
and references.field.startswith(possible_uuid) | |
): | |
new_uuid = references.field.replace(old_name, new_name, 1) | |
self.record_sets[i].fields[j].references.field = new_uuid | |
if ( | |
references | |
and references.file_object | |
and references.file_object == old_name | |
): | |
self.record_sets[i].fields[j].references.file_object = new_name | |
if ( | |
references | |
and references.file_set | |
and references.file_set == old_name | |
): | |
self.record_sets[i].fields[j].references.file_set = new_name | |
if ( | |
references | |
and references.distribution | |
and references.distribution == old_name | |
): | |
self.record_sets[i].fields[j].references.distribution = new_name | |
def rename_field(self, old_name: str, new_name: str): | |
"""Renames a field by changing all the references to this field.""" | |
for i, record_set in enumerate(self.record_sets): | |
for j, field in enumerate(record_set.fields): | |
possible_uuid = f"/{old_name}" | |
# Update source | |
source = field.source | |
# The difference with RecordSet is the `.endswith` here: | |
if source and source.field and source.field.endswith(possible_uuid): | |
new_uuid = source.field.replace(old_name, new_name, 1) | |
self.record_sets[i].fields[j].source.field = new_uuid | |
# Update references | |
references = field.references | |
if ( | |
references | |
and references.field | |
and references.field.endswith(possible_uuid) | |
): | |
new_uuid = references.field.replace(old_name, new_name, 1) | |
self.record_sets[i].fields[j].references.field = new_uuid | |
def rename_id(self, old_id: str, new_id: str): | |
for resource in self.distribution: | |
if resource.id == old_id: | |
resource.id = new_id | |
if resource.contained_in and old_id in resource.contained_in: | |
resource.contained_in = [ | |
new_id if uuid == old_id else uuid for uuid in resource.contained_in | |
] | |
for record_set in self.record_sets: | |
if record_set.id == old_id: | |
record_set.id = new_id | |
for field in record_set.fields: | |
if field.id == old_id: | |
field.id = new_id | |
for p in ["distribution", "field", "file_object", "file_set"]: | |
if field.source and getattr(field.source, p) == old_id: | |
setattr(field.source, p, new_id) | |
if field.references and getattr(field.references, p) == old_id: | |
setattr(field.references, p, new_id) | |
def add_distribution(self, distribution: FileSet | FileObject) -> None: | |
self.distribution.append(distribution) | |
def remove_distribution(self, key: int) -> None: | |
del self.distribution[key] | |
def add_record_set(self, record_set: RecordSet) -> None: | |
name = find_unique_name(self.names(), record_set.name) | |
record_set.name = name | |
self.record_sets.append(record_set) | |
def remove_record_set(self, key: int) -> None: | |
del self.record_sets[key] | |
def _find_record_set(self, record_set_key: int) -> RecordSet: | |
if record_set_key >= len(self.record_sets): | |
raise ValueError(f"Wrong index when finding a RecordSet: {record_set_key}") | |
return self.record_sets[record_set_key] | |
def add_field(self, record_set_key: int, field: Field) -> None: | |
record_set = self._find_record_set(record_set_key) | |
record_set.fields.append(field) | |
def remove_field(self, record_set_key: int, field_key: int) -> None: | |
record_set = self._find_record_set(record_set_key) | |
if field_key >= len(record_set.fields): | |
raise ValueError(f"Wrong index when removing field: {field_key}") | |
del record_set.fields[field_key] | |
def to_canonical(self) -> mlc.Metadata: | |
distribution = [] | |
ctx = self.ctx | |
for file in self.distribution: | |
if isinstance(file, FileObject): | |
distribution.append(create_class(mlc.FileObject, file, ctx=ctx)) | |
elif isinstance(file, FileSet): | |
distribution.append(create_class(mlc.FileSet, file, ctx=ctx)) | |
record_sets = [] | |
for record_set in self.record_sets: | |
fields = [] | |
for field in record_set.fields: | |
fields.append(create_class(mlc.Field, field, ctx=ctx)) | |
record_sets.append( | |
create_class(mlc.RecordSet, record_set, ctx=ctx, fields=fields) | |
) | |
return create_class( | |
mlc.Metadata, | |
self, | |
distribution=distribution, | |
record_sets=record_sets, | |
) | |
def from_canonical(cls, canonical_metadata: mlc.Metadata) -> Metadata: | |
distribution = [] | |
for file in canonical_metadata.distribution: | |
if isinstance(file, mlc.FileObject): | |
distribution.append(create_class(FileObject, file)) | |
else: | |
distribution.append(create_class(FileSet, file)) | |
record_sets = [] | |
for record_set in canonical_metadata.record_sets: | |
fields = [] | |
for field in record_set.fields: | |
fields.append(create_class(Field, field)) | |
record_sets.append( | |
create_class( | |
RecordSet, | |
record_set, | |
fields=fields, | |
) | |
) | |
return create_class( | |
cls, | |
canonical_metadata, | |
distribution=distribution, | |
record_sets=record_sets, | |
) | |
def names(self) -> set[str]: | |
distribution = set() | |
record_sets = set() | |
fields = set() | |
for resource in self.distribution: | |
distribution.add(resource.get_name_or_id()) | |
for record_set in self.record_sets: | |
record_sets.add(record_set.get_name_or_id()) | |
for field in record_set.fields: | |
fields.add(field.get_name_or_id()) | |
return distribution.union(record_sets).union(fields) | |
class OpenTab: | |
pass | |
def get_tab(): | |
tab = st.session_state.get(OpenTab) | |
if tab is None: | |
return 0 | |
else: | |
return tab | |
def set_tab(tab: str): | |
if tab not in TABS: | |
return | |
index = TABS.index(tab) | |
st.session_state[OpenTab] = index | |