Tar-7B / app.py
Jiaming Han
update
ab067b5
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# //
# // Licensed under the Apache License, Version 2.0 (the "License");
# // you may not use this file except in compliance with the License.
# // You may obtain a copy of the License at
# //
# // http://www.apache.org/licenses/LICENSE-2.0
# //
# // Unless required by applicable law or agreed to in writing, software
# // distributed under the License is distributed on an "AS IS" BASIS,
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# // See the License for the specific language governing permissions and
# // limitations under the License.
import os
import gradio as gr
from torchvision.transforms.functional import to_tensor
from huggingface_hub import hf_hub_download, snapshot_download, login
import spaces
from tok.ar_dtok.ar_model import ARModel
from t2i_inference import T2IConfig, TextToImageInference
def generate_text(self, image: str, prompt: str) -> str:
image = image.convert('RGB')
image = to_tensor(image).unsqueeze(0).to(self.device)
image_code = self.visual_tokenizer.encoder(image.to(self.config.dtype))['bottleneck_rep']
image_text = "".join([f"<I{x}>" for x in image_code[0].cpu().tolist()])
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"{image_text}\n{prompt}"}
]
input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(input_text, return_tensors="pt")
gen_ids = self.model.generate(
inputs.input_ids.to(self.device),
max_new_tokens=512,
do_sample=True)
return self.tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]
login(token=os.getenv('HF_TOKEN'))
config = T2IConfig()
config.model = snapshot_download("ByteDance-Seed/Tar-7B")
config.ar_path = {
"1024px": hf_hub_download("ByteDance-Seed/Tar-TA-Tok", "ar_dtok_lp_1024px.pth"),
"512px": hf_hub_download("ByteDance-Seed/Tar-TA-Tok", "ar_dtok_lp_512px.pth"),
}
config.encoder_path = hf_hub_download("ByteDance-Seed/Tar-TA-Tok", "ta_tok.pth")
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
inference = TextToImageInference(config)
@spaces.GPU(duration=120)
def generate_image(prompt, resolution, top_p, top_k, cfg_scale):
image = inference.generate_image(prompt, resolution, top_p, top_k, cfg_scale)
return image
def clear_inputs_t2i():
return "", None
@spaces.GPU(duration=120)
def understand_image(image, prompt):
return generate_text(inference, image, prompt)
def clear_inputs_i2t():
return None, "", ""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
<div align="center">
### Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations
[πŸ•ΈοΈ Project Page](http://tar.csuhan.com) β€’ [πŸ“„ Paper](http://arxiv.org/abs/2506.18898) β€’ [πŸ’» Code](https://github.com/csuhan/Tar) β€’ [πŸ“¦ Model](https://huggingface.co/collections/ByteDance-Seed/tar-6864cf0d9fe59a3b91cc4260)
</div>
""",
elem_id="title",
)
with gr.Tab("Image Generation"):
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt", value="A photo of a macaw")
with gr.Accordion("Advanced Settings", open=False):
resolution = gr.Radio(
["512px", "1024px"], value="1024px", label="Resolution"
)
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
with gr.Row():
generate_btn = gr.Button("Generate")
clear_btn = gr.Button("Clear")
with gr.Column(scale=2):
output_image = gr.Image(label="Generated Image")
generate_btn.click(
generate_image,
inputs=[prompt, resolution, top_p, top_k, cfg_scale],
outputs=output_image
)
clear_btn.click(
clear_inputs_t2i,
outputs=[prompt, output_image]
)
with gr.Tab("Image Understanding"):
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Upload Image", type="pil", value="https://raw.githubusercontent.com/csuhan/Tar/refs/heads/main/asset/dog_cat.jpg")
question_input = gr.Textbox(label="Instruction", value="Describe the image shortly.")
with gr.Row():
qa_btn = gr.Button("Generate")
clear_btn_i2t = gr.Button("Clear")
with gr.Column(scale=1):
answer_output = gr.Textbox(label="Response", lines=4)
qa_btn.click(
understand_image,
inputs=[image_input, question_input],
outputs=answer_output
)
clear_btn_i2t.click(
clear_inputs_i2t,
outputs=[image_input, question_input, answer_output]
)
demo.launch(share=True)