ezequiellopez commited on
Commit
e145e85
β€’
1 Parent(s): b0a3f00

setting up

Browse files
Files changed (4) hide show
  1. .env +2 -0
  2. README.md +1 -1
  3. app/main.py +127 -0
  4. requirements.txt +3 -1
.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ REDIS_PORT=6379
2
+ FASTAPI_PORT=7860
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Mpib Prosocial
3
  emoji: πŸ†
4
  colorFrom: indigo
5
  colorTo: blue
 
1
  ---
2
+ title: MPIB Prosocial
3
  emoji: πŸ†
4
  colorFrom: indigo
5
  colorTo: blue
app/main.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import required libraries
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from typing import List
5
+ import redis
6
+ from transformers import BartForSequenceClassification, BartTokenizer, AutoTokenizer, AutoConfig, pipeline
7
+ from dotenv import load_dotenv
8
+ import os
9
+
10
+ # Load environment variables from .env file
11
+ load_dotenv('../.env')
12
+
13
+ # Access environment variables
14
+ redis_port = os.getenv("REDIS_PORT")
15
+ fastapi_port = os.getenv("FASTAPI_PORT")
16
+
17
+
18
+ print("Redis port:", redis_port)
19
+ print("FastAPI port:", fastapi_port)
20
+
21
+
22
+ # Initialize FastAPI app and Redis client
23
+ app = FastAPI()
24
+ redis_client = redis.Redis(host='redis', port=6379)
25
+
26
+ # Load BART model and tokenizer
27
+ #model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
28
+ #tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")
29
+
30
+ model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
31
+
32
+ def score_text_with_labels(model, text: list, labels: list, multi: bool=True):
33
+ #candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
34
+ results = [result['scores'] for result in model(text, labels, multi_label=multi)]
35
+ #return dict(zip(labels, results['scores']))
36
+ return results
37
+
38
+ def smooth_sequence(tweets_scores, window_size):
39
+ # Calculate the sum of scores for both labels for each tweet
40
+ tweet_sum_scores = [(sum(scores), index) for index, scores in enumerate(tweets_scores)]
41
+ # Sort tweets based on their sum scores, then by their original index to stabilize
42
+ sorted_tweets = sorted(tweet_sum_scores, key=lambda x: (x[0], x[1]))
43
+ # Extract the original indices of tweets after sorting
44
+ sorted_indices = [index for _, index in sorted_tweets]
45
+ # Create a new sequence based on sorted indices
46
+ smoothed_sequence = [tweets_scores[index] for index in sorted_indices]
47
+ return smoothed_sequence
48
+
49
+ def rerank_on_label(label: str):
50
+ return 200
51
+
52
+
53
+ # Define Pydantic models
54
+ class Item(BaseModel):
55
+ #id: str
56
+ #title: str = None
57
+ text: str
58
+ #type: str
59
+ #engagements: dict
60
+
61
+ class RerankedItems(BaseModel):
62
+ ranked_ids: List[str]
63
+ new_items: List[dict]
64
+
65
+ # Define a health check endpoint
66
+ @app.get("/")
67
+ async def health_check():
68
+ return {"status": "ok"}
69
+
70
+ # Define FastAPI routes and logic
71
+ @app.post("/rerank/")
72
+ async def rerank_items(items: List[Item]) -> RerankedItems:
73
+ reranked_ids = []
74
+
75
+ # Process each item
76
+ for item in items:
77
+ # Classify the item using Hugging Face BART model
78
+ labels = classify_item(item.text)
79
+
80
+ # Save the item with labels in Redis
81
+ redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)})
82
+
83
+ # Add the item id to the reranked list
84
+ reranked_ids.append(item.id)
85
+
86
+ # Sort the items based on model confidence
87
+ reranked_ids.sort(key=lambda x: redis_client.zscore("classified_items", x), reverse=True)
88
+
89
+ # Return the reranked items
90
+ return {"ranked_ids": reranked_ids, "new_items": []} # Ignore "new_items" for now
91
+
92
+ # Define an endpoint to classify items and save them in Redis
93
+ @app.post("/classify/")
94
+ async def classify_and_save(items: List[Item]) -> None:
95
+ print("new 1")
96
+ #labels = ["factful", "civic", "constructive", "politics", "health", "news"]
97
+ #labels = ["factful", "politics"]
98
+ labels = ["something else", "news feed, news articles, breaking news", "politics and polititians", "healthcare and health"]
99
+ #labels = ["health", "politics", "news", "non-health non-politics non-news"]
100
+ texts = [item.text for item in items]
101
+ print(texts)
102
+
103
+ labels = score_text_with_labels(model=model, text=texts, labels=labels, multi=True)
104
+ print(labels)
105
+ return labels
106
+ #for item in items:
107
+ # print(item)
108
+ # Classify the item using Hugging Face BART model
109
+ #labels = classify_item(item.text)
110
+ #return score_text_with_labels(model, item.text, labels)
111
+ # Save the item with labels in Redis
112
+ #redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)})
113
+ #return labels
114
+ #return None
115
+
116
+ # Function to classify item text using Hugging Face BART model
117
+ def classify_item(text: str) -> List[str]:
118
+ # Tokenize input text
119
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
120
+ print(1)
121
+ # Perform inference
122
+ outputs = model(**inputs)
123
+ print(2)
124
+ # Get predicted label
125
+ predicted_label = tokenizer.decode(outputs.logits.argmax())
126
+
127
+ return [predicted_label]
requirements.txt CHANGED
@@ -4,4 +4,6 @@ transformers
4
  python-dotenv
5
  dotenv-cli
6
  pandas
7
- uvicorn
 
 
 
4
  python-dotenv
5
  dotenv-cli
6
  pandas
7
+ uvicorn
8
+ pydantic
9
+ redis