File size: 4,829 Bytes
9e03479 e90d33c 4c6dd99 e90b9d0 359fdc3 b503ddc 9e03479 e90d33c d9a32ba e90d33c 1993817 e90d33c 1993817 e90d33c 9e03479 5b13017 e90d33c 630a257 13498f3 fdbe694 c9f8963 e90d33c 645214e 7579d43 01a82e4 b731873 d9a32ba 64f85a2 c836a45 64f85a2 6bdea46 64f85a2 7579d43 64f85a2 8c1d44d 32008c0 d617597 420c094 a4cf9cb 562d3e8 daa38c3 411fde6 f71a80b 562d3e8 a4cf9cb aa188fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
from model import GPTConfig, GPT
import torch
from contextlib import nullcontext
import os
import time
import pickle
def remove_caseifer(text):
new_text = ""
i = 0
while i < len(text):
if text[i] == "↨":
if i+1 < len(text):
new_text += text[i+1].upper()
i += 1
pass # skip this index
new_text += text[i]
i += 1
return new_text
def add_caseifer(text):
tokenlist = set("\n\" !$&'#,/+=-<>*@.:;[]{}()^_?0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzèé")
replace_map = { # Define a mapping of characters to be replaced
"{": "[",
"(": "[",
"}": "]",
")": "]",
new_text = ""
for char in text:
if char in tokenlist:
if char in upperlist:
new_text += "↨" + char.lower()
elif char in replace_map:
new_text += replace_map[char]
new_text += char
return new_text
max_new_tokens = 88 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = None # retain only the top_k most likely tokens, clamp others to have 0 probability
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
out_dir = 'Eml' # ignored if init_from is not 'resume'
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# init from a model saved in a specific directory
ckpt_path = os.path.join(out_dir, '')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
meta_path = os.path.join(out_dir, 'meta.pkl')
load_meta = os.path.exists(meta_path)
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
def load_model(model_name):
ckpt_path = os.path.join(out_dir, model_name)
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
return model
def get_model_list():
models_dir = out_dir
model_files = os.listdir(models_dir)
model_files = [f for f in model_files if f.endswith('.pt')]
return model_files
def gen(input):
direction = 'You are Emeldar, an AI chatter,'
generated_text = ''
x = (torch.tensor(encode(add_caseifer(input)), dtype=torch.long, device=device)[None, ...])
y = (torch.tensor(encode(add_caseifer(direction)), dtype=torch.long, device=device)[None, ...])
for idx_next in model.generate_instructed_streaming(x, y, max_new_tokens, temperature=temperature, top_k=top_k):
# convert the index to a character and print it to the screen
char = decode([idx_next])
generated_text += char
# check for newline character
if char == '\n':
out = remove_caseifer(generated_text)
return out
chat_history = []
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def respond(message, chat_history):
temp_str = "".join([f"{t[0]} {t[1]}" for t in chat_history])
message = message+'\n'
bot_message = gen(temp_str+message)
chat_history.append((message, bot_message))
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot]) None, None, chatbot, queue=False)
demo.launch() |