File size: 7,253 Bytes
63d00d3
2c53650
 
63d00d3
 
2c53650
63d00d3
1e44f81
63d00d3
 
 
 
 
 
 
fa44779
63d00d3
 
 
 
 
2c53650
63d00d3
 
2c53650
 
 
 
63d00d3
 
 
 
 
 
 
 
 
 
 
 
 
2c53650
63d00d3
 
2c53650
63d00d3
2c53650
 
 
 
 
 
 
 
 
 
 
 
63d00d3
2c53650
 
 
 
 
 
 
 
63d00d3
2c53650
 
 
 
 
 
 
 
63d00d3
 
 
2c53650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63d00d3
 
 
 
 
 
 
 
 
2c53650
 
 
63d00d3
2c53650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63d00d3
 
 
2c53650
 
 
63d00d3
2c53650
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import ast  # Import ast for safer evaluation of literals
from typing import List, Tuple, Any # Import Any for better type hint after literal_eval
from fastapi import FastAPI, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
# Removed BaseModel as we are using Form
from text_generation import Client
from deep_translator import GoogleTranslator

# Ensure the HF_TOKEN environment variable is set
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN is None:
    raise ValueError("Please set the HF_TOKEN environment variable.")

# Model and API setup
model_id = 'NousResearch/Hermes-3-Llama-3.1-8B'
API_URL = "https://api-inference.huggingface.co/models/" + model_id

client = Client(
    API_URL,
    headers={"Authorization": f"Bearer {HF_TOKEN}"},
    timeout=120 # Add a timeout for the client
)

# Correct End Of Text token for Llama 3 / 3.1
EOT_TOKEN = "<|eot_id|>"
# Expected header before the assistant's response starts in the Llama 3 format
ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>\n\n"

app = FastAPI()

# Allow CORS for your frontend application
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Change this to your frontend's URL in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant . Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
"""

# Updated get_prompt function using Llama 3.1 instruction format
def get_prompt(message: str, chat_history: List[Tuple[str, str]],
               system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
    """
    Formats the chat history and current message into the Llama 3.1 instruction format.
    """
    prompt_parts = []
    prompt_parts.append("<|begin_of_text|>")

    # Add system prompt if provided
    if system_prompt:
         prompt_parts.append(f"<|start_header_id|>system<|end_header_id|>\n\n{system_prompt}{EOT_TOKEN}")

    # Add previous chat turns
    for user_input, response in chat_history:
        # Ensure inputs/responses are strings before including them
        user_input_str = str(user_input).strip()
        response_str = str(response).strip() if response is not None else "" # Handle potential None in history

        prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{user_input_str}{EOT_TOKEN}")
        # Ensure response is not empty before adding assistant turn
        if response_str:
             prompt_parts.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n{response_str}{EOT_TOKEN}")

    # Add current user message and prepare for assistant response
    message_str = str(message).strip()
    prompt_parts.append(f"<|start_header_id|>user<|end_header_id|>\n\n{message_str}{EOT_TOKEN}")
    prompt_parts.append(ASSISTANT_HEADER) # This is where the model starts generating

    return "".join(prompt_parts)

# Keep app.post with Form data parameters
@app.post("/generate/")
async def generate_response(prompt: str = Form(...), history: str = Form(...)):
    try:
        # --- SAFELY Parse History ---
        # Replace eval() with ast.literal_eval() for safety
        # It can safely evaluate strings containing Python literals (like lists, tuples, strings, numbers, dicts, booleans, None)
        try:
            parsed_history: Any = ast.literal_eval(history)

            # Basic validation to ensure it looks like the expected format
            if not isinstance(parsed_history, list):
                raise ValueError("History is not a list.")
            # You could add more checks, e.g., if items are tuples of strings

            chat_history: List[Tuple[str, str]] = [(str(u), str(a)) for u, a in parsed_history] # Ensure elements are strings

        except (ValueError, SyntaxError, TypeError) as e:
            # Catch errors if the history string is not a valid literal or not the right structure
            raise HTTPException(status_code=400, detail=f"Invalid history format: {e}")
        # --- End Safely Parse History ---

        system_prompt = DEFAULT_SYSTEM_PROMPT # You could make this configurable too
        message = prompt

        prompt_text = get_prompt(message, chat_history, system_prompt)

        generate_kwargs = dict(
            max_new_tokens=1024,
            do_sample=True,
            top_p=0.9,
            top_k=50,
            temperature=0.1, # Keep temperature low for more predictable output
            stop_sequences=[EOT_TOKEN], # Explicitly tell the API to stop at EOT
            # return_full_text=False # Might need to experiment with this depending on API behavior
        )

        # Using generate (non-streaming) is simpler for post-processing
        # If you need streaming, the logic below needs to be adapted for the stream loop
        # Let's use generate first as it makes cleaning easier
        output_obj = client.generate(prompt_text, **generate_kwargs)
        output = output_obj.generated_text # Get the final generated text

        # --- Post-processing the output ---
        # The model *should* generate only the assistant response after the ASSISTANT_HEADER.
        # However, sometimes leading whitespace or unexpected tokens can occur.
        # Let's strip potential leading whitespace and the expected header if it's accidentally generated.
        # Crucially, remove the ASSISTANT_HEADER that was part of the prompt structure
        expected_start = ASSISTANT_HEADER.strip()
        if output.startswith(expected_start):
             output = output[len(expected_start):].strip()
        else:
             # Fallback: If it doesn't start with the expected header, just strip leading/trailing whitespace
             output = output.strip()


        # Remove the EOT token if it's still present at the end
        if output.endswith(EOT_TOKEN):
            output = output[:-len(EOT_TOKEN)].strip()
        # --- End Post-processing ---

        # Ensure the output is not empty before translating
        if not output:
             # If the model produced no output after cleaning, return a default or error
             translated_output = "Error: Model did not produce a response or response was filtered."
        else:
            # Translate the cleaned response to Arabic
            translator = GoogleTranslator(source='auto', target='ar')
            translated_output = translator.translate(output)

        return {"response": translated_output}

    except HTTPException as he:
         # Re-raise HTTPExceptions (like the 400 for invalid history)
         raise he
    except Exception as e:
        # Log other errors for debugging purposes on the server
        print(f"An error occurred during generation or translation: {e}")
        raise HTTPException(status_code=500, detail=f"An internal error occurred: {e}")