|
|
|
|
|
__all__ = ['generate_TS_df', 'normalize_columns', 'remove_constant_columns', 'ReferenceArtifact', 'PrintLayer', |
|
'get_wandb_artifacts', 'get_pickle_artifact'] |
|
|
|
|
|
from .imports import * |
|
from fastcore.all import * |
|
import wandb |
|
import pickle |
|
import pandas as pd |
|
import numpy as np |
|
|
|
import torch.nn as nn |
|
from fastai.basics import * |
|
|
|
|
|
def generate_TS_df(rows, cols): |
|
"Generates a dataframe containing a multivariate time series, where each column \ |
|
represents a variable and each row a time point (sample). The timestamp is in the \ |
|
index of the dataframe, and it is created with a even space of 1 second between samples" |
|
index = np.arange(pd.Timestamp.now(), |
|
pd.Timestamp.now() + pd.Timedelta(rows-1, 'seconds'), |
|
pd.Timedelta(1, 'seconds')) |
|
data = np.random.randn(len(index), cols) |
|
return pd.DataFrame(data, index=index) |
|
|
|
|
|
def normalize_columns(df:pd.DataFrame): |
|
"Normalize columns from `df` to have 0 mean and 1 standard deviation" |
|
mean = df.mean() |
|
std = df.std() + 1e-7 |
|
return (df-mean)/std |
|
|
|
|
|
def remove_constant_columns(df:pd.DataFrame): |
|
return df.loc[:, (df != df.iloc[0]).any()] |
|
|
|
|
|
class ReferenceArtifact(wandb.Artifact): |
|
default_storage_path = Path('data/wandb_artifacts/') |
|
"This class is meant to create an artifact with a single reference to an object \ |
|
passed as argument in the contructor. The object will be pickled, hashed and stored \ |
|
in a specified folder." |
|
@delegates(wandb.Artifact.__init__) |
|
def __init__(self, obj, name, type='object', folder=None, **kwargs): |
|
super().__init__(type=type, name=name, **kwargs) |
|
|
|
hash_code = str(hash(pickle.dumps(obj))) |
|
folder = Path(ifnone(folder, Path.home()/self.default_storage_path)) |
|
with open(f'{folder}/{hash_code}', 'wb') as f: |
|
pickle.dump(obj, f) |
|
self.add_reference(f'file://{folder}/{hash_code}') |
|
if self.metadata is None: |
|
self.metadata = dict() |
|
self.metadata['ref'] = dict() |
|
self.metadata['ref']['hash'] = hash_code |
|
self.metadata['ref']['type'] = str(obj.__class__) |
|
|
|
|
|
@patch |
|
def to_obj(self:wandb.apis.public.Artifact): |
|
"""Download the files of a saved ReferenceArtifact and get the referenced object. The artifact must \ |
|
come from a call to `run.use_artifact` with a proper wandb run.""" |
|
if self.metadata.get('ref') is None: |
|
print(f'ERROR:{self} does not come from a saved ReferenceArtifact') |
|
return None |
|
original_path = ReferenceArtifact.default_storage_path/self.metadata['ref']['hash'] |
|
path = original_path if original_path.exists() else Path(self.download()).ls()[0] |
|
with open(path, 'rb') as f: |
|
obj = pickle.load(f) |
|
return obj |
|
|
|
|
|
import torch.nn as nn |
|
class PrintLayer(nn.Module): |
|
def __init__(self): |
|
super(PrintLayer, self).__init__() |
|
|
|
def forward(self, x): |
|
|
|
print(x.shape) |
|
return x |
|
|
|
|
|
@patch |
|
def export_and_get(self:Learner, keep_exported_file=False): |
|
""" |
|
Export the learner into an auxiliary file, load it and return it back. |
|
""" |
|
aux_path = Path('aux.pkl') |
|
self.export(fname='aux.pkl') |
|
aux_learn = load_learner('aux.pkl') |
|
if not keep_exported_file: aux_path.unlink() |
|
return aux_learn |
|
|
|
|
|
def get_wandb_artifacts(project_path, type=None, name=None, last_version=True): |
|
""" |
|
Get the artifacts logged in a wandb project. |
|
Input: |
|
- `project_path` (str): entity/project_name |
|
- `type` (str): whether to return only one type of artifacts |
|
- `name` (str): Leave none to have all artifact names |
|
- `last_version`: whether to return only the last version of each artifact or not |
|
|
|
Output: List of artifacts |
|
""" |
|
public_api = wandb.Api() |
|
if type is not None: |
|
types = [public_api.artifact_type(type, project_path)] |
|
else: |
|
types = public_api.artifact_types(project_path) |
|
|
|
res = L() |
|
for kind in types: |
|
for collection in kind.collections(): |
|
if name is None or name == collection.name: |
|
versions = public_api.artifact_versions( |
|
kind.type, |
|
"/".join([kind.entity, kind.project, collection.name]), |
|
per_page=1, |
|
) |
|
if last_version: res += next(versions) |
|
else: res += L(versions) |
|
return list(res) |
|
|
|
|
|
def get_pickle_artifact(filename): |
|
|
|
with open(filename, "rb") as f: |
|
df = pickle.load(f) |
|
|
|
return df |