Anime-to-Sketch / app.py
p1atdev's picture
chore: zero gpu test
ebf9292
import gradio as gr
from setup import setup
import torch
import gc
from PIL import Image
from transformers import AutoModel, AutoImageProcessor
from anime2sketch.model import Anime2Sketch
import spaces
setup()
print("Setup finished")
MLE_MODEL_REPO = "p1atdev/MangaLineExtraction-hf"
class MangaLineExtractor:
model = AutoModel.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True)
processor = AutoImageProcessor.from_pretrained(MLE_MODEL_REPO, trust_remote_code=True)
@spaces.GPU
@torch.no_grad()
def __call__(self, image: Image.Image) -> Image.Image:
inputs = self.processor(image, return_tensors="pt")
outputs = self.model(inputs.pixel_values)
line_image = Image.fromarray(outputs.pixel_values[0].numpy().astype("uint8"), mode="L")
return line_image
mle_model = MangaLineExtractor()
a2s_model = Anime2Sketch("./models/netG.pth", "cpu")
def flush():
gc.collect()
torch.cuda.empty_cache()
@torch.no_grad()
def extract(image):
result = mle_model(image)
return result
@torch.no_grad()
def convert_to_sketch(image):
result = a2s_model.predict(image)
return result
def start(image):
return [extract(image), convert_to_sketch(Image.fromarray(image).convert("RGB"))]
def clear():
return [None, None]
def ui():
with gr.Blocks() as blocks:
gr.Markdown(
"""
# Anime to Sketch
Unofficial demo for converting illustrations into sketches.
Original repos:
- [MangaLineExtraction_PyTorch](https://github.com/ljsabc/MangaLineExtraction_PyTorch)
- [Anime2Sketch](https://github.com/Mukosame/Anime2Sketch)
Using with 🤗 transformers:
- [MangaLineExtraction-hf](https://huggingface.co/p1atdev/MangaLineExtraction-hf)
"""
)
with gr.Row():
with gr.Column():
input_img = gr.Image(label="Input", interactive=True)
extract_btn = gr.Button("Start", variant="primary")
clear_btn = gr.Button("Clear", variant="secondary")
with gr.Column():
# with gr.Row():
extract_output_img = gr.Image(
label="MangaLineExtraction", interactive=False
)
to_sketch_output_img = gr.Image(label="Anime2Sketch", interactive=False)
gr.Examples(
fn=start,
examples=[
["./examples/0.jpg"],
["./examples/1.jpg"],
["./examples/2.jpg"],
],
inputs=[input_img],
outputs=[extract_output_img, to_sketch_output_img],
label="Examples",
# cache_examples=True,
)
gr.Markdown("Images are from nijijourney.")
extract_btn.click(
fn=start,
inputs=[input_img],
outputs=[extract_output_img, to_sketch_output_img],
)
clear_btn.click(
fn=clear,
inputs=[],
outputs=[extract_output_img, to_sketch_output_img],
)
return blocks
if __name__ == "__main__":
ui().launch()