File size: 6,381 Bytes
e988662
8dd0bda
 
e988662
 
36b0cce
 
e988662
 
25b1295
e988662
 
 
25b1295
e988662
 
8dd0bda
 
 
 
 
 
87d4598
4c9345a
de7d0eb
 
 
 
6d62697
684e4b9
 
6d62697
2a2fa4c
f92c4cd
 
 
19a2293
09b7848
 
 
 
 
615e1c0
2dc80b5
0beba8c
fb29fe3
 
9f92f81
c1b6961
f92c4cd
1042f70
 
 
 
 
eb50505
 
 
1042f70
78839d9
 
 
 
 
bd17032
 
 
 
 
076b130
bd17032
 
 
c4f5519
 
f8215e3
f92c4cd
f8215e3
9c304c0
1be8ad2
 
 
 
c1b6961
1be8ad2
 
c1b6961
1be8ad2
 
b715851
 
1be8ad2
 
9c304c0
1be8ad2
25b1295
cea6ee2
 
25b1295
f92c4cd
f8215e3
5dd3dde
 
 
43ff508
 
1be8ad2
43ff508
 
 
 
 
c9314ca
d4af306
f92c4cd
 
 
 
9d32382
f92c4cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a43544
4f843fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
from huggingface_hub import snapshot_download

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"

REVISION = "ceaf371f01ef66192264811b390bccad475a4f02"

LOCAL_FLORENCE = snapshot_download(
    repo_id="microsoft/Florence-2-base",
    revision=REVISION
)

LOCAL_TURBOX = snapshot_download(
    repo_id="tensorart/stable-diffusion-3.5-large-TurboX"
)

LOCAL_FLORENCE_DIR = snapshot_download(
    repo_id="microsoft/Florence-2-base",
    revision=REVISION,
    local_files_only=False
)

import sys, types, importlib.machinery, importlib

spec = importlib.machinery.ModuleSpec('flash_attn', loader=None)
mod = types.ModuleType('flash_attn')
mod.__spec__ = spec
sys.modules['flash_attn'] = mod

import huggingface_hub as _hf_hub
_hf_hub.cached_download = _hf_hub.hf_hub_download

import gradio as gr
import torch
import random
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from transformers import (
    CLIPTextModel,
    CLIPTokenizer,
    CLIPFeatureExtractor,
)
import diffusers
from diffusers import StableDiffusionPipeline
from diffusers import DiffusionPipeline
from diffusers import EulerDiscreteScheduler as FlowMatchEulerDiscreteScheduler
from diffusers import UNet2DConditionModel 
# from diffusers import FlowMatchEulerDiscreteScheduler
# diffusers.FlowMatchEulerDiscreteScheduler = EulerDiscreteScheduler

import transformers.utils.import_utils as _import_utils
from transformers.utils import is_flash_attn_2_available
_import_utils._is_package_available   = lambda pkg: False
_import_utils.is_flash_attn_2_available = lambda: False

hf_utils = importlib.import_module('transformers.utils')
hf_utils.is_flash_attn_2_available            = lambda *a, **k: False
hf_utils.is_flash_attn_greater_or_equal_2_10 = lambda *a, **k: False

mask_utils = importlib.import_module("transformers.modeling_attn_mask_utils")
for fn in ("_prepare_4d_attention_mask_for_sdpa", "_prepare_4d_causal_attention_mask_for_sdpa"):
    if not hasattr(mask_utils, fn):
        setattr(mask_utils, fn, lambda *a, **k: None)

cfg_mod = importlib.import_module("transformers.configuration_utils")
_PrC = cfg_mod.PretrainedConfig
_orig_getattr = _PrC.__getattribute__
def _getattr(self, name):
    if name == "_attn_implementation":
        return "sdpa"
    return _orig_getattr(self, name)
_PrC.__getattribute__ = _getattr

model_repo = "tensorart/stable-diffusion-3.5-large-TurboX"

