jonathanjordan21 commited on
Commit
b01c113
1 Parent(s): bccc14d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py CHANGED
@@ -1,7 +1,57 @@
1
  from fastapi import FastAPI
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
5
  @app.get("/")
6
  def greet_json():
7
  return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
3
+ import torch
4
 
5
  app = FastAPI()
6
 
7
+
8
+
9
+ model_name = "cardiffnlp/twitter-xlm-roberta-base-sentiment"
10
+ sentiment_model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
+ sentiment_tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ sentiment_model.config.id2label[3] = "mixed"
13
+
14
  @app.get("/")
15
  def greet_json():
16
  return {"Hello": "World!"}
17
+
18
+
19
+ @app.post(/sentiment_score)
20
+ async def sentiment_score(text: str):
21
+ inputs = sentiment_tokenizer(text[:2500], return_tensors='pt')
22
+
23
+ with torch.no_grad():
24
+ logits = sentiment_model(**inputs).logits #+ 1
25
+
26
+
27
+ print(logits)
28
+
29
+ logits = logits + logits[0,1].abs()
30
+
31
+ # print(torch.nn.functional.sigmoid(logits))
32
+
33
+ # logits = logits / 10
34
+
35
+ # print(logits)
36
+
37
+ # print(torch.abs(logits[0,0] - logits[0,-1]))
38
+ # print(logits[0,1]//torch.max(torch.abs(logits[0,::2])))
39
+
40
+ logits = torch.cat(
41
+ (
42
+ logits, (
43
+ # ( logits[0,1] + torch.sign(logits[0,0] - logits[0,-1]) * (logits[0,0] - logits[0,-1])/2 )/2 +
44
+ # (logits[0,0] + logits[0,-1])/20
45
+ (1 - torch.abs(logits[0,0] - logits[0,-1])*(2+(logits[0,1]//torch.max(torch.abs(logits[0,::2])))))
46
+ ).unsqueeze(0).unsqueeze(0)
47
+ ), dim=-1
48
+ )
49
+
50
+ softmax = torch.nn.functional.softmax(
51
+ logits,
52
+ dim=-1
53
+ )
54
+
55
+ return [{"label":model.config.id2label[predicted_class_id.tolist()], "score":softmax[0, predicted_class_id].tolist()} for predicted_class_id in softmax.argsort(dim=-1, descending=True)[0]]
56
+
57
+