Ziad Meligy commited on
Commit
8390b91
·
1 Parent(s): 446860f

Initial commit: FastAPI + RadDINO report generator

Browse files
Files changed (4) hide show
  1. Dockerfile +13 -0
  2. inference_service.py +161 -0
  3. main.py +18 -0
  4. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12.2-slim
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
inference_service.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["TRANSFORMERS_NO_TF"] = "1"
3
+ from transformers import GPT2Tokenizer
4
+ from transformers import AutoImageProcessor, AutoModel
5
+ import torch
6
+ import torch.nn as nn
7
+ from transformers import GPT2LMHeadModel, GPT2PreTrainedModel
8
+ # from encoder_service import RadDINOEncoder, GPT2WithImagePrefix
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ import os
12
+ os.environ["TRANSFORMERS_NO_TF"] = "1"
13
+ from transformers import GPT2Tokenizer
14
+ from transformers import AutoImageProcessor, AutoModel
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+
20
+
21
+
22
+ processor = AutoImageProcessor.from_pretrained('microsoft/rad-dino')
23
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+
26
+
27
+
28
+ class RadDINOEncoder(nn.Module):
29
+ def __init__(self, model_name="microsoft/rad-dino"):
30
+ super().__init__()
31
+ self.processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
32
+ self.encoder = AutoModel.from_pretrained(model_name)
33
+
34
+ def forward(self, image):
35
+ inputs = self.processor(images=image, return_tensors="pt")
36
+ outputs = self.encoder(**inputs)
37
+ cls_embedding = outputs.last_hidden_state[:, 0, :] # CLS token
38
+ return cls_embedding.squeeze(0) # Shape: (768,)
39
+
40
+
41
+ class GPT2WithImagePrefix(nn.Module):
42
+ def __init__(self, gpt2_model, prefix_length=10, embed_dim=768):
43
+ super().__init__()
44
+ self.gpt2 = gpt2_model
45
+ self.prefix_length = prefix_length
46
+
47
+ # Project image embedding to GPT2 embedding space
48
+ self.image_projector = nn.Linear(embed_dim, prefix_length * gpt2_model.config.n_embd)
49
+
50
+ def forward(self, image_embeds, input_ids, attention_mask, labels=None):
51
+ batch_size = input_ids.size(0)
52
+
53
+ # Project image embedding to prefix tokens
54
+ prefix = self.image_projector(image_embeds).view(batch_size, self.prefix_length, -1).to(input_ids.device)
55
+
56
+
57
+ # Get GPT2 token embeddings
58
+ token_embeds = self.gpt2.transformer.wte(input_ids)
59
+
60
+ # Concatenate image prefix with token embeddings
61
+ inputs_embeds = torch.cat((prefix, token_embeds), dim=1)
62
+
63
+ # Extend attention mask
64
+ extended_attention_mask = torch.cat([
65
+ torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype, device=attention_mask.device),
66
+ attention_mask
67
+ ], dim=1)
68
+
69
+ # Feed to GPT2
70
+ outputs = self.gpt2(
71
+ inputs_embeds=inputs_embeds,
72
+ attention_mask=extended_attention_mask,
73
+ labels=labels
74
+ )
75
+ return outputs
76
+
77
+
78
+
79
+
80
+
81
+
82
+ # CHECKPOINT_PATH = "checkpoints/gpt2_with_prefix_epoch_56.pt"
83
+ # TEST_CSV = "D:/GP/Rad-Dino_yarab efregha/IU_XRay/csv/testing_set.csv"
84
+ IMAGE_DIR = "D:/GP/Rad-Dino_yarab efregha/IU_XRay/images"
85
+ MAX_LENGTH = 128
86
+ BATCH_SIZE = 1
87
+ PREFIX_LENGTH = 10
88
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89
+ OUTPUT_CSV = "generated_vs_groundtruth.csv"
90
+
91
+ # -------------------- Load Processor, Tokenizer, Encoder ----------------
92
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
93
+ tokenizer.pad_token = tokenizer.eos_token
94
+ processor = AutoImageProcessor.from_pretrained("microsoft/rad-dino")
95
+ # -------------------- Rebuild the Model --------------------
96
+ gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
97
+ gpt2.resize_token_embeddings(len(tokenizer))
98
+ model = GPT2WithImagePrefix(gpt2, prefix_length=PREFIX_LENGTH).to(DEVICE)
99
+
100
+
101
+
102
+ #Environment variable for Hugging Face token
103
+ CHECKPOINT_REPO = os.getenv("CHECKPOINT_REPO", "TransformingBerry/Raddino-vision-language-gpt2-CHEXMED")
104
+ CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "Gpt2_checkpoint.pt")
105
+ CHECKPOINT_PATH = hf_hub_download(repo_id=CHECKPOINT_REPO, filename=CHECKPOINT_FILENAME, cache_dir="/app/cache")
106
+
107
+
108
+
109
+
110
+
111
+ try:
112
+ checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
113
+ model.load_state_dict(checkpoint["model_state_dict"])
114
+ except FileNotFoundError:
115
+ raise FileNotFoundError(f"Checkpoint file not found at {CHECKPOINT_PATH}")
116
+
117
+
118
+
119
+ # # Load checkpoint
120
+ # checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
121
+ # model.load_state_dict(checkpoint["model_state_dict"])
122
+
123
+ image_encoder = RadDINOEncoder()
124
+ model.eval()
125
+
126
+
127
+ def generate_report_serviceFn(image):
128
+ model.eval()
129
+ image_encoder.eval()
130
+ with torch.no_grad():
131
+ # Process the image
132
+ image_embeds = image_encoder(image).to(DEVICE)
133
+
134
+ # Prepare empty input for generation
135
+ empty_input_ids = tokenizer.encode("", return_tensors="pt").to(DEVICE).long()
136
+ empty_attention_mask = torch.ones_like(empty_input_ids).to(DEVICE)
137
+
138
+ # Generate report
139
+ prefix = model.image_projector(image_embeds).view(1, model.prefix_length, -1)
140
+ token_embeds = model.gpt2.transformer.wte(empty_input_ids)
141
+ inputs_embeds = torch.cat((prefix, token_embeds), dim=1)
142
+
143
+ extended_attention_mask = torch.cat([
144
+ torch.ones((1, model.prefix_length), device=DEVICE),
145
+ empty_attention_mask
146
+ ], dim=1)
147
+
148
+ generated_ids = model.gpt2.generate(
149
+ inputs_embeds=inputs_embeds,
150
+ attention_mask=extended_attention_mask,
151
+ max_length=model.prefix_length + 60,
152
+ pad_token_id=tokenizer.eos_token_id,
153
+ eos_token_id=tokenizer.eos_token_id
154
+ )
155
+
156
+ generated_text = tokenizer.decode(generated_ids[0][model.prefix_length:], skip_special_tokens=True)
157
+
158
+ return generated_text
159
+
160
+
161
+
main.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from inference_service import generate_report_serviceFn
4
+ from PIL import Image
5
+ app = FastAPI()
6
+ @app.post("/generate_report")
7
+ async def generate_report(file: UploadFile = File(...)):
8
+ try:
9
+ # Read the image file from the request
10
+ image = Image.open(file.file).convert("RGB")
11
+
12
+ # Generate the report using the service function
13
+ report = generate_report_serviceFn(image)
14
+
15
+ return JSONResponse({"generated_report": report})
16
+ except Exception as e:
17
+ return JSONResponse({"error": str(e)}, status_code=500)
18
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ pillow
3
+ torch
4
+ transformers
5
+ python-multipart
6
+ uvicorn