File size: 1,811 Bytes
96fe658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import math
import unittest
import torch
from modelscope import Model, Preprocessor
from torch import nn
from swift import LoRAConfig, Swift
class TestMergedLinear(unittest.TestCase):
def test_swift_lora_forward(self):
from swift.tuners.lora import MergedLinear
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, 'lora_A'):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.ones_(self.lora_B)
MergedLinear.reset_parameters = reset_parameters
model = Model.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base')
preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base')
inputs = preprocessor('how are you')
lora_config = LoRAConfig(
target_modules=['query', 'key', 'value'], use_merged_linear=True, enable_lora=[True, True, True])
outputs = model(**inputs)
model = Swift.prepare_model(model, config=lora_config)
model.eval()
outputs_lora = model(**inputs)
model.deactivate_adapter('default')
outputs_deactivate = model(**inputs)
model.activate_adapter('default')
outputs_reactivate = model(**inputs)
Swift.merge_and_unload(model)
outputs_merged = model(**inputs)
self.assertTrue(torch.allclose(outputs.logits, outputs_deactivate.logits))
self.assertTrue(not torch.allclose(outputs.logits, outputs_lora.logits))
self.assertTrue(torch.allclose(outputs_lora.logits, outputs_reactivate.logits))
self.assertTrue(torch.allclose(outputs_lora.logits, outputs_merged.logits, atol=1e-4))
|