Spaces:
Runtime error
Runtime error
File size: 765 Bytes
387a8a0 1c1b26d 387a8a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import keras_nlp
MODEL_NAME = "gemma2_instruct_2b_en"
#LORA_WEIGHT_PATH = "ice_breaking_challenge/models/gemma2_it_2b_icebreaking_quiz_v2_3.lora.h5"
LORA_WEIGHT_PATH = "ice_breaking_challenge/models/gemma2_it_2b_icebreaking_quiz_v2_5.lora.h5"
def load_model_with_lora(model_name:str = MODEL_NAME, lora_weight_path: str = LORA_WEIGHT_PATH):
"""
Keras ๊ธฐ๋ฐ ๋ชจ๋ธ ๋ก๋ ๋ฐ LoRA ๊ฐ์ค์น ์ ์ฉ
Args:
model_name (str): ๋ก๋ํ ๋ชจ๋ธ์ ์ด๋ฆ
lora_weight_path (str): ์ ์ฉํ LoRA ๊ฐ์ค์น ํ์ผ์ ๊ฒฝ๋ก
Returns:
keras_nlp.models.GemmaCausalLM: ๋ก๋๋ ๋ชจ๋ธ
"""
model = keras_nlp.models.GemmaCausalLM.from_preset(model_name)
model.backbone.load_lora_weights(lora_weight_path)
return model
|