Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from datasets import load_dataset | |
from peft import LoraConfig, get_peft_model, PeftModel | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers.training_args import TrainingArguments | |
from transformers.trainer import Trainer | |
import coremltools as ct | |
import os | |
import zipfile | |
import tempfile | |
MODEL_NAME = "distilbert/distilgpt2" | |
DATASET_NAME = "roneneldan/TinyStories" | |
ADAPTER_PATH = "distilgpt2-lora-tinystories" | |
def load_base_model_and_tokenizer(): | |
"""Loads the base model and tokenizer.""" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
return model, tokenizer | |
def load_and_prepare_dataset(tokenizer, split="train"): | |
"""Loads and tokenizes the dataset.""" | |
dataset = load_dataset(DATASET_NAME, split=split) | |
def tokenize_function(examples): | |
tokenized = tokenizer( | |
examples["text"], | |
truncation=True, | |
padding="max_length", | |
max_length=256 | |
) | |
# For causal language modeling, labels are the same as input_ids | |
tokenized["labels"] = tokenized["input_ids"].copy() | |
return tokenized | |
# Handle different dataset types safely | |
try: | |
if hasattr(dataset, 'column_names'): | |
remove_cols = dataset.column_names | |
else: | |
remove_cols = None | |
except: | |
remove_cols = None | |
tokenized_dataset = dataset.map( | |
tokenize_function, | |
batched=True, | |
remove_columns=remove_cols | |
) | |
return tokenized_dataset | |
def fine_tune_model(model, tokenizer, tokenized_dataset): | |
"""Fine-tunes the model using LoRA.""" | |
lora_config = LoraConfig( | |
r=4, | |
lora_alpha=16, | |
lora_dropout=0.1, | |
bias="none", | |
task_type="CAUSAL_LM", | |
) | |
peft_model = get_peft_model(model, lora_config) | |
peft_model.print_trainable_parameters() | |
training_args = TrainingArguments( | |
output_dir="./results", | |
num_train_epochs=0.5, | |
per_device_train_batch_size=1, | |
per_device_eval_batch_size=1, | |
gradient_accumulation_steps=4, | |
logging_steps=10, | |
save_steps=100, | |
eval_steps=50, | |
warmup_steps=10, | |
fp16=torch.cuda.is_available(), | |
dataloader_pin_memory=False, | |
remove_unused_columns=False, | |
max_steps=100, | |
) | |
trainer = Trainer( | |
model=peft_model, | |
args=training_args, | |
train_dataset=tokenized_dataset, | |
) | |
trainer.train() | |
peft_model.save_pretrained(ADAPTER_PATH) | |
return peft_model | |
def convert_to_coreml(model, tokenizer): | |
"""Converts the model to CoreML format.""" | |
st.info("Merging LoRA adapter...") | |
merged_model = model.merge_and_unload() | |
st.success("Adapter merged.") | |
st.info("Moving model to CPU for CoreML conversion...") | |
merged_model = merged_model.cpu() | |
merged_model.eval() | |
st.success("Model moved to CPU.") | |
# Create a simple wrapper that only returns logits | |
class SimpleModel(torch.nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.model = model | |
def forward(self, input_ids): | |
outputs = self.model(input_ids) | |
return outputs.logits | |
simple_model = SimpleModel(merged_model) | |
st.info("Created simple model wrapper.") | |
st.info("Tracing the model...") | |
example_input = tokenizer("Once upon a time", return_tensors="pt") | |
input_ids = example_input.input_ids | |
# Ensure input is on CPU | |
input_ids = input_ids.cpu() | |
with torch.no_grad(): | |
traced_model = torch.jit.trace(simple_model, input_ids) | |
st.success("Model traced.") | |
st.info("Converting to CoreML ML Program...") | |
coreml_model = ct.convert( | |
traced_model, | |
convert_to="mlprogram", | |
inputs=[ct.TensorType(name="input_ids", shape=(1, 512), dtype=int)], | |
compute_units=ct.ComputeUnit.CPU_ONLY, | |
) | |
st.success("Conversion to CoreML complete.") | |
output_path = f"{ADAPTER_PATH}.mlpackage" | |
# Save CoreML model using the correct method | |
try: | |
coreml_model.save(output_path) | |
except AttributeError: | |
# Alternative method for newer versions | |
ct.models.MLModel(coreml_model).save(output_path) | |
return output_path | |
def main(): | |
st.title("LoRA Fine-Tuning of distilgpt2 for TinyStories") | |
st.write("This app fine-tunes the `distilbert/distilgpt2` model on the `TinyStories` dataset using LoRA and PEFT.") | |
# --- Load Model and Tokenizer --- | |
with st.spinner("Loading base model and tokenizer..."): | |
base_model, tokenizer = load_base_model_and_tokenizer() | |
st.session_state.base_model = base_model | |
st.session_state.tokenizer = tokenizer | |
st.success("Base model and tokenizer loaded.") | |
st.markdown(f"**Model:** `{MODEL_NAME}`") | |
# --- Fine-Tuning --- | |
st.header("1. LoRA Fine-Tuning") | |
if st.button("Start Fine-Tuning"): | |
with st.spinner("Loading dataset and fine-tuning... This might take a few minutes."): | |
tokenized_dataset = load_and_prepare_dataset(tokenizer) | |
st.session_state.tokenized_dataset = tokenized_dataset | |
# Safe way to get dataset length | |
try: | |
dataset_length = len(tokenized_dataset) | |
st.info(f"Dataset loaded with {dataset_length} examples.") | |
except (TypeError, AttributeError): | |
st.info("Dataset loaded (length unknown).") | |
peft_model = fine_tune_model(base_model, tokenizer, tokenized_dataset) | |
st.session_state.peft_model = peft_model | |
st.success("Fine-tuning complete! LoRA adapter saved.") | |
st.balloons() | |
# Check if adapter exists to offer loading it | |
if os.path.exists(ADAPTER_PATH) and "peft_model" not in st.session_state: | |
if st.button("Load Fine-Tuned LoRA Adapter"): | |
with st.spinner("Loading fine-tuned model..."): | |
peft_model = PeftModel.from_pretrained(base_model, ADAPTER_PATH) | |
st.session_state.peft_model = peft_model | |
st.success("Fine-tuned LoRA model loaded.") | |
# --- Text Generation --- | |
if "peft_model" in st.session_state: | |
st.header("2. Generate Story") | |
prompt = st.text_input("Enter a prompt to start a story:", "Once upon a time, in a land full of sunshine,") | |
# Generation parameters | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
temperature = st.slider("Temperature", 0.1, 2.0, 0.8, 0.1) | |
with col2: | |
max_length = st.slider("Max Length", 50, 200, 100, 10) | |
with col3: | |
repetition_penalty = st.slider("Repetition Penalty", 1.0, 2.0, 1.2, 0.1) | |
if st.button("Generate"): | |
with st.spinner("Generating text..."): | |
model = st.session_state.peft_model | |
inputs = tokenizer(prompt, return_tensors="pt") | |
device = next(model.parameters()).device | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
outputs = model.generate( | |
**inputs, | |
max_length=max_length, | |
num_return_sequences=1, | |
temperature=temperature, | |
do_sample=True, | |
top_k=50, | |
top_p=0.9, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.eos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
no_repeat_ngram_size=3 | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
st.write("### Generated Story:") | |
st.write(generated_text) | |
# --- CoreML Conversion --- | |
st.header("3. Convert to CoreML") | |
if st.button("Convert Model to CoreML"): | |
with st.spinner("Converting model to CoreML format..."): | |
coreml_model_path = convert_to_coreml(st.session_state.peft_model, st.session_state.tokenizer) | |
st.success(f"Model successfully converted and saved to `{coreml_model_path}`") | |
# For .mlpackage files, we need to create a zip file for download | |
zip_path = f"{ADAPTER_PATH}.zip" | |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
for root, dirs, files in os.walk(coreml_model_path): | |
for file in files: | |
file_path = os.path.join(root, file) | |
arcname = os.path.relpath(file_path, coreml_model_path) | |
zipf.write(file_path, arcname) | |
with open(zip_path, "rb") as f: | |
st.download_button( | |
label="Download CoreML Model", | |
data=f, | |
file_name=os.path.basename(zip_path), | |
mime="application/zip" | |
) | |
if __name__ == "__main__": | |
main() |