Spaces:
Paused
Paused
File size: 4,169 Bytes
9e03479 e90d33c 4c6dd99 e90b9d0 b503ddc 9e03479 e90d33c 9e03479 5b13017 e90d33c 630a257 13498f3 fdbe694 c9f8963 e90d33c 645214e 7579d43 01a82e4 64f85a2 6bdea46 64f85a2 7579d43 64f85a2 32008c0 562d3e8 d617597 420c094 a4cf9cb 562d3e8 6c5bdb4 562d3e8 a4cf9cb 562d3e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
from model import GPTConfig, GPT
import torch
from contextlib import nullcontext
import os
import pickle
def remove_caseifer(text):
new_text = ""
i = 0
while i < len(text):
if text[i] == "^":
if i+1 < len(text):
new_text += text[i+1].upper()
i += 1
else:
pass # skip this index
else:
new_text += text[i]
i += 1
return new_text
def add_caseifer(text):
new_text = ""
for char in text:
if char.isupper():
new_text += "^" + char.lower()
else:
new_text += char
return new_text
max_new_tokens = 88 # number of tokens generated in each sample
temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = None # retain only the top_k most likely tokens, clamp others to have 0 probability
device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16'
out_dir = 'Eml' # ignored if init_from is not 'resume'
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# init from a model saved in a specific directory
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
meta_path = os.path.join(out_dir, 'meta.pkl')
load_meta = os.path.exists(meta_path)
with open(meta_path, 'rb') as f:
meta = pickle.load(f)
# TODO want to make this more general to arbitrary encoder/decoder schemes
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
def load_model(model_name):
ckpt_path = os.path.join(out_dir, model_name)
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
model.to(device)
return model
def get_model_list():
models_dir = out_dir
model_files = os.listdir(models_dir)
model_files = [f for f in model_files if f.endswith('.pt')]
return model_files
def gen(input):
generated_text = ''
start_ids = encode(add_caseifer(input))
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
for idx_next in model.generate_streaming(x, max_new_tokens, temperature=temperature, top_k=top_k):
# convert the index to a character and print it to the screen
char = decode([idx_next])
generated_text += char
# check for newline character
if char == '\n':
out = remove_caseifer(generated_text)
return input + out
md = """This is some code:
hello
```py
def fn(x, y, z):
print(x, y, z)
"""
chat_history = []
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
def respond(message, chat_history):
chat_history.append((message, md))
bot_message = gen(str(chat_history))
time.sleep(1)
return "", chat_history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
demo.launch() |