sarahyurick commited on
Commit
68fd8f2
1 Parent(s): 7556b38

Add PyTorchModelHubMixin

Browse files
Files changed (1) hide show
  1. README.md +5 -11
README.md CHANGED
@@ -81,9 +81,8 @@ To use this AEGIS classifiers, you must get access to Llama Guard on Hugging Fac
81
  ```python
82
  import torch
83
  import torch.nn.functional as F
84
- from huggingface_hub import hf_hub_download
85
  from peft import PeftModel
86
- from safetensors.torch import load_file
87
  from torch.nn import Dropout, Linear
88
  from transformers import AutoModelForCausalLM, AutoTokenizer
89
 
@@ -103,8 +102,8 @@ tokenizer = AutoTokenizer.from_pretrained(
103
  )
104
  tokenizer.pad_token = tokenizer.unk_token
105
 
106
- class InstructionDataGuardNet(torch.nn.Module):
107
- def __init__(self, input_dim, dropout=0.7):
108
  super().__init__()
109
  self.input_dim = input_dim
110
  self.dropout = Dropout(dropout)
@@ -129,13 +128,8 @@ class InstructionDataGuardNet(torch.nn.Module):
129
  return x
130
 
131
  # Load Instruction-Data-Guard classifier
132
- instruction_data_guard = InstructionDataGuardNet(4096).to(device)
133
- weights_path = hf_hub_download(
134
- repo_id="nvidia/instruction-data-guard",
135
- filename="model.safetensors",
136
- )
137
- state_dict = load_file(weights_path)
138
- instruction_data_guard.load_state_dict(state_dict)
139
  instruction_data_guard = instruction_data_guard.eval()
140
 
141
  # Function to compute results
 
81
  ```python
82
  import torch
83
  import torch.nn.functional as F
84
+ from huggingface_hub import PyTorchModelHubMixin
85
  from peft import PeftModel
 
86
  from torch.nn import Dropout, Linear
87
  from transformers import AutoModelForCausalLM, AutoTokenizer
88
 
 
102
  )
103
  tokenizer.pad_token = tokenizer.unk_token
104
 
105
+ class InstructionDataGuardNet(torch.nn.Module, PyTorchModelHubMixin):
106
+ def __init__(self, input_dim=4096, dropout=0.7):
107
  super().__init__()
108
  self.input_dim = input_dim
109
  self.dropout = Dropout(dropout)
 
128
  return x
129
 
130
  # Load Instruction-Data-Guard classifier
131
+ instruction_data_guard = InstructionDataGuardNet.from_pretrained("nvidia/instruction-data-guard")
132
+ instruction_data_guard = instruction_data_guard.to(device)
 
 
 
 
 
133
  instruction_data_guard = instruction_data_guard.eval()
134
 
135
  # Function to compute results