CodeMind-gemma-2b / README.md
kreimben's picture
Update README.md
9af0318 verified
metadata
language:
  - en
license: mit
library_name: transformers
tags:
  - code
base_model:
  - google/gemma-1.1-2b-it
datasets:
  - kreimben/leetcode_with_youtube_captions
  - kreimben/leetcode_user_submissions
widget:
  - text: >-
      explain about two sum problem. from brute force approach to the most
      advanced algorithms.
    example_title: two sum example
  - text: explain about leetcode 72 edit distance. i don't get even the approach.
    example_title: edit distance example
  - text: explain about leetcode 139 Word Break. please give me the approach.
    example_title: word break example
inference:
  parameters:
    max_new_tokens: 250
    temperature: 0.3
pipeline_tag: text-generation

CodeMind

์†Œ๊ฐœ

์ฝ”๋”ฉ ํ…Œ์ŠคํŠธ ๋ฌธ์ œ ํ•ด๊ฒฐ ๋ฐ ํ•™์Šต ๋ณด์กฐ๋ฅผ ์ง€์›ํ•ด ์ฃผ๋Š” ์–ธ์–ด ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. Leetcode ํ•ด์„ค ์˜์ƒ ์ž๋ง‰ ๋ฐ ์œ ์ €๋“ค์˜ ํฌ์ŠคํŒ… ๊ธ€์„ ์ด์šฉํ•ด ํŒŒ์ธํŠœ๋‹ํ•˜์—ฌ ์ฝ”๋”ฉ ํ…Œ์ŠคํŠธ์— ์กฐ๊ธˆ ๋” ํŠนํ™”๋œ ๋‹ต์•ˆ์„ ์ œ์‹œํ•ด ์ค„ ์ˆ˜ ์žˆ๊ฒŒ ํ•˜์˜€์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ ์„ธ๋ถ€ ์ •๋ณด

  • ๋ชจ๋ธ ์ด๋ฆ„: CodeMind
  • ๊ธฐ๋ณธ ๋ชจ๋ธ: google/gemma-1.1-2b-it
  • ํ›ˆ๋ จ ์–ธ์–ด: ์˜์–ด
  • ๋ชจ๋ธ ํฌ๊ธฐ: 2.51B ํŒŒ๋ผ๋ฏธํ„ฐ

ํŒ€์› ๊ตฌ์„ฑ

  • NLP 3๋ช…
  • SRE 2๋ช…

์ฃผ์š” ๊ธฐ๋Šฅ

  • ๋ฌธ์ œ ์œ ํ˜• ๋ฐ ์ ‘๊ทผ๋ฒ• ์„ค๋ช…
  • ์ •๋‹ต ์ฝ”๋“œ ์ƒ์„ฑ

ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ

์‚ฌ์šฉ๋œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ

  • transformers: ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ ๋ชจ๋ธ์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
  • datasets: ๋ฐ์ดํ„ฐ์…‹ ์ฒ˜๋ฆฌ ๋ฐ ๊ด€๋ฆฌ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
  • bitsandbytes: ์ตœ์ ํ™”๋œ ์—ฐ์‚ฐ์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
  • peft: ํŒŒ์ธ ํŠœ๋‹์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
  • trl: ์–ธ์–ด ๋ชจ๋ธ ํŠœ๋‹์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
  • pandas: ๋ฐ์ดํ„ฐ ์กฐ์ž‘์„ ์œ„ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ

ํŒŒ์ผ ๊ตฌ์กฐ

  • dataset/: ๋ฐ์ดํ„ฐ์…‹ ํŒŒ์ผ์„ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
  • eval/: ํ‰๊ฐ€ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
  • fine-tuning/: fine tuning ๊ด€๋ จ ๋…ธํŠธ๋ถ ๋ฐ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
    • gemma-1.1-2b-it peft qlora.ipynb: fine tuning ๊ณผ์ •์— ๋Œ€ํ•œ ์„ธ๋ถ€ ์‚ฌํ•ญ์ด ํฌํ•จ๋œ ๋…ธํŠธ๋ถ์ž…๋‹ˆ๋‹ค.
  • demo.ipynb: ๋ฐ๋ชจ ๋…ธํŠธ๋ถ์œผ๋กœ ๋ชจ๋ธ ์‚ฌ์šฉ ์˜ˆ์ œ๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • requirements.txt: ํ”„๋กœ์ ํŠธ ์˜์กด์„ฑ ๋ชฉ๋ก์ด ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
  • utils.py: ์œ ํ‹ธ๋ฆฌํ‹ฐ ํ•จ์ˆ˜๋“ค์ด ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

์‚ฌ์šฉ ๋ฐฉ๋ฒ•

