metadata
license: mit
language:
- ru
Based on xlm-roberta-base
Использование
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
del_symbs = ["?","!",".",","]
classes = ["dialog","trouble","quest","about_user","about_model","instruction"]
device = torch.device("cuda")
model_name = 'TeraSpace/replica_classification'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels = len(classes)).to(device)
while True:
text = input("=>").lower()
for del_symb in del_symbs:
text = text.replace(del_symb,"")
inputs = tokenizer(text, truncation=True, max_length=512, padding='max_length',
return_tensors='pt').to(device)
with torch.no_grad():
logits = model(**inputs).logits
probas = list(torch.sigmoid(logits)[0].cpu().detach().numpy())
out = classes[probas.index(max(probas))]
print(out)