from model import BirdModel_Attention_lstm from transformers import AutoTokenizer import torch from transformers import AutoModelForSeq2SeqLM, AutoModelForMaskedLM def loading_model(path): models = AutoModelForMaskedLM.from_pretrained(path) return models def init(path): tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") models = loading_model("bert-base-uncased") model = BirdModel_Attention_lstm(models, True) checkpoint = torch.load(path,map_location=torch.device('cpu')) model.load_state_dict(checkpoint) return model, tokenizer def clip(text, tokenizer,max_len): encoding = tokenizer(text, add_special_tokens=True, return_tensors="pt", max_length=max_len, padding="max_length", truncation=True) input_ids = encoding['input_ids'] attention_mask = encoding['attention_mask'] return input_ids, attention_mask def answer(input,model,tokenizer): input_ids, masks = clip(input, tokenizer, 512) outputs = model(input_ids=input_ids, masks=masks) return outputs