File size: 934 Bytes
d3cd5c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch

from moondream.hf import Moondream
from moondream.hf.configuration_moondream import MoondreamConfig

MoondreamConfig.register_for_auto_class()
Moondream.register_for_auto_class("AutoModelForCausalLM")

OUT_MODEL = "vikhyatk/moondream-next"
CKPT_DIRS = []


def get_ckpt(filename):
    ckpts = [torch.load(f"{dir}/{filename}", map_location="cpu") for dir in CKPT_DIRS]
    avg_ckpt = {key: sum(ckpt[key] for ckpt in ckpts) / len(ckpts) for key in ckpts[0]}
    return avg_ckpt


config = MoondreamConfig()
model = Moondream(config)
model.vision_encoder.encoder.load_state_dict(get_ckpt("vision_encoder.final.pt"))
model.vision_encoder.projection.load_state_dict(get_ckpt("vision_projection.final.pt"))
model.text_model.load_state_dict(get_ckpt("text_model.final.pt"))
model.region_model.load_state_dict(get_ckpt("region_model.final.pt"))
model = model.to(dtype=torch.float16)

model.push_to_hub(OUT_MODEL, config=config)