File size: 2,702 Bytes
779acf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from __future__ import annotations

import sys

import gradio as gr
import PIL.Image
import torch

sys.path.append('Attend-and-Excite')

from config import RunConfig
from pipeline_attend_and_excite import AttendAndExcitePipeline
from run import run_on_prompt
from utils.ptp_utils import AttentionStore


class Model:
    def __init__(self):
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model_id = ''
        self.model = None
        self.tokenizer = None

        self.load_model('CompVis/stable-diffusion-v1-4')

    def load_model(self, model_id: str) -> None:
        if model_id == self.model_id:
            return
        self.model = AttendAndExcitePipeline.from_pretrained(model_id).to(
            self.device)
        self.tokenizer = self.model.tokenizer
        self.model_id = model_id

    def get_token_table(self, model_id: str, prompt: str):
        self.load_model(model_id)
        tokens = [
            self.tokenizer.decode(t)
            for t in self.tokenizer(prompt)['input_ids']
        ]
        tokens = tokens[1:-1]
        return list(enumerate(tokens, start=1))

    def run(
        self,
        model_id: str,
        prompt: str,
        indices_to_alter_str: str,
        seed: int,
        apply_attend_and_excite: bool,
        num_steps: int,
        guidance_scale: float,
        scale_factor: int = 20,
        thresholds: dict[int, float] = {
            10: 0.5,
            20: 0.8
        },
        max_iter_to_alter: int = 25,
    ) -> tuple[list[tuple[int, str]], PIL.Image.Image]:
        generator = torch.Generator(device=self.device).manual_seed(seed)
        try:
            indices_to_alter = list(map(int, indices_to_alter_str.split(',')))
        except:
            raise gr.Error('Invalid token indices.')

        self.load_model(model_id)

        token_table = self.get_token_table(model_id, prompt)

        controller = AttentionStore()
        config = RunConfig(prompt=prompt,
                           n_inference_steps=num_steps,
                           guidance_scale=guidance_scale,
                           run_standard_sd=not apply_attend_and_excite,
                           scale_factor=scale_factor,
                           thresholds=thresholds,
                           max_iter_to_alter=max_iter_to_alter)
        image = run_on_prompt(model=self.model,
                              prompt=[prompt],
                              controller=controller,
                              token_indices=indices_to_alter,
                              seed=generator,
                              config=config)

        return token_table, image