Jose Benitez commited on
Commit
e99d2e7
·
1 Parent(s): b16249c

fix version model and lowercase model name in train

Browse files
services/image_generation.py CHANGED
@@ -19,7 +19,12 @@ def generate_image(model_name, prompt, steps, cfg_scale, width, height, lora_sca
19
  }
20
  )
21
  else:
22
- model_name = model_name.lower().replace(' ', '_')
 
 
 
 
 
23
  img_url = replicate.run(
24
  model_name,
25
  input={
 
19
  }
20
  )
21
  else:
22
+ # check if the model has a version. the model is something like model='user/model-name:version' but sometimes we just got model='user/model-name' in this case, let get and add the model version
23
+ if ':' not in model_name:
24
+ model_version = replicate.models.get(model_name).latest_version.id
25
+ print(f"Model version: {model_version}")
26
+ model_name = f"{model_name}:{model_version}"
27
+
28
  img_url = replicate.run(
29
  model_name,
30
  input={
services/train_lora.py CHANGED
@@ -7,6 +7,7 @@ REPLICATE_OWNER = "josebenitezg"
7
 
8
  def lora_pipeline(user_id, zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
9
  print(f'Creating dataset for {model_name}')
 
10
  hf_repo_name = f"joselobenitezg/flux-dev-{model_name}"
11
  replicate_repo_name = f"josebenitezg/flux-dev-{model_name}"
12
  create_repo(hf_repo_name, repo_type='model')
 
7
 
8
  def lora_pipeline(user_id, zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
9
  print(f'Creating dataset for {model_name}')
10
+ model_name = model_name.lower().replace(' ', '_')
11
  hf_repo_name = f"joselobenitezg/flux-dev-{model_name}"
12
  replicate_repo_name = f"josebenitezg/flux-dev-{model_name}"
13
  create_repo(hf_repo_name, repo_type='model')