chitlchow's picture
Update app.py
22a3411
import gradio as gr
import torch
from src.model import BertClassifier, RobertaClassifier
from transformers import BertTokenizer
from datetime import datetime
device = torch.device('cpu')
model_name = 'bert-base-uncased'
model = BertClassifier(model_name, 0.5)
model.to(device)
model.load_state_dict(torch.load('models/bert-all-data.pth', map_location=device))
tokenizer = BertTokenizer.from_pretrained(model_name)
def ai_text_classifier(text: str) -> dict:
# Convert Text into tokens
tokens = tokenizer(text, return_tensors='pt', max_length=512, padding='max_length', truncation=True).to(device)
# Get probability of the text
prob = model(tokens['input_ids'], tokens['attention_mask']).item()
# Return the probability in dictionary
return {
"AI": prob,
'Others': 1 - prob
}
demo = gr.Interface(fn=ai_text_classifier, inputs="text", outputs="label")
demo.launch()