Spaces:
Runtime error
Runtime error
RohitGandikota
commited on
Commit
Β·
1e14cf1
1
Parent(s):
68e2466
adding train method dropdown
Browse files
app.py
CHANGED
|
@@ -230,12 +230,13 @@ class Demo:
|
|
| 230 |
self.iterations_input,
|
| 231 |
self.lr_input,
|
| 232 |
self.attributes_input,
|
| 233 |
-
self.is_person
|
|
|
|
| 234 |
],
|
| 235 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
| 236 |
)
|
| 237 |
|
| 238 |
-
def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, pbar = gr.Progress(track_tqdm=True)):
|
| 239 |
iterations_input = min(int(iterations_input),1000)
|
| 240 |
if attributes_input == '':
|
| 241 |
attributes_input = None
|
|
@@ -257,13 +258,13 @@ class Demo:
|
|
| 257 |
attributes = 'white, black, asian, hispanic, indian, male, female'
|
| 258 |
|
| 259 |
self.training = True
|
| 260 |
-
train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), device=self.device, attributes=attributes, save_name=save_name)
|
| 261 |
self.training = False
|
| 262 |
|
| 263 |
torch.cuda.empty_cache()
|
| 264 |
-
model_map['
|
| 265 |
|
| 266 |
-
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value='
|
| 267 |
|
| 268 |
|
| 269 |
def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
|
|
|
|
| 230 |
self.iterations_input,
|
| 231 |
self.lr_input,
|
| 232 |
self.attributes_input,
|
| 233 |
+
self.is_person,
|
| 234 |
+
self.train_method_input
|
| 235 |
],
|
| 236 |
outputs=[self.train_button, self.train_status, self.download, self.model_dropdown]
|
| 237 |
)
|
| 238 |
|
| 239 |
+
def train(self, target_concept,positive_prompt, negative_prompt, rank, iterations_input, lr_input, attributes_input, is_person, train_method_input, pbar = gr.Progress(track_tqdm=True)):
|
| 240 |
iterations_input = min(int(iterations_input),1000)
|
| 241 |
if attributes_input == '':
|
| 242 |
attributes_input = None
|
|
|
|
| 258 |
attributes = 'white, black, asian, hispanic, indian, male, female'
|
| 259 |
|
| 260 |
self.training = True
|
| 261 |
+
train_xl(target=target_concept, positive=positive_prompt, negative=negative_prompt, lr=lr_input, iterations=iterations_input, config_file='trainscripts/textsliders/data/config-xl.yaml', rank=int(rank), train_method=train_method_input, device=self.device, attributes=attributes, save_name=save_name)
|
| 262 |
self.training = False
|
| 263 |
|
| 264 |
torch.cuda.empty_cache()
|
| 265 |
+
model_map[save_name.replace('.pt','')] = f'models/{save_name}'
|
| 266 |
|
| 267 |
+
return [gr.update(interactive=True, value='Train'), gr.update(value='Done Training! \n Try your custom slider in the "Test" tab'), f'models/{save_name}', gr.update(choices=list(model_map.keys()), value=save_name.replace('.pt',''))]
|
| 268 |
|
| 269 |
|
| 270 |
def inference(self, prompt, seed, start_noise, scale, model_name, pbar = gr.Progress(track_tqdm=True)):
|
trainscripts/textsliders/data/config-xl.yaml
CHANGED
|
@@ -7,7 +7,7 @@ network:
|
|
| 7 |
type: "c3lier" # or "c3lier" or "lierla"
|
| 8 |
rank: 4
|
| 9 |
alpha: 1.0
|
| 10 |
-
training_method: "
|
| 11 |
train:
|
| 12 |
precision: "bfloat16"
|
| 13 |
noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
|
|
|
|
| 7 |
type: "c3lier" # or "c3lier" or "lierla"
|
| 8 |
rank: 4
|
| 9 |
alpha: 1.0
|
| 10 |
+
training_method: "noxattn"
|
| 11 |
train:
|
| 12 |
precision: "bfloat16"
|
| 13 |
noise_scheduler: "ddim" # or "ddpm", "lms", "euler_a"
|
trainscripts/textsliders/demotrain.py
CHANGED
|
@@ -411,7 +411,7 @@ def train(
|
|
| 411 |
# train(config, prompts, device)
|
| 412 |
|
| 413 |
|
| 414 |
-
def train_xl(target, positive, negative, lr, iterations, config_file, rank, device, attributes,save_name):
|
| 415 |
|
| 416 |
config = config_util.load_config_from_yaml(config_file)
|
| 417 |
randn = torch.randint(1, 10000000, (1,)).item()
|
|
@@ -427,6 +427,7 @@ def train_xl(target, positive, negative, lr, iterations, config_file, rank, devi
|
|
| 427 |
attributes = []
|
| 428 |
config.network.alpha = 1.0
|
| 429 |
config.network.rank = int(rank)
|
|
|
|
| 430 |
|
| 431 |
# config.save.path += f'/{config.save.name}'
|
| 432 |
|
|
|
|
| 411 |
# train(config, prompts, device)
|
| 412 |
|
| 413 |
|
| 414 |
+
def train_xl(target, positive, negative, lr, iterations, config_file, rank, train_method, device, attributes,save_name):
|
| 415 |
|
| 416 |
config = config_util.load_config_from_yaml(config_file)
|
| 417 |
randn = torch.randint(1, 10000000, (1,)).item()
|
|
|
|
| 427 |
attributes = []
|
| 428 |
config.network.alpha = 1.0
|
| 429 |
config.network.rank = int(rank)
|
| 430 |
+
config.network.training_method = train_method
|
| 431 |
|
| 432 |
# config.save.path += f'/{config.save.name}'
|
| 433 |
|