Open-Sora / tests /test_t5_shardformer.py
kadirnar's picture
Upload 98 files
e7d5680 verified
raw
history blame contribute delete
No virus
2.33 kB
import time
from copy import deepcopy
import colossalai
import torch
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import spawn
from opensora.acceleration.shardformer.policy.t5_encoder import T5EncoderPolicy
from opensora.models.text_encoder.t5 import T5Embedder
def run_t5_encoder(rank, world_size, port):
colossalai.launch({}, rank=rank, world_size=world_size, port=port, host="localhost")
# t5 embedder
t5_path = "./pretrained_models/t5_ckpts"
hf_t5 = T5Embedder(device="cuda", local_cache=True, cache_dir=t5_path, torch_dtype=torch.float)
sf_t5 = deepcopy(hf_t5)
# create huggingface model as normal
shard_config = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=None,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_flash_attention=False,
enable_jit_fused=True,
enable_sequence_parallelism=False,
enable_sequence_overlap=False,
)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, _ = shard_former.optimize(sf_t5.model, policy=T5EncoderPolicy())
sf_t5.model = sharded_model
# test t5 embedder
texts = ["Who is the best player in the history of NBA?", "How to study computer science?"]
for i in range(5):
hf_embs, hf_masks = hf_t5.get_text_embeddings(texts)
sf_embs, sf_masks = sf_t5.get_text_embeddings(texts)
# check accuracy
assert torch.allclose(hf_embs, sf_embs, rtol=1e-4, atol=1e-5), f"{hf_embs} \nvs\n{sf_embs}"
assert torch.allclose(hf_masks, sf_masks), f"{hf_masks} \nvs\n{sf_masks}"
# measure perf
torch.cuda.synchronize()
hf_start = time.time()
for i in range(20):
hf_embs, hf_masks = hf_t5.get_text_embeddings(texts)
torch.cuda.synchronize()
hf_end = time.time()
# convert sf to fp16
hf_t5.model = hf_t5.model.half()
torch.cuda.synchronize()
sf_start = time.time()
for i in range(20):
hf_embs, hf_masks = hf_t5.get_text_embeddings(texts)
torch.cuda.synchronize()
sf_end = time.time()
print(f"[Performance] native: {hf_end - hf_start}s, shardformer: {sf_end - sf_start} s")
def test_t5_encoder():
spawn(run_t5_encoder)
if __name__ == "__main__":
test_t5_encoder()