| import json |
| import argparse |
| import sys |
| from handler import EndpointHandler |
|
|
| def test_inference(model_path=".", prompt=None, max_tokens=150, temperature=0.7): |
| """ |
| Test the inference endpoint handler with a sample request. |
| |
| Args: |
| model_path: Path to the model directory |
| prompt: Custom prompt to use (optional) |
| max_tokens: Maximum number of tokens to generate |
| temperature: Temperature for generation |
| """ |
| try: |
| print(f"Initializing handler with model path: {model_path}") |
| handler = EndpointHandler(model_path) |
| |
| |
| if prompt is None: |
| messages = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": "Explain quantum computing in simple terms."} |
| ] |
| else: |
| messages = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": prompt} |
| ] |
| |
| |
| request = { |
| "messages": messages, |
| "max_tokens": max_tokens, |
| "temperature": temperature, |
| "top_p": 0.95 |
| } |
| |
| print("Sending request to handler...") |
| print(f"Request: {json.dumps(request, indent=2)}") |
| |
| |
| response = handler(request) |
| |
| |
| print("\nResponse:") |
| print(json.dumps(response, indent=2)) |
| |
| return response |
| |
| except Exception as e: |
| print(f"Error during inference: {str(e)}", file=sys.stderr) |
| import traceback |
| traceback.print_exc() |
| return {"error": str(e)} |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Test Phi-4 Mini inference") |
| parser.add_argument("--model_path", type=str, default=".", help="Path to the model directory") |
| parser.add_argument("--prompt", type=str, help="Custom prompt to use") |
| parser.add_argument("--max_tokens", type=int, default=150, help="Maximum number of tokens to generate") |
| parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for generation") |
| |
| args = parser.parse_args() |
| test_inference( |
| model_path=args.model_path, |
| prompt=args.prompt, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature |
| ) |