|
|
|
|
|
""" |
|
|
Example client for interacting with the Mistral 7B AHB2APB API |
|
|
""" |
|
|
|
|
|
import requests |
|
|
import json |
|
|
|
|
|
|
|
|
API_BASE_URL = "http://localhost:8000" |
|
|
|
|
|
def check_api_health(): |
|
|
"""Check if the API is running and healthy""" |
|
|
try: |
|
|
response = requests.get(f"{API_BASE_URL}/health") |
|
|
response.raise_for_status() |
|
|
health = response.json() |
|
|
print("β API is healthy!") |
|
|
print(f" Model: {health['model_path']}") |
|
|
print(f" Device: {health['device']}") |
|
|
print(f" Model loaded: {health['model_loaded']}") |
|
|
return True |
|
|
except requests.exceptions.ConnectionError: |
|
|
print("β Cannot connect to API. Is the server running?") |
|
|
print(f" Start it with: python api_server.py") |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"β Error checking API health: {e}") |
|
|
return False |
|
|
|
|
|
def generate(prompt: str, max_length: int = 512, temperature: float = 0.7): |
|
|
"""Generate text using the API""" |
|
|
try: |
|
|
response = requests.post( |
|
|
f"{API_BASE_URL}/api/generate", |
|
|
json={ |
|
|
"prompt": prompt, |
|
|
"max_length": max_length, |
|
|
"temperature": temperature |
|
|
}, |
|
|
timeout=120 |
|
|
) |
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
return result['response'] |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error calling API: {e}") |
|
|
if hasattr(e.response, 'text'): |
|
|
print(f"Response: {e.response.text}") |
|
|
return None |
|
|
|
|
|
def generate_batch(prompts: list, max_length: int = 512, temperature: float = 0.7): |
|
|
"""Generate text for multiple prompts in batch""" |
|
|
try: |
|
|
requests_data = [ |
|
|
{ |
|
|
"prompt": prompt, |
|
|
"max_length": max_length, |
|
|
"temperature": temperature |
|
|
} |
|
|
for prompt in prompts |
|
|
] |
|
|
|
|
|
response = requests.post( |
|
|
f"{API_BASE_URL}/api/generate/batch", |
|
|
json=requests_data, |
|
|
timeout=300 |
|
|
) |
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
return [item['response'] for item in result['results']] |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error calling batch API: {e}") |
|
|
if hasattr(e.response, 'text'): |
|
|
print(f"Response: {e.response.text}") |
|
|
return None |
|
|
|
|
|
def main(): |
|
|
"""Example usage""" |
|
|
print("=" * 70) |
|
|
print("Mistral 7B AHB2APB API Client Example") |
|
|
print("=" * 70) |
|
|
print() |
|
|
|
|
|
|
|
|
if not check_api_health(): |
|
|
return |
|
|
|
|
|
print() |
|
|
print("=" * 70) |
|
|
print("Generating Response") |
|
|
print("=" * 70) |
|
|
print() |
|
|
|
|
|
|
|
|
prompt = "Convert this AHB burst to APB" |
|
|
|
|
|
print(f"Prompt: {prompt}") |
|
|
print() |
|
|
print("Response:") |
|
|
print("-" * 70) |
|
|
|
|
|
response = generate(prompt, max_length=512, temperature=0.7) |
|
|
|
|
|
if response: |
|
|
print(response) |
|
|
print("-" * 70) |
|
|
else: |
|
|
print("Failed to generate response") |
|
|
|
|
|
print() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|