Transformers
English
trl
rlhf

How to load the model?

#2
by jvhoffbauer - opened

Loading the model with

reward_model = AutoModelForSequenceClassification.from_pretrained(
    "trl-lib/llama-7b-se-rm-peft", 
    num_labels=1, 
    torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained("trl-lib/llama-7b-se-rm-peft")

yields the following error

---------------------------------------------------------------------------
HTTPError                                 Traceback (most recent call last)
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py:259, in hf_raise_for_status(response, endpoint_name)
    258 try:
--> 259     response.raise_for_status()
    260 except HTTPError as e:

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/requests/models.py:1021, in Response.raise_for_status(self)
   1020 if http_error_msg:
-> 1021     raise HTTPError(http_error_msg, response=self)

HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/trl-lib/llama-7b-se-rm-peft/resolve/main/config.json

The above exception was the direct cause of the following exception:

EntryNotFoundError                        Traceback (most recent call last)
File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/utils/hub.py:427, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
    425 try:
    426     # Load from URL or cache if already cached
--> 427     resolved_file = hf_hub_download(
    428         path_or_repo_id,
    429         filename,
    430         subfolder=None if len(subfolder) == 0 else subfolder,
    431         repo_type=repo_type,
    432         revision=revision,
    433         cache_dir=cache_dir,
    434         user_agent=user_agent,
    435         force_download=force_download,
    436         proxies=proxies,
    437         resume_download=resume_download,
    438         token=token,
    439         local_files_only=local_files_only,
    440     )
    441 except GatedRepoError as e:

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:120, in validate_hf_hub_args.._inner_fn(*args, **kwargs)
    118     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 120 return fn(*args, **kwargs)

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1195, in hf_hub_download(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, legacy_cache_layout)
   1194 try:
-> 1195     metadata = get_hf_file_metadata(
   1196         url=url,
   1197         token=token,
   1198         proxies=proxies,
   1199         timeout=etag_timeout,
   1200     )
   1201 except EntryNotFoundError as http_error:
   1202     # Cache the non-existence of the file and raise

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:120, in validate_hf_hub_args.._inner_fn(*args, **kwargs)
    118     kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
--> 120 return fn(*args, **kwargs)

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/file_download.py:1541, in get_hf_file_metadata(url, token, proxies, timeout)
   1532 r = _request_wrapper(
   1533     method="HEAD",
   1534     url=url,
   (...)
   1539     timeout=timeout,
   1540 )
-> 1541 hf_raise_for_status(r)
   1543 # Return

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/huggingface_hub/utils/_errors.py:269, in hf_raise_for_status(response, endpoint_name)
    268     message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}."
--> 269     raise EntryNotFoundError(message, response) from e
    271 elif error_code == "GatedRepo":

EntryNotFoundError: 404 Client Error. (Request ID: Root=1-64e60e7f-642a805a39a142330e405e81)

Entry Not Found for url: https://huggingface.co/trl-lib/llama-7b-se-rm-peft/resolve/main/config.json.

The above exception was the direct cause of the following exception:

OSError                                   Traceback (most recent call last)
Cell In[10], line 1
----> 1 reward_model = AutoModelForSequenceClassification.from_pretrained(
      2     "trl-lib/llama-7b-se-rm-peft", 
      3     num_labels=1, 
      4     torch_dtype=torch.bfloat16
      5 )
      7 tokenizer = AutoTokenizer.from_pretrained("trl-lib/llama-7b-se-rm-peft")

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:479, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    476 if kwargs.get("torch_dtype", None) == "auto":
    477     _ = kwargs.pop("torch_dtype")
--> 479 config, kwargs = AutoConfig.from_pretrained(
    480     pretrained_model_name_or_path,
    481     return_unused_kwargs=True,
    482     trust_remote_code=trust_remote_code,
    483     **hub_kwargs,
    484     **kwargs,
    485 )
    487 # if torch_dtype=auto was passed here, ensure to pass it on
    488 if kwargs_orig.get("torch_dtype", None) == "auto":

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/models/auto/configuration_auto.py:1004, in AutoConfig.from_pretrained(cls, pretrained_model_name_or_path, **kwargs)
   1002 kwargs["name_or_path"] = pretrained_model_name_or_path
   1003 trust_remote_code = kwargs.pop("trust_remote_code", None)
-> 1004 config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
   1005 has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
   1006 has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/configuration_utils.py:620, in PretrainedConfig.get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
    618 original_kwargs = copy.deepcopy(kwargs)
    619 # Get config dict associated with the base config file
--> 620 config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
    621 if "_commit_hash" in config_dict:
    622     original_kwargs["_commit_hash"] = config_dict["_commit_hash"]

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/configuration_utils.py:675, in PretrainedConfig._get_config_dict(cls, pretrained_model_name_or_path, **kwargs)
    671 configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
    673 try:
    674     # Load from local folder or from cache or download from model Hub and cache
--> 675     resolved_config_file = cached_file(
    676         pretrained_model_name_or_path,
    677         configuration_file,
    678         cache_dir=cache_dir,
    679         force_download=force_download,
    680         proxies=proxies,
    681         resume_download=resume_download,
    682         local_files_only=local_files_only,
    683         token=token,
    684         user_agent=user_agent,
    685         revision=revision,
    686         subfolder=subfolder,
    687         _commit_hash=commit_hash,
    688     )
    689     commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
    690 except EnvironmentError:
    691     # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
    692     # the original exception.

File /workspaces/reddit_qa/venv/lib/python3.10/site-packages/transformers/utils/hub.py:478, in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
    476     if revision is None:
    477         revision = "main"
--> 478     raise EnvironmentError(
    479         f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
    480         f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
    481     ) from e
    482 except HTTPError as err:
    483     # First we try to see if we have a cached version (not up to date):
    484     resolved_file = try_to_load_from_cache(path_or_repo_id, full_filename, cache_dir=cache_dir, revision=revision)

OSError: trl-lib/llama-7b-se-rm-peft does not appear to have a file named config.json. Checkout 'https://huggingface.co/trl-lib/llama-7b-se-rm-peft/main' for available files.

This is an adapter version which only contain the LoRA layers parameters. You should merge this peft adapter layers with the base model first, using:

python examples/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ

You can find the merge_peft_adapter.py under their repository. Good luck!

Thanks a lot!

jvhoffbauer changed discussion status to closed

Sign up or log in to comment