Spaces:
Build error
Build error
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import matplotlib.pyplot as plt | |
| import time | |
| import json | |
| import re | |
| import os | |
| import asyncio | |
| # ------------------------------- | |
| # Utility Functions | |
| # ------------------------------- | |
| token = st.secrets["HF_TOKEN"] | |
| os.environ['CURL_CA_BUNDLE'] = '' | |
| def load_model(model_id: str, token: str): | |
| """ | |
| Loads and caches the Gemma model and tokenizer with authentication token. | |
| """ | |
| try: | |
| # Create and run an event loop explicitly | |
| asyncio.run(async_load(model_id, token)) | |
| # Ensure torch classes path is valid (optional) | |
| if not hasattr(torch, "classes") or not torch.classes: | |
| torch.classes = torch._C._get_python_module("torch.classes") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, token=token) | |
| return tokenizer, model | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| st.error(f"Model loading failed: {e}") | |
| return None, None | |
| async def async_load(model_id, token): | |
| """ | |
| Dummy async function to initialize the event loop. | |
| """ | |
| await asyncio.sleep(0.1) # Dummy async operation | |
| def preprocess_data(uploaded_file, file_extension): | |
| """ | |
| Reads the uploaded file and returns a processed version. | |
| Supports CSV, JSONL, and TXT. | |
| """ | |
| data = None | |
| try: | |
| if file_extension == "csv": | |
| data = pd.read_csv(uploaded_file) | |
| elif file_extension == "jsonl": | |
| # Each line is a JSON object. | |
| data = [json.loads(line) for line in uploaded_file.readlines()] | |
| try: | |
| data = pd.DataFrame(data) | |
| except Exception: | |
| st.warning("Unable to convert JSONL to a table. Previewing raw JSON objects.") | |
| elif file_extension == "txt": | |
| text_data = uploaded_file.read().decode("utf-8") | |
| data = text_data.splitlines() | |
| except Exception as e: | |
| st.error(f"Error processing file: {e}") | |
| return data | |
| def clean_text(text, lowercase=True, remove_punctuation=True): | |
| """ | |
| Cleans text data by applying basic normalization. | |
| """ | |
| if lowercase: | |
| text = text.lower() | |
| if remove_punctuation: | |
| text = re.sub(r'[^\w\s]', '', text) | |
| return text | |
| def plot_training_metrics(epochs, loss_values, accuracy_values): | |
| """ | |
| Returns a matplotlib figure plotting training loss and accuracy. | |
| """ | |
| fig, ax = plt.subplots(1, 2, figsize=(12, 4)) | |
| ax[0].plot(range(1, epochs+1), loss_values, marker='o', color='red') | |
| ax[0].set_title("Training Loss") | |
| ax[0].set_xlabel("Epoch") | |
| ax[0].set_ylabel("Loss") | |
| ax[1].plot(range(1, epochs+1), accuracy_values, marker='o', color='green') | |
| ax[1].set_title("Training Accuracy") | |
| ax[1].set_xlabel("Epoch") | |
| ax[1].set_ylabel("Accuracy") | |
| return fig | |
| def simulate_training(num_epochs): | |
| """ | |
| Simulates a training loop for demonstration. | |
| Yields current epoch, loss values, and accuracy values. | |
| Replace this with your actual fine-tuning loop. | |
| """ | |
| loss_values = [] | |
| accuracy_values = [] | |
| for epoch in range(1, num_epochs + 1): | |
| loss = np.exp(-epoch) + np.random.random() * 0.1 | |
| acc = 0.5 + (epoch / num_epochs) * 0.5 + np.random.random() * 0.05 | |
| loss_values.append(loss) | |
| accuracy_values.append(acc) | |
| yield epoch, loss_values, accuracy_values | |
| time.sleep(1) # Simulate computation time | |
| def quantize_model(model): | |
| """ | |
| Applies dynamic quantization for demonstration. | |
| In practice, adjust this based on your model and target hardware. | |
| """ | |
| quantized_model = torch.quantization.quantize_dynamic( | |
| model, {torch.nn.Linear}, dtype=torch.qint8 | |
| ) | |
| return quantized_model | |
| def convert_to_torchscript(model): | |
| """ | |
| Converts the model to TorchScript format. | |
| """ | |
| example_input = torch.randint(0, 100, (1, 10)) | |
| traced_model = torch.jit.trace(model, example_input) | |
| return traced_model | |
| def convert_to_onnx(model, output_path="model.onnx"): | |
| """ | |
| Converts the model to ONNX format. | |
| """ | |
| dummy_input = torch.randint(0, 100, (1, 10)) | |
| torch.onnx.export(model, dummy_input, output_path, input_names=["input"], output_names=["output"]) | |
| return output_path | |
| def load_finetuned_model(model, checkpoint_path="fine_tuned_model.pt"): | |
| """ | |
| Loads the fine-tuned model from the checkpoint. | |
| """ | |
| if os.path.exists(checkpoint_path): | |
| model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| st.success("Fine-tuned model loaded successfully!") | |
| else: | |
| st.error(f"Checkpoint not found: {checkpoint_path}") | |
| return model | |
| def generate_response(prompt, model, tokenizer, max_length=200): | |
| """ | |
| Generates a response using the fine-tuned model. | |
| """ | |
| # Tokenize the prompt | |
| inputs = tokenizer(prompt, return_tensors="pt").input_ids | |
| # Generate text | |
| with torch.no_grad(): | |
| outputs = model.generate(inputs, max_length=max_length, num_return_sequences=1, temperature=0.7) | |
| # Decode the output | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| # ------------------------------- | |
| # Application Layout | |
| # ------------------------------- | |
| st.title("One-Stop Gemma Model Fine-tuning, Quantization & Conversion UI") | |
| st.markdown(""" | |
| This application is designed for beginners in generative AI. | |
| It allows you to fine-tune, quantize, and convert Gemma models with an intuitive UI. | |
| You can upload your dataset, clean and preview your data, configure training parameters, and export your model in different formats. | |
| """) | |
| # Sidebar: Model selection and data upload | |
| st.sidebar.header("Configuration") | |
| # Model Selection | |
| selected_model = st.sidebar.selectbox("Select Gemma Model", options=["Gemma-Small", "Gemma-Medium", "Gemma-Large"]) | |
| if selected_model == "google/gemma-3-1b-it": | |
| model_id = "google/gemma-3-1b-it" | |
| elif selected_model == "google/gemma-3-4b-it": | |
| model_id = "google/gemma-3-4b-it" | |
| else: | |
| model_id = "google/gemma-3-1b-it" | |
| loading_placeholder = st.sidebar.empty() | |
| loading_placeholder.info("Loading model...") | |
| tokenizer, model = load_model(model_id, token) | |
| loading_placeholder.success("Model loaded.") | |
| # Dataset Upload | |
| uploaded_file = st.sidebar.file_uploader("Upload Dataset (CSV, JSONL, TXT)", type=["csv", "jsonl", "txt"]) | |
| data = None | |
| if uploaded_file is not None: | |
| file_ext = uploaded_file.name.split('.')[-1].lower() | |
| data = preprocess_data(uploaded_file, file_ext) | |
| st.sidebar.subheader("Dataset Preview:") | |
| if isinstance(data, pd.DataFrame): | |
| st.sidebar.dataframe(data.head()) | |
| elif isinstance(data, list): | |
| st.sidebar.write(data[:5]) | |
| else: | |
| st.sidebar.write(data) | |
| else: | |
| st.sidebar.info("Awaiting dataset upload.") | |
| # Data Cleaning Options (for TXT files) | |
| if uploaded_file is not None and file_ext == "txt": | |
| st.sidebar.subheader("Data Cleaning Options") | |
| lowercase_option = st.sidebar.checkbox("Convert to lowercase", value=True) | |
| remove_punct = st.sidebar.checkbox("Remove punctuation", value=True) | |
| cleaned_data = [clean_text(line, lowercase=lowercase_option, remove_punctuation=remove_punct) for line in data] | |
| st.sidebar.text_area("Cleaned Data Preview", value="\n".join(cleaned_data[:5]), height=150) | |
| # Main Tabs for Different Operations | |
| tabs = st.tabs(["Fine-tuning", "Quantization", "Model Conversion"]) | |
| # ------------------------------- | |
| # Fine-tuning Tab | |
| # ------------------------------- | |
| with tabs[0]: | |
| st.header("Fine-tuning") | |
| st.markdown("Configure hyperparameters and start fine-tuning your Gemma model.") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| learning_rate = st.number_input("Learning Rate", value=1e-4, format="%.5f") | |
| with col2: | |
| batch_size = st.number_input("Batch Size", value=16, step=1) | |
| with col3: | |
| epochs = st.number_input("Epochs", value=3, step=1) | |
| if st.button("Start Fine-tuning"): | |
| if data is None: | |
| st.error("Please upload a dataset first!") | |
| else: | |
| st.info("Starting fine-tuning...") | |
| progress_bar = st.progress(0) | |
| training_placeholder = st.empty() | |
| loss_values = [] | |
| accuracy_values = [] | |
| # Simulate training loop (replace with your actual training code) | |
| for epoch, losses, accs in simulate_training(epochs): | |
| fig = plot_training_metrics(epoch, losses, accs) | |
| training_placeholder.pyplot(fig) | |
| progress_bar.progress(epoch/epochs) | |
| st.success("Fine-tuning completed!") | |
| # Save the fine-tuned model (for demonstration, saving state_dict) | |
| if model: | |
| torch.save(model.state_dict(), "fine_tuned_model.pt") | |
| with open("fine_tuned_model.pt", "rb") as f: | |
| st.download_button("Download Fine-tuned Model", data=f, file_name="fine_tuned_model.pt", mime="application/octet-stream") | |
| else: | |
| st.error("Model not loaded. Cannot save.") | |
| # ------------------------------- | |
| # Quantization Tab | |
| # ------------------------------- | |
| with tabs[1]: | |
| st.header("Model Quantization") | |
| st.markdown("Quantize your model to optimize for inference performance.") | |
| quantize_choice = st.radio("Select Quantization Type", options=["Dynamic Quantization"], index=0) | |
| if st.button("Apply Quantization"): | |
| with st.spinner("Applying quantization..."): | |
| quantized_model = quantize_model(model) | |
| st.success("Model quantized successfully!") | |
| torch.save(quantized_model.state_dict(), "quantized_model.pt") | |
| with open("quantized_model.pt", "rb") as f: | |
| st.download_button("Download Quantized Model", data=f, file_name="quantized_model.pt", mime="application/octet-stream") | |
| # ------------------------------- | |
| # Model Conversion Tab | |
| # ------------------------------- | |
| with tabs[2]: | |
| st.header("Model Conversion") | |
| st.markdown("Convert your model to a different format for deployment or optimization.") | |
| conversion_option = st.selectbox("Select Conversion Format", options=["TorchScript", "ONNX"]) | |
| if st.button("Convert Model"): | |
| if conversion_option == "TorchScript": | |
| with st.spinner("Converting to TorchScript..."): | |
| ts_model = convert_to_torchscript(model) | |
| ts_model.save("model_ts.pt") | |
| st.success("Converted to TorchScript!") | |
| with open("model_ts.pt", "rb") as f: | |
| st.download_button("Download TorchScript Model", data=f, file_name="model_ts.pt", mime="application/octet-stream") | |
| elif conversion_option == "ONNX": | |
| with st.spinner("Converting to ONNX..."): | |
| onnx_path = convert_to_onnx(model, "model.onnx") | |
| st.success("Converted to ONNX!") | |
| with open(onnx_path, "rb") as f: | |
| st.download_button("Download ONNX Model", data=f, file_name="model.onnx", mime="application/octet-stream") | |
| # ------------------------------- | |
| # Response Generation Section | |
| # ------------------------------- | |
| st.header("Generate Responses with Fine-Tuned Model") | |
| st.markdown("Use the fine-tuned model to generate text responses based on your prompts.") | |
| # Check if the fine-tuned model exists | |
| if os.path.exists("fine_tuned_model.pt"): | |
| # Load the fine-tuned model | |
| model = load_finetuned_model(model, "fine_tuned_model.pt") | |
| # Input prompt for generating responses | |
| prompt = st.text_area("Enter a prompt:", "Once upon a time...") | |
| # Max length slider | |
| max_length = st.slider("Max Response Length", min_value=50, max_value=500, value=200, step=10) | |
| if st.button("Generate Response"): | |
| with st.spinner("Generating response..."): | |
| response = generate_response(prompt, model, tokenizer, max_length) | |
| st.success("Generated Response:") | |
| st.write(response) | |
| else: | |
| st.warning("Fine-tuned model not found. Please fine-tune the model first.") | |
| # ------------------------------- | |
| # Optional: Cloud Integration Snippet | |
| # ------------------------------- | |
| st.header("Cloud Integration") | |
| st.markdown(""" | |
| For large-scale training or model storage, consider integrating with Google Cloud Storage or Vertex AI. | |
| Below is an example snippet for uploading your model to GCS: | |
| """) | |
| st.code(""" | |
| from google.cloud import storage | |
| def upload_to_gcs(bucket_name, source_file_name, destination_blob_name): | |
| storage_client = storage.Client() | |
| bucket = storage_client.bucket(bucket_name) | |
| blob = bucket.blob(destination_blob_name) | |
| blob.upload_from_filename(source_file_name) | |
| print(f"Uploaded {source_file_name} to {destination_blob_name}") | |
| # Example usage: | |
| # upload_to_gcs("your-bucket-name", "fine_tuned_model.pt", "models/fine_tuned_model.pt") | |
| """, language="python") | |