Spaces:
Sleeping
Sleeping
File size: 934 Bytes
d3cd5c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
from moondream.hf import Moondream
from moondream.hf.configuration_moondream import MoondreamConfig
MoondreamConfig.register_for_auto_class()
Moondream.register_for_auto_class("AutoModelForCausalLM")
OUT_MODEL = "vikhyatk/moondream-next"
CKPT_DIRS = []
def get_ckpt(filename):
ckpts = [torch.load(f"{dir}/{filename}", map_location="cpu") for dir in CKPT_DIRS]
avg_ckpt = {key: sum(ckpt[key] for ckpt in ckpts) / len(ckpts) for key in ckpts[0]}
return avg_ckpt
config = MoondreamConfig()
model = Moondream(config)
model.vision_encoder.encoder.load_state_dict(get_ckpt("vision_encoder.final.pt"))
model.vision_encoder.projection.load_state_dict(get_ckpt("vision_projection.final.pt"))
model.text_model.load_state_dict(get_ckpt("text_model.final.pt"))
model.region_model.load_state_dict(get_ckpt("region_model.final.pt"))
model = model.to(dtype=torch.float16)
model.push_to_hub(OUT_MODEL, config=config)
|