File size: 2,214 Bytes
3ae8d58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb1d5ab
 
 
 
 
3ae8d58
 
 
4dee35d
 
3ae8d58
4dee35d
3ae8d58
4dee35d
 
3ae8d58
 
 
 
 
 
 
 
 
bb1d5ab
3ae8d58
 
 
bf87810
3ae8d58
bb1d5ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from pred_color import *
import gradio as gr

from diffusers import (
    AutoencoderKL,
    ControlNetModel,
    DDPMScheduler,
    StableDiffusionControlNetPipeline,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
)
import torch
from diffusers.utils import load_image

controlnet_model_name_or_path = "svjack/ControlNet-Face-Zh"
controlnet = ControlNetModel.from_pretrained(controlnet_model_name_or_path)
#controlnet = controlnet.to("cuda")

base_model_path = "IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1"
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    base_model_path, controlnet=controlnet,
    #torch_dtype=torch.float16
)

# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
#pipe.enable_model_cpu_offload()
#pipe = pipe.to("cuda")

if torch.cuda.is_available():
    pipe = pipe.to("cuda")
else:
    #pipe.enable_model_cpu_offload()
    pass

example_sample = [
    ["Protector_Cromwell_style.png", "戴帽子穿灰色衣服的男子"]
]

from PIL import Image
def pred_func(image, prompt):
    #out = single_pred_features(image)
    features ,face_features = single_pred_features(image)
    req_img = produce_center_crop_image(features ,face_features)
    out = {}
    out["spiga_seg"] = req_img
    if type(out) == type({}):
        #return out["spiga_seg"]
        control_image = out["spiga_seg"]
        if type(image) == type("") and os.path.exists(image):
            image = Image.open(image).convert("RGB")
        elif hasattr(image, "shape"):
            image = Image.fromarray(image).convert("RGB")
        else:
            image = image.convert("RGB")
        image = image.resize((512, 512))

        generator = torch.manual_seed(0)
        image = pipe(
             prompt, num_inference_steps=50,
             generator=generator, image=control_image
        ).images[0]
        return control_image ,image


gr=gr.Interface(fn=pred_func, inputs=['image','text'],
outputs=[gr.Image(label='output').style(height=512),
gr.Image(label='output').style(height=512)],
examples=example_sample if example_sample else None,
cache_examples = False
)
gr.launch(share=False)