from __future__ import annotations import sys import gradio as gr import PIL.Image import torch sys.path.append('Attend-and-Excite') from config import RunConfig from pipeline_attend_and_excite import AttendAndExcitePipeline from run import run_on_prompt from utils.ptp_utils import AttentionStore class Model: def __init__(self): self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.model_id = '' self.model = None self.tokenizer = None self.load_model('CompVis/stable-diffusion-v1-4') def load_model(self, model_id: str) -> None: if model_id == self.model_id: return self.model = AttendAndExcitePipeline.from_pretrained(model_id).to( self.device) self.tokenizer = self.model.tokenizer self.model_id = model_id def get_token_table(self, model_id: str, prompt: str): self.load_model(model_id) tokens = [ self.tokenizer.decode(t) for t in self.tokenizer(prompt)['input_ids'] ] tokens = tokens[1:-1] return list(enumerate(tokens, start=1)) def run( self, model_id: str, prompt: str, indices_to_alter_str: str, seed: int, apply_attend_and_excite: bool, num_steps: int, guidance_scale: float, scale_factor: int = 20, thresholds: dict[int, float] = { 10: 0.5, 20: 0.8 }, max_iter_to_alter: int = 25, ) -> tuple[list[tuple[int, str]], PIL.Image.Image]: generator = torch.Generator(device=self.device).manual_seed(seed) try: indices_to_alter = list(map(int, indices_to_alter_str.split(','))) except: raise gr.Error('Invalid token indices.') self.load_model(model_id) token_table = self.get_token_table(model_id, prompt) controller = AttentionStore() config = RunConfig(prompt=prompt, n_inference_steps=num_steps, guidance_scale=guidance_scale, run_standard_sd=not apply_attend_and_excite, scale_factor=scale_factor, thresholds=thresholds, max_iter_to_alter=max_iter_to_alter) image = run_on_prompt(model=self.model, prompt=[prompt], controller=controller, token_indices=indices_to_alter, seed=generator, config=config) return token_table, image