File size: 5,435 Bytes
c5799f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# -*- coding: utf-8 -*-
# Author: ximing
# Description: the main func of this project.
# Copyright (c) 2023, XiMing Xing.
# License: MIT License
import os
import sys
import argparse
from datetime import datetime
import random
from typing import Any, List
from functools import partial
from accelerate.utils import set_seed
import omegaconf
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
from libs.engine import merge_and_update_config
from libs.utils.argparse import accelerate_parser, base_data_parser
def render_batch_wrap(args: omegaconf.DictConfig,
seed_range: List,
pipeline: Any,
**pipe_args):
start_time = datetime.now()
for idx, seed in enumerate(seed_range):
args.seed = seed # update seed
print(f"\n-> [{idx}/{len(seed_range)}], "
f"current seed: {seed}, "
f"current time: {datetime.now() - start_time}\n")
pipe = pipeline(args)
pipe.painterly_rendering(**pipe_args)
def main(args, seed_range):
args.batch_size = 1 # rendering one SVG at a time
args.width = float(args.width)
render_batch_fn = partial(render_batch_wrap, args=args, seed_range=seed_range)
if args.task == "diffsketcher": # text2sketch
from pipelines.painter.diffsketcher_pipeline import DiffSketcherPipeline
if not args.render_batch:
pipe = DiffSketcherPipeline(args)
pipe.painterly_rendering(args.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=DiffSketcherPipeline, prompt=args.prompt)
elif args.task == "style-diffsketcher": # text2sketch + style transfer
from pipelines.painter.diffsketcher_stylized_pipeline import StylizedDiffSketcherPipeline
if not args.render_batch:
pipe = StylizedDiffSketcherPipeline(args)
pipe.painterly_rendering(args.prompt, args.style_file)
else: # generate many SVG at once
render_batch_fn(pipeline=StylizedDiffSketcherPipeline, prompt=args.prompt, style_fpath=args.style_file)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="vary style and content painterly rendering",
parents=[accelerate_parser(), base_data_parser()]
)
# flag
parser.add_argument("-tk", "--task",
default="diffsketcher", type=str,
choices=['diffsketcher', 'style-diffsketcher'],
help="choose a method.")
# config
parser.add_argument("-c", "--config",
required=True, type=str,
default="",
help="YAML/YML file for configuration.")
parser.add_argument("-style", "--style_file",
default="", type=str,
help="the path of style img place.")
# prompt
parser.add_argument("-pt", "--prompt", default="A horse is drinking water by the lake", type=str)
parser.add_argument("-npt", "--negative_prompt", default="", type=str)
# DiffSVG
parser.add_argument("--print_timing", "-timing", action="store_true",
help="set print svg rendering timing.")
# diffuser
parser.add_argument("--download", action="store_true",
help="download models from huggingface automatically.")
parser.add_argument("--force_download", "-download", action="store_true",
help="force the models to be downloaded from huggingface.")
parser.add_argument("--resume_download", "-dpm_resume", action="store_true",
help="download the models again from the breakpoint.")
# rendering quantity
# like: python main.py -rdbz -srange 100 200
parser.add_argument("--render_batch", "-rdbz", action="store_true")
parser.add_argument("-srange", "--seed_range",
required=False, nargs='+',
help="Sampling quantity.")
# visual rendering process
parser.add_argument("-mv", "--make_video", action="store_true",
help="make a video of the rendering process.")
parser.add_argument("-frame_freq", "--video_frame_freq",
default=1, type=int,
help="video frame control.")
parser.add_argument("-framerate", "--video_frame_rate",
default=36, type=int,
help="by adjusting the frame rate, you can control the playback speed of the output video.")
args = parser.parse_args()
# set the random seed range
seed_range = None
if args.render_batch:
# random sampling without specifying a range
start_, end_ = 1, 1000000
if args.seed_range is not None: # specify range sequential sampling
seed_range_ = list(args.seed_range)
assert len(seed_range_) == 2 and int(seed_range_[1]) > int(seed_range_[0])
start_, end_ = int(seed_range_[0]), int(seed_range_[1])
seed_range = [i for i in range(start_, end_)]
else:
# a list of lengths 1000 sampled from the range start_ to end_ (e.g.: [1, 1000000])
numbers = list(range(start_, end_))
seed_range = random.sample(numbers, k=1000)
args = merge_and_update_config(args)
set_seed(args.seed)
main(args, seed_range)
|