hackergeek commited on
Commit
30edd6a
·
verified ·
1 Parent(s): ff11f9b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ============================================================
2
+ # DELCAP — Medical Image Captioning (Hugging Face Space)
3
+ # ============================================================
4
+
5
+ # ------------------------------
6
+ # Install dependencies (if needed)
7
+ # ------------------------------
8
+ !pip install torch torchvision --quiet
9
+ !pip install huggingface_hub --quiet
10
+ !pip install nltk --quiet
11
+ !pip install gradio --quiet
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torchvision.models as models
16
+ import torchvision.transforms as transforms
17
+
18
+ import json
19
+ import nltk
20
+ from PIL import Image
21
+ from collections import Counter
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ import gradio as gr
25
+
26
+ # Ensure punkt tokenizer is available
27
+ nltk.download("punkt")
28
+
29
+ # ============================================================
30
+ # Configuration
31
+ # ============================================================
32
+ class Config:
33
+ IMG_SIZE = 224
34
+ EMBED_SIZE = 256
35
+ HIDDEN_SIZE = 512
36
+ NUM_LSTM_LAYERS = 1
37
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ MAX_CAPTION_LENGTH = 50
39
+
40
+ config = Config()
41
+
42
+ # ============================================================
43
+ # Tokenization
44
+ # ============================================================
45
+ def tokenize_caption(text):
46
+ return nltk.word_tokenize(text.lower())
47
+
48
+ # ============================================================
49
+ # Vocabulary
50
+ # ============================================================
51
+ class Vocabulary:
52
+ def __init__(self, freq_threshold=1):
53
+ self.itos = {
54
+ 0: "<pad>",
55
+ 1: "<unk>",
56
+ 2: "<sos>",
57
+ 3: "<eos>"
58
+ }
59
+ self.stoi = {v: k for k, v in self.itos.items()}
60
+ self.freq_threshold = freq_threshold
61
+ self.vocab_size = len(self.itos)
62
+
63
+ def __len__(self):
64
+ return self.vocab_size
65
+
66
+ @classmethod
67
+ def from_json(cls, json_data):
68
+ vocab_obj = cls()
69
+ vocab_obj.stoi = json_data['stoi']
70
+ vocab_obj.itos = {int(k): v for k, v in json_data['itos'].items()}
71
+ vocab_obj.vocab_size = len(vocab_obj.stoi)
72
+ return vocab_obj
73
+
74
+ def idx_to_word(self, idx):
75
+ return self.itos.get(idx, "<unk>")
76
+
77
+ # ============================================================
78
+ # Encoder
79
+ # ============================================================
80
+ class EncoderCNN(nn.Module):
81
+ def __init__(self, embed_size):
82
+ super().__init__()
83
+ densenet = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
84
+ self.densenet_features = densenet.features
85
+
86
+ for param in self.densenet_features.parameters():
87
+ param.requires_grad_(False)
88
+
89
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
90
+ self.embed = nn.Linear(1024, embed_size)
91
+
92
+ def forward(self, images):
93
+ features = self.densenet_features(images)
94
+ features = self.avgpool(features)
95
+ features = features.view(features.size(0), -1)
96
+ features = self.embed(features)
97
+ return features
98
+
99
+ # ============================================================
100
+ # Decoder
101
+ # ============================================================
102
+ class DecoderRNN(nn.Module):
103
+ def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
104
+ super().__init__()
105
+ self.embed = nn.Embedding(vocab_size, embed_size)
106
+ self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
107
+ self.linear = nn.Linear(hidden_size, vocab_size)
108
+ self.dropout = nn.Dropout(0.5)
109
+ self.num_layers = num_layers
110
+ self.hidden_size = hidden_size
111
+ self.feature_to_hidden_state = nn.Linear(embed_size, hidden_size)
112
+
113
+ def sample(self, features, max_len=20, vocab=None):
114
+ self.eval()
115
+ with torch.no_grad():
116
+ sampled_ids = []
117
+ initial_hidden = self.feature_to_hidden_state(features)
118
+ h = initial_hidden.unsqueeze(0).repeat(self.num_layers, 1, 1)
119
+ c = initial_hidden.unsqueeze(0).repeat(self.num_layers, 1, 1)
120
+ hidden = (h, c)
121
+
122
+ start_token = torch.tensor([vocab.stoi["<sos>"]], device=features.device)
123
+ inputs = self.embed(start_token).unsqueeze(1)
124
+
125
+ for _ in range(max_len):
126
+ output, hidden = self.lstm(inputs, hidden)
127
+ logits = self.linear(self.dropout(output.squeeze(1)))
128
+ _, predicted = logits.max(1)
129
+ sampled_ids.append(predicted)
130
+
131
+ if predicted.item() == vocab.stoi["<eos>"]:
132
+ break
133
+
134
+ inputs = self.embed(predicted).unsqueeze(1)
135
+
136
+ return torch.stack(sampled_ids)
137
+
138
+ # ============================================================
139
+ # Load Vocabulary & Models
140
+ # ============================================================
141
+ vocab_path = hf_hub_download("hackergeek/delcap", "vocab.json")
142
+ with open(vocab_path, "r") as f:
143
+ vocab_data = json.load(f)
144
+ vocab = Vocabulary.from_json(vocab_data)
145
+
146
+ encoder_path = hf_hub_download("hackergeek/delcap", "encoder.pth")
147
+ decoder_path = hf_hub_download("hackergeek/delcap", "decoder.pth")
148
+
149
+ encoder = EncoderCNN(config.EMBED_SIZE).to(config.DEVICE)
150
+ encoder.load_state_dict(torch.load(encoder_path, map_location=config.DEVICE))
151
+
152
+ decoder_state = torch.load(decoder_path, map_location=config.DEVICE)
153
+ vocab_size = decoder_state["linear.weight"].shape[0]
154
+
155
+ decoder = DecoderRNN(config.EMBED_SIZE, config.HIDDEN_SIZE, vocab_size).to(config.DEVICE)
156
+ decoder.load_state_dict(decoder_state)
157
+
158
+ encoder.eval()
159
+ decoder.eval()
160
+
161
+ # ============================================================
162
+ # Image Preprocessing
163
+ # ============================================================
164
+ transform = transforms.Compose([
165
+ transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
166
+ transforms.ToTensor(),
167
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
168
+ std=[0.229, 0.224, 0.225]),
169
+ ])
170
+
171
+ # ============================================================
172
+ # Caption Generation
173
+ # ============================================================
174
+ def generate_caption(image: Image.Image):
175
+ image_tensor = transform(image).unsqueeze(0).to(config.DEVICE)
176
+ with torch.no_grad():
177
+ features = encoder(image_tensor)
178
+ sampled_ids = decoder.sample(features, max_len=config.MAX_CAPTION_LENGTH, vocab=vocab)
179
+
180
+ caption = []
181
+ for token in sampled_ids.cpu().numpy():
182
+ word = vocab.idx_to_word(token.item())
183
+ if word in ["<sos>", "<pad>"]:
184
+ continue
185
+ if word == "<eos>":
186
+ break
187
+ caption.append(word)
188
+ return " ".join(caption)
189
+
190
+ # ============================================================
191
+ # Gradio Interface
192
+ # ============================================================
193
+ iface = gr.Interface(
194
+ fn=generate_caption,
195
+ inputs=gr.Image(type="pil"),
196
+ outputs=gr.Textbox(label="Generated Caption"),
197
+ title="DELCAP — Medical Image Captioning",
198
+ description="Upload a medical image and get a generated caption."
199
+ )
200
+
201
+ iface.launch()