Bagi4 commited on
Commit
f791980
1 Parent(s): e155850

feat: new model

Browse files
Files changed (1) hide show
  1. main.py +47 -3
main.py CHANGED
@@ -1,5 +1,8 @@
1
  import logging
2
  import uvicorn
 
 
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from transformers import pipeline
@@ -13,6 +16,17 @@ logging.basicConfig(
13
  datefmt='%Y-%m-%d %H:%M:%S'
14
  )
15
  classifier = pipeline("zero-shot-classification", model="models/classificator", use_fast=False)
 
 
 
 
 
 
 
 
 
 
 
16
  app = FastAPI()
17
 
18
 
@@ -28,16 +42,46 @@ class ResponseData(BaseModel):
28
  scores: list[float]
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @app.post("/classify", response_model=ResponseData, tags=["Classificator"])
32
  async def classify_text(data: RequestData):
33
- result = classifier(data.sequence, data.labels, multi_label=data.multiLabel)
34
  logging.info(result)
35
-
36
  return result
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
39
  @app.get("/ping", tags=["TEST"])
40
- def ping():
41
  return "pong"
42
 
43
 
 
1
  import logging
2
  import uvicorn
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import torch
5
+ import torch.nn.functional as F
6
  from fastapi import FastAPI
7
  from pydantic import BaseModel
8
  from transformers import pipeline
 
16
  datefmt='%Y-%m-%d %H:%M:%S'
17
  )
18
  classifier = pipeline("zero-shot-classification", model="models/classificator", use_fast=False)
19
+
20
+
21
+ def mean_pooling(model_output, attention_mask):
22
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
23
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
24
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
25
+
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
28
+ model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
29
+
30
  app = FastAPI()
31
 
32
 
 
42
  scores: list[float]
43
 
44
 
45
+ def classify(data: RequestData):
46
+ return classifier(data.sequence, data.labels, multi_label=data.multiLabel)
47
+
48
+
49
+ def similarity(data: RequestData):
50
+ sentences = [data.sequence]
51
+ sentences.extend(data.labels)
52
+ encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
53
+
54
+ with torch.no_grad():
55
+ model_output = model(**encoded_input)
56
+
57
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
58
+
59
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
60
+
61
+ text_probs = sentence_embeddings[:1] @ sentence_embeddings[1:].T
62
+ return text_probs.tolist()[0]
63
+
64
+
65
  @app.post("/classify", response_model=ResponseData, tags=["Classificator"])
66
  async def classify_text(data: RequestData):
67
+ result = classify(data)
68
  logging.info(result)
 
69
  return result
70
 
71
 
72
+ @app.post("/similarity", response_model=ResponseData, tags=["Similarity"])
73
+ async def classify_text(data: RequestData):
74
+ result = similarity(data)
75
+ logging.info(result)
76
+ return ResponseData.model_validate({
77
+ "sequence": data.sequence,
78
+ "labels": data.labels,
79
+ "scores": result
80
+ })
81
+
82
+
83
  @app.get("/ping", tags=["TEST"])
84
+ async def ping():
85
  return "pong"
86
 
87