# -*- coding: utf-8 -*- import torch from torch import nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelWithLMHead from functools import lru_cache from tokenizers import ByteLevelBPETokenizer from tokenizers.processors import BertProcessing def setup_tokenizer(): tokenizer = AutoTokenizer.from_pretrained('distilroberta-base') tokenizer.save_pretrained("tokenizer") import os os.system("mkdir -p tokenizer") setup_tokenizer() # from https://github.com/digantamisra98/Mish/blob/b5f006660ac0b4c46e2c6958ad0301d7f9c59651/Mish/Torch/mish.py @torch.jit.script def mish(input): return input * torch.tanh(F.softplus(input)) class Mish(nn.Module): def forward(self, input): return mish(input) class EmoModel(nn.Module): def __init__(self, base_model, n_classes=2, base_model_output_size=768, dropout=0.05): super().__init__() self.base_model = base_model self.classifier = nn.Sequential( nn.Dropout(dropout), nn.Linear(base_model_output_size, base_model_output_size), Mish(), nn.Dropout(dropout), # originally, n_classes = 6 # now, we want to use VA, change it to 2 nn.Linear(base_model_output_size, n_classes) ) for layer in self.classifier: if isinstance(layer, nn.Linear): layer.weight.data.normal_(mean=0.0, std=0.02) if layer.bias is not None: layer.bias.data.zero_() def forward(self, input_, *args): X, attention_mask = input_ hidden_states = self.base_model(X, attention_mask=attention_mask) return self.classifier(hidden_states[0][:, 0, :]) from pathlib import Path #pretrained_path = "on_plurk_new_fix_data_arch_1_epoch_2_bs_16.pt" pretrained_path = "arch1_unfreeze_all.pt" # the latest weights! assert Path(pretrained_path).is_file() model = EmoModel(AutoModelWithLMHead.from_pretrained("distilroberta-base").base_model) model.load_state_dict(torch.load(pretrained_path,map_location=torch.device('cpu'))) model.eval() from functools import lru_cache @lru_cache(maxsize=1) def get_tokenizer(max_tokens=512): from tokenizers import ByteLevelBPETokenizer from tokenizers.processors import BertProcessing # add error checking voc_file = "tokenizer/vocab.json" merg_file = "tokenizer/merges.txt" import os.path if not os.path.isfile(voc_file) or not os.path.isfile(merg_file): setup_tokenizer() t = ByteLevelBPETokenizer( voc_file, merg_file ) t._tokenizer.post_processor = BertProcessing( ("", t.token_to_id("")), ("", t.token_to_id("")), ) t.enable_truncation(max_tokens) t.enable_padding(length=max_tokens, pad_id=t.token_to_id("")) return t # Cell def convert_text_to_tensor(text, tokenizer=None): if tokenizer is None: tokenizer = get_tokenizer() enc = tokenizer.encode(text) X = torch.tensor(enc.ids).unsqueeze(0) Attn = torch.tensor(enc.attention_mask).unsqueeze(0) return (X, Attn) def get_output(text, model, tokenizer=None, return_tensor=False): # we should add try/Except error handling for "model" argument # , but i consider it to be ugly import torch with torch.no_grad(): model.eval() out = model(convert_text_to_tensor(text, tokenizer)) if return_tensor == True: return out else: # return [float, float] # remember to make it a 1-D tensor tt = out[0] return float(tt[0]), float(tt[1]) import gradio as gr def fn2(text, model=model, return_tensor=False): out = get_output(text,model, return_tensor=return_tensor) return out interface = gr.Interface( fn = fn2, inputs="text", outputs=["number", "number"] ) interface.launch()