stephenebert commited on
Commit
cd3cf65
Β·
verified Β·
1 Parent(s): 1f1ea35

Upload 5 files

Browse files
Files changed (5) hide show
  1. Dockerfile +23 -0
  2. app.py +53 -0
  3. main.py +63 -0
  4. requirements.txt +7 -0
  5. tagger.py +91 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ HF_HOME=/root/.cache/huggingface \
6
+ TRANSFORMERS_CACHE=/root/.cache/huggingface/transformers
7
+
8
+ WORKDIR /app
9
+
10
+ # system basics (tiny)
11
+ RUN apt-get update && apt-get install -y --no-install-recommends git && \
12
+ rm -rf /var/lib/apt/lists/*
13
+
14
+ COPY requirements.txt .
15
+ RUN pip install --no-cache-dir -r requirements.txt
16
+
17
+ # pre-download tiny NLTK bits so first request is warm
18
+ RUN python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger')"
19
+
20
+ COPY . .
21
+
22
+ EXPOSE 7860
23
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, HTTPException, Query, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ from typing import List
5
+ from pathlib import Path
6
+ from PIL import Image
7
+ import io, json
8
+
9
+ from tagger import tag_pil_image, CAP_TAG_DIR
10
+
11
+ app = FastAPI(title="Image Tagger API", version="0.3.0")
12
+
13
+ class TagOut(BaseModel):
14
+ filename: str
15
+ caption: str
16
+ tags: List[str]
17
+
18
+ @app.get("/healthz")
19
+ def healthz():
20
+ return {"ok": True}
21
+
22
+ @app.post("/upload", response_model=TagOut)
23
+ async def upload(
24
+ file: UploadFile = File(..., description="PNG or JPEG image"),
25
+ top_k: int = Query(5, ge=1, le=20, description="Maximum number of tags"),
26
+ nouns: bool = Query(True, description="Include noun tags"),
27
+ adjs: bool = Query(True, description="Include adjective tags"),
28
+ verbs: bool = Query(True, description="Include verb tags"),
29
+ ):
30
+ if file.content_type not in {"image/png", "image/jpeg"}:
31
+ raise HTTPException(415, "Only PNG or JPEG supported")
32
+
33
+ try:
34
+ data = await file.read()
35
+ img = Image.open(io.BytesIO(data)).convert("RGB")
36
+ except Exception:
37
+ raise HTTPException(400, "Could not decode image")
38
+
39
+ stem = Path(file.filename).stem or "upload"
40
+ tags = tag_pil_image(
41
+ img, stem,
42
+ top_k=top_k, keep_nouns=nouns, keep_adjs=adjs, keep_verbs=verbs
43
+ )
44
+
45
+ caption = ""
46
+ meta = CAP_TAG_DIR / f"{stem}.json"
47
+ if meta.exists():
48
+ try:
49
+ caption = json.loads(meta.read_text())["caption"]
50
+ except Exception:
51
+ pass
52
+
53
+ return JSONResponse({"filename": file.filename, "caption": caption, "tags": tags})
main.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, HTTPException, Query, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ from pathlib import Path
5
+ from typing import List
6
+ import io, json
7
+ from PIL import Image
8
+
9
+ from .tagger import tag_pil_image
10
+
11
+ app = FastAPI(title="Image Tagger API", version="0.3.0")
12
+
13
+
14
+ class TagOut(BaseModel):
15
+ filename: str
16
+ caption: str
17
+ tags: List[str]
18
+
19
+
20
+ @app.get("/healthz")
21
+ def healthz():
22
+ return {"ok": True}
23
+
24
+
25
+ @app.post("/upload", response_model=TagOut)
26
+ async def upload(
27
+ file: UploadFile = File(..., description="PNG or JPEG image"),
28
+ top_k: int = Query(5, ge=1, le=20, description="Maximum number of tags"),
29
+ nouns: bool = Query(True, description="Include noun tags"),
30
+ adjs: bool = Query(True, description="Include adjective tags"),
31
+ verbs: bool = Query(True, description="Include verb tags"),
32
+ ):
33
+ if file.content_type not in {"image/png", "image/jpeg"}:
34
+ raise HTTPException(415, "Only PNG or JPEG supported")
35
+
36
+ try:
37
+ data = await file.read()
38
+ img = Image.open(io.BytesIO(data)).convert("RGB")
39
+ except:
40
+ raise HTTPException(400, "Could not decode image")
41
+
42
+ stem = Path(file.filename).stem or "upload"
43
+ tags = tag_pil_image(
44
+ img,
45
+ stem,
46
+ top_k=top_k,
47
+ keep_nouns=nouns,
48
+ keep_adjs=adjs,
49
+ keep_verbs=verbs,
50
+ )
51
+
52
+ # pull the caption back out of the side-car JSON
53
+ caption = ""
54
+ meta = Path.home() / "Desktop" / "image_tags" / f"{stem}.json"
55
+ if meta.exists():
56
+ try:
57
+ caption = json.loads(meta.read_text())["caption"]
58
+ except:
59
+ pass
60
+
61
+ return JSONResponse(
62
+ {"filename": file.filename, "caption": caption, "tags": tags}
63
+ )
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi==0.111.0
2
+ uvicorn[standard]==0.30.0
3
+ pillow>=9.5
4
+ transformers>=4.41
5
+ torch>=2.2
6
+ nltk>=3.8
7
+ pydantic>=2.7
tagger.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime as _dt
4
+ import json as _json
5
+ import pathlib as _pl
6
+ import re as _re
7
+ import sys as _sys
8
+ from typing import List
9
+
10
+ import nltk
11
+ from PIL import Image
12
+ from transformers import BlipForConditionalGeneration, BlipProcessor
13
+
14
+ # ─── ensure punkt + perceptron tagger are downloaded ──────────────────────────
15
+ for res, subdir in [
16
+ ("punkt", "tokenizers"),
17
+ ("averaged_perceptron_tagger", "taggers"),
18
+ ]:
19
+ try:
20
+ nltk.data.find(f"{subdir}/{res}")
21
+ except LookupError:
22
+ nltk.download(res, quiet=True)
23
+
24
+ # ─── where we dump the caption+tags JSON sidecars ──────────────────────────────
25
+ CAP_TAG_DIR = _pl.Path.home() / "Desktop" / "image_tags"
26
+ CAP_TAG_DIR.mkdir(exist_ok=True, parents=True)
27
+
28
+ # ─── load the BLIP model once ──────────────────────────────────────────────────
29
+ _processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
30
+ _model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
31
+
32
+ # ─── allowed POS prefixes ──────────────────────────────────────────────────────
33
+ _POS = {"nouns": ("NN",), "adjs": ("JJ",), "verbs": ("VB",)}
34
+
35
+ def _caption_to_tags(
36
+ caption: str,
37
+ k: int,
38
+ keep_nouns: bool,
39
+ keep_adjs: bool,
40
+ keep_verbs: bool,
41
+ ) -> List[str]:
42
+ from nltk.tokenize import wordpunct_tokenize
43
+
44
+ allowed = []
45
+ if keep_nouns: allowed += _POS["nouns"]
46
+ if keep_adjs: allowed += _POS["adjs"]
47
+ if keep_verbs: allowed += _POS["verbs"]
48
+
49
+ seen, out = set(), []
50
+ for w, pos in nltk.pos_tag(wordpunct_tokenize(caption.lower())):
51
+ if any(pos.startswith(pref) for pref in allowed):
52
+ clean = _re.sub(r"[^a-z0-9-]", "", w)
53
+ if clean and clean not in seen:
54
+ out.append(clean)
55
+ seen.add(clean)
56
+ if len(out) >= k:
57
+ break
58
+ return out
59
+
60
+ def tag_pil_image(
61
+ img: Image.Image,
62
+ stem: str,
63
+ *,
64
+ top_k: int = 5,
65
+ keep_nouns: bool = True,
66
+ keep_adjs: bool = True,
67
+ keep_verbs: bool = True,
68
+ ) -> List[str]:
69
+ # 1) generate caption
70
+ ids = _model.generate(**_processor(images=img, return_tensors="pt"), max_length=30)
71
+ caption = _processor.decode(ids[0], skip_special_tokens=True)
72
+ # 2) extract tags
73
+ tags = _caption_to_tags(caption, top_k, keep_nouns, keep_adjs, keep_verbs)
74
+ # 3) persist side-car JSON for main.py to read back
75
+ payload = {
76
+ "caption": caption,
77
+ "tags": tags,
78
+ "timestamp": _dt.datetime.now(_dt.timezone.utc).isoformat(),
79
+ }
80
+ (_p := CAP_TAG_DIR / f"{stem}.json").write_text(_json.dumps(payload, indent=2))
81
+ return tags
82
+
83
+ if __name__ == "__main__":
84
+ if len(_sys.argv) < 2:
85
+ _sys.exit("Usage: python tagger.py <image_path> [top_k]")
86
+ path = _pl.Path(_sys.argv[1])
87
+ if not path.exists():
88
+ _sys.exit(f"File not found: {path}")
89
+ k = int(_sys.argv[2]) if len(_sys.argv) > 2 else 5
90
+ with Image.open(path).convert("RGB") as im:
91
+ print("tags:", ", ".join(tag_pil_image(im, path.stem, top_k=k)))