ga89tiy
commited on
Commit
•
f5ae994
1
Parent(s):
9de5cd4
last update
Browse files- LLAVA_Biovil/biovil_t/transformer.py +1 -7
- LLAVA_Biovil/llava/model/builder.py +0 -13
- LLAVA_Biovil/llava/model/language_model/llava_llama.py +5 -8
- LLAVA_Biovil/llava/model/llava_arch.py +9 -111
- __pycache__/utils.cpython-310.pyc +0 -0
- example_code.py +3 -25
- findings_classifier/__pycache__/chexpert_train.cpython-310.pyc +0 -0
LLAVA_Biovil/biovil_t/transformer.py
CHANGED
@@ -93,17 +93,11 @@ def forward(self, current_image: torch.Tensor, previous_image: Optional[torch.Te
|
|
93 |
|
94 |
def forward_after_reshape(self,
|
95 |
x: torch.Tensor,
|
96 |
-
pos_embed: torch.Tensor
|
97 |
-
x_previous: Optional[torch.Tensor] = None) -> torch.Tensor:
|
98 |
B, L, _ = x.shape # Batch, Sequence length, Feature dimension
|
99 |
|
100 |
# Positional and type embeddings
|
101 |
type_embed = self.type_embed[0].expand(B, L, -1)
|
102 |
-
if x_previous is not None:
|
103 |
-
x = torch.cat((x, x_previous), dim=1)
|
104 |
-
pos_embed = torch.cat((pos_embed, pos_embed), dim=1)
|
105 |
-
prev_type_embed = self.type_embed[1].expand(B, L, -1)
|
106 |
-
type_embed = torch.cat((type_embed, prev_type_embed), dim=1)
|
107 |
|
108 |
# Add positional and type embeddings (used in query and key matching)
|
109 |
pos_and_type_embed = pos_embed + type_embed
|
|
|
93 |
|
94 |
def forward_after_reshape(self,
|
95 |
x: torch.Tensor,
|
96 |
+
pos_embed: torch.Tensor) -> torch.Tensor:
|
|
|
97 |
B, L, _ = x.shape # Batch, Sequence length, Feature dimension
|
98 |
|
99 |
# Positional and type embeddings
|
100 |
type_embed = self.type_embed[0].expand(B, L, -1)
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
# Add positional and type embeddings (used in query and key matching)
|
103 |
pos_and_type_embed = pos_embed + type_embed
|
LLAVA_Biovil/llava/model/builder.py
CHANGED
@@ -184,19 +184,6 @@ def load_from_hf(repo_id, filename, subfolder=None):
|
|
184 |
new_vision_tower_state_dict[new_k] = v
|
185 |
print('Loaded additional vision tower weights...')
|
186 |
vision_tower.load_state_dict(new_vision_tower_state_dict, strict=False)
|
187 |
-
# weight difference sum([torch.norm(value-vision_tower.state_dict()[key].cpu()) for key,value in new_vision_tower_state_dict.items()])
|
188 |
-
|
189 |
-
image_pooler = model.get_image_pooler()
|
190 |
-
if image_pooler is not None:
|
191 |
-
image_pooler.to(device=device, dtype=torch.float16)
|
192 |
-
if non_lora_trainables is not None and any(k.startswith('model.image_pooler.') for k in non_lora_trainables):
|
193 |
-
new_image_pooler_state_dict = {}
|
194 |
-
for k, v in non_lora_trainables.items(): # we need remapping, because state_dict from model is always like model.vision_tower. It should be vision_tower.
|
195 |
-
if 'model.image_pooler.' in k:
|
196 |
-
new_k = k.replace('model.image_pooler.', '')
|
197 |
-
new_image_pooler_state_dict[new_k] = v
|
198 |
-
print('Loading additional image pooler weights...')
|
199 |
-
image_pooler.load_state_dict(new_image_pooler_state_dict, strict=True)
|
200 |
|
201 |
if hasattr(model.config, "max_sequence_length"):
|
202 |
context_len = model.config.max_sequence_length
|
|
|
184 |
new_vision_tower_state_dict[new_k] = v
|
185 |
print('Loaded additional vision tower weights...')
|
186 |
vision_tower.load_state_dict(new_vision_tower_state_dict, strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
if hasattr(model.config, "max_sequence_length"):
|
189 |
context_len = model.config.max_sequence_length
|
LLAVA_Biovil/llava/model/language_model/llava_llama.py
CHANGED
@@ -35,20 +35,19 @@ class LlavaConfig(LlamaConfig):
|
|
35 |
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
|
36 |
config_class = LlavaConfig
|
37 |
|
38 |
-
def __init__(self, config: LlamaConfig
|
39 |
-
super(LlavaLlamaModel, self).__init__(config
|
40 |
|
41 |
|
42 |
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
43 |
config_class = LlavaConfig
|
44 |
|
45 |
-
def __init__(self, config
|
46 |
super(LlamaForCausalLM, self).__init__(config)
|
47 |
-
self.model = LlavaLlamaModel(config
|
48 |
self.pretraining_tp = config.pretraining_tp
|
49 |
self.vocab_size = config.vocab_size
|
50 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
51 |
-
self.mv_type = mv_type
|
52 |
|
53 |
# Initialize weights and apply final processing
|
54 |
self.post_init()
|
@@ -68,7 +67,6 @@ def forward(
|
|
68 |
output_attentions: Optional[bool] = None,
|
69 |
output_hidden_states: Optional[bool] = None,
|
70 |
images: Optional[torch.FloatTensor] = None,
|
71 |
-
prev_images: Optional[torch.FloatTensor] = None,
|
72 |
return_dict: Optional[bool] = None,
|
73 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
74 |
if inputs_embeds is None:
|
@@ -85,8 +83,7 @@ def forward(
|
|
85 |
attention_mask,
|
86 |
past_key_values,
|
87 |
labels,
|
88 |
-
images
|
89 |
-
prev_images
|
90 |
)
|
91 |
output = super().forward(
|
92 |
input_ids=input_ids,
|
|
|
35 |
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
|
36 |
config_class = LlavaConfig
|
37 |
|
38 |
+
def __init__(self, config: LlamaConfig):
|
39 |
+
super(LlavaLlamaModel, self).__init__(config)
|
40 |
|
41 |
|
42 |
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
|
43 |
config_class = LlavaConfig
|
44 |
|
45 |
+
def __init__(self, config):
|
46 |
super(LlamaForCausalLM, self).__init__(config)
|
47 |
+
self.model = LlavaLlamaModel(config)
|
48 |
self.pretraining_tp = config.pretraining_tp
|
49 |
self.vocab_size = config.vocab_size
|
50 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
51 |
|
52 |
# Initialize weights and apply final processing
|
53 |
self.post_init()
|
|
|
67 |
output_attentions: Optional[bool] = None,
|
68 |
output_hidden_states: Optional[bool] = None,
|
69 |
images: Optional[torch.FloatTensor] = None,
|
|
|
70 |
return_dict: Optional[bool] = None,
|
71 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
72 |
if inputs_embeds is None:
|
|
|
83 |
attention_mask,
|
84 |
past_key_values,
|
85 |
labels,
|
86 |
+
images
|
|
|
87 |
)
|
88 |
output = super().forward(
|
89 |
input_ids=input_ids,
|
LLAVA_Biovil/llava/model/llava_arch.py
CHANGED
@@ -27,13 +27,12 @@
|
|
27 |
|
28 |
class LlavaMetaModel:
|
29 |
|
30 |
-
def __init__(self, config
|
31 |
super(LlavaMetaModel, self).__init__(config)
|
32 |
|
33 |
if hasattr(config, "mm_vision_tower"):
|
34 |
self.vision_tower = build_vision_tower(config, delay_load=True)
|
35 |
self.mm_projector = build_vision_projector(config)
|
36 |
-
self.image_pooler = build_image_pooler(config) if "pool" in mv_type else None
|
37 |
|
38 |
def get_vision_tower(self):
|
39 |
vision_tower = getattr(self, 'vision_tower', None)
|
@@ -51,7 +50,6 @@ def initialize_vision_modules(self, model_args, fsdp=None):
|
|
51 |
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
52 |
|
53 |
self.config.mm_vision_tower = vision_tower
|
54 |
-
self.config.mv_type = getattr(model_args, 'mv_type', False)
|
55 |
|
56 |
if self.get_vision_tower() is None:
|
57 |
if self.config.mm_vision_tower == 'biovil':
|
@@ -188,87 +186,8 @@ def pad_embeddings_mv(self, embeddings, padding_value=0):
|
|
188 |
|
189 |
return padded_embeddings.flatten(1,2), mask
|
190 |
|
191 |
-
def encode_images_pooled(self, images, split_sizes, num_imgs_present, num_imgs_past, mv_type="pool_all"):
|
192 |
-
image_pooler = self.get_image_pooler()
|
193 |
-
image_features = self.get_model().get_vision_tower()(images)
|
194 |
-
if self.get_model().config.mm_vision_tower == 'biovil':
|
195 |
-
image_features = image_features.patch_embeddings
|
196 |
-
# flatten
|
197 |
-
image_features = image_features.flatten(2).transpose(1,2)
|
198 |
-
if split_sizes is not None:
|
199 |
-
image_features = torch.split(image_features, split_sizes, dim=0)
|
200 |
-
|
201 |
-
if mv_type == "pool_all":
|
202 |
-
# merge present and past per batch
|
203 |
-
present_features = [image_features[i] for i in range(len(num_imgs_present))]
|
204 |
-
past_features = []
|
205 |
-
i = 0
|
206 |
-
for num_imgs_elem in num_imgs_past:
|
207 |
-
if num_imgs_elem != 0:
|
208 |
-
past_features.append(image_features[i+len(num_imgs_present)])
|
209 |
-
i += 1
|
210 |
-
else:
|
211 |
-
past_features.append(None)
|
212 |
-
|
213 |
-
all_img_features = []
|
214 |
-
for idx, (batch_num_present, batch_num_past) in enumerate(zip(num_imgs_present, num_imgs_past)):
|
215 |
-
if batch_num_past == 0:
|
216 |
-
all_img_features.append(present_features[idx])
|
217 |
-
else:
|
218 |
-
all_img_features.append(torch.cat((present_features[idx], past_features[idx]), dim=0))
|
219 |
-
|
220 |
-
all_img_features, mask, token_type_ids = self.pad_embeddings(all_img_features, num_imgs_present, num_imgs_past)
|
221 |
-
all_img_features = image_pooler(all_img_features, mask, token_type_ids)
|
222 |
-
|
223 |
-
elif mv_type == "pool_concat":
|
224 |
-
present_features = [image_features[i] for i in range(len(num_imgs_present))]
|
225 |
-
past_features = [image_features[i+len(num_imgs_present)] for i in range(len(image_features)-len(num_imgs_present))]
|
226 |
-
present_features, mask_present, _ = self.pad_embeddings(present_features)
|
227 |
-
past_features, mask_past, _ = self.pad_embeddings(past_features)
|
228 |
-
present_features = image_pooler(present_features, mask_present)
|
229 |
-
past_features = image_pooler(past_features, mask_past)
|
230 |
-
# TODO maybe max pool on past features to save tokens
|
231 |
-
# concat present and past per batch if past is not empty
|
232 |
-
all_img_features = []
|
233 |
-
idx_present = 0
|
234 |
-
idx_past = 0
|
235 |
-
for batch_num_present, batch_num_past in zip(num_imgs_present, num_imgs_past):
|
236 |
-
if batch_num_past == 0:
|
237 |
-
all_img_features.append(present_features[idx_present])
|
238 |
-
idx_present += 1
|
239 |
-
else:
|
240 |
-
all_img_features.append(torch.cat((present_features[idx_present], past_features[idx_past]), dim=0))
|
241 |
-
idx_present += 1
|
242 |
-
idx_past += 1
|
243 |
-
else:
|
244 |
-
raise NotImplementedError
|
245 |
-
if type(all_img_features) is list:
|
246 |
-
split_sizes = [image.shape[0] for image in all_img_features]
|
247 |
-
all_img_features = self.get_model().mm_projector(torch.cat(all_img_features, dim=0))
|
248 |
-
all_img_features = torch.split(all_img_features, split_sizes, dim=0)
|
249 |
-
|
250 |
-
else:
|
251 |
-
all_img_features = self.get_model().mm_projector(all_img_features)
|
252 |
-
return all_img_features
|
253 |
-
|
254 |
-
def encode_images_pooled_mv(self, images, split_sizes):
|
255 |
-
image_pooler = self.get_image_pooler()
|
256 |
-
image_features = self.get_model().get_vision_tower()(images)
|
257 |
-
if split_sizes is not None:
|
258 |
-
image_features = torch.split(image_features, split_sizes, dim=0)
|
259 |
-
image_features, mask = self.pad_embeddings_mv(image_features)
|
260 |
-
image_features = image_pooler(image_features, mask)
|
261 |
-
else:
|
262 |
-
mask = torch.ones((image_features.shape[0], image_features.shape[1]), dtype=torch.bool, device=image_features[0].device)
|
263 |
-
image_features = image_pooler(image_features, mask)
|
264 |
-
image_features = self.get_model().mm_projector(image_features)
|
265 |
-
return image_features
|
266 |
-
|
267 |
-
def get_image_pooler(self):
|
268 |
-
return self.get_model().get_image_pooler()
|
269 |
-
|
270 |
def prepare_inputs_labels_for_multimodal(
|
271 |
-
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
|
272 |
):
|
273 |
vision_tower = self.get_vision_tower()
|
274 |
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
@@ -283,35 +202,14 @@ def prepare_inputs_labels_for_multimodal(
|
|
283 |
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
284 |
|
285 |
if type(images) is list or images.ndim == 5:
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
if getattr(self.config, 'mv_type') == "pool_all":
|
293 |
-
concat_images = torch.cat((torch.cat([image for image in images], dim=0), torch.cat([image for image in prev_images if image is not None], dim=0))) # first present, then past, all will be merged
|
294 |
-
split_sizes = [image.shape[0] for image in images]+ [image.shape[0] for image in prev_images if image is not None]
|
295 |
-
num_imgs_present = [image.shape[0] if image is not None else 0 for image in images]
|
296 |
-
num_imgs_past = [image.shape[0] if image is not None else 0 for image in prev_images]
|
297 |
-
image_features = self.encode_images_pooled(concat_images, split_sizes, num_imgs_present, num_imgs_past, "pool_all")
|
298 |
-
if getattr(self.config, 'mv_type') == "pool_concat": # TODO make sure to allow empty past -> shorter sequence
|
299 |
-
concat_images = torch.cat((torch.cat([image for image in images], dim=0), torch.cat([image for image in prev_images if image is not None], dim=0))) # first present, then past, all will be merged
|
300 |
-
split_sizes = [image.shape[0] for image in images]+ [image.shape[0] for image in prev_images if image is not None]
|
301 |
-
num_imgs_present = [image.shape[0] if image is not None else 0 for image in images]
|
302 |
-
num_imgs_past = [image.shape[0] if image is not None else 0 for image in prev_images]
|
303 |
-
image_features = self.encode_images_pooled(concat_images, split_sizes, num_imgs_present, num_imgs_past, "pool_concat")
|
304 |
-
if getattr(self.config, 'mv_type') == "pool": #no past images
|
305 |
-
concat_images = torch.cat([image for image in images], dim=0)
|
306 |
-
split_sizes = [image.shape[0] for image in images]
|
307 |
-
image_features = self.encode_images_pooled_mv(concat_images, split_sizes)
|
308 |
else:
|
309 |
-
|
310 |
-
image_features = self.encode_images_pooled(images, None).to(self.device)
|
311 |
-
elif hasattr(self.config, 'mv_type') and getattr(self.config, 'mv_type') == "pool":
|
312 |
-
image_features = self.encode_images_pooled_mv(images, None).to(self.device)
|
313 |
-
else:
|
314 |
-
image_features = self.encode_images(images).to(self.device)
|
315 |
|
316 |
# TODO: image start / end is not implemented here to support pretraining.
|
317 |
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
|
|
27 |
|
28 |
class LlavaMetaModel:
|
29 |
|
30 |
+
def __init__(self, config):
|
31 |
super(LlavaMetaModel, self).__init__(config)
|
32 |
|
33 |
if hasattr(config, "mm_vision_tower"):
|
34 |
self.vision_tower = build_vision_tower(config, delay_load=True)
|
35 |
self.mm_projector = build_vision_projector(config)
|
|
|
36 |
|
37 |
def get_vision_tower(self):
|
38 |
vision_tower = getattr(self, 'vision_tower', None)
|
|
|
50 |
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
51 |
|
52 |
self.config.mm_vision_tower = vision_tower
|
|
|
53 |
|
54 |
if self.get_vision_tower() is None:
|
55 |
if self.config.mm_vision_tower == 'biovil':
|
|
|
186 |
|
187 |
return padded_embeddings.flatten(1,2), mask
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
def prepare_inputs_labels_for_multimodal(
|
190 |
+
self, input_ids, position_ids, attention_mask, past_key_values, labels, images
|
191 |
):
|
192 |
vision_tower = self.get_vision_tower()
|
193 |
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
|
|
202 |
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
203 |
|
204 |
if type(images) is list or images.ndim == 5:
|
205 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
206 |
+
image_features = self.encode_images(concat_images)
|
207 |
+
split_sizes = [image.shape[0] for image in images]
|
208 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
209 |
+
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
|
210 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
else:
|
212 |
+
image_features = self.encode_images(images).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
# TODO: image start / end is not implemented here to support pretraining.
|
215 |
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.68 kB). View file
|
|
example_code.py
CHANGED
@@ -6,7 +6,7 @@ import requests
|
|
6 |
import torch
|
7 |
from PIL import Image
|
8 |
import numpy as np
|
9 |
-
from huggingface_hub import snapshot_download
|
10 |
|
11 |
from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
|
12 |
from LLAVA_Biovil.llava.model.builder import load_pretrained_model
|
@@ -18,13 +18,12 @@ from utils import create_chest_xray_transform_for_inference, init_chexpert_predi
|
|
18 |
|
19 |
def load_model_from_huggingface(repo_id):
|
20 |
# Download model files
|
21 |
-
model_path = snapshot_download(repo_id=repo_id, revision="main"
|
22 |
model_path = Path(model_path)
|
23 |
|
24 |
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
|
25 |
model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
|
26 |
|
27 |
-
|
28 |
return tokenizer, model, image_processor, context_len
|
29 |
|
30 |
|
@@ -37,7 +36,7 @@ if __name__ == '__main__':
|
|
37 |
image = remap_to_uint8(np.array(image))
|
38 |
image = Image.fromarray(image).convert("L")
|
39 |
|
40 |
-
tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="
|
41 |
cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()
|
42 |
|
43 |
model.config.tokenizer_padding_side = "left"
|
@@ -82,27 +81,6 @@ if __name__ == '__main__':
|
|
82 |
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
|
83 |
print("ASSISTANT: ", pred)
|
84 |
|
85 |
-
# add prediction to conversation
|
86 |
-
conv.messages.pop()
|
87 |
-
conv.append_message("ASSISTANT", pred)
|
88 |
-
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
89 |
-
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
|
90 |
-
|
91 |
-
# generate a report
|
92 |
-
with torch.inference_mode():
|
93 |
-
output_ids = model.generate(
|
94 |
-
input_ids,
|
95 |
-
images=image_tensor,
|
96 |
-
do_sample=False,
|
97 |
-
use_cache=True,
|
98 |
-
max_new_tokens=300,
|
99 |
-
stopping_criteria=[stopping_criteria],
|
100 |
-
pad_token_id=tokenizer.pad_token_id
|
101 |
-
)
|
102 |
-
|
103 |
-
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
|
104 |
-
print("ASSISTANT: ", pred)
|
105 |
-
|
106 |
# add prediction to conversation
|
107 |
conv.messages.pop()
|
108 |
conv.append_message("ASSISTANT", pred)
|
|
|
6 |
import torch
|
7 |
from PIL import Image
|
8 |
import numpy as np
|
9 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
10 |
|
11 |
from LLAVA_Biovil.llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, remap_to_uint8
|
12 |
from LLAVA_Biovil.llava.model.builder import load_pretrained_model
|
|
|
18 |
|
19 |
def load_model_from_huggingface(repo_id):
|
20 |
# Download model files
|
21 |
+
model_path = snapshot_download(repo_id=repo_id, revision="main")
|
22 |
model_path = Path(model_path)
|
23 |
|
24 |
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base='liuhaotian/llava-v1.5-7b',
|
25 |
model_name="llava-v1.5-7b-task-lora_radialog_instruct_llava_biovil_unfrozen_2e-5_5epochs_v5_checkpoint-21000", load_8bit=False, load_4bit=False)
|
26 |
|
|
|
27 |
return tokenizer, model, image_processor, context_len
|
28 |
|
29 |
|
|
|
36 |
image = remap_to_uint8(np.array(image))
|
37 |
image = Image.fromarray(image).convert("L")
|
38 |
|
39 |
+
tokenizer, model, image_processor, context_len = load_model_from_huggingface(repo_id="ChantalPellegrini/RaDialog-interactive-radiology-report-generation")
|
40 |
cp_model, cp_class_names, cp_transforms = init_chexpert_predictor()
|
41 |
|
42 |
model.config.tokenizer_padding_side = "left"
|
|
|
81 |
pred = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip().replace("</s>", "")
|
82 |
print("ASSISTANT: ", pred)
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
# add prediction to conversation
|
85 |
conv.messages.pop()
|
86 |
conv.append_message("ASSISTANT", pred)
|
findings_classifier/__pycache__/chexpert_train.cpython-310.pyc
CHANGED
Binary files a/findings_classifier/__pycache__/chexpert_train.cpython-310.pyc and b/findings_classifier/__pycache__/chexpert_train.cpython-310.pyc differ
|
|