File size: 921 Bytes
26290c2
 
 
 
 
 
 
 
 
e949db1
26290c2
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from catboost import CatBoostClassifier
import torch.nn as nn
import streamlit as st

@st.cache_resource
def load_model():
    catboost_model = CatBoostClassifier(random_seed=42,eval_metric='Accuracy')
    catboost_model.load_model("pages/anti_toxic/dont_be_toxic.pt")
    model_checkpoint = 'cointegrated/rubert-tiny-toxicity'
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
    model.classifier=nn.Dropout(0)
    model.dropout = nn.Dropout(0)
    return catboost_model, tokenizer, model

catboost_model, tokenizer, model = load_model()
def predict(text):
    t=tokenizer(text, return_tensors='pt',truncation=True, padding=True)
    with torch.no_grad():
        t = model(**t)[0].tolist()[0]
    return catboost_model.predict_proba(t)