TeraSpace's picture
Update README.md (#2)
9247da5
|
raw
history blame
2.75 kB
---
license: mit
language:
- ru
widget:
- text: 'привет'
example_title: example_1
- text: 'тебя как звать'
example_title: example_2
- text: 'как приготовить рагу'
example_title: example_3
- text: 'в чем смысл жизни'
example_title: example_4
- text: 'у меня кот сбежал'
example_title: example_5
- text: 'что такое спидометр'
example_title: example_6
- text: 'меня артур зовут'
example_title: example_7
---
# TeraSpace/replica_classification
Сделано на основе [xlm-roberta-base](https://huggingface.co/xlm-roberta-base)
0. dialog - реагирует на диалоговые реплики. Например, "привет"
1. trouble - реагирует на реплики, где пользователь рассказывает о своих проблемах. Например, "у меня болит зуб, мне проткнули колесо"
2. question - реагирует на вопрос не относящийся к диалогу, например: "когда родился пушкин" или "когда я стану миллионером"
3. about_user - реагирует, когда пользователь говорит о себе. Например, "меня зовут андрей"
4. about_model - реагирует на вопросы о личности ассистента. Например, "как тебя зовут, ты кто такая"
5. instruct - реагирует на вопросы, ответ на которые представляет собой инструкцию. Например, "как установить windows, как приготовить борщ"
# Использование
```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
del_symbs = ["?","!",".",","]
classes = ["dialog","trouble","question","about_user","about_model","instruct"]
device = torch.device("cuda")
model_name = 'TeraSpace/replica_classification'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = len(classes)).to(device)
while True:
text = input("=>").lower()
for del_symb in del_symbs:
text = text.replace(del_symb,"")
inputs = tokenizer(text, truncation=True, max_length=512, padding='max_length',
return_tensors='pt').to(device)
with torch.no_grad():
logits = model(**inputs).logits
probas = list(torch.sigmoid(logits)[0].cpu().detach().numpy())
out = classes[probas.index(max(probas))]
print(out)
```