LLMnBiasV2 / app.py
Woziii's picture
Update app.py
ea35578 verified
raw
history blame
5.04 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import matplotlib.pyplot as plt
import numpy as np
# Login to Hugging Face with token
login(token=os.environ["HF_TOKEN"])
# Liste des modèles
model_list = [
"meta-llama/Llama-2-13b",
"meta-llama/Llama-2-7b",
"meta-llama/Llama-2-70b",
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Llama-3.2-3B",
"meta-llama/Llama-3.1-8B",
"mistralai/Mistral-7B-v0.1",
"mistralai/Mixtral-8x7B-v0.1",
"mistralai/Mistral-7B-v0.3",
"google/gemma-2-2b",
"google/gemma-2-9b",
"google/gemma-2-27b",
"croissantllm/CroissantLLMBase"
]
# Charger le modèle et le tokenizer
model = None
tokenizer = None
def load_model(model_name):
global model, tokenizer
print(f"Chargement du modèle {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16, attn_implementation="eager")
print("Modèle chargé avec succès.")
return f"Modèle {model_name} chargé."
def plot_attention(attention_data):
tokens = attention_data['tokens']
attention = attention_data['attention']
fig, ax = plt.subplots(figsize=(10, 10))
cax = ax.matshow(attention, cmap='viridis')
fig.colorbar(cax)
ax.set_xticklabels([''] + tokens, rotation=90)
ax.set_yticklabels([''] + tokens)
plt.xlabel("Tokens")
plt.ylabel("Tokens")
plt.title("Attention Heatmap")
plt.tight_layout()
plt.savefig('attention_plot.png')
return 'attention_plot.png'
def plot_probabilities(prob_data):
words, probs = zip(*prob_data.items())
plt.figure(figsize=(6, 4))
plt.barh(words, probs, color='skyblue')
plt.xlabel('Probabilities')
plt.title('Top Probable Words')
plt.tight_layout()
plt.savefig('probabilities_plot.png')
return 'probabilities_plot.png'
def generate_text(input_text, temperature, top_p, top_k):
global model, tokenizer
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
top_k=top_k,
output_scores=True,
output_attentions=True,
return_dict_in_generate=True,
return_legacy_cache=True
)
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Logits et probabilités du dernier token généré
last_token_logits = outputs.scores[-1][0]
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
# Top 5 des mots les plus probables
top_probs, top_indices = torch.topk(probabilities, 5)
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
# Extraction des attentions
attentions = torch.cat([att[-1].mean(dim=1) for att in outputs.attentions], dim=0).cpu().numpy()
attention_data = {
'attention': attentions,
'tokens': tokenizer.convert_ids_to_tokens(outputs.sequences[0])
}
return generated_text, plot_attention(attention_data), plot_probabilities(prob_data)
def reset_app():
global model, tokenizer
model = None
tokenizer = None
return "Application réinitialisée."
# Interface utilisateur Gradio
with gr.Blocks() as demo:
with gr.Row():
model_selection = gr.Accordion("Sélection du modèle", open=True)
with model_selection:
model_name = gr.Dropdown(choices=model_list, label="Choisir un modèle", value=model_list[0])
load_model_button = gr.Button("Charger le modèle")
load_status = gr.Textbox(label="Statut du modèle", interactive=False)
with gr.Row():
temperature = gr.Slider(0.0, 1.0, value=0.7, label="Température")
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top-p")
top_k = gr.Slider(1, 100, value=50, label="Top-k")
with gr.Row():
input_text = gr.Textbox(label="Entrer le texte")
generate_button = gr.Button("Générer")
with gr.Row():
output_text = gr.Textbox(label="Texte généré", interactive=False)
with gr.Row():
attention_plot = gr.Image(label="Carte de chaleur des attentions")
prob_plot = gr.Image(label="Probabilités des mots les plus probables")
with gr.Row():
reset_button = gr.Button("Réinitialiser l'application")
load_model_button.click(load_model, inputs=[model_name], outputs=[load_status])
generate_button.click(generate_text, inputs=[input_text, temperature, top_p, top_k], outputs=[output_text, attention_plot, prob_plot])
reset_button.click(reset_app)
demo.launch()