from llama_cpp import Llama import gradio as gr import os import json GRADIO_SHOW_API_INFO = os.getenv("AGS_SHOW_API_INFO") or True AGS_REPO = os.getenv("AGS_REPO") or "lmstudio-community/gemma-1.1-2b-it-GGUF" AGS_FILENAME = os.getenv("AGS_FILENAME") or "gemma-1.1-2b-it-Q4_K_M.gguf" AGS_LLAMA_CONFIG = { "prompt_format": "raw" } try: AGS_LLAMA_CONFIG = json.loads(os.getenv("AGS_LLAMA_CONFIG")) except Exception as e: if AGS_LLAMA_CONFIG and AGS_LLAMA_CONFIG is not None: print("Invalid Llama config. Config must be valid JSON. Got:\n", AGS_LLAMA_CONFIG) ARGS_LLAMA_CONFIG = { "n_gpu_layers": 0, } def main(): llm = Llama.from_pretrained(AGS_REPO, filename=AGS_FILENAME, **ARGS_LLAMA_CONFIG) def api_chat(inpt, settings): res = None try: inpt = json.loads(inpt) settings = json.loads(settings) print("Request:\n" + json.dumps(inpt, indent=2) + "\n\n" + ("*_"*24)) except Exception as e: res = llm(inpt, **settings) if "@execute" in inpt and inpt['@execute']: inpt.pop("@execute", None) res = llm(json.dumps(inpt), **settings) if "messages" in inpt and inpt['messages']: res = llm.create_chat_completion(messages=inpt, **settings) if "prompt" in settings and settings["prompt"]: res = llm(inpt, **settings) if res is None: res = llm(json.dumps(inpt), **settings) if settings and "full_output" in settings and settings["full_output"]: return res if "content" in res['choices'][0]: return res['choices'][0]['content'] if "text" in res['choices'][0]: return res['choices'][0]['text'] if "message" in res['choices'][0] and "content" in res['choices'][0]['message']: return res['choices'][0]['message']['content'] return res def chat(inpt): if not inpt: return "" return llm.create_chat_completion(messages=[{ "role": "user", "content": inpt }])['choices'][0]['message']['content'] with gr.Interface(fn=chat, inputs=[inpt:=gr.Textbox()], outputs="text") as interface: with gr.Row(visible=False): shadow_input = gr.Textbox(visible=False) shadow_input.submit(api_chat, inputs=[inpt, shadow_input], api_name="api_chat") interface.launch(debug=True, show_api=GRADIO_SHOW_API_INFO) if __name__ == "__main__": main()