jebin511 commited on
Commit
8e8354c
·
verified ·
1 Parent(s): 4dfdaf5

Rename app.py to main.py

Browse files
Files changed (2) hide show
  1. app.py +0 -0
  2. main.py +86 -0
app.py DELETED
File without changes
main.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from PIL import Image
4
+ import tempfile
5
+ import os
6
+ import base64
7
+ import cv2
8
+ import io
9
+ import re
10
+ from together import Together
11
+ import releaf_ai # this should still contain your SYSTEM_PROMPT
12
+
13
+ app = FastAPI()
14
+
15
+ API_KEY = "your_api_key_here"
16
+ client = Together(api_key=API_KEY)
17
+ MODEL_NAME = "meta-llama/Llama-Vision-Free"
18
+
19
+ SYSTEM_PROMPT = releaf_ai.SYSTEM_PROMPT
20
+
21
+ def encode_image_to_base64(image: Image.Image) -> str:
22
+ buffered = io.BytesIO()
23
+ image.save(buffered, format="JPEG")
24
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
25
+
26
+ def extract_score(text: str):
27
+ match = re.search(r"(?i)Score:\s*(\d+)", text)
28
+ return int(match.group(1)) if match else None
29
+
30
+ def extract_activity(text: str):
31
+ match = re.search(r"(?i)Detected Activity:\s*(.+?)\n", text)
32
+ return match.group(1).strip() if match else "Unknown"
33
+
34
+ @app.post("/predict")
35
+ async def predict(file: UploadFile = File(...)):
36
+ try:
37
+ if file.content_type.startswith("image"):
38
+ image = Image.open(io.BytesIO(await file.read())).convert("RGB")
39
+
40
+ elif file.content_type.startswith("video"):
41
+ temp_path = tempfile.NamedTemporaryFile(delete=False).name
42
+ with open(temp_path, "wb") as f:
43
+ f.write(await file.read())
44
+
45
+ cap = cv2.VideoCapture(temp_path)
46
+ total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
47
+ interval = max(total // 9, 1)
48
+
49
+ frames = []
50
+ for i in range(9):
51
+ cap.set(cv2.CAP_PROP_POS_FRAMES, i * interval)
52
+ ret, frame = cap.read()
53
+ if ret:
54
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
55
+ img = Image.fromarray(frame).resize((256, 256))
56
+ frames.append(img)
57
+ cap.release()
58
+ os.remove(temp_path)
59
+
60
+ w, h = frames[0].size
61
+ grid = Image.new("RGB", (3 * w, 3 * h))
62
+ for idx, frame in enumerate(frames):
63
+ grid.paste(frame, ((idx % 3) * w, (idx // 3) * h))
64
+ image = grid
65
+
66
+ else:
67
+ raise HTTPException(status_code=400, detail="Unsupported file type")
68
+
69
+ b64_img = encode_image_to_base64(image)
70
+ messages = [
71
+ {"role": "system", "content": SYSTEM_PROMPT},
72
+ {"role": "user", "content": [
73
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}}
74
+ ]}
75
+ ]
76
+ res = client.chat.completions.create(model=MODEL_NAME, messages=messages)
77
+ reply = res.choices[0].message.content
78
+
79
+ return JSONResponse({
80
+ "points": extract_score(reply),
81
+ "task": extract_activity(reply),
82
+ "raw": reply
83
+ })
84
+
85
+ except Exception as e:
86
+ raise HTTPException(status_code=500, detail=str(e))