Spaces:
Sleeping
Sleeping
Commit
·
56f8c1c
1
Parent(s):
988393e
added main and test codes
Browse files
main.py
CHANGED
|
@@ -1,6 +1,126 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
import torch
|
| 3 |
+
import asyncio
|
| 4 |
+
import transformers
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from fastapi import FastAPI
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
from contextlib import asynccontextmanager
|
| 9 |
+
from transformers import (
|
| 10 |
+
pipeline,
|
| 11 |
+
AutoTokenizer,
|
| 12 |
+
AutoModelForSequenceClassification,
|
| 13 |
+
BitsAndBytesConfig,
|
| 14 |
+
)
|
| 15 |
|
| 16 |
+
# ----------------------------- #
|
| 17 |
+
# Configurations #
|
| 18 |
+
# ----------------------------- #
|
| 19 |
+
transformers.set_seed(42)
|
| 20 |
+
torch.set_default_dtype(torch.bfloat16)
|
| 21 |
|
| 22 |
+
MODEL_NAME = "climatebert/distilroberta-base-climate-sentiment"
|
| 23 |
+
BATCH_PROCESS_INTERVAL = 0.01
|
| 24 |
+
MAX_BATCH_SIZE = 128
|
| 25 |
+
|
| 26 |
+
# ----------------------------- #
|
| 27 |
+
# Shared Storage #
|
| 28 |
+
# ----------------------------- #
|
| 29 |
+
query_queue: asyncio.Queue = asyncio.Queue()
|
| 30 |
+
results: Dict[str, Dict] = {}
|
| 31 |
+
classifier = None # will be initialized in lifespan
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ----------------------------- #
|
| 35 |
+
# Model Initialization #
|
| 36 |
+
# ----------------------------- #
|
| 37 |
+
def load_classifier(model_name: str):
|
| 38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 39 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 40 |
+
model_name,
|
| 41 |
+
device_map="auto",
|
| 42 |
+
quantization_config=BitsAndBytesConfig(
|
| 43 |
+
load_in_4bit=True,
|
| 44 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 45 |
+
),
|
| 46 |
+
)
|
| 47 |
+
return pipeline(
|
| 48 |
+
"text-classification", model=model, tokenizer=tokenizer, framework="pt"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ----------------------------- #
|
| 53 |
+
# Pydantic Schema #
|
| 54 |
+
# ----------------------------- #
|
| 55 |
+
class Query(BaseModel):
|
| 56 |
+
sentence: str
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ----------------------------- #
|
| 60 |
+
# Queue Processing Task #
|
| 61 |
+
# ----------------------------- #
|
| 62 |
+
async def process_queue():
|
| 63 |
+
while True:
|
| 64 |
+
await asyncio.sleep(BATCH_PROCESS_INTERVAL)
|
| 65 |
+
|
| 66 |
+
batch = []
|
| 67 |
+
while not query_queue.empty() and len(batch) < MAX_BATCH_SIZE:
|
| 68 |
+
batch.append(await query_queue.get())
|
| 69 |
+
|
| 70 |
+
if not batch:
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
sentences = [item["sentence"] for item in batch]
|
| 74 |
+
ids = [item["id"] for item in batch]
|
| 75 |
+
predictions = classifier(sentences, batch_size=len(sentences))
|
| 76 |
+
|
| 77 |
+
for query_id, pred, sentence in zip(ids, predictions, sentences):
|
| 78 |
+
results[query_id] = {
|
| 79 |
+
"sentence": sentence,
|
| 80 |
+
"label": pred["label"],
|
| 81 |
+
"score": pred["score"],
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ----------------------------- #
|
| 86 |
+
# Lifespan Handler #
|
| 87 |
+
# ----------------------------- #
|
| 88 |
+
@asynccontextmanager
|
| 89 |
+
async def lifespan(app: FastAPI):
|
| 90 |
+
global classifier
|
| 91 |
+
classifier = load_classifier(MODEL_NAME)
|
| 92 |
+
_ = classifier("Startup warm-up sentence.")
|
| 93 |
+
queue_task = asyncio.create_task(process_queue())
|
| 94 |
+
yield
|
| 95 |
+
queue_task.cancel()
|
| 96 |
+
try:
|
| 97 |
+
await queue_task
|
| 98 |
+
except asyncio.CancelledError:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ----------------------------- #
|
| 103 |
+
# FastAPI Setup #
|
| 104 |
+
# ----------------------------- #
|
| 105 |
+
app = FastAPI(lifespan=lifespan)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ----------------------------- #
|
| 109 |
+
# API Endpoints #
|
| 110 |
+
# ----------------------------- #
|
| 111 |
+
@app.post("/classify")
|
| 112 |
+
async def classify(query: Query):
|
| 113 |
+
query_id = str(uuid.uuid4())
|
| 114 |
+
await query_queue.put({"id": query_id, "sentence": query.sentence})
|
| 115 |
+
|
| 116 |
+
while query_id not in results:
|
| 117 |
+
await asyncio.sleep(0.001)
|
| 118 |
+
|
| 119 |
+
return {"id": query_id, "result": results.pop(query_id)}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@app.get("/")
|
| 123 |
+
def read_root():
|
| 124 |
+
return {
|
| 125 |
+
"message": "Welcome to the Sentiment Classification API with Query Batching"
|
| 126 |
+
}
|
test.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import random
|
| 4 |
+
import requests
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Union, Tuple
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 8 |
+
from tqdm import tqdm # optional: install via `pip install tqdm`
|
| 9 |
+
|
| 10 |
+
# ----------------------------- #
|
| 11 |
+
# Configuration #
|
| 12 |
+
# ----------------------------- #
|
| 13 |
+
URL = "http://localhost:8000/classify"
|
| 14 |
+
NUM_REQUESTS = 4000
|
| 15 |
+
MAX_WORKERS = os.cpu_count() * 8 or 2
|
| 16 |
+
TIMEOUT = 20
|
| 17 |
+
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
|
| 20 |
+
# Sample text tokens
|
| 21 |
+
SAMPLE_POPULATION = """Limes have higher contents of sugars and acids than lemons do.[1] Lime juice may be squeezed from fresh limes, or purchased in bottles in both unsweetened and sweetened varieties. Lime juice is used to make limeade, and as an ingredient (typically as sour mix) in many cocktails.
|
| 22 |
+
|
| 23 |
+
Lime pickles are an integral part of Indian cuisine, especially in South India. In Kerala, the Onam Sadhya usually includes either lemon pickle or lime pickle. Other Indian preparations of limes include sweetened lime pickle, salted pickle, and lime chutney.
|
| 24 |
+
|
| 25 |
+
In cooking, lime is valued both for the acidity of its juice and the floral aroma of its zest. It is a common ingredient in authentic Mexican, Vietnamese and Thai dishes. Lime soup is a traditional dish from the Mexican state of Yucatan. It is also used for its pickling properties in ceviche. Some guacamole recipes call for lime juice.
|
| 26 |
+
|
| 27 |
+
The use of dried limes (called black lime or limoo) as a flavouring is typical of Persian cuisine, Iraqi cuisine, as well as in Eastern Arabian cuisine baharat (a spice mixture that is also called kabsa or kebsa).
|
| 28 |
+
|
| 29 |
+
Key lime gives the character flavouring to the American dessert known as Key lime pie. In Australia, desert lime is used for making marmalade.
|
| 30 |
+
|
| 31 |
+
Lime is an ingredient in several highball cocktails, often based on gin, such as gin and tonic, the gimlet and the Rickey. Freshly squeezed lime juice is also considered a key ingredient in margaritas, although sometimes lemon juice is substituted. It is also found in many rum cocktails such as the daiquiri, and other tropical drinks.
|
| 32 |
+
|
| 33 |
+
Lime extracts and lime essential oils are frequently used in perfumes, cleaning products, and aromatherapy.""".split()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ----------------------------- #
|
| 37 |
+
# Request Builder #
|
| 38 |
+
# ----------------------------- #
|
| 39 |
+
def build_payload() -> dict:
|
| 40 |
+
sentence = " ".join(
|
| 41 |
+
random.choices(SAMPLE_POPULATION, k=random.randint(20, len(SAMPLE_POPULATION)))
|
| 42 |
+
)
|
| 43 |
+
return {"sentence": sentence}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ----------------------------- #
|
| 47 |
+
# Request Sender Logic #
|
| 48 |
+
# ----------------------------- #
|
| 49 |
+
def send_request() -> Union[int, str]:
|
| 50 |
+
try:
|
| 51 |
+
response = requests.post(
|
| 52 |
+
URL,
|
| 53 |
+
json=build_payload(),
|
| 54 |
+
headers={"Content-Type": "application/json"},
|
| 55 |
+
timeout=TIMEOUT,
|
| 56 |
+
)
|
| 57 |
+
return response.status_code
|
| 58 |
+
except requests.RequestException as e:
|
| 59 |
+
return f"Error: {e}"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ----------------------------- #
|
| 63 |
+
# Test Runner #
|
| 64 |
+
# ----------------------------- #
|
| 65 |
+
def test_endpoint():
|
| 66 |
+
print(f"Sending {NUM_REQUESTS} requests to {URL} with {MAX_WORKERS} workers")
|
| 67 |
+
start_time = time.time()
|
| 68 |
+
|
| 69 |
+
successful = 0
|
| 70 |
+
failed = 0
|
| 71 |
+
status_distribution = {}
|
| 72 |
+
|
| 73 |
+
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
|
| 74 |
+
futures = [executor.submit(send_request) for _ in range(NUM_REQUESTS)]
|
| 75 |
+
for future in tqdm(
|
| 76 |
+
as_completed(futures), total=NUM_REQUESTS, desc="Processing"
|
| 77 |
+
):
|
| 78 |
+
result = future.result()
|
| 79 |
+
if isinstance(result, int):
|
| 80 |
+
status_distribution[result] = status_distribution.get(result, 0) + 1
|
| 81 |
+
if 200 <= result < 300:
|
| 82 |
+
successful += 1
|
| 83 |
+
else:
|
| 84 |
+
failed += 1
|
| 85 |
+
else:
|
| 86 |
+
failed += 1
|
| 87 |
+
logging.warning(result)
|
| 88 |
+
|
| 89 |
+
duration = time.time() - start_time
|
| 90 |
+
print("\n--- Test Summary ---")
|
| 91 |
+
print(f"Elapsed Time : {duration:.2f} seconds")
|
| 92 |
+
print(f"Total Requests Sent : {NUM_REQUESTS}")
|
| 93 |
+
print(f"Successful Requests : {successful}")
|
| 94 |
+
print(f"Failed Requests : {failed}")
|
| 95 |
+
print(f"Status Code Summary : {status_distribution}")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
test_endpoint()
|