Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,267 Bytes
7362797 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import json
import os
import torch
import argparse
from PIL import Image
from chameleon.inference.chameleon import ChameleonInferenceModel, Options
from constants import (
MODEL_7B_PATH,
TOKENIZER_TEXT_PATH,
TOKENIZER_IMAGE_CFG_PATH,
TOKENIZER_IMAGE_PATH,
)
from typing import List, Tuple
import logging
# Set up the logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def split_token_sequence(
tokens: torch.LongTensor,
boi: int,
eoi: int
) -> List[Tuple[str, torch.LongTensor]]:
"""
Split a sequence of tokens into text and image segments.
Args:
tokens (torch.LongTensor): The token sequence.
boi (int): Begin of image token.
eoi (int): End of image token.
Returns:
List[Tuple[str, torch.LongTensor]]: List of tuples indicating segment type and tokens.
"""
batch_size, _ = tokens.shape
assert batch_size == 1, "Batch size must be 1"
device = tokens.device
tokens = tokens[0] # remove batch dimension
tokens = tokens.to(device)
segments = []
current_segment = []
in_image_seg = False
for token in tokens:
if token == boi:
# if entering an image segment, save the current text segment (if any)
if current_segment:
segments.append(("text_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
current_segment = []
in_image_seg = True
elif token == eoi and in_image_seg:
# if exiting an image segment, save the current image segment
segments.append(("image_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
current_segment = []
in_image_seg = False
else:
current_segment.append(token)
# save any remaining tokens
if current_segment:
if in_image_seg:
segments.append(("image_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
else:
segments.append(("text_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
return segments
def main(args: argparse.Namespace):
"""Main function to generate and process model output."""
# Load Chameleon model
model = ChameleonInferenceModel(
MODEL_7B_PATH.as_posix(),
TOKENIZER_TEXT_PATH.as_posix(),
TOKENIZER_IMAGE_CFG_PATH.as_posix(),
TOKENIZER_IMAGE_PATH.as_posix(),
)
# Print model configuration
logging.info(f"Model path: {MODEL_7B_PATH}")
logging.info(f"Text tokenizer path: {TOKENIZER_TEXT_PATH}")
logging.info(f"Image tokenizer config path: {TOKENIZER_IMAGE_CFG_PATH}")
logging.info(f"Image tokenizer path: {TOKENIZER_IMAGE_PATH}")
# Generate options
options = Options()
# Prepare prompt
instructions = [args.instruction]
batch_prompt_ui = []
for instruction in instructions:
if isinstance(instruction, Tuple):
inst, image_path = instruction
batch_prompt_ui += [
[
{"type": "image", "value": f"file:{image_path}"},
{"type": "text", "value": inst}
],
]
else:
batch_prompt_ui += [
[
{"type": "text", "value": instruction}
],
]
# generate
tokens: torch.LongTensor = model.generate(
batch_prompt_ui=batch_prompt_ui,
options=options
)
# split
boi, eoi = model.vocab.begin_image, model.vocab.end_image # 8197(boi), 8196(eoi)
segments = split_token_sequence(tokens, boi, eoi)
# decode
os.makedirs(args.save_dir, exist_ok=True)
segments_data = []
for seg_id, (seg_type, seg_tokens) in enumerate(segments):
if seg_type == "image_seg":
assert seg_tokens.shape[1] == 1024
img = model.decode_image(seg_tokens)[0]
image_path = os.path.join(args.save_dir, f"{seg_id}.png")
img.save(image_path)
segments_data.append({"type": "image", "content": image_path})
else:
assert seg_type == "text_seg"
decoded_text = model.decode_text(seg_tokens)[0]
segments_data.append({"type": "text", "content": decoded_text})
jsonl_path = os.path.join("./segments.jsonl")
with open(jsonl_path, 'w') as jsonl_file:
for segment in segments_data:
jsonl_file.write(json.dumps(segment) + '\n')
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Generate interleaved image-text content based on text instructions.")
parser.add_argument("-i", "--instruction", type=str, required=True, help="The instruction for interleaved image-text generation.")
parser.add_argument("-s", "--save_dir", type=str, default="./outputs/interleaved/", help="The directory to save the generated images.")
args: argparse.Namespace = parser.parse_args()
return args
if __name__ == "__main__":
args: argparse.Namespace = parse_arguments()
main(args)
|