Spaces:
Build error
Build error
File size: 8,356 Bytes
8cf3f1c cb1513b 8cf3f1c cb1513b 8cf3f1c cb1513b 8cf3f1c cb1513b 8cf3f1c cb1513b 8cf3f1c cb1513b 8cf3f1c cb1513b 8cf3f1c cb1513b 8cf3f1c cb1513b 8cf3f1c cb1513b 8cf3f1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import warnings
warnings.filterwarnings("ignore")
import os
import sys
from typing import List, Tuple
from llama_cpp import Llama
from llama_cpp_agent import LlamaCppAgent
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
from llama_cpp_agent.messages_formatter import MessagesFormatter, PromptMarkers
from huggingface_hub import hf_hub_download
import gradio as gr
# Local imports (assure-toi que ces fichiers sont dans le même dossier)
from logger import logging
from exception import CustomExceptionHandling
# Download gguf model files
if not os.path.exists("./models"):
os.makedirs("./models")
MODEL_REPO_ID = "bartowski/google_gemma-3-1b-it-GGUF"
MODEL_FILENAME_Q4 = "google_gemma-3-1b-it-Q4_K_M.gguf"
if not os.path.exists(f"./models/{MODEL_FILENAME_Q4}"):
logging.info(f"Téléchargement du modèle {MODEL_FILENAME_Q4} depuis {MODEL_REPO_ID}...")
hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME_Q4,
local_dir="./models",
)
logging.info("Téléchargement terminé.")
else:
logging.info(f"Modèle {MODEL_FILENAME_Q4} déjà présent localement.")
# Define the prompt markers for Gemma 3
gemma_3_prompt_markers = {
Roles.system: PromptMarkers("", "\n"),
Roles.user: PromptMarkers("<start_of_turn>user\n", "<end_of_turn>\n"),
Roles.assistant: PromptMarkers("<start_of_turn>model\n", "<end_of_turn>\n"),
Roles.tool: PromptMarkers("", ""),
}
gemma_3_formatter = MessagesFormatter(
pre_prompt="",
prompt_markers=gemma_3_prompt_markers,
include_sys_prompt_in_first_user_message=True,
default_stop_sequences=["<end_of_turn>", "<start_of_turn>"],
strip_prompt=False,
bos_token="<bos>",
eos_token="<eos>",
)
# Global variables to cache the model
llm = None
current_model_name = None
def answer(
message: str,
historical_information: List[Tuple[str, str]],
model_filename: str,
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
repeat_penalty: float,
):
global llm
global current_model_name
try:
model_path = f"./models/{model_filename}"
if not os.path.exists(model_path):
yield f"Erreur : Fichier modèle non trouvé à {model_path}. Vérifiez le chemin."
return
if llm is None or current_model_name != model_filename:
logging.info(f"Chargement du modèle : {model_path}")
# Ajuste les n_threads en fonction de ton CPU
cpu_count = os.cpu_count()
threads_to_use = max(1, cpu_count // 2 if cpu_count else 4)
llm = Llama(
model_path=model_path,
flash_attn=False,
n_gpu_layers=0,
n_batch=512,
n_ctx=2048,
n_threads=threads_to_use,
n_threads_batch=threads_to_use,
verbose=False
)
current_model_name = model_filename
logging.info(f"Modèle {current_model_name} chargé avec {threads_to_use} threads.")
provider = LlamaCppPythonProvider(llm)
agent = LlamaCppAgent(
provider,
system_prompt=system_message,
custom_messages_formatter=gemma_3_formatter,
debug_output=False,
)
settings = provider.get_provider_default_settings()
settings.temperature = temperature
settings.top_k = top_k
settings.top_p = top_p
settings.max_tokens = max_tokens
settings.repeat_penalty = repeat_penalty
settings.stream = True
chat_history_for_agent = BasicChatHistory()
for user_msg, assistant_msg in historical_information:
if user_msg:
chat_history_for_agent.add_message({"role": Roles.user, "content": user_msg})
if assistant_msg:
chat_history_for_agent.add_message({"role": Roles.assistant, "content": assistant_msg})
logging.info(f"Envoi du message à l'agent: {message}")
stream = agent.get_chat_response(
message,
llm_sampling_settings=settings,
chat_history=chat_history_for_agent,
returns_streaming_generator=True,
print_output=False,
)
response_so_far = ""
for token in stream:
response_so_far += token
yield response_so_far
logging.info("Réponse générée.")
except Exception as e:
logging.error(f"Erreur lors de la génération de la réponse: {e}")
# Si tu utilises CustomExceptionHandling
# raise CustomExceptionHandling(e, sys) from e
yield f"Une erreur est survenue: {str(e)}"
available_models = [MODEL_FILENAME_Q4]
# --- Définition du Thème ---
# Tu peux décommenter et tester différents thèmes
# current_theme = gr.themes.Glass()
# current_theme = gr.themes.Monochrome()
# current_theme = gr.themes.Seafoam()
# current_theme = "gradio/dracula_revamped"
# current_theme = "NoCrypt/Miku"
current_theme = gr.themes.Soft(
primary_hue=gr.themes.colors.indigo, # Couleur principale (boutons, sliders actifs)
secondary_hue=gr.themes.colors.pink, # Couleur secondaire
neutral_hue=gr.themes.colors.slate, # Couleur neutre (texte, bordures)
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"] # Police
).set(
# Tu peux surcharger des éléments spécifiques du thème ici si besoin
# Exemple: body_background_fill="linear-gradient(to right, #DCE35B, #45B649)"
)
app_title = "OpenGemma3 Chat"
app_description = """Discutez avec **Gemma 3 1B-IT**, un modèle de langage avancé de Google, exécuté localement grâce à `llama.cpp`.
Explorez ses capacités en ajustant les paramètres de génération ci-dessous."""
demo = gr.ChatInterface(
answer,
chatbot=gr.Chatbot(
label="Conversation", # Label du composant chatbot
height=600,
scale=1,
show_copy_button=True,
resizable=True,
# Pour les avatars, crée un dossier 'avatars' et place des images dedans
# avatar_images=("./avatars/user_avatar.png", "./avatars/bot_avatar.png")
bubble_full_width=False # Pour que les bulles ne prennent pas toute la largeur
),
additional_inputs=[
gr.Dropdown(
choices=available_models,
value=available_models[0],
label="Modèle GGUF",
info="Sélectionnez le modèle GGUF à utiliser.",
),
gr.Textbox(value="You are a helpful and friendly AI assistant named Gemma. You are concise and provide accurate information.", label="System message", lines=3, info="Définissez la personnalité et le rôle de l'assistant."),
gr.Slider(minimum=128, maximum=3072, value=1024, step=128, label="Max Tokens", info="Nombre maximum de tokens à générer pour la réponse."),
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature", info="Contrôle la créativité (plus haut = plus créatif)."),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)", info="Considère les tokens dont la probabilité cumulative atteint top-p."),
gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k", info="Considère les k tokens les plus probables."),
gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition Penalty", info="Pénalise la répétition de tokens (plus haut = moins de répétition)."),
],
title=app_title,
description=app_description,
examples=[
["Explique le concept de trou noir de manière simple."],
["Quelle est la recette des crêpes ?"],
["Raconte-moi une histoire courte et amusante."]
],
submit_btn="Envoyer",
stop_btn="Arrêter",
theme=current_theme,
)
if __name__ == "__main__":
logging.info("Lancement de l'interface Gradio...")
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False) |