How to add a model to ð€ Transformers?
ð€ Transformersã©ã€ãã©ãªã¯ãã³ãã¥ããã£ã®è²¢ç®è ã®ãããã§æ°ããã¢ãã«ãæäŸã§ããããšããããããŸãã ããããããã¯é£ãããããžã§ã¯ãã§ãããð€ Transformersã©ã€ãã©ãªãšå®è£ ããã¢ãã«ã«ã€ããŠã®æ·±ãç¥èãå¿ èŠã§ãã Hugging Faceã§ã¯ãã³ãã¥ããã£ã®å€ãã®äººã ã«ç©æ¥µçã«ã¢ãã«ãè¿œå ããåãäžããããšåªåããŠããã ãã®ã¬ã€ãããŸãšããŠãPyTorchã¢ãã«ãè¿œå ããããã»ã¹ã説æããŸãïŒPyTorchãã€ã³ã¹ããŒã«ãããŠããããšã確èªããŠãã ããïŒã
ãã®éçšã§ã以äžã®ããšãåŠã³ãŸãïŒ
- ãªãŒãã³ãœãŒã¹ã®ãã¹ããã©ã¯ãã£ã¹ã«é¢ããæŽå¯
- æã人æ°ã®ãã深局åŠç¿ã©ã€ãã©ãªã®èšèšååãç解ãã
- 倧èŠæš¡ãªã¢ãã«ãå¹ççã«ãã¹ãããæ¹æ³ãåŠã¶
black
ãruff
ãããã³make fix-copies
ãªã©ã®PythonãŠãŒãã£ãªãã£ãçµ±åããŠãã¯ãªãŒã³ã§èªã¿ãããã³ãŒãã確ä¿ããæ¹æ³ãåŠã¶
Hugging FaceããŒã ã®ã¡ã³ããŒããµããŒããæäŸããã®ã§ãäžäººãŒã£ã¡ã«ãªãããšã¯ãããŸããã ð€ â€ïž
ãããå§ããŸãããïŒð€ Transformersã§èŠããã¢ãã«ã«ã€ããŠã®New model additionã®ã€ã·ã¥ãŒãéããŠãã ããã ç¹å®ã®ã¢ãã«ãæäŸããããšã«ç¹ã«ãã ããããªãå ŽåãNew model labelã§æªå²ãåœãŠã®ã¢ãã«ãªã¯ãšã¹ãããããã©ããã確èªããŠãããã«åãçµãããšãã§ããŸãã
æ°ããã¢ãã«ãªã¯ãšã¹ããéããããæåã®ã¹ãããã¯ð€ Transformersãããç解ããããšã§ãïŒ
General overview of ð€ Transformers
ãŸããð€ Transformersã®äžè¬çãªæŠèŠãææ¡ããå¿ èŠããããŸããð€ Transformersã¯éåžžã«æèŠãåãããã©ã€ãã©ãªã§ãã®ã§ã ã©ã€ãã©ãªã®å²åŠãèšèšéžæã«ã€ããŠåæã§ããªãå¯èœæ§ããããŸãããã ããç§ãã¡ã®çµéšãããã©ã€ãã©ãªã®åºæ¬çãªèšèšéžæãšå²åŠã¯ã ð€ Transformersãå¹ççã«ã¹ã±ãŒãªã³ã°ããé©åãªã¬ãã«ã§ä¿å®ã³ã¹ããæããããã«äžå¯æ¬ ã§ãã
ã©ã€ãã©ãªã®ç解ãæ·±ããããã®è¯ãåºçºç¹ã¯ãå²åŠã®ããã¥ã¡ã³ããèªãããšã§ãã ç§ãã¡ã®äœæ¥æ¹æ³ã®çµæããã¹ãŠã®ã¢ãã«ã«é©çšããããšããããã€ãã®éžæè¢ããããŸãïŒ
- äžè¬çã«ãæœè±¡åãããæ§æãåªå ãããŸãã
- ã³ãŒãã®éè€ã¯ãèªã¿ããããã¢ã¯ã»ã¹å¯èœæ§ãå€§å¹ ã«åäžãããå Žåãå¿ ãããæªãããã§ã¯ãããŸããã
- ã¢ãã«ãã¡ã€ã«ã¯ã§ããã ãèªå·±å®çµçã§ããã¹ãã§ãç¹å®ã®ã¢ãã«ã®ã³ãŒããèªãéã«ã¯ãçæ³çã«ã¯è©²åœãã
modeling_....py
ãã¡ã€ã«ã®ã¿ãèŠãå¿ èŠããããŸãã
ç§ãã¡ã®æèŠã§ã¯ããã®ã©ã€ãã©ãªã®ã³ãŒãã¯åãªã補åãæäŸããæ段ã ãã§ãªããäŸãã°ãæšè«ã®ããã«BERTã䜿çšããèœåãªã©ã®è£œåãã®ãã®.
Overview of models
ã¢ãã«ãæ£åžžã«è¿œå ããããã«ã¯ãã¢ãã«ãšãã®èšå®ãPreTrainedModelãããã³PretrainedConfigã®çžäºäœçšãç解ããããšãéèŠã§ãã äŸç€ºçãªç®çã§ãð€ Transformersã«è¿œå ããã¢ãã«ããBrandNewBertããšåŒã³ãŸãã
以äžãã芧ãã ããïŒ
ã芧ã®ããã«ãð€ Transformersã§ã¯ç¶æ¿ã䜿çšããŠããŸãããæœè±¡åã®ã¬ãã«ãæå°éã«ä¿ã£ãŠããŸãã
ã©ã€ãã©ãªå
ã®ã©ã®ã¢ãã«ã«ããæœè±¡åã®ã¬ãã«ã2ã€ãè¶
ããããšã¯ãããŸããã
BrandNewBertModel
㯠BrandNewBertPreTrainedModel
ãç¶æ¿ããããã«PreTrainedModelãç¶æ¿ããŠããŸãã
ããã ãã§ãã
äžè¬çãªã«ãŒã«ãšããŠãæ°ããã¢ãã«ã¯PreTrainedModelã«ã®ã¿äŸåããããã«ããããšèããŠããŸãã
ãã¹ãŠã®æ°ããã¢ãã«ã«èªåçã«æäŸãããéèŠãªæ©èœã¯ãfrom_pretrained()ããã³
save_pretrained()ã§ãã
ãããã¯ã·ãªã¢ã©ã€ãŒãŒã·ã§ã³ãšãã·ãªã¢ã©ã€ãŒãŒã·ã§ã³ã«äœ¿çšãããŸãã
BrandNewBertModel.forward
ãªã©ã®ä»ã®éèŠãªæ©èœã¯ãæ°ãããmodeling_brand_new_bert.pyãã¹ã¯ãªããã§å®å
šã«å®çŸ©ãããã¹ãã§ãã
次ã«ãç¹å®ã®ãããã¬ã€ã€ãŒãæã€ã¢ãã«ïŒããšãã° BrandNewBertForMaskedLM
ïŒã BrandNewBertModel
ãç¶æ¿ããã®ã§ã¯ãªãã
æœè±¡åã®ã¬ãã«ãäœãä¿ã€ããã«ããã®ãã©ã¯ãŒããã¹ã§ BrandNewBertModel
ãåŒã³åºãã³ã³ããŒãã³ããšããŠäœ¿çšãããããã«ããããšèããŠããŸãã
æ°ããã¢ãã«ã«ã¯åžžã« BrandNewBertConfig
ãšããèšå®ã¯ã©ã¹ãå¿
èŠã§ãããã®èšå®ã¯åžžã«PreTrainedModelã®å±æ§ãšããŠä¿åããã
ãããã£ãŠãBrandNewBertPreTrainedModel
ããç¶æ¿ãããã¹ãŠã®ã¯ã©ã¹ã§config
å±æ§ãä»ããŠã¢ã¯ã»ã¹ã§ããŸãã
model = BrandNewBertModel.from_pretrained("brandy/brand_new_bert")
model.config # model has access to its config
ã¢ãã«ãšåæ§ã«ãèšå®ã¯PretrainedConfigããåºæ¬çãªã·ãªã¢ã«åããã³éã·ãªã¢ã«åã®æ©èœãç¶æ¿ããŠããŸãã泚æãã¹ãã¯ãèšå®ãšã¢ãã«ã¯åžžã«2ã€ã®ç°ãªã圢åŒã«ã·ãªã¢ã«åãããããšã§ã - ã¢ãã«ã¯pytorch_model.binãã¡ã€ã«ã«ãèšå®ã¯config.jsonãã¡ã€ã«ã«ã·ãªã¢ã«åãããŸããsave_pretrained()ãåŒã³åºããšãèªåçã«save_pretrained()ãåŒã³åºãããã¢ãã«ãšèšå®ã®äž¡æ¹ãä¿åãããŸãã
Code style
æ°ããã¢ãã«ãã³ãŒãã£ã³ã°ããéã«ã¯ãTransformersã¯æèŠãããã©ã€ãã©ãªã§ãããã³ãŒãã®æžãæ¹ã«é¢ããŠããã€ãã®ç¬èªã®èãæ¹ããããŸã :-)
- ã¢ãã«ã®ãã©ã¯ãŒããã¹ã¯ã¢ããªã³ã°ãã¡ã€ã«ã«å®å
šã«èšè¿°ãããã©ã€ãã©ãªå
ã®ä»ã®ã¢ãã«ãšã¯å®å
šã«ç¬ç«ããŠããå¿
èŠããããŸããä»ã®ã¢ãã«ãããããã¯ãåå©çšãããå Žåãã³ãŒããã³ããŒããŠãããã«
# Copied from
ã³ã¡ã³ããä»ããŠè²Œãä»ããŸãïŒè¯ãäŸã¯ãã¡ããã³ããŒã«é¢ãã詳现ãªããã¥ã¡ã³ããŒã·ã§ã³ã¯ãããåç §ããŠãã ããïŒã - ã³ãŒãã¯å®å
šã«ç解å¯èœã§ãªããã°ãªããŸãããããã¯èšè¿°çãªå€æ°åãéžæããçç¥åœ¢ãé¿ããã¹ãã§ããããšãæå³ããŸããäŸãã°ã
act
ã§ã¯ãªãactivation
ã奜ãŸããŸãã1æåã®å€æ°åã¯ãforã«ãŒãå ã®ã€ã³ããã¯ã¹ã§ãªãéãã匷ãéæšå¥šã§ãã - ããäžè¬çã«ãéæ³ã®ãããªçãã³ãŒããããé·ããŠæ瀺çãªã³ãŒãã奜ã¿ãŸãã
- PyTorchã§ã¯
nn.Sequential
ããµãã¯ã©ã¹åããã«ãnn.Module
ããµãã¯ã©ã¹åãããã©ã¯ãŒããã¹ãèšè¿°ããã³ãŒãã䜿çšããä»ã®äººãç°¡åã«ãããã°ã§ããããã«ããŸããããªã³ãã¹ããŒãã¡ã³ãããã¬ãŒã¯ãã€ã³ããè¿œå ããŠãããã°ã§ããããã«ããŸãã - é¢æ°ã®ã·ã°ããã£ã¯åã¢ãããŒã·ã§ã³ãä»ããã¹ãã§ãããã®ä»ã®éšåã«é¢ããŠã¯ãåã¢ãããŒã·ã§ã³ãããè¯ãå€æ°åãèªã¿ãããç解ããããããšããããŸãã
Overview of tokenizers
ãŸã å®äºããŠããŸãã :-( ãã®ã»ã¯ã·ã§ã³ã¯è¿æ¥äžã«è¿œå ãããŸãïŒ
Step-by-step recipe to add a model to ð€ Transformers
ã¢ãã«ãè¿œå ããæ¹æ³ã¯äººããããç°ãªããããä»ã®ã³ã³ããªãã¥ãŒã¿ãŒãð€ Transformersã«ã¢ãã«ãè¿œå ããéã®èŠçŽã確èªããããšãéåžžã«åœ¹ç«ã€å ŽåããããŸãã以äžã¯ãä»ã®ã³ã³ããªãã¥ãŒã¿ãŒãð€ Transformersã«ã¢ãã«ãããŒãããéã®ã³ãã¥ããã£ããã°æçš¿ã®ãªã¹ãã§ãã
çµéšããèšããããšã¯ãã¢ãã«ãè¿œå ããéã«æãéèŠãªããšã¯æ¬¡ã®ããã«ãªããŸãïŒ
- è»èŒªã®åçºæãããªãã§ãã ããïŒæ°ããð€ Transformersã¢ãã«ã®ããã«è¿œå ããã³ãŒãã®ã»ãšãã©ã¯ãã§ã«ð€ Transformerså ã®ã©ããã«ååšããŠããŸããé¡äŒŒããæ¢åã®ã¢ãã«ãããŒã¯ãã€ã¶ãèŠã€ããããã«ãããã€ãã®æéããããŠæ¢ãããšãéèŠã§ããgrepãšrgã¯ããªãã®åéã§ããã¢ãã«ã®ããŒã¯ãã€ã¶ã¯1ã€ã®ã¢ãã«å®è£ ã«åºã¥ããŠãããããããŸããããã¢ãã«ã®ã¢ããªã³ã°ã³ãŒãã¯å¥ã®å®è£ ã«åºã¥ããŠããããšãããããšã«æ³šæããŠãã ãããäŸãã°ãFSMTã®ã¢ããªã³ã°ã³ãŒãã¯BARTã«åºã¥ããŠãããFSMTã®ããŒã¯ãã€ã¶ã³ãŒãã¯XLMã«åºã¥ããŠããŸãã
- ããã¯ç§åŠçãªèª²é¡ããããšã³ãžãã¢ãªã³ã°ã®èª²é¡ã§ããã¢ãã«ã®è«æã®çè«çãªåŽé¢ããã¹ãŠç解ããããšããããããå¹ççãªãããã°ç°å¢ãäœæããããã«æéãè²»ããã¹ãã§ãã
- è¡ãè©°ãŸã£ãå Žåã¯å©ããæ±ããŠãã ããïŒã¢ãã«ã¯ð€ Transformersã®ã³ã¢ã³ã³ããŒãã³ãã§ãããHugging Faceã§ã¯ã¢ãã«ãè¿œå ããããã®åã¹ãããã§ãæäŒãããã®ãåãã§ããŸããé²è¡ããªãããšã«æ°ä»ããå Žåã¯ãé²å±ããŠããªãããšãæ°ã«ããªãã§ãã ããã
以äžã§ã¯ãð€ Transformersã«ã¢ãã«ãããŒãããéã«æã圹ç«ã€ãšèããããäžè¬çãªã¬ã·ããæäŸããããšããŠããŸãã
次ã®ãªã¹ãã¯ãã¢ãã«ãè¿œå ããããã«è¡ãå¿ èŠããããã¹ãŠã®ããšã®èŠçŽã§ãããTo-Doãªã¹ããšããŠäœ¿çšã§ããŸãïŒ
- â ïŒãªãã·ã§ã³ïŒã¢ãã«ã®çè«çãªåŽé¢ãç解ããŸãã
- â ð€ Transformersã®éçºç°å¢ãæºåããŸãã
- â ãªãªãžãã«ã®ãªããžããªã®ãããã°ç°å¢ãã»ããã¢ããããŸãã
- â
forward()
ãã¹ããªãªãžãã«ã®ãªããžããªãšãã§ãã¯ãã€ã³ãã§æ£åžžã«å®è¡ããã¹ã¯ãªãããäœæããŸãã - â ã¢ãã«ã®éªšæ Œãð€ Transformersã«æ£åžžã«è¿œå ããŸãã
- â ãªãªãžãã«ã®ãã§ãã¯ãã€ã³ããð€ Transformersã®ãã§ãã¯ãã€ã³ãã«æ£åžžã«å€æããŸãã
- â ð€ Transformersã§å®è¡ããã
forward()
ãã¹ãæ£åžžã«å®è¡ãããªãªãžãã«ã®ãã§ãã¯ãã€ã³ããšåäžã®åºåãåŸãŸãã - â ð€ Transformersã§ã®ã¢ãã«ãã¹ããå®äºããŸãã
- â ð€ Transformersã«ããŒã¯ãã€ã¶ãæ£åžžã«è¿œå ããŸãã
- â ãšã³ãããŒãšã³ãã®çµ±åãã¹ããå®è¡ããŸãã
- â ããã¥ã¡ã³ããå®æãããŸãã
- â ã¢ãã«ã®ãŠã§ã€ããHubã«ã¢ããããŒãããŸãã
- â ãã«ãªã¯ãšã¹ããæåºããŸãã
- â ïŒãªãã·ã§ã³ïŒãã¢ããŒãããã¯ãè¿œå ããŸãã
ãŸããéåžžãBrandNewBert
ã®çè«çãªç解ãæ·±ããããšããå§ãããŸãã
ãã ããããã¢ãã«ã®çè«çãªåŽé¢ããå®åäžã«ç解ãããæ¹ã奜ãŸããå ŽåãBrandNewBert
ã®ã³ãŒãããŒã¹ã«çŽæ¥ã¢ã¯ã»ã¹ããã®ãåé¡ãããŸããã
ãã®ãªãã·ã§ã³ã¯ããšã³ãžãã¢ãªã³ã°ã®ã¹ãã«ãçè«çãªã¹ãã«ãããåªããŠããå Žåã
BrandNewBert
ã®è«æãç解ããã®ã«èŠåŽããŠããå ŽåããŸãã¯ç§åŠçãªè«æãèªããããããã°ã©ãã³ã°ã楜ããã§ããå Žåã«é©ããŠããŸãã
1. (Optional) Theoretical aspects of BrandNewBert
BrandNewBertã®è«æãããå Žåããã®èª¬æãèªãããã®æéãåãã¹ãã§ããè«æã®äžã«ã¯ç解ãé£ããéšåããããããããŸããã ãã®å Žåã§ãå¿é ããªãã§ãã ãããç®æšã¯è«æã®æ·±ãçè«çç解ãåŸãããšã§ã¯ãªãã ð€ Transformersã§ã¢ãã«ãå¹æçã«åå®è£ ããããã«å¿ èŠãªæ å ±ãæœåºããããšã§ãã ãã ããçè«çãªåŽé¢ã«ããŸãå€ãã®æéããããå¿ èŠã¯ãããŸããã代ããã«ãå®è·µçãªåŽé¢ã«çŠç¹ãåœãŠãŸããããå ·äœçã«ã¯æ¬¡ã®ç¹ã§ãïŒ
- brand_new_bertã¯ã©ã®çš®é¡ã®ã¢ãã«ã§ããïŒ BERTã®ãããªãšã³ã³ãŒããŒã®ã¿ã®ã¢ãã«ã§ããïŒ GPT2ã®ãããªãã³ãŒããŒã®ã¿ã®ã¢ãã«ã§ããïŒ BARTã®ãããªãšã³ã³ãŒããŒ-ãã³ãŒããŒã¢ãã«ã§ããïŒ model_summaryãåç §ããŠããããã®éãã«ã€ããŠè©³ããç¥ãããå ŽåããããŸãã
- brand_new_bertã®å¿çšåéã¯äœã§ããïŒ ããã¹ãåé¡ã§ããïŒ ããã¹ãçæã§ããïŒ Seq2Seqã¿ã¹ã¯ãäŸãã°èŠçŽã§ããïŒ
- ã¢ãã«ãBERT/GPT-2/BARTãšã¯ç°ãªããã®ã«ããæ°ããæ©èœã¯äœã§ããïŒ
- æ¢åã®ð€ Transformersã¢ãã«ã®äžã§brand_new_bertã«æã䌌ãŠããã¢ãã«ã¯ã©ãã§ããïŒ
- 䜿çšãããŠããããŒã¯ãã€ã¶ã®çš®é¡ã¯äœã§ããïŒ SentencePieceããŒã¯ãã€ã¶ã§ããïŒ WordPieceããŒã¯ãã€ã¶ã§ããïŒ BERTãBARTã§äœ¿çšãããŠããããŒã¯ãã€ã¶ãšåãã§ããïŒ
ã¢ãã«ã®ã¢ãŒããã¯ãã£ã®è¯ãæŠèŠãåŸããšæããããHugging FaceããŒã ã«è³ªåãéãããšãã§ããŸãã ããã«ã¯ã¢ãã«ã®ã¢ãŒããã¯ãã£ã泚æå±€ãªã©ã«é¢ãã質åãå«ãŸãããããããŸããã ç§ãã¡ã¯åãã§ãæäŒãããŸãã
2. Next prepare your environment
ãªããžããªã®ããŒãžã§ãForkããã¿ã³ãã¯ãªãã¯ããŠããªããžããªããã©ãŒã¯ããŸãã ããã«ãããã³ãŒãã®ã³ããŒãGitHubãŠãŒã¶ãŒã¢ã«ãŠã³ãã®äžã«äœæãããŸãã
ããŒã«ã«ãã£ã¹ã¯ã«ãã
transformers
ãã©ãŒã¯ãã¯ããŒã³ããããŒã¹ãªããžããªããªã¢ãŒããšããŠè¿œå ããŸãïŒ
git clone https://github.com/[your Github handle]/transformers.git
cd transformers
git remote add upstream https://github.com/huggingface/transformers.git
python -m venv .env
source .env/bin/activate
pip install -e ".[dev]"
- éçºç°å¢ãã»ããã¢ããããããã«ã次ã®ã³ãã³ããå®è¡ããŠãã ããïŒ
python -m venv .env
source .env/bin/activate
pip install -e ".[dev]"
ã䜿ãã®OSã«å¿ããŠãããã³Transformersã®ãªãã·ã§ã³ã®äŸåé¢ä¿ã®æ°ãå¢ããŠããããããã®ã³ãã³ãã§ãšã©ãŒãçºçããå¯èœæ§ããããŸãã ãã®å Žåã¯ãäœæ¥ããŠããDeep Learningãã¬ãŒã ã¯ãŒã¯ïŒPyTorchãTensorFlowãããã³/ãŸãã¯FlaxïŒãã€ã³ã¹ããŒã«ãã次ã®æé ãå®è¡ããŠãã ããïŒ
pip install -e ".[quality]"
ããã¯ã»ãšãã©ã®ãŠãŒã¹ã±ãŒã¹ã«ã¯ååã§ããã¯ãã§ãããã®åŸã芪ãã£ã¬ã¯ããªã«æ»ãããšãã§ããŸãã
cd ..
Transformersã«brand_new_bertã®PyTorchããŒãžã§ã³ãè¿œå ããããšããå§ãããŸããPyTorchãã€ã³ã¹ããŒã«ããã«ã¯ã https://pytorch.org/get-started/locally/ ã®æ瀺ã«åŸã£ãŠãã ããã
泚æ: CUDAãã€ã³ã¹ããŒã«ããå¿ èŠã¯ãããŸãããæ°ããã¢ãã«ãCPUã§åäœãããããšã§ååã§ãã
brand_new_bertã移æ€ããã«ã¯ãå ã®ãªããžããªãžã®ã¢ã¯ã»ã¹ãå¿ èŠã§ãã
git clone https://github.com/org_that_created_brand_new_bert_org/brand_new_bert.git
cd brand_new_bert
pip install -e .
brand_new_bertãð€ Transformersã«ããŒãããããã®éçºç°å¢ãèšå®ããŸããã
3.-4. Run a pretrained checkpoint using the original repository
æåã«ããªãªãžãã«ã®brand_new_bertãªããžããªã§äœæ¥ããŸããéåžžããªãªãžãã«ã®å®è£ ã¯éåžžã«ãç 究çãã§ãããããã¥ã¡ã³ããŒã·ã§ã³ãäžè¶³ããŠããããã³ãŒããç解ãã«ããããšããããŸãããããããããbrand_new_bertãåå®è£ ããåæ©ãšãªãã¹ãã§ããHugging Faceã§ã¯ãäž»èŠãªç®æšã®1ã€ããåäœããã¢ãã«ãåãããããã§ããã ãã¢ã¯ã»ã¹å¯èœã§ãŠãŒã¶ãŒãã¬ã³ããªãŒã§çŸãããã®ã«æžãçŽãããšã§ããããã¯ãð€ Transformersã«ã¢ãã«ãåå®è£ ããæãéèŠãªåæ©ã§ã - è€éãªæ°ããNLPæè¡ã誰ã«ã§ãã¢ã¯ã»ã¹å¯èœã«ããããšããè©Šã¿ã§ãã
ãŸãããªãªãžãã«ã®ãªããžããªã«å ¥ã蟌ãããšããå§ããã¹ãã§ãã
å ¬åŒã®äºååŠç¿æžã¿ã¢ãã«ããªãªãžãã«ã®ãªããžããªã§æ£åžžã«å®è¡ããããšã¯ãéåžžãæãå°é£ãªã¹ãããã§ãã ç§ãã¡ã®çµéšããããªãªãžãã«ã®ã³ãŒãããŒã¹ã«æ £ããã®ã«æéããããããšãéåžžã«éèŠã§ãã以äžã®ããšãç解ããå¿ èŠããããŸãïŒ
- äºååŠç¿æžã¿ã®éã¿ãã©ãã§èŠã€ãããïŒ
- 察å¿ããã¢ãã«ã«äºååŠç¿æžã¿ã®éã¿ãããŒãããæ¹æ³ã¯ïŒ
- ã¢ãã«ããç¬ç«ããŠããŒã¯ãã€ã¶ãå®è¡ããæ¹æ³ã¯ïŒ
- 1ã€ã®ãã©ã¯ãŒããã¹ã远跡ããŠãåçŽãªãã©ã¯ãŒããã¹ã«å¿ èŠãªã¯ã©ã¹ãšé¢æ°ããããããã«ããŸããéåžžããããã®é¢æ°ã ããåå®è£ ããå¿ èŠããããŸãã
- ã¢ãã«ã®éèŠãªã³ã³ããŒãã³ããç¹å®ã§ããããšïŒã¢ãã«ã®ã¯ã©ã¹ã¯ã©ãã«ãããŸããïŒã¢ãã«ã®ãµãã¯ã©ã¹ãäŸ EncoderModelãDecoderModelããããŸããïŒèªå·±æ³šæã¬ã€ã€ãŒã¯ã©ãã«ãããŸããïŒè€æ°ã®ç°ãªã泚æã¬ã€ã€ãŒãäŸ èªå·±æ³šæãã¯ãã¹ã¢ãã³ã·ã§ã³ãªã©ãååšããŸããïŒ
- ãªãªãžãã«ã®ãªããžããªã®ç°å¢ã§ã¢ãã«ããããã°ããæ¹æ³ã¯ïŒprintã¹ããŒãã¡ã³ããè¿œå ããå¿ èŠãããããipdbã®ãããªå¯Ÿè©±åãããã¬ã䜿çšã§ããããPyCharmã®ãããªå¹ççãªIDEã䜿çšããŠã¢ãã«ããããã°ããå¿ èŠããããŸããïŒ
éèŠãªã®ã¯ãããŒãã£ã³ã°ããã»ã¹ãéå§ããåã«ããªãªãžãã«ã®ãªããžããªã§ã³ãŒããå¹ççã«ãããã°ã§ããããšã§ãïŒãŸããããã¯ãªãŒãã³ãœãŒã¹ã©ã€ãã©ãªã§äœæ¥ããŠããããšãèŠããŠãããŠãã ããããªãªãžãã«ã®ãªããžããªã§ã³ãŒãã調ã¹ã誰ããæè¿ããããã«ãåé¡ããªãŒãã³ã«ãããããã«ãªã¯ãšã¹ããéä¿¡ãããããããšããããããªãã§ãã ããããã®ãªããžããªã®ã¡ã³ãããŒã¯ã圌ãã®ã³ãŒãã調ã¹ãŠããã人ã«å¯ŸããŠéåžžã«åãã§ããå¯èœæ§ãé«ãã§ãïŒ
ãã®æ®µéã§ã¯ããªãªãžãã«ã®ã¢ãã«ã®ãããã°ã«ã©ã®ãããªç°å¢ãšæŠç¥ã䜿çšãããã¯ãããªã次第ã§ããæåã«ãªãªãžãã«ã®ãªããžããªã«é¢ããã³ãŒãããããã°ã§ããããšãéåžžã«éèŠã§ãããŸããGPUç°å¢ãã»ããã¢ããããããšã¯ãå§ãããŸããããŸããCPUäžã§äœæ¥ããã¢ãã«ããã§ã«ð€ Transformersã«æ£åžžã«ããŒããããŠããããšã確èªããŸããæåŸã«ãã¢ãã«ãGPUäžã§ãæåŸ éãã«åäœãããã©ãããæ€èšŒããå¿ èŠããããŸãã
äžè¬çã«ããªãªãžãã«ã®ã¢ãã«ãå®è¡ããããã®2ã€ã®ãããã°ç°å¢ããããŸãïŒ
- Jupyter notebooks / google colab
- ããŒã«ã«ãªPythonã¹ã¯ãªããã
JupyterããŒãããã¯ã¯ãã»ã«ããšã«å®è¡ã§ãããããè«ççãªã³ã³ããŒãã³ããããåå²ããäžéçµæãä¿åã§ããããããããã°ãµã€ã¯ã«ãéããªããšããå©ç¹ããããŸãããŸããããŒãããã¯ã¯ä»ã®å ±åäœæ¥è ãšç°¡åã«å ±æã§ããããšãå€ããHugging FaceããŒã ã«å©ããæ±ããå Žåã«éåžžã«åœ¹ç«ã€å ŽåããããŸããJupyterããŒãããã¯ã«ç²ŸéããŠããå Žåããã
model = BrandNewBertModel.load_pretrained_checkpoint("/path/to/checkpoint/")
input_ids = [0, 4, 5, 2, 3, 7, 9] # vector of input ids
original_output = model.predict(input_ids)
ãããã°æŠç¥ã«ã€ããŠã¯ãéåžžãããã€ãã®éžæè¢ããããŸãïŒ
- å ã®ã¢ãã«ãå€ãã®å°ããªãã¹ãå¯èœãªã³ã³ããŒãã³ãã«å解ããããããã«å¯ŸããŠåæ¹ãã¹ãå®è¡ããŠæ€èšŒããŸã
- å ã®ã¢ãã«ãå ã®ããŒã¯ãã€ã¶ãšå ã®ã¢ãã«ã«ã®ã¿å解ãããããã«å¯ŸããŠåæ¹ãã¹ãå®è¡ããæ€èšŒã®ããã«äžéã®ããªã³ãã¹ããŒãã¡ã³ããŸãã¯ãã¬ãŒã¯ãã€ã³ãã䜿çšããŸã
å床ãã©ã®æŠç¥ãéžæãããã¯ããªã次第ã§ããå ã®ã³ãŒãããŒã¹ã«äŸåããããšãå€ããå ã®ã³ãŒãããŒã¹ã«å¿ããŠäžæ¹ãŸãã¯ä»æ¹ãæå©ãªããšããããŸãã
å ã®ã³ãŒãããŒã¹ãã¢ãã«ãå°ããªãµãã³ã³ããŒãã³ãã«å解ã§ããå ŽåãäŸãã°å ã®ã³ãŒãããŒã¹ãç°¡åã«ã€ãŒã¬ãŒã¢ãŒãã§å®è¡ã§ããå Žåããããè¡ã䟡å€ãéåžžãããŸããæåããããé£ããæ¹æ³ãéžæããããšã«ã¯ããã€ãã®éèŠãªå©ç¹ããããŸãïŒ
- åŸã§å ã®ã¢ãã«ãð€ Transformersã®å®è£ ãšæ¯èŒããéã«ãåã³ã³ããŒãã³ãã察å¿ããð€ Transformerså®è£ ã®ã³ã³ããŒãã³ããšäžèŽããããšãèªåçã«æ€èšŒã§ãããããèŠèŠçãªæ¯èŒã«äŸåããã«æžã¿ãŸã
- 倧ããªåé¡ãå°ããªåé¡ã«å解ãããã€ãŸãåã ã®ã³ã³ããŒãã³ãã®ã¿ãããŒãã£ã³ã°ããåé¡ã«åå²ããã®ã«åœ¹ç«ã¡ãäœæ¥ãæ§é åããã®ã«åœ¹ç«ã¡ãŸã
- ã¢ãã«ãè«ççãªæå³ã®ããã³ã³ããŒãã³ãã«åå²ããããšã§ãã¢ãã«ã®èšèšãããããç解ããããããã¢ãã«ãããããç解ããã®ã«åœ¹ç«ã¡ãŸã
- åŸã§ãã³ã³ããŒãã³ãããšã®ãã¹ããè¡ãããšã§ãã³ãŒããå€æŽãç¶ããéã«ãªã°ã¬ãã·ã§ã³ãçºçããªãããšã確èªããã®ã«åœ¹ç«ã¡ãŸã
Lysandreã® ELECTRAã®çµ±åãã§ãã¯ã¯ããããã©ã®ããã«è¡ããããã®è¯ãäŸã§ãã
ãã ããå ã®ã³ãŒãããŒã¹ãéåžžã«è€éã§ãäžéã³ã³ããŒãã³ããã³ã³ãã€ã«ã¢ãŒãã§å®è¡ããããšããèš±å¯ããªãå Žåãã¢ãã«ãå°ããªãã¹ãå¯èœãªãµãã³ã³ããŒãã³ãã«å解ããããšãæéãããããããããäžå¯èœã§ããããšããããŸãã è¯ãäŸã¯T5ã®MeshTensorFlowã©ã€ãã©ãªã§ãããéåžžã«è€éã§ã¢ãã«ããµãã³ã³ããŒãã³ãã«å解ããç°¡åãªæ¹æ³ãæäŸããªãããšããããŸãããã®ãããªã©ã€ãã©ãªã§ã¯ãéåžžãããªã³ãã¹ããŒãã¡ã³ããæ€èšŒããããšã«äŸåããŸãã
ã©ã®æŠç¥ãéžæããŠããæšå¥šãããæé ã¯éåžžåãã§ãæåã®ã¬ã€ã€ãŒãããããã°ãéå§ããæåŸã®ã¬ã€ã€ãŒãããããã°ãè¡ãã¹ãã§ãã
éåžžã以äžã®é åºã§æ¬¡ã®ã¬ã€ã€ãŒããã®åºåãååŸããããšããå§ãããŸãïŒ
- ã¢ãã«ã«æž¡ãããå ¥åIDãååŸãã
- åèªã®åã蟌ã¿ãååŸãã
- æåã®Transformerã¬ã€ã€ãŒã®å ¥åãååŸãã
- æåã®Transformerã¬ã€ã€ãŒã®åºåãååŸãã
- 次ã®n - 1ã€ã®Transformerã¬ã€ã€ãŒã®åºåãååŸãã
- BrandNewBertã¢ãã«å šäœã®åºåãååŸãã
å
¥åIDã¯æŽæ°ã®é
åã§ããå¿
èŠããããäŸïŒ input_ids = [0, 4, 4, 3, 2, 4, 1, 7, 19]
ã®ããã«ãªããŸãã
以äžã®ã¬ã€ã€ãŒã®åºåã¯å€æ¬¡å ã®æµ®åå°æ°ç¹é åã§ããããšãå€ãã次ã®ããã«ãªãããšããããŸãïŒ
[[
[-0.1465, -0.6501, 0.1993, ..., 0.1451, 0.3430, 0.6024],
[-0.4417, -0.5920, 0.3450, ..., -0.3062, 0.6182, 0.7132],
[-0.5009, -0.7122, 0.4548, ..., -0.3662, 0.6091, 0.7648],
...,
[-0.5613, -0.6332, 0.4324, ..., -0.3792, 0.7372, 0.9288],
[-0.5416, -0.6345, 0.4180, ..., -0.3564, 0.6992, 0.9191],
[-0.5334, -0.6403, 0.4271, ..., -0.3339, 0.6533, 0.8694]]],
ð€ Transformersã«è¿œå ããããã¹ãŠã®ã¢ãã«ã¯ãçµ±åãã¹ããæ°ååæ ŒããããšãæåŸ ãããŠãããå ã®ã¢ãã«ãšð€ Transformersã§åå®è£ ãããããŒãžã§ã³ãã0.001ã®ç²ŸåºŠãŸã§ãŸã£ããåãåºåãæäŸããå¿ èŠããããŸãã ç°ãªãã©ã€ãã©ãªãã¬ãŒã ã¯ãŒã¯ã§åãã¢ãã«ãæžããå Žåããããã«ç°ãªãåºåãè¿ãããšãæ£åžžã§ããããã誀差蚱容å€ãšããŠ1e-3ïŒ0.001ïŒãåãå ¥ããŠããŸããã¢ãã«ãã»ãŒåãåºåãè¿ãã ãã§ã¯äžååã§ãã»ãŒåäžã§ããå¿ èŠããããŸãããã®ãããð€ TransformersããŒãžã§ã³ã®äžéåºåãå ã®brand_new_bertã®å®è£ ã®äžéåºåãšè€æ°åã«ããã£ãŠæ¯èŒããããšã«ãªãã§ãããããã®éãå ã®ãªããžããªã®å¹ççãªãããã°ç°å¢ãéåžžã«éèŠã§ãã以äžã¯ããããã°ç°å¢ãã§ããã ãå¹ççã«ããããã®ã¢ããã€ã¹ã§ãã
- äžéçµæããããã°ããæé©ãªæ¹æ³ãèŠã€ãããå ã®ãªããžããªã¯PyTorchã§æžãããŠããŸããïŒãã®å Žåãå ã®ã¢ãã«ãããå°ããªãµãã³ã³ããŒãã³ãã«å解ããŠäžéå€ãååŸããé·ãã¹ã¯ãªãããæžãããšãããããé©åã§ããå ã®ãªããžããªãTensorflow 1ã§æžãããŠããå Žåãtf.printãªã©ã®TensorFlowã®ããªã³ãæäœã䜿çšããŠäžéå€ãåºåããå¿ èŠããããããããŸãããå ã®ãªããžããªãJaxã§æžãããŠããå Žåããã©ã¯ãŒããã¹ã®å®è¡æã«ã¢ãã«ãjittedãããŠããªãããšã確èªããŠãã ãããäŸïŒãã®ãªã³ã¯ããã§ãã¯ã
- 䜿çšå¯èœãªæå°ã®äºååŠç¿æžã¿ãã§ãã¯ãã€ã³ãã䜿çšããŸãããã§ãã¯ãã€ã³ããå°ããã»ã©ããããã°ãµã€ã¯ã«ãéããªããŸããäºååŠç¿æžã¿ã¢ãã«ããã©ã¯ãŒããã¹ã«10ç§ä»¥äžãããå Žåãå¹ççã§ã¯ãããŸãããéåžžã«å€§ããªãã§ãã¯ãã€ã³ãããå©çšã§ããªãå Žåãæ°ããç°å¢ã§ã©ã³ãã ã«åæåããããŠã§ã€ããæã€ãããŒã¢ãã«ãäœæãããããã®ãŠã§ã€ããð€ TransformersããŒãžã§ã³ã®ã¢ãã«ãšæ¯èŒããæ¹ãè¯ããããããŸããã
- å ã®ãªããžããªã§ãã©ã¯ãŒããã¹ãåŒã³åºãæãç°¡åãªæ¹æ³ã䜿çšããŠããããšã確èªããŠãã ãããçæ³çã«ã¯ãå ã®ãªããžããªã§åäžã®ãã©ã¯ãŒããã¹ãåŒã³åºãé¢æ°ãèŠã€ãããã§ããããã¯éåžžãpredictãããevaluateãããforwardãããcallããšåŒã°ããŸããè€æ°åãforwardããåŒã³åºãé¢æ°ããããã°ããããããŸãããäŸïŒããã¹ããçæããããã«ãautoregressive_sampleãããgenerateããšåŒã°ããé¢æ°ã
- ããŒã¯ãã€ãŒãŒã·ã§ã³ãšã¢ãã«ã®ããã©ã¯ãŒãããã¹ãåé¢ããããšããŠãã ãããå ã®ãªããžããªãå ¥åæååãå ¥åããå¿ èŠãããäŸã瀺ãå Žåããã©ã¯ãŒãã³ãŒã«å ã§æååå ¥åãå ¥åIDã«å€æŽãããå Žæãç¹å®ãããã®ãã€ã³ãããéå§ããŸããããã¯ãã¹ã¯ãªãããèªåã§æžãããå ¥åæååã§ã¯ãªãå ¥åIDãçŽæ¥å ¥åã§ããããã«å ã®ã³ãŒããå€æŽããå¿ èŠããããããããŸããã
- ãããã°ã»ããã¢ããå ã®ã¢ãã«ããã¬ãŒãã³ã°ã¢ãŒãã§ã¯ãªãããšã確èªããŠãã ããããã¬ãŒãã³ã°ã¢ãŒãã§ã¯ãã¢ãã«å ã®è€æ°ã®ããããã¢ãŠãã¬ã€ã€ãŒã®ããã«ã©ã³ãã ãªåºåãçæãããããšããããŸãããããã°ç°å¢ã®ãã©ã¯ãŒããã¹ã決å®è«çã§ããããšã確èªããããããã¢ãŠãã¬ã€ã€ãŒã䜿çšãããªãããã«ããŸãããŸãã¯ãæ°ããå®è£ ãåããã¬ãŒã ã¯ãŒã¯å ã«ããå Žåãtransformers.utils.set_seedã䜿çšããŠãã ããã
以äžã®ã»ã¯ã·ã§ã³ã§ã¯ãbrand_new_bertã«ã€ããŠãããå ·äœçã«ã©ã®ããã«è¡ããã«ã€ããŠã®è©³çŽ°/ãã³ããæäŸããŸãã
5.-14. Port BrandNewBert to ð€ Transformers
次ã«ãã€ãã«æ°ããã³ãŒããð€ Transformersã«è¿œå ã§ããŸããð€ Transformersã®ãã©ãŒã¯ã®ã¯ããŒã³ã«ç§»åããŠãã ããïŒ
cd transformers
ç¹å¥ãªã±ãŒã¹ãšããŠãæ¢åã®ã¢ãã«ãšå®å šã«äžèŽããã¢ãŒããã¯ãã£ã®ã¢ãã«ãè¿œå ããå Žåã ãã®ã»ã¯ã·ã§ã³ã§èª¬æãããŠããããã«ãå€æã¹ã¯ãªãããè¿œå ããã ãã§æžã¿ãŸãã ãã®å Žåãæ¢åã®ã¢ãã«ã®å®å šãªã¢ãã«ã¢ãŒããã¯ãã£ãåå©çšã§ããŸãã
ãã以å€ã®å Žåã¯ãæ°ããã¢ãã«ã®çæãéå§ããŸãããã 次ã®ã¹ã¯ãªããã䜿çšããŠã以äžããå§ãŸãã¢ãã«ãè¿œå ããããšããå§ãããŸãã æ¢åã®ã¢ãã«:
transformers-cli add-new-model-like
ã¢ãã«ã®åºæ¬æ å ±ãå ¥åããããã®ã¢ã³ã±ãŒãã衚瀺ãããŸãã
äž»èŠãª huggingface/transformers ãªããžããªã§ãã«ãªã¯ãšã¹ããéã
èªåçæãããã³ãŒããé©å¿ãå§ããåã«ãð€ Transformers ã«ãäœæ¥äžïŒWIPïŒããã«ãªã¯ãšã¹ããéãã¿ã€ãã³ã°ã§ãã äŸïŒã[WIP] brand_new_bert ãè¿œå ããªã©ã§ãã ããã«ããããŠãŒã¶ãŒãš Hugging Face ããŒã ãð€ Transformers ã«ã¢ãã«ãçµ±åããäœæ¥ã䞊è¡ããŠè¡ãããšãã§ããŸãã
以äžã®æé ãå®è¡ããŠãã ããïŒ
- ã¡ã€ã³ãã©ã³ãããåãããããååã®ãã©ã³ããäœæããŸãã
git checkout -b add_brand_new_bert
- èªåçæãããã³ãŒããã³ãããããŠãã ãã:
git add . git commit
- çŸåšã® main ãã©ã³ãã«ãã§ããããŠãªããŒã¹
git fetch upstream git rebase upstream/main
- å€æŽãããªãã®ã¢ã«ãŠã³ãã«ããã·ã¥ããã«ã¯ã次ã®ã³ãã³ãã䜿çšããŸãïŒ
git push -u origin a-descriptive-name-for-my-changes
æºè¶³ããããGitHubäžã®ãã©ãŒã¯ã®ãŠã§ãããŒãžã«ç§»åããŸãã[ãã«ãªã¯ãšã¹ã]ãã¯ãªãã¯ããŸããå°æ¥ã®å€æŽã«åããŠãHugging Face ããŒã ã®ã¡ã³ããŒã®GitHubãã³ãã«ãã¬ãã¥ã¢ãŒãšããŠè¿œå ããŠãã ããã
GitHubã®ãã«ãªã¯ãšã¹ããŠã§ãããŒãžã®å³åŽã«ããããã©ããã«å€æããã¯ãªãã¯ããŠãPRããã©ããã«å€æŽããŸãã
以äžã§ã¯ãé²æããã£ãå Žåã¯åžžã«äœæ¥ãã³ãããããããã·ã¥ããŠãã«ãªã¯ãšã¹ãã«è¡šç€ºãããããã«ããŠãã ãããããã«ãå®æçã«ã¡ã€ã³ããã®ææ°ã®å€æŽãåã蟌ãããã«ã次ã®ããã«è¡ãããšãå¿ããªãã§ãã ããïŒ
git fetch upstream git merge upstream/main
äžè¬çã«ãã¢ãã«ãå®è£ ã«é¢ãã質åã¯Pull Request (PR) ã§è¡ããPRå ã§è°è«ãã解決ããŸãã ããã«ãããHugging Face ããŒã ã¯æ°ããã³ãŒããã³ãããããéã質åãããå Žåã«åžžã«éç¥ãåããããšãã§ããŸãã 質åãåé¡ã解決ãããéã«ãåé¡ã質åãç解ãããããããã«ãHugging Face ããŒã ã«ã³ãŒããææããããšãéåžžã«åœ¹ç«ã¡ãŸãã
ãã®ããã«ã¯ããFiles changedãã¿ãã«ç§»åããŠãã¹ãŠã®å€æŽã衚瀺ãã質åãããè¡ã«ç§»åããŠã+ãã·ã³ãã«ãã¯ãªãã¯ããŠã³ã¡ã³ããè¿œå ããŸãã 質åãåé¡ã解決ãããå Žåã¯ãäœæãããã³ã¡ã³ãã®ãResolveããã¿ã³ãã¯ãªãã¯ã§ããŸãã
åæ§ã«ãHugging Face ããŒã ã¯ã³ãŒããã¬ãã¥ãŒããéã«ã³ã¡ã³ããéããŸãã PRäžã§ã®ã»ãšãã©ã®è³ªåã¯GitHubäžã§è¡ãããšããå§ãããŸãã äžè¬çãªè³ªåã«é¢ããŠã¯ãå ¬ã«ã¯ããŸã圹ç«ããªã質åã«ã€ããŠã¯ãSlackãã¡ãŒã«ã§Hugging Face ããŒã ã«é£çµ¡ããããšãã§ããŸãã
5. çæãããã¢ãã«ã³ãŒããâbrand_new_bertâã«é©å¿ããã
æåã«ãã¢ãã«èªäœã«çŠç¹ãåœãŠãããŒã¯ãã€ã¶ã«ã¯æ°ã«ããªãã§ãã ããã
é¢é£ããã³ãŒãã¯ãçæããããã¡ã€ã«src/transformers/models/brand_new_bert/modeling_brand_new_bert.py
ããã³src/transformers/models/brand_new_bert/configuration_brand_new_bert.py
ã§èŠã€ããã¯ãã§ãã
ããŠãã€ãã«ã³ãŒãã£ã³ã°ãå§ããããšãã§ããŸã :smile:ã
src/transformers/models/brand_new_bert/modeling_brand_new_bert.py
ã«ããçæãããã³ãŒãã¯ããšã³ã³ãŒããŒã®ã¿ã®ã¢ãã«ã§ããã°BERTãšåãã¢ãŒããã¯ãã£ãæã£ãŠãããããšã³ã³ãŒããŒ-ãã³ãŒããŒã¢ãã«ã§ããã°BARTãšåãã¢ãŒããã¯ãã£ãæã£ãŠããã¯ãã§ãã
ãã®æ®µéã§ã¯ãã¢ãã«ã®çè«çãªåŽé¢ã«ã€ããŠåŠãã ããšãæãåºãã¹ãã§ããã€ãŸããããã®ã¢ãã«ã¯BERTãŸãã¯BARTãšã©ã®ããã«ç°ãªãã®ãïŒããšããããšã§ãã
ãããã®å€æŽãå®è£
ããŸãããããã¯éåžžãã»ã«ãã¢ãã³ã·ã§ã³ã¬ã€ã€ãŒãæ£èŠåã¬ã€ã€ãŒã®é åºãªã©ãå€æŽããããšãæå³ããŸãã
åã³ãããªãã®ã¢ãã«ãã©ã®ããã«å®è£
ãããã¹ãããããè¯ãç解ããããã«ãTransformerså
ã«æ¢åã®ã¢ãã«ã®é¡äŒŒã¢ãŒããã¯ãã£ãèŠãããšã圹ç«ã€ããšããããŸãã
ãã®æç¹ã§ã¯ãã³ãŒããå®å
šã«æ£ç¢ºãŸãã¯ã¯ãªãŒã³ã§ããå¿
èŠã¯ãããŸããã
ãããããŸãã¯å¿
èŠãªã³ãŒãã®æåã®ã¯ãªãŒã³ã§ãªãã³ããŒïŒããŒã¹ãããŒãžã§ã³ã
src/transformers/models/brand_new_bert/modeling_brand_new_bert.py
ã«è¿œå ããå¿
èŠãªã³ãŒãããã¹ãŠè¿œå ãããŠãããšæãããŸã§æ¹å/ä¿®æ£ãå埩çã«è¡ãããšããå§ãã§ãã
ç§ãã¡ã®çµéšãããå¿
èŠãªã³ãŒãã®æåã®ããŒãžã§ã³ãè¿
éã«è¿œå ãã次ã®ã»ã¯ã·ã§ã³ã§èª¬æããå€æã¹ã¯ãªããã䜿çšããŠã³ãŒããç¹°ãè¿ãæ¹å/ä¿®æ£ããæ¹ãå¹ççã§ããããšãå€ãã§ãã
ãã®æç¹ã§åäœããå¿
èŠãããã®ã¯ãð€ Transformersã®âbrand_new_bertâã®å®è£
ãã€ã³ã¹ã¿ã³ã¹åã§ããããšã ãã§ããã€ãŸãã以äžã®ã³ãã³ããæ©èœããå¿
èŠããããŸãïŒ
from transformers import BrandNewBertModel, BrandNewBertConfig
model = BrandNewBertModel(BrandNewBertConfig())
äžèšã®ã³ãã³ãã¯ãBrandNewBertConfig()
ã§å®çŸ©ãããããã©ã«ããã©ã¡ãŒã¿ã«åŸã£ãŠã¢ãã«ãäœæãã
ãã¹ãŠã®ã³ã³ããŒãã³ãã® init()
ã¡ãœãããæ£åžžã«åäœããããšã確èªããŸãã
ãã¹ãŠã®ã©ã³ãã ãªåæåã¯ãBrandnewBertPreTrainedModel
ã¯ã©ã¹ã® _init_weights
ã¡ãœããã§è¡ãå¿
èŠããããŸãã
ãã®ã¡ãœããã¯ãèšå®å€æ°ã«äŸåãããã¹ãŠã®ãªãŒãã¢ãžã¥ãŒã«ãåæåããå¿
èŠããããŸãã以äžã¯ãBERT ã® _init_weights
ã¡ãœããã®äŸã§ãïŒ
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
ç¹å®ã®ã¢ãžã¥ãŒã«ã«ç¹å¥ãªåæåãå¿
èŠãªå Žåãã«ã¹ã¿ã ã¹ããŒã ãããã«æã€ããšãã§ããŸããããšãã°ã
Wav2Vec2ForPreTraining
ã§ã¯ãæåŸã®2ã€ã®ç·åœ¢å±€ã«ã¯éåžžã®PyTorchã®nn.Linear
ã®åæåãå¿
èŠã§ããã
ä»ã®ãã¹ãŠã®å±€ã¯äžèšã®ãããªåæåã䜿çšããå¿
èŠããããŸããããã¯ä»¥äžã®ããã«ã³ãŒãã£ã³ã°ãããŠããŸãïŒ
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Wav2Vec2ForPreTraining):
module.project_hid.reset_parameters()
module.project_q.reset_parameters()
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
_is_hf_initialized
ãã©ã°ã¯ããµãã¢ãžã¥ãŒã«ãäžåºŠã ãåæåããããšã確å®ã«ããããã«å
éšã§äœ¿çšãããŸãã
module.project_q
ãšmodule.project_hid
ã®ããã«ãããTrue
ã«èšå®ããããšã§ã
ã«ã¹ã¿ã åæåãåŸã§äžæžããããªãããã«ãã_init_weights
é¢æ°ããããã«é©çšãããªãããã«ããŸãã
6. å€æã¹ã¯ãªãããæžã
次ã«ãbrand_new_bert ã®å ã®ãªããžããªã§ãããã°ã«äœ¿çšãããã§ãã¯ãã€ã³ãããæ°ããäœæãã ð€ Transformers å®è£ ã® brand_new_bert ãšäºææ§ã®ãããã§ãã¯ãã€ã³ãã«å€æã§ããå€æã¹ã¯ãªãããæžãå¿ èŠããããŸãã å€æã¹ã¯ãªããããŒãããæžãããšã¯ãå§ããããŸãããã代ããã« ð€ Transformers ã§æ¢ã«ååšããé¡äŒŒã®ã¢ãã«ãåããã¬ãŒã ã¯ãŒã¯ã§å€æããã¹ã¯ãªããã調ã¹ãããšãè¯ãã§ãããã éåžžãæ¢åã®å€æã¹ã¯ãªãããã³ããŒããŠãèªåã®ãŠãŒã¹ã±ãŒã¹ã«ãããã«é©å¿ãããããšã§ååã§ãã Hugging Face ããŒã ã«æ¢åã®ã¢ãã«ã«é¡äŒŒããå€æã¹ã¯ãªãããæããŠãããããšãèºèºããªãã§ãã ããã
- TensorFlowããPyTorchã«ã¢ãã«ã移æ€ããŠããå Žåãè¯ãåºçºç¹ã¯BERTã®å€æã¹ã¯ãªãããããããŸãã here
- PyTorchããPyTorchã«ã¢ãã«ã移æ€ããŠããå Žåãè¯ãåºçºç¹ã¯BARTã®å€æã¹ã¯ãªãããããããŸãã here
以äžã§ã¯ãPyTorchã¢ãã«ãå±€ã®éã¿ãã©ã®ããã«ä¿åããå±€ã®ååãå®çŸ©ãããã«ã€ããŠç°¡åã«èª¬æããŸãã
PyTorchã§ã¯ãå±€ã®ååã¯å±€ã«äžããã¯ã©ã¹å±æ§ã®ååã«ãã£ãŠå®çŸ©ãããŸãã
PyTorch㧠SimpleModel
ãšãããããŒã¢ãã«ãå®çŸ©ããŸãããïŒ
from torch import nn
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.dense = nn.Linear(10, 10)
self.intermediate = nn.Linear(10, 10)
self.layer_norm = nn.LayerNorm(10)
ããã§ããã®ã¢ãã«å®çŸ©ã®ã€ã³ã¹ã¿ã³ã¹ãäœæããdense
ãintermediate
ãlayer_norm
ã®ãã¹ãŠã®éã¿ãã©ã³ãã ãªéã¿ã§åããã¢ãã«ãäœæã§ããŸããã¢ãã«ã®ã¢ãŒããã¯ãã£ã確èªããããã«ãã¢ãã«ãå°å·ããŠã¿ãŸãããã
model = SimpleModel()
print(model)
ããã¯ä»¥äžãåºåããŸãïŒ
SimpleModel(
(dense): Linear(in_features=10, out_features=10, bias=True)
(intermediate): Linear(in_features=10, out_features=10, bias=True)
(layer_norm): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
)
å±€ã®ååã¯PyTorchã®ã¯ã©ã¹å±æ§ã®ååã«ãã£ãŠå®çŸ©ãããŠããŸããç¹å®ã®å±€ã®éã¿å€ãåºåããããšãã§ããŸãïŒ
print(model.dense.weight.data)
ã©ã³ãã ã«åæåãããéã¿ã確èªããããã«
tensor([[-0.0818, 0.2207, -0.0749, -0.0030, 0.0045, -0.1569, -0.1598, 0.0212,
-0.2077, 0.2157],
[ 0.1044, 0.0201, 0.0990, 0.2482, 0.3116, 0.2509, 0.2866, -0.2190,
0.2166, -0.0212],
[-0.2000, 0.1107, -0.1999, -0.3119, 0.1559, 0.0993, 0.1776, -0.1950,
-0.1023, -0.0447],
[-0.0888, -0.1092, 0.2281, 0.0336, 0.1817, -0.0115, 0.2096, 0.1415,
-0.1876, -0.2467],
[ 0.2208, -0.2352, -0.1426, -0.2636, -0.2889, -0.2061, -0.2849, -0.0465,
0.2577, 0.0402],
[ 0.1502, 0.2465, 0.2566, 0.0693, 0.2352, -0.0530, 0.1859, -0.0604,
0.2132, 0.1680],
[ 0.1733, -0.2407, -0.1721, 0.1484, 0.0358, -0.0633, -0.0721, -0.0090,
0.2707, -0.2509],
[-0.1173, 0.1561, 0.2945, 0.0595, -0.1996, 0.2988, -0.0802, 0.0407,
0.1829, -0.1568],
[-0.1164, -0.2228, -0.0403, 0.0428, 0.1339, 0.0047, 0.1967, 0.2923,
0.0333, -0.0536],
[-0.1492, -0.1616, 0.1057, 0.1950, -0.2807, -0.2710, -0.1586, 0.0739,
0.2220, 0.2358]]).
ã¹ã¯ãªããå ã®å€æã¹ã¯ãªããã§ã¯ãã©ã³ãã ã«åæåãããéã¿ãã察å¿ãããã§ãã¯ãã€ã³ãå ã®æ£ç¢ºãªéã¿ã§åããå¿ èŠããããŸããäŸãã°ã以äžã®ããã«ç¿»èš³ããŸãïŒ
# retrieve matching layer weights, e.g. by
# recursive algorithm
layer_name = "dense"
pretrained_weight = array_of_dense_layer
model_pointer = getattr(model, "dense")
model_pointer.weight.data = torch.from_numpy(pretrained_weight)
PyTorchã¢ãã«ã®åã©ã³ãã åæåãããéã¿ãšå¯Ÿå¿ããäºååŠç¿æžã¿ãã§ãã¯ãã€ã³ãã®éã¿ã 圢ç¶ãšååã®äž¡æ¹ã§æ£ç¢ºã«äžèŽããããšã確èªããå¿ èŠããããŸãã ãããè¡ãããã«ã圢ç¶ã«å¯Ÿããassertã¹ããŒãã¡ã³ããè¿œå ãããã§ãã¯ãã€ã³ãã®éã¿ã®ååãåºåããããšã å¿ èŠäžå¯æ¬ ã§ããäŸãã°ã次ã®ãããªã¹ããŒãã¡ã³ããè¿œå ããå¿ èŠããããŸãïŒ
assert (
model_pointer.weight.shape == pretrained_weight.shape
), f"Pointer shape of random weight {model_pointer.shape} and array shape of checkpoint weight {pretrained_weight.shape} mismatched"
ãŸããäž¡æ¹ã®éã¿ã®ååãå°å·ããŠãäžèŽããŠããããšã確èªããå¿ èŠããããŸããäŸãã°ã次ã®ããã«ããŸãïŒ
logger.info(f"Initialize PyTorch weight {layer_name} from {pretrained_weight.name}")
ãã圢ç¶ãŸãã¯ååã®ãããããäžèŽããªãå Žåããããã誀ã£ãŠð€ Transformersã®å®è£ ã«åæåãããã¬ã€ã€ãŒã«ééã£ããã§ãã¯ãã€ã³ãã®éã¿ãå²ãåœãŠãŠããŸã£ãå¯èœæ§ããããŸãã
誀ã£ã圢ç¶ã¯ãããããBrandNewBertConfig()
ã§ã®èšå®ãã©ã¡ãŒã¿ãŒããå€æããããã§ãã¯ãã€ã³ãã§äœ¿çšããããã®ãšæ£ç¢ºã«äžèŽããªãããã§ãã
ãã ããPyTorchã®ã¬ã€ã€ãŒã®å®è£
ã«ãã£ãŠã¯ãéã¿ãäºåã«è»¢çœ®ããå¿
èŠãããå ŽåããããŸãã
æåŸã«ããã¹ãŠã®å¿
èŠãªéã¿ãåæåãããŠããããšã確èªããåæåã«äœ¿çšãããªãã£ããã¹ãŠã®ãã§ãã¯ãã€ã³ãã®éã¿ã衚瀺ããŠãã¢ãã«ãæ£ããå€æãããŠããããšã確èªããŠãã ããã
å€æãã©ã€ã¢ã«ã誀ã£ã圢ç¶ã¹ããŒãã¡ã³ããŸãã¯èª€ã£ãååå²ãåœãŠã§å€±æããã®ã¯å®å
šã«æ£åžžã§ãã
ããã¯ãããããBrandNewBertConfig()
ã§èª€ã£ããã©ã¡ãŒã¿ãŒã䜿çšããããð€ Transformersã®å®è£
ã«èª€ã£ãã¢ãŒããã¯ãã£ãããããð€ Transformersã®å®è£
ã®1ã€ã®ã³ã³ããŒãã³ãã®init()
é¢æ°ã«ãã°ããããããã§ãã¯ãã€ã³ãã®éã¿ã®1ã€ã転眮ããå¿
èŠãããããã§ãã
ãã®ã¹ãããã¯ã以åã®ã¹ããããšç¹°ãè¿ãã¹ãã§ãããã¹ãŠã®ãã§ãã¯ãã€ã³ãã®éã¿ãæ£ããð€ Transformersã¢ãã«ã«èªã¿èŸŒãŸãããŸã§ç¹°ãè¿ãã¹ãã§ãã
ð€ Transformerså®è£
ã«æ£ãããã§ãã¯ãã€ã³ããèªã¿èŸŒãã åŸãéžæãããã©ã«ããŒã«ã¢ãã«ãä¿åã§ããŸã /path/to/converted/checkpoint/folder
ããã®ãã©ã«ãã«ã¯pytorch_model.bin
ãã¡ã€ã«ãšconfig.json
ãã¡ã€ã«ã®äž¡æ¹ãå«ãŸããã¯ãã§ãã
model.save_pretrained("/path/to/converted/checkpoint/folder")
7. é äŒæïŒforward passïŒã®å®è£
ð€ Transformerså®è£ ã§äºååŠç¿æžã¿ã®éã¿ãæ£ããèªã¿èŸŒãã åŸãé äŒæãæ£ããå®è£ ãããŠããããšã確èªããå¿ èŠããããŸããå ã®ãªããžããªãç解ããã§ãå ã®ãªããžããªã䜿çšããŠã¢ãã«ã®é äŒæãå®è¡ããã¹ã¯ãªããããã§ã«äœæããŸãããä»åºŠã¯ãå ã®ãªããžããªã®ä»£ããã«ð€ Transformerså®è£ ã䜿çšããŠé¡äŒŒã®ã¹ã¯ãªãããäœæããå¿ èŠããããŸãã以äžã®ããã«ãªããŸãïŒ
model = BrandNewBertModel.from_pretrained("/path/to/converted/checkpoint/folder")
input_ids = [0, 4, 4, 3, 2, 4, 1, 7, 19]
output = model(input_ids).last_hidden_states
ð€ Transformersã®å®è£
ãšå
ã®ã¢ãã«ã®å®è£
ãæåã®å®è¡ã§å®å
šã«åãåºåãæäŸããªããã
ãã©ã¯ãŒããã¹ã§ãšã©ãŒãçºçããå¯èœæ§ãéåžžã«é«ãã§ãã倱æããªãã§ãã ãã - ããã¯äºæ³ãããŠããããšã§ãïŒ
ãŸãããã©ã¯ãŒããã¹ããšã©ãŒãã¹ããŒããªãããšã確èªããå¿
èŠããããŸãã
ééã£ã次å
ã䜿çšããã次å
ã®äžäžèŽãšã©ãŒãã誀ã£ãããŒã¿åãªããžã§ã¯ãã䜿çšãããããšããããããŸãã
äŸãã°ãtorch.long
ã§ã¯ãªãtorch.float32
ã䜿çšãããŸããç¹å®ã®ãšã©ãŒã解決ã§ããªãå Žåã¯ã
Hugging FaceããŒã ã«å©ããæ±ããããšãèºèºããªãã§ãã ããã
ð€ Transformerså®è£
ãæ£ããæ©èœããããšã確èªããæçµçãªéšåã¯ãåºåã1e-3
ã®ç²ŸåºŠã§åçã§ããããšã確èªããããšã§ãã
ãŸããåºåã®åœ¢ç¶ãåäžã§ããããšãã€ãŸãã¹ã¯ãªããã®ð€ Transformerså®è£
ãšå
ã®å®è£
ã®äž¡æ¹ã§outputs.shape
ãåãå€ãçæããå¿
èŠããããŸãã
次ã«ãåºåå€ãåäžã§ããããšã確èªããå¿
èŠããããŸãã
ããã¯æ°ããã¢ãã«ãè¿œå ããéã®æãé£ããéšåã®1ã€ã§ãã
åºåãåäžã§ãªãçç±ã®äžè¬çãªééãã¯ä»¥äžã®éãã§ãã
- äžéšã®ã¬ã€ã€ãŒãè¿œå ãããŠããªããã€ãŸã掻æ§åã¬ã€ã€ãŒãè¿œå ãããŠããªããããªã¶ãã«æ¥ç¶ãå¿ããããŠãã
- åèªåã蟌ã¿è¡åãçµã°ããŠããªã
- ãªãªãžãã«ã®å®è£ ããªãã»ããã䜿çšããŠããããã誀ã£ãäœçœ®åã蟌ã¿ã䜿çšãããŠãã
- ãã©ã¯ãŒããã¹äžã«ããããã¢ãŠããé©çšãããŠããŸãããããä¿®æ£ããã«ã¯ãmodel.trainingãFalseã§ããããšã確èªãããã©ã¯ãŒããã¹äžã«èª€ã£ãŠããããã¢ãŠãã¬ã€ã€ãŒãã¢ã¯ãã£ãåãããªãããã«ããŸãã ã€ãŸã PyTorchã®functional dropoutã«model.trainingãæž¡ããŸãã
åé¡ãä¿®æ£ããæè¯ã®æ¹æ³ã¯ãéåžžãå
ã®å®è£
ãšð€ Transformerså®è£
ã®ãã©ã¯ãŒããã¹ã䞊ã¹ãŠè¡šç€ºããéãããããã©ããã確èªããããšã§ãã
çæ³çã«ã¯ããã©ã¯ãŒããã¹ã®äž¡æ¹ã®å®è£
ã®äžéåºåããããã°/ããªã³ãã¢ãŠãããŠãð€ Transformerså®è£
ãå
ã®å®è£
ãšç°ãªãåºåã瀺ããããã¯ãŒã¯å
ã®æ£ç¢ºãªäœçœ®ãèŠã€ããããšãã§ããŸãã
æåã«ãäž¡æ¹ã®ã¹ã¯ãªããã®ããŒãã³ãŒãã£ã³ã°ãããinput_ids
ãåäžã§ããããšã確èªããŸãã
次ã«ãinput_ids
ã®æåã®å€æïŒéåžžãåèªåã蟌ã¿ïŒã®åºåãåäžã§ããããšã確èªããŸãã
ãã®åŸããããã¯ãŒã¯ã®æåŸã®ã¬ã€ã€ãŒãŸã§äœæ¥ãé²ããŸãã
ããããã®æç¹ã§ã2ã€ã®å®è£
éã§éããããããšã«æ°ä»ãã¯ãã§ãããã«ããð€ Transformerså®è£
ã®ãã°ã®å Žæãç¹å®ãããŸãã
çµéšäžãå
ã®å®è£
ãšð€ Transformerså®è£
ã®ãã©ã¯ãŒããã¹ã®åãäœçœ®ã«å€ãã®ããªã³ãã¹ããŒãã¡ã³ããè¿œå ãã
äžéãã¬ãŒã³ããŒã·ã§ã³ã§åãå€ã瀺ãããªã³ãã¹ããŒãã¡ã³ãã段éçã«åé€ããã®ãã·ã³ãã«ãã€å¹æçãªæ¹æ³ã§ãã
äž¡æ¹ã®å®è£
ãåãåºåãçæããããšã«èªä¿¡ãæã£ãŠããå Žåãtorch.allclose(original_output, output, atol=1e-3)
ã䜿çšããŠåºåã確èªãããšãæãé£ããéšåãå®äºããŸãïŒ
ããã§ãšãããããŸã - å®äºããäœæ¥ã¯ç°¡åãªãã®ã«ãªãã¯ãã§ã ðã
8. å¿ èŠãªãã¹ãŠã®ã¢ãã«ãã¹ããè¿œå
ãã®æç¹ã§ãæ°ããã¢ãã«ãæ£åžžã«è¿œå ãããŸããã
ãã ããã¢ãã«ããŸã å¿
èŠãªèšèšã«å®å
šã«æºæ ããŠããªãå¯èœæ§ãéåžžã«é«ãã§ãã
ð€ Transformersãšå®å
šã«äºææ§ãããããšã確èªããããã«ããã¹ãŠã®äžè¬çãªãã¹ãããã¹ããå¿
èŠããããŸãã
Cookiecutterã¯ããããã¢ãã«çšã®ãã¹ããã¡ã€ã«ãèªåçã«è¿œå ããŠããã¯ãã§ãããããåããã£ã¬ã¯ããªã«tests/models/brand_new_bert/test_modeling_brand_new_bert.py
ãšããŠååšããŸãã
ãã®ãã¹ããã¡ã€ã«ãå®è¡ããŠããã¹ãŠã®äžè¬çãªãã¹ãããã¹ããããšã確èªããŠãã ããïŒ
pytest tests/models/brand_new_bert/test_modeling_brand_new_bert.py
ãã¹ãŠã®äžè¬çãªãã¹ããä¿®æ£ããããä»åºŠã¯å®è¡ãããã¹ãŠã®çŽ æŽãããäœæ¥ãé©åã«ãã¹ããããŠããããšã確èªããããšãéåžžã«éèŠã§ããããã«ããã
- a) ã³ãã¥ããã£ã¯brand_new_bertã®ç¹å®ã®ãã¹ããèŠãããšã§ãããªãã®äœæ¥ãç°¡åã«ç解ã§ããŸãã
- b) ã¢ãã«ãžã®å°æ¥ã®å€æŽãã¢ãã«ã®éèŠãªæ©èœãå£ããªãããã«ããããšãã§ããŸãã
ãŸããçµ±åãã¹ããè¿œå ããå¿ èŠããããŸãããããã®çµ±åãã¹ãã¯ãåºæ¬çã«ã¯ãããã°ã¹ã¯ãªãããšåãããšãè¡ããŸãããããã®ã¢ãã«ãã¹ãã®ãã³ãã¬ãŒãã¯Cookiecutterã«ãã£ãŠæ¢ã«è¿œå ãããŠããããBrandNewBertModelIntegrationTestsããšåŒã°ããŠããŸãããã®ãã¹ããèšå ¥ããã ãã§ãããããã®ãã¹ããåæ ŒããŠããããšã確èªããã«ã¯ã次ã®ã³ãã³ããå®è¡ããŸãã
RUN_SLOW=1 pytest -sv tests/models/brand_new_bert/test_modeling_brand_new_bert.py::BrandNewBertModelIntegrationTests
Windowsã䜿çšããŠããå ŽåãRUN_SLOW=1
ãSET RUN_SLOW=1
ã«çœ®ãæããŠãã ããã
次ã«ãbrand_new_bertã«ç¹æã®ãã¹ãŠã®ç¹åŸŽã¯ãå¥åã®ãã¹ãå
ã§è¿œå ãããã¹ãã§ãã
BrandNewBertModelTester
/BrandNewBertModelTest
ã®äžã«ããã®éšåã¯ããå¿ããããŸããã2ã€ã®ç¹ã§éåžžã«åœ¹ç«ã¡ãŸãïŒ
- ã¢ãã«ã®è¿œå äžã«ç²åŸããç¥èãã³ãã¥ããã£ã«äŒããbrand_new_bertã®ç¹å¥ãªæ©èœãã©ã®ããã«åäœãããã瀺ãããšã«ãã£ãŠãç¥èã®å ±æãæ¯æŽããŸãã
- å°æ¥ã®è²¢ç®è ã¯ããããã®ç¹å¥ãªãã¹ããå®è¡ããããšã§ã¢ãã«ãžã®å€æŽãè¿ éã«ãã¹ãã§ããŸãã
9. ããŒã¯ãã€ã¶ã®å®è£
次ã«ãbrand_new_bertã®ããŒã¯ãã€ã¶ãè¿œå ããå¿ èŠããããŸããéåžžãããŒã¯ãã€ã¶ã¯ð€ Transformersã®æ¢åã®ããŒã¯ãã€ã¶ãšåçãéåžžã«äŒŒãŠããŸãã
ããŒã¯ãã€ã¶ãæ£ããåäœããããšã確èªããããã«ã¯ããŸããå
ã®ãªããžããªå
ã§æååãå
¥åããinput_ids
ãè¿ãã¹ã¯ãªãããäœæããããšããå§ãããŸãã
ãã®ã¹ã¯ãªããã¯ã次ã®ããã«èŠãããããããŸããïŒç䌌ã³ãŒãã§ç€ºããŸãïŒïŒ
input_str = "This is a long example input string containing special characters .$?-, numbers 2872 234 12 and words."
model = BrandNewBertModel.load_pretrained_checkpoint("/path/to/checkpoint/")
input_ids = model.tokenize(input_str)
ãªãªãžãã«ã®ãªããžããªã詳ãã調æ»ããæ£ããããŒã¯ãã€ã¶ã®é¢æ°ãèŠã€ããå¿
èŠããããããããŸããã
ãŸãã¯ããªãªãžãã«ã®ãªããžããªã®ã¯ããŒã³ãå€æŽããŠãinput_ids
ã ããåºåããããã«ããå¿
èŠããããããããŸããã
ãªãªãžãã«ã®ãªããžããªã䜿çšããæ©èœçãªããŒã¯ãã€ãŒãŒã·ã§ã³ã¹ã¯ãªãããäœæããåŸã
ð€ Transformersåãã®é¡äŒŒããã¹ã¯ãªãããäœæããå¿
èŠããããŸãã
以äžã®ããã«èŠããã¹ãã§ãïŒ
from transformers import BrandNewBertTokenizer
input_str = "This is a long example input string containing special characters .$?-, numbers 2872 234 12 and words."
tokenizer = BrandNewBertTokenizer.from_pretrained("/path/to/tokenizer/folder/")
input_ids = tokenizer(input_str).input_ids
input_ids
ãåãå€ãçæããå Žåãæçµã¹ããããšããŠããŒã¯ãã€ã¶ã®ãã¹ããã¡ã€ã«ãè¿œå ããã¹ãã§ãã
brand_new_bertã®ã¢ãã«ã³ã°ãã¹ããã¡ã€ã«ãšåæ§ã«ãbrand_new_bertã®ããŒã¯ãã€ãºãã¹ããã¡ã€ã«ã«ã¯ãããã€ãã®ããŒãã³ãŒããããçµ±åãã¹ããå«ãŸããã¹ãã§ãã
10. ãšã³ãããŒãšã³ãçµ±åãã¹ãã®å®è¡
ããŒã¯ãã€ã¶ãè¿œå ããåŸãð€ Transformers
å
ã®tests/models/brand_new_bert/test_modeling_brand_new_bert.py
ã«
ã¢ãã«ãšããŒã¯ãã€ã¶ã®äž¡æ¹ã䜿çšããããã€ãã®ãšã³ãããŒãšã³ãçµ±åãã¹ããè¿œå ããå¿
èŠããããŸãã
ãã®ãããªãã¹ãã¯ãð€ Transformersã®å®è£
ãæåŸ
ã©ããã«æ©èœããããšã瀺ãã¹ãã§ãã
æå³ã®ããããã¹ã察ããã¹ãã®ãµã³ãã«ãå«ãŸããŸããæçšãªããã¹ã察ããã¹ãã®ãµã³ãã«ã«ã¯ããœãŒã¹ããã¿ãŒã²ãããžã®ç¿»èš³ãã¢ãèšäºããèŠçŽãžã®ãã¢ã質åããåçãžã®ãã¢ãªã©ãå«ãŸããŸãã
ããŒãããããã§ãã¯ãã€ã³ããããŠã³ã¹ããªãŒã ã¿ã¹ã¯ã§ãã¡ã€ã³ãã¥ãŒãã³ã°ãããŠããªãå Žåãã¢ãã«ã®ãã¹ãã«äŸåããã ãã§ååã§ãã
ã¢ãã«ãå®å
šã«æ©èœããŠããããšã確èªããããã«ããã¹ãŠã®ãã¹ããGPUäžã§å®è¡ããããšããå§ãããŸãã
ã¢ãã«ã®å
éšãã³ãœã«ã«.to(self.device)
ã¹ããŒãã¡ã³ããè¿œå ããã®ãå¿ããå¯èœæ§ãããããããã®ãããªãã¹ãã§ã¯ãšã©ãŒã衚瀺ãããããšããããŸãã
GPUã«ã¢ã¯ã»ã¹ã§ããªãå ŽåãHugging FaceããŒã ã代ããã«ãããã®ãã¹ããå®è¡ã§ããŸãã
11. ããã¥ã¡ã³ãã®è¿œå
ããã§ãbrand_new_bertã®å¿
èŠãªãã¹ãŠã®æ©èœãè¿œå ãããŸãã - ã»ãŒå®äºã§ãïŒæ®ãã®è¿œå ãã¹ãããšã¯ãè¯ãããã¥ã¡ã³ããšããã¥ã¡ã³ãããŒãžã§ãã
Cookiecutterãdocs/source/model_doc/brand_new_bert.md
ãšãããã³ãã¬ãŒããã¡ã€ã«ãè¿œå ããŠããã¯ãã§ããããèšå
¥ããå¿
èŠããããŸãã
ã¢ãã«ã®ãŠãŒã¶ãŒã¯éåžžãã¢ãã«ã䜿çšããåã«ãŸããã®ããŒãžãèŠãŸãããããã£ãŠãããã¥ã¡ã³ããŒã·ã§ã³ã¯ç解ããããç°¡æœã§ããå¿
èŠããããŸãã
ã¢ãã«ã®äœ¿çšæ¹æ³ã瀺ãããã«ããã€ãã®Tipsãè¿œå ããããšã¯ã³ãã¥ããã£ã«ãšã£ãŠéåžžã«åœ¹ç«ã¡ãŸããããã¥ã¡ã³ããŒã·ã§ã³ã«é¢ããŠã¯ãHugging FaceããŒã ã«åãåãããããšããããããªãã§ãã ããã
次ã«ãsrc/transformers/models/brand_new_bert/modeling_brand_new_bert.py
ã«è¿œå ãããããã¥ã¡ã³ããŒã·ã§ã³æååãæ£ããããšãããã³ãã¹ãŠã®å¿
èŠãªå
¥åããã³åºåãå«ãã§ããããšã確èªããŠãã ããã
ããã¥ã¡ã³ããŒã·ã§ã³ã®æžãæ¹ãšããã¥ã¡ã³ããŒã·ã§ã³æååã®ãã©ãŒãããã«ã€ããŠè©³çŽ°ãªã¬ã€ãããã¡ãã«ãããŸãã
ããã¥ã¡ã³ããŒã·ã§ã³ã¯éåžžãã³ãã¥ããã£ãšã¢ãã«ã®æåã®æ¥è§Šç¹ã§ãããããã³ãŒããšåãããã泚ææ·±ãæ±ãã¹ãã§ããããšãåžžã«å¿µé ã«çœ®ããŠãã ããã
ã³ãŒãã®ãªãã¡ã¯ã¿ãªã³ã°
çŽ æŽããããããã§brand_new_bertã«å¿ èŠãªãã¹ãŠã®ã³ãŒããè¿œå ãããŸããã ãã®æç¹ã§ã次ã®ãããªããã³ã·ã£ã«ãªã³ãŒãã¹ã¿ã€ã«ã®èª€ããèšæ£ããããã«ä»¥äžãå®è¡ããå¿ èŠããããŸãïŒ
make style
ããªãã®ã³ãŒãã£ã³ã°ã¹ã¿ã€ã«ãå質ãã§ãã¯ããã¹ããããšã確èªããŠãã ãã:
make quality
ð€ Transformersã®éåžžã«å³æ Œãªãã¶ã€ã³ãã¹ãã«ã¯ããŸã åæ ŒããŠããªãå¯èœæ§ãããããã€ãã®ä»ã®ãã¹ããååšãããããããŸããã ããã¯ãããã¥ã¡ã³ãæååã«æ å ±ãäžè¶³ããŠããããååãééã£ãŠããããšãåå ã§ããããšãå€ãã§ããHugging FaceããŒã ã¯ãããã§è©°ãŸã£ãŠããå Žåã«ã¯å¿ ãå©ããŠãããã§ãããã
æåŸã«ãã³ãŒããæ£ããæ©èœããããšã確èªããåŸãã³ãŒãããªãã¡ã¯ã¿ãªã³ã°ããã®ã¯åžžã«è¯ãã¢ã€ãã¢ã§ãã ãã¹ãŠã®ãã¹ãããã¹ããä»ãè¿œå ããã³ãŒããå床確èªããŠãªãã¡ã¯ã¿ãªã³ã°ãè¡ãã®ã¯è¯ãã¿ã€ãã³ã°ã§ãã
ããã§ã³ãŒãã£ã³ã°ã®éšåã¯å®äºããŸãããããã§ãšãããããŸãïŒ ð ããªãã¯çŽ æŽãããã§ãïŒ ð
12. ã¢ãã«ãã¢ãã«ããã«ã¢ããããŒã
æåŸã®ããŒãã§ã¯ããã¹ãŠã®ãã§ãã¯ãã€ã³ããã¢ãã«ããã«å€æããŠã¢ããããŒãããåã¢ããããŒãããã¢ãã«ãã§ãã¯ãã€ã³ãã«ã¢ãã«ã«ãŒããè¿œå ããå¿
èŠããããŸãã
ã¢ãã«ããã®æ©èœã«ã€ããŠè©³ããã¯ãModel sharing and uploading Pageãèªãã§ç解ã§ããŸãã
ããã§ã¯ãbrand_new_bertã®èè
çµç¹ã®äžã«ã¢ãã«ãã¢ããããŒãã§ããããã«å¿
èŠãªã¢ã¯ã»ã¹æš©ãååŸããããã«ãHugging FaceããŒã ãšååããå¿
èŠããããŸãã
transformers
ã®ãã¹ãŠã®ã¢ãã«ã«ååšããpush_to_hub
ã¡ãœããã¯ããã§ãã¯ãã€ã³ããããã«ããã·ã¥ããè¿
éãã€å¹ççãªæ¹æ³ã§ãã
以äžã«ãå°ãã®ã³ãŒãã¹ããããã瀺ããŸãïŒ
brand_new_bert.push_to_hub("brand_new_bert")
# Uncomment the following line to push to an organization.
# brand_new_bert.push_to_hub("<organization>/brand_new_bert")
åãã§ãã¯ãã€ã³ãã«é©åãªã¢ãã«ã«ãŒããäœæãã䟡å€ããããŸããã¢ãã«ã«ãŒãã¯ããã®ç¹å®ã®ãã§ãã¯ãã€ã³ãã®ç¹æ§ããã€ã©ã€ãããã¹ãã§ããäŸãã°ããã®ãã§ãã¯ãã€ã³ãã¯ã©ã®ããŒã¿ã»ããã§äºååŠç¿/ãã¡ã€ã³ãã¥ãŒãã³ã°ãããããã©ã®ãããªäžæµã¿ã¹ã¯ã§ã¢ãã«ã䜿çšãã¹ããã瀺ãã¹ãã§ãããŸããã¢ãã«ã®æ£ãã䜿çšæ¹æ³ã«é¢ããã³ãŒããå«ããã¹ãã§ãã
13.ïŒãªãã·ã§ã³ïŒããŒãããã¯ã®è¿œå
brand_new_bertãæšè«ãŸãã¯äžæµã¿ã¹ã¯ã®ãã¡ã€ã³ãã¥ãŒãã³ã°ã«ã©ã®ããã«è©³çŽ°ã«äœ¿çšã§ãããã瀺ãããŒãããã¯ãè¿œå ããããšã¯éåžžã«åœ¹ç«ã¡ãŸããããã¯ããªãã®PRãããŒãžããããã«å¿ é ã§ã¯ãããŸããããã³ãã¥ããã£ã«ãšã£ãŠéåžžã«æçšã§ãã
14. å®æããPRã®æåº
ããã°ã©ãã³ã°ãå®äºããããæåŸã®ã¹ãããã«ç§»åããPRãã¡ã€ã³ãã©ã³ãã«ããŒãžããŸããããéåžžãHugging FaceããŒã ã¯ãã®æç¹ã§æ¢ã«ããªãããµããŒãããŠããã¯ãã§ãããPRã«è¯ã説æãè¿œå ããã³ãŒãã«ã³ã¡ã³ããè¿œå ããŠãã¬ãã¥ã¢ãŒã«ç¹å®ã®èšèšã®éžæè¢ãææãããå Žåã¯ã³ã¡ã³ããè¿œå ããããšã䟡å€ããããŸãã
Share your work!!
ãããã³ãã¥ããã£ããããªãã®äœæ¥ã«å¯Ÿããè©äŸ¡ãåŸãæãæ¥ãŸããïŒã¢ãã«ã®è¿œå ãå®äºããããšã¯ãTransformersããã³NLPã³ãã¥ããã£ã«ãšã£ãŠéèŠãªè²¢ç®ã§ããããªãã®ã³ãŒããšããŒããããäºååŠç¿æžã¿ã¢ãã«ã¯ãäœçŸäººãäœå人ãšããéçºè ãç 究è ã«ãã£ãŠç¢ºå®ã«äœ¿çšãããã§ããããããªãã®ä»äºã«èªããæã¡ãã³ãã¥ããã£ãšããªãã®ææãå ±æããŸãããã
ããªãã¯ã³ãã¥ããã£ã®èª°ã§ãç°¡åã«ã¢ã¯ã»ã¹ã§ããå¥ã®ã¢ãã«ãäœæããŸããïŒ ð€¯
< > Update on GitHub