{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8", "metadata": {}, "outputs": [], "source": [ "import csv\n", "import tempfile\n", "import wandb\n", "from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n", "from vqgan_jax.modeling_flax_vqgan import VQModel\n", "from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel\n", "from dalle_mini.text import TextNormalizer" ] }, { "cell_type": "code", "execution_count": null, "id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270", "metadata": {}, "outputs": [], "source": [ "wandb_runs = ['rjf3rycy']\n", "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n", "normalize_text = True" ] }, { "cell_type": "code", "execution_count": null, "id": "c6a878fa-4bf5-4978-abb5-e235841d765b", "metadata": {}, "outputs": [], "source": [ "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n", "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n", "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9", "metadata": {}, "outputs": [], "source": [ "with open('samples.csv', newline='', encoding='utf8') as f:\n", " reader = csv.reader(f)\n", " for row in reader:\n", " breakpoint()" ] }, { "cell_type": "code", "execution_count": null, "id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08", "metadata": {}, "outputs": [], "source": [ "wandb_run = wandb_runs[0]\n", "api = wandb.Api()" ] }, { "cell_type": "code", "execution_count": null, "id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614", "metadata": {}, "outputs": [], "source": [ "try:\n", " versions = api.artifact_versions(type_name='bart_model', name=f'dalle-mini/dalle-mini/model-{wandb_run}', per_page=10000)\n", "except:\n", " versions = []" ] }, { "cell_type": "code", "execution_count": null, "id": "e8026e63-9e73-472c-9440-5e742c614901", "metadata": {}, "outputs": [], "source": [ "versions, len(versions)" ] }, { "cell_type": "code", "execution_count": null, "id": "29613a9d-de7e-44e3-94f1-650085039204", "metadata": {}, "outputs": [], "source": [ "versions = sorted(versions, key=lambda x: int(x.version[1:]))" ] }, { "cell_type": "code", "execution_count": null, "id": "d77159df-1a16-4996-aafd-1df82c5a3509", "metadata": {}, "outputs": [], "source": [ "versions" ] }, { "cell_type": "code", "execution_count": null, "id": "cfd48de9-6022-444f-8b12-05cba8fad071", "metadata": {}, "outputs": [], "source": [ "artifact = versions[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "4db848c1-2bb5-432c-a732-1c6d0636e172", "metadata": {}, "outputs": [], "source": [ "version = int(artifact.version[1:])" ] }, { "cell_type": "code", "execution_count": null, "id": "25fac577-146d-4e62-a3ea-f0baea79ef83", "metadata": {}, "outputs": [], "source": [ "version" ] }, { "cell_type": "code", "execution_count": null, "id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570", "metadata": {}, "outputs": [], "source": [ "# retrieve training run\n", "training_run = api.run(f'dalle-mini/dalle-mini/{wandb_run}')\n", "config = training_run.config" ] }, { "cell_type": "code", "execution_count": null, "id": "9b9393c6-0a3c-46a8-ba27-ba37982b0009", "metadata": {}, "outputs": [], "source": [ "# see summary metrics\n", "training_run.summary" ] }, { "cell_type": "code", "execution_count": null, "id": "7e784a43-626d-4e8d-9e47-a23775b2f35f", "metadata": {}, "outputs": [], "source": [ "# retrieve inference run details\n", "def get_last_version_inference(run_id):\n", " try:\n", " inference_run = api.run(f'dalle-mini/dalle-mini/inference-{run_id}')\n", " return inference_run.summary.get('_step', None)\n", " except:\n", " return None" ] }, { "cell_type": "code", "execution_count": null, "id": "93b8d869-1658-4fa4-a401-2b91f8ac7a11", "metadata": {}, "outputs": [], "source": [ "last_version_inference = get_last_version_inference(wandb_run)" ] }, { "cell_type": "code", "execution_count": null, "id": "8324835e-fd94-408e-b106-138be308480b", "metadata": {}, "outputs": [], "source": [ "if last_version_inference is None:\n", " assert version == 0\n", "else:\n", " assert version == last_version_inference + 1" ] }, { "cell_type": "code", "execution_count": null, "id": "8ce9d2d3-aea3-4d5e-834a-c5caf85dd117", "metadata": {}, "outputs": [], "source": [ "run = wandb.init(job_type='inference', config=config, id=f'inference-{wandb_run}', resume='allow')" ] }, { "cell_type": "code", "execution_count": null, "id": "ffe392c9-36d2-4aaa-a1b3-a827e348c1ef", "metadata": {}, "outputs": [], "source": [ "tmp_f.cleanup\n", "tmp_f = tempfile.TemporaryDirectory()\n", "tmp = tmp_f.name\n", "#TODO: use context manager" ] }, { "cell_type": "code", "execution_count": null, "id": "562036ed-dc86-48af-90b1-9c18383b3552", "metadata": {}, "outputs": [], "source": [ "# remove tmp\n", "tmp_f.cleanup()" ] }, { "cell_type": "code", "execution_count": null, "id": "299db1bb-fbe6-4d79-a48f-89893f8ed809", "metadata": {}, "outputs": [], "source": [ "artifact = run.use_artifact(artifact)" ] }, { "cell_type": "code", "execution_count": null, "id": "d71481bf-98aa-42cb-b7e2-545d13ae4309", "metadata": {}, "outputs": [], "source": [ "# only download required files\n", "for f in ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']:\n", " artifact.get_path(f).download(tmp)" ] }, { "cell_type": "code", "execution_count": null, "id": "6f8ad8dd-da8f-40f9-b438-e43b779d637c", "metadata": {}, "outputs": [], "source": [ "# we verify all the files are present\n", "from pathlib import Path\n", "list(Path(tmp).glob('*'))" ] }, { "cell_type": "code", "execution_count": null, "id": "5b715c32-e757-4cb0-9912-ff90238b9f10", "metadata": {}, "outputs": [], "source": [ "tokenizer = BartTokenizer.from_pretrained(tmp)\n", "model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)" ] }, { "cell_type": "code", "execution_count": null, "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "43d2a99b-3501-4b30-b041-0fdeead12380", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "06472541-75f1-44e5-841f-a4a26a0493e3", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "7a24b903-777b-4e3d-817c-00ed613a7021", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "e1c04761-1016-47e9-925c-3a9ec6fec95a", "metadata": {}, "outputs": [], "source": [ "wandb.finish()" ] }, { "cell_type": "code", "execution_count": null, "id": "b37c1714-d54b-479e-a9e8-740affc0de2c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }