|
import math |
|
import unittest |
|
|
|
import torch |
|
from modelscope import Model, Preprocessor |
|
from torch import nn |
|
|
|
from swift import LoRAConfig, Swift |
|
|
|
|
|
class TestMergedLinear(unittest.TestCase): |
|
|
|
def test_swift_lora_forward(self): |
|
|
|
from swift.tuners.lora import MergedLinear |
|
|
|
def reset_parameters(self): |
|
nn.Linear.reset_parameters(self) |
|
if hasattr(self, 'lora_A'): |
|
|
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
|
nn.init.ones_(self.lora_B) |
|
|
|
MergedLinear.reset_parameters = reset_parameters |
|
|
|
model = Model.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') |
|
preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_sentence-similarity_chinese-base') |
|
inputs = preprocessor('how are you') |
|
lora_config = LoRAConfig( |
|
target_modules=['query', 'key', 'value'], use_merged_linear=True, enable_lora=[True, True, True]) |
|
outputs = model(**inputs) |
|
model = Swift.prepare_model(model, config=lora_config) |
|
model.eval() |
|
outputs_lora = model(**inputs) |
|
model.deactivate_adapter('default') |
|
outputs_deactivate = model(**inputs) |
|
model.activate_adapter('default') |
|
outputs_reactivate = model(**inputs) |
|
Swift.merge_and_unload(model) |
|
outputs_merged = model(**inputs) |
|
self.assertTrue(torch.allclose(outputs.logits, outputs_deactivate.logits)) |
|
self.assertTrue(not torch.allclose(outputs.logits, outputs_lora.logits)) |
|
self.assertTrue(torch.allclose(outputs_lora.logits, outputs_reactivate.logits)) |
|
self.assertTrue(torch.allclose(outputs_lora.logits, outputs_merged.logits, atol=1e-4)) |
|
|