Kortikov Mikhail commited on
Commit
46dcd1f
1 Parent(s): e6112a3
Files changed (6) hide show
  1. Dockerfile +14 -7
  2. endpoint.py +8 -0
  3. main.py +4 -83
  4. main2.py +87 -0
  5. static/index.html +36 -0
  6. static/script.js +17 -0
Dockerfile CHANGED
@@ -1,13 +1,20 @@
1
- FROM python:3.10
2
 
3
  WORKDIR ./code
4
 
5
- COPY requirements.txt ./code/requirements.txt
6
 
7
- RUN pip install --no-cache-dir --upgrade -r ./code/requirements.txt
8
 
9
- COPY main.py models/fruit_recognition_model.pth ./models/
10
- COPY templates/ ./templates
11
- COPY metadata.json .
12
 
13
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
 
3
  WORKDIR ./code
4
 
5
+ COPY ./requirements.txt /code/requirements.txt
6
 
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
 
9
+ RUN useradd -m -u 1000 user
 
 
10
 
11
+ USER user
12
+
13
+ ENV HOME=/home/user \
14
+ PATH=/home/user/.local/bin:$PATH
15
+
16
+ WORKDIR $HOME/app
17
+
18
+ COPY --chown=user . $HOME/app
19
+
20
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
endpoint.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+
3
+ pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
4
+
5
+ @app.get("/infer_t5")
6
+ def t5(input):
7
+ output = pipe_flan(input)
8
+ return {"output": output[0]["generated_text"]}
main.py CHANGED
@@ -1,87 +1,8 @@
1
- # main.py
2
- import json
3
- from fastapi import FastAPI, File, UploadFile, Request
4
- from fastapi.responses import HTMLResponse, JSONResponse
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from fastapi.templating import Jinja2Templates
7
- from PIL import Image
8
- from io import BytesIO
9
- import torch
10
- import torch.nn as nn
11
- import torchvision.transforms as transforms
12
- from torchvision import models
13
-
14
- class FruitRecognizer:
15
- def __init__(self, model_path, num_classes):
16
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- self.model = models.resnet18(pretrained=False)
18
- self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
19
- self.model.load_state_dict(torch.load(model_path, map_location=self.device))
20
- self.model.to(self.device)
21
- self.model.eval()
22
-
23
- self.transform = transforms.Compose([
24
- transforms.Resize(256),
25
- transforms.CenterCrop(224),
26
- transforms.ToTensor(),
27
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
28
- ])
29
-
30
- def recognize_fruit_from_path(self, image_path, class_names):
31
- img = Image.open(image_path).convert("RGB")
32
- img = self.transform(img)
33
- img = img.unsqueeze(0).to(self.device)
34
-
35
- with torch.no_grad():
36
- outputs = self.model(img)
37
- _, predicted = torch.max(outputs.data, 1)
38
- predicted_class = class_names[predicted.item()]
39
-
40
- return predicted_class
41
-
42
- def recognize_fruit(self, image, class_names):
43
- img = self.transform(image)
44
- img = img.unsqueeze(0).to(self.device)
45
-
46
- with torch.no_grad():
47
- outputs = self.model(img)
48
- _, predicted = torch.max(outputs.data, 1)
49
- predicted_class = class_names[predicted.item()]
50
-
51
- return predicted_class
52
-
53
 
54
  app = FastAPI()
55
 
56
- app.add_middleware(
57
- CORSMiddleware,
58
- allow_origins=["*"],
59
- allow_credentials=True,
60
- allow_methods=["*"],
61
- allow_headers=["*"],
62
- )
63
-
64
- with open('metadata.json', 'r') as f:
65
- metadata = json.load(f)
66
-
67
- class_names = metadata['classes']
68
-
69
- model_path = "models/fruit_recognition_model.pth"
70
- recognizer = FruitRecognizer(model_path, len(class_names))
71
-
72
- templates = Jinja2Templates(directory="templates")
73
-
74
- @app.get("/", response_class=HTMLResponse)
75
- async def root(request: Request):
76
- return templates.TemplateResponse("index.html", {"request": request})
77
-
78
- @app.post("/predict/")
79
- async def predict_fruit(request: Request, file: UploadFile = File(...)):
80
- try:
81
- print("request")
82
- img = Image.open(BytesIO(await file.read())).convert("RGB")
83
- predicted_class = recognizer.recognize_fruit(img, class_names)
84
- return JSONResponse({"predicted_class": predicted_class})
85
- except Exception as e:
86
- return JSONResponse({"error": str(e)}, status_code=400)
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  app = FastAPI()
3
 
