Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import PIL.Image | |
import torch | |
from diffusers import StableDiffusionAttendAndExcitePipeline, StableDiffusionPipeline | |
class Model: | |
def __init__(self): | |
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
model_id = "CompVis/stable-diffusion-v1-4" | |
self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(model_id) | |
self.ax_pipe.to(self.device) | |
self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id) | |
self.sd_pipe.to(self.device) | |
def get_token_table(self, prompt: str): | |
tokens = [self.ax_pipe.tokenizer.decode(t) for t in self.ax_pipe.tokenizer(prompt)["input_ids"]] | |
tokens = tokens[1:-1] | |
return list(enumerate(tokens, start=1)) | |
def run( | |
self, | |
prompt: str, | |
indices_to_alter_str: str, | |
seed: int = 0, | |
apply_attend_and_excite: bool = True, | |
num_steps: int = 50, | |
guidance_scale: float = 7.5, | |
scale_factor: int = 20, | |
thresholds: dict[int, float] = { | |
10: 0.5, | |
20: 0.8, | |
}, | |
max_iter_to_alter: int = 25, | |
) -> PIL.Image.Image: | |
generator = torch.Generator(device=self.device).manual_seed(seed) | |
if apply_attend_and_excite: | |
try: | |
token_indices = list(map(int, indices_to_alter_str.split(","))) | |
except Exception: | |
raise ValueError("Invalid token indices.") | |
out = self.ax_pipe( | |
prompt=prompt, | |
token_indices=token_indices, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
num_inference_steps=num_steps, | |
max_iter_to_alter=max_iter_to_alter, | |
thresholds=thresholds, | |
scale_factor=scale_factor, | |
) | |
else: | |
out = self.sd_pipe( | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
num_inference_steps=num_steps, | |
) | |
return out.images[0] | |