TAPA / app.py
xuxw98's picture
Upload 2 files
c9eccf3
raw
history blame
7.38 kB
import sys
import time
import warnings
from pathlib import Path
from typing import Optional
import lightning as L
import torch
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from generate import generate
from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
from scripts.prepare_alpaca import generate_prompt
# 配置hugface环境
from huggingface_hub import hf_hub_download
import gradio as gr
import os
import glob
import json
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.set_float32_matmul_precision("high")
# quantize: Optional[str] = "llm.int8",
def model_load(
adapter_path: Path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned_15k.pth"),
pretrained_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"),
quantize: Optional[str] = None,
):
fabric = L.Fabric(devices=1)
dtype = torch.bfloat16 if fabric.device.type == "cuda" and torch.cuda.is_bf16_supported() else torch.float32
with lazy_load(pretrained_path) as pretrained_checkpoint, lazy_load(adapter_path) as adapter_checkpoint:
name = llama_model_lookup(pretrained_checkpoint)
with EmptyInitOnDevice(
device=fabric.device, dtype=dtype, quantization_mode=quantize
):
model = LLaMA.from_name(name)
# 1. Load the pretrained weights
model.load_state_dict(pretrained_checkpoint, strict=False)
# 2. Load the fine-tuned adapter weights
model.load_state_dict(adapter_checkpoint, strict=False)
model.eval()
model = fabric.setup_module(model)
return model
def instruct_generate(
img_path: str = " ",
prompt: str = "What food do lamas eat?",
input: str = "",
max_new_tokens: int = 100,
temperature: float = 0.8,
top_k: int = 200,
) -> None:
"""Generates a response based on a given instruction and an optional input.
This script will only work with checkpoints from the instruction-tuned LLaMA-Adapter model.
See `finetune_adapter.py`.
Args:
prompt: The prompt/instruction (Alpaca style).
adapter_path: Path to the checkpoint with trained adapter weights, which are the output of
`finetune_adapter.py`.
input: Optional input (Alpaca style).
pretrained_path: The path to the checkpoint with pretrained LLaMA weights.
tokenizer_path: The tokenizer path to load.
quantize: Whether to quantize the model and using which method:
``"llm.int8"``: LLM.int8() mode,
``"gptq.int4"``: GPTQ 4-bit mode.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
"""
if input in input_value_2_real.keys():
input = input_value_2_real[input]
if "..." in input:
input = input.replace("...", "")
sample = {"instruction": prompt, "input": input}
prompt = generate_prompt(sample)
encoded = tokenizer.encode(prompt, bos=True, eos=False, device=model.device)
# prompt_length = encoded.size(0)
y = generate(
model,
idx=encoded,
max_seq_length=max_new_tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id
)
# y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id)
output = tokenizer.decode(y)
output = output.split("### Response:")[1].strip()
print(output)
return output
# 配置具体参数
# pretrained_path = hf_hub_download(
# repo_id="Gary3410/pretrain_lit_llama", filename="lit-llama.pth")
# tokenizer_path = hf_hub_download(
# repo_id="Gary3410/pretrain_lit_llama", filename="tokenizer.model")
# adapter_path = hf_hub_download(
# repo_id="Gary3410/pretrain_lit_llama", filename="lit-llama-adapter-finetuned_15k.pth")
adapter_path = "lit-llama-adapter-finetuned_15k.pth"
tokenizer_path = "tokenizer.model"
pretrained_path = "lit-llama.pth"
example_path = "example.json"
# 1024如果不够, 调整为512
max_seq_len = 1024
max_batch_size = 1
model = model_load(adapter_path, pretrained_path)
tokenizer = Tokenizer(tokenizer_path)
with open(example_path, 'r') as f:
content = f.read()
example_dict = json.loads(content)
input_value_2_real = {}
for scene_id, scene_dict in example_dict.items():
input_value_2_real[scene_dict["input_display"]] = scene_dict["input"]
def create_instruct_demo():
with gr.Blocks() as instruct_demo:
with gr.Row():
with gr.Column():
scene_img = gr.Image(label='Scene', type='filepath', shape=(1024, 320), height=320, width=1024, interactive=False)
object_list = gr.Textbox(
lines=5, label="Object List", placeholder="Please click one from the examples below", interactive=False)
instruction = gr.Textbox(
lines=2, label="Instruction", placeholder="Please input the instruction. E.g.Please turn on the lamp")
max_len = gr.Slider(minimum=256, maximum=1024,
value=1024, label="Max length")
with gr.Accordion(label='Advanced options', open=False):
temp = gr.Slider(minimum=0, maximum=1,
value=0.8, label="Temperature")
top_k = gr.Slider(minimum=100, maximum=300,
value=200, label="Top k")
run_botton = gr.Button("Run")
with gr.Column():
outputs = gr.Textbox(lines=20, label="Output")
inputs = [scene_img, instruction, object_list, max_len, temp, top_k]
# inputs = [scene_img, instruction, object_list]
# 接下来设定具体的example格式
examples_img_list = glob.glob("caption_demo/*.png")
examples = []
for example_img_one in examples_img_list:
scene_name = os.path.basename(example_img_one).split(".")[0]
example_object_list = example_dict[scene_name]["input"]
example_instruction = example_dict[scene_name]["instruction"]
example_one = [example_img_one, example_instruction, example_object_list]
examples.append(example_one)
gr.Examples(
examples=examples,
inputs=inputs,
outputs=outputs,
fn=instruct_generate,
cache_examples=os.getenv('SYSTEM') == 'spaces'
)
# inputs = inputs + [max_len, temp, top_k]
run_botton.click(fn=instruct_generate, inputs=inputs, outputs=outputs)
return instruct_demo
# Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
description = """
# TaPA
The official demo for **Embodied Task Planning with Large Language Models**.
"""
with gr.Blocks(css='style.css') as demo:
gr.Markdown(description)
with gr.TabItem("Instruction-Following"):
create_instruct_demo()
demo.queue(api_open=True, concurrency_count=1).launch()