Spaces:
Runtime error
Runtime error
liujch1998
commited on
Commit
β’
b921db9
1
Parent(s):
be2d0e3
bf16
Browse files
app.py
CHANGED
@@ -32,35 +32,35 @@ repo.git_pull()
|
|
32 |
class Interactive:
|
33 |
def __init__(self):
|
34 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD)
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
|
42 |
def run(self, statement):
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
# return {
|
53 |
-
# 'logit': logit.item(),
|
54 |
-
# 'logit_calibrated': logit_calibrated.item(),
|
55 |
-
# 'score': score.item(),
|
56 |
-
# 'score_calibrated': score_calibrated.item(),
|
57 |
-
# }
|
58 |
return {
|
59 |
-
'logit':
|
60 |
-
'logit_calibrated':
|
61 |
-
'score':
|
62 |
-
'score_calibrated':
|
63 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
interactive = Interactive()
|
66 |
|
|
|
32 |
class Interactive:
|
33 |
def __init__(self):
|
34 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD)
|
35 |
+
self.model = transformers.T5EncoderModel.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN_DOWNLOAD, low_cpu_mem_usage=True, device_map='auto', torch_dtype='auto')
|
36 |
+
self.linear = torch.nn.Linear(self.model.shared.embedding_dim, 1, dtype=self.model.dtype).to(device)
|
37 |
+
self.linear.weight = torch.nn.Parameter(self.model.shared.weight[32099, :].unsqueeze(0)) # (1, D)
|
38 |
+
self.linear.bias = torch.nn.Parameter(self.model.shared.weight[32098, 0].unsqueeze(0)) # (1)
|
39 |
+
self.model.eval()
|
40 |
+
self.t = self.model.shared.weight[32097, 0].item()
|
41 |
|
42 |
def run(self, statement):
|
43 |
+
input_ids = self.tokenizer.batch_encode_plus([statement], return_tensors='pt', padding='longest').input_ids.to(device)
|
44 |
+
with torch.no_grad():
|
45 |
+
output = self.model(input_ids)
|
46 |
+
last_hidden_state = output.last_hidden_state.to(device) # (B=1, L, D)
|
47 |
+
hidden = last_hidden_state[0, -1, :] # (D)
|
48 |
+
logit = self.linear(hidden).squeeze(-1) # ()
|
49 |
+
logit_calibrated = logit / self.t
|
50 |
+
score = logit.sigmoid()
|
51 |
+
score_calibrated = logit_calibrated.sigmoid()
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
return {
|
53 |
+
'logit': logit.item(),
|
54 |
+
'logit_calibrated': logit_calibrated.item(),
|
55 |
+
'score': score.item(),
|
56 |
+
'score_calibrated': score_calibrated.item(),
|
57 |
}
|
58 |
+
# return {
|
59 |
+
# 'logit': 0.0,
|
60 |
+
# 'logit_calibrated': 0.0,
|
61 |
+
# 'score': 0.5,
|
62 |
+
# 'score_calibrated': 0.5,
|
63 |
+
# }
|
64 |
|
65 |
interactive = Interactive()
|
66 |
|