ezequiellopez commited on
Commit
c8df78e
1 Parent(s): c6dd11e

debugging integration tests

Browse files
app/app.py CHANGED
@@ -3,9 +3,8 @@ from fastapi import FastAPI, HTTPException
3
  #import redis
4
  from dotenv import load_dotenv
5
  import os
6
- import torch
7
 
8
- from modules.redistribute import redistribute, insert_element_at_position
9
  #from modules.models.api import Input, Output, NewItem, UUID
10
  from modules.database import BoostDatabase, UserDatabase, User
11
  from _models.request import RankingRequest
@@ -19,10 +18,6 @@ load_dotenv('../.env')
19
  redis_port = os.getenv("REDIS_PORT")
20
  fastapi_port = os.getenv("FASTAPI_PORT")
21
 
22
-
23
- print(f"Is CUDA available: {torch.cuda.is_available()}")
24
- #print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
25
-
26
  #print("Redis port:", redis_port)
27
  print("FastAPI port:", fastapi_port)
28
 
@@ -47,10 +42,8 @@ async def rerank_items(input_data: RankingRequest) -> RankingResponse:
47
  # TODO consider sampling them?
48
 
49
  print(items)
50
- reranked_ids, first_topic, insertion_pos = redistribute(items=items)
51
  #reranked_ids = [ for id_ in reranked_ids]
52
- print("here!")
53
- print(reranked_ids)
54
 
55
  user_in_db = user_db.get_user(user_id=user)
56
 
 
3
  #import redis
4
  from dotenv import load_dotenv
5
  import os
 
6
 
7
+ from modules.redistribute import redistribute, insert_element_at_position, handle_text_content
8
  #from modules.models.api import Input, Output, NewItem, UUID
9
  from modules.database import BoostDatabase, UserDatabase, User
10
  from _models.request import RankingRequest
 
18
  redis_port = os.getenv("REDIS_PORT")
19
  fastapi_port = os.getenv("FASTAPI_PORT")
20
 
 
 
 
 
21
  #print("Redis port:", redis_port)
22
  print("FastAPI port:", fastapi_port)
23
 
 
42
  # TODO consider sampling them?
43
 
44
  print(items)
45
+ reranked_ids, first_topic, insertion_pos = redistribute(platform=platform, items=items)
46
  #reranked_ids = [ for id_ in reranked_ids]
 
 
47
 
48
  user_in_db = user_db.get_user(user_id=user)
49
 
app/modules/classify.py CHANGED
@@ -1,10 +1,17 @@
1
  from transformers import pipeline
2
  from typing import List
3
 
4
- #model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
5
- model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0)
6
-
 
 
 
 
 
7
 
 
 
8
 
9
  label_map = {
10
  "something else": "non-civic",
@@ -13,6 +20,7 @@ label_map = {
13
  "health are and public health": "health",
14
  "religious": "news" # CONSCIOUS DECISION
15
  }
 
16
 
17
  def map_scores(predicted_labels: List[dict], default_label: str):
18
  mapped_scores = [item['scores'][0] if item['labels'][0]!= default_label else 0 for item in predicted_labels]
@@ -26,7 +34,39 @@ def get_first_relevant_label(predicted_labels, mapped_scores: List[float], defau
26
 
27
 
28
  def classify(texts: List[str], labels: List[str]):
29
- predicted_labels = model(texts, labels, multi_label=False)
30
  print(predicted_labels)
31
  return predicted_labels
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from transformers import pipeline
2
  from typing import List
3
 
4
+ try:
5
+ import torch
6
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
7
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
8
+ device = 0
9
+ except:
10
+ print("No GPU available, running on CPU")
11
+ device = None
12
 
13
+ #model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
14
+ model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device)
15
 
16
  label_map = {
17
  "something else": "non-civic",
 
20
  "health are and public health": "health",
21
  "religious": "news" # CONSCIOUS DECISION
22
  }
23
+ default_label = "something else"
24
 
25
  def map_scores(predicted_labels: List[dict], default_label: str):
26
  mapped_scores = [item['scores'][0] if item['labels'][0]!= default_label else 0 for item in predicted_labels]
 
34
 
35
 
36
  def classify(texts: List[str], labels: List[str]):
37
+ predicted_labels = model(texts, labels, multi_label=False, batch_size=16)
38
  print(predicted_labels)
39
  return predicted_labels
40
+
41
+
42
+ def classify(texts: List[str], labels: List[str]):
43
+ results = []
44
+
45
+ # Lists to hold texts and indices for model processing
46
+ model_texts = []
47
+ model_indices = []
48
+
49
+ # Iterate through each text to check for special cases
50
+ for index, text in enumerate(texts):
51
+ if text == "NON-VALID":
52
+ # If text is "X", directly assign the label and score
53
+ results.append({
54
+ "sequence": text,
55
+ "labels": [default_label], # Assuming the first label is the correct one for "X"
56
+ "scores": [1.0] # Assign a full score
57
+ })
58
+ else:
59
+ # Otherwise, prepare for model processing
60
+ model_texts.append(text)
61
+ model_indices.append(index)
62
+
63
+ if model_texts:
64
+ # Process texts through the model if there are any
65
+ predicted_labels = model(model_texts, labels, multi_label=False, batch_size=16)
66
+
67
+ # Insert model results into the correct positions
68
+ for pred, idx in zip(predicted_labels, model_indices):
69
+ results.insert(idx, pred)
70
+
71
+ print(results)
72
+ return results
app/modules/redistribute.py CHANGED
@@ -4,12 +4,27 @@ from modules.classify import classify, map_scores, get_first_relevant_label
4
  labels = ["something else", "headlines, news channels, news articles, breaking news", "politics, policy and politicians", "health care and public health", "religious"]
5
 
6
 
7
- def redistribute(items):
8
- predicted_labels = classify(texts=[item.text for item in items], labels=labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  mapped_scores = map_scores(predicted_labels=predicted_labels, default_label="something else")
10
  first_topic, insertion_pos = get_first_relevant_label(predicted_labels=predicted_labels, mapped_scores=mapped_scores, default_label="something else")
11
  # TODO include parent linking
12
- print("OK?")
13
  reranked_ids, _ = distribute_evenly(ids=[item.id for item in items], scores=mapped_scores)
14
  print(reranked_ids)
15
  return reranked_ids, first_topic, insertion_pos
 
4
  labels = ["something else", "headlines, news channels, news articles, breaking news", "politics, policy and politicians", "health care and public health", "religious"]
5
 
6
 
7
+ def handle_text_content(platform, items):
8
+ texts = []
9
+ for item in items:
10
+ if platform == "reddit" and item.title:
11
+ text = item.title +"\n"+ item.text
12
+ else:
13
+ text = item.text
14
+
15
+ if len(text) <=5:
16
+ text = "NON-VALID"
17
+
18
+ texts.append(text)
19
+ return texts
20
+
21
+
22
+ def redistribute(platform, items):
23
+ predicted_labels = classify(texts=handle_text_content(platform=platform, items=items), labels=labels)
24
  mapped_scores = map_scores(predicted_labels=predicted_labels, default_label="something else")
25
  first_topic, insertion_pos = get_first_relevant_label(predicted_labels=predicted_labels, mapped_scores=mapped_scores, default_label="something else")
26
  # TODO include parent linking
27
+ print("OK--", predicted_labels)
28
  reranked_ids, _ = distribute_evenly(ids=[item.id for item in items], scores=mapped_scores)
29
  print(reranked_ids)
30
  return reranked_ids, first_topic, insertion_pos