|
--- |
|
license: mit |
|
language: |
|
- ru |
|
widget: |
|
- text: 'привет' |
|
example_title: example_1 |
|
- text: 'тебя как звать' |
|
example_title: example_2 |
|
- text: 'как приготовить рагу' |
|
example_title: example_3 |
|
- text: 'в чем смысл жизни' |
|
example_title: example_4 |
|
- text: 'у меня кот сбежал' |
|
example_title: example_5 |
|
- text: 'что такое спидометр' |
|
example_title: example_6 |
|
- text: 'меня артур зовут' |
|
example_title: example_7 |
|
--- |
|
# TeraSpace/replica_classification |
|
Сделано на основе [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) |
|
|
|
0. dialog - реагирует на диалоговые реплики. Например, "привет" |
|
1. trouble - реагирует на реплики, где пользователь рассказывает о своих проблемах. Например, "у меня болит зуб, мне проткнули колесо" |
|
2. question - реагирует на вопрос не относящийся к диалогу, например: "когда родился пушкин" или "когда я стану миллионером" |
|
3. about_user - реагирует, когда пользователь говорит о себе. Например, "меня зовут андрей" |
|
4. about_model - реагирует на вопросы о личности ассистента. Например, "как тебя зовут, ты кто такая" |
|
5. instruct - реагирует на вопросы, ответ на которые представляет собой инструкцию. Например, "как установить windows, как приготовить борщ" |
|
|
|
# Использование |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
del_symbs = ["?","!",".",","] |
|
classes = ["dialog","trouble","question","about_user","about_model","instruct"] |
|
|
|
device = torch.device("cuda") |
|
model_name = 'TeraSpace/replica_classification' |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = len(classes)).to(device) |
|
|
|
while True: |
|
text = input("=>").lower() |
|
for del_symb in del_symbs: |
|
text = text.replace(del_symb,"") |
|
|
|
inputs = tokenizer(text, truncation=True, max_length=512, padding='max_length', |
|
return_tensors='pt').to(device) |
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
probas = list(torch.sigmoid(logits)[0].cpu().detach().numpy()) |
|
|
|
out = classes[probas.index(max(probas))] |
|
print(out) |
|
``` |