NDugar's picture
Create app.py
6ed6d81
raw
history blame
2.79 kB
from cgitb import enable
from ctypes.wintypes import HFONT
import os
import sys
import torch
import gradio as gr
import numpy as np
import torchvision.transforms as transforms
from torch.autograd import Variable
from network.Transformer import Transformer
from huggingface_hub import hf_hub_download
from PIL import Image
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
MODEL_PATH = "models"
COLOUR_MODEL = "RGB"
MODEL_REPO = "NDugar/horse_to_zebra_cycle_GAN"
MODEL_FILE = "h2z-85epoch.pth"
# Model Initalisation
#shinkai_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_SHINKAI, filename=MODEL_FILE_SHINKAI)
#hosoda_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_HOSODA, filename=MODEL_FILE_HOSODA)
#miyazaki_model_hfhub = hf_hub_download(repo_id=MODEL_REPO_MIYAZAKI, filename=MODEL_FILE_MIYAZAKI)
model_hfhub = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE)
#shinkai_model = Transformer()
#hosoda_model = Transformer()
#miyazaki_model = Transformer()
model = Transformer()
enable_gpu = torch.cuda.is_available()
map_location = torch.device("cuda") if enable_gpu else "cpu"
model.load_state_dict(torch.load(model_hfhub, map_location=map_location))
shinkai_model.eval()
hosoda_model.eval()
miyazaki_model.eval()
kon_model.eval()
# Functions
def get_model():
return model
def adjust_image_for_model(img):
logger.info(f"Image Height: {img.height}, Image Width: {img.width}")
return img
def inference(img, style):
img = adjust_image_for_model(img)
input_image = img.convert(COLOUR_MODEL)
input_image = np.asarray(input_image)
input_image = input_image[:, :, [2, 1, 0]]
input_image = transforms.ToTensor()(input_image).unsqueeze(0)
input_image = -1 + 2 * input_image
if enable_gpu:
logger.info(f"CUDA found. Using GPU.")
input_image = Variable(input_image).cuda()
else:
logger.info(f"CUDA not found. Using CPU.")
input_image = Variable(input_image).float()
model = get_model()
output_image = model(input_image)
output_image = output_image[0]
# BGR -> RGB
output_image = output_image[[2, 1, 0], :, :]
output_image = output_image.data.cpu().float() * 0.5 + 0.5
return transforms.ToPILImage()(output_image)
# Gradio setup
title = "Horse 2 Zebra GAN"
description = "Gradio Demo for CycleGAN"
gr.Interface(
fn=inference,
inputs=[
gr.inputs.Image(
type="pil",
label="Input Photo",
),
],
outputs=gr.outputs.Image(
type="pil",
label="Output Image",
),
title=title,
description=description,
article=article,
examples=examples,
allow_flagging="never",
allow_screenshot=False,
).launch(enable_queue=True)