File size: 5,287 Bytes
bb8a317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c45e9cd
bb8a317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9fe27d
 
bb8a317
b43a417
e9fe27d
bb8a317
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8649f7
 
bb8a317
 
b8649f7
bb8a317
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# References:
# https://www.tanishq.ai/blog/posts/2021-11-16-gradio-huggingface.html

import numpy as np
import pandas as pd
import gradio as gr
import torch
from torch import nn
import pickle
from torch import tensor
import torch.nn.functional as F
import pandas as pd

with open("meta.pkl", "rb") as f:
    meta = pickle.load(f)
t2i = meta['t2i']
i2t = meta['i2t']
encode = lambda x: [t2i[c] for c in x]
decode = lambda x: "".join([i2t[i] for i in x])

batch_size = 128 # B, batch size
block_size = 48 # T, context len for poem is shorter, to set to 48
vocab_size = len(t2i.keys())
nn_emb_size = 64 # nn_emb
n_head = 16
n_layers = 8

#device = "cuda"
device = "cpu"

def encode_pad(s):
    if len(s) >= block_size:
        sample = s[:block_size]
    else:
        sample = s
    sample = encode(s)
    sample = [0]*(block_size-len(sample)) + sample    
    inp = tensor(sample[:block_size])[None,...]
    return inp

class AttentionBlock(nn.Module):
    def __init__(self, nn_emb = nn_emb_size, block_size = block_size, n_head = n_head):
        super().__init__()
        self.nn_emb = nn_emb_size
        self.block_size = block_size
        self.n_head = n_head

        self.emb_proj = nn.Linear(nn_emb, nn_emb * 3)
        self.ln_1 = nn.LayerNorm(nn_emb) 
        self.mult_head = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True)
        self.ln_2 = nn.LayerNorm(nn_emb) 
        self.ff = nn.Sequential(nn.Linear(nn_emb, nn_emb * 4),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb * 4, nn_emb), nn.GELU(), nn.Dropout(0.2))

    def forward(self,x): # (B, T, nn_emb)
        x1 = x
        x = self.emb_proj(x) # (B, T, nn_emb*3)
        q,k,v = x.split(self.nn_emb, dim=2)
        x,_ = self.mult_head(q, k, v,  key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb)
        x = x+x1
        x = self.ff(self.ln_2(x)) + x
        return x
        
        
class Model(nn.Module):
    def __init__(self, nn_emb = nn_emb_size, block_size = block_size,vocab_size = vocab_size, n_head = n_head, n_layers = n_layers): 
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.nn_emb = nn_emb
        self.n_head = n_head
        self.n_layers = n_layers
        
        self.tk_emb = nn.Embedding(vocab_size, nn_emb)
        self.pos_emb = nn.Embedding(block_size, nn_emb)
        self.ln = nn.LayerNorm(nn_emb)
        #self.emb_proj = nn.Linear(nn_emb, nn_emb * 3)
        #self.atten = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True)
        self.attention_blocks = nn.ModuleList( [AttentionBlock(nn_emb, block_size, n_head)] * n_layers)
        #self.h = nn.Sequential(nn.Linear(nn_emb, nn_emb),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb, nn_emb), nn.GELU(), nn.Dropout(0.2))
        self.ln_h = nn.Linear(nn_emb, self.vocab_size)

    def forward(self, inp, targ = None): # inp is (B, T), targ is (B, T)
        inp.to(device)
        tk = self.tk_emb(inp) # (B,T,nn_emb)
        positions = torch.arange(self.block_size).to(device)
        #print(positions)
        pos = self.pos_emb(positions) # (T,nn_emb)
        x = tk + pos # (B,T,nn_emb)
        #x = self.ln(x)        
        #a = x
        #x = self.emb_proj(x) # (B,t,nn_emb*3)
        for blk in self.attention_blocks:
            x = blk(x)
        #q,k,v = x.split(self.nn_emb, dim=2)
        #x,_ = self.atten(q, k, v,  key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb)
        #x = x + a
        #x = self.ln(x)                
        #x = x+self.h(x) # (B,T,nn_emb)
        x = self.ln(x) # (B,T,nn_emb)                
        x = self.ln_h(x) # (B,T,vocab_size)
        if targ == None:
            loss = None
        else:
            targ.to(device)
            loss = F.cross_entropy(x.view(-1, x.shape[-1]), targ.view(-1))
        return x, loss

#m = Model()
#m.to(device)

m=torch.load("model_v4t.pkl",map_location=torch.device('cpu'))
m.eval()

top_k = 20
def generate(s, num = 60):

    for i in range(num + num):
        inp = s[-block_size:]
        inp = encode_pad(inp).to(device)
        out, loss = m(inp)
        out = out[:,-1,:]
        if top_k is not None:
            v, _ = torch.topk(out, min(top_k, out.size(-1)))
            out[out < v[:, [-1]]] = -float('Inf')        
        prob = torch.softmax(out[:,:], dim=-1)
        g = torch.multinomial(prob, num_samples=1)
        next_c = i2t[g[0].item()]
        if next_c in s and next_c != '。' and next_c != ',':
            continue
        s = s + next_c
        yield s
        
        if (len(s) > num and s[-1] == "。"):
            break
    #return s

inputs = [gr.Textbox(label="Input",
            info="Enter some Chinese text to start generate",
            lines=3,
            value="η»ˆε—γ€‚",)]

outputs = [ gr.Textbox(
            label="Output",
            info="Generated Poem",
            lines=3,
            value="", )]
gr.Interface(fn=generate, inputs=inputs, outputs=outputs, title="Enter Chinese text to generate Chinese Poem.").launch(share=True)