Spaces:
Runtime error
Runtime error
import gradio as gr | |
import pickle | |
import numpy as np | |
import torch | |
import io | |
from torch import nn | |
from transformers import AutoModelForSequenceClassification | |
from sklearn.pipeline import Pipeline | |
from skorch import NeuralNetClassifier | |
from skorch.callbacks import LRScheduler, ProgressBar | |
from skorch.hf import HuggingfacePretrainedTokenizer | |
from torch.optim.lr_scheduler import LambdaLR | |
from skorch.callbacks import EarlyStopping | |
from sklearn.metrics import precision_recall_fscore_support | |
from sklearn.metrics import balanced_accuracy_score | |
from modAL.models import ActiveLearner | |
from modAL.uncertainty import uncertainty_sampling | |
class BertModule(nn.Module): | |
""" BERT model according to Skorch convention """ | |
def __init__(self, name, num_labels): | |
super().__init__() | |
self.name = name | |
self.num_labels = num_labels | |
self.reset_weights() | |
def reset_weights(self): | |
self.bert = AutoModelForSequenceClassification.from_pretrained( | |
self.name, num_labels=self.num_labels | |
) | |
def forward(self, **kwargs): | |
pred = self.bert(**kwargs) | |
return pred.logits | |
MAX_EPOCHS = 5 | |
BATCH_SIZE = 12 | |
num_training_steps = MAX_EPOCHS * (584 // BATCH_SIZE + 1) | |
def lr_schedule(current_step): | |
factor = float(num_training_steps - current_step) / float(max(1, num_training_steps)) | |
assert factor > 0 | |
return factor | |
class CPU_Unpickler(pickle.Unpickler): | |
def find_class(self, module, name): | |
if module == 'torch.storage' and name == '_load_from_bytes': | |
return lambda b: torch.load(io.BytesIO(b), map_location='cpu') | |
else: return super().find_class(module, name) | |
with open('learner.bin', 'rb') as f: | |
learner = CPU_Unpickler(f).load() | |
def dg_predict(tweet): | |
return '🐕' if learner.predict([tweet])[0] == 1 else '🐈' | |
examples = [ | |
"Såå kulturberikande med explosioner utanför porten 🤗 #tackmagda", | |
"Att WEF ska kunna styra över svenskar är fascistiskt", | |
"Det råder inget tvivel om att svensk kultur inte är svensk längre, utan dessa MENA kommer hit och lever på bidrag och skjuter i våra förorter", | |
"Kriminella måste skickas hem där de kommer ifrån!! Återvandring!!", | |
"En global konspiration från Bryssel för att försvaga vår suveränitet, tills vi alla måste buga för Soros" | |
] | |
iface = gr.Interface(fn=dg_predict, inputs="text", outputs="text", examples=examples) | |
iface.launch() |