sarahyurick
commited on
Commit
•
68fd8f2
1
Parent(s):
7556b38
Add PyTorchModelHubMixin
Browse files
README.md
CHANGED
@@ -81,9 +81,8 @@ To use this AEGIS classifiers, you must get access to Llama Guard on Hugging Fac
|
|
81 |
```python
|
82 |
import torch
|
83 |
import torch.nn.functional as F
|
84 |
-
from huggingface_hub import
|
85 |
from peft import PeftModel
|
86 |
-
from safetensors.torch import load_file
|
87 |
from torch.nn import Dropout, Linear
|
88 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
89 |
|
@@ -103,8 +102,8 @@ tokenizer = AutoTokenizer.from_pretrained(
|
|
103 |
)
|
104 |
tokenizer.pad_token = tokenizer.unk_token
|
105 |
|
106 |
-
class InstructionDataGuardNet(torch.nn.Module):
|
107 |
-
def __init__(self, input_dim, dropout=0.7):
|
108 |
super().__init__()
|
109 |
self.input_dim = input_dim
|
110 |
self.dropout = Dropout(dropout)
|
@@ -129,13 +128,8 @@ class InstructionDataGuardNet(torch.nn.Module):
|
|
129 |
return x
|
130 |
|
131 |
# Load Instruction-Data-Guard classifier
|
132 |
-
instruction_data_guard = InstructionDataGuardNet
|
133 |
-
|
134 |
-
repo_id="nvidia/instruction-data-guard",
|
135 |
-
filename="model.safetensors",
|
136 |
-
)
|
137 |
-
state_dict = load_file(weights_path)
|
138 |
-
instruction_data_guard.load_state_dict(state_dict)
|
139 |
instruction_data_guard = instruction_data_guard.eval()
|
140 |
|
141 |
# Function to compute results
|
|
|
81 |
```python
|
82 |
import torch
|
83 |
import torch.nn.functional as F
|
84 |
+
from huggingface_hub import PyTorchModelHubMixin
|
85 |
from peft import PeftModel
|
|
|
86 |
from torch.nn import Dropout, Linear
|
87 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
88 |
|
|
|
102 |
)
|
103 |
tokenizer.pad_token = tokenizer.unk_token
|
104 |
|
105 |
+
class InstructionDataGuardNet(torch.nn.Module, PyTorchModelHubMixin):
|
106 |
+
def __init__(self, input_dim=4096, dropout=0.7):
|
107 |
super().__init__()
|
108 |
self.input_dim = input_dim
|
109 |
self.dropout = Dropout(dropout)
|
|
|
128 |
return x
|
129 |
|
130 |
# Load Instruction-Data-Guard classifier
|
131 |
+
instruction_data_guard = InstructionDataGuardNet.from_pretrained("nvidia/instruction-data-guard")
|
132 |
+
instruction_data_guard = instruction_data_guard.to(device)
|
|
|
|
|
|
|
|
|
|
|
133 |
instruction_data_guard = instruction_data_guard.eval()
|
134 |
|
135 |
# Function to compute results
|