Update README for PyTorchModelHubMixin

#4
Files changed (1) hide show
  1. README.md +4 -6
README.md CHANGED
@@ -45,13 +45,13 @@ A simple example to get started:
45
 
46
  ```python
47
  import torch
 
48
  import pandas as pd
49
  from gluonts.dataset.pandas import PandasDataset
50
  from gluonts.dataset.split import split
51
- from huggingface_hub import hf_hub_download
52
 
53
  from uni2ts.eval_util.plot import plot_single
54
- from uni2ts.model.moirai import MoiraiForecast
55
 
56
 
57
  SIZE = "small" # model size: choose from {'small', 'base', 'large'}
@@ -85,9 +85,7 @@ test_data = test_template.generate_instances(
85
 
86
  # Prepare pre-trained model by downloading model weights from huggingface hub
87
  model = MoiraiForecast.load_from_checkpoint(
88
- checkpoint_path=hf_hub_download(
89
- repo_id=f"Salesforce/moirai-R-{SIZE}", filename="model.ckpt"
90
- ),
91
  prediction_length=PDT,
92
  context_length=CTX,
93
  patch_size=PSZ,
@@ -95,7 +93,6 @@ model = MoiraiForecast.load_from_checkpoint(
95
  target_dim=1,
96
  feat_dynamic_real_dim=ds.num_feat_dynamic_real,
97
  past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
98
- map_location="cuda:0" if torch.cuda.is_available() else "cpu",
99
  )
100
 
101
  predictor = model.create_predictor(batch_size=BSZ)
@@ -117,6 +114,7 @@ plot_single(
117
  name="pred",
118
  show_label=True,
119
  )
 
120
  ```
121
 
122
  ## The Moirai Family
 
45
 
46
  ```python
47
  import torch
48
+ import matplotlib.pyplot as plt
49
  import pandas as pd
50
  from gluonts.dataset.pandas import PandasDataset
51
  from gluonts.dataset.split import split
 
52
 
53
  from uni2ts.eval_util.plot import plot_single
54
+ from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
55
 
56
 
57
  SIZE = "small" # model size: choose from {'small', 'base', 'large'}
 
85
 
86
  # Prepare pre-trained model by downloading model weights from huggingface hub
87
  model = MoiraiForecast.load_from_checkpoint(
88
+ module=MoiraiModule.from_pretrained(f"Salesforce/moirai-1.0-R-{SIZE}"),
 
 
89
  prediction_length=PDT,
90
  context_length=CTX,
91
  patch_size=PSZ,
 
93
  target_dim=1,
94
  feat_dynamic_real_dim=ds.num_feat_dynamic_real,
95
  past_feat_dynamic_real_dim=ds.num_past_feat_dynamic_real,
 
96
  )
97
 
98
  predictor = model.create_predictor(batch_size=BSZ)
 
114
  name="pred",
115
  show_label=True,
116
  )
117
+ plt.show()
118
  ```
119
 
120
  ## The Moirai Family