|
import time |
|
import pytest |
|
import replicate |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
def model_name(request): |
|
return "stability-ai/sdxl" |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
def model(model_name): |
|
return replicate.models.get(model_name) |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
def version(model): |
|
versions = model.versions.list() |
|
return versions[0] |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
def training(model_name, version): |
|
training_input = { |
|
"input_images": "https://storage.googleapis.com/replicate-datasets/sdxl-test/monstertoy-captions.tar" |
|
} |
|
print(f"Training on {model_name}:{version.id}") |
|
return replicate.trainings.create( |
|
version=model_name + ":" + version.id, |
|
input=training_input, |
|
destination="replicate-internal/training-scratch", |
|
) |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
def prediction_tests(): |
|
return [ |
|
{ |
|
"prompt": "A photo of TOK at the beach", |
|
"refine": "expert_ensemble_refiner", |
|
}, |
|
] |
|
|
|
|
|
def test_training(training): |
|
while training.completed_at is None: |
|
time.sleep(60) |
|
training.reload() |
|
assert training.status == "succeeded" |
|
|
|
|
|
@pytest.fixture(scope="module") |
|
def trained_model_and_version(training): |
|
trained_model, trained_version = training.output["version"].split(":") |
|
return trained_model, trained_version |
|
|
|
|
|
def test_post_training_predictions(trained_model_and_version, prediction_tests): |
|
trained_model, trained_version = trained_model_and_version |
|
model = replicate.models.get(trained_model) |
|
version = model.versions.get(trained_version) |
|
predictions = [ |
|
replicate.predictions.create(version=version, input=val) |
|
for val in prediction_tests |
|
] |
|
|
|
for val in predictions: |
|
val.wait() |
|
assert val.status == "succeeded" |
|
|