import gradio as gr from transformers import pipeline, AutoModelForImageClassification, AutoFeatureExtractor import requests import asyncio import httpx import io from PIL import Image import PIL from functools import lru_cache from toolz import pluck from piffle.image import IIIFImageClient DESCRIPTION = """ **tl;dr** this project develops a machine learning model which can tell you whether a page of a historical book contains an illustration or not. [IIIF](https://iiif.io/get-started/) is a standard for sharing images that many libraries, archives and museums use. One part of the IIIF standard is the [presentation API](https://iiif.io/api/presentation/3.0/) which defines a way of representing multiple images which make up a digitised physical item such as a book as a JSON document. This app allows you to pass the URL for one of these manifests and return all of the book pages containing illustrations. For example, [Harvard Library Digital Collections](https://library.harvard.edu/digital-collections) has many digitised books. Many of these books have IIIF manifests available. If we take [this](https://digitalcollections.library.harvard.edu/catalog/990052690260203941) book as an example, we can find a [link](https://iiif.lib.harvard.edu/manifests/drs:49040527) to it's manifest. If we pass this manifest to the app, we should get all the pages from that book containing illustrations. This is done by grabbing all the individual page URLs from the book (using the manifest) and passing them through the image classification model we created as part of this project. You can read more about how we created the model and training data for this project in this [GitHub repository](https://github.com/davanstrien/ImageIN/) """ HF_MODEL_PATH = ( "ImageIN/levit-192_finetuned_on_unlabelled_IA_with_snorkel_labels" ) classif_model = AutoModelForImageClassification.from_pretrained(HF_MODEL_PATH) feature_extractor = AutoFeatureExtractor.from_pretrained(HF_MODEL_PATH) classif_pipeline = pipeline( "image-classification", model=classif_model, feature_extractor=feature_extractor ) def load_manifest(inputs): with requests.get(inputs) as r: return r.json() def get_image_urls_from_manifest(data): image_urls = [] for sequences in data['sequences']: for canvases in sequences['canvases']: image_urls.extend(image['resource']['@id'] for image in canvases['images']) return image_urls def resize_iiif_urls(image_url, size='224'): image_url = IIIFImageClient.init_from_url(image_url) image_url = image_url.size(width=size, height=size) return image_url.__str__() async def get_image(client, url): try: resp = await client.get(url, timeout=30) return Image.open(io.BytesIO(resp.content)) except (PIL.UnidentifiedImageError, httpx.ReadTimeout, httpx.ConnectError): return None async def get_images(urls): async with httpx.AsyncClient() as client: tasks = [asyncio.ensure_future(get_image(client, url)) for url in urls] images = await asyncio.gather(*tasks) assert len(images) == len(urls) return [(url, image) for url, image in zip(urls, images) if image is not None] # return [image for image in images if image is not None] def predict(inputs): return _predict(str(inputs)) @lru_cache(maxsize=100) def _predict(inputs): data = load_manifest(inputs) urls = get_image_urls_from_manifest(data) resized_urls = [resize_iiif_urls(url) for url in urls] images_urls = asyncio.run(get_images(resized_urls)) predicted_images = [] images = list(pluck(1, images_urls)) urls = list(pluck(0, images_urls)) predictions = classif_pipeline(images, top_k=1, num_workers=2) for url, pred in zip(urls, predictions): top_pred = pred[0] if top_pred['label'] == 'illustrated': image_url = IIIFImageClient.init_from_url(url) image_url = image_url.size(width=500) image_url = image_url.size(width=500, height='') predicted_images.append((str(image_url), f"Confidence: {top_pred['score']}, \n image url: {image_url}")) return predicted_images gallery = gr.Gallery() gallery.style(grid=3) demo = gr.Interface( fn=predict, inputs=gr.Text(label="IIIF manifest url"), outputs=gallery, title="IIIF book manifest illustration detection", description=DESCRIPTION, ).queue() demo.launch()