File size: 3,255 Bytes
1634927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
# from huggingface_hub import InferenceClient # Remove this line
import json
import os
from groq import Groq # Import the Groq client

app = FastAPI()

# Initialize the Groq client
# It's recommended to set GROQ_API_KEY environment variable
client = Groq(
    api_key=os.environ.get("GROQ_API_KEY"),
)

SYSTEM_MESSAGE = (
    "You are a helpful, respectful and honest assistant. Always answer as helpfully "
    "as possible, while being safe. Your answers should not include any harmful, "
    "unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure "
    "that your responses are socially unbiased and positive in nature.\n\nIf a question "
    "does not make any sense, or is not factually coherent, explain why instead of "
    "answering something not correct. If you don't know the answer to a question, please "
    "don't share false information."
    "Always respond in the language of user prompt for each prompt ."
)
MAX_TOKENS = 2000
TEMPERATURE = 0.7
TOP_P = 0.95
# Set the Groq model name
GROQ_MODEL_NAME = "llama3-8b-8192" # This is the correct model name [1, 2, 8]

def respond(message, history: list[tuple[str, str]]):
    messages = [{"role": "system", "content": SYSTEM_MESSAGE}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    # Use the Groq client for chat completion
    # Set stream=True to get a streaming response [4, 12, 13]
    response = client.chat.completions.create(
        messages=messages,
        model=GROQ_MODEL_NAME,
        max_tokens=MAX_TOKENS,
        stream=True,
        temperature=TEMPERATURE,
        top_p=TOP_P,
    )

    # Iterate over the streaming response
    for chunk in response:
         if chunk.choices and chunk.choices[0].delta.content is not None:
            yield chunk.choices[0].delta.content


from fastapi.middleware.cors import CORSMiddleware

app.add_middleware(
    CORSMiddleware,
    allow_origins=["https://artixiban-ll3.static.hf.space"],  # Allow only this origin
    allow_credentials=True,
    allow_methods=["*"],  # Allow all methods (GET, POST, etc.)
    allow_headers=["*"],  # Allow all headers
)

@app.post("/generate/")
async def generate(request: Request):
    allowed_origin = "https://artixiban-ll3.static.hf.space"
    origin = request.headers.get("origin")
    if origin != allowed_origin:
        raise HTTPException(status_code=403, detail="Origin not allowed")
    form = await request.form()
    prompt = form.get("prompt")
    history = json.loads(form.get("history", "[]"))  # Default to empty history

    if not prompt:
        raise HTTPException(status_code=400, detail="Prompt is required")

    response_generator = respond(prompt, history)
    final_response = ""
    # The respond function is already a generator yielding chunks
    for part in response_generator:
        final_response += part

    return JSONResponse(content={"response": final_response})