import gradio as gr import pickle import numpy as np import torch import io from torch import nn from transformers import AutoModelForSequenceClassification from sklearn.pipeline import Pipeline from skorch import NeuralNetClassifier from skorch.callbacks import LRScheduler, ProgressBar from skorch.hf import HuggingfacePretrainedTokenizer from torch.optim.lr_scheduler import LambdaLR from skorch.callbacks import EarlyStopping from sklearn.metrics import precision_recall_fscore_support from sklearn.metrics import balanced_accuracy_score from modAL.models import ActiveLearner from modAL.uncertainty import uncertainty_sampling class BertModule(nn.Module): """ BERT model according to Skorch convention """ def __init__(self, name, num_labels): super().__init__() self.name = name self.num_labels = num_labels self.reset_weights() def reset_weights(self): self.bert = AutoModelForSequenceClassification.from_pretrained( self.name, num_labels=self.num_labels ) def forward(self, **kwargs): pred = self.bert(**kwargs) return pred.logits MAX_EPOCHS = 5 BATCH_SIZE = 12 num_training_steps = MAX_EPOCHS * (584 // BATCH_SIZE + 1) def lr_schedule(current_step): factor = float(num_training_steps - current_step) / float(max(1, num_training_steps)) assert factor > 0 return factor class CPU_Unpickler(pickle.Unpickler): def find_class(self, module, name): if module == 'torch.storage' and name == '_load_from_bytes': return lambda b: torch.load(io.BytesIO(b), map_location='cpu') else: return super().find_class(module, name) with open('learner.bin', 'rb') as f: learner = CPU_Unpickler(f).load() def dg_predict(tweet): return learner.predict(tweet) examples = [ "Såå kulturberikande med explosioner utanför porten 🤗 #tackmagda", "Att WEF ska kunna styra över svenskar är fascistiskt", "Det råder inget tvivel om att svensk kultur inte är svensk längre, utan dessa MENA kommer hit och lever på bidrag och skjuter i våra förorter", "Kriminella måste skickas hem där de kommer ifrån!! Återvandring!!", "En global konspiration från Bryssel för att försvaga vår suveränitet, tills vi alla måste buga för Soros" ] iface = gr.Interface(fn=dg_predict, inputs="text", outputs="text", examples=examples) iface.launch()