obtu-ai / services /train_lora.py
Jose Benitez
fix version model and lowercase model name in train
e99d2e7
raw
history blame
2.34 kB
import replicate
import os
from huggingface_hub import create_repo
from database import create_lora_models
REPLICATE_OWNER = "josebenitezg"
def lora_pipeline(user_id, zip_path, model_name, trigger_word="TOK", steps=1000, lora_rank=16, batch_size=1, autocaption=True, learning_rate=0.0004):
print(f'Creating dataset for {model_name}')
model_name = model_name.lower().replace(' ', '_')
hf_repo_name = f"joselobenitezg/flux-dev-{model_name}"
replicate_repo_name = f"josebenitezg/flux-dev-{model_name}"
create_repo(hf_repo_name, repo_type='model')
lora_name = f"flux-dev-{model_name}"
model = replicate.models.create(
owner=REPLICATE_OWNER,
name=lora_name,
visibility="private", # or "private" if you prefer
hardware="gpu-t4", # Replicate will override this for fine-tuned models
description="A fine-tuned FLUX.1 model"
)
print(f"Model created: {model.name}")
print(f"Model URL: https://replicate.com/{model.owner}/{model.name}")
# Now use this model as the destination for your training
print(f"[INFO] Starting training")
print(f'\n[INFO] Parametros a entrenar: \n Trigger word: {trigger_word}\n steps: {steps} \n lora_rank: {lora_rank}\n autocaption: {autocaption}\n learning_rate: {learning_rate}\n')
training = replicate.trainings.create(
version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
input={
"input_images": open(zip_path, "rb"),
"steps": steps,
"lora_rank": lora_rank,
"batch_size": batch_size,
"autocaption": autocaption,
"trigger_word": trigger_word,
"learning_rate": learning_rate,
"hf_token": os.getenv('HF_TOKEN'), # optional
"hf_repo_id": hf_repo_name, # optional
},
destination=f"{model.owner}/{model.name}"
)
print(f"training: {training.keys()}")
print(f"Training started: {training.status}")
print(f"Training URL: https://replicate.com/p/{training.id}")
print(f"Creating model in Database")
training_url = f"https://replicate.com/p/{training.id}"
create_lora_models(user_id, replicate_repo_name, trigger_word, steps, lora_rank, batch_size, learning_rate, hf_repo_name, training_url)