import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import spaces import matplotlib.pyplot as plt import numpy as np from huggingface_hub import login import os login(token=os.environ["HF_TOKEN"]) # Liste des modèles models = [ "meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b", "meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B", "mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3", "google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b", "croissantllm/CroissantLLMBase" ] # Variables globales pour stocker le modèle et le tokenizer model = None tokenizer = None def load_model(model_name): global model, tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu") return f"Modèle {model_name} chargé avec succès sur CPU." @spaces.GPU(duration=300) def generate_text(input_text, temperature, top_p, top_k): global model, tokenizer inputs = tokenizer(input_text, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=50, temperature=temperature, top_p=top_p, top_k=top_k, output_attentions=True, return_dict_in_generate=True ) generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) # Extraire les attentions et les logits attentions = outputs.attentions[-1][0][-1].numpy() logits = outputs.scores[-1][0] # Visualiser l'attention plt.figure(figsize=(10, 10)) plt.imshow(attentions, cmap='viridis') plt.title("Carte d'attention") attention_plot = plt.gcf() plt.close() # Obtenir les mots les plus probables probs = torch.nn.functional.softmax(logits, dim=-1) top_probs, top_indices = torch.topk(probs, k=5) top_words = [tokenizer.decode([idx]) for idx in top_indices] return generated_text, attention_plot, top_words def reset(): return "", 1.0, 1.0, 50, None, None, None with gr.Blocks() as demo: gr.Markdown("# Générateur de texte avec visualisation d'attention") with gr.Accordion("Sélection du modèle"): model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle") load_button = gr.Button("Charger le modèle") load_output = gr.Textbox(label="Statut du chargement") with gr.Row(): temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température") top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p") top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") input_text = gr.Textbox(label="Texte d'entrée") generate_button = gr.Button("Générer") output_text = gr.Textbox(label="Texte généré") with gr.Row(): attention_plot = gr.Plot(label="Visualisation de l'attention") top_words = gr.JSON(label="Mots les plus probables") reset_button = gr.Button("Réinitialiser") load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output]) generate_button.click(generate_text, inputs=[input_text, temperature, top_p, top_k], outputs=[output_text, attention_plot, top_words]) reset_button.click(reset, outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, top_words]) demo.launch()