K00B404 commited on
Commit
7721cb1
·
verified ·
1 Parent(s): 86fc7b8

Update gguf_loader.py

Browse files
Files changed (1) hide show
  1. gguf_loader.py +144 -0
gguf_loader.py CHANGED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Optional, Union, Dict, Any
5
+
6
+ class GGUFUNetLoader:
7
+ """
8
+ A class for loading and managing GGUF-formatted UNet models for diffusion.
9
+ Supports quantized models with custom patch handling.
10
+ """
11
+ def __init__(self):
12
+ self.model = None
13
+ self.patches = {}
14
+ self.backup = {}
15
+ self.load_device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ self.offload_device = "cpu"
17
+
18
+ @staticmethod
19
+ def is_quantized(weight: torch.Tensor) -> bool:
20
+ """Check if a tensor is quantized."""
21
+ return hasattr(weight, "patches")
22
+
23
+ def patch_weight(self, key: str, weight: torch.Tensor, device_to: Optional[str] = None) -> torch.Tensor:
24
+ """
25
+ Apply patches to model weights with quantization support.
26
+
27
+ Args:
28
+ key: The parameter key to patch
29
+ weight: The weight tensor to patch
30
+ device_to: Target device for the patched weight
31
+
32
+ Returns:
33
+ Patched weight tensor
34
+ """
35
+ if key not in self.patches:
36
+ return weight
37
+
38
+ if self.is_quantized(weight):
39
+ # Handle quantized weights
40
+ out_weight = weight.to(device_to if device_to else self.load_device)
41
+ patches = self.patches[key]
42
+ out_weight.patches = [(self.calculate_weight, patches, key)]
43
+ return out_weight
44
+ else:
45
+ # Handle regular weights
46
+ if key not in self.backup:
47
+ self.backup[key] = weight.to(device=self.offload_device)
48
+
49
+ temp_weight = weight.to(torch.float32)
50
+ if device_to:
51
+ temp_weight = temp_weight.to(device_to)
52
+
53
+ # Apply patches
54
+ for patch in self.patches[key]:
55
+ temp_weight += patch
56
+
57
+ return temp_weight.to(weight.dtype)
58
+
59
+ def load_model(self,
60
+ model_path: Union[str, Path],
61
+ config: Optional[Dict[str, Any]] = None) -> None:
62
+ """
63
+ Load a GGUF model from disk.
64
+
65
+ Args:
66
+ model_path: Path to the GGUF model file
67
+ config: Optional configuration dictionary for model loading
68
+ """
69
+ try:
70
+ model_path = Path(model_path)
71
+ if not model_path.exists():
72
+ raise FileNotFoundError(f"Model file not found: {model_path}")
73
+
74
+ if not str(model_path).endswith('.gguf'):
75
+ raise ValueError("Not a GGUF model file")
76
+
77
+ # Load the model (implementation would depend on your GGUF loader)
78
+ from .gguf_loader import load_gguf_model # You'd need to implement this
79
+ self.model = load_gguf_model(
80
+ model_path,
81
+ device=self.load_device,
82
+ config=config or {}
83
+ )
84
+
85
+ logging.info(f"Successfully loaded GGUF model from {model_path}")
86
+
87
+ except Exception as e:
88
+ logging.error(f"Error loading model: {str(e)}")
89
+ raise
90
+
91
+ def add_patch(self, key: str, patch: torch.Tensor) -> None:
92
+ """
93
+ Add a patch for a specific model parameter.
94
+
95
+ Args:
96
+ key: Parameter key to patch
97
+ patch: The patch tensor to apply
98
+ """
99
+ if key not in self.patches:
100
+ self.patches[key] = []
101
+ self.patches[key].append(patch)
102
+
103
+ def clear_patches(self) -> None:
104
+ """Remove all patches from the model."""
105
+ self.patches.clear()
106
+
107
+ # Clear quantized patches
108
+ if self.model:
109
+ for param in self.model.parameters():
110
+ if self.is_quantized(param):
111
+ param.patches = []
112
+
113
+ def to(self, device: str) -> 'GGUFUNetLoader':
114
+ """
115
+ Move model to specified device.
116
+
117
+ Args:
118
+ device: Target device ("cuda", "cpu", etc.)
119
+
120
+ Returns:
121
+ Self for method chaining
122
+ """
123
+ if self.model:
124
+ self.model.to(device)
125
+ self.load_device = device
126
+ return self
127
+
128
+ @staticmethod
129
+ def calculate_weight(patches: list, base_weight: torch.Tensor, key: str) -> torch.Tensor:
130
+ """
131
+ Calculate final weight by applying patches.
132
+
133
+ Args:
134
+ patches: List of patches to apply
135
+ base_weight: Base weight tensor
136
+ key: Parameter key
137
+
138
+ Returns:
139
+ Patched weight tensor
140
+ """
141
+ result = base_weight.clone()
142
+ for patch in patches:
143
+ result += patch
144
+ return result