--- library_name: transformers pipeline_tag: image-segmentation tags: - vision - image-segmentation - dit datasets: - ds4sd/DocLayNet-v1.1 widget: - src: >- https://upload.wikimedia.org/wikipedia/commons/c/c3/LibreOffice_Writer_6.3.png example_title: Wiki --- Trained for 4 epochs. Usage: ``` image_processor = AutoImageProcessor.from_pretrained("microsoft/dit-large") model = BeitForSemanticSegmentation.from_pretrained("jzju/dit-doclaynet") image = Image.open('img.png').convert('RGB') inputs = image_processor(images=image, return_tensors="pt") outputs = model(**inputs) # logits are of shape (batch_size, num_labels, height, width) logits = outputs.logits out = logits[0].detach() out.size() for i in range(11): plt.imshow(out[i]) plt.show() ``` Labels: ``` 1: Caption 2: Footnote 3: Formula 4: List-item 5: Page-footer 6: Page-header 7: Picture 8: Section-header 9: Table 10: Text 11: Title ``` Data label convert: ``` model = BeitForSemanticSegmentation.from_pretrained("microsoft/dit-base", num_labels=11) ds = load_dataset("ds4sd/DocLayNet-v1.1") mask = np.zeros([11, 1025, 1025]) for b, c in zip(d["bboxes"], d["category_id"]): b = [np.clip(int(bb), 0, 1025) for bb in b] mask[c - 1][b[1]:b[1]+b[3], b[0]:b[0]+b[2]] = 1 mask = [cv2.resize(a, dsize=(56, 56), interpolation=cv2.INTER_AREA) for a in mask] d["label"] = np.stack(mask) ```