Spaces:
Sleeping
Sleeping
File size: 1,600 Bytes
7a7a56a 428c8b1 7a7a56a 428c8b1 7a7a56a 428c8b1 7a7a56a 428c8b1 7a7a56a 428c8b1 7a7a56a 428c8b1 7a7a56a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import transformers
import gradio as gr
# import warnings
import torch
# warnings.simplefilter('ignore')
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
#add padding token, beginstring and endstring tokens
tokenizer.add_special_tokens(
{
"pad_token":"<pad>",
"bos_token":"<startstring>",
"eos_token":"<endstring>"
})
#add bot token since it is not a special token
tokenizer.add_tokens(["<bot>:"])
print("=====Done 1")
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))
model.load_state_dict(torch.load('./gpt2talk.pt', map_location=torch.device('cpu')))
print("=====Done 2")
model.eval()
def inference(quiz):
quiz1 = quiz
quiz = "<startstring>"+quiz+" <bot>:"
quiztoken = tokenizer(quiz,
return_tensors='pt'
)
answer = model.generate(**quiztoken, max_length=200, top_k=0.7,top_p=0.1)[0]
answer = tokenizer.decode(answer, skip_special_tokens=True)
answer = answer.replace(" <bot>:","").replace(quiz1,"") + '.'
return answer
def chatbot(input_text):
response = inference(input_text)
return response
# Create the Gradio interface
print("=====Done 3")
iface = gr.Interface(
fn=chatbot,
inputs=gr.Textbox(),
outputs=gr.Textbox(),
live=False, #set false to avoid caching
interpretation="chat",
title="ChatFinance",
description="Ask the a question and see its response!",
)
print("=====Done 4")
# Launch the Gradio interface
iface.launch() |