Spaces:
Sleeping
Sleeping
import transformers | |
import gradio as gr | |
# import warnings | |
import torch | |
# warnings.simplefilter('ignore') | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2') | |
#add padding token, beginstring and endstring tokens | |
tokenizer.add_special_tokens( | |
{ | |
"pad_token":"<pad>", | |
"bos_token":"<startstring>", | |
"eos_token":"<endstring>" | |
}) | |
#add bot token since it is not a special token | |
tokenizer.add_tokens(["<bot>:"]) | |
print("=====Done 1") | |
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2') | |
model.resize_token_embeddings(len(tokenizer)) | |
model.load_state_dict(torch.load('./gpt2talk.pt', map_location=torch.device('cpu'))) | |
print("=====Done 2") | |
model.eval() | |
def inference(quiz): | |
quiz1 = quiz | |
quiz = "<startstring>"+quiz+" <bot>:" | |
quiztoken = tokenizer(quiz, | |
return_tensors='pt' | |
) | |
answer = model.generate(**quiztoken, max_length=200, top_k=0.7,top_p=0.1)[0] | |
answer = tokenizer.decode(answer, skip_special_tokens=True) | |
answer = answer.replace(" <bot>:","").replace(quiz1,"") + '.' | |
return answer | |
def chatbot(input_text): | |
response = inference(input_text) | |
return response | |
# Create the Gradio interface | |
print("=====Done 3") | |
iface = gr.Interface( | |
fn=chatbot, | |
inputs='text', | |
outputs='text', | |
live=False, #set false to avoid caching | |
interpretation="chat", | |
title="ChatFinance", | |
description="Ask the a question and see its response!", | |
) | |
print("=====Done 4") | |
# Launch the Gradio interface | |
iface.launch() |