File size: 4,259 Bytes
e71a2ba
776e43c
 
e71a2ba
 
776e43c
0669a02
aae4195
 
776e43c
aae4195
776e43c
 
 
 
 
aae4195
 
776e43c
 
 
 
 
aae4195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776e43c
 
 
 
 
aae4195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0669a02
aae4195
 
 
 
 
 
 
 
 
 
 
 
 
776e43c
 
aae4195
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import sys

import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

sys.path.insert(0, './petals/')

from petals.client.remote_model import DistributedBloomForCausalLM

MODEL_NAME = "bigscience/test-bloomd-6b3"
# INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"]
tokenizer_bloomd_6b3 = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model_bloomd_6b3 = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME,
                                                               # initial_peers=INITIAL_PEERS,
                                                               low_cpu_mem_usage=True, torch_dtype=torch.float32)

MODEL_NAME = "bigscience/bloom-petals"
tokenizer_bloomd = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model_bloomd = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME,
                                                           low_cpu_mem_usage=True, torch_dtype=torch.float32)


tokenizer_DialoGPT_small = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model_DialoGPT_small = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")

tokenizer_DialoGPT_medium = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model_DialoGPT_medium = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")

tokenizer_DialoGPT_large = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model_DialoGPT_large = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")


def predict(
        input_text,
        history=None,
        person_description=None,
        number_of_new_tokens=1000,
        model_name=None,
        del_hist=None
):
    if history is None or del_hist == 'delete history':
        history = []
    if model_name == 'DialoGPT-small':
        model = model_DialoGPT_small
        tokenizer = tokenizer_DialoGPT_small
    elif model_name == 'DialoGPT-medium':
        model = model_DialoGPT_medium
        tokenizer = tokenizer_DialoGPT_medium
    elif model_name == 'DialoGPT-large':
        model = model_DialoGPT_large
        tokenizer = tokenizer_DialoGPT_large
    elif model_name == 'test-bloomd-6b3':
        model = model_bloomd_6b3
        tokenizer = tokenizer_bloomd_6b3
    elif model_name == 'bloom-petals':
        model = model_bloomd
        tokenizer = tokenizer_bloomd
    else:
        model = model_DialoGPT_medium
        tokenizer = tokenizer_DialoGPT_medium

    person_description_ids = tokenizer.encode(person_description + tokenizer.eos_token, return_tensors='pt')
    new_user_input_ids = tokenizer.encode(input_text + tokenizer.eos_token, return_tensors='pt')

    bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
    input_with_desc_ids = torch.cat([person_description_ids, bot_input_ids], dim=-1)
    max_token_count = number_of_new_tokens + len(input_with_desc_ids[0])
    history = model.generate(input_with_desc_ids, max_length=max_token_count,
                             pad_token_id=tokenizer.eos_token_id).tolist()
    history[0] = history[0][len(person_description_ids[0]):]

    response = tokenizer.decode(history[0]).split("<|endoftext|>")
    response = [(response[i], response[i + 1]) for i in range(0, len(response) - 1, 2)]
    return response, history


gr.Interface(
    fn=predict,
    inputs=[
        gr.Textbox(label='Input message', lines=1, placeholder="Enter your message..."),
        "state",
        gr.Textbox(label='Person Description', lines=2, placeholder="Enter a description of the person..."),
        gr.Slider(label='Number of new tokens', minimum=2, maximum=100, value=10),
        gr.Radio(
            label='Model name',
            choices=[
                'DialoGPT-small',
                'DialoGPT-medium',
                'DialoGPT-large',
                'test-bloomd-6b3',
                'bloom-petals',
            ]
        ),
        gr.Radio(
            label='Delete history',
            value="Don't delete history",
            choices=[
                'delete history',
                "Don't delete history"
            ]),
    ],
    outputs=[gr.Chatbot(label='History of the dialogue'), "state"],
).launch(),