File size: 765 Bytes
387a8a0
 
 
1c1b26d
 
387a8a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import keras_nlp

MODEL_NAME = "gemma2_instruct_2b_en"
#LORA_WEIGHT_PATH = "ice_breaking_challenge/models/gemma2_it_2b_icebreaking_quiz_v2_3.lora.h5"
LORA_WEIGHT_PATH = "ice_breaking_challenge/models/gemma2_it_2b_icebreaking_quiz_v2_5.lora.h5"

def load_model_with_lora(model_name:str = MODEL_NAME, lora_weight_path: str = LORA_WEIGHT_PATH):
    """
    Keras ๊ธฐ๋ฐ˜ ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ LoRA ๊ฐ€์ค‘์น˜ ์ ์šฉ

    Args:
        model_name (str): ๋กœ๋“œํ•  ๋ชจ๋ธ์˜ ์ด๋ฆ„
        lora_weight_path (str): ์ ์šฉํ•  LoRA ๊ฐ€์ค‘์น˜ ํŒŒ์ผ์˜ ๊ฒฝ๋กœ

    Returns:
        keras_nlp.models.GemmaCausalLM: ๋กœ๋“œ๋œ ๋ชจ๋ธ
    """
    model = keras_nlp.models.GemmaCausalLM.from_preset(model_name)

    model.backbone.load_lora_weights(lora_weight_path)

    return model