4
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ @app.get("/")
7
+ def index() -> FileResponse:
8
+ return FileResponse(path="/app/static/index.html", media_type="text/html")
main2.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import json
3
+ from fastapi import FastAPI, File, UploadFile, Request
4
+ from fastapi.responses import HTMLResponse, JSONResponse
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.templating import Jinja2Templates
7
+ from PIL import Image
8
+ from io import BytesIO
9
+ import torch
10
+ import torch.nn as nn
11
+ import torchvision.transforms as transforms
12
+ from torchvision import models
13
+
14
+ class FruitRecognizer:
15
+ def __init__(self, model_path, num_classes):
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ self.model = models.resnet18(pretrained=False)
18
+ self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
19
+ self.model.load_state_dict(torch.load(model_path, map_location=self.device))
20
+ self.model.to(self.device)
21
+ self.model.eval()
22
+
23
+ self.transform = transforms.Compose([
24
+ transforms.Resize(256),
25
+ transforms.CenterCrop(224),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
28
+ ])
29
+
30
+ def recognize_fruit_from_path(self, image_path, class_names):
31
+ img = Image.open(image_path).convert("RGB")
32
+ img = self.transform(img)
33
+ img = img.unsqueeze(0).to(self.device)
34
+
35
+ with torch.no_grad():
36
+ outputs = self.model(img)
37
+ _, predicted = torch.max(outputs.data, 1)
38
+ predicted_class = class_names[predicted.item()]
39
+
40
+ return predicted_class
41
+
42
+ def recognize_fruit(self, image, class_names):
43
+ img = self.transform(image)
44
+ img = img.unsqueeze(0).to(self.device)
45
+
46
+ with torch.no_grad():
47
+ outputs = self.model(img)
48
+ _, predicted = torch.max(outputs.data, 1)
49
+ predicted_class = class_names[predicted.item()]
50
+
51
+ return predicted_class
52
+
53
+
54
+ app = FastAPI()
55
+
56
+ app.add_middleware(
57
+ CORSMiddleware,
58
+ allow_origins=["*"],
59
+ allow_credentials=True,
60
+ allow_methods=["*"],
61
+ allow_headers=["*"],
62
+ )
63
+
64
+ with open('metadata.json', 'r') as f:
65
+ metadata = json.load(f)
66
+
67
+ class_names = metadata['classes']
68
+
69
+ model_path = "models/fruit_recognition_model.pth"
70
+ recognizer = FruitRecognizer(model_path, len(class_names))
71
+
72
+ templates = Jinja2Templates(directory="templates")
73
+
74
+ @app.get("/", response_class=HTMLResponse)
75
+ async def root(request: Request):
76
+ return templates.TemplateResponse("index.html", {"request": request})
77
+
78
+ @app.post("/predict/")
79
+ async def predict_fruit(request: Request, file: UploadFile = File(...)):
80
+ try:
81
+ print("request")
82
+ img = Image.open(BytesIO(await file.read())).convert("RGB")
83
+ predicted_class = recognizer.recognize_fruit(img, class_names)
84
+ return JSONResponse({"predicted_class": predicted_class})
85
+ except Exception as e:
86
+ return JSONResponse({"error": str(e)}, status_code=400)
87
+
static/index.html ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <main>
2
+ <section id="text-gen">
3
+ <h2 class="relative group">
4
+ <a
5
+ id="text-generation-using-flan-t5"
6
+ class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full"
7
+ href="#text-generation-using-flan-t5"
8
+ >
9
+ <span><IconCopyLink/></span>
10
+ </a>
11
+ <span>
12
+ Text generation using Flan T5
13
+ </span>
14
+ </h2>
15
+
16
+ <p>
17
+ Model:
18
+ <a
19
+ href="https://huggingface.co/google/flan-t5-small"
20
+ rel="noreferrer"
21
+ target="_blank"
22
+ >google/flan-t5-small
23
+ </a>
24
+ </p>
25
+ <form class="text-gen-form">
26
+ <label for="text-gen-input">Text prompt</label>
27
+ <input
28
+ id="text-gen-input"
29
+ type="text"
30
+ value="German: There are many ducks"
31
+ />
32
+ <button id="text-gen-submit">Submit</button>
33
+ <p class="text-gen-output"></p>
34
+ </form>
35
+ </section>
36
+ </main>
static/script.js ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const textGenForm = document.querySelector(".text-gen-form");
2
+
3
+ const translateText = async (text) => {
4
+ const inferResponse = await fetch(`infer_t5?input=${text}`);
5
+ const inferJson = await inferResponse.json();
6
+
7
+ return inferJson.output;
8
+ };
9
+
10
+ textGenForm.addEventListener("submit", async (event) => {
11
+ event.preventDefault();
12
+
13
+ const textGenInput = document.getElementById("text-gen-input");
14
+ const textGenParagraph = document.querySelector(".text-gen-output");
15
+
16
+ textGenParagraph.textContent = await translateText(textGenInput.value);
17
+ });