Chan-Y's picture
Update app.py
261080b verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel, PeftConfig
import gc
import torch
# Global variables to track loaded models
current_model = None
current_pipe = None
def load_adapter_model(adapter_model_name):
global current_model, current_pipe
# If there's a model already loaded, delete it to free memory
if current_model is not None:
del current_model
del current_pipe
# Force garbage collection
gc.collect()
torch.cuda.empty_cache()
# Load the base model and tokenizer
base_model_name = "unsloth/gemma-3-12b-it"
# Load tokenizer from the base model
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
try:
# Method 1: Try loading as a PEFT model
print(f"Loading adapter model {adapter_model_name} on top of {base_model_name}...")
# First load the adapter config
peft_config = PeftConfig.from_pretrained(adapter_model_name)
# Then load the base model
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
device_map="auto",
torch_dtype="auto"
)
# Load the adapter on top of the base model
model = PeftModel.from_pretrained(base_model, adapter_model_name)
current_model = model
except Exception as e:
print(f"PEFT loading failed: {e}")
try:
# Method 2: Try loading directly if it's already merged or a different format
print("Trying to load model directly...")
model = AutoModelForCausalLM.from_pretrained(
adapter_model_name,
device_map="auto",
torch_dtype="auto"
)
current_model = model
except Exception as e2:
print(f"Direct loading failed: {e2}")
# Method 3: Fallback to using the model name in pipeline
print("Falling back to using the model name in pipeline...")
pipe = pipeline("text-generation", model=adapter_model_name)
current_pipe = pipe
return pipe
# Create pipeline with the loaded model and tokenizer
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
current_pipe = pipe
return pipe
# Default model name
default_model = "Chan-Y/gemma3-12b-1204-seperate"
# Create the initial pipeline
pipe = load_adapter_model(default_model)
pipe.model_name = default_model # Track the current model name
def generate_response(model_name, prompt, system_prompt, max_length, temperature, top_p, top_k):
"""Generate text using the model based on user input and advanced settings"""
global pipe
# Check if we need to load a different model
if model_name != getattr(pipe, 'model_name', default_model):
pipe = load_adapter_model(model_name)
# Store the model name attribute on the pipeline for tracking
pipe.model_name = model_name
messages = [
[
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}]
},
{
"role": "user",
"content": [{"type": "text", "text": prompt}]
},
],
]
print("Generating response...")
# Generate text with all parameters
output = pipe(
messages,
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k
)
# Extract the generated text from the output
return output[0][0]["generated_text"][-1]["content"]
# Default system prompt in Turkish
#default_system_prompt = """Sana bir problem verildi.
#Problem hakkında düşün ve çalışmanı göster.
#Çalışmanı <start_working_out> ve <end_working_out> arasına yerleştir.
#Sonra, çözümünü <SOLUTION> ve </SOLUTION> arasına yerleştir.
#Lütfen SADECE Türkçe kullan."""
default_system_prompt = """Sen kullanıcıların isteklerine Türkçe cevap veren bir asistansın ve sana bir problem verildi.
Problem hakkında düşün ve çalışmanı göster.
Çalışmanı <start_working_out> ve <end_working_out> arasına yerleştir.
Sonra, çözümünü <SOLUTION> ve </SOLUTION> arasına yerleştir.
Lütfen SADECE Türkçe kullan."""
# Create Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Gemma 3 Reasoning Model Interface")
gr.Markdown("Using Gemma 3 1B with Turkish reasoning adapters")
with gr.Row():
with gr.Column():
# Model selection in an expander
with gr.Accordion("Model Selection", open=True):
model_selector = gr.Dropdown(
choices=[
"Chan-Y/gemma3-12b-1204-seperate",
],
value="Chan-Y/gemma3-12b-1204-seperate",
label="Select Model",
info="Choosing a new model will unload the current one to save memory"
)
prompt_input = gr.Textbox(
lines=5,
placeholder="Enter your prompt here...",
label="Prompt"
)
# Advanced settings in an expander (accordion)
with gr.Accordion("Advanced Settings", open=False):
system_prompt = gr.Textbox(
lines=5,
value=default_system_prompt,
label="System Prompt"
)
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.75,
step=0.1,
label="Temperature"
)
max_tokens = gr.Slider(
minimum=64,
maximum=1024*4,
value=512,
step=16,
label="Max New Tokens"
)
top_p_value = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p"
)
top_k_value = gr.Slider(
minimum=1,
maximum=100,
value=64,
step=1,
label="Top-k"
)
submit_btn = gr.Button("Generate Response")
with gr.Column():
output_text = gr.Textbox(lines=15, label="Generated Response")
# Connect the function to the interface
submit_btn.click(
fn=generate_response,
inputs=[
model_selector,
prompt_input,
system_prompt,
max_tokens,
temperature,
top_p_value,
top_k_value
],
outputs=output_text
)
# Launch the interface
if __name__ == "__main__":
demo.launch() # Set share=True to create a public link