Spaces:
Building
Building
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import os | |
import pandas as pd | |
from typing import List, Tuple | |
# LLM λͺ¨λΈ μ μ | |
LLM_MODELS = { | |
"Default": "CohereForAI/c4ai-command-r-plus-08-2024", # κΈ°λ³Έ λͺ¨λΈ | |
"Mistral": "mistralai/Mistral-7B-Instruct-v0.2", | |
"Zephyr": "HuggingFaceH4/zephyr-7b-beta", | |
"OpenChat": "openchat/openchat-3.5", | |
"Llama2": "meta-llama/Llama-2-7b-chat-hf", | |
"Phi": "microsoft/phi-2", | |
"Neural": "nvidia/neural-chat-7b-v3-1", | |
"Starling": "HuggingFaceH4/starling-lm-7b-alpha" | |
} | |
def get_client(model_name): | |
return InferenceClient(LLM_MODELS[model_name], token=os.getenv("HF_TOKEN")) | |
def read_uploaded_file(file): | |
if file is None: | |
return "" | |
try: | |
if file.name.endswith('.parquet'): | |
df = pd.read_parquet(file.name, engine='pyarrow') | |
return df.head(10).to_markdown(index=False) | |
else: | |
content = file.read() | |
if isinstance(content, bytes): | |
return content.decode('utf-8') | |
return content | |
except Exception as e: | |
return f"νμΌμ μ½λ μ€ μ€λ₯κ° λ°μνμ΅λλ€: {str(e)}" | |
def format_history(history): | |
formatted_history = [] | |
for user_msg, assistant_msg in history: | |
formatted_history.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
formatted_history.append({"role": "assistant", "content": assistant_msg}) | |
return formatted_history | |
def chat(message, history, uploaded_file, model_name, system_message="", max_tokens=4000, temperature=0.7, top_p=0.9): | |
system_prefix = """λ°λμ νκΈλ‘ λ΅λ³ν κ². λλ μ£Όμ΄μ§ μμ€μ½λλ λ°μ΄ν°λ₯Ό κΈ°λ°μΌλ‘ "μλΉμ€ μ¬μ© μ€λͺ λ° μλ΄, Q&Aλ₯Ό νλ μν μ΄λ€". μμ£Ό μΉμ νκ³ μμΈνκ² 4000ν ν° μ΄μ Markdown νμμΌλ‘ μμ±νλΌ. λλ μ λ ₯λ λ΄μ©μ κΈ°λ°μΌλ‘ μ¬μ© μ€λͺ λ° μ§μ μλ΅μ μ§ννλ©°, μ΄μ©μμκ² λμμ μ£Όμ΄μΌ νλ€. μ΄μ©μκ° κΆκΈν΄ ν λ§ν λ΄μ©μ μΉμ νκ² μλ €μ£Όλλ‘ νλΌ. μ 체 λ΄μ©μ λν΄μλ 보μμ μ μ§νκ³ , ν€ κ° λ° μλν¬μΈνΈμ ꡬ체μ μΈ λͺ¨λΈμ 곡κ°νμ§ λ§λΌ.""" | |
if uploaded_file: | |
content = read_uploaded_file(uploaded_file) | |
file_extension = os.path.splitext(uploaded_file.name)[1].lower() | |
if file_extension == '.parquet': | |
system_message += f"\n\nνμΌ λ΄μ©:\n```markdown\n{content}\n```" | |
else: | |
system_message += f"\n\nνμΌ λ΄μ©:\n```python\n{content}\n```" | |
if message == "νμΌ λΆμμ μμν©λλ€.": | |
message = """μ λ‘λλ νμΌμ λΆμνμ¬ λ€μ λ΄μ©μ ν¬ν¨νμ¬ μμΈν μ€λͺ νλΌ: | |
1. νμΌμ μ£Όμ λͺ©μ κ³Ό κΈ°λ₯ | |
2. μ£Όμ νΉμ§κ³Ό ꡬμ±μμ | |
3. νμ© λ°©λ² λ° μ¬μ© μλλ¦¬μ€ | |
4. μ£Όμμ¬ν λ° μ νμ¬ν | |
5. κΈ°λν¨κ³Ό λ° μ₯μ """ | |
messages = [{"role": "system", "content": f"{system_prefix} {system_message}"}] | |
messages.extend(format_history(history)) | |
messages.append({"role": "user", "content": message}) | |
response = "" | |
try: | |
client = get_client(model_name) | |
for msg in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = msg.choices[0].delta.get('content', None) | |
if token: | |
response += token | |
history = history + [[message, response]] | |
return "", history | |
except Exception as e: | |
error_msg = f"μΆλ‘ μ€ μ€λ₯κ° λ°μνμ΅λλ€: {str(e)}" | |
history = history + [[message, error_msg]] | |
return "", history | |
css = """ | |
footer {visibility: hidden} | |
""" | |
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(height=600) | |
msg = gr.Textbox( | |
label="λ©μμ§λ₯Ό μ λ ₯νμΈμ", | |
show_label=False, | |
placeholder="λ©μμ§λ₯Ό μ λ ₯νμΈμ...", | |
container=False | |
) | |
clear = gr.ClearButton([msg, chatbot]) | |
with gr.Column(scale=1): | |
model_name = gr.Dropdown( | |
choices=list(LLM_MODELS.keys()), | |
value="Default", | |
label="LLM λͺ¨λΈ μ ν", | |
info="μ¬μ©ν LLM λͺ¨λΈμ μ ννμΈμ" | |
) | |
file_upload = gr.File( | |
label="νμΌ μ λ‘λ (.csv, .txt, .py, .parquet)", | |
file_types=[".csv", ".txt", ".py", ".parquet"], | |
type="filepath" | |
) | |
with gr.Accordion("κ³ κΈ μ€μ ", open=False): | |
system_message = gr.Textbox(label="System Message", value="") | |
max_tokens = gr.Slider(minimum=1, maximum=8000, value=4000, label="Max Tokens") | |
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, label="Temperature") | |
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, label="Top P") | |
# μ΄λ²€νΈ λ°μΈλ© | |
msg.submit( | |
chat, | |
inputs=[msg, chatbot, file_upload, model_name, system_message, max_tokens, temperature, top_p], | |
outputs=[msg, chatbot] | |
) | |
# νμΌ μ λ‘λ μ μλ λΆμ | |
file_upload.change( | |
chat, | |
inputs=[gr.Textbox(value="νμΌ λΆμμ μμν©λλ€."), chatbot, file_upload, model_name, system_message, max_tokens, temperature, top_p], | |
outputs=[msg, chatbot] | |
) | |
# μμ μΆκ° | |
gr.Examples( | |
examples=[ | |
["μμΈν μ¬μ© λ°©λ²μ λ§μΉ νλ©΄μ 보면μ μ€λͺ νλ―μ΄ 4000 ν ν° μ΄μ μμΈν μ€λͺ νλΌ"], | |
["FAQ 20건μ μμΈνκ² μμ±νλΌ. 4000ν ν° μ΄μ μ¬μ©νλΌ."], | |
["μ¬μ© λ°©λ²κ³Ό μ°¨λ³μ , νΉμ§, κ°μ μ μ€μ¬μΌλ‘ 4000 ν ν° μ΄μ μ νλΈ μμ μ€ν¬λ¦½νΈ ννλ‘ μμ±νλΌ"], | |
["λ³Έ μλΉμ€λ₯Ό SEO μ΅μ ννμ¬ λΈλ‘κ·Έ ν¬μ€νΈλ‘ 4000 ν ν° μ΄μ μμ±νλΌ"], | |
["νΉν μΆμμ νμ©ν κΈ°μ λ° λΉμ¦λμ€λͺ¨λΈ μΈ‘λ©΄μ ν¬ν¨νμ¬ νΉν μΆμμ ꡬμ±μ λ§κ² μμ±νλΌ"], | |
["κ³μ μ΄μ΄μ λ΅λ³νλΌ"], | |
], | |
inputs=msg, | |
) | |
if __name__ == "__main__": | |
demo.launch() |