Sijuade's picture
Update config.py
4f7bd04 verified
raw
history blame contribute delete
No virus
2.8 kB
import peft
import torch
import whisperx
import torch.nn as nn
from transformers import AutoProcessor, AutoTokenizer
from transformers import CLIPVisionModel, AutoModelForCausalLM
class Projections(nn.Module):
def __init__(
self,
clip_embed,
phi_embed,
num_projection_layers=6,
):
super().__init__()
self.norm = nn.LayerNorm(phi_embed)
self.output = nn.Linear(clip_embed, phi_embed)
self.projection_layers = nn.ModuleList(
[
nn.Sequential(
nn.Linear(phi_embed, phi_embed),
nn.GELU(),
nn.Linear(phi_embed, phi_embed),
)
for _ in range(num_projection_layers)
]
)
def forward(self, x):
x = self.output(x)
self.norm(x)
for layer in self.projection_layers:
residual = x
x = layer(x) + residual
return x
def load_projection_model(path, clip_embed, phi_embed):
"""Loads a Projections model instance from a checkpoint and returns it with weights loaded.
Args:
path (str): Path to the checkpoint file.
Returns:
torch.nn.Module: The loaded Projections model instance.
"""
state_dict = torch.load(path)['state_dict']
new_state_dict = {k.replace('projection.', ''): v for k, v in state_dict.items()}
model = Projections(clip_embed, phi_embed)
model.load_state_dict(new_state_dict)
return model
class Config:
EOS_TOKEN_ID = 50256
QUESTION_ANSWER_SEPARATOR_ID = 50295 # Special token ID for question-answer separation
IMAGE_SEPARATOR_TOKENS = [685, 36259, 14041, 60, 220]
phi_model_name = "microsoft/phi-2"
model_name = "openai/clip-vit-base-patch32"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
projection = load_projection_model("models/MModalGPT-FINETUNE-continued-step=10100-loss=1.16.ckpt", 768, 2560)
clip_model = CLIPVisionModel.from_pretrained(model_name)
audio_model = whisperx.load_model("small", device.type, compute_type="float16")
text_model = AutoModelForCausalLM.from_pretrained(phi_model_name,
torch_dtype=torch.float16,
#device_map="cuda",
low_cpu_mem_usage=True,
return_dict=True,
trust_remote_code=True)
peft_model = peft.PeftModel.from_pretrained(text_model, 'models/10100')