In [None]:
import csv
import tempfile
import wandb
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
from dalle_mini.text import TextNormalizer

In [None]:
wandb_runs = ['rjf3rycy']
VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None
normalize_text = True

In [None]:
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
with open('samples.csv', newline='', encoding='utf8') as f:
 reader = csv.reader(f)
 for row in reader:
 breakpoint()

In [None]:
wandb_run = wandb_runs[0]
api = wandb.Api()

In [None]:
try:
 versions = api.artifact_versions(type_name='bart_model', name=f'dalle-mini/dalle-mini/model-{wandb_run}', per_page=10000)
except:
 versions = []

In [None]:
versions, len(versions)

In [None]:
versions = sorted(versions, key=lambda x: int(x.version[1:]))

In [None]:
versions

In [None]:
artifact = versions[0]

In [None]:
version = int(artifact.version[1:])

In [None]:
version

In [None]:
# retrieve training run
training_run = api.run(f'dalle-mini/dalle-mini/{wandb_run}')
config = training_run.config

In [None]:
# see summary metrics
training_run.summary

In [None]:
# retrieve inference run details
def get_last_version_inference(run_id):
 try:
 inference_run = api.run(f'dalle-mini/dalle-mini/inference-{run_id}')
 return inference_run.summary.get('_step', None)
 except:
 return None

In [None]:
last_version_inference = get_last_version_inference(wandb_run)

In [None]:
if last_version_inference is None:
 assert version == 0
else:
 assert version == last_version_inference + 1

In [None]:
run = wandb.init(job_type='inference', config=config, id=f'inference-{wandb_run}', resume='allow')

In [None]:
tmp_f.cleanup
tmp_f = tempfile.TemporaryDirectory()
tmp = tmp_f.name
#TODO: use context manager

In [None]:
# remove tmp
tmp_f.cleanup()

In [None]:
artifact = run.use_artifact(artifact)

In [None]:
# only download required files
for f in ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']:
 artifact.get_path(f).download(tmp)

In [None]:
# we verify all the files are present
from pathlib import Path
list(Path(tmp).glob('*'))

In [None]:
tokenizer = BartTokenizer.from_pretrained(tmp)
model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)

In [None]:
wandb.finish()