Spaces:
Sleeping
Sleeping
File size: 6,048 Bytes
7dce37f 6d018f9 f109d37 6d018f9 f109d37 7dce37f 1da142f 7dce37f deb3f1f 7dce37f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import gradio as gr
import torch
import transformers
from transformers import BitsAndBytesConfig
from llavajp.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llavajp.conversation import conv_templates
from llavajp.model.llava_llama import LlavaLlamaForCausalLM
from llavajp.train.dataset import tokenizer_image_token
import spaces
# import subprocess
# import sys
# def install_package():
# subprocess.check_call([sys.executable, "-m", "pip", "install", "https://github.com/hibikaze-git/LLaVA-JP@feature/tanuki-moe"])
model_path = "weblab-GENIAC/Tanuki-8B-vision"
# load model
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
bnb_model_from_pretrained_args = {}
bnb_model_from_pretrained_args.update(dict(
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_skip_modules=["mm_projector", "vision_tower"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
))
model = LlavaLlamaForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
use_safetensors=True,
**bnb_model_from_pretrained_args
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path,
model_max_length=8192,
padding_side="right",
use_fast=False,
)
model.eval()
conv_mode = "v1"
@spaces.GPU(duration=120)
@torch.inference_mode()
def inference_fn(
image,
prompt,
max_len,
temperature,
top_p,
no_repeat_ngram_size
):
# prepare inputs
# image pre-process
image_size = model.get_model().vision_tower.image_processor.size["height"]
if model.get_model().vision_tower.scales is not None:
image_size = model.get_model().vision_tower.image_processor.size[
"height"
] * len(model.get_model().vision_tower.scales)
if device == "cuda":
image_tensor = (
model.get_model()
.vision_tower.image_processor(
image,
return_tensors="pt",
size={"height": image_size, "width": image_size},
)["pixel_values"]
.half()
.cuda()
.to(torch_dtype)
)
else:
image_tensor = (
model.get_model()
.vision_tower.image_processor(
image,
return_tensors="pt",
size={"height": image_size, "width": image_size},
)["pixel_values"]
.to(torch_dtype)
)
# create prompt
inp = DEFAULT_IMAGE_TOKEN + "\n" + prompt
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
).unsqueeze(0)
if device == "cuda":
input_ids = input_ids.to(device)
input_ids = input_ids[:, :-1] # </sep>がinputの最後に入るので削除する
# generate
output_ids = model.generate(
inputs=input_ids,
images=image_tensor,
do_sample=temperature != 0.0,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_len,
repetition_penalty=1.0,
use_cache=False,
no_repeat_ngram_size=no_repeat_ngram_size
)
output_ids = [
token_id for token_id in output_ids.tolist()[0] if token_id != IMAGE_TOKEN_INDEX
]
print(output_ids)
output = tokenizer.decode(output_ids, skip_special_tokens=True)
print(output)
target = "システム: "
idx = output.find(target)
output = output[idx + len(target) :]
return output
with gr.Blocks() as demo:
gr.Markdown("# Tanuki-8B-vision Demo")
with gr.Row():
with gr.Column():
# input_instruction = gr.TextArea(label="instruction", value=DEFAULT_INSTRUCTION)
input_image = gr.Image(type="pil", label="image")
prompt = gr.Textbox(label="prompt (optional)", value="")
with gr.Accordion(label="Configs", open=False):
max_len = gr.Slider(
minimum=10,
maximum=256,
value=200,
step=5,
interactive=True,
label="Max New Tokens",
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.5,
maximum=1.0,
value=1.0,
step=0.1,
interactive=True,
label="Top p",
)
no_repeat_ngram_size = gr.Slider(
minimum=0,
maximum=4,
value=3.0,
step=1,
interactive=True,
label="No Repeat Ngram Size(1, 2にすると出力が狂います)",
)
# button
input_button = gr.Button(value="Submit")
with gr.Column():
output = gr.Textbox(label="Output")
inputs = [input_image, prompt, max_len, temperature, top_p, no_repeat_ngram_size]
input_button.click(inference_fn, inputs=inputs, outputs=[output])
prompt.submit(inference_fn, inputs=inputs, outputs=[output])
img2txt_examples = gr.Examples(
examples=[
[
"https://raw.githubusercontent.com/hibikaze-git/LLaVA-JP/feature/package/imgs/sample1.jpg",
"猫の隣には何がありますか?",
128,
0.0,
1.0,
3.0
],
],
inputs=inputs,
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0") |