Jinyi-Guard / main.py
changsr's picture
Update main.py
f09c638 verified
raw
history blame contribute delete
No virus
1.07 kB
from model import BirdModel_Attention_lstm
from transformers import AutoTokenizer
import torch
from transformers import AutoModelForSeq2SeqLM, AutoModelForMaskedLM
def loading_model(path):
models = AutoModelForMaskedLM.from_pretrained(path)
return models
def init(path):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
models = loading_model("bert-base-uncased")
model = BirdModel_Attention_lstm(models, True)
checkpoint = torch.load(path,map_location=torch.device('cpu'))
model.load_state_dict(checkpoint)
return model, tokenizer
def clip(text, tokenizer,max_len):
encoding = tokenizer(text, add_special_tokens=True, return_tensors="pt",
max_length=max_len, padding="max_length", truncation=True)
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
return input_ids, attention_mask
def answer(input,model,tokenizer):
input_ids, masks = clip(input, tokenizer, 512)
outputs = model(input_ids=input_ids, masks=masks)
return outputs