File size: 3,411 Bytes
1cff349
a82a7f6
2b295a0
 
 
 
 
 
 
 
1719b69
1cff349
 
 
1719b69
1cff349
 
 
 
d7ba616
1cff349
 
d7ba616
0e60e44
d7ba616
2b295a0
 
d7ba616
1cff349
2b295a0
1cff349
 
 
 
 
 
 
 
 
 
 
2b295a0
1cff349
 
 
2b295a0
1cff349
 
 
 
 
 
 
 
 
 
 
2b295a0
1cff349
 
 
 
2b295a0
d7ba616
a82a7f6
 
 
 
 
 
 
 
 
d7ba616
a82a7f6
 
 
 
 
2b295a0
a82a7f6
d7ba616
 
 
 
 
 
 
a82a7f6
 
2b295a0
a82a7f6
d7ba616
a82a7f6
 
 
 
 
1cff349
 
 
 
a82a7f6
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import gradio as gr
import logging

# --- Config logs ---
logging.basicConfig(
    level=logging.INFO,
    format="[%(asctime)s] %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# --- OpenAI (SDK officiel) ---
from openai import OpenAI
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# --- Mistral (SDK officiel) ---
from mistralai import Mistral
mistral_client = Mistral(api_key=os.getenv("MISTRAL_API_KEY"))

# Choix par défaut
DEFAULT_PROVIDER = "openai"  # ou "mistral"


def llm_chat(messages, max_tokens, temperature, top_p, provider=DEFAULT_PROVIDER):
    provider = (provider or "").strip().lower()
    logger.info(f"Appel LLM provider={provider}, max_tokens={max_tokens}, temp={temperature}, top_p={top_p}")
    logger.info(f"Messages envoyés: {messages}")

    if provider == "openai":
        logger.info("→ Appel OpenAI Chat Completions")
        stream = openai_client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
            stream=True,
        )
        for chunk in stream:
            delta = chunk.choices[0].delta
            if delta and delta.content:
                logger.debug(f"OpenAI renvoie token: {delta.content!r}")
                yield delta.content

    elif provider == "mistral":
        logger.info("→ Appel Mistral Chat Completions")
        stream = mistral_client.chat.stream(
            model="mistral-large-latest",
            messages=messages,
            temperature=temperature,
            top_p=top_p,
            max_tokens=max_tokens,
        )
        for event in stream:
            if event.type == "chat.completion.chunk":
                piece = event.data.delta or ""
                if piece:
                    logger.debug(f"Mistral renvoie token: {piece!r}")
                    yield piece
        stream.close()

    else:
        logger.error(f"Provider inconnu: {provider}")
        yield "[Erreur] Provider inconnu (utilise 'openai' ou 'mistral')."


def respond(
    message,
    history: list[dict[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    provider,  # "openai" ou "mistral"
):
    messages = [{"role": "system", "content": system_message}]
    messages.extend(history)
    messages.append({"role": "user", "content": message})

    logger.info(f"Nouvelle requête utilisateur: {message}")
    response = ""
    for token in llm_chat(
        messages=messages,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        provider=provider,
    ):
        response += token
        yield response
    logger.info(f"Réponse finale générée: {response}")


chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
        gr.Dropdown(choices=["openai", "mistral"], value=DEFAULT_PROVIDER, label="Provider"),
    ],
)

with gr.Blocks() as demo:
    chatbot.render()

if __name__ == "__main__":
    demo.launch()