Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| from st_keyup import st_keyup | |
| from streamlit_helpers import * | |
| from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler | |
| VERSION2SPECS = { | |
| "SDXL-Turbo": { | |
| "H": 512, | |
| "W": 512, | |
| "C": 4, | |
| "f": 8, | |
| "is_legacy": False, | |
| "config": "configs/inference/sd_xl_base.yaml", | |
| "ckpt": "checkpoints/sd_xl_turbo_1.0.safetensors", | |
| }, | |
| "SD-Turbo": { | |
| "H": 512, | |
| "W": 512, | |
| "C": 4, | |
| "f": 8, | |
| "is_legacy": False, | |
| "config": "configs/inference/sd_2_1.yaml", | |
| "ckpt": "checkpoints/sd_turbo.safetensors", | |
| }, | |
| } | |
| class SubstepSampler(EulerAncestralSampler): | |
| def __init__(self, n_sample_steps=1, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.n_sample_steps = n_sample_steps | |
| self.steps_subset = [0, 100, 200, 300, 1000] | |
| def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): | |
| sigmas = self.discretization( | |
| self.num_steps if num_steps is None else num_steps, device=self.device | |
| ) | |
| sigmas = sigmas[ | |
| self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:] | |
| ] | |
| uc = cond | |
| x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) | |
| num_sigmas = len(sigmas) | |
| s_in = x.new_ones([x.shape[0]]) | |
| return x, s_in, sigmas, num_sigmas, cond, uc | |
| def seeded_randn(shape, seed): | |
| randn = np.random.RandomState(seed).randn(*shape) | |
| randn = torch.from_numpy(randn).to(device="cuda", dtype=torch.float32) | |
| return randn | |
| class SeededNoise: | |
| def __init__(self, seed): | |
| self.seed = seed | |
| def __call__(self, x): | |
| self.seed = self.seed + 1 | |
| return seeded_randn(x.shape, self.seed) | |
| def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): | |
| value_dict = {} | |
| for key in keys: | |
| if key == "txt": | |
| value_dict["prompt"] = prompt | |
| value_dict["negative_prompt"] = "" | |
| if key == "original_size_as_tuple": | |
| orig_width = init_dict["orig_width"] | |
| orig_height = init_dict["orig_height"] | |
| value_dict["orig_width"] = orig_width | |
| value_dict["orig_height"] = orig_height | |
| if key == "crop_coords_top_left": | |
| crop_coord_top = 0 | |
| crop_coord_left = 0 | |
| value_dict["crop_coords_top"] = crop_coord_top | |
| value_dict["crop_coords_left"] = crop_coord_left | |
| if key == "aesthetic_score": | |
| value_dict["aesthetic_score"] = 6.0 | |
| value_dict["negative_aesthetic_score"] = 2.5 | |
| if key == "target_size_as_tuple": | |
| value_dict["target_width"] = init_dict["target_width"] | |
| value_dict["target_height"] = init_dict["target_height"] | |
| return value_dict | |
| def sample( | |
| model, | |
| sampler, | |
| prompt="A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.", | |
| H=1024, | |
| W=1024, | |
| seed=0, | |
| filter=None, | |
| ): | |
| F = 8 | |
| C = 4 | |
| shape = (1, C, H // F, W // F) | |
| value_dict = init_embedder_options( | |
| keys=get_unique_embedder_keys_from_conditioner(model.conditioner), | |
| init_dict={ | |
| "orig_width": W, | |
| "orig_height": H, | |
| "target_width": W, | |
| "target_height": H, | |
| }, | |
| prompt=prompt, | |
| ) | |
| if seed is None: | |
| seed = torch.seed() | |
| precision_scope = autocast | |
| with torch.no_grad(): | |
| with precision_scope("cuda"): | |
| batch, batch_uc = get_batch( | |
| get_unique_embedder_keys_from_conditioner(model.conditioner), | |
| value_dict, | |
| [1], | |
| ) | |
| c = model.conditioner(batch) | |
| uc = None | |
| randn = seeded_randn(shape, seed) | |
| def denoiser(input, sigma, c): | |
| return model.denoiser( | |
| model.model, | |
| input, | |
| sigma, | |
| c, | |
| ) | |
| samples_z = sampler(denoiser, randn, cond=c, uc=uc) | |
| samples_x = model.decode_first_stage(samples_z) | |
| samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) | |
| if filter is not None: | |
| samples = filter(samples) | |
| samples = ( | |
| (255 * samples) | |
| .to(dtype=torch.uint8) | |
| .permute(0, 2, 3, 1) | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| ) | |
| return samples | |
| def v_spacer(height) -> None: | |
| for _ in range(height): | |
| st.write("\n") | |
| if __name__ == "__main__": | |
| st.title("Turbo") | |
| head_cols = st.columns([1, 1, 1]) | |
| with head_cols[0]: | |
| version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0) | |
| version_dict = VERSION2SPECS[version] | |
| with head_cols[1]: | |
| v_spacer(2) | |
| if st.checkbox("Load Model"): | |
| mode = "txt2img" | |
| else: | |
| mode = "skip" | |
| if mode != "skip": | |
| state = init_st(version_dict, load_filter=True) | |
| if state["msg"]: | |
| st.info(state["msg"]) | |
| model = state["model"] | |
| load_model(model) | |
| # seed | |
| if "seed" not in st.session_state: | |
| st.session_state.seed = 0 | |
| def increment_counter(): | |
| st.session_state.seed += 1 | |
| def decrement_counter(): | |
| if st.session_state.seed > 0: | |
| st.session_state.seed -= 1 | |
| with head_cols[2]: | |
| n_steps = st.number_input(label="number of steps", min_value=1, max_value=4) | |
| sampler = SubstepSampler( | |
| n_sample_steps=1, | |
| num_steps=1000, | |
| eta=1.0, | |
| discretization_config=dict( | |
| target="sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization" | |
| ), | |
| ) | |
| sampler.n_sample_steps = n_steps | |
| default_prompt = ( | |
| "A cinematic shot of a baby racoon wearing an intricate italian priest robe." | |
| ) | |
| prompt = st_keyup( | |
| "Enter a value", value=default_prompt, debounce=300, key="interactive_text" | |
| ) | |
| cols = st.columns([1, 5, 1]) | |
| if mode != "skip": | |
| with cols[0]: | |
| v_spacer(14) | |
| st.button("↩", on_click=decrement_counter) | |
| with cols[2]: | |
| v_spacer(14) | |
| st.button("↪", on_click=increment_counter) | |
| sampler.noise_sampler = SeededNoise(seed=st.session_state.seed) | |
| out = sample( | |
| model, | |
| sampler, | |
| H=512, | |
| W=512, | |
| seed=st.session_state.seed, | |
| prompt=prompt, | |
| filter=state.get("filter"), | |
| ) | |
| with cols[1]: | |
| st.image(out[0]) | |