Spaces:
Sleeping
Sleeping
import os | |
from huggingface_hub import model_info | |
import torch | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
def main(): | |
REPOS = { | |
"tom_cruise_plain": {"hub_model_id": "asrimanth/person-thumbs-up-plain-lora", "model_dir": "/l/vision/v5/sragas/easel_ai/models_plain/"}, | |
"tom_cruise": {"hub_model_id": "asrimanth/person-thumbs-up-lora", "model_dir": "/l/vision/v5/sragas/easel_ai/models/"}, | |
"tom_cruise_no_cap": {"hub_model_id": "asrimanth/person-thumbs-up-lora-no-cap", "model_dir": "/l/vision/v5/sragas/easel_ai/models_no_cap/"}, | |
"srimanth_plain": {"hub_model_id": "asrimanth/srimanth-thumbs-up-lora-plain", "model_dir": "/l/vision/v5/sragas/easel_ai/models_srimanth_plain/"} | |
} | |
N_IMAGES = 50 | |
current_repo_id = "tom_cruise_no_cap" | |
SAVE_DIR = f"./results/{current_repo_id}/" | |
os.makedirs(SAVE_DIR, exist_ok=True) | |
current_repo = REPOS[current_repo_id] | |
print(f"{'-'*20} CURRENT REPO: {current_repo_id} {'-'*20}") | |
hub_model_id = current_repo["hub_model_id"] | |
model_dir = current_repo["model_dir"] | |
info = model_info(hub_model_id) | |
model_base = info.cardData["base_model"] | |
print(f"Base model is: {model_base}") | |
pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16, cache_dir="/l/vision/v5/sragas/hf_models/") | |
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.unet.load_attn_procs(hub_model_id) | |
pipe.to("cuda") | |
generators = [torch.Generator("cuda").manual_seed(i) for i in range(N_IMAGES)] | |
prompt = "<tom_cruise> showing #thumbsup" | |
print(f"Inferencing '{prompt}' for {N_IMAGES} images.") | |
for i in range(N_IMAGES): | |
image = pipe(prompt, generator=generators[i], num_inference_steps=25).images[0] | |
image.save(f"{SAVE_DIR}out_{i}.png") | |
if __name__ == "__main__": | |
main() | |