bsny commited on
Commit
cdf6424
·
verified ·
1 Parent(s): 57939ca

Added Groq endpoint for threat assesment

Browse files
Files changed (1) hide show
  1. app.py +53 -46
app.py CHANGED
@@ -1,58 +1,65 @@
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import torch
5
- import uuid
6
  import os
 
7
 
8
- # FastAPI app setup
9
  app = FastAPI()
10
 
11
- # Use HF cache location that's safe in HF Spaces
12
- os.environ["HF_HOME"] = "/data/huggingface"
13
 
14
- # Use a CPU-compatible model (non-GPTQ)
15
- model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
16
- hf_token = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
 
 
 
17
 
18
- # Load model and tokenizer (no GPU-specific args)
19
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
20
- model = AutoModelForCausalLM.from_pretrained(
21
- model_id,
22
- token=hf_token
23
- ).to("cpu")
24
 
25
- # In-memory store for system prompts per session
26
- session_prompts = {}
 
 
 
27
 
28
- # Request body models
29
- class SystemPrompt(BaseModel):
30
- prompt: str
 
 
 
 
31
 
32
- class UserMessage(BaseModel):
33
- session_id: str
34
- message: str
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- @app.post("/start")
37
- def start_chat(system_prompt: SystemPrompt):
38
- session_id = str(uuid.uuid4())
39
- session_prompts[session_id] = system_prompt.prompt
40
- return {"session_id": session_id}
41
-
42
- @app.post("/chat")
43
- def chat(message: UserMessage):
44
- system = session_prompts.get(message.session_id)
45
- if not system:
46
- return {"error": "Invalid session_id. Call /start first."}
47
-
48
- full_prompt = f"<|system|>\n{system}\n<|user|>\n{message.message}\n<|assistant|>\n"
49
-
50
- inputs = tokenizer(full_prompt, return_tensors="pt").to("cpu")
51
- outputs = model.generate(
52
- **inputs,
53
- max_new_tokens=200,
54
- pad_token_id=tokenizer.eos_token_id,
55
- )
56
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
57
- answer = response.replace(full_prompt.strip(), "").strip()
58
- return {"response": answer}
 
1
+ # app.py
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
 
 
 
4
  import os
5
+ import openai
6
 
 
7
  app = FastAPI()
8
 
9
+ # Environment Variables
10
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
11
 
12
+ # Model Setup
13
+ def generate_response(system_prompt: str, user_message: str):
14
+ client = openai.OpenAI(api_key=GROQ_API_KEY, base_url="https://api.groq.com/openai/v1")
15
+ response = client.chat.completions.create(
16
+ model="mixtral-8x7b-32768",
17
+ messages=[
18
+ {"role": "system", "content": system_prompt},
19
+ {"role": "user", "content": user_message}
20
+ ],
21
+ temperature=0.4
22
+ )
23
+ return response.choices[0].message.content
24
 
25
+ # Request model
26
+ class Message(BaseModel):
27
+ message: str
 
 
 
28
 
29
+ @app.post("/bia/threat-assessment")
30
+ def bia_threat_assessment(req: Message):
31
+ prompt = """
32
+ You are a cybersecurity and geopolitical risk analyst AI working on Business Impact Assessment (BIA).
33
+ Given a paragraph, do the following:
34
 
35
+ 1. Identify the **place** mentioned in the text.
36
+ 2. List likely **threats** specific to that place and context.
37
+ 3. For each threat:
38
+ - Give a **likelihood rating (1–5)**.
39
+ - Give a **severity rating (1–5)**.
40
+ - Describe the **potential impact**.
41
+ - Compute **threat rating = likelihood × severity**.
42
 
43
+ Respond strictly in this JSON format:
44
+ {
45
+ "place": "<place>",
46
+ "threats": [
47
+ {
48
+ "name": "<threat name>",
49
+ "likelihood": <1-5>,
50
+ "severity": <1-5>,
51
+ "impact": "<impact statement>",
52
+ "threat_rating": <likelihood * severity>
53
+ }
54
+ ]
55
+ }
56
+ """
57
+ result = generate_response(prompt, req.message)
58
+ return result
59
 
60
+ @app.post("/bia/impact-analysis")
61
+ def bia_impact_analysis(req: Message):
62
+ return {
63
+ "status": "placeholder",
64
+ "note": "This endpoint is reserved for BIA impact analysis logic."
65
+ }