{ "cells": [ { "cell_type": "code", "execution_count": 51, "id": "1eccc83e-bc68-4082-a3cc-b055779b6ee8", "metadata": {}, "outputs": [], "source": [ "# References:\n", "# https://www.tanishq.ai/blog/posts/2021-11-16-gradio-huggingface.html" ] }, { "cell_type": "code", "execution_count": 2, "id": "5b74867e-7ec1-4cda-9d96-0f5cd9cd4810", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import gradio as gr\n", "import torch\n", "from torch import nn\n", "import pickle\n", "from torch import tensor\n", "import torch.nn.functional as F\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 3, "id": "7d6e9e70-83fe-4209-8f06-6542cf6ba11b", "metadata": {}, "outputs": [], "source": [ "with open(\"meta.pkl\", \"rb\") as f:\n", " meta = pickle.load(f)\n", "t2i = meta['t2i']\n", "i2t = meta['i2t']\n", "encode = lambda x: [t2i[c] for c in x]\n", "decode = lambda x: \"\".join([i2t[i] for i in x])" ] }, { "cell_type": "code", "execution_count": 7, "id": "c4a0b480-6775-4d82-9395-9b5a455012ad", "metadata": {}, "outputs": [], "source": [ "batch_size = 128 # B, batch size\n", "block_size = 48 # T, context len for poem is shorter, to set to 48\n", "vocab_size = len(t2i.keys())\n", "nn_emb_size = 64 # nn_emb\n", "n_head = 16\n", "n_layers = 8\n", "\n", "#device = \"cuda\"\n", "devicd = \"cpu\"" ] }, { "cell_type": "code", "execution_count": 8, "id": "0e4e72ce-5f61-4831-b7e8-703ed171936b", "metadata": {}, "outputs": [], "source": [ "def encode_pad(s):\n", " if len(s) >= block_size:\n", " sample = s[:block_size]\n", " else:\n", " sample = s\n", " sample = encode(s)\n", " sample = [0]*(block_size-len(sample)) + sample \n", " inp = tensor(sample[:block_size])[None,...]\n", " return inp" ] }, { "cell_type": "code", "execution_count": 9, "id": "a9bc886f-4ec8-458a-b847-c9996df57fa9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Model(\n", " (tk_emb): Embedding(7475, 64)\n", " (pos_emb): Embedding(48, 64)\n", " (ln): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (attention_blocks): ModuleList(\n", " (0-7): 8 x AttentionBlock(\n", " (emb_proj): Linear(in_features=64, out_features=192, bias=True)\n", " (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (mult_head): MultiheadAttention(\n", " (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)\n", " )\n", " (ln_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (ff): Sequential(\n", " (0): Linear(in_features=64, out_features=256, bias=True)\n", " (1): GELU(approximate='none')\n", " (2): Dropout(p=0.2, inplace=False)\n", " (3): Linear(in_features=256, out_features=64, bias=True)\n", " (4): GELU(approximate='none')\n", " (5): Dropout(p=0.2, inplace=False)\n", " )\n", " )\n", " )\n", " (ln_h): Linear(in_features=64, out_features=7475, bias=True)\n", ")" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "class AttentionBlock(nn.Module):\n", " def __init__(self, nn_emb = nn_emb_size, block_size = block_size, n_head = n_head):\n", " super().__init__()\n", " self.nn_emb = nn_emb_size\n", " self.block_size = block_size\n", " self.n_head = n_head\n", "\n", " self.emb_proj = nn.Linear(nn_emb, nn_emb * 3)\n", " self.ln_1 = nn.LayerNorm(nn_emb) \n", " self.mult_head = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True)\n", " self.ln_2 = nn.LayerNorm(nn_emb) \n", " self.ff = nn.Sequential(nn.Linear(nn_emb, nn_emb * 4),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb * 4, nn_emb), nn.GELU(), nn.Dropout(0.2))\n", "\n", " def forward(self,x): # (B, T, nn_emb)\n", " x1 = x\n", " x = self.emb_proj(x) # (B, T, nn_emb*3)\n", " q,k,v = x.split(self.nn_emb, dim=2)\n", " x,_ = self.mult_head(q, k, v, key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb)\n", " x = x+x1\n", " x = self.ff(self.ln_2(x)) + x\n", " return x\n", " \n", " \n", "class Model(nn.Module):\n", " def __init__(self, nn_emb = nn_emb_size, block_size = block_size,vocab_size = vocab_size, n_head = n_head, n_layers = n_layers): \n", " super().__init__()\n", " self.vocab_size = vocab_size\n", " self.block_size = block_size\n", " self.nn_emb = nn_emb\n", " self.n_head = n_head\n", " self.n_layers = n_layers\n", " \n", " self.tk_emb = nn.Embedding(vocab_size, nn_emb)\n", " self.pos_emb = nn.Embedding(block_size, nn_emb)\n", " self.ln = nn.LayerNorm(nn_emb)\n", " #self.emb_proj = nn.Linear(nn_emb, nn_emb * 3)\n", " #self.atten = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True)\n", " self.attention_blocks = nn.ModuleList( [AttentionBlock(nn_emb, block_size, n_head)] * n_layers)\n", " #self.h = nn.Sequential(nn.Linear(nn_emb, nn_emb),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb, nn_emb), nn.GELU(), nn.Dropout(0.2))\n", " self.ln_h = nn.Linear(nn_emb, self.vocab_size)\n", "\n", " def forward(self, inp, targ = None): # inp is (B, T), targ is (B, T)\n", " inp.to(device)\n", " tk = self.tk_emb(inp) # (B,T,nn_emb)\n", " positions = torch.arange(self.block_size).to(device)\n", " #print(positions)\n", " pos = self.pos_emb(positions) # (T,nn_emb)\n", " x = tk + pos # (B,T,nn_emb)\n", " #x = self.ln(x) \n", " #a = x\n", " #x = self.emb_proj(x) # (B,t,nn_emb*3)\n", " for blk in self.attention_blocks:\n", " x = blk(x)\n", " #q,k,v = x.split(self.nn_emb, dim=2)\n", " #x,_ = self.atten(q, k, v, key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb)\n", " #x = x + a\n", " #x = self.ln(x) \n", " #x = x+self.h(x) # (B,T,nn_emb)\n", " x = self.ln(x) # (B,T,nn_emb) \n", " x = self.ln_h(x) # (B,T,vocab_size)\n", " if targ == None:\n", " loss = None\n", " else:\n", " targ.to(device)\n", " loss = F.cross_entropy(x.view(-1, x.shape[-1]), targ.view(-1))\n", " return x, loss\n", "\n", "m = Model()\n", "m.to(device)" ] }, { "cell_type": "code", "execution_count": 20, "id": "95545bf7-51fa-45a8-b34d-0231aa95e300", "metadata": {}, "outputs": [], "source": [ "with open(\"model_v4.pkl\",\"rb\") as f:\n", " m=pickle.load(f)" ] }, { "cell_type": "code", "execution_count": 21, "id": "c2393e78-a1c6-4671-9170-4ea33cdb50d1", "metadata": {}, "outputs": [], "source": [ "top_k = 20\n", "def generate(s, num = 60):\n", "\n", " for i in range(num + num):\n", " inp = s[-block_size:]\n", " inp = encode_pad(inp).to(device)\n", " out, loss = m(inp)\n", " out = out[:,-1,:]\n", " if top_k is not None:\n", " v, _ = torch.topk(out, min(top_k, out.size(-1)))\n", " out[out < v[:, [-1]]] = -float('Inf') \n", " prob = torch.softmax(out[:,:], dim=-1)\n", " g = torch.multinomial(prob, num_samples=1)\n", " next_c = i2t[g[0].item()]\n", " if next_c in s and next_c != '。' and next_c != ',':\n", " continue\n", " s = s + next_c\n", "\n", " if (len(s) > num and s[-1] == \"。\"):\n", " break\n", " return s" ] }, { "cell_type": "code", "execution_count": 24, "id": "170b95ca-74b9-4360-84cc-6a8dfa3f8c42", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'终南。若问黄云一路在,更有东城上去时。不须为别故园庐,独坐江山半夜凉。此地无馀春树晚,今朝日暮向来迟。西北天津长望后,三湘月下烟中。'" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generate('终南。')" ] }, { "cell_type": "code", "execution_count": 26, "id": "edca19ab-087b-4368-84d0-8eee7388c200", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7867\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "inputs = [gr.Textbox(label=\"Input\",\n", " info=\"Enter some Chinese text to start generate\",\n", " lines=3,\n", " value=\"终南。\",)]\n", "\n", "outputs = [ gr.Textbox(\n", " label=\"Output\",\n", " info=\"Generated Poem\",\n", " lines=3,\n", " value=\"\", )]\n", "gr.Interface(fn=generate, inputs=inputs, outputs=outputs, title=\"Enter Chinese text to generate Chinese Poem.\").launch(share=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "6112eaea-16d6-4d43-8b95-3999c605643b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }