Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import gradio as gr | |
from facenet_pytorch import MTCNN | |
from torchvision import transforms | |
import numpy as np | |
import onnxruntime as ort | |
from PIL import Image | |
from huggingface_hub import hf_hub_download | |
class ArcaneGANProcessor: | |
def __init__(self): | |
self.hf_token = os.getenv('HF_TOKEN') | |
if not self.hf_token: | |
raise ValueError("HF_TOKEN not found in environment variables") | |
print("HF_TOKEN found in environment variables") | |
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
if self.device.type == 'cpu': | |
print("Warning: Using CPU, performance may be reduced.") | |
self.mtcnn = MTCNN( | |
image_size=256, | |
margin=80, | |
keep_all=True, | |
device=self.device, | |
post_process=True, | |
select_largest=True | |
) | |
self.img_transforms = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def download_model(self, version): | |
"""Download model from HuggingFace Hub.""" | |
model_filename = f"ArcaneGAN{version}.onnx" | |
try: | |
model_path = hf_hub_download( | |
repo_id="Arrcttacsrks/ArcaneGanOnnx", | |
filename=model_filename, | |
token=self.hf_token | |
) | |
return model_path | |
except Exception as e: | |
raise RuntimeError(f"Failed to download model: {str(e)}") | |
def process_image(self, image, version): | |
"""Process image through ArcaneGAN.""" | |
if image is None: | |
raise ValueError("Input image is None") | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
boxes, _ = self.mtcnn.detect(image) | |
if boxes is None: | |
raise ValueError("No face detected in the image") | |
face = self.mtcnn(image) | |
if face is None: | |
raise ValueError("Failed to process face") | |
face = face.unsqueeze(0) # Shape: [1, C, H, W] | |
model_path = self.download_model(version) | |
ort_session = ort.InferenceSession(model_path) | |
# Ensure the input is of the correct shape | |
ort_inputs = {ort_session.get_inputs()[0].name: face.numpy()} | |
ort_output = ort_session.run(None, ort_inputs)[0] | |
output = torch.from_numpy(ort_output) | |
output = output.squeeze(0).permute(1, 2, 0) | |
output = output.clamp(0, 1) * 255 # Use clamp instead of clip | |
output = output.cpu().numpy().astype(np.uint8) | |
return Image.fromarray(output) | |
def create_interface(): | |
"""Create Gradio interface.""" | |
processor = ArcaneGANProcessor() | |
with gr.Blocks() as demo: | |
gr.Markdown("# ArcaneGAN Converter") | |
with gr.Row(): | |
input_image = gr.Image(type="numpy", label="Input Image") | |
output_image = gr.Image(type="pil", label="Output Image") | |
version = gr.Radio( | |
choices=["v0.4", "v0.3", "v0.2"], | |
value="v0.4", | |
label="Model Version" | |
) | |
process_button = gr.Button("Convert") | |
process_button.click( | |
fn=processor.process_image, | |
inputs=[input_image, version], | |
outputs=output_image | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() | |