Swallow-MS-7b-upos / maker.sh
KoichiYasuoka's picture
bug fix
6cf32b2
#! /bin/sh
test -f ja_gsd_modern.conllu || curl -LO https://github.com/KoichiYasuoka/SuPar-UniDic/raw/main/suparunidic/suparmodels/ja_gsd_modern.conllu
test -f JapaneseCoreKanji.txt || curl -LO https://www.unicode.org/wg2/iso10646/edition6/data/JapaneseCoreKanji.txt
if [ ! -d exSwallow-MS-7b-v0.1 ]
then TMPA=./maker$$a.py
cat << 'EOF' > $TMPA
#! /usr/bin/python3
src="tokyotech-llm/Swallow-MS-7b-v0.1"
tgt="exSwallow-MS-7b-v0.1"
import json,torch,unicodedata
from transformers import LlamaTokenizerFast,LlamaForCausalLM
with open("JapaneseCoreKanji.txt","r",encoding="utf-8") as r:
cjk=[chr(int(t,16)) for t in r.read().strip().split("\n") if not t.startswith("#")]
with open("ja_gsd_modern.conllu","r",encoding="utf-8") as r:
for s in r:
t=s.split("\t")
if len(t)==10:
for c in t[1]:
if unicodedata.name(c).startswith("CJK "):
cjk.append(c)
cjk=list(set(cjk))
tkz=LlamaTokenizerFast.from_pretrained(src,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
c={i:j[2:] for i,j in zip(cjk,tkz(cjk)["input_ids"]) if len(j)>3}
d=json.loads(tkz.backend_tokenizer.to_str())
for i,j in enumerate(c,len(tkz)):
d["model"]["vocab"][j]=i
tkz.backend_tokenizer.from_str(json.dumps(d)).save("tokenizer.json")
mdl=LlamaForCausalLM.from_pretrained(src)
tkz=LlamaTokenizerFast(tokenizer_file="tokenizer.json",model_max_length=mdl.config.max_position_embeddings,cls_token="<s>",sep_token="<s>",mask_token="<unk>",pad_token="</s>")
e=mdl.resize_token_embeddings(len(tkz))
f=mdl.get_output_embeddings()
with torch.no_grad():
for k,v in c.items():
e.weight[d["model"]["vocab"][k],:]=e.weight[v,:].sum(0)
f.weight[d["model"]["vocab"][k],:]=f.weight[v,:].sum(0)
mdl.set_input_embeddings(e)
mdl.set_output_embeddings(f)
mdl.save_pretrained(tgt)
tkz.save_pretrained(tgt)
EOF
chmod 755 $TMPA
$TMPA
fi
TMPB=./maker$$b.py
cat << 'EOF' > $TMPB
#! /usr/bin/env deepspeed
src="exSwallow-MS-7b-v0.1"
tgt="KoichiYasuoka/Swallow-MS-7b-upos"
from transformers import LlamaTokenizerFast,MistralModel,MistralPreTrainedModel,AutoConfig,DataCollatorForTokenClassification,TrainingArguments,Trainer
from transformers.modeling_outputs import TokenClassifierOutput
from tokenizers.normalizers import Replace
class MistralForTokenClassification(MistralPreTrainedModel):
def __init__(self,config):
from torch import nn
super().__init__(config)
self.num_labels=config.num_labels
self.model=MistralModel(config)
if hasattr(config,"classifier_dropout") and config.classifier_dropout is not None:
classifier_dropout=config.classifier_dropout
elif hasattr(config,"hidden_dropout") and config.hidden_dropout is not None:
classifier_dropout=config.hidden_dropout
else:
classifier_dropout=0.1
self.dropout=nn.Dropout(classifier_dropout)
self.classifier=nn.Linear(config.hidden_size,config.num_labels)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self,value):
self.model.embed_tokens=value
def forward(self,input_ids=None,past_key_values=None,attention_mask=None,position_ids=None,inputs_embeds=None,labels=None,use_cache=None,output_attentions=None,output_hidden_states=None,return_dict=None):
return_dict=return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs=self.model(input_ids,past_key_values=past_key_values,attention_mask=attention_mask,position_ids=position_ids,inputs_embeds=inputs_embeds,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict)
hidden_states=transformer_outputs[0]
hidden_states=self.dropout(hidden_states)
logits=self.classifier(hidden_states)
loss=None
if labels is not None:
from torch import nn
loss_fct=nn.CrossEntropyLoss()
loss=loss_fct(logits.view(-1,self.num_labels),labels.view(-1))
if not return_dict:
output=(logits,)+transformer_outputs[2:]
return ((loss,)+output) if loss is not None else output
return TokenClassifierOutput(loss=loss,logits=logits,hidden_states=transformer_outputs.hidden_states,attentions=transformer_outputs.attentions)
class UPOSFileDataset(object):
def __init__(self,conllu,tokenizer):
self.conllu=open(conllu,"r",encoding="utf-8")
self.tokenizer=tokenizer
self.seeks=[0]
self.multiword={}
label=set(["SYM"])
s=self.conllu.readline()
while s!="":
if s=="\n":
self.seeks.append(self.conllu.tell())
else:
w=s.split("\t")
if len(w)==10:
if w[0].isdecimal():
label.add(w[3] if w[5]=="_" else w[3]+"|"+w[5])
elif w[0].find("-")>0:
t=w[0].split("-")
f,j,k=w[1],[],[]
for i in range(int(t[0]),int(t[1])+1):
w=self.conllu.readline().split("\t")
j.append(w[3] if w[5]=="_" else w[3]+"|"+w[5])
k.append(w[1])
p="+".join(j)
label.add(p)
if p in self.multiword:
self.multiword[p][f]=list(k)
else:
self.multiword[p]={f:list(k)}
s=self.conllu.readline()
lid={}
for i,l in enumerate(sorted(label)):
lid[l],lid["B-"+l],lid["I-"+l]=i*3,i*3+1,i*3+2
self.label2id=lid
def __call__(*args):
lid={l:i for i,l in enumerate(sorted(set(sum([list(t.label2id) for t in args],[]))))}
for t in args:
t.label2id=lid
return lid
def __del__(self):
self.conllu.close()
__len__=lambda self:len(self.seeks)-1
def __getitem__(self,i):
self.conllu.seek(self.seeks[i])
form,upos=[],[]
while self.conllu.tell()<self.seeks[i+1]:
w=self.conllu.readline().split("\t")
if len(w)==10:
form.append(w[1])
if w[0].isdecimal():
upos.append(w[3] if w[5]=="_" else w[3]+"|"+w[5])
elif w[0].find("-")>0:
t=w[0].split("-")
u=[]
for j in range(int(t[0]),int(t[1])+1):
k=self.conllu.readline().split("\t")
u.append(k[3] if k[5]=="_" else k[3]+"|"+k[5])
upos.append("+".join(u))
v=self.tokenizer(form,add_special_tokens=False)
i,u=[],[]
for j,(x,y) in enumerate(zip(v["input_ids"],upos)):
if x!=[]:
i+=x
u+=[y] if len(x)==1 else ["B-"+y]+["I-"+y]*(len(x)-1)
if len(i)<self.tokenizer.model_max_length-3:
ids=[self.tokenizer.cls_token_id]+i+[self.tokenizer.sep_token_id]
upos=["SYM"]+u+["SYM"]
else:
ids=i[0:self.tokenizer.model_max_length-2]
upos=u[0:self.tokenizer.model_max_length-2]
return {"input_ids":ids,"labels":[self.label2id[t] for t in upos]}
tkz=LlamaTokenizerFast.from_pretrained(src)
tkz.backend_tokenizer.normalizer=Replace(" ","\u2581")
tkz.backend_tokenizer.model.byte_fallback=False
trainDS=UPOSFileDataset("ja_gsd_modern.conllu",tkz)
lid=trainDS.label2id
cfg=AutoConfig.from_pretrained(src,num_labels=len(lid),label2id=lid,id2label={i:l for l,i in lid.items()},ignore_mismatched_sizes=True)
dsp={"fp16":{"enabled":"auto"},"optimizer":{"type":"AdamW"},"scheduler":{"type":"WarmupLR","params":{}},"train_batch_size":"auto","train_micro_batch_size_per_gpu":"auto","zero_optimization":{"stage":3,"offload_optimizer":{"device":"cpu","pin_memory":True},"offload_param":{"device":"cpu","pin_memory":True},"overlap_comm":True,"contiguous_gradients":True,"reduce_bucket_size":"auto","stage3_prefetch_bucket_size":"auto","stage3_param_persistence_threshold":"auto","stage3_gather_16bit_weights_on_model_save":True}}
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=8,deepspeed=dsp,output_dir=tgt,overwrite_output_dir=True,save_total_limit=2,learning_rate=5e-05,warmup_ratio=0.1,save_safetensors=False)
trn=Trainer(args=arg,data_collator=DataCollatorForTokenClassification(tkz),model=MistralForTokenClassification.from_pretrained(src,config=cfg,ignore_mismatched_sizes=True),train_dataset=trainDS)
trn.train()
trn.save_model(tgt)
tkz.save_pretrained(tgt)
EOF
chmod 755 $TMPB
$TMPB
exit