import argparse | |
import gradio as gr | |
from openai import OpenAI | |
import os | |
# Argument parser setup | |
#parser = argparse.ArgumentParser( | |
#description='Chatbot Interface with Customizable Parameters') | |
#parser.add_argument('--model-url', | |
#type=str, | |
#default='http://134.28.190.100:8000/v1', | |
#help='Model URL') | |
#parser.add_argument('-m', | |
#'--model', | |
#type=str, | |
#required=True, | |
#default='TheBloke/Mistral-7B-Instruct-v0.2-AWQ', | |
#help='Model name for the chatbot') | |
#parser.add_argument('--temp', | |
#type=float, | |
#default=0.8, | |
#help='Temperature for text generation') | |
##parser.add_argument('--stop-token-ids', | |
#type=str, | |
#default='', | |
#help='Comma-separated stop token IDs') | |
#parser.add_argument("--host", type=str, default=None) | |
#parser.add_argument("--port", type=int, default=8001) | |
# Parse the arguments | |
#args = parser.parse_args() | |
model_url = os.getenv('MODEL_URL', 'http://localhost:8000/v1') | |
model_name = os.getenv('MODEL_NAME', 'default-model-name') # Make sure to set this in the environment | |
temperature = float(os.getenv('TEMPERATURE', 0.8)) | |
stop_token_ids = os.getenv('STOP_TOKEN_IDS', '') | |
#host = os.getenv('HOST','0.0.0.0') | |
#port_str = os.getenv('PORT', '8001') | |
#try: | |
#port = int(port_str) | |
#except ValueError: | |
#port = 8001 | |
#port = int(os.getenv('PORT', 8001)) | |
# Set OpenAI's API key and API base to use vLLM's API server. | |
openai_api_key = "EMPTY" | |
openai_api_base = model_url | |
# Create an OpenAI client to interact with the API server | |
client = OpenAI( | |
api_key=openai_api_key, | |
base_url=openai_api_base, | |
) | |
# def add_document(): | |
def predict(message, history): | |
# Convert chat history to OpenAI format | |
history_openai_format = []#[{ | |
#"role": "system", | |
#"content": "You are a great ai assistant." | |
#}] | |
for human, assistant in history: | |
history_openai_format.append({"role": "user", "content": human}) | |
history_openai_format.append({ | |
"role": "assistant", | |
"content": assistant | |
}) | |
history_openai_format.append({"role": "user", "content": message}) | |
# Create a chat completion request and send it to the API server | |
stream = client.chat.completions.create( | |
model=args.model, # Model name to use | |
messages=history_openai_format, # Chat history | |
temperature=args.temp, # Temperature for text generation | |
stream=True, # Stream response | |
extra_body={ | |
'repetition_penalty': | |
1, | |
'stop_token_ids': [ | |
int(id.strip()) for id in args.stop_token_ids.split(',') | |
if id.strip() | |
] if args.stop_token_ids else [] | |
}) | |
# Read and return generated text from response stream | |
partial_message = "" | |
for chunk in stream: | |
partial_message += (chunk.choices[0].delta.content or "") | |
yield partial_message | |
with gr.Blocks(title="MethodAI 0.15", theme="Soft") as demo: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.UploadButton("Click to upload PDFs",file_types=[".pdf"]) | |
with gr.Column(scale=4): | |
# Create and launch a chat interface with Gradio | |
gr.ChatInterface(predict).queue() | |
# with demo: | |
# btn.upload(render_file, inputs=[btn], outputs=[show_img]) | |
demo.launch(share=True) | |