Midnightar commited on
Commit
889f133
·
verified ·
1 Parent(s): b718163

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -5,12 +5,21 @@ from fastapi.responses import JSONResponse, HTMLResponse
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
  from PIL import Image
7
 
8
- # Set Hugging Face cache to avoid permission issues
 
9
  os.environ["HF_HOME"] = "/tmp/hf_cache"
 
 
 
 
10
 
11
  # Load processor + model
12
- processor = AutoImageProcessor.from_pretrained("prithivMLmods/Realistic-Gender-Classification")
13
- model = AutoModelForImageClassification.from_pretrained("prithivMLmods/Realistic-Gender-Classification")
 
 
 
 
14
 
15
  # Create FastAPI app
16
  app = FastAPI()
@@ -43,10 +52,10 @@ async def predict(file: UploadFile = File(...)):
43
  outputs = model(**inputs)
44
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
45
 
46
- # Get labels (ensure consistent order)
47
  labels = list(model.config.id2label.values())
48
 
49
- # Fix keys: return "male" and "female" only
50
  result = {
51
  "female": float(probs[labels.index("female portrait")]),
52
  "male": float(probs[labels.index("male portrait")])
 
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
  from PIL import Image
7
 
8
+ # Force cache to /tmp/hf_cache before anything else
9
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
10
  os.environ["HF_HOME"] = "/tmp/hf_cache"
11
+ os.environ["HF_DATASETS_CACHE"] = "/tmp/hf_cache"
12
+
13
+ # Create cache directory if missing
14
+ os.makedirs("/tmp/hf_cache", exist_ok=True)
15
 
16
  # Load processor + model
17
+ processor = AutoImageProcessor.from_pretrained(
18
+ "prithivMLmods/Realistic-Gender-Classification", cache_dir="/tmp/hf_cache"
19
+ )
20
+ model = AutoModelForImageClassification.from_pretrained(
21
+ "prithivMLmods/Realistic-Gender-Classification", cache_dir="/tmp/hf_cache"
22
+ )
23
 
24
  # Create FastAPI app
25
  app = FastAPI()
 
52
  outputs = model(**inputs)
53
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
54
 
55
+ # Get labels
56
  labels = list(model.config.id2label.values())
57
 
58
+ # Clean result for FlutterFlow
59
  result = {
60
  "female": float(probs[labels.index("female portrait")]),
61
  "male": float(probs[labels.index("male portrait")])