Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from repeng import ControlVector, ControlModel | |
import gradio as gr | |
# Initialize model and tokenizer | |
from huggingface_hub import login | |
# Initialize model and tokenizer | |
mistral_path = "mistralai/Mistral-7B-Instruct-v0.3" | |
# mistral_path = "E:/language_models/models/mistral" | |
access_token = os.getenv("mistralaccesstoken") | |
login(access_token) | |
tokenizer = AutoTokenizer.from_pretrained(mistral_path) | |
tokenizer.pad_token_id = 0 | |
model = AutoModelForCausalLM.from_pretrained( | |
mistral_path, | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
use_safetensors=True | |
) | |
model = model.to("cuda:0" if torch.cuda.is_available() else "cpu") | |
model = ControlModel(model, list(range(-5, -18, -1))) | |
# Generation settings | |
generation_settings = { | |
"pad_token_id": tokenizer.eos_token_id, # Silence warning | |
"do_sample": False, # Deterministic output | |
"max_new_tokens": 256, | |
"repetition_penalty": 1.1, # Reduce repetition | |
} | |
# Tags for prompt formatting | |
user_tag, asst_tag = "[INST]", "[/INST]" | |
# List available control vectors | |
control_vector_files = [f for f in os.listdir('.') if f.endswith('.gguf')] | |
if not control_vector_files: | |
raise FileNotFoundError("No .gguf control vector files found in the current directory.") | |
# Function to toggle slider visibility based on checkbox state | |
def toggle_slider(checked): | |
return gr.update(visible=checked) | |
# Function to generate the model's response | |
def generate_response(system_prompt, user_message, *args, history): | |
# Separate checkboxes and sliders based on type | |
print(f"Generating response to {user_message}") | |
checkboxes = [item for item in args if isinstance(item, bool)] | |
sliders = [item for item in args if isinstance(item, (int, float))] | |
if len(checkboxes) != len(control_vector_files) or len(sliders) != len(control_vector_files): | |
return history # Return current history if there's a mismatch | |
# Reset any previous control vectors | |
model.reset() | |
# Apply selected control vectors with their corresponding weights | |
for i in range(len(control_vector_files)): | |
if checkboxes[i]: | |
cv_file = control_vector_files[i] | |
weight = sliders[i] | |
try: | |
print(f"Setting {cv_file} to {weight}") | |
control_vector = ControlVector.import_gguf(cv_file) | |
model.set_control(control_vector, weight) | |
except Exception as e: | |
print(f"Failed to set control vector {cv_file}: {e}") | |
# Initialize history if None | |
history = history or [] | |
# Construct the formatted prompt based on history | |
formatted_prompt = "" | |
for turn in history: | |
user_msg, asst_msg = turn | |
formatted_prompt += f"{user_tag} {user_msg} {asst_tag} {asst_msg} </s>" | |
# Append the system prompt if provided | |
if system_prompt.strip(): | |
formatted_prompt += f"[INST] {system_prompt}" | |
# Append the new user message | |
formatted_prompt += f"\n{user_tag} {user_message} {asst_tag}" | |
# Tokenize the input | |
input_ids = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) | |
# Generate the response | |
output_ids = model.generate(**input_ids, **generation_settings) | |
response = tokenizer.decode(output_ids.squeeze(), skip_special_tokens=True) | |
# Clean up the response by removing any trailing tags | |
if "</s>" in response: | |
response = response.split("</s>")[0].strip() | |
# Update conversation history | |
history.append((user_message, response)) | |
return history | |
# Function to reset the conversation history | |
def reset_chat(): | |
return [], [] | |
# Build the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🧠 Mistral v3 Language Model Interface") | |
with gr.Row(): | |
# Left Column: Settings and Control Vectors | |
with gr.Column(scale=1): | |
gr.Markdown("### ⚙️ Settings") | |
# System Prompt Input | |
system_prompt = gr.Textbox( | |
label="System Prompt", | |
lines=2, | |
placeholder="Enter system-level instructions here..." | |
) | |
gr.Markdown("### 📊 Control Vectors") | |
# Create checkboxes and sliders for each control vector | |
control_checks = [] | |
control_sliders = [] | |
for cv_file in control_vector_files: | |
with gr.Row(): | |
# Checkbox to select the control vector | |
checkbox = gr.Checkbox(label=cv_file, value=False) | |
control_checks.append(checkbox) | |
# Slider to adjust the control vector's weight | |
slider = gr.Slider( | |
minimum=-2.5, | |
maximum=2.5, | |
value=0.0, | |
step=0.1, | |
label=f"{cv_file} Weight", | |
visible=False | |
) | |
control_sliders.append(slider) | |
# Link the checkbox to toggle slider visibility | |
checkbox.change( | |
toggle_slider, | |
inputs=checkbox, | |
outputs=slider | |
) | |
# Right Column: Chat Interface | |
with gr.Column(scale=2): | |
gr.Markdown("### 🗨️ Conversation") | |
# Chatbot to display conversation | |
chatbot = gr.Chatbot(label="Conversation") | |
# User Message Input | |
user_input = gr.Textbox( | |
label="Your Message", | |
lines=2, | |
placeholder="Type your message here..." | |
) | |
with gr.Row(): | |
# Submit and New Chat buttons | |
submit_button = gr.Button("💬 Submit") | |
new_chat_button = gr.Button("🆕 New Chat") | |
# State to keep track of conversation history | |
state = gr.State([]) | |
# Define button actions | |
submit_button.click( | |
generate_response, | |
inputs=[system_prompt, user_input] + control_checks + control_sliders + [state], | |
outputs=[chatbot, state] | |
) | |
new_chat_button.click( | |
reset_chat, | |
inputs=[], | |
outputs=[chatbot, state] | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch() |