File size: 5,187 Bytes
b03a999
 
553d99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3630c9b
6a0ca81
 
 
 
553d99f
 
 
 
 
6a0ca81
553d99f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a0ca81
 
553d99f
 
 
 
6a0ca81
 
 
553d99f
6a0ca81
553d99f
3802cb6
553d99f
 
 
b03a999
 
 
 
553d99f
 
 
b03a999
553d99f
 
b03a999
553d99f
b03a999
553d99f
2ad304c
553d99f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import sys
import argparse
import gradio as gr

import torch
from model import IntentPredictModel
from transformers import (T5Tokenizer, 
                          GPT2Tokenizer, GPT2Config, GPT2LMHeadModel)
from diffusers import StableDiffusionPipeline

from chatbot import Chat


def main(args):
    # Intent Prediction
    print("Loading Intent Prediction Classifier...")
    ## tokenizer
    intent_predict_tokenizer = T5Tokenizer.from_pretrained(args.intent_predict_model_name, truncation_side='left')
    intent_predict_tokenizer.add_special_tokens({'sep_token': '[SEP]'})
    # model
    intent_predict_model = IntentPredictModel(pretrained_model_name_or_path=args.intent_predict_model_name, num_classes=2)
    intent_predict_model.load_state_dict(torch.load(args.intent_predict_model_weights_path, map_location=args.device))
    print("Intent Prediction Classifier loading completed.")
    
    # Textual Dialogue Response Generator
    print("Loading Textual Dialogue Response Generator...")
    ## tokenizer
    text_dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.text_dialog_model_name, truncation_side='left')
    text_dialog_tokenizer.add_tokens(['[UTT]', '[DST]'])
    print(len(text_dialog_tokenizer))
    # config
    text_dialog_config = GPT2Config.from_pretrained(args.text_dialog_model_name)
    if len(text_dialog_tokenizer) > text_dialog_config.vocab_size:
        text_dialog_config.vocab_size = len(text_dialog_tokenizer)
    # load model weights
    text_dialog_model = GPT2LMHeadModel.from_pretrained(args.text_dialog_model_weights_path, config=text_dialog_config)
    print("Textual Dialogue Response Generator loading completed.")

    # Text-to-Image Translator
    print("Loading Text-to-Image Translator...")
    text2image_model = StableDiffusionPipeline.from_pretrained(args.text2image_model_weights_path, torch_dtype=torch.float32)
    print("Text-to-Image Translator loading completed.")
    
    chat = Chat(intent_predict_model, intent_predict_tokenizer, 
                text_dialog_model, text_dialog_tokenizer,
                text2image_model, 
                args.device)
    
    title = """<h1 align="center">Demo of Tiger</h1>"""
    description1 = """<h2>This is the demo of Tiger (Generative Multimodal Dialogue Model).</h2>"""
    description2 = """<h2>Input text start chatting!</h2>"""
    hr = """<hr>"""
    description_input = """<h3>Input:  text (English)</h3>"""
    description_output = """<h3>Output:  text / image</h3>"""

    with gr.Blocks() as demo:
        gr.Markdown(title)
        gr.Markdown(description1)
        gr.Markdown(description2)
        gr.Markdown(hr)
        gr.Markdown(description_input)
        gr.Markdown(description_output)

        with gr.Row():
            with gr.Column(scale=0.33):
                num_beams = gr.Slider(
                    minimum=1,
                    maximum=10,
                    value=5,
                    step=1,
                    interactive=True,
                    label="beam search numbers",
                )
                text2image_seed = gr.Slider(
                    minimum=1,
                    maximum=100,
                    value=42,
                    step=1,
                    interactive=True,
                    label="seed for text-to-image",
                )
                start = gr.Button("Start Chat", variant="primary")
                clear = gr.Button("Restart Chat (Clear dialogue history)", interactive=False)

            with gr.Column():
                chat_state = gr.State()
                chatbot = gr.Chatbot(label='Tiger')
                text_input = gr.Textbox(label='User', placeholder="Please click the <Start Chat> button to start chat!", interactive=False)
        
        start.click(chat.start_chat, [chat_state], [text_input, start, clear, chat_state])
        text_input.submit(chat.respond, [text_input, num_beams, text2image_seed, chatbot, chat_state], [text_input, chatbot, chat_state])
        clear.click(chat.restart_chat, [chat_state], [chatbot, text_input, start, clear, chat_state], queue=False)

    demo.launch(share=False, enable_queue=False)


if __name__ == "__main__":
    intent_predict_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_t5_base_encoder.pth")
    text_dialog_model_weights_path = os.path.join(sys.path[0], "model_weights/Tiger_DialoGPT_medium.pth")
    text2image_model_weights_path = os.path.join(sys.path[0], "model_weights/stable-diffusion-2-1-realistic")
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--intent_predict_model_name', type=str, default="t5-base")
    parser.add_argument('--intent_predict_model_weights_path', type=str, default=intent_predict_model_weights_path)

    parser.add_argument('--text_dialog_model_name', type=str, default="microsoft/DialoGPT-medium")
    parser.add_argument('--text_dialog_model_weights_path', type=str, default=text_dialog_model_weights_path)
    
    parser.add_argument('--text2image_model_weights_path', type=str, default=text2image_model_weights_path)

    parser.add_argument('--device', default="cpu")

    args = parser.parse_args()

    main(args)