e-hossam96 commited on
Commit
56f8c1c
·
1 Parent(s): 988393e

added main and test codes

Browse files
Files changed (2) hide show
  1. main.py +124 -4
  2. test.py +99 -0
main.py CHANGED
@@ -1,6 +1,126 @@
1
- def main():
2
- print("Hello from sentiment-classification!")
 
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
4
 
5
- if __name__ == "__main__":
6
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import torch
3
+ import asyncio
4
+ import transformers
5
+ from typing import Dict
6
+ from fastapi import FastAPI
7
+ from pydantic import BaseModel
8
+ from contextlib import asynccontextmanager
9
+ from transformers import (
10
+ pipeline,
11
+ AutoTokenizer,
12
+ AutoModelForSequenceClassification,
13
+ BitsAndBytesConfig,
14
+ )
15
 
16
+ # ----------------------------- #
17
+ # Configurations #
18
+ # ----------------------------- #
19
+ transformers.set_seed(42)
20
+ torch.set_default_dtype(torch.bfloat16)
21
 
22
+ MODEL_NAME = "climatebert/distilroberta-base-climate-sentiment"
23
+ BATCH_PROCESS_INTERVAL = 0.01
24
+ MAX_BATCH_SIZE = 128
25
+
26
+ # ----------------------------- #
27
+ # Shared Storage #
28
+ # ----------------------------- #
29
+ query_queue: asyncio.Queue = asyncio.Queue()
30
+ results: Dict[str, Dict] = {}
31
+ classifier = None # will be initialized in lifespan
32
+
33
+
34
+ # ----------------------------- #
35
+ # Model Initialization #
36
+ # ----------------------------- #
37
+ def load_classifier(model_name: str):
38
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
39
+ model = AutoModelForSequenceClassification.from_pretrained(
40
+ model_name,
41
+ device_map="auto",
42
+ quantization_config=BitsAndBytesConfig(
43
+ load_in_4bit=True,
44
+ bnb_4bit_compute_dtype=torch.bfloat16,
45
+ ),
46
+ )
47
+ return pipeline(
48
+ "text-classification", model=model, tokenizer=tokenizer, framework="pt"
49
+ )
50
+
51
+
52
+ # ----------------------------- #
53
+ # Pydantic Schema #
54
+ # ----------------------------- #
55
+ class Query(BaseModel):
56
+ sentence: str
57
+
58
+
59
+ # ----------------------------- #
60
+ # Queue Processing Task #
61
+ # ----------------------------- #
62
+ async def process_queue():
63
+ while True:
64
+ await asyncio.sleep(BATCH_PROCESS_INTERVAL)
65
+
66
+ batch = []
67
+ while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE:
68
+ batch.append(await query_queue.get())
69
+
70
+ if not batch:
71
+ continue
72
+
73
+ sentences = [item["sentence"] for item in batch]
74
+ ids = [item["id"] for item in batch]
75
+ predictions = classifier(sentences, batch_size=len(sentences))
76
+
77
+ for query_id, pred, sentence in zip(ids, predictions, sentences):
78
+ results[query_id] = {
79
+ "sentence": sentence,
80
+ "label": pred["label"],
81
+ "score": pred["score"],
82
+ }
83
+
84
+
85
+ # ----------------------------- #
86
+ # Lifespan Handler #
87
+ # ----------------------------- #
88
+ @asynccontextmanager
89
+ async def lifespan(app: FastAPI):
90
+ global classifier
91
+ classifier = load_classifier(MODEL_NAME)
92
+ _ = classifier("Startup warm-up sentence.")
93
+ queue_task = asyncio.create_task(process_queue())
94
+ yield
95
+ queue_task.cancel()
96
+ try:
97
+ await queue_task
98
+ except asyncio.CancelledError:
99
+ pass
100
+
101
+
102
+ # ----------------------------- #
103
+ # FastAPI Setup #
104
+ # ----------------------------- #
105
+ app = FastAPI(lifespan=lifespan)
106
+
107
+
108
+ # ----------------------------- #
109
+ # API Endpoints #
110
+ # ----------------------------- #
111
+ @app.post("/classify")
112
+ async def classify(query: Query):
113
+ query_id = str(uuid.uuid4())
114
+ await query_queue.put({"id": query_id, "sentence": query.sentence})
115
+
116
+ while query_id not in results:
117
+ await asyncio.sleep(0.001)
118
+
119
+ return {"id": query_id, "result": results.pop(query_id)}
120
+
121
+
122
+ @app.get("/")
123
+ def read_root():
124
+ return {
125
+ "message": "Welcome to the Sentiment Classification API with Query Batching"
126
+ }
test.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import requests
5
+ import logging
6
+ from typing import Union, Tuple
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from tqdm import tqdm # optional: install via `pip install tqdm`
9
+
10
+ # ----------------------------- #
11
+ # Configuration #
12
+ # ----------------------------- #
13
+ URL = "http://localhost:8000/classify"
14
+ NUM_REQUESTS = 4000
15
+ MAX_WORKERS = os.cpu_count() * 8 or 2
16
+ TIMEOUT = 20
17
+
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+ # Sample text tokens
21
+ SAMPLE_POPULATION = """Limes have higher contents of sugars and acids than lemons do.[1] Lime juice may be squeezed from fresh limes, or purchased in bottles in both unsweetened and sweetened varieties. Lime juice is used to make limeade, and as an ingredient (typically as sour mix) in many cocktails.
22
+
23
+ Lime pickles are an integral part of Indian cuisine, especially in South India. In Kerala, the Onam Sadhya usually includes either lemon pickle or lime pickle. Other Indian preparations of limes include sweetened lime pickle, salted pickle, and lime chutney.
24
+
25
+ In cooking, lime is valued both for the acidity of its juice and the floral aroma of its zest. It is a common ingredient in authentic Mexican, Vietnamese and Thai dishes. Lime soup is a traditional dish from the Mexican state of Yucatan. It is also used for its pickling properties in ceviche. Some guacamole recipes call for lime juice.
26
+
27
+ The use of dried limes (called black lime or limoo) as a flavouring is typical of Persian cuisine, Iraqi cuisine, as well as in Eastern Arabian cuisine baharat (a spice mixture that is also called kabsa or kebsa).
28
+
29
+ Key lime gives the character flavouring to the American dessert known as Key lime pie. In Australia, desert lime is used for making marmalade.
30
+
31
+ Lime is an ingredient in several highball cocktails, often based on gin, such as gin and tonic, the gimlet and the Rickey. Freshly squeezed lime juice is also considered a key ingredient in margaritas, although sometimes lemon juice is substituted. It is also found in many rum cocktails such as the daiquiri, and other tropical drinks.
32
+
33
+ Lime extracts and lime essential oils are frequently used in perfumes, cleaning products, and aromatherapy.""".split()
34
+
35
+
36
+ # ----------------------------- #
37
+ # Request Builder #
38
+ # ----------------------------- #
39
+ def build_payload() -> dict:
40
+ sentence = " ".join(
41
+ random.choices(SAMPLE_POPULATION, k=random.randint(20, len(SAMPLE_POPULATION)))
42
+ )
43
+ return {"sentence": sentence}
44
+
45
+
46
+ # ----------------------------- #
47
+ # Request Sender Logic #
48
+ # ----------------------------- #
49
+ def send_request() -> Union[int, str]:
50
+ try:
51
+ response = requests.post(
52
+ URL,
53
+ json=build_payload(),
54
+ headers={"Content-Type": "application/json"},
55
+ timeout=TIMEOUT,
56
+ )
57
+ return response.status_code
58
+ except requests.RequestException as e:
59
+ return f"Error: {e}"
60
+
61
+
62
+ # ----------------------------- #
63
+ # Test Runner #
64
+ # ----------------------------- #
65
+ def test_endpoint():
66
+ print(f"Sending {NUM_REQUESTS} requests to {URL} with {MAX_WORKERS} workers")
67
+ start_time = time.time()
68
+
69
+ successful = 0
70
+ failed = 0
71
+ status_distribution = {}
72
+
73
+ with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
74
+ futures = [executor.submit(send_request) for _ in range(NUM_REQUESTS)]
75
+ for future in tqdm(
76
+ as_completed(futures), total=NUM_REQUESTS, desc="Processing"
77
+ ):
78
+ result = future.result()
79
+ if isinstance(result, int):
80
+ status_distribution[result] = status_distribution.get(result, 0) + 1
81
+ if 200 <= result < 300:
82
+ successful += 1
83
+ else:
84
+ failed += 1
85
+ else:
86
+ failed += 1
87
+ logging.warning(result)
88
+
89
+ duration = time.time() - start_time
90
+ print("\n--- Test Summary ---")
91
+ print(f"Elapsed Time : {duration:.2f} seconds")
92
+ print(f"Total Requests Sent : {NUM_REQUESTS}")
93
+ print(f"Successful Requests : {successful}")
94
+ print(f"Failed Requests : {failed}")
95
+ print(f"Status Code Summary : {status_distribution}")
96
+
97
+
98
+ if __name__ == "__main__":
99
+ test_endpoint()