Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import spaces | |
| import gradio as gr | |
| import torch | |
| import math | |
| from PIL import Image | |
| from transformers import AutoTokenizer | |
| from model import LLaDAForMultiModalGeneration | |
| from utils.image_utils import ( | |
| decode_vq_to_image, calculate_vq_params, | |
| generate_crop_size_list, var_center_crop, add_break_line, | |
| encode_img_with_breaks, encode_img_with_paint | |
| ) | |
| from utils.prompt_utils import generate_text_image_to_text_image_prompt | |
| import torch.nn.functional as F | |
| MODEL = None | |
| TOKENIZER = None | |
| VQVAE = None | |
| DEVICE = None | |
| CURRENT_MODEL_PATH = None | |
| SPECIAL_TOKENS = { | |
| "mask_token": 126336, | |
| "newline_token": 126084, | |
| "image_token_offset": 126356, | |
| "answer_start": 126354, | |
| "answer_end": 126355, | |
| "boi": 126349, | |
| "eoi": 126350, | |
| "uncondition": 126351 | |
| } | |
| SYSTEM_PROMPT = "Generate an image applying the following editing instruction based on the original image." | |
| def cosine_schedule(t): | |
| return torch.cos(t * math.pi / 2) | |
| def add_gumbel_noise(logits, temperature=1.0, generator=None): | |
| if temperature == 0: | |
| return logits | |
| if generator is not None: | |
| uniform_noise = torch.rand(logits.shape, dtype=logits.dtype, device=logits.device, generator=generator) | |
| else: | |
| uniform_noise = torch.rand_like(logits) | |
| gumbel_noise = -torch.log(-torch.log(uniform_noise + 1e-10) + 1e-10) | |
| return logits + temperature * gumbel_noise | |
| def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): | |
| if generator is not None: | |
| noise = torch.randn(probs.shape, dtype=probs.dtype, device=probs.device, generator=generator) | |
| else: | |
| noise = torch.randn_like(probs) | |
| confidence = torch.log(probs + 1e-10) + temperature * noise | |
| sorted_confidence, sorted_indices = torch.sort(confidence, dim=-1, descending=False) | |
| if isinstance(mask_len, torch.Tensor): | |
| mask_len_clamped = torch.clamp(mask_len, 0, probs.shape[-1] - 1) | |
| mask_len_clamped = mask_len_clamped.long().squeeze(-1) | |
| else: | |
| mask_len_clamped = int(mask_len) | |
| if isinstance(mask_len_clamped, torch.Tensor): | |
| batch = probs.shape[0] | |
| masking = torch.zeros_like(probs, dtype=torch.bool, device=probs.device) | |
| for b in range(batch): | |
| k = mask_len_clamped[b].item() | |
| if k <= 0: | |
| continue | |
| low_idx = sorted_indices[b, :k] | |
| masking[b, low_idx] = True | |
| else: | |
| k = mask_len_clamped | |
| if k <= 0: | |
| masking = torch.zeros_like(probs, dtype=torch.bool, device=probs.device) | |
| else: | |
| low_idx = sorted_indices[:, :k] | |
| masking = torch.zeros_like(probs, dtype=torch.bool, device=probs.device) | |
| batch = probs.shape[0] | |
| for b in range(batch): | |
| masking[b, low_idx[b]] = True | |
| return masking | |
| def get_num_transfer_tokens(text_masked_indices, text_steps): | |
| batch_size = text_masked_indices.shape[0] | |
| initial_masks = text_masked_indices.sum(dim=1) | |
| num_transfer = torch.zeros(batch_size, text_steps, dtype=torch.long, device=text_masked_indices.device) | |
| for b in range(batch_size): | |
| total_masks = initial_masks[b].item() | |
| remaining = total_masks | |
| for step in range(text_steps): | |
| ratio = (step + 1) / text_steps | |
| target_remaining = int(total_masks * (1 - ratio)) | |
| tokens_to_unmask = max(0, remaining - target_remaining) | |
| num_transfer[b, step] = tokens_to_unmask | |
| remaining -= tokens_to_unmask | |
| return num_transfer | |
| def decode_text_with_masks(combined_input_ids, text_start, text_end, tokenizer, mask_token): | |
| text_ids = combined_input_ids[0, text_start:text_end].cpu().tolist() | |
| result_parts = [] | |
| consecutive_masks = 0 | |
| for token_id in text_ids: | |
| if token_id == mask_token: | |
| consecutive_masks += 1 | |
| else: | |
| if consecutive_masks > 0: | |
| if consecutive_masks <= 10: | |
| result_parts.append("β" * consecutive_masks) | |
| else: | |
| result_parts.append(f"βββββ[...{consecutive_masks - 5} more]") | |
| consecutive_masks = 0 | |
| try: | |
| token_text = tokenizer.decode([token_id], skip_special_tokens=False, clean_up_tokenization_spaces=False) | |
| if token_text.strip() or token_text in [' ', '\n', '\t']: | |
| result_parts.append(token_text) | |
| except: | |
| result_parts.append(f"[{token_id}]") | |
| if consecutive_masks > 0: | |
| if consecutive_masks <= 10: | |
| result_parts.append("β" * consecutive_masks) | |
| else: | |
| result_parts.append(f"βββββ[...{consecutive_masks - 5} more]") | |
| return "".join(result_parts) | |
| def generate_ti2ti_stepwise( | |
| model, input_ids, text_start, text_end, image_start, seq_len, newline_every, | |
| text_steps=100, temperature=1.0, text_temperature=0.7, cfg_scale=0.0, cfg_img=4.0, | |
| uncon_text=None, uncon_image=None, tokenizer=None, remasking='low_confidence', | |
| noise_schedule=cosine_schedule, generator=None, text_vocab_size=126356, | |
| codebook_size=8192, vqvae=None, image_height=512, image_width=512, | |
| ): | |
| device = input_ids.device | |
| MASK_TOKEN = SPECIAL_TOKENS["mask_token"] | |
| NEW_LINE = SPECIAL_TOKENS["newline_token"] | |
| combined_input_ids = input_ids.clone() | |
| num_vq_tokens = seq_len | |
| total_image_len = seq_len + seq_len // newline_every | |
| image_end = image_start + total_image_len | |
| text_masked_indices = combined_input_ids[:, text_start:text_end] == MASK_TOKEN | |
| num_transfer_tokens = get_num_transfer_tokens(text_masked_indices, text_steps) | |
| image_generation_step_indices = torch.linspace( | |
| 0, text_steps - 1, int(text_steps * 0.3) | |
| ).round().int().tolist() | |
| image_position_mapping = [] | |
| for i in range(image_start, image_end): | |
| if combined_input_ids[0, i] != NEW_LINE: | |
| image_position_mapping.append(i) | |
| batch_size = combined_input_ids.shape[0] | |
| initial_text_display = decode_text_with_masks(combined_input_ids, text_start, text_end, tokenizer, MASK_TOKEN) | |
| last_generated_image = None | |
| yield 0, initial_text_display, None, f"Step 0/{text_steps}" | |
| for step in range(text_steps): | |
| cond_logits = model(combined_input_ids, infer=True, use_cache=False).logits | |
| text_masked_indices = combined_input_ids[:, text_start:text_end] == MASK_TOKEN | |
| if text_masked_indices.sum() > 0: | |
| text_logits = cond_logits[:, text_start:text_end, :] | |
| logits_with_noise = add_gumbel_noise(text_logits, temperature=text_temperature, generator=generator) | |
| x0 = torch.argmax(logits_with_noise, dim=-1) | |
| if remasking == 'low_confidence': | |
| p = F.softmax(text_logits.to(torch.float64), dim=-1) | |
| x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) | |
| elif remasking == 'random': | |
| if generator is not None: | |
| x0_p = torch.rand(x0.shape, dtype=x0.dtype, device=x0.device, generator=generator) | |
| else: | |
| x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) | |
| else: | |
| x0_p = torch.ones_like(x0, dtype=torch.float) | |
| x0 = torch.where(text_masked_indices, x0, combined_input_ids[:, text_start:text_end]) | |
| confidence = torch.where(text_masked_indices, x0_p, float('-inf')) | |
| transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) | |
| for j in range(confidence.shape[0]): | |
| k = num_transfer_tokens[j, step].item() | |
| if k > 0: | |
| _, select_index = torch.topk(confidence[j], k=k) | |
| transfer_index[j, select_index] = True | |
| combined_input_ids[:, text_start:text_end][transfer_index] = x0[transfer_index] | |
| if step in image_generation_step_indices: | |
| vq_tokens_list = [] | |
| mask_positions = [] | |
| for idx, pos in enumerate(image_position_mapping): | |
| token = combined_input_ids[0, pos].item() | |
| if token == MASK_TOKEN: | |
| vq_tokens_list.append(-1) | |
| mask_positions.append(idx) | |
| else: | |
| vq_token = token - text_vocab_size | |
| vq_token = max(0, min(vq_token, codebook_size - 1)) | |
| vq_tokens_list.append(vq_token) | |
| vq_tokens_tensor = torch.tensor(vq_tokens_list, device=device).unsqueeze(0) | |
| unknown_map = vq_tokens_tensor == -1 | |
| cond_image_logits_list = [] | |
| for pos in image_position_mapping: | |
| cond_image_logits_list.append( | |
| cond_logits[:, pos:pos+1, text_vocab_size:text_vocab_size+codebook_size] | |
| ) | |
| cond_vq_logits = torch.cat(cond_image_logits_list, dim=1) | |
| if (cfg_scale > 0.0 and uncon_text is not None) or (cfg_img > 0.0 and uncon_image is not None): | |
| if uncon_text is None: | |
| combined_uncond_text = combined_input_ids.clone() | |
| else: | |
| combined_uncond_text = combined_input_ids.clone() | |
| prefix_len = uncon_text.shape[1] | |
| combined_uncond_text[:, :prefix_len] = uncon_text.to(device) | |
| if uncon_image is None: | |
| combined_uncond_img = combined_input_ids.clone() | |
| else: | |
| combined_uncond_img = combined_input_ids.clone() | |
| prefix_len_img = uncon_image.shape[1] | |
| combined_uncond_img[:, :prefix_len_img] = uncon_image.to(device) | |
| uncond_text_logits_full = model(combined_uncond_text, infer=True, use_cache=False).logits | |
| uncond_img_logits_full = model(combined_uncond_img, infer=True, use_cache=False).logits | |
| uncond_text_vq_list = [] | |
| uncond_img_vq_list = [] | |
| for pos in image_position_mapping: | |
| uncond_text_vq_list.append( | |
| uncond_text_logits_full[:, pos:pos+1, text_vocab_size:text_vocab_size+codebook_size] | |
| ) | |
| uncond_img_vq_list.append( | |
| uncond_img_logits_full[:, pos:pos+1, text_vocab_size:text_vocab_size+codebook_size] | |
| ) | |
| uncond_text_vq_logits = torch.cat(uncond_text_vq_list, dim=1) | |
| uncond_img_vq_logits = torch.cat(uncond_img_vq_list, dim=1) | |
| else: | |
| uncond_text_vq_logits = torch.zeros_like(cond_vq_logits) | |
| uncond_img_vq_logits = torch.zeros_like(cond_vq_logits) | |
| image_logits = cond_vq_logits | |
| if cfg_scale != 0.0: | |
| image_logits = image_logits + cfg_scale * (cond_vq_logits - uncond_text_vq_logits) | |
| if cfg_img != 0.0: | |
| image_logits = image_logits + cfg_img * (cond_vq_logits - uncond_img_vq_logits) | |
| probs = F.softmax(image_logits, dim=-1) | |
| if temperature == 0: | |
| sampled_ids = probs.argmax(dim=-1) | |
| else: | |
| sampled = probs.reshape(-1, image_logits.size(-1)) | |
| if generator is not None: | |
| sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*image_logits.shape[:-1]) | |
| else: | |
| sampled_ids = torch.multinomial(sampled, 1)[:, 0].view(*image_logits.shape[:-1]) | |
| sampled_ids = torch.where(unknown_map, sampled_ids, vq_tokens_tensor) | |
| sampled_ids = torch.clamp(sampled_ids, 0, codebook_size - 1) | |
| selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]).squeeze(-1) | |
| high_val = torch.finfo(selected_probs.dtype).max | |
| selected_probs = torch.where(unknown_map, selected_probs, high_val) | |
| ratio = 1.0 * (step + 1) / text_steps | |
| mask_ratio = noise_schedule(torch.tensor(ratio, device=device)) | |
| unknown_counts = unknown_map.sum(dim=-1, keepdim=True) | |
| mask_len = (num_vq_tokens * mask_ratio).floor().unsqueeze(0).to(device) | |
| mask_len = torch.max(torch.tensor([1], device=device), torch.min(unknown_counts - 1, mask_len.to(device).long())) | |
| if mask_len.ndim == 1: | |
| mask_len = mask_len.unsqueeze(1) | |
| img_temp = temperature * (1.0 - ratio) | |
| masking = mask_by_random_topk(mask_len, selected_probs, img_temp, generator=generator) | |
| final_vq_tokens = torch.where(masking, torch.tensor(-1, device=device), sampled_ids) | |
| for idx, pos in enumerate(image_position_mapping): | |
| v = final_vq_tokens[0, idx].item() | |
| if v == -1: | |
| combined_input_ids[0, pos] = MASK_TOKEN | |
| else: | |
| combined_input_ids[0, pos] = int(v + text_vocab_size) | |
| try: | |
| decoded_image = decode_vq_to_image( | |
| sampled_ids, None, None, image_height, image_width, vqvae | |
| ) | |
| masked_positions_bool = masking[0] | |
| if masked_positions_bool.sum() > 0: | |
| from PIL import ImageDraw | |
| decoded_image = decoded_image.copy() | |
| draw = ImageDraw.Draw(decoded_image, 'RGBA') | |
| vae_scale = 2 ** (len(VQVAE.config.block_out_channels) - 1) | |
| token_h = image_height // vae_scale | |
| token_w = image_width // vae_scale | |
| pixel_h = image_height // token_h | |
| pixel_w = image_width // token_w | |
| masked_indices = torch.where(masked_positions_bool)[0].cpu().tolist() | |
| for masked_idx in masked_indices: | |
| token_row = masked_idx // token_w | |
| token_col = masked_idx % token_w | |
| y1 = token_row * pixel_h | |
| x1 = token_col * pixel_w | |
| y2 = y1 + pixel_h | |
| x2 = x1 + pixel_w | |
| draw.rectangle([x1, y1, x2, y2], fill=(128, 128, 128, 120)) | |
| last_generated_image = decoded_image | |
| except Exception as e: | |
| pass | |
| text_display = decode_text_with_masks(combined_input_ids, text_start, text_end, tokenizer, MASK_TOKEN) | |
| text_masks_remaining = (combined_input_ids[:, text_start:text_end] == MASK_TOKEN).sum().item() | |
| text_progress = (1 - text_masks_remaining / (text_end - text_start)) * 100 | |
| status_msg = f"Step {step + 1}/{text_steps} | Text: {text_progress:.1f}%" | |
| if step in image_generation_step_indices: | |
| image_masks_remaining = sum(1 for pos in image_position_mapping if combined_input_ids[0, pos] == MASK_TOKEN) | |
| image_progress = (1 - image_masks_remaining / num_vq_tokens) * 100 | |
| status_msg += f" | Image: {image_progress:.1f}%" | |
| if step % 5 == 0 or step in image_generation_step_indices or step == text_steps - 1: | |
| yield step + 1, text_display, last_generated_image, status_msg | |
| final_text_display = decode_text_with_masks(combined_input_ids, text_start, text_end, tokenizer, MASK_TOKEN) | |
| if last_generated_image is not None: | |
| final_image = last_generated_image | |
| else: | |
| final_vq_tokens = [] | |
| final_mask_positions = [] | |
| for idx, pos in enumerate(image_position_mapping): | |
| token = combined_input_ids[0, pos].item() | |
| if token != MASK_TOKEN: | |
| vq_token = token - text_vocab_size | |
| vq_token = max(0, min(vq_token, codebook_size - 1)) | |
| final_vq_tokens.append(vq_token) | |
| else: | |
| final_vq_tokens.append(codebook_size // 2) | |
| final_mask_positions.append(idx) | |
| vq_tensor = torch.tensor(final_vq_tokens, dtype=torch.long, device=device).unsqueeze(0) | |
| final_image = decode_vq_to_image(vq_tensor, None, None, image_height, image_width, vqvae) | |
| if final_mask_positions: | |
| from PIL import ImageDraw | |
| final_image = final_image.copy() | |
| draw = ImageDraw.Draw(final_image, 'RGBA') | |
| vae_scale = 2 ** (len(VQVAE.config.block_out_channels) - 1) | |
| token_h = image_height // vae_scale | |
| token_w = image_width // vae_scale | |
| pixel_h = image_height // token_h | |
| pixel_w = image_width // token_w | |
| for masked_idx in final_mask_positions: | |
| token_row = masked_idx // token_w | |
| token_col = masked_idx % token_w | |
| y1 = token_row * pixel_h | |
| x1 = token_col * pixel_w | |
| y2 = y1 + pixel_h | |
| x2 = x1 + pixel_w | |
| draw.rectangle([x1, y1, x2, y2], fill=(128, 128, 128, 120)) | |
| yield text_steps, final_text_display, final_image, "β Complete" | |
| def load_model_and_vae(model_path, vae_path): | |
| global MODEL, TOKENIZER, VQVAE, DEVICE, CURRENT_MODEL_PATH | |
| if MODEL is not None and CURRENT_MODEL_PATH == model_path: | |
| return f"Model already loaded: {model_path}" | |
| try: | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| TOKENIZER = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
| MODEL = LLaDAForMultiModalGeneration.from_pretrained( | |
| model_path, torch_dtype=torch.bfloat16, device_map="auto" | |
| ) | |
| MODEL.eval() | |
| from diffusers import VQModel | |
| VQVAE = VQModel.from_pretrained(vae_path, subfolder="vqvae").to(DEVICE) | |
| CURRENT_MODEL_PATH = model_path | |
| return f"β Model loaded | Device: {DEVICE}" | |
| except Exception as e: | |
| MODEL = None | |
| TOKENIZER = None | |
| VQVAE = None | |
| CURRENT_MODEL_PATH = None | |
| return f"β Failed: {str(e)}" | |
| def generate_wrapper( | |
| input_image, prompt_text, model_path, vae_path, height, width, | |
| text_steps, text_gen_length, text_block_length, cfg_scale, cfg_img, | |
| temperature, text_temperature, remasking_strategy, painting_mode, | |
| mask_h_ratio, mask_w_ratio, seed, | |
| ): | |
| global MODEL, TOKENIZER, VQVAE, DEVICE | |
| if MODEL is None or TOKENIZER is None or VQVAE is None: | |
| load_status = load_model_and_vae(model_path, vae_path) | |
| if "Failed" in load_status: | |
| yield "", None, load_status | |
| return | |
| if input_image is None: | |
| yield "", None, "β No input image" | |
| return | |
| if seed != 0: | |
| torch.manual_seed(seed) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed) | |
| else: | |
| generator = None | |
| MASK = SPECIAL_TOKENS["mask_token"] | |
| NEW_LINE = SPECIAL_TOKENS["newline_token"] | |
| BOA = SPECIAL_TOKENS["answer_start"] | |
| EOA = SPECIAL_TOKENS["answer_end"] | |
| BOI = SPECIAL_TOKENS["boi"] | |
| EOI = SPECIAL_TOKENS["eoi"] | |
| try: | |
| input_prompt, uncon_text = generate_text_image_to_text_image_prompt( | |
| prompt_text, SYSTEM_PROMPT | |
| ) | |
| prompt_ids = TOKENIZER(input_prompt)["input_ids"] | |
| uncon_text_ids = TOKENIZER(uncon_text)["input_ids"] | |
| img = input_image.convert("RGB") | |
| crop_size_list = generate_crop_size_list((512 // 32) ** 2, 32) | |
| img = var_center_crop(img, crop_size_list=crop_size_list) | |
| input_img_token = encode_img_with_breaks(img, VQVAE) | |
| con_input_list = prompt_ids[:-1] + input_img_token + prompt_ids[-1:] | |
| uncon_input_text = uncon_text_ids[:-1] + input_img_token + uncon_text_ids[-1:] | |
| uncon_input_image = prompt_ids | |
| vae_scale = 2 ** (len(VQVAE.config.block_out_channels) - 1) | |
| seq_len, newline_every, token_grid_height, token_grid_width = calculate_vq_params( | |
| height, width, vae_scale | |
| ) | |
| text_mask_tokens = [MASK] * text_gen_length | |
| if painting_mode: | |
| img_mask_token, img_vis = encode_img_with_paint( | |
| img, vqvae=VQVAE, mask_h_ratio=mask_h_ratio, | |
| mask_w_ratio=mask_w_ratio, mask_mode=painting_mode | |
| ) | |
| else: | |
| img_mask_token = add_break_line( | |
| [MASK] * seq_len, token_grid_height, token_grid_width, | |
| new_number=NEW_LINE | |
| ) | |
| end_token_ids = TOKENIZER("</answer>", add_special_tokens=False).input_ids | |
| pred_token = [BOA] + [BOI] + img_mask_token + [EOI] + text_mask_tokens + end_token_ids | |
| code_start = len(con_input_list) | |
| image_start = len(con_input_list) + 2 | |
| image_end = image_start + len(img_mask_token) | |
| text_start = image_end + 1 | |
| text_end = text_start + text_gen_length | |
| full_input_ids = con_input_list + pred_token | |
| con_input = torch.tensor(full_input_ids, device=DEVICE).unsqueeze(0) | |
| uncon_input_text_tensor = torch.tensor(uncon_input_text, device=DEVICE).unsqueeze(0) | |
| uncon_input_image_tensor = torch.tensor(uncon_input_image, device=DEVICE).unsqueeze(0) | |
| config = MODEL.config | |
| text_vocab_size = getattr(config, 'text_vocab_size', 126356) | |
| codebook_size = getattr(config, 'codebook_size', 8192) | |
| for step, text_display, image, status in generate_ti2ti_stepwise( | |
| model=MODEL, input_ids=con_input, text_start=text_start, text_end=text_end, | |
| image_start=image_start, seq_len=seq_len, newline_every=newline_every, | |
| text_steps=text_steps, temperature=temperature, text_temperature=text_temperature, | |
| cfg_scale=cfg_scale, cfg_img=cfg_img, uncon_text=uncon_input_text_tensor, | |
| uncon_image=uncon_input_image_tensor, tokenizer=TOKENIZER, | |
| remasking=remasking_strategy, noise_schedule=cosine_schedule, | |
| generator=generator, text_vocab_size=text_vocab_size, | |
| codebook_size=codebook_size, vqvae=VQVAE, | |
| image_height=height, image_width=width, | |
| ): | |
| yield text_display, image, status | |
| except Exception as e: | |
| import traceback | |
| yield "", None, f"β Error: {str(e)}" | |
| css_styles = """ | |
| .gradio-container { | |
| font-family: 'IBM Plex Sans', sans-serif; | |
| max-width: 1400px !important; | |
| margin: auto; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(90deg, #7c3aed 0%, #a855f7 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| } | |
| .gr-button-primary:hover { | |
| transform: scale(1.02); | |
| box-shadow: 0 4px 12px rgba(124, 58, 237, 0.4) !important; | |
| } | |
| .output-markdown { | |
| min-height: 400px !important; | |
| max-height: 600px !important; | |
| overflow-y: auto !important; | |
| padding: 12px !important; | |
| background: #fafafa !important; | |
| border-radius: 8px !important; | |
| border: 1px solid #e0e0e0 !important; | |
| font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important; | |
| font-size: 13px !important; | |
| line-height: 1.5 !important; | |
| } | |
| .output-markdown .prose, | |
| .output-markdown .prose * { | |
| font-size: 10px !important; | |
| line-height: 1.4 !important; | |
| } | |
| .output-markdown h1 { | |
| font-size: 1.4em !important; | |
| margin-top: 0.8em !important; | |
| margin-bottom: 0.4em !important; | |
| color: #333 !important; | |
| } | |
| .output-markdown h2 { | |
| font-size: 1.2em !important; | |
| margin-top: 0.8em !important; | |
| margin-bottom: 0.4em !important; | |
| color: #333 !important; | |
| } | |
| .output-markdown h3 { | |
| font-size: 1.1em !important; | |
| margin-top: 0.8em !important; | |
| margin-bottom: 0.4em !important; | |
| color: #333 !important; | |
| } | |
| .output-markdown code { | |
| background: #f0f0f0 !important; | |
| padding: 2px 4px !important; | |
| border-radius: 3px !important; | |
| font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important; | |
| font-size: 12px !important; | |
| } | |
| .output-markdown pre { | |
| background: #f5f5f5 !important; | |
| padding: 8px !important; | |
| border-radius: 5px !important; | |
| overflow-x: auto !important; | |
| font-size: 12px !important; | |
| } | |
| .output-markdown ul, .output-markdown ol { | |
| padding-left: 18px !important; | |
| margin: 8px 0 !important; | |
| } | |
| .output-markdown li { | |
| margin: 4px 0 !important; | |
| } | |
| .output-markdown p { | |
| margin: 6px 0 !important; | |
| } | |
| .output-markdown strong { | |
| font-weight: 600 !important; | |
| } | |
| footer {display: none !important} | |
| """ | |
| with gr.Blocks(css=css_styles, theme=gr.themes.Soft(primary_hue="purple")) as demo: | |
| gr.Markdown( | |
| """ | |
| # π¨ MMaDA-Parallel: Text+Image to Text+Image Generation | |
| Real-time parallel generation with step-by-step visualization. | |
| **Github:** [tyfeld/MMaDA-Parallel-A](https://github.com/tyfeld/MMaDA-Parallel-A) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input") | |
| input_image = gr.Image(type="pil", label="Input Image") | |
| prompt_text = gr.Textbox( | |
| label="Editing Instruction", | |
| lines=3, | |
| value="Make the sky more dramatic with sunset colors", | |
| placeholder="Enter your editing instruction..." | |
| ) | |
| with gr.Accordion("Model", open=False): | |
| model_path = gr.Textbox( | |
| label="Model Path", | |
| value="tyfeld/MMaDA-Parallel-A", | |
| info="HuggingFace path or local directory" | |
| ) | |
| vae_path = gr.Textbox( | |
| label="VAE Path", | |
| value="tyfeld/MMaDA-Parallel-A", | |
| info="VQ-VAE checkpoint path" | |
| ) | |
| with gr.Accordion("Parameters", open=False): | |
| with gr.Row(): | |
| height = gr.Slider(256, 768, value=512, step=64, label="Height") | |
| width = gr.Slider(256, 768, value=512, step=64, label="Width") | |
| text_steps = gr.Slider(32, 512, value=128, step=32, label="Steps") | |
| text_gen_length = gr.Slider(64, 512, value=256, step=32, label="Text Length") | |
| text_block_length = gr.Slider(16, 128, value=32, step=16, label="Block Length") | |
| with gr.Row(): | |
| cfg_scale = gr.Slider(0, 5, value=2.5, step=0.5, label="Text CFG") | |
| cfg_img = gr.Slider(0, 8, value=4.0, step=0.5, label="Image CFG") | |
| with gr.Row(): | |
| temperature = gr.Slider(0, 2, value=1.0, step=0.1, label="Image Temp") | |
| text_temperature = gr.Slider(0, 2, value=0.7, step=0.1, label="Text Temp") | |
| remasking_strategy = gr.Dropdown( | |
| choices=["low_confidence", "random"], | |
| value="low_confidence", | |
| label="Remasking" | |
| ) | |
| seed = gr.Slider(0, 10000, value=0, step=1, label="Seed (0=random)") | |
| with gr.Accordion("Painting Mode", open=False): | |
| painting_mode = gr.Dropdown( | |
| choices=[None, "inpainting", "outpainting"], | |
| value=None, | |
| label="Mode" | |
| ) | |
| with gr.Row(): | |
| mask_h_ratio = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Mask H") | |
| mask_w_ratio = gr.Slider(0.1, 1.0, value=0.5, step=0.1, label="Mask W") | |
| generate_btn = gr.Button("π Generate", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Output") | |
| status_text = gr.Textbox(label="Status", lines=2, interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1.2): | |
| output_text = gr.Markdown( | |
| value="*Waiting...*", | |
| label="Generated Text (β = masked)", | |
| show_label=True, | |
| container=True, | |
| elem_classes=["output-markdown"] | |
| ) | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(label="Generated Image", type="pil", interactive=False) | |
| generate_btn.click( | |
| fn=generate_wrapper, | |
| inputs=[ | |
| input_image, prompt_text, model_path, vae_path, | |
| height, width, text_steps, text_gen_length, text_block_length, | |
| cfg_scale, cfg_img, temperature, text_temperature, | |
| remasking_strategy, painting_mode, mask_h_ratio, mask_w_ratio, seed | |
| ], | |
| outputs=[output_text, output_image, status_text] | |
| ) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="MMaDA-Parallel Gradio Demo") | |
| parser.add_argument("--model_path", type=str, default="tyfeld/MMaDA-Parallel-A") | |
| parser.add_argument("--vae_path", type=str, default="tyfeld/MMaDA-Parallel-A") | |
| parser.add_argument("--share", action="store_true") | |
| parser.add_argument("--port", type=int, default=7860) | |
| args = parser.parse_args() | |
| print("Loading model...") | |
| load_status = load_model_and_vae(args.model_path, args.vae_path) | |
| print(load_status) | |
| demo.launch(share=args.share, server_name="0.0.0.0", server_port=args.port) | |