import os import sys import argparse import gradio as gr import torch from model import IntentPredictModel from transformers import (T5Tokenizer, GPT2Tokenizer, GPT2Config, GPT2LMHeadModel) from diffusers import StableDiffusionPipeline from chatbot import Chat def main(args): # Intent Prediction print("Loading Intent Prediction Classifier...") ## tokenizer intent_predict_tokenizer = T5Tokenizer.from_pretrained(args.intent_predict_model_name, truncation_side='left') intent_predict_tokenizer.add_special_tokens({'sep_token': '[SEP]'}) # model intent_predict_model = IntentPredictModel(pretrained_model_name_or_path=args.intent_predict_model_name, num_classes=2) intent_predict_model.load_state_dict(torch.load(args.intent_predict_model_weights_path, map_location=args.device)) print("Intent Prediction Classifier loading completed.") # Textual Dialogue Response Generator print("Loading Textual Dialogue Response Generator...") ## tokenizer text_dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.text_dialog_model_name, truncation_side='left') text_dialog_tokenizer.add_tokens(['[UTT]', '[DST]']) print(len(text_dialog_tokenizer)) # config text_dialog_config = GPT2Config.from_pretrained(args.text_dialog_model_name) if len(text_dialog_tokenizer) > text_dialog_config.vocab_size: text_dialog_config.vocab_size = len(text_dialog_tokenizer) # load model weights text_dialog_model = GPT2LMHeadModel.from_pretrained(args.text_dialog_model_weights_path, config=text_dialog_config) print("Textual Dialogue Response Generator loading completed.") # Text-to-Image Translator print("Loading Text-to-Image Translator...") text2image_model = StableDiffusionPipeline.from_pretrained(args.text2image_model_weights_path, torch_dtype=torch.float32) print("Text-to-Image Translator loading completed.") chat = Chat(intent_predict_model, intent_predict_tokenizer, text_dialog_model, text_dialog_tokenizer, text2image_model, args.device) title = """

Demo of Tiger

""" description1 = """

This is the demo of Tiger (Generative Multimodal Dialogue Model).

""" description2 = """

Input text start chatting!

""" hr = """
""" description_input = """

Input: text (English)

""" description_output = """

Output: text / image

""" with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description1) gr.Markdown(description2) gr.Markdown(hr) gr.Markdown(description_input) gr.Markdown(description_output) with gr.Row(): with gr.Column(scale=0.33): num_beams = gr.Slider( minimum=1, maximum=10, value=5, step=1, interactive=True, label="beam search numbers", ) text2image_seed = gr.Slider( minimum=1, maximum=100, value=42, step=1, interactive=True, label="seed for text-to-image", ) start = gr.Button("Start Chat", variant="primary") clear = gr.Button("Restart Chat (Clear dialogue history)", interactive=False) with gr.Column(): chat_state = gr.State() chatbot = gr.Chatbot(label='Tiger') text_input = gr.Textbox(label='User', placeholder="Please click the button to start chat!", interactive=False) start.click(chat.start_chat, [chat_state], [text_input, start, clear, chat_state]) text_input.submit(chat.respond, [text_input, num_beams, text2image_seed, chatbot, chat_state], [text_input, chatbot, chat_state]) clear.click(chat.restart_chat, [chat_state], [chatbot, text_input, start, clear, chat_state], queue=False) demo.launch(share=False, enable_queue=False) if __name__ == "__main__": intent_predict_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_t5_base_encoder.pth") text_dialog_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_DialoGPT_medium.pth") text2image_model_weights_path = os.path.join(sys.path[0], "model_weights/stable-diffusion-2-1-realistic") parser = argparse.ArgumentParser() parser.add_argument('--intent_predict_model_name', type=str, default="t5-base") parser.add_argument('--intent_predict_model_weights_path', type=str, default=intent_predict_model_weights_path) parser.add_argument('--text_dialog_model_name', type=str, default="microsoft/DialoGPT-medium") parser.add_argument('--text_dialog_model_weights_path', type=str, default=text_dialog_model_weights_path) parser.add_argument('--text2image_model_weights_path', type=str, default=text2image_model_weights_path) parser.add_argument('--device', default="cpu") args = parser.parse_args() main(args)