blasisd commited on
Commit
2b27c2d
·
1 Parent(s): 9e35b9e

Deleted unnecessary files, updated Dockerfile

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -1
  2. src/backup_services.py +0 -91
Dockerfile CHANGED
@@ -7,7 +7,8 @@ RUN useradd -m -u 1000 user
7
  WORKDIR /app
8
 
9
  COPY --chown=user ./requirements.txt requirements.txt
 
10
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
 
12
  COPY --chown=user ./src /app
13
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", $APP_PORT]
 
7
  WORKDIR /app
8
 
9
  COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN python -m pip install --upgrade pip
11
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
12
 
13
  COPY --chown=user ./src /app
14
+ CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", $APP_PORT]
src/backup_services.py DELETED
@@ -1,91 +0,0 @@
1
- import logging
2
-
3
- from pathlib import Path
4
-
5
- import torch
6
-
7
- from fastapi import HTTPException, status
8
- from PIL import Image
9
- from torchvision import models
10
- from typing import Tuple
11
-
12
- import src.config as config
13
-
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- async def classify_mushroom_in_image_svc(img: Image.Image) -> Tuple[str, str, str]:
19
- """Service used to classify a mushroom shown in an image.
20
- The mushroom is classified to one of many well known mushroom classes/types,
21
- as well as according to its toxicity profile (i.e. edible or poisonous).
22
- Additionally, a probability is returned showing confidence of classification.
23
-
24
- :param img: the image of the mushroom to be classified
25
- :type img: Image.Image
26
- :return: mushroom_type, toxicity_profile, classification_confidence
27
- :rtype: Tuple[str, str, str]
28
- """
29
-
30
- try:
31
- # Device agnostic
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
-
34
- logger.debug("Loading classification model.")
35
-
36
- model_path = config.MODEL_PATH
37
-
38
- # Load saved model checkpoint
39
- model_state_dict = torch.load(model_path, map_location=device)
40
-
41
- # Get class_names from saved model checkpoint
42
- model_dirname = Path(model_path).resolve().parent
43
- with open(model_dirname / "labels.txt", "r") as labels_fp:
44
- class_names = [line.strip() for line in labels_fp]
45
-
46
- model = models.get_model(config.BASE_MODEL_NAME, num_classes=len(class_names))
47
-
48
- # Load state_dict of saved model
49
- model.load_state_dict(model_state_dict)
50
-
51
- weights_enum = models.get_model_weights(config.BASE_MODEL_NAME)
52
-
53
- # Get the model's default transforms
54
- image_transform = weights_enum.DEFAULT.transforms()
55
-
56
- # Make sure the model is on the target device
57
- model.to(device)
58
-
59
- # Turn on model evaluation mode and inference mode
60
- model.eval()
61
- with torch.inference_mode():
62
- logger.debug("Adapting input image by applying necessary transforms!")
63
- # Transform and add an extra dimension to image (model requires samples in [batch_size, color_channels, height, width])
64
- transformed_image = image_transform(img).unsqueeze(dim=0)
65
-
66
- # Make a prediction on image with an extra dimension and send it to the target device
67
- target_image_pred = model(transformed_image.to(device))
68
-
69
- logger.debug("Starting classification process...")
70
- # Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
71
- target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
72
-
73
- # Convert prediction probabilities -> prediction labels
74
- target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
75
-
76
- class_name = class_names[target_image_pred_label]
77
-
78
- # Split class_name to mushroom type and toxicity profile
79
- class_type, toxicity = class_name.rsplit("_", 1)
80
-
81
- # 4 decimal points precision
82
- prob = round(target_image_pred_probs.max().item(), 4)
83
-
84
- return class_type, toxicity, prob
85
-
86
- except Exception as e:
87
- logger.error("Classification process error: {e}")
88
- raise HTTPException(
89
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
90
- detail="Classification process failed due to an internal error. Contact support if this persists.",
91
- )