|
|
from transformers import AutoConfig |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
from transformers.models.dinov2.configuration_dinov2 import Dinov2Config |
|
|
|
|
|
|
|
|
class VisionConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
@staticmethod |
|
|
def from_exp_config(vision_config: dict): |
|
|
|
|
|
model_type = vision_config["model_type"] |
|
|
|
|
|
if model_type in [ |
|
|
"siglip_vision_model", |
|
|
"clip_vision_model", |
|
|
"dinov2", |
|
|
"sam", |
|
|
"raddino", |
|
|
]: |
|
|
config = AutoConfig.from_pretrained( |
|
|
vision_config["pretrained_name_or_path"] |
|
|
) |
|
|
config = config.to_dict() |
|
|
vision_config.update(config) |
|
|
elif model_type == "xrayclip": |
|
|
config = AutoConfig.from_pretrained( |
|
|
vision_config["pretrained_name_or_path"] |
|
|
) |
|
|
config = config.to_dict() |
|
|
config["model_type"] = "xrayclip" |
|
|
vision_config.update(config) |
|
|
elif model_type == "biomedclip": |
|
|
pass |
|
|
elif model_type == "m3ae": |
|
|
pass |
|
|
|
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
vision_config = VisionConfig(**vision_config) |
|
|
|
|
|
return vision_config |
|
|
|
|
|
|
|
|
class TextConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
model_type, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.model_type = model_type |
|
|
|
|
|
@staticmethod |
|
|
def from_exp_config( |
|
|
text_config: dict, |
|
|
): |
|
|
model_type = text_config["model_type"] |
|
|
|
|
|
if model_type in [ |
|
|
"siglip_text_model", |
|
|
"clip_text_model", |
|
|
"mpnet", |
|
|
"biomedclip", |
|
|
"bioclinicalmpbert", |
|
|
]: |
|
|
text_config = TextConfig(**text_config) |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
|
|
|
return text_config |
|
|
|
|
|
|
|
|
class AlignTransformerConfig(PretrainedConfig): |
|
|
def __init__( |
|
|
self, |
|
|
model_type: str = "align_transformer", |
|
|
projector_config=None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.model_type = model_type |
|
|
self.projector_config = projector_config |
|
|
|
|
|
@staticmethod |
|
|
def from_exp_config( |
|
|
align_transformer_config: dict, |
|
|
): |
|
|
projector_config = align_transformer_config.pop("projector_config", None) |
|
|
|
|
|
config = Dinov2Config(**align_transformer_config) |
|
|
config = config.to_dict() |
|
|
|
|
|
align_transformer_config = AlignTransformerConfig( |
|
|
**(config | align_transformer_config), |
|
|
projector_config=projector_config, |
|
|
) |
|
|
|
|
|
return align_transformer_config |
|
|
|
|
|
|
|
|
class CxrAlignConfig(PretrainedConfig): |
|
|
is_composition = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vision_config: dict, |
|
|
text_config: dict, |
|
|
align_transformer_config: dict, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
self.vision_config = VisionConfig.from_exp_config(vision_config) |
|
|
|
|
|
|
|
|
self.text_config = TextConfig.from_exp_config(text_config) |
|
|
|
|
|
self.align_transformer_config = AlignTransformerConfig.from_exp_config( |
|
|
align_transformer_config |
|
|
) |
|
|
|
|
|
self.kwargs = kwargs |
|
|
|