|
|
|
"""S22.ipynb |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1pq0UO46D0emoqF8rPuD4cUznmYVSMESO |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
torch.cuda.is_available() |
|
|
|
import glob |
|
import math |
|
import sys |
|
import time |
|
from pathlib import Path |
|
from typing import Optional, Tuple, Union |
|
|
|
import lightning as L |
|
from lightning.fabric.loggers import CSVLogger |
|
from lightning.fabric.strategies import FSDPStrategy |
|
|
|
from tsai_gpt.model import GPT, Block, Config |
|
from tsai_gpt.packed_dataset import CombinedDataset, PackedDataset |
|
from tsai_gpt.speed_monitor import SpeedMonitorBase, estimate_flops, measure_flops |
|
from tsai_gpt.speed_monitor import SpeedMonitorFabric as SpeedMonitor |
|
from tsai_gpt.utils import chunked_cross_entropy, get_default_supported_precision, num_parameters, load_checkpoint |
|
import os |
|
import pickle |
|
from contextlib import nullcontext |
|
from torch.utils.data import DataLoader |
|
import torch.nn.functional as F |
|
from tsai_gpt.tokenizer import Tokenizer |
|
import gradio as gr |
|
|
|
model_name = "pythia-160m" |
|
name = "redpajama" |
|
out_dir = Path("out") / name |
|
|
|
hparams = {k: v for k, v in locals().items() if isinstance(v, (int, float, str)) and not k.startswith("_")} |
|
logger = CSVLogger("out", name, flush_logs_every_n_steps=log_interval) |
|
|
|
fabric = L.Fabric(devices=1, strategy='auto', precision=None, loggers=logger) |
|
|
|
checkpoint_path = Path("out/redpajama/iter-023999-ckpt.pth") |
|
config = Config.from_name(model_name) |
|
model = GPT(config) |
|
|
|
load_checkpoint(fabric, model, checkpoint_path) |
|
|
|
|
|
|
|
def generate( model, config, idx, max_new_tokens, temperature=1.0, top_k=None): |
|
""" |
|
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
|
the sequence max_new_tokens times, feeding the predictions back into the model each time. |
|
Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
|
|
|
""" |
|
idx = idx.unsqueeze(dim=0) |
|
for _ in range(max_new_tokens): |
|
|
|
|
|
idx_cond = idx if idx.size(1) <= config.block_size else idx[ :,-config.block_size:] |
|
|
|
idx_cd = idx |
|
logits = model(idx_cd) |
|
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
if top_k is not None: |
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
logits[logits < v[:, [-1]]] = -float('Inf') |
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
|
|
return idx |
|
|
|
|
|
|
|
checkpoint_dir = Path('./checkpoints/meta-llama/Llama-2-7b-chat-hf') |
|
token = Tokenizer(checkpoint_dir = checkpoint_dir) |
|
|
|
def tsaigpt(start:str , model= model, max_new_tokens = 300, num_samples =2, tokeniser= token): |
|
|
|
|
|
|
|
|
|
temperature = 0.8 |
|
top_k = 200 |
|
seed = 1337 |
|
device = 'cpu' |
|
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' |
|
compile = False |
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
device_type = 'cuda' if 'cuda' in device else 'cpu' |
|
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
|
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
|
model.eval() |
|
model.to(device) |
|
if compile: |
|
model = torch.compile(model) |
|
|
|
|
|
|
|
start_ids = tokeniser.encode(start).to(device) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
with ctx: |
|
|
|
y = generate(model =model, config =config , max_new_tokens = max_new_tokens, idx = start_ids ,temperature=1.0, top_k=None) |
|
|
|
output = tokeniser.decode(y[0]) |
|
return output |
|
|
|
INTERFACE = gr.Interface(fn=tsaigpt, inputs=[gr.Textbox(label= "Prompt", value= 'All that glisters is not gold.'), |
|
gr.Slider(minimum = 300, maximum = 500, value= 300, label= "Maximum number of tokens to be generated")] , |
|
outputs=gr.Text(label= "Generated Text"), title="TSAI_GPT", |
|
description="TSAIGPT is a transformer-based language model with only 0.16 billion parameters, trained on RedPajama 1T Sample.", |
|
examples = [['We know what we are, but know not what we may be',300], |
|
['Sweet are the uses of adversity which, like the toad, ugly and venomous, wears yet a precious jewel in his head',300],] |
|
).launch(debug=True) |