tsungtao's picture
Update app.py
07a1273
raw
history blame
648 Bytes
import gradio as gr
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers.utils import load_image
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
#gr.Interface.load("models/tsungtao/controlnet-mlsd-202305011046").launch()
def infer(prompt, negative_prompt, image):
# implement your inference function here
return output_image
# you need to pass inputs and outputs according to inference function
gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "image").launch()