#tuto : https://gradio.app/creating_a_chatbot/ from transformers import AutoTokenizer, AutoModelForCausalLM import torch import re ckpt = 'armandnlp/gpt2-TOD_finetuned_SGD' tokenizer = AutoTokenizer.from_pretrained(ckpt) model = AutoModelForCausalLM.from_pretrained(ckpt) def format_resp(system_resp): # format Belief, Action and Response tags system_resp = system_resp.replace('<|belief|>', '*Belief State: ') system_resp = system_resp.replace('<|action|>', '*Actions: ') system_resp = system_resp.replace('<|response|>', '*System Response: ') return system_resp def predict(input, history=[]): if history != []: # model expects only user and system responses, no belief or action sequences # therefore we clean up the history first. # history is a list of token ids which represents all the previous states in the conversation # ie. tokenied user inputs + tokenized model outputs history_str = tokenizer.decode(history[0]) turns = re.split('<\|system\|>|<\|user\|>', history_str)[1:] for i in range(0, len(turns)-1, 2): turns[i] = '<|user|>' + turns[i] # keep only the response part of each system_out in the history (no belief and action) turns[i+1] = '<|system|>' + turns[i+1].split('<|response|>')[1] history4input = tokenizer.encode(''.join(turns), return_tensors='pt') else: history4input = torch.LongTensor(history) # format input for model by concatenating <|context|> + history4input + new_input + <|endofcontext|> new_user_input_ids = tokenizer.encode(' <|user|> '+input, return_tensors='pt') context = tokenizer.encode('<|context|>', return_tensors='pt') endofcontext = tokenizer.encode(' <|endofcontext|>', return_tensors='pt') model_input = torch.cat([context, history4input, new_user_input_ids, endofcontext], dim=-1) # generate output out = model.generate(model_input, max_length=1024, eos_token_id=50262).tolist()[0] # formatting the history # leave out endof... tokens string_out = tokenizer.decode(out) system_out = string_out.split('<|endofcontext|>')[1].replace('<|endofbelief|>', '').replace('<|endofaction|>', '').replace('<|endofresponse|>', '') resp_tokenized = tokenizer.encode(' <|system|> '+system_out, return_tensors='pt') history = torch.cat([torch.LongTensor(history), new_user_input_ids, resp_tokenized], dim=-1).tolist() # history = history + last user input + <|system|> <|belief|> ... <|action|> ... <|response|>... # format responses to print out # need to output all of the turns, hence why the history must contain belief + action info # even if we have to take it out of the model input turns = tokenizer.decode(history[0]) turns = re.split('<\|system\|>|<\|user\|>', turns)[1:] # list of all the user and system turns until now # list of tuples [(user, system), (user, system)...] # 1 tuple represents 1 exchange at 1 turn # system resp is formatted with function above to make more readable resps = [(turns[i], format_resp(turns[i+1])) for i in range(0, len(turns)-1, 2)] return resps, history examples = [["I want to book a restaurant for 2 people on Saturday."], ["What's the weather in Cambridge today ?"], ["I need to find a bus to Boston."], ["I want to add an event to my calendar."], ["I would like to book a plane ticket to New York."], ["I want to find a concert around LA."], ["Hi, I'd like to find an apartment in London please."], ["Can you find me a hotel room near Seattle please ?"], ["I want to watch a film online, a comedy would be nice"], ["I want to transfer some money please."], ["I want to reserve a movie ticket for tomorrow evening"], ["Can you play the song Learning to Fly by Tom Petty ?"], ["I need to rent a small car."] ] description = """ This is an interactive window to chat with GPT-2 fine-tuned on the Schema-Guided Dialogues dataset, in which we find domains such as travel, weather, media, calendar, banking, restaurant booking... """ article = """ ### Model Outputs This task-oriented dialogue system is trained end-to-end, following the method detailed in [SimpleTOD](https://arxiv.org/pdf/2005.00796.pdf), where GPT-2 is trained by casting task-oriented dialogue as a seq2seq task. From the dialogue history, composed of the previous user and system responses, the model is trained to output the belief state, the action decisions and the system response as a sequence. We show all three outputs in this demo : the belief state tracks the user goal (restaurant cuisine : Indian or media genre : comedy for ex.), the action decisions show how the system should proceed (restaurants request city or media offer title for ex.) and the natural language response provides an output the user can interpret. The model responses are *de-lexicalized* : database values in the training set have been replaced with their slot names to make the learning process database agnostic. These slots are meant to later be replaced by actual results from a database, using the belief state to issue calls. The model is capable of dealing with multiple domains : a list of possible inputs is provided to get the conversation going. ### Dataset The SGD dataset ([blogpost](https://ai.googleblog.com/2019/10/introducing-schema-guided-dialogue.html) and [article](https://arxiv.org/pdf/1909.05855.pdf)) contains multiple task domains... Here is a list of some of the services and their descriptions from the dataset: * **Restaurants** : A leading provider for restaurant search and reservations * **Weather** : Check the weather for any place and any date * **Buses** : Find a bus to take you to the city you want * **Calendar** : Calendar service to manage personal events and reservations * **Flights** : Find your next flight * **Events** : Get tickets for the coolest concerts and sports in your area * **Homes** : A widely used service for finding apartments and scheduling visits * **Hotels** : A popular service for searching and reserving rooms in hotels * **Media** : A leading provider of movies for searching and watching on-demand * **Banks** : Manage bank accounts and transfer money * **Movies** : A go-to provider for finding movies, searching for show times and booking tickets * **Music** : A popular provider of a wide range of music content for searching and listening * **RentalCars** : Car rental service with extensive coverage of locations and cars """ import gradio as gr gr.Interface(fn=predict, inputs=["text", "state"], outputs=["chatbot", "state"], title="Chatting with multi task-oriented GPT2", examples=examples, description=description, article=article ).launch()