Spaces:
Sleeping
Sleeping
from transformers import GPT2Tokenizer, GPT2LMHeadModel, StoppingCriteria, StoppingCriteriaList | |
import torch, os | |
import gradio as gr | |
model_path = "iopwsy/MatGPT-synthesis" | |
tokenizer = GPT2Tokenizer.from_pretrained(model_path, pad_token = '<|endoftext|>') | |
model = GPT2LMHeadModel.from_pretrained(model_path) | |
model.config.pad_token_id = model.config.eos_token_id | |
model.eval() | |
class StopforGPT2(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
return input_ids[0,-1].detach().cpu().numpy() == 50256 | |
def get_path(text): | |
input_ = tokenizer.encode_plus(text, | |
add_special_tokens = True, | |
return_token_type_ids = False, | |
return_attention_mask = True, | |
return_special_tokens_mask = False, | |
return_tensors = 'pt') | |
with torch.no_grad(): | |
res = model.generate(**input_, | |
do_sample=False, | |
max_new_tokens = 300, | |
stopping_criteria=StoppingCriteriaList([StopforGPT2()])) | |
return tokenizer.decode(res[0,len(input_['input_ids'][0]):],skip_special_tokens=True) | |
demo = gr.Interface(get_path, | |
gr.Textbox(label="Input",info="How to synthesis xxx?\n"), | |
gr.Textbox(label="Output")) | |
if __name__ == "__main__": | |
demo.launch() | |