import gradio as gr from setup import setup import torch import gc from PIL import Image from transformers import AutoModel, AutoImageProcessor from anime2sketch.model import Anime2Sketch import spaces setup() print("Setup finished") MLE_MODEL_REPO = "p1atdev/MangaLineExtraction-hf" class MangaLineExtractor: model = AutoModel.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True) processor = AutoImageProcessor.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True) @spaces.GPU @torch.no_grad() def __call__(self, image: Image.Image) -> Image.Image: inputs = self.processor(image, return_tensors="pt") outputs = self.model(inputs.pixel_values) line_image = Image.fromarray(outputs.pixel_values[0].numpy().astype("uint8"), mode="L") return line_image mle_model = MangaLineExtractor() a2s_model = Anime2Sketch("./models/netG.pth", "cpu") def flush(): gc.collect() torch.cuda.empty_cache() @torch.no_grad() def extract(image): result = mle_model(image) return result @torch.no_grad() def convert_to_sketch(image): result = a2s_model.predict(image) return result def start(image): return [extract(image), convert_to_sketch(Image.fromarray(image).convert("RGB"))] def clear(): return [None, None] def ui(): with gr.Blocks() as blocks: gr.Markdown( """ # Anime to Sketch Unofficial demo for converting illustrations into sketches. Original repos: - [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch) - [Anime2Sketch](https://github.com/Mukosame/Anime2Sketch) Using with 🤗 transformers: - [MangaLineExtraction-hf](https://huggingface.co/p1atdev/MangaLineExtraction-hf) """ ) with gr.Row(): with gr.Column(): input_img = gr.Image(label="Input", interactive=True) extract_btn = gr.Button("Start", variant="primary") clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(): # with gr.Row(): extract_output_img = gr.Image( label="MangaLineExtraction", interactive=False ) to_sketch_output_img = gr.Image(label="Anime2Sketch", interactive=False) gr.Examples( fn=start, examples=[ ["./examples/0.jpg"], ["./examples/1.jpg"], ["./examples/2.jpg"], ], inputs=[input_img], outputs=[extract_output_img, to_sketch_output_img], label="Examples", # cache_examples=True, ) gr.Markdown("Images are from nijijourney.") extract_btn.click( fn=start, inputs=[input_img], outputs=[extract_output_img, to_sketch_output_img], ) clear_btn.click( fn=clear, inputs=[], outputs=[extract_output_img, to_sketch_output_img], ) return blocks if __name__ == "__main__": ui().launch()