RamblingGPT / app.py
Corianas's picture
Update app.py
7a0b96c
raw
history blame
4.83 kB
import gradio as gr
from model import GPTConfig, GPT
import torch
from contextlib import nullcontext
import os
import time
import pickle
def remove_caseifer(text):
new_text = ""
i = 0
while i < len(text):
if text[i] == "↨":
if i+1 < len(text):
new_text += text[i+1].upper()
i += 1
else:
pass # skip this index
else:
new_text += text[i]
i += 1
return new_text
def add_caseifer(text):
tokenlist = set("\n\" !$&'#,/+=-<>*@.:;[]{}()^_?0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzèé")
replace_map = { # Define a mapping of characters to be replaced
"{": "[",
"(": "[",
"}": "]",
")": "]",
"&":"and"
}
upperlist = set("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
new_text = ""
for char in text:
if char in tokenlist:
if char in upperlist:
new_text += "↨" + char.lower()
elif char in replace_map:
new_text += replace_map[char]
else:
new_text += char
return new_text
max_new_tokens = 88 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = None # retain only the top_k most likely tokens, clamp others to have 0 probability
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
out_dir = 'Eml' # ignored if init_from is not 'resume'
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# init from a model saved in a specific directory
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
meta_path = os.path.join(out_dir, 'meta.pkl')
load_meta = os.path.exists(meta_path)
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
def load_model(model_name):
ckpt_path = os.path.join(out_dir, model_name)
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
return model
def get_model_list():
models_dir = out_dir
model_files = os.listdir(models_dir)
model_files = [f for f in model_files if f.endswith('.pt')]
return model_files
def gen(input):
#print(input)
direction = 'You are Emeldar, an AI chatter,'
generated_text = ''
x = (torch.tensor(encode(add_caseifer(input)), dtype=torch.long, device=device)[None, ...])
y = (torch.tensor(encode(add_caseifer(direction)), dtype=torch.long, device=device)[None, ...])
for idx_next in model.generate_instructed_streaming(x, y, max_new_tokens, temperature=temperature, top_k=top_k):
# convert the index to a character and print it to the screen
char = decode([idx_next])
generated_text += char
# check for newline character
if char == '\n':
out = remove_caseifer(generated_text)
return out
chat_history = []
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def respond(message, chat_history):
temp_str = "".join([f"{t[0]} {t[1]}" for t in chat_history])
message = message+'\n'
bot_message = gen(temp_str+message)
chat_history.append((message, bot_message))
time.sleep(1)
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch()