Spaces:
Build error
Build error
from diffusers import UnCLIPPipeline, DiffusionPipeline | |
import torch | |
import os | |
from lora_diffusion.cli_lora_pti import * | |
from lora_diffusion.lora import * | |
from PIL import Image | |
import numpy as np | |
import json | |
from lora_dataset import PivotalTuningDatasetCapation as PVD | |
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} | |
UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} | |
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} | |
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"} | |
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE | |
def save_all( | |
unet, | |
text_encoder, | |
save_path, | |
placeholder_token_ids=None, | |
placeholder_tokens=None, | |
save_lora=True, | |
save_ti=True, | |
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, | |
target_replace_module_unet=DEFAULT_TARGET_REPLACE, | |
safe_form=True, | |
): | |
if not safe_form: | |
# save ti | |
if save_ti: | |
ti_path = ti_lora_path(save_path) | |
learned_embeds_dict = {} | |
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): | |
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] | |
print( | |
f"Current Learned Embeddings for {tok}:, id {tok_id} ", | |
learned_embeds[:4], | |
) | |
learned_embeds_dict[tok] = learned_embeds.detach().cpu() | |
torch.save(learned_embeds_dict, ti_path) | |
print("Ti saved to ", ti_path) | |
# save text encoder | |
if save_lora: | |
save_lora_weight( | |
unet, save_path, target_replace_module=target_replace_module_unet | |
) | |
print("Unet saved to ", save_path) | |
save_lora_weight( | |
text_encoder, | |
_text_lora_path(save_path), | |
target_replace_module=target_replace_module_text, | |
) | |
print("Text Encoder saved to ", _text_lora_path(save_path)) | |
else: | |
assert save_path.endswith( | |
".safetensors" | |
), f"Save path : {save_path} should end with .safetensors" | |
loras = {} | |
embeds = {} | |
if save_lora: | |
loras["unet"] = (unet, target_replace_module_unet) | |
loras["text_encoder"] = (text_encoder, target_replace_module_text) | |
if save_ti: | |
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): | |
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] | |
print( | |
f"Current Learned Embeddings for {tok}:, id {tok_id} ", | |
learned_embeds[:4], | |
) | |
embeds[tok] = learned_embeds.detach().cpu() | |
return save_safeloras_with_embeds(loras, embeds, save_path) | |
def save_safeloras_with_embeds( | |
modelmap = {}, | |
embeds = {}, | |
outpath="./lora.safetensors", | |
): | |
""" | |
Saves the Lora from multiple modules in a single safetensor file. | |
modelmap is a dictionary of { | |
"module name": (module, target_replace_module) | |
} | |
""" | |
weights = {} | |
metadata = {} | |
for name, (model, target_replace_module) in modelmap.items(): | |
metadata[name] = json.dumps(list(target_replace_module)) | |
for i, (_up, _down) in enumerate( | |
extract_lora_as_tensor(model, target_replace_module) | |
): | |
rank = _down.shape[0] | |
metadata[f"{name}:{i}:rank"] = str(rank) | |
weights[f"{name}:{i}:up"] = _up | |
weights[f"{name}:{i}:down"] = _down | |
for token, tensor in embeds.items(): | |
metadata[token] = EMBED_FLAG | |
weights[token] = tensor | |
sorted_dict = {key: value for key, value in sorted(weights.items())} | |
state={} | |
state['weights']=sorted_dict | |
state['metadata'] = metadata | |
# print(sorted_dict.keys()) | |
# # print('meta', metadata) | |
# print(f"Saving weights to {outpath}") | |
# safe_save(weights, outpath, metadata) | |
return state | |
def perform_tuning( | |
unet, | |
vae, | |
text_encoder, | |
dataloader, | |
num_steps, | |
scheduler, | |
optimizer, | |
save_steps: int, | |
placeholder_token_ids, | |
placeholder_tokens, | |
save_path, | |
lr_scheduler_lora, | |
lora_unet_target_modules, | |
lora_clip_target_modules, | |
mask_temperature, | |
out_name: str, | |
tokenizer, | |
test_image_path: str, | |
cached_latents: bool, | |
log_wandb: bool = False, | |
wandb_log_prompt_cnt: int = 10, | |
class_token: str = "person", | |
train_inpainting: bool = False, | |
): | |
progress_bar = tqdm(range(num_steps)) | |
progress_bar.set_description("Steps") | |
global_step = 0 | |
weight_dtype = torch.float16 | |
unet.train() | |
text_encoder.train() | |
if log_wandb: | |
preped_clip = prepare_clip_model_sets() | |
loss_sum = 0.0 | |
for epoch in range(math.ceil(num_steps / len(dataloader))): | |
for batch in dataloader: | |
lr_scheduler_lora.step() | |
optimizer.zero_grad() | |
loss = loss_step( | |
batch, | |
unet, | |
vae, | |
text_encoder, | |
scheduler, | |
train_inpainting=train_inpainting, | |
t_mutliplier=0.8, | |
mixed_precision=True, | |
mask_temperature=mask_temperature, | |
cached_latents=cached_latents, | |
) | |
loss_sum += loss.detach().item() | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_( | |
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 | |
) | |
optimizer.step() | |
progress_bar.update(1) | |
logs = { | |
"loss": loss.detach().item(), | |
"lr": lr_scheduler_lora.get_last_lr()[0], | |
} | |
progress_bar.set_postfix(**logs) | |
global_step += 1 | |
if global_step % save_steps == 0: | |
save_all( | |
unet, | |
text_encoder, | |
placeholder_token_ids=placeholder_token_ids, | |
placeholder_tokens=placeholder_tokens, | |
save_path=os.path.join( | |
save_path, f"step_{global_step}.safetensors" | |
), | |
target_replace_module_text=lora_clip_target_modules, | |
target_replace_module_unet=lora_unet_target_modules, | |
) | |
moved = ( | |
torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) | |
.mean() | |
.item() | |
) | |
print("LORA Unet Moved", moved) | |
moved = ( | |
torch.tensor( | |
list(itertools.chain(*inspect_lora(text_encoder).values())) | |
) | |
.mean() | |
.item() | |
) | |
print("LORA CLIP Moved", moved) | |
if log_wandb: | |
with torch.no_grad(): | |
pipe = StableDiffusionPipeline( | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
safety_checker=None, | |
feature_extractor=None, | |
) | |
# open all images in test_image_path | |
images = [] | |
for file in os.listdir(test_image_path): | |
if file.endswith(".png") or file.endswith(".jpg"): | |
images.append( | |
Image.open(os.path.join(test_image_path, file)) | |
) | |
wandb.log({"loss": loss_sum / save_steps}) | |
loss_sum = 0.0 | |
wandb.log( | |
evaluate_pipe( | |
pipe, | |
target_images=images, | |
class_token=class_token, | |
learnt_token="".join(placeholder_tokens), | |
n_test=wandb_log_prompt_cnt, | |
n_step=50, | |
clip_model_sets=preped_clip, | |
) | |
) | |
if global_step >= num_steps: | |
break | |
return save_all( | |
unet, | |
text_encoder, | |
placeholder_token_ids=placeholder_token_ids, | |
placeholder_tokens=placeholder_tokens, | |
save_path=os.path.join(save_path, f"{out_name}.safetensors"), | |
target_replace_module_text=lora_clip_target_modules, | |
target_replace_module_unet=lora_unet_target_modules, | |
) | |
def train( | |
images, | |
caption, | |
pretrained_model_name_or_path: str, | |
train_text_encoder: bool = True, | |
pretrained_vae_name_or_path: str = None, | |
revision: Optional[str] = None, | |
perform_inversion: bool = True, | |
use_template: Literal[None, "object", "style"] = None, | |
train_inpainting: bool = False, | |
placeholder_tokens: str = "", | |
placeholder_token_at_data: Optional[str] = None, | |
initializer_tokens: Optional[str] = None, | |
seed: int = 42, | |
resolution: int = 512, | |
color_jitter: bool = True, | |
train_batch_size: int = 1, | |
sample_batch_size: int = 1, | |
max_train_steps_tuning: int = 1000, | |
max_train_steps_ti: int = 1000, | |
save_steps: int = 100, | |
gradient_accumulation_steps: int = 4, | |
gradient_checkpointing: bool = False, | |
lora_rank: int = 4, | |
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, | |
lora_clip_target_modules={"CLIPAttention"}, | |
lora_dropout_p: float = 0.0, | |
lora_scale: float = 1.0, | |
use_extended_lora: bool = False, | |
clip_ti_decay: bool = True, | |
learning_rate_unet: float = 1e-4, | |
learning_rate_text: float = 1e-5, | |
learning_rate_ti: float = 5e-4, | |
continue_inversion: bool = False, | |
continue_inversion_lr: Optional[float] = None, | |
use_face_segmentation_condition: bool = False, | |
cached_latents: bool = True, | |
use_mask_captioned_data: bool = False, | |
mask_temperature: float = 1.0, | |
scale_lr: bool = False, | |
lr_scheduler: str = "linear", | |
lr_warmup_steps: int = 0, | |
lr_scheduler_lora: str = "linear", | |
lr_warmup_steps_lora: int = 0, | |
weight_decay_ti: float = 0.00, | |
weight_decay_lora: float = 0.001, | |
use_8bit_adam: bool = False, | |
device="cuda:0", | |
extra_args: Optional[dict] = None, | |
log_wandb: bool = False, | |
wandb_log_prompt_cnt: int = 10, | |
wandb_project_name: str = "new_pti_project", | |
wandb_entity: str = "new_pti_entity", | |
proxy_token: str = "person", | |
enable_xformers_memory_efficient_attention: bool = False, | |
out_name: str = "final_lora", | |
): | |
torch.manual_seed(seed) | |
# print(placeholder_tokens, initializer_tokens) | |
if len(placeholder_tokens) == 0: | |
placeholder_tokens = [] | |
print("PTI : Placeholder Tokens not given, using null token") | |
else: | |
placeholder_tokens = placeholder_tokens.split("|") | |
assert ( | |
sorted(placeholder_tokens) == placeholder_tokens | |
), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'" | |
if initializer_tokens is None: | |
print("PTI : Initializer Tokens not given, doing random inits") | |
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens) | |
else: | |
initializer_tokens = initializer_tokens.split("|") | |
assert len(initializer_tokens) == len( | |
placeholder_tokens | |
), "Unequal Initializer token for Placeholder tokens." | |
if proxy_token is not None: | |
class_token = proxy_token | |
class_token = "".join(initializer_tokens) | |
if placeholder_token_at_data is not None: | |
tok, pat = placeholder_token_at_data.split("|") | |
token_map = {tok: pat} | |
else: | |
token_map = {"DUMMY": "".join(placeholder_tokens)} | |
print("PTI : Placeholder Tokens", placeholder_tokens) | |
print("PTI : Initializer Tokens", initializer_tokens) | |
# get the models | |
text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models( | |
pretrained_model_name_or_path, | |
pretrained_vae_name_or_path, | |
revision, | |
placeholder_tokens, | |
initializer_tokens, | |
device=device, | |
) | |
noise_scheduler = DDPMScheduler.from_config( | |
pretrained_model_name_or_path, subfolder="scheduler" | |
) | |
if gradient_checkpointing: | |
unet.enable_gradient_checkpointing() | |
if enable_xformers_memory_efficient_attention: | |
from diffusers.utils.import_utils import is_xformers_available | |
if is_xformers_available(): | |
unet.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError( | |
"xformers is not available. Make sure it is installed correctly" | |
) | |
if scale_lr: | |
unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size | |
text_encoder_lr = ( | |
learning_rate_text * gradient_accumulation_steps * train_batch_size | |
) | |
ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size | |
else: | |
unet_lr = learning_rate_unet | |
text_encoder_lr = learning_rate_text | |
ti_lr = learning_rate_ti | |
train_dataset = PVD( | |
images=images, | |
caption=caption, | |
token_map=token_map, | |
use_template=use_template, | |
tokenizer=tokenizer, | |
size=resolution, | |
color_jitter=color_jitter, | |
use_face_segmentation_condition=use_face_segmentation_condition, | |
use_mask_captioned_data=use_mask_captioned_data, | |
train_inpainting=train_inpainting, | |
) | |
train_dataset.blur_amount = 200 | |
if train_inpainting: | |
assert not cached_latents, "Cached latents not supported for inpainting" | |
train_dataloader = inpainting_dataloader( | |
train_dataset, train_batch_size, tokenizer, vae, text_encoder | |
) | |
else: | |
print(cached_latents) | |
train_dataloader = text2img_dataloader( | |
train_dataset, | |
train_batch_size, | |
tokenizer, | |
vae, | |
text_encoder, | |
cached_latents=cached_latents, | |
) | |
index_no_updates = torch.arange(len(tokenizer)) != -1 | |
for tok_id in placeholder_token_ids: | |
index_no_updates[tok_id] = False | |
unet.requires_grad_(False) | |
vae.requires_grad_(False) | |
params_to_freeze = itertools.chain( | |
text_encoder.text_model.encoder.parameters(), | |
text_encoder.text_model.final_layer_norm.parameters(), | |
text_encoder.text_model.embeddings.position_embedding.parameters(), | |
) | |
for param in params_to_freeze: | |
param.requires_grad = False | |
if cached_latents: | |
vae = None | |
# STEP 1 : Perform Inversion | |
if perform_inversion: | |
ti_optimizer = optim.AdamW( | |
text_encoder.get_input_embeddings().parameters(), | |
lr=ti_lr, | |
betas=(0.9, 0.999), | |
eps=1e-08, | |
weight_decay=weight_decay_ti, | |
) | |
lr_scheduler = get_scheduler( | |
lr_scheduler, | |
optimizer=ti_optimizer, | |
num_warmup_steps=lr_warmup_steps, | |
num_training_steps=max_train_steps_ti, | |
) | |
train_inversion( | |
unet, | |
vae, | |
text_encoder, | |
train_dataloader, | |
max_train_steps_ti, | |
cached_latents=cached_latents, | |
accum_iter=gradient_accumulation_steps, | |
scheduler=noise_scheduler, | |
index_no_updates=index_no_updates, | |
optimizer=ti_optimizer, | |
lr_scheduler=lr_scheduler, | |
save_steps=save_steps, | |
placeholder_tokens=placeholder_tokens, | |
placeholder_token_ids=placeholder_token_ids, | |
save_path="./tmps", | |
test_image_path="./tmps", | |
log_wandb=log_wandb, | |
wandb_log_prompt_cnt=wandb_log_prompt_cnt, | |
class_token=class_token, | |
train_inpainting=train_inpainting, | |
mixed_precision=False, | |
tokenizer=tokenizer, | |
clip_ti_decay=clip_ti_decay, | |
) | |
del ti_optimizer | |
# Next perform Tuning with LoRA: | |
if not use_extended_lora: | |
unet_lora_params, _ = inject_trainable_lora( | |
unet, | |
r=lora_rank, | |
target_replace_module=lora_unet_target_modules, | |
dropout_p=lora_dropout_p, | |
scale=lora_scale, | |
) | |
else: | |
print("PTI : USING EXTENDED UNET!!!") | |
lora_unet_target_modules = ( | |
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE | |
) | |
print("PTI : Will replace modules: ", lora_unet_target_modules) | |
unet_lora_params, _ = inject_trainable_lora_extended( | |
unet, r=lora_rank, target_replace_module=lora_unet_target_modules | |
) | |
print(f"PTI : has {len(unet_lora_params)} lora") | |
print("PTI : Before training:") | |
inspect_lora(unet) | |
params_to_optimize = [ | |
{"params": itertools.chain(*unet_lora_params), "lr": unet_lr}, | |
] | |
text_encoder.requires_grad_(False) | |
if continue_inversion: | |
params_to_optimize += [ | |
{ | |
"params": text_encoder.get_input_embeddings().parameters(), | |
"lr": continue_inversion_lr | |
if continue_inversion_lr is not None | |
else ti_lr, | |
} | |
] | |
text_encoder.requires_grad_(True) | |
params_to_freeze = itertools.chain( | |
text_encoder.text_model.encoder.parameters(), | |
text_encoder.text_model.final_layer_norm.parameters(), | |
text_encoder.text_model.embeddings.position_embedding.parameters(), | |
) | |
for param in params_to_freeze: | |
param.requires_grad = False | |
else: | |
text_encoder.requires_grad_(False) | |
if train_text_encoder: | |
text_encoder_lora_params, _ = inject_trainable_lora( | |
text_encoder, | |
target_replace_module=lora_clip_target_modules, | |
r=lora_rank, | |
) | |
params_to_optimize += [ | |
{ | |
"params": itertools.chain(*text_encoder_lora_params), | |
"lr": text_encoder_lr, | |
} | |
] | |
inspect_lora(text_encoder) | |
lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora) | |
unet.train() | |
if train_text_encoder: | |
text_encoder.train() | |
train_dataset.blur_amount = 70 | |
lr_scheduler_lora = get_scheduler( | |
lr_scheduler_lora, | |
optimizer=lora_optimizers, | |
num_warmup_steps=lr_warmup_steps_lora, | |
num_training_steps=max_train_steps_tuning, | |
) | |
return perform_tuning( | |
unet, | |
vae, | |
text_encoder, | |
train_dataloader, | |
max_train_steps_tuning, | |
cached_latents=cached_latents, | |
scheduler=noise_scheduler, | |
optimizer=lora_optimizers, | |
save_steps=save_steps, | |
placeholder_tokens=placeholder_tokens, | |
placeholder_token_ids=placeholder_token_ids, | |
save_path="./tmps", | |
lr_scheduler_lora=lr_scheduler_lora, | |
lora_unet_target_modules=lora_unet_target_modules, | |
lora_clip_target_modules=lora_clip_target_modules, | |
mask_temperature=mask_temperature, | |
tokenizer=tokenizer, | |
out_name=out_name, | |
test_image_path="./tmps", | |
log_wandb=log_wandb, | |
wandb_log_prompt_cnt=wandb_log_prompt_cnt, | |
class_token=class_token, | |
train_inpainting=train_inpainting, | |
) | |
def semantic_karlo(prompt, output_dir, num_initial_image, bg_preprocess=False): | |
pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16) | |
pipe = pipe.to('cuda') | |
view_prompt=["front view of ","overhead view of ","side view of ", "back view of "] | |
if bg_preprocess: | |
# Please refer to the code at https://github.com/Ir1d/image-background-remove-tool. | |
import cv2 | |
from carvekit.api.high import HiInterface | |
interface = HiInterface(object_type="object", | |
batch_size_seg=5, | |
batch_size_matting=1, | |
device='cuda' if torch.cuda.is_available() else 'cpu', | |
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net | |
matting_mask_size=2048, | |
trimap_prob_threshold=231, | |
trimap_dilation=30, | |
trimap_erosion_iters=5, | |
fp16=False) | |
for i in range(num_initial_image): | |
t=", white background" if bg_preprocess else ", white background" | |
if i==0: | |
prompt_ = f"{view_prompt[i%4]}{prompt}{t}" | |
else: | |
prompt_ = f"{view_prompt[i%4]}{prompt}" | |
image = pipe(prompt_).images[0] | |
fn=f"instance{i}.png" | |
os.makedirs(output_dir,exist_ok=True) | |
if bg_preprocess: | |
# motivated by NeuralLift-360 (removing bg), and Zero-1-to-3 (removing bg and object-centering) | |
# NOTE: This option was added during the code orgranization process. | |
# The results reported in the paper were obtained with [bg_preprocess: False] setting. | |
img_without_background = interface([image]) | |
mask = np.array(img_without_background[0]) > 127 | |
image = np.array(image) | |
image[~mask] = [255., 255., 255.] | |
# x, y, w, h = cv2.boundingRect(mask.astype(np.uint8)) | |
# image = image[y:y+h, x:x+w, :] | |
image = Image.fromarray(np.array(image)) | |
image.save(os.path.join(output_dir,fn)) | |
def semantic_sd(prompt, output_dir, num_initial_image, bg_preprocess=False): | |
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") | |
pipe = pipe.to('cuda') | |
view_prompt=["front view of ","overhead view of ","side view of ", "back view of "] | |
if bg_preprocess: | |
# Please refer to the code at https://github.com/Ir1d/image-background-remove-tool. | |
import cv2 | |
from carvekit.api.high import HiInterface | |
interface = HiInterface(object_type="object", | |
batch_size_seg=5, | |
batch_size_matting=1, | |
device='cuda' if torch.cuda.is_available() else 'cpu', | |
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net | |
matting_mask_size=2048, | |
trimap_prob_threshold=231, | |
trimap_dilation=30, | |
trimap_erosion_iters=5, | |
fp16=False) | |
for i in range(num_initial_image): | |
t=", white background" if bg_preprocess else ", white background" | |
if i==0: | |
prompt_ = f"{view_prompt[i%4]}{prompt}{t}" | |
else: | |
prompt_ = f"{view_prompt[i%4]}{prompt}" | |
image = pipe(prompt_).images[0] | |
fn=f"instance{i}.png" | |
os.makedirs(output_dir,exist_ok=True) | |
if bg_preprocess: | |
# motivated by NeuralLift-360 (removing bg), and Zero-1-to-3 (removing bg and object-centering) | |
# NOTE: This option was added during the code orgranization process. | |
# The results reported in the paper were obtained with [bg_preprocess: False] setting. | |
img_without_background = interface([image]) | |
mask = np.array(img_without_background[0]) > 127 | |
image = np.array(image) | |
image[~mask] = [255., 255., 255.] | |
# x, y, w, h = cv2.boundingRect(mask.astype(np.uint8)) | |
# image = image[y:y+h, x:x+w, :] | |
image = Image.fromarray(np.array(image)) | |
image.save(os.path.join(output_dir,fn)) | |
def semantic_coding(images, cfgs,sd,initial): | |
ti_step=cfgs.pop('ti_step') | |
pt_step=cfgs.pop('pt_step') | |
# semantic_model=cfgs.pop('semantic_model') | |
prompt=cfgs['sd']['prompt'] | |
# instance_dir=os.path.join(exp_dir,'initial_image') | |
# weight_dir=os.path.join(exp_dir,'lora') | |
if initial=="": | |
initial=None | |
state=train(images=images, caption=initial, pretrained_model_name_or_path='runwayml/stable-diffusion-v1-5',\ | |
gradient_checkpointing=True,\ | |
scale_lr=True,lora_rank=1,cached_latents=False,save_steps=max(ti_step,pt_step)+1,\ | |
max_train_steps_ti=ti_step,max_train_steps_tuning=pt_step, use_template="object",\ | |
lr_warmup_steps=0, lr_warmup_steps_lora=100, placeholder_tokens="<0>", initializer_tokens=initial,\ | |
continue_inversion=True, continue_inversion_lr=1e-4,device="cuda:0", | |
) | |
if initial is not None: | |
sd.prompt=prompt.replace(initial,'<0>') | |
else: | |
sd.prompt="a <0>" | |
return state |