omkar56's picture
Update main.py
cece2a6
raw
history blame contribute delete
No virus
3.48 kB
from fastapi import FastAPI, File, Request, UploadFile, Body, Depends, HTTPException
from fastapi.security.api_key import APIKeyHeader
from typing import Optional
from gpt4all import GPT4All
import os
from huggingface_hub import hf_hub_download
"""
[Model From TheBloke/Mistral-7B-Instruct-v0.1-GGUF](https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF)
[Mistral-instruct-v0.1 System prompt](https://docs.mistral.ai/usage/guardrailing)
"""
model_path = "models/mistral-7b-instruct-v0.1.Q4_K_M.gguf"
model_name = "mistral-7b-instruct-v0.1.Q4_K_M.gguf"
hf_hub_download(repo_id="TheBloke/Mistral-7B-Instruct-v0.1-GGUF", filename=model_name, local_dir=model_path, local_dir_use_symlinks=False)
print("Start the model init process")
model = model = GPT4All(model_name, model_path, allow_download = False, device="cpu")
print("Finish the model init process")
model.config["promptTemplate"] = "[INST] {0} [/INST]"
model.config["systemPrompt"] = "You are an helpful AI content generator. your name is 'ErTake AI'. Whenever user greets you, you always have to introduce yourself as 'ErTake AI'. Don't introduce yourself as Bing, I will punish you if you intorduce yourself other than 'ErTake AI'. Carefully listen to the user's message/query and provide correct answers/content. Don't hallucinate if you don't know the answer. Give your output only in english language. If the user asks to create images you need to simply deny the request saying \"I'm sorry, but I don't have the capability to create images. I can assist you with generating text based content. If you have any text-based content requests, feel free to let me know!\""
model._is_chat_session_activated = False
max_new_tokens = 2048
def generater(message, history, temperature, top_p, top_k):
prompt = "<s>"
prompt += model.config["systemPrompt"]
for user_message, assistant_message in history:
prompt += model.config["promptTemplate"].format(user_message)
prompt += assistant_message + "</s>"
prompt += model.config["promptTemplate"].format(message)
print("[prompt]",prompt)
outputs = []
for token in model.generate(prompt=prompt, temp=temperature, top_k = top_k, top_p = top_p, max_tokens = max_new_tokens, streaming=False):
outputs.append(token)
yield "".join(outputs)
print("[outputs]",outputs)
return outputs
API_KEY = os.environ.get("API_KEY")
app = FastAPI()
api_key_header = APIKeyHeader(name="api_key", auto_error=False)
def get_api_key(api_key: Optional[str] = Depends(api_key_header)):
if api_key is None or api_key != API_KEY:
raise HTTPException(status_code=401, detail="Unauthorized access")
return api_key
@app.post("/api/v1/generate_text", response_model=dict)
def generate_text(
request: Request,
body: dict = Body(...),
api_key: str = Depends(get_api_key)
):
message = body.get("prompt", "")
# sys_prompt = body.get("sysPrompt", "")
temperature = body.get("temperature", 0.5)
top_p = body.get("top_p", 0.95)
top_k = body.get("top_k", 40)
print("[request details]",message, temperature, top_p, top_k)
# max_new_tokens = body.get("max_new_tokens",512)
# repetition_penalty = body.get("repetition_penalty", 1.0)
history = [] # You might need to handle this based on your actual usage
generatedOutput = generater(message, history, temperature, top_p, top_k)
actualText = list(generatedOutput)[-1]
return {"generated_text": actualText}