Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import transformers | |
| import torch | |
| import yaml | |
| from dearth_config import DearthConfig | |
| from dearth_model import DearthForCausalLM | |
| import random | |
| import time | |
| import threading | |
| import asyncio | |
| tk = None | |
| model_states = None | |
| lock_using_model = threading.Lock() | |
| recent_generate_timestamp = time.time() | |
| MODEL_LIVE_TIME = 5 * 60 # 5 minutes | |
| def load_model(): | |
| global tk, model_states | |
| tk = transformers.AutoTokenizer.from_pretrained("./tk") | |
| model_path = "./ts100-re2-h1-4000-model.pt" | |
| states = torch.load(model_path, map_location="cpu") | |
| model_states = states | |
| unwanted_prefix_dueto_compile = '_orig_mod.' | |
| unwanted_prefix_dueto_ddp = 'module.' | |
| unwanted_prefix_dueto_ddp_compiled = 'module._orig_mod.' | |
| for k,v in list(model_states.items()): | |
| if k.startswith(unwanted_prefix_dueto_ddp_compiled): | |
| new_key = k[len(unwanted_prefix_dueto_ddp_compiled):] | |
| model_states[new_key] = model_states.pop(k) | |
| elif k.startswith(unwanted_prefix_dueto_ddp): | |
| new_key = k[len(unwanted_prefix_dueto_ddp):] | |
| model_states[new_key] = model_states.pop(k) | |
| elif k.startswith(unwanted_prefix_dueto_compile): | |
| new_key = k[len(unwanted_prefix_dueto_compile):] | |
| model_states[new_key] = model_states.pop(k) | |
| def main_free_mem(): | |
| event_loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(event_loop) | |
| event_loop.call_later(MODEL_LIVE_TIME, free_mem) | |
| event_loop.run_forever() | |
| def free_mem(): | |
| global tk, model_states, recent_generate_timestamp, lock_using_model | |
| lock_using_model.acquire() | |
| if time.time() - recent_generate_timestamp >= MODEL_LIVE_TIME and tk is not None: | |
| tk = None | |
| model_states = None | |
| print(f"free mem, {time.time()}") | |
| lock_using_model.release() | |
| try: | |
| event_loop = asyncio.get_event_loop() | |
| event_loop.call_later(MODEL_LIVE_TIME, free_mem) | |
| except: | |
| pass | |
| def generate(input, num_more_tokens): | |
| global tk, model_states, model, recent_generate_timestamp, lock_using_model | |
| lock_using_model.acquire() | |
| time_start = time.time() | |
| if tk is None: | |
| load_model() | |
| elif time.time() - recent_generate_timestamp > MODEL_LIVE_TIME: | |
| tk = None | |
| model_states = None | |
| load_model() | |
| yml_path = "./ts100-re2-h1.yml" | |
| with open(yml_path, "r") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader)['model'] | |
| if "vocab_size" not in config: | |
| config['vocab_size'] = tk.vocab_size | |
| config["attn_window_size"] = 500 | |
| # print(config) | |
| config = DearthConfig(**config) | |
| model = DearthForCausalLM(config) | |
| model.load_state_dict(model_states) | |
| model.eval() | |
| recent_generate_timestamp = time.time() | |
| print(f"load model time: {time.time() - time_start}") | |
| time_start = time.time() | |
| num_more_tokens = int(num_more_tokens) | |
| # print(input) | |
| input = input.strip() | |
| input_ids = tk.encode(input) | |
| input_ids = [tk.bos_token_id] + input_ids | |
| input_ids = torch.tensor(input_ids, dtype=torch.long).view(1, -1) | |
| # print(input_ids) | |
| print(f"encode time: {time.time() - time_start}") | |
| time_start = time.time() | |
| output_ids = input_ids.squeeze(0).tolist() | |
| for i in range(num_more_tokens): | |
| input = torch.tensor(output_ids, dtype=torch.long).view(1, -1) | |
| with torch.no_grad(): | |
| output = model(input)[0] | |
| last_token_logits = output[0, -1, :] | |
| last_token_logits_topk = torch.topk(last_token_logits, k=5, dim=-1) | |
| probs = torch.softmax(last_token_logits_topk.values, dim=-1) | |
| new_token = torch.multinomial(probs, num_samples=1).item() | |
| new_token = last_token_logits_topk.indices[new_token].item() | |
| if new_token == tk.eos_token_id: | |
| break | |
| output_ids.append(new_token) | |
| # print(output_ids) | |
| # print(tk.decode(output_ids)) | |
| output_ids = output_ids[1:] | |
| print(f"inference time: {time.time() - time_start}\n") | |
| ret = tk.decode(output_ids) | |
| lock_using_model.release() | |
| return ret | |
| example_input = ["Once upon a time, there was a little girl", | |
| "John and Sarah were playing together in their backyard when", | |
| "It was a warm summer day when Billy and", | |
| ] | |
| ui_title = "Tinystories LM 11M" | |
| Description = """ | |
| This is a small language model with 11M parameters, trained with the TinyStories dataset, and distilled from a 28M parameter teacher model.\n | |
| This model has been trained with 512M tokens, which is about 0.9 epoch of the TinyStories dataset.\n | |
| The PPL on the validation set is 1.7, in comparison, the teacher model has a PPL of 0.9. Lower PPL means better performance.\n | |
| """ | |
| if __name__ == "__main__": | |
| load_model() | |
| thread_free_mem = threading.Thread(target=main_free_mem) | |
| thread_free_mem.start() | |
| with gr.Blocks( | |
| title="Tinystories LM 11M", | |
| js="./random_input_example.js" | |
| ) as demo: | |
| with gr.Blocks(title="Description"): | |
| gr.HTML(f"<h1>{ui_title}</h1>") | |
| gr.Markdown(Description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Textbox(lines=5, label="Input Text", value=example_input[random.randint(0, len(example_input)-1)], elem_id="input_textbox") | |
| generate_max_slider = gr.Slider(8, 64, step=1.0, value=16, label="more tokens", info="") | |
| generate_button = gr.Button(value="Generate") | |
| with gr.Column(): | |
| out = gr.Textbox(lines=5, label="Output Text", value="") | |
| out.readonly = True | |
| def generate_inside(input, num_more_tokens): | |
| return generate(input, num_more_tokens) | |
| demo.queue() | |
| demo.launch() |