google-tpu / app.py
florentgbelidji's picture
Update app.py
03f2a71 verified
raw
history blame contribute delete
No virus
3.47 kB
import gradio as gr
from transformers import BartTokenizer, BartForConditionalGeneration
import datetime
import os
import time
from typing import List
import torch
import torch_xla.core.xla_model as xm
from transformers import AutoTokenizer, StaticCache
from optimum.tpu.modeling import AutoModelForCausalLM
os.environ["PJRT_DEVICE"] = "TPU"
def sample_greedy(logits):
next_logits = logits[:, -1]
next_token_id = torch.argmax(next_logits, dim=-1)[:, None].int()
return next_token_id
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token,
position_ids=input_pos,
cache_position=cache_position,
return_dict=False,
use_cache=True,
past_key_values=past_key_values,
)[0]
new_token = sample_greedy(logits)
return new_token
def conditional_compile(func):
if "DBG_COMPILE" in os.environ:
compiled = torch.compile(func, backend="openxla")
return compiled
return func
model_id = "google/gemma-2b"
torch_dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
device = model.device
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
def summarize(inp, model=model, tokenizer=tokenizer, device=device):
with torch.no_grad():
inp = inp.replace('\n','')
inputs = tokenizer(inp, return_tensors="pt", padding=True).to(device)
batch_size, sequence_length = inputs["input_ids"].shape
max_cache_length = 1024
max_new_tokens = 64
# setup static cache
past_key_values = StaticCache(
config=model.config,
max_batch_size=batch_size,
max_cache_len=max_cache_length,
device=model.device,
dtype=model.dtype,
)
cache_position = torch.arange(sequence_length, device=device)
generated_ids = torch.zeros(
(batch_size, sequence_length + max_new_tokens + 1),
dtype=torch.int,
device=device,
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch.int)
# prefill here
attention_mask = inputs["attention_mask"]
pos_ids = (attention_mask.cumsum(-1) - 1).masked_fill(attention_mask == 0, 0)
logits = model(
**inputs,
cache_position=cache_position,
return_dict=False,
use_cache=True,
position_ids=pos_ids,
past_key_values=past_key_values,
)[0]
next_token = sample_greedy(logits)
xm.mark_step()
generated_ids[:, sequence_length] = next_token[:, 0]
pos_ids = pos_ids.max(axis=-1)[0].unsqueeze(1) + 1
model = conditional_compile(model)
cache_position = torch.tensor([sequence_length], device=device)
for i in range(max_new_tokens):
next_token = decode_one_tokens(model, next_token.clone(), pos_ids, cache_position, past_key_values)
cache_position += 1
generated_ids[:, cache_position] = next_token
pos_ids += 1
xm.mark_step()
decoded_texts = tokenizer.batch_decode(generated_ids)
response = " ".join(decoded_texts)
return response
gr.Interface(fn=summarize, inputs=gr.Textbox(lines=7, label="Input Text"), outputs="text", title="gemma-2b simple TPU demo").launch(inline=False)