romnatall
deploy
e949db1
raw
history blame
921 Bytes
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from catboost import CatBoostClassifier
import torch.nn as nn
import streamlit as st
@st.cache_resource
def load_model():
catboost_model = CatBoostClassifier(random_seed=42,eval_metric='Accuracy')
catboost_model.load_model("pages/anti_toxic/dont_be_toxic.pt")
model_checkpoint = 'cointegrated/rubert-tiny-toxicity'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)
model.classifier=nn.Dropout(0)
model.dropout = nn.Dropout(0)
return catboost_model, tokenizer, model
catboost_model, tokenizer, model = load_model()
def predict(text):
t=tokenizer(text, return_tensors='pt',truncation=True, padding=True)
with torch.no_grad():
t = model(**t)[0].tolist()[0]
return catboost_model.predict_proba(t)