File size: 5,287 Bytes
bb8a317 c45e9cd bb8a317 e9fe27d bb8a317 b43a417 e9fe27d bb8a317 b8649f7 bb8a317 b8649f7 bb8a317 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# References:
# https://www.tanishq.ai/blog/posts/2021-11-16-gradio-huggingface.html
import numpy as np
import pandas as pd
import gradio as gr
import torch
from torch import nn
import pickle
from torch import tensor
import torch.nn.functional as F
import pandas as pd
with open("meta.pkl", "rb") as f:
meta = pickle.load(f)
t2i = meta['t2i']
i2t = meta['i2t']
encode = lambda x: [t2i[c] for c in x]
decode = lambda x: "".join([i2t[i] for i in x])
batch_size = 128 # B, batch size
block_size = 48 # T, context len for poem is shorter, to set to 48
vocab_size = len(t2i.keys())
nn_emb_size = 64 # nn_emb
n_head = 16
n_layers = 8
#device = "cuda"
device = "cpu"
def encode_pad(s):
if len(s) >= block_size:
sample = s[:block_size]
else:
sample = s
sample = encode(s)
sample = [0]*(block_size-len(sample)) + sample
inp = tensor(sample[:block_size])[None,...]
return inp
class AttentionBlock(nn.Module):
def __init__(self, nn_emb = nn_emb_size, block_size = block_size, n_head = n_head):
super().__init__()
self.nn_emb = nn_emb_size
self.block_size = block_size
self.n_head = n_head
self.emb_proj = nn.Linear(nn_emb, nn_emb * 3)
self.ln_1 = nn.LayerNorm(nn_emb)
self.mult_head = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True)
self.ln_2 = nn.LayerNorm(nn_emb)
self.ff = nn.Sequential(nn.Linear(nn_emb, nn_emb * 4),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb * 4, nn_emb), nn.GELU(), nn.Dropout(0.2))
def forward(self,x): # (B, T, nn_emb)
x1 = x
x = self.emb_proj(x) # (B, T, nn_emb*3)
q,k,v = x.split(self.nn_emb, dim=2)
x,_ = self.mult_head(q, k, v, key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb)
x = x+x1
x = self.ff(self.ln_2(x)) + x
return x
class Model(nn.Module):
def __init__(self, nn_emb = nn_emb_size, block_size = block_size,vocab_size = vocab_size, n_head = n_head, n_layers = n_layers):
super().__init__()
self.vocab_size = vocab_size
self.block_size = block_size
self.nn_emb = nn_emb
self.n_head = n_head
self.n_layers = n_layers
self.tk_emb = nn.Embedding(vocab_size, nn_emb)
self.pos_emb = nn.Embedding(block_size, nn_emb)
self.ln = nn.LayerNorm(nn_emb)
#self.emb_proj = nn.Linear(nn_emb, nn_emb * 3)
#self.atten = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True)
self.attention_blocks = nn.ModuleList( [AttentionBlock(nn_emb, block_size, n_head)] * n_layers)
#self.h = nn.Sequential(nn.Linear(nn_emb, nn_emb),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb, nn_emb), nn.GELU(), nn.Dropout(0.2))
self.ln_h = nn.Linear(nn_emb, self.vocab_size)
def forward(self, inp, targ = None): # inp is (B, T), targ is (B, T)
inp.to(device)
tk = self.tk_emb(inp) # (B,T,nn_emb)
positions = torch.arange(self.block_size).to(device)
#print(positions)
pos = self.pos_emb(positions) # (T,nn_emb)
x = tk + pos # (B,T,nn_emb)
#x = self.ln(x)
#a = x
#x = self.emb_proj(x) # (B,t,nn_emb*3)
for blk in self.attention_blocks:
x = blk(x)
#q,k,v = x.split(self.nn_emb, dim=2)
#x,_ = self.atten(q, k, v, key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb)
#x = x + a
#x = self.ln(x)
#x = x+self.h(x) # (B,T,nn_emb)
x = self.ln(x) # (B,T,nn_emb)
x = self.ln_h(x) # (B,T,vocab_size)
if targ == None:
loss = None
else:
targ.to(device)
loss = F.cross_entropy(x.view(-1, x.shape[-1]), targ.view(-1))
return x, loss
#m = Model()
#m.to(device)
m=torch.load("model_v4t.pkl",map_location=torch.device('cpu'))
m.eval()
top_k = 20
def generate(s, num = 60):
for i in range(num + num):
inp = s[-block_size:]
inp = encode_pad(inp).to(device)
out, loss = m(inp)
out = out[:,-1,:]
if top_k is not None:
v, _ = torch.topk(out, min(top_k, out.size(-1)))
out[out < v[:, [-1]]] = -float('Inf')
prob = torch.softmax(out[:,:], dim=-1)
g = torch.multinomial(prob, num_samples=1)
next_c = i2t[g[0].item()]
if next_c in s and next_c != 'γ' and next_c != 'οΌ':
continue
s = s + next_c
yield s
if (len(s) > num and s[-1] == "γ"):
break
#return s
inputs = [gr.Textbox(label="Input",
info="Enter some Chinese text to start generate",
lines=3,
value="η»εγ",)]
outputs = [ gr.Textbox(
label="Output",
info="Generated Poem",
lines=3,
value="", )]
gr.Interface(fn=generate, inputs=inputs, outputs=outputs, title="Enter Chinese text to generate Chinese Poem.").launch(share=True) |