# References:
import numpy as np
import pandas as pd
import gradio as gr
import torch
from torch import nn
import pickle
from torch import tensor
import torch.nn.functional as F
import pandas as pd
with open("meta.pkl", "rb") as f:
meta = pickle.load(f)
t2i = meta['t2i']
i2t = meta['i2t']
encode = lambda x: [t2i[c] for c in x]
decode = lambda x: "".join([i2t[i] for i in x])
batch_size = 128 # B, batch size
block_size = 48 # T, context len for poem is shorter, to set to 48
vocab_size = len(t2i.keys())
nn_emb_size = 64 # nn_emb
n_head = 16
n_layers = 8
#device = "cuda"
device = "cpu"
def encode_pad(s):
if len(s) >= block_size:
sample = s[:block_size]
sample = s
sample = encode(s)
sample = [0]*(block_size-len(sample)) + sample
inp = tensor(sample[:block_size])[None,...]
return inp
class AttentionBlock(nn.Module):
def __init__(self, nn_emb = nn_emb_size, block_size = block_size, n_head = n_head):
self.nn_emb = nn_emb_size
self.block_size = block_size
self.n_head = n_head
self.emb_proj = nn.Linear(nn_emb, nn_emb * 3)
self.ln_1 = nn.LayerNorm(nn_emb)
self.mult_head = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True)
self.ln_2 = nn.LayerNorm(nn_emb)
self.ff = nn.Sequential(nn.Linear(nn_emb, nn_emb * 4),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb * 4, nn_emb), nn.GELU(), nn.Dropout(0.2))
def forward(self,x): # (B, T, nn_emb)
x1 = x
x = self.emb_proj(x) # (B, T, nn_emb*3)
q,k,v = x.split(self.nn_emb, dim=2)
x,_ = self.mult_head(q, k, v, key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb)
x = x+x1
x = self.ff(self.ln_2(x)) + x
return x
class Model(nn.Module):
def __init__(self, nn_emb = nn_emb_size, block_size = block_size,vocab_size = vocab_size, n_head = n_head, n_layers = n_layers):
self.vocab_size = vocab_size
self.block_size = block_size
self.nn_emb = nn_emb
self.n_head = n_head
self.n_layers = n_layers
self.tk_emb = nn.Embedding(vocab_size, nn_emb)
self.pos_emb = nn.Embedding(block_size, nn_emb)
self.ln = nn.LayerNorm(nn_emb)
#self.emb_proj = nn.Linear(nn_emb, nn_emb * 3)
#self.atten = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True)
self.attention_blocks = nn.ModuleList( [AttentionBlock(nn_emb, block_size, n_head)] * n_layers)
#self.h = nn.Sequential(nn.Linear(nn_emb, nn_emb),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb, nn_emb), nn.GELU(), nn.Dropout(0.2))
self.ln_h = nn.Linear(nn_emb, self.vocab_size)
def forward(self, inp, targ = None): # inp is (B, T), targ is (B, T)
tk = self.tk_emb(inp) # (B,T,nn_emb)
positions = torch.arange(self.block_size).to(device)
pos = self.pos_emb(positions) # (T,nn_emb)
x = tk + pos # (B,T,nn_emb)
#x = self.ln(x)
#a = x
#x = self.emb_proj(x) # (B,t,nn_emb*3)
for blk in self.attention_blocks:
x = blk(x)
#q,k,v = x.split(self.nn_emb, dim=2)
#x,_ = self.atten(q, k, v, key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb)
#x = x + a
#x = self.ln(x)
#x = x+self.h(x) # (B,T,nn_emb)
x = self.ln(x) # (B,T,nn_emb)
x = self.ln_h(x) # (B,T,vocab_size)
if targ == None:
loss = None
loss = F.cross_entropy(x.view(-1, x.shape[-1]), targ.view(-1))
return x, loss
#m = Model()
top_k = 20
def generate(s, num = 60):
for i in range(num + num):
inp = s[-block_size:]
inp = encode_pad(inp).to(device)
out, loss = m(inp)
out = out[:,-1,:]
if top_k is not None:
v, _ = torch.topk(out, min(top_k, out.size(-1)))
out[out < v[:, [-1]]] = -float('Inf')
prob = torch.softmax(out[:,:], dim=-1)
g = torch.multinomial(prob, num_samples=1)
next_c = i2t[g[0].item()]
if next_c in s and next_c != '。' and next_c != ',':
s = s + next_c
yield s
if (len(s) > num and s[-1] == "。"):
#return s
inputs = [gr.Textbox(label="Input",
info="Enter some Chinese text to start generate",
outputs = [ gr.Textbox(
info="Generated Poem",
value="", )]
gr.Interface(fn=generate, inputs=inputs, outputs=outputs, title="Enter Chinese text to generate Chinese Poem.").launch(share=True)