{"guide": {"name": "create-your-own-friends-with-a-gan", "category": "other-tutorials", "pretty_category": "Other Tutorials", "guide_index": null, "absolute_index": 59, "pretty_name": "Create Your Own Friends With A Gan", "content": "# Create Your Own Friends with a GAN\n\n\n\n\n\n\n## Introduction\n\nIt seems that cryptocurrencies, [NFTs](https://www.nytimes.com/interactive/2022/03/18/technology/nft-guide.html), and the web3 movement are all the rage these days! Digital assets are being listed on marketplaces for astounding amounts of money, and just about every celebrity is debuting their own NFT collection. While your crypto assets [may be taxable, such as in Canada](https://www.canada.ca/en/revenue-agency/programs/about-canada-revenue-agency-cra/compliance/digital-currency/cryptocurrency-guide.html), today we'll explore some fun and tax-free ways to generate your own assortment of procedurally generated [CryptoPunks](https://www.larvalabs.com/cryptopunks).\n\nGenerative Adversarial Networks, often known just as _GANs_, are a specific class of deep-learning models that are designed to learn from an input dataset to create (_generate!_) new material that is convincingly similar to elements of the original training set. Famously, the website [thispersondoesnotexist.com](https://thispersondoesnotexist.com/) went viral with lifelike, yet synthetic, images of people generated with a model called StyleGAN2. GANs have gained traction in the machine learning world, and are now being used to generate all sorts of images, text, and even [music](https://salu133445.github.io/musegan/)!\n\nToday we'll briefly look at the high-level intuition behind GANs, and then we'll build a small demo around a pre-trained GAN to see what all the fuss is about. Here's a [peek](https://nimaboscarino-cryptopunks.hf.space) at what we're going to be putting together.\n\n### Prerequisites\n\nMake sure you have the `gradio` Python package already [installed](/getting_started). To use the pretrained model, also install `torch` and `torchvision`.\n\n## GANs: a very brief introduction\n\nOriginally proposed in [Goodfellow et al. 2014](https://arxiv.org/abs/1406.2661), GANs are made up of neural networks which compete with the intention of outsmarting each other. One network, known as the _generator_, is responsible for generating images. The other network, the _discriminator_, receives an image at a time from the generator along with a **real** image from the training data set. The discriminator then has to guess: which image is the fake?\n\nThe generator is constantly training to create images which are trickier for the discriminator to identify, while the discriminator raises the bar for the generator every time it correctly detects a fake. As the networks engage in this competitive (_adversarial!_) relationship, the images that get generated improve to the point where they become indistinguishable to human eyes!\n\nFor a more in-depth look at GANs, you can take a look at [this excellent post on Analytics Vidhya](https://www.analyticsvidhya.com/blog/2021/06/a-detailed-explanation-of-gan-with-implementation-using-tensorflow-and-keras/) or this [PyTorch tutorial](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html). For now, though, we'll dive into a demo!\n\n## Step 1 \u2014 Create the Generator model\n\nTo generate new images with a GAN, you only need the generator model. There are many different architectures that the generator could use, but for this demo we'll use a pretrained GAN generator model with the following architecture:\n\n```python\nfrom torch import nn\n\nclass Generator(nn.Module):\n # Refer to the link below for explanations about nc, nz, and ngf\n # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs\n def __init__(self, nc=4, nz=100, ngf=64):\n super(Generator, self).__init__()\n self.network = nn.Sequential(\n nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),\n nn.BatchNorm2d(ngf * 4),\n nn.ReLU(True),\n nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),\n nn.BatchNorm2d(ngf * 2),\n nn.ReLU(True),\n nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),\n nn.BatchNorm2d(ngf),\n nn.ReLU(True),\n nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),\n nn.Tanh(),\n )\n\n def forward(self, input):\n output = self.network(input)\n return output\n```\n\nWe're taking the generator from [this repo by @teddykoker](https://github.com/teddykoker/cryptopunks-gan/blob/main/train.py#L90), where you can also see the original discriminator model structure.\n\nAfter instantiating the model, we'll load in the weights from the Hugging Face Hub, stored at [nateraw/cryptopunks-gan](https://huggingface.co/nateraw/cryptopunks-gan):\n\n```python\nfrom huggingface_hub import hf_hub_download\nimport torch\n\nmodel = Generator()\nweights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')\nmodel.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available\n```\n\n## Step 2 \u2014 Defining a `predict` function\n\nThe `predict` function is the key to making Gradio work! Whatever inputs we choose through the Gradio interface will get passed through our `predict` function, which should operate on the inputs and generate outputs that we can display with Gradio output components. For GANs it's common to pass random noise into our model as the input, so we'll generate a tensor of random numbers and pass that through the model. We can then use `torchvision`'s `save_image` function to save the output of the model as a `png` file, and return the file name:\n\n```python\nfrom torchvision.utils import save_image\n\ndef predict(seed):\n num_punks = 4\n torch.manual_seed(seed)\n z = torch.randn(num_punks, 100, 1, 1)\n punks = model(z)\n save_image(punks, \"punks.png\", normalize=True)\n return 'punks.png'\n```\n\nWe're giving our `predict` function a `seed` parameter, so that we can fix the random tensor generation with a seed. We'll then be able to reproduce punks if we want to see them again by passing in the same seed.\n\n_Note!_ Our model needs an input tensor of dimensions 100x1x1 to do a single inference, or (BatchSize)x100x1x1 for generating a batch of images. In this demo we'll start by generating 4 punks at a time.\n\n## Step 3 \u2014 Creating a Gradio interface\n\nAt this point you can even run the code you have with `predict()`, and you'll find your freshly generated punks in your file system at `./punks.png`. To make a truly interactive demo, though, we'll build out a simple interface with Gradio. Our goals here are to:\n\n- Set a slider input so users can choose the \"seed\" value\n- Use an image component for our output to showcase the generated punks\n- Use our `predict()` to take the seed and generate the images\n\nWith `gr.Interface()`, we can define all of that with a single function call:\n\n```python\nimport gradio as gr\n\ngr.Interface(\n predict,\n inputs=[\n gr.Slider(0, 1000, label='Seed', default=42),\n ],\n outputs=\"image\",\n).launch()\n```\n\n\n## Step 4 \u2014 Even more punks!\n\nGenerating 4 punks at a time is a good start, but maybe we'd like to control how many we want to make each time. Adding more inputs to our Gradio interface is as simple as adding another item to the `inputs` list that we pass to `gr.Interface`:\n\n```python\ngr.Interface(\n predict,\n inputs=[\n gr.Slider(0, 1000, label='Seed', default=42),\n gr.Slider(4, 64, label='Number of Punks', step=1, default=10), # Adding another slider!\n ],\n outputs=\"image\",\n).launch()\n```\n\nThe new input will be passed to our `predict()` function, so we have to make some changes to that function to accept a new parameter:\n\n```python\ndef predict(seed, num_punks):\n torch.manual_seed(seed)\n z = torch.randn(num_punks, 100, 1, 1)\n punks = model(z)\n save_image(punks, \"punks.png\", normalize=True)\n return 'punks.png'\n```\n\nWhen you relaunch your interface, you should see a second slider that'll let you control the number of punks!\n\n## Step 5 - Polishing it up\n\nYour Gradio app is pretty much good to go, but you can add a few extra things to really make it ready for the spotlight \u2728\n\nWe can add some examples that users can easily try out by adding this to the `gr.Interface`:\n\n```python\ngr.Interface(\n # ...\n # keep everything as it is, and then add\n examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],\n).launch(cache_examples=True) # cache_examples is optional\n```\n\nThe `examples` parameter takes a list of lists, where each item in the sublists is ordered in the same order that we've listed the `inputs`. So in our case, `[seed, num_punks]`. Give it a try!\n\nYou can also try adding a `title`, `description`, and `article` to the `gr.Interface`. Each of those parameters accepts a string, so try it out and see what happens \ud83d\udc40 `article` will also accept HTML, as [explored in a previous guide](/guides/key-features/#descriptive-content)!\n\nWhen you're all done, you may end up with something like [this](https://nimaboscarino-cryptopunks.hf.space).\n\nFor reference, here is our full code:\n\n```python\nimport torch\nfrom torch import nn\nfrom huggingface_hub import hf_hub_download\nfrom torchvision.utils import save_image\nimport gradio as gr\n\nclass Generator(nn.Module):\n # Refer to the link below for explanations about nc, nz, and ngf\n # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html#inputs\n def __init__(self, nc=4, nz=100, ngf=64):\n super(Generator, self).__init__()\n self.network = nn.Sequential(\n nn.ConvTranspose2d(nz, ngf * 4, 3, 1, 0, bias=False),\n nn.BatchNorm2d(ngf * 4),\n nn.ReLU(True),\n nn.ConvTranspose2d(ngf * 4, ngf * 2, 3, 2, 1, bias=False),\n nn.BatchNorm2d(ngf * 2),\n nn.ReLU(True),\n nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 0, bias=False),\n nn.BatchNorm2d(ngf),\n nn.ReLU(True),\n nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),\n nn.Tanh(),\n )\n\n def forward(self, input):\n output = self.network(input)\n return output\n\nmodel = Generator()\nweights_path = hf_hub_download('nateraw/cryptopunks-gan', 'generator.pth')\nmodel.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))) # Use 'cuda' if you have a GPU available\n\ndef predict(seed, num_punks):\n torch.manual_seed(seed)\n z = torch.randn(num_punks, 100, 1, 1)\n punks = model(z)\n save_image(punks, \"punks.png\", normalize=True)\n return 'punks.png'\n\ngr.Interface(\n predict,\n inputs=[\n gr.Slider(0, 1000, label='Seed', default=42),\n gr.Slider(4, 64, label='Number of Punks', step=1, default=10),\n ],\n outputs=\"image\",\n examples=[[123, 15], [42, 29], [456, 8], [1337, 35]],\n).launch(cache_examples=True)\n```\n\n---\n\nCongratulations! You've built out your very own GAN-powered CryptoPunks generator, with a fancy Gradio interface that makes it easy for anyone to use. Now you can [scour the Hub for more GANs](https://huggingface.co/models?other=gan) (or train your own) and continue making even more awesome demos \ud83e\udd17\n", "tags": ["GAN", "IMAGE", "HUB"], "spaces": ["https://huggingface.co/spaces/NimaBoscarino/cryptopunks", "https://huggingface.co/spaces/nateraw/cryptopunks-generator"], "url": "/guides/create-your-own-friends-with-a-gan/", "contributor": "Nima Boscarino and Nate Raw"}}