File size: 3,358 Bytes
30ed707
 
54bb5af
ee2cec4
30ed707
 
 
54bb5af
 
 
1b42456
54bb5af
30ed707
 
 
 
 
 
 
 
 
 
54bb5af
 
 
 
 
 
 
 
 
fd55bc4
ee2cec4
 
 
30ed707
 
 
 
 
 
 
ee2cec4
 
 
30ed707
 
54bb5af
 
30ed707
 
ee2cec4
54bb5af
 
 
 
 
 
ee2cec4
54bb5af
 
ee2cec4
 
 
 
54bb5af
 
 
 
 
 
30ed707
ee2cec4
 
 
 
 
30ed707
ee2cec4
 
 
 
 
30ed707
ee2cec4
30ed707
ee2cec4
 
 
 
 
30ed707
 
 
 
 
 
ee2cec4
1b42456
54bb5af
ee2cec4
30ed707
ee2cec4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import uuid

from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import FileResponse
import uvicorn

app = FastAPI()

# Constants
bases = {
    "Cartoon": "frankjoshua/toonyou_beta6",
    "Realistic": "emilianJR/epiCRealism",
    "3d": "Lykon/DreamShaper",
    "Anime": "Yntec/mistoonAnime2"
}
motions = {
    "Zoom in": "guoyww/animatediff-motion-lora-zoom-in",
    "Zoom out": "guoyww/animatediff-motion-lora-zoom-out",
    "Tilt up": "guoyww/animatediff-motion-lora-tilt-up",
    "Tilt down": "guoyww/animatediff-motion-lora-tilt-down",
    "Pan left": "guoyww/animatediff-motion-lora-pan-left",
    "Pan right": "guoyww/animatediff-motion-lora-pan-right",
    "Roll left": "guoyww/animatediff-motion-lora-rolling-anticlockwise",
    "Roll right": "guoyww/animatediff-motion-lora-rolling-clockwise",
}
step_loaded = None
base_loaded = "Realistic"
motion_loaded = None

# Ensure model and scheduler are initialized in GPU-enabled function
if not torch.cuda.is_available():
    raise NotImplementedError("No GPU detected!")

device = "cuda"
dtype = torch.float16
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing",
                                                    beta_schedule="linear")

# Safety checkers
from transformers import CLIPFeatureExtractor

feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")


class GenerateImageRequest(BaseModel):
    prompt: str
    base: str = "Realistic"
    motion: str = ""
    step: int = 8


@app.post("/generate-image")
def generate_image(request: GenerateImageRequest):
    global step_loaded
    global base_loaded
    global motion_loaded

    prompt = request.prompt
    base = request.base
    motion = request.motion
    step = request.step

    print(prompt, base, step)

    if step_loaded != step:
        repo = "ByteDance/AnimateDiff-Lightning"
        ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
        pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
        step_loaded = step

    if base_loaded != base:
        pipe.unet.load_state_dict(
            torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
            strict=False)
        base_loaded = base

    if motion_loaded != motion:
        pipe.unload_lora_weights()
        if motion in motions:
            motion_repo = motions[motion]
            pipe.load_lora_weights(motion_repo, adapter_name="motion")
            pipe.set_adapters(["motion"], [0.7])
        motion_loaded = motion

    output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step)

    name = str(uuid.uuid4()).replace("-", "")
    path = f"/tmp/{name}.mp4"
    export_to_video(output.frames[0], path, fps=10)

    return FileResponse(path, media_type="video/mp4", filename=f"{name}.mp4")


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)