# References: # https://www.tanishq.ai/blog/posts/2021-11-16-gradio-huggingface.html import numpy as np import pandas as pd import gradio as gr import torch from torch import nn import pickle from torch import tensor import torch.nn.functional as F import pandas as pd with open("meta.pkl", "rb") as f: meta = pickle.load(f) t2i = meta['t2i'] i2t = meta['i2t'] encode = lambda x: [t2i[c] for c in x] decode = lambda x: "".join([i2t[i] for i in x]) batch_size = 128 # B, batch size block_size = 48 # T, context len for poem is shorter, to set to 48 vocab_size = len(t2i.keys()) nn_emb_size = 64 # nn_emb n_head = 16 n_layers = 8 #device = "cuda" device = "cpu" def encode_pad(s): if len(s) >= block_size: sample = s[:block_size] else: sample = s sample = encode(s) sample = [0]*(block_size-len(sample)) + sample inp = tensor(sample[:block_size])[None,...] return inp class AttentionBlock(nn.Module): def __init__(self, nn_emb = nn_emb_size, block_size = block_size, n_head = n_head): super().__init__() self.nn_emb = nn_emb_size self.block_size = block_size self.n_head = n_head self.emb_proj = nn.Linear(nn_emb, nn_emb * 3) self.ln_1 = nn.LayerNorm(nn_emb) self.mult_head = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True) self.ln_2 = nn.LayerNorm(nn_emb) self.ff = nn.Sequential(nn.Linear(nn_emb, nn_emb * 4),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb * 4, nn_emb), nn.GELU(), nn.Dropout(0.2)) def forward(self,x): # (B, T, nn_emb) x1 = x x = self.emb_proj(x) # (B, T, nn_emb*3) q,k,v = x.split(self.nn_emb, dim=2) x,_ = self.mult_head(q, k, v, key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb) x = x+x1 x = self.ff(self.ln_2(x)) + x return x class Model(nn.Module): def __init__(self, nn_emb = nn_emb_size, block_size = block_size,vocab_size = vocab_size, n_head = n_head, n_layers = n_layers): super().__init__() self.vocab_size = vocab_size self.block_size = block_size self.nn_emb = nn_emb self.n_head = n_head self.n_layers = n_layers self.tk_emb = nn.Embedding(vocab_size, nn_emb) self.pos_emb = nn.Embedding(block_size, nn_emb) self.ln = nn.LayerNorm(nn_emb) #self.emb_proj = nn.Linear(nn_emb, nn_emb * 3) #self.atten = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True) self.attention_blocks = nn.ModuleList( [AttentionBlock(nn_emb, block_size, n_head)] * n_layers) #self.h = nn.Sequential(nn.Linear(nn_emb, nn_emb),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb, nn_emb), nn.GELU(), nn.Dropout(0.2)) self.ln_h = nn.Linear(nn_emb, self.vocab_size) def forward(self, inp, targ = None): # inp is (B, T), targ is (B, T) inp.to(device) tk = self.tk_emb(inp) # (B,T,nn_emb) positions = torch.arange(self.block_size).to(device) #print(positions) pos = self.pos_emb(positions) # (T,nn_emb) x = tk + pos # (B,T,nn_emb) #x = self.ln(x) #a = x #x = self.emb_proj(x) # (B,t,nn_emb*3) for blk in self.attention_blocks: x = blk(x) #q,k,v = x.split(self.nn_emb, dim=2) #x,_ = self.atten(q, k, v, key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb) #x = x + a #x = self.ln(x) #x = x+self.h(x) # (B,T,nn_emb) x = self.ln(x) # (B,T,nn_emb) x = self.ln_h(x) # (B,T,vocab_size) if targ == None: loss = None else: targ.to(device) loss = F.cross_entropy(x.view(-1, x.shape[-1]), targ.view(-1)) return x, loss #m = Model() #m.to(device) m=torch.load("model_v4t.pkl",map_location=torch.device('cpu')) m.eval() top_k = 20 def generate(s, num = 60): for i in range(num + num): inp = s[-block_size:] inp = encode_pad(inp).to(device) out, loss = m(inp) out = out[:,-1,:] if top_k is not None: v, _ = torch.topk(out, min(top_k, out.size(-1))) out[out < v[:, [-1]]] = -float('Inf') prob = torch.softmax(out[:,:], dim=-1) g = torch.multinomial(prob, num_samples=1) next_c = i2t[g[0].item()] if next_c in s and next_c != '。' and next_c != ',': continue s = s + next_c yield s if (len(s) > num and s[-1] == "。"): break #return s inputs = [gr.Textbox(label="Input", info="Enter some Chinese text to start generate", lines=3, value="终南。",)] outputs = [ gr.Textbox( label="Output", info="Generated Poem", lines=3, value="", )] gr.Interface(fn=generate, inputs=inputs, outputs=outputs, title="Enter Chinese text to generate Chinese Poem.").launch(share=True)