dangduytung's picture
Add simple chatbot use model DiabloGPT
d15ce79
raw
history blame
2.74 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import datetime
import __init__
MODEL_NAME = __init__.MODEL_MICROSOFT_DIABLO_MEDIUM
OUTPUT_MAX_LENGTH = __init__.OUTPUT_MAX_LENGTH
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
def print_f(session_id, text):
print(f"{datetime.datetime.now()} | {session_id} | {text}")
def predict(input, history, request: gr.Request):
session_id = 'UNKNOWN'
if request:
# Get session_id is client_ip + client_port
session_id = request.client.host + ':' + str(request.client.port)
# print_f(session_id, f" inp: {input}")
# Tokenize the new input sentence
new_user_input_ids = tokenizer.encode(
input + tokenizer.eos_token, return_tensors='pt')
# Append the new user input tokens to the chat history
bot_input_ids = torch.cat(
[torch.LongTensor(history), new_user_input_ids], dim=-1)
# Generate a response
history = model.generate(bot_input_ids, max_length=OUTPUT_MAX_LENGTH,
pad_token_id=tokenizer.eos_token_id).tolist()
# Convert the tokens to text, and then split the responses into lines
response = tokenizer.decode(history[0]).split("<|endoftext|>")
# Convert to tuples of list
response = [(response[i], response[i + 1])
for i in range(0, len(response) - 1, 2)]
# Print new conversation
print_f(session_id, response[-1])
return response, history
css = """
#row_bot{width: 70%; height: var(--size-96); margin: 0 auto}
#row_bot .block{background: var(--color-grey-100); height: 100%}
#row_input{width: 70%; margin: 0 auto}
#row_input .block{background: var(--color-grey-100)}
@media screen and (max-width: 768px) {
#row_bot{width: 100%; height: var(--size-96); margin: 0 auto}
#row_bot .block{background: var(--color-grey-100); height: 100%}
#row_input{width: 100%; margin: 0 auto}
#row_input .block{background: var(--color-grey-100)}
}
"""
block = gr.Blocks(css=css, title="Chatbot")
with block:
gr.Markdown(f"""
<p style="font-size:20px; text-align: center">{MODEL_NAME}</p>
""")
with gr.Row(elem_id='row_bot'):
chatbot = gr.Chatbot()
with gr.Row(elem_id='row_input'):
message = gr.Textbox(placeholder="Enter something")
state = gr.State([])
message.submit(predict,
inputs=[message, state],
outputs=[chatbot, state])
message.submit(lambda x: "", message, message)
# Params ex: debug=True, share=True, server_name="0.0.0.0", server_port=5050
block.launch()