Spaces:
No application file
No application file
manuelrobben
commited on
Commit
•
6850fe2
1
Parent(s):
2367848
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- hf_gradio.py +76 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/.no_exist/2a259cdd96a4beb1cdf467512e3904197345f6a9/added_tokens.json +0 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/.no_exist/2a259cdd96a4beb1cdf467512e3904197345f6a9/generation_config.json +0 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/.no_exist/2a259cdd96a4beb1cdf467512e3904197345f6a9/merges.txt +0 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/.no_exist/2a259cdd96a4beb1cdf467512e3904197345f6a9/vocab.json +0 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/blobs/0204ed10c186a4c7c68f55dff8f26087a45898d6 +5 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/blobs/490234a04b8fc9587db08c7dbc7d73f99152f697 +24 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/blobs/ab496f1c3fd79e3c749a9d5414136a2c8e4224f94eecb261970315cdb0f813fe +3 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/blobs/f1860edb10f80bcaf7b023fce47c68a23b724c23 +9 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/blobs/f74dfbfab8f97770a87769c739fb080c21c8bacc +0 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/refs/main +1 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/snapshots/2a259cdd96a4beb1cdf467512e3904197345f6a9/config.json +24 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/snapshots/2a259cdd96a4beb1cdf467512e3904197345f6a9/model.safetensors +3 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/snapshots/2a259cdd96a4beb1cdf467512e3904197345f6a9/special_tokens_map.json +5 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/snapshots/2a259cdd96a4beb1cdf467512e3904197345f6a9/tokenizer.json +0 -0
- sbhatt54/models--EleutherAI--pythia-2.8b/snapshots/2a259cdd96a4beb1cdf467512e3904197345f6a9/tokenizer_config.json +9 -0
- utils.py +175 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
sbhatt54/models--EleutherAI--pythia-2.8b/blobs/ab496f1c3fd79e3c749a9d5414136a2c8e4224f94eecb261970315cdb0f813fe filter=lfs diff=lfs merge=lfs -text
|
__pycache__/utils.cpython-39.pyc
ADDED
Binary file (7.57 kB). View file
|
|
hf_gradio.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
import transformers
|
3 |
+
from utils import get_local_dir, pad_to_length
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def load_checkpoint(checkpoint_path):
|
9 |
+
model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
|
10 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
11 |
+
return model, tokenizer
|
12 |
+
|
13 |
+
import gradio as gr
|
14 |
+
import torch
|
15 |
+
|
16 |
+
|
17 |
+
checkpoint_paths = {'full_policy':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/anthropic_dpo_phythia28/LATEST/policy.pt',
|
18 |
+
'reference':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/anthropic_dpo_pythia28_2023-08-06_12-12-25_294354/LATEST/policy.pt',
|
19 |
+
'all_but_two_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_two_last/LATEST/policy.pt',
|
20 |
+
'all_but_three_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_three_last_2023-08-19_06-44-44_597545/LATEST/policy.pt',
|
21 |
+
'all_but_last_basic':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_last_basic_2023-08-19_06-44-55_606332/LATEST/policy.pt',
|
22 |
+
'all_but_last':'/home/sbhatt54/direct-preference-optimization/.cache/sbhatt54/all_but_last_2023-08-19_06-45-07_722235/LATEST/policy.pt'
|
23 |
+
|
24 |
+
|
25 |
+
}
|
26 |
+
|
27 |
+
options=['reference','full_policy','all_but_two_last','all_but_three_last','all_but_last_basic','all_but_last']
|
28 |
+
|
29 |
+
policy_dtype = getattr(torch, 'float32')
|
30 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/pythia-2.8b', cache_dir=get_local_dir('.cache'))
|
31 |
+
|
32 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
33 |
+
'EleutherAI/pythia-2.8b', cache_dir=get_local_dir('.cache'), low_cpu_mem_usage=True, torch_dtype=policy_dtype)
|
34 |
+
|
35 |
+
|
36 |
+
if tokenizer.pad_token_id is None:
|
37 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
38 |
+
|
39 |
+
def load_selected_checkpoint(options):
|
40 |
+
selected_path = checkpoint_paths[options]
|
41 |
+
|
42 |
+
policy_state_dict = torch.load(selected_path, map_location='cpu')
|
43 |
+
step, metrics = policy_state_dict ['step_idx'], policy_state_dict ['metrics']
|
44 |
+
model.load_state_dict(policy_state_dict['state'])
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def generate_response(prompt, options):
|
49 |
+
model= load_selected_checkpoint(options)
|
50 |
+
prompt='\n\nHuman: ' + prompt + '\n\nAssistant:'
|
51 |
+
input =tokenizer(prompt, add_special_tokens=False)
|
52 |
+
|
53 |
+
for i,k in input.items():
|
54 |
+
input[i]=torch.LongTensor(k).unsqueeze(0)
|
55 |
+
|
56 |
+
policy_output = model.generate(input['input_ids'], attention_mask=input['attention_mask'], max_length=512, do_sample=True, pad_token_id=tokenizer.pad_token_id)
|
57 |
+
|
58 |
+
policy_output = pad_to_length(policy_output, 512, tokenizer.pad_token_id)
|
59 |
+
|
60 |
+
policy_output_decoded = tokenizer.batch_decode(policy_output, skip_special_tokens=True)
|
61 |
+
|
62 |
+
return policy_output_decoded
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
iface = gr.Interface(
|
70 |
+
fn=generate_response,
|
71 |
+
inputs=[gr.inputs.Textbox(label="Prompt"), gr.inputs.Dropdown(choices=options, label="Select Checkpoint")],
|
72 |
+
outputs="text"
|
73 |
+
)
|
74 |
+
iface.launch(share=True)
|
75 |
+
|
76 |
+
|
sbhatt54/models--EleutherAI--pythia-2.8b/.no_exist/2a259cdd96a4beb1cdf467512e3904197345f6a9/added_tokens.json
ADDED
File without changes
|
sbhatt54/models--EleutherAI--pythia-2.8b/.no_exist/2a259cdd96a4beb1cdf467512e3904197345f6a9/generation_config.json
ADDED
File without changes
|
sbhatt54/models--EleutherAI--pythia-2.8b/.no_exist/2a259cdd96a4beb1cdf467512e3904197345f6a9/merges.txt
ADDED
File without changes
|
sbhatt54/models--EleutherAI--pythia-2.8b/.no_exist/2a259cdd96a4beb1cdf467512e3904197345f6a9/vocab.json
ADDED
File without changes
|
sbhatt54/models--EleutherAI--pythia-2.8b/blobs/0204ed10c186a4c7c68f55dff8f26087a45898d6
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<|endoftext|>",
|
3 |
+
"eos_token": "<|endoftext|>",
|
4 |
+
"unk_token": "<|endoftext|>"
|
5 |
+
}
|
sbhatt54/models--EleutherAI--pythia-2.8b/blobs/490234a04b8fc9587db08c7dbc7d73f99152f697
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"GPTNeoXForCausalLM"
|
4 |
+
],
|
5 |
+
"bos_token_id": 0,
|
6 |
+
"eos_token_id": 0,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_size": 2560,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 10240,
|
11 |
+
"layer_norm_eps": 1e-05,
|
12 |
+
"max_position_embeddings": 2048,
|
13 |
+
"model_type": "gpt_neox",
|
14 |
+
"num_attention_heads": 32,
|
15 |
+
"num_hidden_layers": 32,
|
16 |
+
"rotary_emb_base": 10000,
|
17 |
+
"rotary_pct": 0.25,
|
18 |
+
"tie_word_embeddings": false,
|
19 |
+
"torch_dtype": "float16",
|
20 |
+
"transformers_version": "4.24.0",
|
21 |
+
"use_cache": true,
|
22 |
+
"use_parallel_residual": true,
|
23 |
+
"vocab_size": 50304
|
24 |
+
}
|
sbhatt54/models--EleutherAI--pythia-2.8b/blobs/ab496f1c3fd79e3c749a9d5414136a2c8e4224f94eecb261970315cdb0f813fe
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab496f1c3fd79e3c749a9d5414136a2c8e4224f94eecb261970315cdb0f813fe
|
3 |
+
size 5684693096
|
sbhatt54/models--EleutherAI--pythia-2.8b/blobs/f1860edb10f80bcaf7b023fce47c68a23b724c23
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": "<|endoftext|>",
|
4 |
+
"eos_token": "<|endoftext|>",
|
5 |
+
"name_or_path": "EleutherAI/gpt-neox-20b",
|
6 |
+
"special_tokens_map_file": "/admin/home-hailey/.cache/huggingface/hub/models--EleutherAI--gpt-neox-20b/snapshots/4e49eadb5d14bd22f314ec3f45b69a87b88c7691/special_tokens_map.json",
|
7 |
+
"tokenizer_class": "GPTNeoXTokenizer",
|
8 |
+
"unk_token": "<|endoftext|>"
|
9 |
+
}
|
sbhatt54/models--EleutherAI--pythia-2.8b/blobs/f74dfbfab8f97770a87769c739fb080c21c8bacc
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sbhatt54/models--EleutherAI--pythia-2.8b/refs/main
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
2a259cdd96a4beb1cdf467512e3904197345f6a9
|
sbhatt54/models--EleutherAI--pythia-2.8b/snapshots/2a259cdd96a4beb1cdf467512e3904197345f6a9/config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"GPTNeoXForCausalLM"
|
4 |
+
],
|
5 |
+
"bos_token_id": 0,
|
6 |
+
"eos_token_id": 0,
|
7 |
+
"hidden_act": "gelu",
|
8 |
+
"hidden_size": 2560,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 10240,
|
11 |
+
"layer_norm_eps": 1e-05,
|
12 |
+
"max_position_embeddings": 2048,
|
13 |
+
"model_type": "gpt_neox",
|
14 |
+
"num_attention_heads": 32,
|
15 |
+
"num_hidden_layers": 32,
|
16 |
+
"rotary_emb_base": 10000,
|
17 |
+
"rotary_pct": 0.25,
|
18 |
+
"tie_word_embeddings": false,
|
19 |
+
"torch_dtype": "float16",
|
20 |
+
"transformers_version": "4.24.0",
|
21 |
+
"use_cache": true,
|
22 |
+
"use_parallel_residual": true,
|
23 |
+
"vocab_size": 50304
|
24 |
+
}
|
sbhatt54/models--EleutherAI--pythia-2.8b/snapshots/2a259cdd96a4beb1cdf467512e3904197345f6a9/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab496f1c3fd79e3c749a9d5414136a2c8e4224f94eecb261970315cdb0f813fe
|
3 |
+
size 5684693096
|
sbhatt54/models--EleutherAI--pythia-2.8b/snapshots/2a259cdd96a4beb1cdf467512e3904197345f6a9/special_tokens_map.json
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "<|endoftext|>",
|
3 |
+
"eos_token": "<|endoftext|>",
|
4 |
+
"unk_token": "<|endoftext|>"
|
5 |
+
}
|
sbhatt54/models--EleutherAI--pythia-2.8b/snapshots/2a259cdd96a4beb1cdf467512e3904197345f6a9/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sbhatt54/models--EleutherAI--pythia-2.8b/snapshots/2a259cdd96a4beb1cdf467512e3904197345f6a9/tokenizer_config.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": "<|endoftext|>",
|
4 |
+
"eos_token": "<|endoftext|>",
|
5 |
+
"name_or_path": "EleutherAI/gpt-neox-20b",
|
6 |
+
"special_tokens_map_file": "/admin/home-hailey/.cache/huggingface/hub/models--EleutherAI--gpt-neox-20b/snapshots/4e49eadb5d14bd22f314ec3f45b69a87b88c7691/special_tokens_map.json",
|
7 |
+
"tokenizer_class": "GPTNeoXTokenizer",
|
8 |
+
"unk_token": "<|endoftext|>"
|
9 |
+
}
|
utils.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import getpass
|
3 |
+
from datetime import datetime
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch.distributed as dist
|
8 |
+
import inspect
|
9 |
+
import importlib.util
|
10 |
+
import socket
|
11 |
+
import os
|
12 |
+
from typing import Dict, Union, Type, List
|
13 |
+
|
14 |
+
|
15 |
+
def get_open_port():
|
16 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
17 |
+
s.bind(('', 0)) # bind to all interfaces and use an OS provided port
|
18 |
+
return s.getsockname()[1] # return only the port number
|
19 |
+
|
20 |
+
|
21 |
+
def get_remote_file(remote_path, local_path=None):
|
22 |
+
hostname, path = remote_path.split(':')
|
23 |
+
local_hostname = socket.gethostname()
|
24 |
+
if hostname == local_hostname or hostname == local_hostname[:local_hostname.find('.')]:
|
25 |
+
return path
|
26 |
+
|
27 |
+
if local_path is None:
|
28 |
+
local_path = path
|
29 |
+
# local_path = local_path.replace('/scr-ssd', '/scr')
|
30 |
+
if os.path.exists(local_path):
|
31 |
+
return local_path
|
32 |
+
local_dir = os.path.dirname(local_path)
|
33 |
+
os.makedirs(local_dir, exist_ok=True)
|
34 |
+
|
35 |
+
print(f'Copying {hostname}:{path} to {local_path}')
|
36 |
+
os.system(f'scp {remote_path} {local_path}')
|
37 |
+
return local_path
|
38 |
+
|
39 |
+
|
40 |
+
def rank0_print(*args, **kwargs):
|
41 |
+
"""Print, but only on rank 0."""
|
42 |
+
if not dist.is_initialized() or dist.get_rank() == 0:
|
43 |
+
print(*args, **kwargs)
|
44 |
+
|
45 |
+
|
46 |
+
def get_local_dir(prefixes_to_resolve: List[str]) -> str:
|
47 |
+
"""Return the path to the cache directory for this user."""
|
48 |
+
for prefix in prefixes_to_resolve:
|
49 |
+
if os.path.exists(prefix):
|
50 |
+
return f"{prefix}/{getpass.getuser()}"
|
51 |
+
os.makedirs(prefix)
|
52 |
+
return f"{prefix}/{getpass.getuser()}"
|
53 |
+
|
54 |
+
|
55 |
+
def get_local_run_dir(exp_name: str, local_dirs: List[str]) -> str:
|
56 |
+
"""Create a local directory to store outputs for this run, and return its path."""
|
57 |
+
now = datetime.now()
|
58 |
+
timestamp = now.strftime("%Y-%m-%d_%H-%M-%S_%f")
|
59 |
+
run_dir = f"{get_local_dir(local_dirs)}/{exp_name}_{timestamp}"
|
60 |
+
os.makedirs(run_dir, exist_ok=True)
|
61 |
+
return run_dir
|
62 |
+
|
63 |
+
|
64 |
+
def slice_and_move_batch_for_device(batch: Dict, rank: int, world_size: int, device: str) -> Dict:
|
65 |
+
"""Slice a batch into chunks, and move each chunk to the specified device."""
|
66 |
+
chunk_size = len(list(batch.values())[0]) // world_size
|
67 |
+
start = chunk_size * rank
|
68 |
+
end = chunk_size * (rank + 1)
|
69 |
+
sliced = {k: v[start:end] for k, v in batch.items()}
|
70 |
+
on_device = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in sliced.items()}
|
71 |
+
return on_device
|
72 |
+
|
73 |
+
|
74 |
+
def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor:
|
75 |
+
if tensor.size(dim) >= length:
|
76 |
+
return tensor
|
77 |
+
else:
|
78 |
+
pad_size = list(tensor.shape)
|
79 |
+
pad_size[dim] = length - tensor.size(dim)
|
80 |
+
return torch.cat([tensor, pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device)], dim=dim)
|
81 |
+
|
82 |
+
|
83 |
+
def all_gather_if_needed(values: torch.Tensor, rank: int, world_size: int) -> torch.Tensor:
|
84 |
+
"""Gather and stack/cat values from all processes, if there are multiple processes."""
|
85 |
+
if world_size == 1:
|
86 |
+
return values
|
87 |
+
|
88 |
+
all_values = [torch.empty_like(values).to(rank) for _ in range(world_size)]
|
89 |
+
dist.all_gather(all_values, values)
|
90 |
+
cat_function = torch.cat if values.dim() > 0 else torch.stack
|
91 |
+
return cat_function(all_values, dim=0)
|
92 |
+
|
93 |
+
|
94 |
+
def formatted_dict(d: Dict) -> Dict:
|
95 |
+
"""Format a dictionary for printing."""
|
96 |
+
return {k: (f"{v:.5g}" if type(v) == float else v) for k, v in d.items()}
|
97 |
+
|
98 |
+
|
99 |
+
def disable_dropout(model: torch.nn.Module):
|
100 |
+
"""Disable dropout in a model."""
|
101 |
+
for module in model.modules():
|
102 |
+
if isinstance(module, torch.nn.Dropout):
|
103 |
+
module.p = 0
|
104 |
+
|
105 |
+
|
106 |
+
def print_gpu_memory(rank: int = None, message: str = ''):
|
107 |
+
"""Print the amount of GPU memory currently allocated for each GPU."""
|
108 |
+
if torch.cuda.is_available():
|
109 |
+
device_count = torch.cuda.device_count()
|
110 |
+
for i in range(device_count):
|
111 |
+
device = torch.device(f'cuda:{i}')
|
112 |
+
allocated_bytes = torch.cuda.memory_allocated(device)
|
113 |
+
if allocated_bytes == 0:
|
114 |
+
continue
|
115 |
+
print('*' * 40)
|
116 |
+
print(f'[{message} rank {rank} ] GPU {i}: {allocated_bytes / 1024**2:.2f} MB')
|
117 |
+
print('*' * 40)
|
118 |
+
|
119 |
+
|
120 |
+
def get_block_class_from_model(model: torch.nn.Module, block_class_name: str) -> torch.nn.Module:
|
121 |
+
"""Get the class of a block from a model, using the block's class name."""
|
122 |
+
for module in model.modules():
|
123 |
+
if module.__class__.__name__ == block_class_name:
|
124 |
+
return module.__class__
|
125 |
+
raise ValueError(f"Could not find block class {block_class_name} in model {model}")
|
126 |
+
|
127 |
+
|
128 |
+
def get_block_class_from_model_class_and_block_name(model_class: Type, block_class_name: str) -> Type:
|
129 |
+
filepath = inspect.getfile(model_class)
|
130 |
+
assert filepath.endswith('.py'), f"Expected a .py file, got {filepath}"
|
131 |
+
assert os.path.exists(filepath), f"File {filepath} does not exist"
|
132 |
+
assert "transformers" in filepath, f"Expected a transformers model, got {filepath}"
|
133 |
+
|
134 |
+
module_name = filepath[filepath.find('transformers'):].replace('/', '.')[:-3]
|
135 |
+
print(f"Searching in file {filepath}, module {module_name} for class {block_class_name}")
|
136 |
+
|
137 |
+
# Load the module dynamically
|
138 |
+
spec = importlib.util.spec_from_file_location(module_name, filepath)
|
139 |
+
module = importlib.util.module_from_spec(spec)
|
140 |
+
spec.loader.exec_module(module)
|
141 |
+
|
142 |
+
# Get the class dynamically
|
143 |
+
class_ = getattr(module, block_class_name)
|
144 |
+
print(f"Found class {class_} in module {module_name}")
|
145 |
+
return class_
|
146 |
+
|
147 |
+
|
148 |
+
def init_distributed(rank: int, world_size: int, master_addr: str = 'localhost', port: int = 12355, backend: str = 'nccl'):
|
149 |
+
print(rank, 'initializing distributed')
|
150 |
+
os.environ["MASTER_ADDR"] = master_addr
|
151 |
+
os.environ["MASTER_PORT"] = str(port)
|
152 |
+
dist.init_process_group(backend, rank=rank, world_size=world_size)
|
153 |
+
torch.cuda.set_device(rank)
|
154 |
+
|
155 |
+
|
156 |
+
class TemporarilySeededRandom:
|
157 |
+
def __init__(self, seed):
|
158 |
+
"""Temporarily set the random seed, and then restore it when exiting the context."""
|
159 |
+
self.seed = seed
|
160 |
+
self.stored_state = None
|
161 |
+
self.stored_np_state = None
|
162 |
+
|
163 |
+
def __enter__(self):
|
164 |
+
# Store the current random state
|
165 |
+
self.stored_state = random.getstate()
|
166 |
+
self.stored_np_state = np.random.get_state()
|
167 |
+
|
168 |
+
# Set the random seed
|
169 |
+
random.seed(self.seed)
|
170 |
+
np.random.seed(self.seed)
|
171 |
+
|
172 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
173 |
+
# Restore the random state
|
174 |
+
random.setstate(self.stored_state)
|
175 |
+
np.random.set_state(self.stored_np_state)
|