Baptiste Canton commited on
Commit
2056905
0 Parent(s):

initial commit

Browse files
Files changed (11) hide show
  1. .gitattributes +35 -0
  2. .gitignore +3 -0
  3. Dockerfile +32 -0
  4. LICENSE +21 -0
  5. Makefile +14 -0
  6. README.md +11 -0
  7. captioner.py +118 -0
  8. data.json +3 -0
  9. gg.py +18 -0
  10. grapi.py +15 -0
  11. requirements.txt +10 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .envrc
2
+ .venv
3
+ __pycache__/
Dockerfile ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM debian:11-slim AS build
2
+ RUN apt-get update && apt-get upgrade -yy && \
3
+ apt-get install --no-install-suggests --no-install-recommends --yes python3-venv gcc libpython3-dev && \
4
+ python3 -m venv /venv && \
5
+ /venv/bin/pip install --no-cache-dir --upgrade pip setuptools wheel
6
+
7
+ FROM build AS build-venv
8
+ COPY requirements.txt /requirements.txt
9
+ RUN /venv/bin/pip install --no-cache-dir --upgrade --disable-pip-version-check -r /requirements.txt
10
+
11
+ # Copy the virtualenv into a distroless image
12
+ FROM gcr.io/distroless/python3-debian11
13
+ COPY --from=build-venv /venv /venv
14
+
15
+ #RUN useradd -m -u 1000 user
16
+ USER nonroot:nonroot
17
+ ENV HOME=/tmp \
18
+ PATH=/venv/bin:$PATH
19
+ WORKDIR $HOME/app
20
+
21
+ COPY --chown=nonroot captioner.py $HOME/app/captioner.py
22
+
23
+ ARG MODEL="Salesforce/blip-image-captioning-base"
24
+ ENV MODEL=${MODEL}
25
+ ENTRYPOINT ["/venv/bin/python3", "/venv/bin/uvicorn", "--host", "0.0.0.0", "captioner:app"]
26
+ EXPOSE 8000
27
+ USER nonroot:nonroot
28
+
29
+
30
+
31
+
32
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Baptiste Canton
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Makefile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .PHONY: run dockerbuild dockerrun
3
+
4
+ run:
5
+ @uvicorn captioner:app
6
+
7
+ dockerbuild:
8
+ @docker build -t captioner .
9
+
10
+ dockerrun:
11
+ @docker run --rm -p 8000:8000 captioner
12
+
13
+ test:
14
+ @time curl -X POST -H "Content-Type: application/json" -d '{"url": "https://cataas.com/cat" }' http://127.0.0.1:8000/caption/
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Truc
3
+ emoji: 🌍
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
captioner.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ import time
5
+ from typing import List, Union
6
+
7
+ from pillow_heif import register_heif_opener
8
+
9
+ register_heif_opener()
10
+
11
+ import gradio as gr
12
+ from fastapi import FastAPI, HTTPException
13
+ from pydantic import BaseModel, HttpUrl
14
+ from transformers import pipeline
15
+
16
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG")
17
+ MAX_URLS = int(os.getenv("MAX_URLS", 5))
18
+ MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", 200))
19
+ # https://huggingface.co/models?pipeline_tag=image-to-text&sort=likes
20
+ MODEL = os.getenv("MODEL", "../models/Salesforce/blip-image-captioning-large")
21
+
22
+
23
+ logging.basicConfig(level=LOG_LEVEL)
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ app = FastAPI()
28
+
29
+
30
+ captioner = None # Placeholder for the captioner pipeline
31
+ is_initialized = asyncio.Event() # Event to track initialization status
32
+ lock = asyncio.Lock()
33
+
34
+
35
+ def load_model():
36
+ global captioner
37
+ logger.info("Loading model...")
38
+ # simpler model: "ydshieh/vit-gpt2-coco-en"
39
+ captioner = pipeline(
40
+ "image-to-text",
41
+ model=MODEL,
42
+ max_new_tokens=MAX_NEW_TOKENS,
43
+ )
44
+ logger.info("Done loading model.")
45
+ is_initialized.set()
46
+
47
+
48
+ class Image(BaseModel):
49
+ url: Union[HttpUrl, List[HttpUrl]] # url can be a string or a list of strings
50
+
51
+
52
+ @app.on_event("startup")
53
+ async def startup_event():
54
+ global app
55
+ asyncio.create_task(asyncio.to_thread(load_model))
56
+ # add gradio interface
57
+ iface = gr.Interface(fn=captioner_gradapter, inputs="text", outputs=["text"], allow_flagging="never")
58
+ app = gr.mount_gradio_app(app, iface, path="/gradio")
59
+
60
+
61
+ async def captioner_gradapter(image_url):
62
+ await is_initialized.wait()
63
+ async with lock:
64
+ result = await asyncio.to_thread(captioner, image_url)
65
+ caption = result[0]["generated_text"]
66
+ return caption
67
+
68
+
69
+ @app.get("/")
70
+ async def root():
71
+ return {"message": "Hello World"}
72
+
73
+
74
+ # the image url is passed in as a "url" tag in the json body
75
+ @app.post("/caption/")
76
+ async def create_caption(image: Image):
77
+ if isinstance(image.url, list) and len(image.url) > MAX_URLS:
78
+ logger.debug(
79
+ f"Request with more than {MAX_URLS} URLs received. Refusing the request."
80
+ )
81
+
82
+ raise HTTPException(
83
+ status_code=400,
84
+ detail=f"Maximum of {MAX_URLS} URLs can be processed at once",
85
+ )
86
+ async with lock:
87
+ await is_initialized.wait() # Wait until initialization is completed
88
+
89
+ start_time = time.time()
90
+ # get the image url from the json body
91
+ image_url = image.url
92
+ try:
93
+ caption = await asyncio.to_thread(captioner, image_url)
94
+ except Exception as e:
95
+ logger.error("Error during caption generation: %s", str(e))
96
+ raise HTTPException(
97
+ status_code=500,
98
+ detail="An error occurred during caption generation. Please try again later.",
99
+ )
100
+ end_time = time.time()
101
+ duration = end_time - start_time
102
+ logger.debug("Captioning completed. Time taken: %s seconds.", duration)
103
+
104
+ return {"caption": caption, "duration": duration}
105
+
106
+
107
+ # add liveness probe
108
+ @app.get("/healthz")
109
+ async def healthz():
110
+ return {"status": "ok"}
111
+
112
+
113
+ # add readiness probe
114
+ @app.get("/readyz")
115
+ async def readyz():
116
+ if not is_initialized.is_set():
117
+ raise HTTPException(status_code=503, detail="Initialization in progress")
118
+ return {"status": "ok"}
data.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "url": "https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png"
3
+ }
gg.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+ import logging
3
+ import os
4
+
5
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG")
6
+
7
+ logging.basicConfig(level=LOG_LEVEL)
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+
12
+ client = Client("https://batmac-captioner.hf.space/")
13
+ print(client.view_api())
14
+ result = client.predict(
15
+ "https://images.pexels.com/photos/58997/pexels-photo-58997.jpeg", # str in 'image_url' Textbox component
16
+ api_name="/predict"
17
+ )
18
+ print(result)
grapi.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+
5
+ from gradio_client import Client
6
+
7
+ LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG")
8
+
9
+ logging.basicConfig(level=LOG_LEVEL)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ url = sys.argv[1]
13
+
14
+ client = Client(url)
15
+ print(client.view_api())
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ pillow
3
+ transformers
4
+ torch
5
+ accelerate
6
+ pillow-heif
7
+ gradio
8
+
9
+ uvicorn
10
+