skogsgren's picture
Update app.py
1457416
import gradio as gr
import pickle
import numpy as np
import torch
import io
from torch import nn
from transformers import AutoModelForSequenceClassification
from sklearn.pipeline import Pipeline
from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler, ProgressBar
from skorch.hf import HuggingfacePretrainedTokenizer
from torch.optim.lr_scheduler import LambdaLR
from skorch.callbacks import EarlyStopping
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import balanced_accuracy_score
from modAL.models import ActiveLearner
from modAL.uncertainty import uncertainty_sampling
class BertModule(nn.Module):
""" BERT model according to Skorch convention """
def __init__(self, name, num_labels):
super().__init__()
self.name = name
self.num_labels = num_labels
self.reset_weights()
def reset_weights(self):
self.bert = AutoModelForSequenceClassification.from_pretrained(
self.name, num_labels=self.num_labels
)
def forward(self, **kwargs):
pred = self.bert(**kwargs)
return pred.logits
MAX_EPOCHS = 5
BATCH_SIZE = 12
num_training_steps = MAX_EPOCHS * (584 // BATCH_SIZE + 1)
def lr_schedule(current_step):
factor = float(num_training_steps - current_step) / float(max(1, num_training_steps))
assert factor > 0
return factor
class CPU_Unpickler(pickle.Unpickler):
def find_class(self, module, name):
if module == 'torch.storage' and name == '_load_from_bytes':
return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
else: return super().find_class(module, name)
with open('learner.bin', 'rb') as f:
learner = CPU_Unpickler(f).load()
def dg_predict(tweet):
return '🐕' if learner.predict([tweet])[0] == 1 else '🐈'
examples = [
"Såå kulturberikande med explosioner utanför porten 🤗 #tackmagda",
"Att WEF ska kunna styra över svenskar är fascistiskt",
"Det råder inget tvivel om att svensk kultur inte är svensk längre, utan dessa MENA kommer hit och lever på bidrag och skjuter i våra förorter",
"Kriminella måste skickas hem där de kommer ifrån!! Återvandring!!",
"En global konspiration från Bryssel för att försvaga vår suveränitet, tills vi alla måste buga för Soros"
]
iface = gr.Interface(fn=dg_predict, inputs="text", outputs="text", examples=examples)
iface.launch()