ECOFRI commited on
Commit
1182539
1 Parent(s): 7e8c90a

Upload model

Browse files
Files changed (2) hide show
  1. CXR_LLAVA_HF.py +12 -4
  2. config.json +3 -3
CXR_LLAVA_HF.py CHANGED
@@ -8,7 +8,8 @@ from transformers import TextIteratorStreamer
8
  from transformers import StoppingCriteria, GenerationConfig
9
  from threading import Thread
10
  from dataclasses import dataclass
11
-
 
12
  # Model Constants
13
  IGNORE_INDEX = -100
14
  IMAGE_TOKEN_INDEX = -200
@@ -596,8 +597,16 @@ class CXRLLAVAModel(PreTrainedModel):
596
  def generate_cxr_repsonse(self, chat, image, temperature=0.2, top_p=0.8):
597
  with torch.no_grad():
598
  streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
599
- import numpy as np
600
- image = np.expand_dims(image,axis=-1)
 
 
 
 
 
 
 
 
601
  prompt = self.apply_chat_template(chat)
602
  images = self.vision_tower.image_processor(image, return_tensors='pt')['pixel_values']
603
  images = images.to(self.device)
@@ -610,7 +619,6 @@ class CXRLLAVAModel(PreTrainedModel):
610
  max_context_length = getattr(self.config, 'max_position_embeddings', 2048)
611
 
612
  max_new_tokens = min(512, max_context_length - input_ids.shape[-1] - num_image_tokens)
613
-
614
  thread = Thread(target=self.generate, kwargs=dict(
615
  inputs=input_ids,
616
  do_sample=do_sample,
 
8
  from transformers import StoppingCriteria, GenerationConfig
9
  from threading import Thread
10
  from dataclasses import dataclass
11
+ import numpy as np
12
+ from PIL import Image
13
  # Model Constants
14
  IGNORE_INDEX = -100
15
  IMAGE_TOKEN_INDEX = -200
 
597
  def generate_cxr_repsonse(self, chat, image, temperature=0.2, top_p=0.8):
598
  with torch.no_grad():
599
  streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
600
+
601
+ if np.array(image).max()>255:
602
+ raise Exception("WARNING. 16-bit image is not supported.")
603
+
604
+ image = image.convert('L') # convert to grayscale
605
+ image = np.array(image)
606
+
607
+ if len(image.shape) == 2:
608
+ image = np.expand_dims(image,axis=-1) # (width, height) --> (width, height, 1)
609
+
610
  prompt = self.apply_chat_template(chat)
611
  images = self.vision_tower.image_processor(image, return_tensors='pt')['pixel_values']
612
  images = images.to(self.device)
 
619
  max_context_length = getattr(self.config, 'max_position_embeddings', 2048)
620
 
621
  max_new_tokens = min(512, max_context_length - input_ids.shape[-1] - num_image_tokens)
 
622
  thread = Thread(target=self.generate, kwargs=dict(
623
  inputs=input_ids,
624
  do_sample=do_sample,
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "CXR-LLAVA-v2",
3
  "architectures": [
4
  "CXRLLAVAModel"
5
  ],
@@ -26,7 +26,7 @@
26
  "std": 0.3821719215686275
27
  },
28
  "llama": {
29
- "_name_or_path": "CXR-LLAVA-v2",
30
  "add_cross_attention": false,
31
  "architectures": [
32
  "LlamaForCausalLM"
@@ -105,7 +105,7 @@
105
  "vocab_size": 32000
106
  },
107
  "llama_model_dtype": "bf16",
108
- "llama_model_path": "CXR-LLAVA-v2",
109
  "mm_projector_dim": 1024,
110
  "mm_projector_dtype": "fp32",
111
  "mm_projector_path": null,
 
1
  {
2
+ "_name_or_path": "G:\\Temp\\finetune_result\\LLAMA2-7B-CHAT_ViT-L-16-512_MOREKEYWORD_LN_PATCH_FINETUNE_ChexpertJSON_POSTTRAIN_25000_DIST",
3
  "architectures": [
4
  "CXRLLAVAModel"
5
  ],
 
26
  "std": 0.3821719215686275
27
  },
28
  "llama": {
29
+ "_name_or_path": "/home/jovyan/llava/SW_LLAVA/LLAMA2-7B-CHAT_ViT-L-16-512_MOREKEYWORD_LN_PATCH_FINETUNE_ChexpertJSON_POSTTRAIN",
30
  "add_cross_attention": false,
31
  "architectures": [
32
  "LlamaForCausalLM"
 
105
  "vocab_size": 32000
106
  },
107
  "llama_model_dtype": "bf16",
108
+ "llama_model_path": "/home/jovyan/llava/SW_LLAVA/LLAMA2-7B-CHAT_ViT-L-16-512_MOREKEYWORD_LN_PATCH_FINETUNE_ChexpertJSON_POSTTRAIN",
109
  "mm_projector_dim": 1024,
110
  "mm_projector_dtype": "fp32",
111
  "mm_projector_path": null,