File size: 1,234 Bytes
8f4c005
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
FROM python:3.9-slim-bullseye

RUN apt-get -y update && \
   apt-get install -y --no-install-recommends build-essential  \
   curl wget nginx ca-certificates npm \
   && npm install pm2 -g \
   && pip install --upgrade pip setuptools \
   && rm -rf /var/lib/apt/lists/*

COPY requirements.txt .
RUN pip install -r requirements.txt

class ZeroShotTextClassifier:
  # Class variable for the model
  classifier = None
  @classmethod
  def load(cls):
    if cls.classifier is None:
      # Load the model only once
      cls.classifier = pipeline("zero-shot-classification", 
                         model="facebook/bart-large-mnli")
  @classmethod
  def predict(cls, text, candidate_labels):
    # Ensure the model is loaded
    cls.load()
    # Predict
    huggingface_predictions = cls.classifier(text, candidate_labels)
    # Create our own prediction object with the best label
    max_index = np.argmax(huggingface_predictions["scores"])
    label = huggingface_predictions["labels"][max_index]
    score = huggingface_predictions["scores"][max_index]
    return {"label": label, "score": score}

RUN python -c "from transformers import pipeline; classifier = pipeline('zero-shot-classification', model='facebook/bart-large-mnli')"