iopwsy's picture
Update app.py
4977688 verified
from transformers import GPT2Tokenizer, GPT2LMHeadModel, StoppingCriteria, StoppingCriteriaList
import torch, os
import gradio as gr
model_path = "iopwsy/MatGPT-synthesis"
tokenizer = GPT2Tokenizer.from_pretrained(model_path, pad_token = '<|endoftext|>')
model = GPT2LMHeadModel.from_pretrained(model_path)
model.config.pad_token_id = model.config.eos_token_id
model.eval()
class StopforGPT2(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids[0,-1].detach().cpu().numpy() == 50256
def get_path(text):
input_ = tokenizer.encode_plus(text,
add_special_tokens = True,
return_token_type_ids = False,
return_attention_mask = True,
return_special_tokens_mask = False,
return_tensors = 'pt')
with torch.no_grad():
res = model.generate(**input_,
do_sample=False,
max_new_tokens = 300,
stopping_criteria=StoppingCriteriaList([StopforGPT2()]))
return tokenizer.decode(res[0,len(input_['input_ids'][0]):],skip_special_tokens=True)
demo = gr.Interface(get_path,
gr.Textbox(label="Input",info="How to synthesis xxx?\n"),
gr.Textbox(label="Output"))
if __name__ == "__main__":
demo.launch()