import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch import datetime import __init__ MODEL_NAME = __init__.MODEL_MICROSOFT_DIABLO_MEDIUM OUTPUT_MAX_LENGTH = __init__.OUTPUT_MAX_LENGTH tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) def print_f(session_id, text): print(f"{datetime.datetime.now()} | {session_id} | {text}") def predict(input, history, request: gr.Request): session_id = 'UNKNOWN' if request: # Get session_id is client_ip + client_port session_id = request.client.host + ':' + str(request.client.port) # print_f(session_id, f" inp: {input}") # Tokenize the new input sentence new_user_input_ids = tokenizer.encode( input + tokenizer.eos_token, return_tensors='pt') # Append the new user input tokens to the chat history bot_input_ids = torch.cat( [torch.LongTensor(history), new_user_input_ids], dim=-1) # Generate a response history = model.generate(bot_input_ids, max_length=OUTPUT_MAX_LENGTH, pad_token_id=tokenizer.eos_token_id).tolist() # Convert the tokens to text, and then split the responses into lines response = tokenizer.decode(history[0]).split("<|endoftext|>") # Convert to tuples of list response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)] # Print new conversation print_f(session_id, response[-1]) return response, history css = """ #row_bot{width: 70%; height: var(--size-96); margin: 0 auto} #row_bot .block{background: var(--color-grey-100); height: 100%} #row_input{width: 70%; margin: 0 auto} #row_input .block{background: var(--color-grey-100)} @media screen and (max-width: 768px) { #row_bot{width: 100%; height: var(--size-96); margin: 0 auto} #row_bot .block{background: var(--color-grey-100); height: 100%} #row_input{width: 100%; margin: 0 auto} #row_input .block{background: var(--color-grey-100)} } """ block = gr.Blocks(css=css, title="Chatbot") with block: gr.Markdown(f"""

{MODEL_NAME}

""") with gr.Row(elem_id='row_bot'): chatbot = gr.Chatbot() with gr.Row(elem_id='row_input'): message = gr.Textbox(placeholder="Enter something") state = gr.State([]) message.submit(predict, inputs=[message, state], outputs=[chatbot, state]) message.submit(lambda x: "", message, message) # Params ex: debug=True, share=True, server_name="0.0.0.0", server_port=5050 block.launch()