File size: 1,804 Bytes
3ab8901 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import time
import pytest
import replicate
@pytest.fixture(scope="module")
def model_name(request):
return "stability-ai/sdxl"
@pytest.fixture(scope="module")
def model(model_name):
return replicate.models.get(model_name)
@pytest.fixture(scope="module")
def version(model):
versions = model.versions.list()
return versions[0]
@pytest.fixture(scope="module")
def training(model_name, version):
training_input = {
"input_images": "https://storage.googleapis.com/replicate-datasets/sdxl-test/monstertoy-captions.tar"
}
print(f"Training on {model_name}:{version.id}")
return replicate.trainings.create(
version=model_name + ":" + version.id,
input=training_input,
destination="replicate-internal/training-scratch",
)
@pytest.fixture(scope="module")
def prediction_tests():
return [
{
"prompt": "A photo of TOK at the beach",
"refine": "expert_ensemble_refiner",
},
]
def test_training(training):
while training.completed_at is None:
time.sleep(60)
training.reload()
assert training.status == "succeeded"
@pytest.fixture(scope="module")
def trained_model_and_version(training):
trained_model, trained_version = training.output["version"].split(":")
return trained_model, trained_version
def test_post_training_predictions(trained_model_and_version, prediction_tests):
trained_model, trained_version = trained_model_and_version
model = replicate.models.get(trained_model)
version = model.versions.get(trained_version)
predictions = [
replicate.predictions.create(version=version, input=val)
for val in prediction_tests
]
for val in predictions:
val.wait()
assert val.status == "succeeded"
|