File size: 648 Bytes
bc912c5
4967dbc
 
 
 
 
 
 
 
 
 
bc912c5
1bd1511
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import gradio as gr
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers.utils import load_image
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel



#gr.Interface.load("models/tsungtao/controlnet-mlsd-202305011046").launch()
def infer(prompt, negative_prompt, image):
    # implement your inference function here
    return output_image

# you need to pass inputs and outputs according to inference function
gr.Interface(fn = infer, inputs = ["text", "text", "image"], outputs = "image").launch()