Spaces:
Running
Running
File size: 3,881 Bytes
46ed6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
import torch.nn as nn
import gradio as gr
from tsai_gpt.tokenizer import Tokenizer
import lightning as L
from lightning.fabric.loggers import CSVLogger
from pathlib import Path
from tsai_gpt.utils import num_parameters, load_checkpoint, get_default_supported_precision
from tsai_gpt.model import GPT, Block, Config
model_name = "pythia-160m"
name = "redpajama"
out_dir = Path("out") / name
log_interval = 100
precision = get_default_supported_precision(False)
logger = CSVLogger("out", name, flush_logs_every_n_steps=log_interval)
fabric = L.Fabric(devices=1, strategy="auto", precision=precision, loggers=logger)
config = Config.from_name(model_name)
def _init_weights(module: nn.Module) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
with fabric.init_module(empty_init=True):
model = GPT(config)
model.apply(_init_weights)
model.apply(_init_weights)
checkpoint_path = Path("out/redpajama/iter-025000-ckpt.pth")
load_checkpoint(fabric, model, checkpoint_path)
# print(model.transformer.h[0].mlp.fc.weight)
# fabric.print(f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.")
# fabric.print(f"Total parameters {num_parameters(model):,}")
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95
learning_rate = 6e-3
hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")}
model = fabric.setup(model)
optimizer = torch.optim.AdamW(
model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), foreach=False
)
# model_copy = model
optimizer = fabric.setup_optimizers(optimizer)
state = {"model": model, "optimizer": optimizer, "hparams": hparams, "iter_num": 0, "step_count": 0}
resume = max(out_dir.glob("*.pth"), key=lambda p: int(p.name.split("-")[1]))
if resume:
fabric.print(f"Loading model from {resume}")
fabric.load(resume, state)
deviceType = 'cuda' if torch.cuda.is_available() else 'cpu'
m = model.to(deviceType)
tokenizer_gpt = Tokenizer(checkpoint_dir=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf"))
def fn_query_on_load():
return "Biofuels would disrupt"
def generate_output(prompt, max_new_tokens=200, temperature=0.8, top_k=50):
m.eval()
encoded_text = tokenizer_gpt.encode(prompt)
# print('--------------------encoded text = ',encoded_text)
reshaped_tensor = torch.unsqueeze(encoded_text, 0).to(deviceType)
# print('--------------------reshaped_tensor = ',reshaped_tensor)
out_text = tokenizer_gpt.decode(
m.generate(reshaped_tensor, max_new_tokens=max_new_tokens, temperature=0.8, top_k=50)[0])
m.train()
return {
output: out_text
}
with gr.Blocks() as app:
with gr.Row():
gr.Markdown(
"""
# MiniGPT - GPT Training on LLaMa with redpajama dataset
### Enter a context to generate automated text "
""")
with gr.Row(visible=True):
search_text = gr.Textbox(value=fn_query_on_load, placeholder='Enter prompt..', label='Enter Prompt')
with gr.Row():
submit_btn = gr.Button("Submit", variant='primary')
clear_btn = gr.ClearButton()
with gr.Row():
with gr.Row():
output = gr.Textbox(lines=15, interactive=False, label='Out Box')
def clear_data():
return {
output: None,
search_text: None
}
clear_btn.click(clear_data, None, [output, search_text])
submit_btn.click(
generate_output,
search_text,
output
)
'''
Launch the app
'''
app.queue().launch() |