Spaces:
Sleeping
Sleeping
import gradio as gr | |
import transformers | |
import torch | |
import yaml | |
from dearth_config import DearthConfig | |
from dearth_model import DearthForCausalLM | |
import random | |
import time | |
import threading | |
import asyncio | |
tk = None | |
model_states = None | |
lock_using_model = threading.Lock() | |
recent_generate_timestamp = time.time() | |
MODEL_LIVE_TIME = 5 * 60 # 5 minutes | |
def load_model(): | |
global tk, model_states | |
tk = transformers.AutoTokenizer.from_pretrained("./tk") | |
model_path = "./ts100-re2-h1-4000-model.pt" | |
states = torch.load(model_path, map_location="cpu") | |
model_states = states | |
unwanted_prefix_dueto_compile = '_orig_mod.' | |
unwanted_prefix_dueto_ddp = 'module.' | |
unwanted_prefix_dueto_ddp_compiled = 'module._orig_mod.' | |
for k,v in list(model_states.items()): | |
if k.startswith(unwanted_prefix_dueto_ddp_compiled): | |
new_key = k[len(unwanted_prefix_dueto_ddp_compiled):] | |
model_states[new_key] = model_states.pop(k) | |
elif k.startswith(unwanted_prefix_dueto_ddp): | |
new_key = k[len(unwanted_prefix_dueto_ddp):] | |
model_states[new_key] = model_states.pop(k) | |
elif k.startswith(unwanted_prefix_dueto_compile): | |
new_key = k[len(unwanted_prefix_dueto_compile):] | |
model_states[new_key] = model_states.pop(k) | |
def main_free_mem(): | |
event_loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(event_loop) | |
event_loop.call_later(MODEL_LIVE_TIME, free_mem) | |
event_loop.run_forever() | |
def free_mem(): | |
global tk, model_states, recent_generate_timestamp, lock_using_model | |
lock_using_model.acquire() | |
if time.time() - recent_generate_timestamp >= MODEL_LIVE_TIME and tk is not None: | |
tk = None | |
model_states = None | |
print(f"free mem, {time.time()}") | |
lock_using_model.release() | |
try: | |
event_loop = asyncio.get_event_loop() | |
event_loop.call_later(MODEL_LIVE_TIME, free_mem) | |
except: | |
pass | |
def generate(input, num_more_tokens): | |
global tk, model_states, model, recent_generate_timestamp, lock_using_model | |
lock_using_model.acquire() | |
time_start = time.time() | |
if tk is None: | |
load_model() | |
elif time.time() - recent_generate_timestamp > MODEL_LIVE_TIME: | |
tk = None | |
model_states = None | |
load_model() | |
yml_path = "./ts100-re2-h1.yml" | |
with open(yml_path, "r") as f: | |
config = yaml.load(f, Loader=yaml.FullLoader)['model'] | |
if "vocab_size" not in config: | |
config['vocab_size'] = tk.vocab_size | |
config["attn_window_size"] = 500 | |
# print(config) | |
config = DearthConfig(**config) | |
model = DearthForCausalLM(config) | |
model.load_state_dict(model_states) | |
model.eval() | |
recent_generate_timestamp = time.time() | |
print(f"load model time: {time.time() - time_start}") | |
time_start = time.time() | |
num_more_tokens = int(num_more_tokens) | |
# print(input) | |
input = input.strip() | |
input_ids = tk.encode(input) | |
input_ids = [tk.bos_token_id] + input_ids | |
input_ids = torch.tensor(input_ids, dtype=torch.long).view(1, -1) | |
# print(input_ids) | |
print(f"encode time: {time.time() - time_start}") | |
time_start = time.time() | |
output_ids = input_ids.squeeze(0).tolist() | |
for i in range(num_more_tokens): | |
input = torch.tensor(output_ids, dtype=torch.long).view(1, -1) | |
with torch.no_grad(): | |
output = model(input)[0] | |
last_token_logits = output[0, -1, :] | |
last_token_logits_topk = torch.topk(last_token_logits, k=5, dim=-1) | |
probs = torch.softmax(last_token_logits_topk.values, dim=-1) | |
new_token = torch.multinomial(probs, num_samples=1).item() | |
new_token = last_token_logits_topk.indices[new_token].item() | |
if new_token == tk.eos_token_id: | |
break | |
output_ids.append(new_token) | |
# print(output_ids) | |
# print(tk.decode(output_ids)) | |
output_ids = output_ids[1:] | |
print(f"inference time: {time.time() - time_start}\n") | |
ret = tk.decode(output_ids) | |
lock_using_model.release() | |
return ret | |
example_input = ["Once upon a time, there was a little girl", | |
"John and Sarah were playing together in their backyard when", | |
"It was a warm summer day when Billy and", | |
] | |
ui_title = "Tinystories LM 11M" | |
Description = """ | |
This is a small language model with 11M parameters, trained with the TinyStories dataset, and distilled from a 28M parameter teacher model.\n | |
This model has been trained with 512M tokens, which is about 0.9 epoch of the TinyStories dataset.\n | |
The PPL on the validation set is 1.7, in comparison, the teacher model has a PPL of 0.9. Lower PPL means better performance.\n | |
""" | |
if __name__ == "__main__": | |
load_model() | |
thread_free_mem = threading.Thread(target=main_free_mem) | |
thread_free_mem.start() | |
with gr.Blocks( | |
title="Tinystories LM 11M", | |
js="./random_input_example.js" | |
) as demo: | |
with gr.Blocks(title="Description"): | |
gr.HTML(f"<h1>{ui_title}</h1>") | |
gr.Markdown(Description) | |
with gr.Row(): | |
with gr.Column(): | |
inp = gr.Textbox(lines=5, label="Input Text", value=example_input[random.randint(0, len(example_input)-1)], elem_id="input_textbox") | |
generate_max_slider = gr.Slider(8, 64, step=1.0, value=16, label="more tokens", info="") | |
generate_button = gr.Button(value="Generate") | |
with gr.Column(): | |
out = gr.Textbox(lines=5, label="Output Text", value="") | |
out.readonly = True | |
def generate_inside(input, num_more_tokens): | |
return generate(input, num_more_tokens) | |
demo.queue() | |
demo.launch() |