Llama3[[llama3]]
import transformers
import torch
model_id = "meta-llama/Meta-Llama-3-8B"
pipeline = transformers.pipeline("text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto")
pipeline("Hey how are you doing today?")
๊ฐ์[[overview]]
๋ผ๋ง3 ๋ชจ๋ธ์ Meta AI ํ์ด ์ ์ํ ๋ฉํ ๋ผ๋ง3 ์๊ฐ: ํ์ฌ๊น์ง ๊ฐ์ฅ ์ ๋ฅํ ๊ณต๊ฐ ๊ฐ๋ฅ LLM์์ ์๊ฐ๋์์ต๋๋ค.
ํด๋น ๋ธ๋ก๊ทธ ํฌ์คํธ์ ์ด๋ก์ ๋๋ค:
์ค๋, ๊ด๋ฒ์ํ ์ฌ์ฉ์ ์ํด ์ด์ฉ ๊ฐ๋ฅํ ๋ผ๋ง์ ์ฐจ์ธ๋ ๋ชจ๋ธ์ธ ๋ฉํ ๋ผ๋ง3์ ์ฒซ ๋ ๋ชจ๋ธ์ ๊ณต์ ํ๊ฒ ๋์ด ๊ธฐ์ฉ๋๋ค. ์ด๋ฒ ์ถ์๋ 8B์ 70B ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐ์ง ์ฌ์ ํ๋ จ ๋ฐ ์ง์ ๋ฏธ์ธ ์กฐ์ ๋ ์ธ์ด ๋ชจ๋ธ์ ํน์ง์ผ๋ก ํ๋ฉฐ, ๊ด๋ฒ์ํ ์ฌ์ฉ ์ฌ๋ก๋ฅผ ์ง์ํ ์ ์์ต๋๋ค. ๋ผ๋ง์ ์ด ์ฐจ์ธ๋ ๋ชจ๋ธ์ ๋ค์ํ ์ฐ์ ๋ฒค์น๋งํฌ์์ ์ต์ฒจ๋จ์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ฉฐ, ๊ฐ์ ๋ ์ถ๋ก ๋ฅ๋ ฅ์ ํฌํจํ ์๋ก์ด ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค. ์ฐ๋ฆฌ๋ ์ด๊ฒ๋ค์ด ๋จ์ฐ์ฝ ํด๋น ํด๋์ค์์ ์ต๊ณ ์ ์คํ ์์ค ๋ชจ๋ธ์ด๋ผ๊ณ ๋ฏฟ์ต๋๋ค. ์ค๋ ๊ฐ๋ฐฉ์ ์ ๊ทผ ๋ฐฉ์์ ์ง์งํ๋ฉฐ, ์ฐ๋ฆฌ๋ ๋ผ๋ง3๋ฅผ ์ปค๋ฎค๋ํฐ ๊ธฐ์ฌ์๋ค์๊ฒ ๋งก๊ธฐ๊ณ ์์ต๋๋ค. ์ ํ๋ฆฌ์ผ์ด์ ์์ ๊ฐ๋ฐ์ ๋๊ตฌ, ํ๊ฐ, ์ถ๋ก ์ต์ ํ ๋ฑ์ ์ด๋ฅด๊ธฐ๊น์ง AI ์คํ ์ ๋ฐ์ ๊ฑธ์น ๋ค์ ํ์ ์ ๋ฌผ๊ฒฐ์ ์ด๋ฐํ๊ธธ ํฌ๋งํฉ๋๋ค. ์ฌ๋ฌ๋ถ์ด ๋ฌด์์ ๋ง๋ค์ง ๊ธฐ๋ํ๋ฉฐ ์ฌ๋ฌ๋ถ์ ํผ๋๋ฐฑ์ ๊ณ ๋ํฉ๋๋ค.
๋ผ๋ง3 ๋ชจ๋ธ์ ๋ชจ๋ ์ฒดํฌํฌ์ธํธ๋ ์ด๊ณณ์์ ํ์ธํ์ธ์. ์๋ณธ ์ฝ๋๋ ์ด๊ณณ์์ ํ์ธํ ์ ์์ต๋๋ค.
์ฌ์ฉ ํ[[usage-tips]]
๋ผ๋ง3 ๋ชจ๋ธ๋ค์ bfloat16๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ จ๋์์ง๋ง, ์๋์ ์ถ๋ก ์ float16์ ์ฌ์ฉํฉ๋๋ค. Hub์ ์
๋ก๋๋ ์ฒดํฌํฌ์ธํธ๋ค์ torch_dtype = 'float16'์ ์ฌ์ฉํ๋๋ฐ, ์ด๋ AutoModel API๊ฐ ์ฒดํฌํฌ์ธํธ๋ฅผ torch.float32์์ torch.float16์ผ๋ก ๋ณํํ๋๋ฐ ์ด์ฉ๋ฉ๋๋ค.
model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ด๊ธฐํํ ๋, ์จ๋ผ์ธ ๊ฐ์ค์น์ dtype๋ torch_dtype="auto"๋ฅผ ์ฌ์ฉํ์ง ์๋ ํ ๋๋ถ๋ถ ๋ฌด๊ดํฉ๋๋ค. ๊ทธ ์ด์ ๋ ๋ชจ๋ธ์ด ๋จผ์ ๋ค์ด๋ก๋๋๊ณ (์จ๋ผ์ธ ์ฒดํฌํฌ์ธํธ์ dtype๋ฅผ ์ฌ์ฉ), ๊ทธ ๋ค์ torch์ dtype์ผ๋ก ๋ณํ๋์ด(torch.float32๊ฐ ๋จ), ๋ง์ง๋ง์ผ๋ก config์ torch_dtype์ด ์ ๊ณต๋ ๊ฒฝ์ฐ ๊ฐ์ค์น๊ฐ ์ฌ์ฉ๋๊ธฐ ๋๋ฌธ์
๋๋ค.
float16์ผ๋ก ๋ชจ๋ธ์ ํ๋ จํ๋ ๊ฒ์ ๊ถ์ฅ๋์ง ์์ผ๋ฉฐ nan์ ์์ฑํ๋ ๊ฒ์ผ๋ก ์๋ ค์ ธ ์์ต๋๋ค. ๋ฐ๋ผ์ ๋ชจ๋ ๋ชจ๋ธ์ bfloat16์ผ๋ก ํ๋ จ๋์ด์ผ ํฉ๋๋ค.
ํ:
๋ผ๋ง3 ๋ชจ๋ธ์ ์ํ ๊ฐ์ค์น๋ ์ด ํผ์ ์ฑ์ฐ๋ฉด์ ์ป์ด์ ธ์ผ ํฉ๋๋ค.
์ํคํ ์ฒ๋ ๋ผ๋ง2์ ์ ํํ ๊ฐ์ต๋๋ค.
ํ ํฌ๋์ด์ ๋ tiktoken (sentencepiece ๊ตฌํ์ ๊ธฐ๋ฐํ ๋ผ๋ง2 ์๋ ๋ค๋ฅด๊ฒ)์ ๊ธฐ๋ฐํ BPE ๋ชจ๋ธ์ ๋๋ค. tiktoken ๊ธฐ๋ฐ ํ ํฌ๋์ด์ ๊ฐ sebtencepiece ๊ธฐ๋ฐ ๋ฐฉ์๊ณผ ๋ค๋ฅธ์ ์ ์ ๋ ฅ ํ ํฐ์ด vocab์ ์ด๋ฏธ ์กด์ฌํ ๋ BPE ๋ณํฉ ๋ฃฐ์ ๋ฌด์ํ๊ณ ์ฑ๊ธ ํ ํฐ์ผ๋ก ํ ํฌ๋์ด์งํ๋ค๋ ์ ์์ ๊ฐ์ฅ ํฐ ์ฐจ์ด๋ฅผ ๋ณด์ ๋๋ค. ์์ธํ ๋งํ๋ฉด
"hugging"์ด vocab์ ์กด์ฌํ๊ณ ๊ธฐ์กด์ ๋ณํฉ์ด ์กด์ฌํ์ง ์์ผ๋ฉด,["hug","ging"]์ฒ๋ผ ๋ ํ ํฐ์ผ๋ก ๋ ์์ ๋จ์์ ๋จ์ด๋ฅผ ๊ฐ์ง๋ ๊ฒ์ด ์๋๋ผ, ํ๋์ ํ ํฐ๋ง์ ์๋์ผ๋ก ๋ฆฌํดํ๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.๊ธฐ๋ณธ ๋ชจ๋ธ์ ํจ๋ฉ ํ ํฐ์ด ์๋ค๋ ๊ฒ์ ์๋ฏธํ๋
pad_id = -1์ ์ฌ์ฉํฉ๋๋ค. ๊ฐ์ ๋ก์ง์ ์ฌ์ฉํ ์ ์์ผ๋tokenizer.add_special_tokens({"pad_token":"<pad>"})๋ฅผ ์ฌ์ฉํ์ฌ ํ ํฐ์ ์ถ๊ฐํ๊ณ ์๋ฒ ๋ฉ ํฌ๊ธฐ๋ ํ์คํ ์กฐ์ ํด์ผ ํฉ๋๋ค.model.config.pad_token_id๋ ์ค์ ์ด ํ์ํฉ๋๋ค. ๋ชจ๋ธ์embed_tokens๋ ์ด์ด๋self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.config.padding_idx)๋ก ์ด๊ธฐํ๋๋ฉฐ, ํจ๋ฉ ํ ํฐ์ ์ธ์ฝ๋ฉํ๋ ๊ฒ์ด 0(zero)๋ฅผ ์ถ๋ ฅํ๊ฒ ํ ๊ฒ์ธ์ง ๊ทธ๋์ ์ด๊ธฐํ๊ฐ ์ถ์ฒ๋ ๋ ์ด๋ฅผ ํตํ์ํฌ ๊ฒ์ธ์ง๋ฅผ ์ ํ๊ฒ ํฉ๋๋ค.์๋ณธ ์ฒดํฌํฌ์ธํธ๋ ์ด ๋ณํ ์คํฌ๋ฆฝํธ๋ฅผ ์ด์ฉํด์ ๋ณํ ๊ฐ๋ฅํฉ๋๋ค. ์คํฌ๋ฆฝํธ๋ ๋ค์ ๋ช ๋ น์ด๋ก ํธ์ถํ ์ ์์ต๋๋ค:
python src/transformers/models/llama/convert_llama_weights_to_hf.py \ --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path --llama_version 3๋ณํ ํ, ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ ๋ค์์ ํตํด ๋ก๋๋๋ค.
from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("/output/path") model = AutoModelForCausalLM.from_pretrained("/output/path")์ด ์คํฌ๋ฆฝํธ๋ฅผ ์คํ์ํค๋ ค๋ฉด ๋ชจ๋ธ ์ ์ฒด๋ฅผ float16 ์ ๋ฐ๋๋ก ํธ์คํ ํ ์ ์๋ ์ถฉ๋ถํ ๋ฉ์ธ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํ๋ค๋ ์ ์ ์ ์ํ์ธ์. ๊ฐ์ฅ ํฐ ๋ฒ์ ์ด ์ฌ๋ฌ ์ฒดํฌํฌ์ธํธ๋ก ๋๋์ด ์๋๋ผ๋, ๊ฐ ์ฒดํฌํฌ์ธํธ๊ฐ ๋ชจ๋ธ์ ๊ฐ์ค์น ์ผ๋ถ๋ฅผ ํฌํจํ๊ณ ์๊ธฐ ๋๋ฌธ์ ์ด๋ฅผ ๋ชจ๋ RAM์ ๋ก๋ํด์ผ ํฉ๋๋ค. 75B ๋ชจ๋ธ์ ์๋ก ๋ค๋ฉด ๋๋ต 145GB์ RAM์ด ํ์ํฉ๋๋ค.
attn_implementation="flash_attention_2"๋ฅผ ํตํด์ ํ๋์ ์ดํ ์ 2๋ฅผ ์ฌ์ฉํ ๋,from_pretrainedํด๋์ค ๋ฉ์๋์torch_dtype๋ฅผ ์ ๋ฌํ์ง ๋ง๊ณ ์๋ ํผํฉ ์ ๋ฐ๋(Automatic Mixed-Precision) ํ์ต์ ์ฌ์ฉํ์ธ์.Trainer๋ฅผ ์ฌ์ฉํ ๋๋ ๋จ์ํfp16๋๋bf16์True๋ก ์ค์ ํ๋ฉด ๋ฉ๋๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด ๋ฐ๋์torch.autocast๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค. ํ๋์ ์ดํ ์ ์fp16๊ณผbf16๋ฐ์ดํฐ ์ ํ๋ง ์ง์ํ๊ธฐ ๋๋ฌธ์ ๋๋ค.
์๋ฃ[[resources]]
๋ผ๋ง2 ๋ฌธ์ ํ์ด์ง์์๋ ์ด๋ฏธ ์ ๋ง์ ๋ฉ์ง๊ณ ์ ์ตํ ์๋ฃ๋ค์ ์ ๊ณตํ๊ณ ์์ต๋๋ค. ์ด๊ณณ์ ๋ผ๋ง3์ ๋ํ ์๋ก์ด ์๋ฃ๋ฅผ ๋ํด์ฃผ์ค ์ปจํธ๋ฆฌ๋ทฐํฐ๋ค์ ์ด๋ํฉ๋๋ค! ๐ค