import gradio as gr import pickle import numpy as np import torch import io from torch import nn from transformers import AutoModelForSequenceClassification from sklearn.pipeline import Pipeline from skorch import NeuralNetClassifier from skorch.callbacks import LRScheduler, ProgressBar from skorch.hf import HuggingfacePretrainedTokenizer from torch.optim.lr_scheduler import LambdaLR from skorch.callbacks import EarlyStopping from sklearn.metrics import precision_recall_fscore_support from sklearn.metrics import balanced_accuracy_score from modAL.models import ActiveLearner from modAL.uncertainty import uncertainty_sampling class BertModule(nn.Module): """ BERT model according to Skorch convention """ def __init__(self, name, num_labels): super().__init__() self.name = name self.num_labels = num_labels self.reset_weights() def reset_weights(self): self.bert = AutoModelForSequenceClassification.from_pretrained( self.name, num_labels=self.num_labels ) def forward(self, **kwargs): pred = self.bert(**kwargs) return pred.logits MAX_EPOCHS = 5 BATCH_SIZE = 12 num_training_steps = MAX_EPOCHS * (584 // BATCH_SIZE + 1) def lr_schedule(current_step): factor = float(num_training_steps - current_step) / float(max(1, num_training_steps)) assert factor > 0 return factor class CPU_Unpickler(pickle.Unpickler): def find_class(self, module, name): if module == 'torch.storage' and name == '_load_from_bytes': return lambda b: torch.load(io.BytesIO(b), map_location='cpu') else: return super().find_class(module, name) with open('learner.bin', 'rb') as f: learner = CPU_Unpickler(f).load() def dg_predict(tweet): return '🐕' if learner.predict([tweet])[0] == 1 else '🐈' examples = [ "SĂ„Ă„ kulturberikande med explosioner utanför porten đŸ€— #tackmagda", "Att WEF ska kunna styra över svenskar Ă€r fascistiskt", "Det rĂ„der inget tvivel om att svensk kultur inte Ă€r svensk lĂ€ngre, utan dessa MENA kommer hit och lever pĂ„ bidrag och skjuter i vĂ„ra förorter", "Kriminella mĂ„ste skickas hem dĂ€r de kommer ifrĂ„n!! Återvandring!!", "En global konspiration frĂ„n Bryssel för att försvaga vĂ„r suverĂ€nitet, tills vi alla mĂ„ste buga för Soros" ] iface = gr.Interface(fn=dg_predict, inputs="text", outputs="text", examples=examples) iface.launch()