Spaces:
Running
on
Zero
Running
on
Zero
Update models/attn_model.py
#3
by
johnson906
- opened
- models/attn_model.py +6 -2
models/attn_model.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
from .model import Model
|
3 |
from .utils import sample_token, get_last_attn
|
@@ -5,6 +6,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
8 |
|
9 |
class AttentionModel(Model):
|
10 |
def __init__(self, config):
|
@@ -12,12 +14,14 @@ class AttentionModel(Model):
|
|
12 |
self.name = config["model_info"]["name"]
|
13 |
self.max_output_tokens = int(config["params"]["max_output_tokens"])
|
14 |
model_id = config["model_info"]["model_id"]
|
15 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_id
|
|
|
16 |
self.model = AutoModelForCausalLM.from_pretrained(
|
17 |
model_id,
|
18 |
torch_dtype=torch.bfloat16,
|
19 |
device_map=device,
|
20 |
-
attn_implementation="eager"
|
|
|
21 |
).eval()
|
22 |
if config["params"]["important_heads"] == "all":
|
23 |
attn_size = self.get_map_dim()
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
from .model import Model
|
4 |
from .utils import sample_token, get_last_attn
|
|
|
6 |
import torch.nn.functional as F
|
7 |
|
8 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
9 |
+
token = os.getenv("HF_TOKEN")
|
10 |
|
11 |
class AttentionModel(Model):
|
12 |
def __init__(self, config):
|
|
|
14 |
self.name = config["model_info"]["name"]
|
15 |
self.max_output_tokens = int(config["params"]["max_output_tokens"])
|
16 |
model_id = config["model_info"]["model_id"]
|
17 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id,
|
18 |
+
use_auth_token=token)
|
19 |
self.model = AutoModelForCausalLM.from_pretrained(
|
20 |
model_id,
|
21 |
torch_dtype=torch.bfloat16,
|
22 |
device_map=device,
|
23 |
+
attn_implementation="eager",
|
24 |
+
use_auth_token=token
|
25 |
).eval()
|
26 |
if config["params"]["important_heads"] == "all":
|
27 |
attn_size = self.get_map_dim()
|