| import subprocess |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from PIL import Image |
| import gradio as gr |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| AutoImageProcessor, |
| AutoModel, |
| ) |
| from transformers.generation.configuration_utils import GenerationConfig |
| from transformers.generation import ( |
| LogitsProcessorList, |
| PrefixConstrainedLogitsProcessor, |
| UnbatchedClassifierFreeGuidanceLogitsProcessor, |
| ) |
| import torch |
| from emu3.mllm.processing_emu3 import Emu3Processor |
|
|
| import io |
| import base64 |
|
|
| def image2str(image): |
| buf = io.BytesIO() |
| image.save(buf, format="PNG") |
| i_str = base64.b64encode(buf.getvalue()).decode() |
| return f'<div style="float:left"><img src="data:image/png;base64, {i_str}"></div>' |
|
|
| print(gr.__version__) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| EMU_GEN_HUB = "BAAI/Emu3-Gen" |
| EMU_CHAT_HUB = "BAAI/Emu3-Chat" |
| VQ_HUB = "BAAI/Emu3-VisionTokenizer" |
|
|
|
|
| |
| |
| |
| gen_model = AutoModelForCausalLM.from_pretrained( |
| EMU_GEN_HUB, |
| device_map="cpu", |
| torch_dtype=torch.bfloat16, |
| attn_implementation="flash_attention_2", |
| trust_remote_code=True, |
| ).eval() |
|
|
| chat_model = AutoModelForCausalLM.from_pretrained( |
| EMU_CHAT_HUB, |
| device_map="cpu", |
| torch_dtype=torch.bfloat16, |
| attn_implementation="flash_attention_2", |
| trust_remote_code=True, |
| ).eval() |
|
|
| tokenizer = AutoTokenizer.from_pretrained(EMU_CHAT_HUB, trust_remote_code=True) |
| image_processor = AutoImageProcessor.from_pretrained( |
| VQ_HUB, trust_remote_code=True |
| ) |
| image_tokenizer = AutoModel.from_pretrained( |
| VQ_HUB, device_map="cpu", trust_remote_code=True |
| ).eval() |
|
|
| print(device) |
| image_tokenizer.to(device) |
|
|
| processor = Emu3Processor( |
| image_processor, image_tokenizer, tokenizer |
| ) |
|
|
| def generate_image(prompt): |
| POSITIVE_PROMPT = " masterpiece, film grained, best quality." |
| NEGATIVE_PROMPT = ( |
| "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, " |
| "fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, " |
| "signature, watermark, username, blurry." |
| ) |
|
|
| classifier_free_guidance = 3.0 |
| full_prompt = prompt + POSITIVE_PROMPT |
|
|
| kwargs = dict( |
| mode="G", |
| ratio="1:1", |
| image_area=gen_model.config.image_area, |
| return_tensors="pt", |
| ) |
| pos_inputs = processor(text=full_prompt, **kwargs) |
| neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs) |
|
|
| |
| GENERATION_CONFIG = GenerationConfig( |
| use_cache=True, |
| eos_token_id=gen_model.config.eos_token_id, |
| pad_token_id=gen_model.config.pad_token_id, |
| max_new_tokens=40960, |
| do_sample=True, |
| top_k=2048, |
| ) |
|
|
| torch.cuda.empty_cache() |
| gen_model.to(device) |
|
|
| h, w = pos_inputs.image_size[0] |
| constrained_fn = processor.build_prefix_constrained_fn(h, w) |
| logits_processor = LogitsProcessorList( |
| [ |
| UnbatchedClassifierFreeGuidanceLogitsProcessor( |
| classifier_free_guidance, |
| gen_model, |
| unconditional_ids=neg_inputs.input_ids.to(device), |
| ), |
| PrefixConstrainedLogitsProcessor( |
| constrained_fn, |
| num_beams=1, |
| ), |
| ] |
| ) |
|
|
| |
| outputs = gen_model.generate( |
| pos_inputs.input_ids.to(device), |
| generation_config=GENERATION_CONFIG, |
| logits_processor=logits_processor, |
| ) |
|
|
| mm_list = processor.decode(outputs[0]) |
| result = None |
| for idx, im in enumerate(mm_list): |
| if isinstance(im, Image.Image): |
| result = im |
| break |
|
|
| gen_model.cpu() |
| torch.cuda.empty_cache() |
| |
| return result |
|
|
| def vision_language_understanding(image, text): |
| inputs = processor( |
| text=text, |
| image=image, |
| mode="U", |
| padding_side="left", |
| padding="longest", |
| return_tensors="pt", |
| ) |
|
|
| |
| GENERATION_CONFIG = GenerationConfig( |
| pad_token_id=tokenizer.pad_token_id, |
| bos_token_id=tokenizer.bos_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| max_new_tokens=320, |
| ) |
|
|
| torch.cuda.empty_cache() |
| chat_model.to(device) |
|
|
| |
| outputs = chat_model.generate( |
| inputs.input_ids.to(device), |
| generation_config=GENERATION_CONFIG, |
| max_new_tokens=320, |
| ) |
|
|
| outputs = outputs[:, inputs.input_ids.shape[-1] :] |
| response = processor.batch_decode(outputs, skip_special_tokens=True)[0] |
|
|
| chat_model.cpu() |
| torch.cuda.empty_cache() |
| |
| return response |
|
|
| |
| def chat(history, user_input, user_image): |
| if user_image is not None: |
| |
| response = vision_language_understanding(user_image, user_input) |
| |
| history = history + [(image2str(user_image) + "<br>" + user_input, response)] |
| else: |
| |
| generated_image = generate_image(user_input) |
| if generated_image is not None: |
| |
| history = history + [(user_input, image2str(generated_image))] |
| else: |
| |
| history = history + [ |
| (user_input, "Sorry, I could not generate an image.") |
| ] |
|
|
| return history, history, gr.update(value=None) |
|
|
| |
| def clear_input(): |
| return gr.update(value="") |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Emu3 Chatbot Demo") |
| gr.Markdown( |
| "This is a chatbot demo for image generation and vision-language understanding using Emu3 models." |
| ) |
| gr.Markdown( |
| "Please provide <b>only text input</b> for image generation (<b>\~600s</b>) and <b>both image and text</b> for vision-language understanding (<b>\~20s</b>)" |
| ) |
|
|
| state = gr.State([]) |
| with gr.Row(): |
| with gr.Column(scale=0.2): |
| user_input = gr.Textbox( |
| show_label=False, placeholder="Type your message here...", lines=10, container=False, |
| ) |
| user_image = gr.Image( |
| sources="upload", type="pil", label="Upload an image (optional)" |
| ) |
| submit_btn = gr.Button("Send") |
|
|
| with gr.Column(scale=0.8): |
| chatbot = gr.Chatbot(height=800) |
|
|
| submit_btn.click( |
| chat, |
| inputs=[state, user_input, user_image], |
| outputs=[chatbot, state, user_image], |
| ).then(fn=clear_input, inputs=[], outputs=user_input, queue=False) |
| user_input.submit( |
| chat, |
| inputs=[state, user_input, user_image], |
| outputs=[chatbot, state, user_image], |
| ).then(fn=clear_input, inputs=[], outputs=user_input, queue=False) |
|
|
| demo.launch(max_threads=1).queue() |
|
|