S-Dreamer's picture
Update app.py
054953c verified
```python
# app.py
import time
import zipfile
from pathlib import Path
import gradio as gr
from src.train import finetune_lora
from src.infer import load_generator, generate_text
def _default_output_root() -> Path:
# On Hugging Face Spaces, /data exists if Persistent Storage is enabled.
# Otherwise fall back to a repo-local outputs/ directory.
return Path("/data/outputs") if Path("/data").exists() else Path("outputs")
def _zip_dir(src_dir: Path, zip_path: Path) -> Path:
zip_path.parent.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for p in src_dir.rglob("*"):
if p.is_file():
zf.write(p, arcname=p.relative_to(src_dir))
return zip_path
def run_train(
base_model: str,
dataset_id: str,
max_train_samples: int,
max_steps: int,
lr: float,
batch_size: int,
lora_r: int,
lora_alpha: int,
lora_dropout: float,
):
out_root = _default_output_root()
run_id = time.strftime("%Y%m%d-%H%M%S")
out_dir = out_root / run_id
out_dir.mkdir(parents=True, exist_ok=True)
status = finetune_lora(
base_model=base_model.strip(),
dataset_id=dataset_id.strip(),
output_dir=str(out_dir),
max_train_samples=int(max_train_samples),
max_steps=int(max_steps),
learning_rate=float(lr),
batch_size=int(batch_size),
lora_r=int(lora_r),
lora_alpha=int(lora_alpha),
lora_dropout=float(lora_dropout),
)
adapter_dir = out_dir / "adapter"
zip_path = out_dir / "adapter.zip"
zip_file = None
if adapter_dir.exists():
_zip_dir(adapter_dir, zip_path)
zip_file = str(zip_path)
msg = (
f"Done.\n\n"
f"Run dir: {out_dir}\n"
f"Adapter dir: {adapter_dir}\n\n"
f"{status}"
)
return msg, zip_file, str(out_dir), str(adapter_dir)
def run_generate(
base_model: str,
adapter_dir: str,
prompt: str,
max_new_tokens: int,
temperature: float,
):
gen = load_generator(base_model.strip(), adapter_dir.strip())
return generate_text(
gen,
prompt,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
)
with gr.Blocks(title="Fine-tune Pipeline (Docker)") as demo:
gr.Markdown("# Fine-tuning pipeline (LoRA) — Docker Space\nUsing Trendyol cybersecurity instruction dataset.")
with gr.Tab("Train"):
base_model = gr.Textbox(
value="sshleifer/tiny-gpt2",
label="Base model (HF Hub id)",
info="Tip: for best chat behavior, use a small instruct/chat model and update LoRA target_modules in src/train.py accordingly.",
)
dataset_id = gr.Textbox(
value="Trendyol/Trendyol-Cybersecurity-Instruction-Tuning-Dataset",
label="Dataset (HF Hub id)",
info="Expected columns: system, user, assistant (this dataset has them).",
)
with gr.Row():
max_train_samples = gr.Number(value=2000, precision=0, label="Max train samples")
max_steps = gr.Number(value=100, precision=0, label="Max steps")
with gr.Row():
lr = gr.Number(value=2e-4, label="Learning rate")
batch_size = gr.Number(value=2, precision=0, label="Batch size")
with gr.Row():
lora_r = gr.Number(value=8, precision=0, label="LoRA r")
lora_alpha = gr.Number(value=16, precision=0, label="LoRA alpha")
lora_dropout = gr.Number(value=0.05, label="LoRA dropout")
train_btn = gr.Button("Start fine-tune")
train_out = gr.Textbox(lines=12, label="Status / logs")
adapter_zip = gr.File(label="Download trained adapter (zip)")
out_dir_box = gr.Textbox(label="Run output directory")
adapter_dir_box = gr.Textbox(label="Adapter directory (use this in Generate tab)")
train_btn.click(
fn=run_train,
inputs=[
base_model,
dataset_id,
max_train_samples,
max_steps,
lr,
batch_size,
lora_r,
lora_alpha,
lora_dropout,
],
outputs=[train_out, adapter_zip, out_dir_box, adapter_dir_box],
queue=True,
)
with gr.Tab("Generate"):
base_model2 = gr.Textbox(
value="sshleifer/tiny-gpt2",
label="Base model (must match training)",
)
adapter_dir = gr.Textbox(
placeholder="Paste adapter dir path from Train tab (e.g., outputs/20260306-120000/adapter)",
label="Adapter directory",
)
prompt = gr.Textbox(
value="Explain the difference between phishing and spear phishing.",
lines=4,
label="Prompt",
)
with gr.Row():
max_new_tokens = gr.Slider(16, 256, value=120, step=1, label="Max new tokens")
temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature")
gen_btn = gr.Button("Generate")
gen_out = gr.Textbox(lines=12, label="Output")
gen_btn.click(
fn=run_generate,
inputs=[base_model2, adapter_dir, prompt, max_new_tokens, temperature],
outputs=[gen_out],
queue=False,
)
demo.launch()
```