{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "964ccced", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from shap_e.diffusion.sample import sample_latents\n", "from shap_e.diffusion.gaussian_diffusion import diffusion_from_config\n", "from shap_e.models.download import load_model, load_config\n", "from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget" ] }, { "cell_type": "code", "execution_count": null, "id": "8eed3a76", "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": null, "id": "2d922637", "metadata": {}, "outputs": [], "source": [ "xm = load_model('transmitter', device=device)\n", "model = load_model('text300M', device=device)\n", "diffusion = diffusion_from_config(load_config('diffusion'))" ] }, { "cell_type": "code", "execution_count": null, "id": "53d329d0", "metadata": {}, "outputs": [], "source": [ "batch_size = 4\n", "guidance_scale = 15.0\n", "prompt = \"a shark\"\n", "\n", "latents = sample_latents(\n", " batch_size=batch_size,\n", " model=model,\n", " diffusion=diffusion,\n", " guidance_scale=guidance_scale,\n", " model_kwargs=dict(texts=[prompt] * batch_size),\n", " progress=True,\n", " clip_denoised=True,\n", " use_fp16=True,\n", " use_karras=True,\n", " karras_steps=64,\n", " sigma_min=1e-3,\n", " sigma_max=160,\n", " s_churn=0,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "633da2ec", "metadata": {}, "outputs": [], "source": [ "render_mode = 'nerf' # you can change this to 'stf'\n", "size = 64 # this is the size of the renders; higher values take longer to render.\n", "\n", "cameras = create_pan_cameras(size, device)\n", "for i, latent in enumerate(latents):\n", " images = decode_latent_images(xm, latent, cameras, rendering_mode=render_mode)\n", " display(gif_widget(images))" ] }, { "cell_type": "code", "execution_count": null, "id": "85a4dce4", "metadata": {}, "outputs": [], "source": [ "# Example of saving the latents as meshes.\n", "from shap_e.util.notebooks import decode_latent_mesh\n", "\n", "for i, latent in enumerate(latents):\n", " t = decode_latent_mesh(xm, latent).tri_mesh()\n", " with open(f'example_mesh_{i}.ply', 'wb') as f:\n", " t.write_ply(f)\n", " with open(f'example_mesh_{i}.obj', 'w') as f:\n", " t.write_obj(f)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" } }, "nbformat": 4, "nbformat_minor": 5 }