menouar
commited on
Commit
•
09bdd6c
1
Parent(s):
611507d
Use FalconForCausalLM for falcon finetuning instead of AutoModelForCausalLM
Browse files
utils/notebook_generator.py
CHANGED
@@ -2,7 +2,7 @@ from typing import Optional
|
|
2 |
|
3 |
import nbformat as nbf
|
4 |
|
5 |
-
from utils import FTDataSet
|
6 |
|
7 |
|
8 |
def create_install_libraries_cells(cells: list):
|
@@ -130,9 +130,13 @@ def create_model_cells(cells: list, model_id: str, version: str, flash_attention
|
|
130 |
if pad_value is None:
|
131 |
pad_value_str = ""
|
132 |
|
|
|
|
|
|
|
|
|
133 |
code = f"""
|
134 |
import torch
|
135 |
-
from transformers import AutoTokenizer,
|
136 |
from trl import setup_chat_format
|
137 |
|
138 |
# Hugging Face model id
|
@@ -145,7 +149,7 @@ bnb_config = BitsAndBytesConfig(
|
|
145 |
)
|
146 |
|
147 |
# Load model and tokenizer
|
148 |
-
model =
|
149 |
model_id,
|
150 |
device_map="auto",
|
151 |
trust_remote_code=True,
|
|
|
2 |
|
3 |
import nbformat as nbf
|
4 |
|
5 |
+
from utils import FTDataSet, falcon
|
6 |
|
7 |
|
8 |
def create_install_libraries_cells(cells: list):
|
|
|
130 |
if pad_value is None:
|
131 |
pad_value_str = ""
|
132 |
|
133 |
+
auto_model_import = "AutoModelForCausalLM"
|
134 |
+
if model_id == falcon.name:
|
135 |
+
auto_model_import = "FalconForCausalLM"
|
136 |
+
|
137 |
code = f"""
|
138 |
import torch
|
139 |
+
from transformers import AutoTokenizer, {auto_model_import}, BitsAndBytesConfig
|
140 |
from trl import setup_chat_format
|
141 |
|
142 |
# Hugging Face model id
|
|
|
149 |
)
|
150 |
|
151 |
# Load model and tokenizer
|
152 |
+
model = {auto_model_import}.from_pretrained(
|
153 |
model_id,
|
154 |
device_map="auto",
|
155 |
trust_remote_code=True,
|