์ด ๋ชจ๋ธ์€ HuggingFace์˜ ๋ชจ๋ธ ํ—ˆ๋ธŒ๋ฅผ ํ†ตํ•ด ์•ก์„ธ์Šคํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, API๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‘์šฉ ํ”„๋กœ๊ทธ๋žจ์— ํ†ตํ•ฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ฝ”๋”ฉ ๋ฌธ์ œ ๋˜๋Š” ํ”„๋กœ๊ทธ๋ž˜๋ฐ ๊ด€๋ จ ์งˆ๋ฌธ์„ ์ œ๊ณตํ•˜๋ฉด ๋ชจ๋ธ์ด ๊ด€๋ จ ์„ค๋ช…, ์ฝ”๋“œ ์Šค๋‹ˆํŽซ ๋˜๋Š” ๊ฐ€์ด๋“œ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("kreimben/CodeMind-gemma-2b")
model = AutoModelForCausalLM.from_pretrained("kreimben/CodeMind-gemma-2b")

inputs = tokenizer("์ฝ”๋”ฉ ๋ฌธ์ œ๋‚˜ ์งˆ๋ฌธ์„ ์—ฌ๊ธฐ์— ์ž…๋ ฅํ•˜์„ธ์š”", return_tensors="pt")
outputs = model.generate(inputs.input_ids)
print(tokenizer.decode(outputs[0]))

ํ›ˆ๋ จ ๊ณผ์ •

๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ

import os
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model_id = 'google/gemma-1.1-2b-it'
token = os.getenv('HF_READ')

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"": 0}, token=token)
model.config.use_cache = False
model.gradient_checkpointing_enable()

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token

LoRA ๊ตฌ์„ฑ ๋ฐ ๋ชจ๋ธ ์ค€๋น„

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import bitsandbytes as bnb

model = prepare_model_for_kbit_training(model)

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
        if 'lm_head' in lora_module_names:
            lora_module_names.remove('lm_head')
    return list(lora_module_names)

modules = find_all_linear_names(model)
lora_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

๋ฐ์ดํ„ฐ ์ค€๋น„

import pandas as pd
from datasets import Dataset

submission_dataset = datasets.load_dataset('kreimben/leetcode_user_submissions_only_python', split='train').to_pandas()
submission_dataset = submission_dataset[['title', 'question_hints', 'question_content', 'content']]
captions_dataset = datasets.load_dataset('kreimben/leetcode_with_youtube_captions', split='train').to_pandas()
captions_dataset = captions_dataset[['title', 'question_hints', 'question_content', 'cc_content']]
captions_dataset.rename(columns={'cc_content': 'content'}, inplace=True)

dataset = pd.concat([submission_dataset, captions_dataset])
del submission_dataset, captions_dataset

dataset = Dataset.from_pandas(dataset)
GEMMA_2B_IT_MODEL_PREFIX_TEXT = "Below is an coding test problem. Solve the question."

def generate_prompt(data_point):
    return f"<bos><start_of_turn>user {GEMMA_2B_IT_MODEL_PREFIX_TEXT}

I don't know {data_point['title']} problem. give me the insight or appoach.

this is problem's hint.
{data_point['question_hints']}

here are some content of question.
{data_point['question_content']}<end_of_turn>
<start_of_turn>model {data_point['content']}<end_of_turn><eos>"

text_column = [generate_prompt(data_point) for data_point in dataset]
dataset = dataset.add_column("prompt", text_column)

ํ›ˆ๋ จ

from trl import SFTTrainer
import transformers
import torch

tokenizer.pad_token = tokenizer.eos_token
torch.cuda.empty_cache()

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    dataset_text_field="prompt",
    peft_config=lora_config,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    args=transformers.TrainingArguments(
        output_dir='out',
        bf16=True,
        max_steps=100,
        warmup_steps=50,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        optim="paged_adamw_8bit",
        logging_steps=20,
        report_to='wandb',
    ),
)

trainer.train()

ํ‰๊ฐ€

๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ‰๊ฐ€๋˜์—ˆ์Šต๋‹ˆ๋‹ค:

Metric Value
Average 41.62
ARC 41.81
HellaSwag 59.03
MMLU 37.26
TruthfulQA 43.45
Winogrande 59.91
GSM8K 8.26

์ œํ•œ ์‚ฌํ•ญ ๋ฐ ์œค๋ฆฌ์  ๊ณ ๋ ค์‚ฌํ•ญ

  • ๋ชจ๋ธ์˜ ์ถœ๋ ฅ์€ ํ•™์Šต ๋ฐ์ดํ„ฐ์— ๊ธฐ๋ฐ˜ํ•˜๋ฏ€๋กœ ํ•ญ์ƒ ์ •ํ™•ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ์ค‘์š”ํ•œ ๊ฒฐ์ •์ด๋‚˜ ์‹ค์„ธ๊ณ„ ๋ฌธ์ œ ํ•ด๊ฒฐ์— ๋ชจ๋ธ ์ถœ๋ ฅ์„ ์‚ฌ์šฉํ•˜๊ธฐ ์ „์— ๋ฐ˜๋“œ์‹œ ๊ฒ€์ฆ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.