File size: 2,219 Bytes
779acf3
 
 
 
93fd2ea
 
779acf3
 
 
 
 
 
93fd2ea
f5e83aa
 
 
 
 
779acf3
93fd2ea
779acf3
93fd2ea
 
779acf3
 
 
 
 
 
 
 
93fd2ea
 
 
 
779acf3
 
 
93fd2ea
779acf3
 
e3c9822
779acf3
 
93fd2ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from __future__ import annotations

import PIL.Image
import torch
from diffusers import (StableDiffusionAttendAndExcitePipeline,
                       StableDiffusionPipeline)


class Model:
    def __init__(self):
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        model_id = 'CompVis/stable-diffusion-v1-4'
        self.ax_pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
            model_id)
        self.ax_pipe.to(self.device)
        self.sd_pipe = StableDiffusionPipeline.from_pretrained(model_id)
        self.sd_pipe.to(self.device)

    def get_token_table(self, prompt: str):
        tokens = [
            self.ax_pipe.tokenizer.decode(t)
            for t in self.ax_pipe.tokenizer(prompt)['input_ids']
        ]
        tokens = tokens[1:-1]
        return list(enumerate(tokens, start=1))

    def run(
        self,
        prompt: str,
        indices_to_alter_str: str,
        seed: int = 0,
        apply_attend_and_excite: bool = True,
        num_steps: int = 50,
        guidance_scale: float = 7.5,
        scale_factor: int = 20,
        thresholds: dict[int, float] = {
            10: 0.5,
            20: 0.8,
        },
        max_iter_to_alter: int = 25,
    ) -> PIL.Image.Image:
        generator = torch.Generator(device=self.device).manual_seed(seed)

        if apply_attend_and_excite:
            try:
                token_indices = list(map(int, indices_to_alter_str.split(',')))
            except Exception:
                raise ValueError('Invalid token indices.')
            out = self.ax_pipe(
                prompt=prompt,
                token_indices=token_indices,
                guidance_scale=guidance_scale,
                generator=generator,
                num_inference_steps=num_steps,
                max_iter_to_alter=max_iter_to_alter,
                thresholds=thresholds,
                scale_factor=scale_factor,
            )
        else:
            out = self.sd_pipe(
                prompt=prompt,
                guidance_scale=guidance_scale,
                generator=generator,
                num_inference_steps=num_steps,
            )
        return out.images[0]