|
import argparse |
|
import os.path as osp |
|
|
|
import mmcv |
|
import numpy as np |
|
import cv2 |
|
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot |
|
|
|
|
|
palette = np.random.randint(0, 255, (104, 3)) |
|
DEFAULT_PALETTE = palette |
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='MMSeg Inference') |
|
parser.add_argument('img', help='Image file') |
|
parser.add_argument('config', help='Config file') |
|
parser.add_argument('checkpoint', help='Checkpoint file') |
|
parser.add_argument('--device', default='cpu', help='Device used for inference') |
|
args = parser.parse_args() |
|
return args |
|
|
|
def compute_all_segments_area(image): |
|
""" |
|
Computes the pixel area for each segment in the image. |
|
Args: |
|
- image: A 2D numpy array representing the segmented image (grayscale). |
|
|
|
Returns: |
|
- A dictionary with segment values as keys and their corresponding areas (in pixels) as values. |
|
""" |
|
|
|
assert len(image.shape) == 2, "Image should be grayscale" |
|
|
|
unique_values, counts = np.unique(image, return_counts=True) |
|
return dict(zip(unique_values, counts)) |
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
|
|
model = init_segmentor(args.config, args.checkpoint, device=args.device) |
|
|
|
|
|
with open('../data/FoodSeg103/category_id.txt', 'r') as f: |
|
class_names = [line.strip().split('\t')[1] for line in f.readlines()] |
|
|
|
model.CLASSES = class_names |
|
|
|
|
|
|
|
result = inference_segmentor(model, args.img) |
|
|
|
|
|
print("Inside image_demo DEFAULT_PALETTE.shape:", DEFAULT_PALETTE.shape) |
|
|
|
|
|
|
|
|
|
segment_areas = compute_all_segments_area(result[0]) |
|
|
|
|
|
segment_classes = {class_names[segment_value]: area for segment_value, area in segment_areas.items()} |
|
|
|
|
|
for class_name, area in segment_classes.items(): |
|
print(f"Segment for class '{class_name}': Area = {area} pixels") |
|
|
|
|
|
class_names = model.CLASSES if model.CLASSES is not None else ['class_{}'.format(i) for i in range(104)] |
|
output_path = '../output_images/segmented_result.jpg' |
|
show_result_pyplot(model, args.img, result, DEFAULT_PALETTE, out_file=output_path, class_names=class_names) |
|
|
|
|
|
class_names = model.CLASSES if model.CLASSES is not None else ['class_{}'.format(i) for i in range(104)] |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|
|
|