| |
| import random |
| from xtuner.utils import DEFAULT_IMAGE_TOKEN |
|
|
| SEG_QUESTIONS = [ |
| "Can you segment the {class_name} in this image?", |
| "Please segment {class_name} in this image.", |
| "What is {class_name} in this image? Please respond with segmentation mask.", |
| "What is {class_name} in this image? Please output segmentation mask.", |
|
|
| "Can you segment the {class_name} in this image", |
| "Please segment {class_name} in this image", |
| "What is {class_name} in this image? Please respond with segmentation mask", |
| "What is {class_name} in this image? Please output segmentation mask", |
|
|
| "Could you provide a segmentation mask for the {class_name} in this image?", |
| "Please identify and segment the {class_name} in this image.", |
| "Where is the {class_name} in this picture? Please respond with a segmentation mask.", |
| "Can you highlight the {class_name} in this image with a segmentation mask?", |
|
|
| "Could you provide a segmentation mask for the {class_name} in this image", |
| "Please identify and segment the {class_name} in this image", |
| "Where is the {class_name} in this picture? Please respond with a segmentation mask", |
| "Can you highlight the {class_name} in this image with a segmentation mask", |
| ] |
|
|
| ANSWER_LIST = [ |
| "It is [SEG].", |
| "Sure, [SEG].", |
| "Sure, it is [SEG].", |
| "Sure, the segmentation result is [SEG].", |
| "[SEG].", |
| ] |
|
|
| ANSWER_LIST_GCG_FORMAT = [ |
| "<p> {} </p> [SEG].", |
| ] |
|
|
| def semantic_seg_conversations(labels): |
| ret = [] |
| for i, label in enumerate(labels): |
| label = label.strip() |
| assert len(label.split("||")) == 1 |
| for question_template in SEG_QUESTIONS: |
| for answer_template in ANSWER_LIST: |
| item = {} |
| item['conversations'] = [{'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question_template.format(class_name=label.lower())}, |
| {'from': 'gpt', 'value': answer_template}] |
| item['class_id'] = i |
| ret.append(item) |
| return ret |
|
|
| def semantic_seg_map_fn(example): |
| |
| messages = example['conversations'] |
| input = '' |
| conversation = [] |
| while messages and messages[0]['from'] == 'gpt': |
| |
| messages = messages[1:] |
| for msg in messages: |
| if msg['from'] == 'human': |
| if DEFAULT_IMAGE_TOKEN in msg['value']: |
| msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, |
| '').strip() |
| msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] |
| msg['value'] = msg['value'].strip() |
| input += msg['value'] |
|
|
| elif msg['from'] == 'gpt': |
| conversation.append({'input': input, 'output': msg['value']}) |
| input = '' |
| else: |
| raise NotImplementedError |
| example.update({'conversation': conversation}) |
| return example |
|
|
| def pascal_part_conversation(selected_labels): |
| conversations = [] |
| for i, selected_label in enumerate(selected_labels): |
| question = random.choice(SEG_QUESTIONS).format(class_name=selected_label.lower()).strip() |
| answer = random.choice(ANSWER_LIST) |
| if i == 0: |
| question = DEFAULT_IMAGE_TOKEN + question |
| conversations.append({'from': 'human', 'value': question}) |
| conversations.append({'from': 'gpt', 'value': answer}) |
| return conversations |
|
|
| def pascal_part_preprocess(example): |
| selected_labels = example["selected_labels"] |
| conversations = pascal_part_conversation(selected_labels) |
| example['conversations'] = conversations |
| return example |
|
|
| def pascal_part_map_fn(example): |
| example = pascal_part_preprocess(example) |
| example['image'] = example["file_name"] |
| |
| messages = example['conversations'] |
| input = '' |
| conversation = [] |
| while messages and messages[0]['from'] == 'gpt': |
| |
| messages = messages[1:] |
| for msg in messages: |
| if msg['from'] == 'human': |
| if DEFAULT_IMAGE_TOKEN in msg['value']: |
| msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, |
| '').strip() |
| msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] |
| msg['value'] = msg['value'].strip() |
| input += msg['value'] |
|
|
| elif msg['from'] == 'gpt': |
| conversation.append({'input': input, 'output': msg['value']}) |
| input = '' |
| else: |
| raise NotImplementedError |
| example.update({'conversation': conversation}) |
| return example |
|
|
|
|
| def semantic_seg_gcg_format_conversations(labels): |
| ret = [] |
| for i, label in enumerate(labels): |
| label = label.strip() |
| assert len(label.split("||")) == 1 |
| for question_template in SEG_QUESTIONS: |
| for answer_template in ANSWER_LIST_GCG_FORMAT: |
| item = {} |
| item['conversations'] = [{'from': 'human', 'value': DEFAULT_IMAGE_TOKEN+question_template.format(class_name=label.lower())}, |
| {'from': 'gpt', 'value': answer_template.format(label.lower().capitalize())}] |
| item['class_id'] = i |
| ret.append(item) |
| return ret |
|
|
| def semantic_seg_gcg_format_map_fn(example): |
| |
| messages = example['conversations'] |
| input = '' |
| conversation = [] |
| while messages and messages[0]['from'] == 'gpt': |
| |
| messages = messages[1:] |
| for msg in messages: |
| if msg['from'] == 'human': |
| if DEFAULT_IMAGE_TOKEN in msg['value']: |
| msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, |
| '').strip() |
| msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] |
| msg['value'] = msg['value'].strip() |
| input += msg['value'] |
|
|
| elif msg['from'] == 'gpt': |
| conversation.append({'input': input, 'output': msg['value']}) |
| input = '' |
| else: |
| raise NotImplementedError |
| example.update({'conversation': conversation}) |
| return example |
|
|
| def pascal_part_gcg_format_conversation(selected_labels): |
| conversations = [] |
| for i, selected_label in enumerate(selected_labels): |
| question = random.choice(SEG_QUESTIONS).format(class_name=selected_label.lower()).strip() |
| answer = random.choice(ANSWER_LIST).format(selected_label.lower().capitalize()) |
| if i == 0: |
| question = DEFAULT_IMAGE_TOKEN + question |
| conversations.append({'from': 'human', 'value': question}) |
| conversations.append({'from': 'gpt', 'value': answer}) |
| return conversations |
|
|
| def pascal_part_gcg_format_preprocess(example): |
| selected_labels = example["selected_labels"] |
| conversations = pascal_part_gcg_format_conversation(selected_labels) |
| example['conversations'] = conversations |
| return example |
|
|
| def pascal_part_gcg_format_map_fn(example): |
| example = pascal_part_gcg_format_preprocess(example) |
| example['image'] = example["file_name"] |
| |
| messages = example['conversations'] |
| input = '' |
| conversation = [] |
| while messages and messages[0]['from'] == 'gpt': |
| |
| messages = messages[1:] |
| for msg in messages: |
| if msg['from'] == 'human': |
| if DEFAULT_IMAGE_TOKEN in msg['value']: |
| msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN, |
| '').strip() |
| msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value'] |
| msg['value'] = msg['value'].strip() |
| input += msg['value'] |
|
|
| elif msg['from'] == 'gpt': |
| conversation.append({'input': input, 'output': msg['value']}) |
| input = '' |
| else: |
| raise NotImplementedError |
| example.update({'conversation': conversation}) |
| return example |
|
|
|
|
|
|