Spaces:
Sleeping
Sleeping
File size: 1,769 Bytes
e47c7c5 6bbdd23 e47c7c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import argparse
import os
import torch
from syngen_diffusion_pipeline import SynGenDiffusionPipeline
def main(prompt, seed, output_directory, model_path):
pipe = load_model(model_path)
image = generate(pipe, prompt, seed)
save_image(image, prompt, seed, output_directory)
def load_model(model_path):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
pipe = SynGenDiffusionPipeline.from_pretrained(model_path).to(device)
pipe.disable_xformers_memory_efficient_attention()
return pipe
def generate(pipe, prompt, seed):
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
generator = torch.Generator(device.type).manual_seed(seed)
result = pipe(prompt=prompt, generator=generator)
return result['images'][0]
def save_image(image, prompt, seed, output_directory):
if not os.path.exists(output_directory):
os.makedirs(output_directory)
file_name = f"{output_directory}/{prompt}_{seed}.png"
image.save(file_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
default="a checkered bowl on a red and blue table"
)
parser.add_argument(
'--seed',
type=int,
default=1924
)
parser.add_argument(
'--output_directory',
type=str,
default='./output'
)
parser.add_argument(
'--model_path',
type=str,
default='CompVis/stable-diffusion-v1-4',
help='The path to the model (this will download the model if the path doesn\'t exist)'
)
args = parser.parse_args()
main(args.prompt, args.seed, args.output_directory, args.model_path)
|