Makhinur commited on
Commit
4c8d5ac
·
verified ·
1 Parent(s): 7a1f650

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +90 -0
main.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, HTTPException
2
+ from fastapi.responses import JSONResponse, FileResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ # from huggingface_hub import InferenceClient # Remove this line
5
+ import json
6
+ import os
7
+ from groq import Groq # Import the Groq client
8
+
9
+ app = FastAPI()
10
+
11
+ # Initialize the Groq client
12
+ # It's recommended to set GROQ_API_KEY environment variable
13
+ client = Groq(
14
+ api_key=os.environ.get("GROQ_API_KEY"),
15
+ )
16
+
17
+ SYSTEM_MESSAGE = (
18
+ "You are a helpful, respectful and honest assistant. Always answer as helpfully "
19
+ "as possible, while being safe. Your answers should not include any harmful, "
20
+ "unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure "
21
+ "that your responses are socially unbiased and positive in nature.\n\nIf a question "
22
+ "does not make any sense, or is not factually coherent, explain why instead of "
23
+ "answering something not correct. If you don't know the answer to a question, please "
24
+ "don't share false information."
25
+ "Always respond in the language of user prompt for each prompt ."
26
+ )
27
+ MAX_TOKENS = 2000
28
+ TEMPERATURE = 0.7
29
+ TOP_P = 0.95
30
+ # Set the Groq model name
31
+ GROQ_MODEL_NAME = "llama3-8b-8192" # This is the correct model name [1, 2, 8]
32
+
33
+ def respond(message, history: list[tuple[str, str]]):
34
+ messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
35
+
36
+ for val in history:
37
+ if val[0]:
38
+ messages.append({"role": "user", "content": val[0]})
39
+ if val[1]:
40
+ messages.append({"role": "assistant", "content": val[1]})
41
+
42
+ messages.append({"role": "user", "content": message})
43
+
44
+ # Use the Groq client for chat completion
45
+ # Set stream=True to get a streaming response [4, 12, 13]
46
+ response = client.chat.completions.create(
47
+ messages=messages,
48
+ model=GROQ_MODEL_NAME,
49
+ max_tokens=MAX_TOKENS,
50
+ stream=True,
51
+ temperature=TEMPERATURE,
52
+ top_p=TOP_P,
53
+ )
54
+
55
+ # Iterate over the streaming response
56
+ for chunk in response:
57
+ if chunk.choices and chunk.choices[0].delta.content is not None:
58
+ yield chunk.choices[0].delta.content
59
+
60
+
61
+ from fastapi.middleware.cors import CORSMiddleware
62
+
63
+ app.add_middleware(
64
+ CORSMiddleware,
65
+ allow_origins=["https://artixiban-ll3.static.hf.space"], # Allow only this origin
66
+ allow_credentials=True,
67
+ allow_methods=["*"], # Allow all methods (GET, POST, etc.)
68
+ allow_headers=["*"], # Allow all headers
69
+ )
70
+
71
+ @app.post("/generate/")
72
+ async def generate(request: Request):
73
+ allowed_origin = "https://artixiban-ll3.static.hf.space"
74
+ origin = request.headers.get("origin")
75
+ if origin != allowed_origin:
76
+ raise HTTPException(status_code=403, detail="Origin not allowed")
77
+ form = await request.form()
78
+ prompt = form.get("prompt")
79
+ history = json.loads(form.get("history", "[]")) # Default to empty history
80
+
81
+ if not prompt:
82
+ raise HTTPException(status_code=400, detail="Prompt is required")
83
+
84
+ response_generator = respond(prompt, history)
85
+ final_response = ""
86
+ # The respond function is already a generator yielding chunks
87
+ for part in response_generator:
88
+ final_response += part
89
+
90
+ return JSONResponse(content={"response": final_response})