|
--- |
|
language: |
|
- zh |
|
tags: |
|
- pytorch |
|
- zh |
|
- Conversational |
|
|
|
--- |
|
|
|
[roberta-zh](https://github.com/brightmart/roberta_zh) fine-tuned on human-annotated conversational model self-chat data. It supports 2-class calssification for multi-turn dialogue sensible detection. |
|
Usage example: |
|
|
|
NOTE: it should be used under similar data distribution. |
|
|
|
```python |
|
import torch |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
|
tokenizer = BertTokenizer.from_pretrained('thu-coai/roberta-zh-sensible') |
|
model = BertForSequenceClassification.from_pretrained('thu-coai/roberta-zh-sensible', num_labels=2) |
|
model.eva() |
|
|
|
context = [ |
|
"你大爱的冷门古诗词是什么?\t一枝红艳露凝香,云雨巫山枉断肠", |
|
"你大爱的冷门古诗词是什么?\t一枝红艳露凝香,云雨巫山枉断肠", |
|
] |
|
|
|
response = [ |
|
"最爱春江花月夜", |
|
"我也很喜欢", |
|
] |
|
|
|
model_input = tokenizer(context, response, return_tensors='pt', padding=True) |
|
with torch.no_grad(): |
|
model_output = model(**model_input, return_dict=True) |
|
logits = model_output.logits |
|
preds_all = torch.argmax(logits, dim=-1).cpu() |
|
print(preds_all) # 1 for sensible response else 0 |
|
|
|
|
|
``` |