File size: 4,218 Bytes
9e03479
e90d33c
 
4c6dd99
e90b9d0
359fdc3
b503ddc
9e03479
e90d33c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e03479
5b13017
e90d33c
630a257
13498f3
fdbe694
c9f8963
e90d33c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645214e
7579d43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01a82e4
b731873
64f85a2
 
 
 
 
 
6bdea46
64f85a2
7579d43
64f85a2
 
 
8c1d44d
32008c0
d617597
420c094
a4cf9cb
 
 
 
562d3e8
daa38c3
411fde6
f71a80b
 
562d3e8
 
 
 
a4cf9cb
 
562d3e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
from model import GPTConfig, GPT
import torch              
from contextlib import nullcontext
import os
import time
import pickle

def remove_caseifer(text):
    new_text = ""
    i = 0
    while i < len(text):
        if text[i] == "^":
            if i+1 < len(text):
                new_text += text[i+1].upper()
                i += 1
            else:
                pass  # skip this index
        else:
            new_text += text[i]
        i += 1
    return new_text
    
def add_caseifer(text):
    new_text = ""
    for char in text:
        if char.isupper():
            new_text += "^" + char.lower()
        else:
            new_text += char
    return new_text

max_new_tokens = 88 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = None # retain only the top_k most likely tokens, clamp others to have 0 probability
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
out_dir = 'Eml' # ignored if init_from is not 'resume'


torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# init from a model saved in a specific directory
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

model.eval()
model.to(device)

meta_path = os.path.join(out_dir, 'meta.pkl')
load_meta = os.path.exists(meta_path)

with open(meta_path, 'rb') as f:
    meta = pickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
    
def load_model(model_name):
    ckpt_path = os.path.join(out_dir, model_name)
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    model.eval()
    model.to(device)
    return model

def get_model_list():
    models_dir = out_dir
    model_files = os.listdir(models_dir)
    model_files = [f for f in model_files if f.endswith('.pt')]
    return model_files

def gen(input):
    #print(input)
    generated_text = ''
    start_ids = encode(add_caseifer(input))
    x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
    for idx_next in model.generate_streaming(x, max_new_tokens, temperature=temperature, top_k=top_k):
        # convert the index to a character and print it to the screen
        char = decode([idx_next])
        generated_text += char


        # check for newline character
        if char == '\n':
            out = remove_caseifer(generated_text)
            return out
        
chat_history = []
with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    def respond(message, chat_history):
        temp_str = "".join([f"{t[0]} {t[1]}" for t in chat_history])
        message = message+'\n'
        bot_message = gen(temp_str+message)
        chat_history.append((message, bot_message))
        time.sleep(1)
        return "", chat_history

    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)

demo.launch()