Final_Assignment_Template / serve_test.py
mjschock's picture
Enhance serve.py with fine-tuning job management, including job creation, status tracking, and training process in a separate thread. Update serve_test.py to include a test for fine-tuning functionality. Modify .gitignore to exclude model files. This update improves model training capabilities and API integration.
145385b unverified
import json
import os
import time
from openai import OpenAI
# Initialize the OpenAI client with the local server
client = OpenAI(
base_url="http://localhost:8000/v1",
api_key="not-needed", # API key is not needed for local server
)
def test_chat_completion():
try:
print("Sending chat completion request...")
response = client.chat.completions.create(
model="unsloth/SmolLM2-135M-Instruct-bnb-4bit",
messages=[{"role": "user", "content": "Hello"}],
temperature=0.7,
max_tokens=50,
)
# Print the response
print("\nResponse:")
print(response.choices[0].message.content)
# Print full response object for debugging
print("\nFull response object:")
print(json.dumps(response.model_dump(), indent=2))
except Exception as e:
print(f"Error occurred: {str(e)}")
import traceback
print("\nFull traceback:")
print(traceback.format_exc())
def test_fine_tuning():
try:
# Create a sample training file
training_data = {
"conversations": [
{
"from": "human",
"value": "What is the capital of France?",
},
{
"from": "gpt",
"value": "The capital of France is Paris.",
},
]
}
training_file = "training_data.json"
with open(training_file, "w") as f:
json.dump(training_data, f)
print("\nCreating fine-tuning job...")
job = client.fine_tuning.jobs.create(
training_file=training_file,
model="unsloth/SmolLM2-135M-Instruct-bnb-4bit",
)
print(f"Created job: {job.id}")
# Wait for job to start
print("\nWaiting for job to start...")
time.sleep(2)
# List jobs
print("\nListing fine-tuning jobs...")
jobs = client.fine_tuning.jobs.list()
print(f"Found {len(jobs.data)} jobs")
# Get job status
print("\nGetting job status...")
job = client.fine_tuning.jobs.retrieve(job.id)
print(f"Job status: {job.status}")
# Wait for job to complete or fail
print("\nWaiting for job to complete...")
while job.status in ["created", "running"]:
time.sleep(5)
job = client.fine_tuning.jobs.retrieve(job.id)
print(f"Job status: {job.status}")
# Clean up
os.remove(training_file)
except Exception as e:
print(f"Error occurred: {str(e)}")
import traceback
print("\nFull traceback:")
print(traceback.format_exc())
if __name__ == "__main__":
print("Testing chat completions endpoint...")
test_chat_completion()
print("\nTesting fine-tuning endpoints...")
test_fine_tuning()