EAGLE: ETRI's Advanced-lightweight Generative Language Engine
(๊ณผ๊ฑฐ์ eGPT๋ก ๋ถ๋ ธ์ผ๋ฉฐ, 2024.11.14 ์ ์ด๋ฆ์ ๋ณ๊ฒฝํ์์ต๋๋ค. ์ถํ ๋ฆด๋ฆฌ์ฆ๋๋ ๋ชจ๋ธ์ prefix๋ egpt- ๋์ eagle-๋ก ๋ณ๊ฒฝ๋ฉ๋๋ค.)
๋ณธ ๋ชจ๋ธ์ ์ฌ์ ํ์ต๋ง ์ํ๋ ๋ชจ๋ธ์ด๋ฉฐ, ๋ณ๋์ Instruction Tuning ๋ฑ์ด ์ ์ฉ๋์ง ์์ ๊ธฐ์ด ๋ชจ๋ธ์ ๋๋ค. ์ฑ๋ด ์คํ์ผ์ ์ ์ถ๋ ฅ์ด ํ์ํ ๊ฒฝ์ฐ, ๋ณ๋์ ๋ฏธ์ธ์กฐ์ ์ ๋ฐ๋์ ์ํํด์ผ ํฉ๋๋ค.
๋ชจ๋ธ ์ ๋ณด
1.3B Decoder-only, Causal ์ธ์ด๋ชจ๋ธ. ์ํ, ์ ๋ ์ถ๋ก ์ ๋น๋กฏํ STEM ๋ถ์ผ์ ํนํ๋ ์๊ท๋ชจ ์ธ์ด๋ชจ๋ธ์ ์งํฅํฉ๋๋ค. ๋ฒ์ฉ ์ธ์ด๋ชจ๋ธ์ ์ญํ ์ ๋ชฉํ๋กํ์ง๋ ์๊ธฐ์, ํต์์ ์ดํด ๊ด๋ จ ๋ฒ์ฉ ํ์คํฌ ํ๊ฐ(e.g. hellaswag, sentineg ๋ฑ)์๋ ๋ฎ์ ์ฑ๋ฅ์ด ๋ํ๋ ์ ์์ต๋๋ค. ํ์ต ๋ฐ์ดํฐ ๋ณ๊ฒฝ ๋ฐ ํ์ต ๋ฐฉ๋ฒ ์์ , ๊ฐ์ ์ผ๋ก ์ธํด ๋ณธ ๋ชจ๋ธ์ ๋น์ ๊ธฐ์ ์ผ๋ก ์ ๋ฐ์ดํธ ๋ ์ ์์์ ๋ฏธ๋ฆฌ ์๋ ค๋๋ฆฝ๋๋ค.
Tokenizer๋ LLaMa์ ๊ตฌ์ฑ๊ณผ ์ ์ฌํ๊ฒ byte-fallbacked BPE + digit ๋ถ๋ฆฌ ๊ตฌ์ฑ์ ๊ฐ์ง๋, BOS/EOS(e.g. <s>,</s>
) ํ ํฐ์ด ๋ชจ๋ EOS(</s>
)๋ก ํต์ผ๋์ด ์์ต๋๋ค. ํ ํฌ๋์ด์ ์ค์ ์์ PAD ํ ํฐ์ ๋ณ๋๋ก ์ง์ ๋์ด ์์ง ์์ผ๋, Byte-level BPE์ ํน์ฑ์ <unk>
์ฌ๋ณผ์ด ์ฌ์ฉ๋์ง ์์ผ๋ฏ๋ก, ๋ฏธ์ธ์กฐ์ ๋จ๊ณ์์๋ <unk>
ํ ํฐ์ PAD ํ ํฐ์ผ๋ก ์ง์ ํ์ฌ ํ์ฉํ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค.
EleutherAI/gptneox ์ํคํ
์ณ๋ก ๊ตฌ์ฑ๋์ด ์์ผ๋ฉฐ, A100 80GB PCIE * 8์ฅ์์ ์ฝ 12์ฃผ๊ฐ ํ์ต(4์ฃผ์ฉ v1, v2, v3 ์ง์ ์ฌ์ ํ์ต; ์ฝ 500B tokens ํ์ต)ํ์ฌ ํ๋๋ ์ฌ์ ํ์ต ๋ชจ๋ธ์
๋๋ค.
ํต์ง์ฌํญ/Acknowledgement
- ์ด ๋ชจ๋ธ์ 2023๋ ๋ ์ ๋ถ(๊ณผํ๊ธฐ์ ์ ๋ณดํต์ ๋ถ)์ ์ฌ์์ผ๋ก ์ ๋ณดํต์ ๊ธฐํํ๊ฐ์์ ์ง์์ ๋ฐ์ ์ํ๋ ์ฐ๊ตฌ์ (RS-2023-00216011, ์ฌ๋์ฒ๋ผ ๊ฐ๋ ์ ์ผ๋ก ์ดํด/์ถ๋ก ์ด ๊ฐ๋ฅํ ๋ณตํฉ์ธ๊ณต์ง๋ฅ ์์ฒ๊ธฐ์ ์ฐ๊ตฌ)
- This work was supported by Institute of Information & Communications Technology Planning & Evaluation(IITP) grant funded by the Korea government(MSIT) (RS-2023-00216011, Development of artificial complex intelligence for conceptually understanding and inferring like human)
์ ํ์ ๋ชจ๋ธ ์ ๊ทผ ๋ฐ, ๋ชจ๋ธ ์ ๊ทผ ํ๊ฐ์ ๊ด๋ จํ ๊ฐ์ธ์ ๋ณด ์์ง ๋ฐ ์ฌ์ฉ ์๋ด/Information on Collection and Use of Personal Information for Gated Model Access
๋ณธ ๋ชจ๋ธ์ ์ฐ๊ตฌ์ ๊ต์ก ๋ชฉ์ ์ผ๋ก๋ง ์ฌ์ฉ ๋ ์ ์์ผ๋ฉฐ, ํ์ฌ ์ ํ์ ๊ณต๊ฐ ์ํ๋ก, ๋ณธ ์ธ์ด ๋ชจ๋ธ์ ๋ค์ด๋ก๋์๋ ๋ด๋น์ ์ฌ์ ์น์ธ์ด ํ์ํฉ๋๋ค. ์ฌ์ ์น์ธ๊ณผ ๊ด๋ จ๋ ๋ฌธ์์ฌํญ์ ๋ณ๋์ ๋ฉ์ผ(jhshin82 at etri.re.kr)๋ก ์์ฒญ ๋ถํ๋๋ฆฝ๋๋ค.
๋ณธ ๋ชจ๋ธ๊ณผ ๊ด๋ จํด ์ฌํ์ , ๋ฒ์ ๋ฌธ์ ๊ฐ ๋ฐ์ํ ๊ฒฝ์ฐ ๋ชจ๋ธ์ ์ฌ์ฉ์ ์ ํํ๊ณ , ๋ฐฐํฌ๋ฅผ ์ฒ ํํ ์ ์์ต๋๋ค. ์ด๋ฅผ ์ํด ๋ชจ๋ธ ์ ๊ทผ ํ๊ฐ์ ์ฌ์ฉ๋ ์ด๋ฉ์ผ ์ฃผ์๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ์์ง, ๋ณด์ , ์ด์ฉํ ์ ์์ต๋๋ค.
๊ฐ์ธ์ ๋ณด ์์ง๋์/Concent to collection of Personal Information
๋ณธ ๋ชจ๋ธ์ ์ฌ์ฉ๊ณผ ๊ด๋ จ, ๋ฐฐํฌ/์ฌ์ฉ ์ ํ/์ฒ ํ, ๊ทธ ์ธ ์ฌ์ฉ์์ ์ด์ต์ ๊ด๊ณ๋ ๋ผ์ด์ ์ค ๋ณ๊ฒฝ ์ ์ด๋ฅผ ํต์งํ๊ธฐ ์ํด, ์๋์ ๊ฐ์ด ๊ฐ์ธ์ ๋ณด๋ฅผ ์์ง, ์ด์ฉํฉ๋๋ค.
์์ง ๋ชฉ์ | ์์ง ํญ๋ชฉ | ๋ณด์ , ์ด์ฉ๊ธฐ๊ฐ |
---|---|---|
๋ชจ๋ธ์ ์ฌ์ฉ์ ํ/์ฒ ํ ์์ฒญ ๋ชฉ์ | ์ด๋ฉ์ผ ์ฃผ์, huggingface hub ID | ๋ณธ ๋ชจ๋ธ์ ๊ณต๊ฐ ๊ธฐ๊ฐ ๋ฐ ์ด์ฉ ๋ชฉ์ ๋ฌ์ฑ ์ |
๋ชจ๋ธ์ ์ฌ์ฉ ๋ผ์ด์ ์ค ๋ฑ ๋ณ๊ฒฝ ์๋ด | ์ด๋ฉ์ผ ์ฃผ์, huggingface hub ID | ๋ณธ ๋ชจ๋ธ์ ๊ณต๊ฐ ๊ธฐ๊ฐ ๋ฐ ์ด์ฉ ๋ชฉ์ ๋ฌ์ฑ ์ |
๋ณธ ๋ชจ๋ธ์ ๋ํ ์ ๊ทผ ์์ฒญ์ ์ํํ๊ณ , ๋ชจ๋ธ์ ์ ๊ทผํ์๋ ํ์๋ ์๋์ ์๋ด๋ ์๋ด์ฌํญ, ๋ณธ ๋ชจ๋ธ์ ํ๊ณ, ์ฑ ์์๋ AI ์ฐ๊ตฌ์ ๋ํ ์ ๋ณด, ๊ฐ์ธ์ ๋ณด ์์ง/์ด์ฉ์ ๋์ํ์ ๊ฒ์ผ๋ก ๊ฐ์ฃผํฉ๋๋ค. ์ฌ์ฉ์๋ ๋์๋ฅผ ๊ฑฐ๋ถํ์ค ๊ถ๋ฆฌ๊ฐ ์์ผ๋ฉฐ, ๋์๋ฅผ ๊ฑฐ๋ถํ์ค ๊ฒฝ์ฐ ๋ชจ๋ธ ์ฌ์ฉ์ด ์ ํ๋๋ฉฐ, ์ด์ ๊ด๋ จํ ์ฌ์ฉ, ๊ฒฐ๊ณผ์ ๋ํ ์ฑ ์์ ์ฌ์ฉ์์๊ฒ ์์์ ์๋ ค๋๋ฆฝ๋๋ค. ์ฌ์ฉ ํ ๋์ ์ฒ ํ, ๊ฐ์ธ์ ๋ณด ํ๊ธฐ์ ๋ํ ์ฌํญ์ ์๊ธฐ ์๋ด๋ ๋ฉ์ผ ์ฃผ์ ๋๋ Community tab์ ํตํด์ ์์ฒญํ์ค ์ ์์ต๋๋ค.
๋ชจ๋ธ์ ํ๊ณ, ์ฑ ์์๋ AI ์ฐ๊ตฌ๋ฅผ ์ํ ๊ด๋ จ ์ ๋ณด ์๋ด
๋ณธ ๋ชจ๋ธ์ ๊ฐ๋ฐ๊ณผ ๊ด๋ จํ ๊ฐ๋ฐ์ ๋ฐ ์กฐ์ง์ ์ฑ ์์๋ AI ์ฐ๊ตฌ๋ฅผ ์ค์ํ๊ณ ์ ๋ ธ๋ ฅํ๊ณ ์์ผ๋ฉฐ, ์ด์ ๊ด๋ จํด AI ์ฐ๊ตฌ์ ์ฌ์ฉ๋๋ ์ ์ถ๋ ฅ ๋ฐ์ดํฐ ๋ด ํฌํจ๋ ์์ค, ์๋, ์ ์น์ ๋ด์ฉ ๋ฐ ๊ธฐํ ๊ฑฐ์น ์ธ์ด์ ๋ํ ์ฒ๋ฆฌ๋ฅผ ์ํํ๊ณ ์ ๋ ธ๋ ฅํ๊ณ ์์ต๋๋ค. ๊ทธ๋ผ์๋ ๋ถ๊ตฌํ๊ณ , ์์ ์น ํ ์คํธ ๋ฐ์ดํฐ์ ํน์ฑ ์ ์ด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํด ํ์ต๋ ๋ณธ ์์ฑ ์ธ์ด ๋ชจ๋ธ์ ๊ฒฝ๋๋ ์ฌ์์ ํฌํจํ๊ฑฐ๋, ์ฌํ์ ์ผ๋ก ์ฉ์ธ๋ ์ ์๋ ํ ์คํธ๋ฅผ ์์ฑํ ์ ์์ผ๋ฉฐ, ๋ค๋ฅธ ์ธ์ด ๋ชจ๋ธ๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก ํน์ ํ๋กฌํํธ์ ๊ณต๊ฒฉ์ ์ธ ์ฝํ ์ธ ๊ฐ ๋ฐํ๋ ์ ์์ต๋๋ค. ์ด๋ฅผ ํฌํจ, ๋ณธ ๋ชจ๋ธ์ ์ถ๋ ฅ/์์ฑ ๊ฒฐ๊ณผ์ ๊ด๋ จํ ๋ด์ฉ์ ๊ฐ๋ฐ์ ๋ฐ ๊ฐ๋ฐ์๊ฐ ์ํ ์กฐ์ง์ ์ฌ์, ์๋์ ์ ํ ๊ด๋ จ์ด ์์์ ์๋ ค๋๋ฆฝ๋๋ค.
ํ ์คํธ์ค์ ๋ฐ์ํ ๋น์ ์์ ์ธ ํน์ ์ฌํ์ ์ผ๋ก ์ฉ์ธ๋์ง ์๋ ํ ์คํธ๊ฐ ์์ฑ๋ ๊ฒฝ์ฐ jhshin82 at etri.re.kr๋ก (__at__์ @๋ก ์นํ) ์ถ๋ ฅ ์ ๋์ ์ฌ์ฉ๋ ์ ๋ ฅ๋ฌธ(ํ๋กฌํํธ), ์ฌ์ฉ๋ ์ํ๋ง ๊ธฐ๋ฒ ๋ฐ ํ์ดํผํ๋ผ๋ฏธํฐ(์: top-p=0.8, temperature, repetition-penalty ๋ฑ), ์ด๋ฅผ ํตํด ์์ฑ๋ ์ถ๋ ฅ ๊ฒฐ๊ณผ๋ฅผ ํจ๊ป ๋ณด๋ด์ฃผ์๋ฉด, ์ด๋ฅผ ์ต์ ํ๊ธฐ ์ํ ๋ ธ๋ ฅ์ ๊ธฐ์ธ์ด๋๋ก ํ๊ฒ ์ต๋๋ค.
ํ๊ฐ/Evaluations
์ฌ์ ํ์ต ๋ชจ๋ธ์ KOBEST ํ๊ฐ
ํ๊ฐ๋ EleutherAI/lm-evaluation-harness, polyglot branch ๋ฅผ ์ฌ์ฉํ์ฌ, KoBEST(Kim et al., 2022) ํ๊ฐ์ ์ผ๋ก fine-tuning ์์ด zero-shot, 5-shot ํ ์คํธ๋ฅผ ์ํํ์ต๋๋ค. (lm-evaluation-harness์ KOBEST ํ๊ฐ๋ ๋ฒ์ ์ ๋ฐ๋ผ ๋ค๋ฅด๊ฒ ๋ํ๋ ๋ฌธ์ ๊ฐ ์์ด, ์ต์ lm-evaluation-harness(๋ฒ์ 0.4.2 ์ดํ)๋ฅผ ํตํ ํ๊ฐ๋ฅผ ์๋ ๋ณ๋๋ก ์ ์ํ์์ต๋๋ค.)
Zero-shot ์ฑ๋ฅ | KB-BOOLQ (F1) | KB-COPA (F1) | KB-HELLASWAG (F1) | KB-SENTINEG (F1) | KB-WIC (F1) |
---|---|---|---|---|---|
Polyglot-ko-1.3b | 0.3552ยฑ0.0087 | 0.7196ยฑ0.0142 | 0.4013ยฑ0.0217 | 0.6790ยฑ0.0239 | 0.3276ยฑ0.0064 |
egpt-1.3b (23/07) | 0.4903ยฑ0.0134 | 0.6612ยฑ0.0149 | 0.3925ยฑ0.0217 | 0.3383ยฑ0.0112 | 0.3280ยฑ0.0063 |
egpt-1.3b (23/11) | 0.3969ยฑ0.0112 | 0.6470ยฑ0.0151 | 0.3746ยฑ0.0214 | 0.3350ยฑ0.0111 | 0.3297ยฑ0.0066 |
egpt-1.3b (24/03) | 0.4034ยฑ0.0118 | 0.6438ยฑ0.0152 | 0.4150ยฑ0.0218 | 0.5272ยฑ0.0255 | 0.3294ยฑ0.0066 |
5-shot ์ฑ๋ฅ | KB-BOOLQ (F1) | KB-COPA (F1) | KB-HELLASWAG (F1) | KB-SENTINEG (F1) | KB-WIC (F1) |
---|---|---|---|---|---|
Polyglot-ko-1.3b | 0.4751ยฑ0.0133 | 0.7193ยฑ0.0142 | 0.3984ยฑ0.0218 | 0.6257ยฑ0.0244 | 0.4559ยฑ0.0138 |
egpt-1.3b (23/07) | 0.4829ยฑ0.0133 | 0.6558ยฑ0.0150 | 0.3846ยฑ0.0216 | 0.5715ยฑ0.0249 | 0.5108ยฑ0.0141 |
egpt-1.3b (23/11) | 0.4762ยฑ0.0133 | 0.6499ยฑ0.0151 | 0.3689ยฑ0.0214 | 0.5607ยฑ0.0249 | 0.4776ยฑ0.0141 |
egpt-1.3b (24/03) | 0.4944ยฑ0.0134 | 0.6643ยฑ0.0149 | 0.3862ยฑ0.0216 | 0.5232ยฑ0.0251 | 0.4947ยฑ0.0141 |
LM-Evaluation-Harness 0.4.2 ๋ฒ์ ์ด์(์ดํ LEH 0.4.2dev, commit id b1777c82) ์ผ๋ก ํ๊ฐ ์, KB-SENTINEG๋ ๋ ๋ฎ์ ์ ์๋ฅผ, ๋๋จธ์ง 4๊ฐ ํ๊ฐ ํญ๋ชฉ์ ๋ ๋์ ์ ์๋ก ๋ํ๋ฉ๋๋ค. polyglot branch์ ํ๊ฐ ์ค๋ฅ๊ฐ ์์ ๋ ๊ฒ์ผ๋ก ๋ณด์ฌ ์ต์ ๋ฒ์ ์ ํตํด ํ๊ฐํ๋ ๊ฒ์ด ์ข์ ๊ฒ์ผ๋ก ํ๋จ๋๋, ํ๊ฐ ์ผ๊ด์ฑ์ ์ํด polyglot branch์ ํ๊ฐ ์ ์๋ฅผ ๋ณ๋๋ก ์ ์งํ์์ต๋๋ค.
Zero-shot ์ฑ๋ฅ | KB-BOOLQ (F1) | KB-COPA (F1) | KB-HELLASWAG (F1) | KB-SENTINEG (F1) | KB-WIC (F1) |
---|---|---|---|---|---|
egpt-1.3b (23/11) - LEH v0.4.2dev ํ๊ฐ | 0.4926 | 0.6530 | 0.3933 | 0.3350 | 0.3280 |
egpt-1.3b (24/03) - LEH v0.4.2dev ํ๊ฐ | 0.4391 | 0.6497 | 0.4222 | 0.3733 | 0.3412 |
egpt-1.3b (24/03) - LEH polyglot branch ํ๊ฐ(์ฐธ๊ณ ) | 0.4034 | 0.6438 | 0.4150 | 0.5272 | 0.3294 |
์ ์ดํ์ต ๋ฅ๋ ฅ ํ๊ฐ
MetaMathQA๋ฅผ ํตํ ์์ด GSM8k ํ๊ฐ ์ ์๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค. ํ์ต ํ๊ฒฝ: LR 8e-5, FP16, TF32 ์ฌ์ฉ, ์ ํจ ๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ๋์ผํ๊ฒ 128๋ก ์ค์ (์: GPU=4 * batch size=GPU ๋น 8 * Gradient Accumulation=4). LR Scheduler๋ Cosine Decaying, Warmup ratio 0.03. no weight decay.
๋ชจ๋ธ | GSM8k test | ๋น๊ณ |
---|---|---|
polyglot-ko-1.3b | 0.2160 | |
polyglot-ko-12.8b | 0.3646 | LR์ 5e-5๋ก ์ธํ , Beomi/polyglot-ko-alpaca-12.8b์ hparam์ ๋ณด๊ณ LR์ ๊ฒฐ์ ํจ |
egpt-1.3b (23/11) | 0.4443 | |
egpt-1.3b (24/03) | 0.4147 |
์ ๋ฐ์ดํธ ๊ธฐ๋ก/Update log
- (23/7/27 ๋ชจ๋ธ) ์ด๊ธฐ ๋ชจ๋ธ. Polyglot 1.3b ๋๋น BOOLQ/WIC์์ ๋ ๋์ ์ฑ๋ฅ, ๊ทธ๋ฆฌ๊ณ COPA/HELLASWAG/SENTINEG์์ ์ด์ธ.
- (23/11/22 ๋ชจ๋ธ) ์ ์ฌํ ๋ฐ์ดํฐ ๊ตฌ์ฑ, ์ผ๋ถ ๋ฐ์ดํฐ ์ถ๊ฐํ์ฌ 23/7/27 ๋ชจ๋ธ๋ก ๋ถํฐ ์ถ๊ฐ ์ฌ์ ํ์ตํ ๊ฒ. ์ง์ ๋ชฉํ๋ฅผ ์ํ ๋ค๋ฅธ ํ๊ฐ ์ฒด๊ณ์์ ๋ ๋์ ์ฑ๋ฅ(39 vs 44)์ ๋ณด์ฌ ์ ๋ฐ์ดํธ ํจ
- (24/03/21 ๋ชจ๋ธ) AIHUB ๋ฐ์ดํฐ์ , ํ๊ตญ์ด ์ํค ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํด ์ถ๊ฐ ์ฌ์ ํ์ต
์ฌ์ ํ์ต์ ์ฐธ์ฌํ ๋ฐ์ดํฐ์ ์ ๋ณด/Datasets
์๋์ ํ์ต ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ฌ ํ์ตํ์์ต๋๋ค:
- AIHub ๋ฐ์ดํฐ์ , MRC, RAW, ๋ํ, ๋ฒ์ญ, ์์ฝ
- KISTI ๊ตญ๋ด๋ ผ๋ฌธ EN, KR ๋ฐ์ดํฐ์
- KcBERT v2022.3q ๋ค์ด๋ฒ ๋ด์ค ๋๊ธ ๋ฐ์ดํฐ์
- ๊ตญ๋ฆฝ๊ตญ์ด์ ๋ชจ๋์ ๋ง๋ญ์น(๋ฌธ์ด, ๊ตฌ์ด, ์ ๋ฌธ, ๋น์ถํ๋ฌผ, ๊ตญํํ์๋ก, ์ผ์๋ํ, ์จ๋ผ์ธ๋ํ, ๋ฉ์ ์ ๋ง๋ญ์น)
- ํ๊ตญ์ด ์ํคํผ๋์ด ๋คํ, lovit/ko-wikitext ๋ฐ์ดํฐ์ . 20200920.v3 ๋ฑ korpora ๋ฐ์ดํฐ์ ์ ์ฌ์ ํ์ต์ฉ ๋ง๋ญ์น ์ผ๋ถ
- (์) stack exchange ๋ฐ์ดํฐ์
- (์) OpenWebText2
(์) books3 corpus(๋ผ์ด์ ์ค ๋ฌธ์ ๋ก 2024/03์์ ์ ๊ฑฐ๋จ. removed on v3(2024/03) due to licensing issues)- (์) 2020-09-08-arXiv-extracts
- (์) PUBMED title abstracts 2019
- THUDM/MathGLM Arithmetic Text Corpus (applied from 23/11/22, https://github.com/THUDM/MathGLM)
์ฌ์ฉ ์๋ น/How to use
์๋ ์ฝ๋๋ฅผ ํตํด, transformers>=4.28 ๋ฒ์ ์์ ์ถ๋ก ๊ฐ๋ฅํฉ๋๋ค.
import sys
from transformers import (
AutoTokenizer, AutoModelForCausalLM, GenerationConfig
)
def load_model(mdl_path):
tokenizer = AutoTokenizer.from_pretrained(mdl_path, use_fast=True, legacy=False,)
# device_map ์ธ์๋ฅผ ์ฌ์ฉํ๊ธฐ ์ํด์๋ accelerator ๋ชจ๋ ์ค์น ํ์.
model = AutoModelForCausalLM.from_pretrained(mdl_path, device_map="auto",
torch_dtype="auto")
return tokenizer, model
if __name__ == '__main__':
# FIXME: ๋ชจ๋ธ ๊ฒฝ๋ก ์์ !
tokenizer, model = load_model("../egpt-1.3b-test-230720/")
# print(model.hf_device_map)
# ํ์์ ๋ฐ๋ผ ์๋ ์์ฑ ์ต์
์ ์ ์ด
gen_cfg = GenerationConfig(max_new_tokens=256, min_length=0,
max_time=10.0, do_sample=True,
top_p=0.9, epsilon_cutoff=3e-4,)
print("** Now Ready to input from stdin.")
for aline in sys.stdin:
aline = aline.rstrip("\n\r\t")
input_cond = tokenizer(aline, add_special_tokens=False, return_tensors="pt").to("cuda")
outs = model.generate(**input_cond, generation_config=gen_cfg)
out_str = tokenizer.batch_decode(outs, skip_special_tokens=True,
clean_up_tokenization_spaces=True)
print(">> " + ' '.join(out_str))
- Downloads last month
- 132