APITemplateGGUF / app.py
MrOvkill's picture
v0.1
d5a90b5
raw
history blame
2.6 kB
from llama_cpp import Llama
import gradio as gr
import os
import json
GRADIO_SHOW_API_INFO = os.getenv("AGS_SHOW_API_INFO") or True
AGS_REPO = os.getenv("AGS_REPO") or "lmstudio-community/gemma-1.1-2b-it-GGUF"
AGS_FILENAME = os.getenv("AGS_FILENAME") or "gemma-1.1-2b-it-Q4_K_M.gguf"
AGS_LLAMA_CONFIG = {
"prompt_format": "raw"
}
AGS_TITLE = os.getenv("AGS_TITLE") or "API GGUF Space"
try:
AGS_LLAMA_CONFIG = json.loads(os.getenv("AGS_LLAMA_CONFIG"))
except Exception as e:
if AGS_LLAMA_CONFIG and AGS_LLAMA_CONFIG is not None:
print("Invalid Llama config. Config must be valid JSON. Got:\n", AGS_LLAMA_CONFIG)
ARGS_LLAMA_CONFIG = {
"n_gpu_layers": 0,
}
def main():
llm = Llama.from_pretrained(AGS_REPO, filename=AGS_FILENAME, **ARGS_LLAMA_CONFIG)
def api_chat(inpt, settings):
res = None
try:
inpt = json.loads(inpt)
settings = json.loads(settings)
print("Request:\n" + json.dumps(inpt, indent=2) + "\n\n" + ("*_"*24))
except Exception as e:
res = llm(inpt, **settings)
if "@execute" in inpt and inpt['@execute']:
inpt.pop("@execute", None)
res = llm(json.dumps(inpt), **settings)
if "messages" in inpt and inpt['messages']:
res = llm.create_chat_completion(messages=inpt, **settings)
if "prompt" in settings and settings["prompt"]:
res = llm(inpt, **settings)
if res is None:
res = llm(json.dumps(inpt), **settings)
if settings and "full_output" in settings and settings["full_output"]:
return res
if "content" in res['choices'][0]:
return res['choices'][0]['content']
if "text" in res['choices'][0]:
return res['choices'][0]['text']
if "message" in res['choices'][0] and "content" in res['choices'][0]['message']:
return res['choices'][0]['message']['content']
return res
def chat(inpt):
if not inpt:
return ""
return llm.create_chat_completion(messages=[{
"role": "user",
"content": inpt
}])['choices'][0]['message']['content']
with gr.Interface(fn=chat, inputs=[inpt:=gr.Textbox()], outputs="text") as interface:
with gr.Row(visible=False):
shadow_input = gr.Textbox(visible=False)
shadow_input.submit(api_chat, inputs=[inpt, shadow_input], api_name="api_chat", server_name=AGS_TITLE)
interface.launch(debug=True, show_api=GRADIO_SHOW_API_INFO)
if __name__ == "__main__":
main()