# Florence-2 
device = "cuda" if torch.cuda.is_available() else "cpu"
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    model_repo,
    subfolder="scheduler",
    torch_dtype=torch.float16,
)
text_encoder      = CLIPTextModel.from_pretrained(
    model_repo, subfolder="text_encoder", torch_dtype=torch.float16
)
tokenizer         = CLIPTokenizer.from_pretrained(
    model_repo, subfolder="tokenizer"
)
feature_extractor = CLIPFeatureExtractor.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    subfolder="feature_extractor"
)
unet = UNet2DConditionModel.from_pretrained(
    model_repo, subfolder="unet", torch_dtype=torch.float16
)
florence_model = AutoModelForCausalLM.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True, torch_dtype=torch.float16)
florence_model.to("cpu")
florence_model.eval()
florence_processor = AutoProcessor.from_pretrained(LOCAL_FLORENCE, trust_remote_code=True)

# Stable Diffusion TurboX 

diffusers.StableDiffusion3Pipeline = StableDiffusionPipeline

pipe = DiffusionPipeline.from_pretrained(
    "tensorart/stable-diffusion-3.5-large-TurboX",
    torch_dtype=torch.float16,
    trust_remote_code=True,
    safety_checker=None,
    feature_extractor=None
)
pipe = pipe.to("cuda")

pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(model_repo, subfolder="scheduler", local_files_only=True, trust_remote_code = True, shift=5)

MAX_SEED = 2**31 - 1

def pseudo_translate_to_korean_style(en_prompt: str) -> str:
    return f"Cartoon styled {en_prompt} handsome or pretty people"

def generate_prompt(image):
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)

    inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
    generated_ids = florence_model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=512,
        num_beams=3
    )
    generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = florence_processor.post_process_generation(
        generated_text,
        task="<MORE_DETAILED_CAPTION>",
        image_size=(image.width, image.height)
    )
    prompt_en = parsed_answer["<MORE_DETAILED_CAPTION>"]

    # ๋ฒˆ์—ญ๊ธฐ ์—†์ด ์Šคํƒ€์ผ ์ ์šฉ
    cartoon_prompt = pseudo_translate_to_korean_style(prompt_en)
    return cartoon_prompt

def generate_image(prompt, seed=42, randomize_seed=False):
    """ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ โ†’ ์ด๋ฏธ์ง€ ์ƒ์„ฑ"""
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    image = pipe(
        prompt=prompt,
        guidance_scale=1.5,
        num_inference_steps=8,
        width=768,
        height=768,
        generator=generator
    ).images[0]
    return image, seed

# Gradio UI ๊ตฌ์„ฑ
with gr.Blocks() as demo:
    gr.Markdown("# ๐Ÿ–ผ ์ด๋ฏธ์ง€ โ†’ ์„ค๋ช… ์ƒ์„ฑ โ†’ ์นดํˆฐ ์ด๋ฏธ์ง€ ์ž๋™ ์ƒ์„ฑ๊ธฐ")
    
    gr.Markdown("**๐Ÿ“Œ ์‚ฌ์šฉ๋ฒ• ์•ˆ๋‚ด (ํ•œ๊ตญ์–ด)**\n"
                "- ์™ผ์ชฝ์— ์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.\n"
                "- AI๊ฐ€ ์˜์–ด ์„ค๋ช…์„ ๋งŒ๋“ค๊ณ , ๋‚ด๋ถ€์—์„œ ํ•œ๊ตญ์–ด ์Šคํƒ€์ผ ํ”„๋กฌํ”„ํŠธ๋กœ ์žฌ๊ตฌ์„ฑํ•ฉ๋‹ˆ๋‹ค.\n"
                "- ์˜ค๋ฅธ์ชฝ์— ๊ฒฐ๊ณผ ์ด๋ฏธ์ง€๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.")

    with gr.Row():
        with gr.Column():
            input_img = gr.Image(label="๐ŸŽจ ์›๋ณธ ์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
            run_button = gr.Button("โœจ ์ƒ์„ฑ ์‹œ์ž‘")

        with gr.Column():
            prompt_out = gr.Textbox(label="๐Ÿ“ ์Šคํƒ€์ผ ์ ์šฉ๋œ ํ”„๋กฌํ”„ํŠธ", lines=3, show_copy_button=True)
            output_img = gr.Image(label="๐ŸŽ‰ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€")

    def full_process(img):
        prompt = generate_prompt(img)
        image, seed = generate_image(prompt, randomize_seed=True)
        return prompt, image

    run_button.click(fn=full_process, inputs=[input_img], outputs=[prompt_out, output_img])

demo.launch()