Miaoran000
commited on
Commit
•
6f7b340
1
Parent(s):
ade58fc
update for pipeline
Browse files- config.json +2 -1
- modeling_hhem_v2.py +9 -2
config.json
CHANGED
@@ -8,5 +8,6 @@
|
|
8 |
},
|
9 |
"model_type": "HHEMv2Config",
|
10 |
"torch_dtype": "float32",
|
11 |
-
"transformers_version": "4.39.3"
|
|
|
12 |
}
|
|
|
8 |
},
|
9 |
"model_type": "HHEMv2Config",
|
10 |
"torch_dtype": "float32",
|
11 |
+
"transformers_version": "4.39.3",
|
12 |
+
"id2label": {"0": "hallucinated", "1": "consistent"}
|
13 |
}
|
modeling_hhem_v2.py
CHANGED
@@ -45,8 +45,15 @@ class HHEMv2ForSequenceClassification(PreTrainedModel):
|
|
45 |
# combined_model = PeftModel.from_pretrained(base_model, checkpoint, is_trainable=False)
|
46 |
# self.t5 = combined_model
|
47 |
|
48 |
-
def forward(self, **kwargs):
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def predict(self, text_pairs):
|
52 |
tokenizer = self.tokenzier
|
|
|
45 |
# combined_model = PeftModel.from_pretrained(base_model, checkpoint, is_trainable=False)
|
46 |
# self.t5 = combined_model
|
47 |
|
48 |
+
def forward(self, **kwargs): # To cope with `text-classiication` pipeline
|
49 |
+
self.t5.eval()
|
50 |
+
with torch.no_grad():
|
51 |
+
outputs = self.t5(**kwargs)
|
52 |
+
logits = outputs.logits
|
53 |
+
logits = logits[:, 0, :]
|
54 |
+
outputs.logits = logits
|
55 |
+
return outputs
|
56 |
+
# return self.t5(**kwargs)
|
57 |
|
58 |
def predict(self, text_pairs):
|
59 |
tokenizer = self.tokenzier
|