File size: 3,386 Bytes
5147f0d 12ca412 5147f0d 12ca412 90e8224 12ca412 1b98aa7 5147f0d fa0fdab 5147f0d 12ca412 fa0fdab 12ca412 1b98aa7 5147f0d 1b98aa7 12ca412 1b98aa7 90e8224 1b98aa7 12ca412 1b98aa7 12ca412 1b98aa7 12ca412 1b98aa7 12ca412 1b98aa7 12ca412 1b98aa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import re
from pathlib import Path
import requests
import json
from .artifact import Artifact, Artifactory
COLLECTION_SEPARATOR = '.'
PATHS_SEP = ':'
class Catalog(Artifactory):
name: str = None
location: str = None
try:
import unitxt
default_catalog_path = os.path.dirname(unitxt.__file__) + "/catalog"
except ImportError:
default_catalog_path = os.path.dirname(__file__) + "/catalog"
class LocalCatalog(Catalog):
name: str = "local"
location: str = default_catalog_path
def path(self, artifact_identifier: str):
assert artifact_identifier.strip(), 'artifact_identifier should not be an empty string.'
parts = artifact_identifier.split(COLLECTION_SEPARATOR)
parts[-1] = parts[-1] + ".json"
return os.path.join(self.location, *parts)
def load(self, artifact_identifier: str):
assert artifact_identifier in self, "Artifact with name {} does not exist".format(artifact_identifier)
path = self.path(artifact_identifier)
artifact_instance = Artifact.load(path)
return artifact_instance
def __getitem__(self, name) -> Artifact:
return self.load(name)
def __contains__(self, artifact_identifier: str):
if not os.path.exists(self.location):
return False
path = self.path(artifact_identifier)
if path is None:
return False
return os.path.exists(path) and os.path.isfile(path)
def save_artifact(self, artifact: Artifact, artifact_identifier: str, overwrite: bool = False):
assert isinstance(artifact, Artifact), f"Input artifact must be an instance of Artifact, got {type(artifact)}"
if not overwrite:
assert (
artifact_identifier not in self
), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}"
path = self.path(artifact_identifier)
os.makedirs(Path(path).parent.absolute(), exist_ok=True)
artifact.save(path)
class GithubCatalog(LocalCatalog):
name = "community"
repo = "unitxt"
repo_dir = "src/unitxt/catalog"
user = "IBM"
branch = "master"
def prepare(self):
self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{self.branch}/{self.repo_dir}"
def load(self, artifact_identifier: str):
url = self.path(artifact_identifier)
response = requests.get(url)
data = response.json()
return Artifact.from_dict(data)
def __contains__(self, artifact_identifier: str):
url = self.path(artifact_identifier)
response = requests.head(url)
return response.status_code == 200
def verify_legal_catalog_name(name):
assert re.match('^[\w' + COLLECTION_SEPARATOR + ']+$', name),\
'Catalog name should be alphanumeric, ":" should specify dirs (instead of "/").'
def add_to_catalog(artifact: Artifact, name: str, catalog: Catalog = None, overwrite: bool = False,
catalog_path: str = None):
if catalog is None:
if catalog_path is None:
catalog_path = default_catalog_path
catalog = LocalCatalog(location=catalog_path)
verify_legal_catalog_name(name)
catalog.save_artifact(artifact, name, overwrite=overwrite) # remove collection (its actually the dir).
# verify name
|