Pulastya0 commited on
Commit
00a3f62
·
verified ·
1 Parent(s): 8560891

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -19
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import os
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
- import torch
6
 
7
  # -------------------------------
8
  # Set Hugging Face cache to writable directory
@@ -11,7 +10,6 @@ os.environ["HF_HOME"] = "/tmp/huggingface"
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
12
  os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface"
13
 
14
-
15
  # -------------------------------
16
  # FastAPI app
17
  # -------------------------------
@@ -24,16 +22,16 @@ class RoutingRequest(BaseModel):
24
  text: str
25
 
26
  # -------------------------------
27
- # Load DeBERTa MNLI Model
28
  # -------------------------------
29
- MODEL_NAME = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
30
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
31
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
32
-
33
- # Departments mapping (example, can adjust for hackathon)
34
  DEPARTMENTS = ['Account', 'Software', 'Network', 'Security', 'Hardware',
35
- 'Infrastructure', 'Licensing', 'Communication', 'RemoteWork',
36
- 'Training', 'Performance']
 
 
 
 
 
37
 
38
  # -------------------------------
39
  # Routing Endpoint
@@ -43,16 +41,11 @@ async def route_ticket(req: RoutingRequest):
43
  text = req.text
44
  if not text:
45
  raise HTTPException(status_code=400, detail="Text cannot be empty")
46
-
47
- inputs = tokenizer(text, return_tensors="pt", truncation=True)
48
- outputs = model(**inputs)
49
- logits = outputs.logits[0]
50
 
51
- # Simple mapping: max logit → department
52
- department_idx = torch.argmax(logits).item() % len(DEPARTMENTS)
53
- department = DEPARTMENTS[department_idx]
54
 
55
- return {"department": department}
56
 
57
  # -------------------------------
58
  # Health Check
 
1
  import os
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
+ from transformers import pipeline
 
5
 
6
  # -------------------------------
7
  # Set Hugging Face cache to writable directory
 
10
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
11
  os.environ["HF_DATASETS_CACHE"] = "/tmp/huggingface"
12
 
 
13
  # -------------------------------
14
  # FastAPI app
15
  # -------------------------------
 
22
  text: str
23
 
24
  # -------------------------------
25
+ # Load Zero-Shot Classifier
26
  # -------------------------------
 
 
 
 
 
27
  DEPARTMENTS = ['Account', 'Software', 'Network', 'Security', 'Hardware',
28
+ 'Infrastructure', 'Licensing', 'Communication', 'RemoteWork',
29
+ 'Training', 'Performance']
30
+
31
+ classifier = pipeline(
32
+ "zero-shot-classification",
33
+ model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
34
+ )
35
 
36
  # -------------------------------
37
  # Routing Endpoint
 
41
  text = req.text
42
  if not text:
43
  raise HTTPException(status_code=400, detail="Text cannot be empty")
 
 
 
 
44
 
45
+ result = classifier(text, DEPARTMENTS)
46
+ department = result["labels"][0] # highest scoring department
 
47
 
48
+ return {"department": department, "scores": result["scores"]}
49
 
50
  # -------------------------------
51
  # Health Check