File size: 8,046 Bytes
5896126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f51d3e7
5896126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#! /bin/sh
test -f ja_gsd_modern.conllu || curl -LO https://github.com/KoichiYasuoka/SuPar-UniDic/raw/main/suparunidic/suparmodels/ja_gsd_modern.conllu
test -f JapaneseCoreKanji.txt || curl -LO https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt

if [ ! -d exSwallow-7b-plus-hf ]
then TMPA=./maker$$a.py
     cat << 'EOF' > $TMPA
#! /usr/bin/python3
src="tokyotech-llm/Swallow-7b-plus-hf"
tgt="exSwallow-7b-plus-hf"
import json,torch,unicodedata
from transformers import LlamaTokenizerFast,LlamaForCausalLM
with open("JapaneseCoreKanji.txt","r",encoding="utf-8") as r:
  cjk=[chr(int(t,16)) for t in r.read().strip().split("\n") if not t.startswith("#")]
with open("ja_gsd_modern.conllu","r",encoding="utf-8") as r:
  for s in r:
    t=s.split("\t")
    if len(t)==10:
      for c in t[1]:
        if unicodedata.name(c).startswith("CJK "):
          cjk.append(c)
cjk=list(set(cjk))
tkz=LlamaTokenizerFast.from_pretrained(src,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
c={i:j[2:] for i,j in zip(cjk,tkz(cjk)["input_ids"]) if len(j)>3}
d=json.loads(tkz.backend_tokenizer.to_str())
for i,j in enumerate(c,len(tkz)):
  d["model"]["vocab"][j]=i
tkz.backend_tokenizer.from_str(json.dumps(d)).save("tokenizer.json")
mdl=LlamaForCausalLM.from_pretrained(src)
tkz=LlamaTokenizerFast(tokenizer_file="tokenizer.json",model_max_length=mdl.config.max_position_embeddings,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
e=mdl.resize_token_embeddings(len(tkz))
f=mdl.get_output_embeddings()
with torch.no_grad():
  for k,v in c.items():
    e.weight[d["model"]["vocab"][k],:]=e.weight[v,:].sum(0)
    f.weight[d["model"]["vocab"][k],:]=f.weight[v,:].sum(0)
mdl.set_input_embeddings(e)
mdl.set_output_embeddings(f)
mdl.save_pretrained(tgt)
tkz.save_pretrained(tgt)
EOF
     chmod 755 $TMPA
     $TMPA
fi

TMPB=./maker$$b.py
cat << 'EOF' > $TMPB
#! /usr/bin/env deepspeed
src="exSwallow-7b-plus-hf"
tgt="KoichiYasuoka/Swallow-7b-plus-upos"
from transformers import LlamaTokenizerFast,LlamaModel,LlamaPreTrainedModel,AutoConfig,DataCollatorForTokenClassification,TrainingArguments,Trainer
from transformers.modeling_outputs import TokenClassifierOutput
from tokenizers.normalizers import Replace

class LlamaForTokenClassification(LlamaPreTrainedModel):
  def __init__(self,config):
    from torch import nn
    super().__init__(config)
    self.num_labels=config.num_labels
    self.model=LlamaModel(config)
    if hasattr(config,"classifier_dropout") and config.classifier_dropout is not None:
      classifier_dropout=config.classifier_dropout
    elif hasattr(config,"hidden_dropout") and config.hidden_dropout is not None:
      classifier_dropout=config.hidden_dropout
    else:
      classifier_dropout=0.1
    self.dropout=nn.Dropout(classifier_dropout)
    self.classifier=nn.Linear(config.hidden_size,config.num_labels)
    self.post_init()
  def get_input_embeddings(self):
    return self.model.embed_tokens
  def set_input_embeddings(self,value):
    self.model.embed_tokens=value
  def forward(self,input_ids=None,past_key_values=None,attention_mask=None,position_ids=None,inputs_embeds=None,labels=None,use_cache=None,output_attentions=None,output_hidden_states=None,return_dict=None):
    return_dict=return_dict if return_dict is not None else self.config.use_return_dict
    transformer_outputs=self.model(input_ids,past_key_values=past_key_values,attention_mask=attention_mask,position_ids=position_ids,inputs_embeds=inputs_embeds,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict)
    hidden_states=transformer_outputs[0]
    hidden_states=self.dropout(hidden_states)
    logits=self.classifier(hidden_states)
    loss=None
    if labels is not None:
      from torch import nn
      loss_fct=nn.CrossEntropyLoss()
      loss=loss_fct(logits.view(-1,self.num_labels),labels.view(-1))
    if not return_dict:
      output=(logits,)+transformer_outputs[2:]
      return ((loss,)+output) if loss is not None else output
    return TokenClassifierOutput(loss=loss,logits=logits,hidden_states=transformer_outputs.hidden_states,attentions=transformer_outputs.attentions)

class UPOSFileDataset(object):
  def __init__(self,conllu,tokenizer):
    self.conllu=open(conllu,"r",encoding="utf-8")
    self.tokenizer=tokenizer
    self.seeks=[0]
    self.multiword={}
    label=set(["SYM"])
    s=self.conllu.readline()
    while s!="":
      if s=="\n":
        self.seeks.append(self.conllu.tell())
      else:
        w=s.split("\t")
        if len(w)==10:
          if w[0].isdecimal():
            label.add(w[3] if w[5]=="_" else w[3]+"|"+w[5])
          elif w[0].find("-")>0:
            t=w[0].split("-")
            f,j,k=w[1],[],[]
            for i in range(int(t[0]),int(t[1])+1):
              w=self.conllu.readline().split("\t")
              j.append(w[3] if w[5]=="_" else w[3]+"|"+w[5])
              k.append(w[1])
            p="+".join(j)
            label.add(p)
            if p in self.multiword:
              self.multiword[p][f]=list(k)
            else:
              self.multiword[p]={f:list(k)}
      s=self.conllu.readline()
    lid={}
    for i,l in enumerate(sorted(label)):
      lid[l],lid["B-"+l],lid["I-"+l]=i*3,i*3+1,i*3+2
    self.label2id=lid
  def __call__(*args):
    lid={l:i for i,l in enumerate(sorted(set(sum([list(t.label2id) for t in args],[]))))}
    for t in args:
      t.label2id=lid
    return lid
  def __del__(self):
    self.conllu.close()
  __len__=lambda self:len(self.seeks)-1
  def __getitem__(self,i):
    self.conllu.seek(self.seeks[i])
    form,upos=[],[]
    while self.conllu.tell()<self.seeks[i+1]:
      w=self.conllu.readline().split("\t")
      if len(w)==10:
        form.append(w[1])
        if w[0].isdecimal():
          upos.append(w[3] if w[5]=="_" else w[3]+"|"+w[5])
        elif w[0].find("-")>0:
          t=w[0].split("-")
          u=[]
          for j in range(int(t[0]),int(t[1])+1):
            k=self.conllu.readline().split("\t")
            u.append(k[3] if k[5]=="_" else k[3]+"|"+k[5])
          upos.append("+".join(u))
    v=self.tokenizer(form,add_special_tokens=False)
    i,u=[],[]
    for j,(x,y) in enumerate(zip(v["input_ids"],upos)):
      if x!=[]:
        i+=x
        u+=[y] if len(x)==1 else ["B-"+y]+["I-"+y]*(len(x)-1)
    if len(i)<self.tokenizer.model_max_length-3:
      ids=[self.tokenizer.cls_token_id]+i+[self.tokenizer.sep_token_id]
      upos=["SYM"]+u+["SYM"]
    else:
      ids=i[0:self.tokenizer.model_max_length-2]
      upos=u[0:self.tokenizer.model_max_length-2]
    return {"input_ids":ids,"labels":[self.label2id[t] for t in upos]}

tkz=LlamaTokenizerFast.from_pretrained(src)
tkz.backend_tokenizer.normalizer=Replace(" ","\u2581")
tkz.backend_tokenizer.model.byte_fallback=False
trainDS=UPOSFileDataset("ja_gsd_modern.conllu",tkz)
lid=trainDS.label2id
cfg=AutoConfig.from_pretrained(src,num_labels=len(lid),label2id=lid,id2label={i:l for l,i in lid.items()},ignore_mismatched_sizes=True)
dsp={"fp16":{"enabled":"auto"},"optimizer":{"type":"AdamW"},"scheduler":{"type":"WarmupLR","params":{}},"train_batch_size":"auto","train_micro_batch_size_per_gpu":"auto","zero_optimization":{"stage":3,"offload_optimizer":{"device":"cpu","pin_memory":True},"offload_param":{"device":"cpu","pin_memory":True},"overlap_comm":True,"contiguous_gradients":True,"reduce_bucket_size":"auto","stage3_prefetch_bucket_size":"auto","stage3_param_persistence_threshold":"auto","stage3_gather_16bit_weights_on_model_save":True}}
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=8,deepspeed=dsp,output_dir=tgt,overwrite_output_dir=True,save_total_limit=2,learning_rate=5e-05,warmup_ratio=0.1,save_safetensors=False)
trn=Trainer(args=arg,data_collator=DataCollatorForTokenClassification(tkz),model=LlamaForTokenClassification.from_pretrained(src,config=cfg,ignore_mismatched_sizes=True),train_dataset=trainDS)
trn.train()
trn.save_model(tgt)
tkz.save_pretrained(tgt)
EOF
chmod 755 $TMPB
$TMPB
exit