evolutiontransformer / tests /test_hf_api.py
tcmmichaelb139's picture
list_model tests.
66feb21
import pytest
import requests
import time
import re
def get_final_answer(text: str) -> int | None:
numbers = re.findall(r"\d+", text)
return int(numbers[-1]) if numbers else None
BASE_URL = "https://tcmmichaelb139-evolutiontransformer.hf.space"
@pytest.fixture
def session():
return requests.Session()
def await_task_completion(session, task_id, timeout=60):
start_time = time.time()
while time.time() - start_time < timeout:
status_response = session.get(f"{BASE_URL}/tasks/{task_id}")
print(status_response.json())
if status_response.status_code == 500:
return {"error": status_response.json().get("detail", "Unknown error")}
assert status_response.status_code == 200
status_data = status_response.json()
if status_data["status"] == "SUCCESS":
return status_data["result"]
time.sleep(2)
else:
pytest.fail(
f"Task {task_id} did not complete within the {timeout}-second timeout."
)
return None
def test_generate_endpoint_svamp(session):
"""
Tests inference on svamp
"""
response = session.post(
f"{BASE_URL}/generate",
json={
"model_name": "svamp",
"prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
"max_new_tokens": 50,
"temperature": 0.7,
},
)
assert response.status_code == 200
data = response.json()
assert "task_id" in data
task_id = data["task_id"]
final_result = await_task_completion(session, task_id)
assert "response" in final_result
output_text = final_result["response"]
answer = get_final_answer(output_text)
assert answer == 14
def test_merge_then_inference_svamp_1(session):
"""
Tests merging then inference for svamp dataset
"""
merge_response = session.post(
f"{BASE_URL}/merge",
json={
"model1_name": "svamp",
"model2_name": "tinystories",
"layer_recipe": [[(i, 0, 1.0)] for i in range(24)],
"embedding_lambdas": [1.0, 1.0],
"linear_lambdas": [1.0, 1.0],
"merged_name": "svamp_merged",
},
)
assert merge_response.status_code == 200
merge_data = merge_response.json()
assert "task_id" in merge_data
merge_task_id = merge_data["task_id"]
merge_status_data = await_task_completion(session, merge_task_id)
model_name = merge_status_data["response"]
time.sleep(5)
generate_response = session.post(
f"{BASE_URL}/generate",
json={
"model_name": model_name,
"prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
"max_new_tokens": 50,
"temperature": 0.7,
},
)
assert generate_response.status_code == 200
generate_data = generate_response.json()
assert "task_id" in generate_data
generate_task_id = generate_data["task_id"]
final_result = await_task_completion(session, generate_task_id)
assert "response" in final_result
output_text = final_result["response"]
answer = get_final_answer(output_text)
assert answer == 14
def test_merge_then_inference_svamp_2(session):
"""
Tests merging then inference for svamp dataset
"""
merge_repsonse = session.post(
f"{BASE_URL}/merge",
json={
"model1_name": "svamp",
"model2_name": "tinystories",
"layer_recipe": [[(i % 24, 0, 1.0 if i < 24 else 0.5)] for i in range(48)],
"embedding_lambdas": [1.0, 1.0],
"linear_lambdas": [1.0, 1.0],
"merged_name": "svamp_merged",
},
)
assert merge_repsonse.status_code == 200
merge_data = merge_repsonse.json()
assert "task_id" in merge_data
merge_task_id = merge_data["task_id"]
merge_status_data = await_task_completion(session, merge_task_id)
model_name = merge_status_data["response"]
merge_response2 = session.post(
f"{BASE_URL}/merge",
json={
"model1_name": model_name,
"model2_name": "tinystories",
"layer_recipe": [[(i, 1, 0.25)] for i in range(24)],
"embedding_lambdas": [0.0, 0.0],
"linear_lambdas": [0.0, 0.0],
"merged_name": "svamp_merged",
},
)
assert merge_response2.status_code == 200
merge_data2 = merge_response2.json()
assert "task_id" in merge_data2
merge_task_id2 = merge_data2["task_id"]
merge_status_data2 = await_task_completion(session, merge_task_id2)
model_name2 = merge_status_data2["response"]
time.sleep(5)
generate_response = session.post(
f"{BASE_URL}/generate",
json={
"model_name": model_name2,
"prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
"max_new_tokens": 50,
"temperature": 0.7,
},
)
assert generate_response.status_code == 200
generate_data = generate_response.json()
assert "task_id" in generate_data
generate_task_id = generate_data["task_id"]
final_result = await_task_completion(session, generate_task_id)
assert "response" in final_result
output_text = final_result["response"]
answer = get_final_answer(output_text)
assert answer == 14
def test_merge_two_children_then_merge(session):
"""
Tests creating two children and merging them
"""
merge_response1 = session.post(
f"{BASE_URL}/merge",
json={
"model1_name": "svamp",
"model2_name": "tinystories",
"layer_recipe": [[(i, 0, 0.8)] for i in range(12)]
+ [[(i, 1, 0.6)] for i in range(12)],
"embedding_lambdas": [0.7, 0.3],
"linear_lambdas": [0.8, 0.2],
"merged_name": "child1",
},
)
assert merge_response1.status_code == 200
merge_data1 = merge_response1.json()
assert "task_id" in merge_data1
merge_task_id1 = merge_data1["task_id"]
merge_status_data1 = await_task_completion(session, merge_task_id1)
child1_name = merge_status_data1["response"]
merge_response2 = session.post(
f"{BASE_URL}/merge",
json={
"model1_name": "svamp",
"model2_name": "tinystories",
"layer_recipe": [[(i, 1, 0.9)] for i in range(8)]
+ [[(i, 0, 0.4)] for i in range(16)],
"embedding_lambdas": [0.2, 0.9],
"linear_lambdas": [0.3, 0.7],
"merged_name": "child2",
},
)
assert merge_response2.status_code == 200
merge_data2 = merge_response2.json()
assert "task_id" in merge_data2
merge_task_id2 = merge_data2["task_id"]
merge_status_data2 = await_task_completion(session, merge_task_id2)
child2_name = merge_status_data2["response"]
merge_response3 = session.post(
f"{BASE_URL}/merge",
json={
"model1_name": child1_name,
"model2_name": child2_name,
"layer_recipe": [[(i, 0, 0.6), (i, 1, 0.4)] for i in range(24)],
"embedding_lambdas": [0.5, 0.5],
"linear_lambdas": [0.6, 0.4],
"merged_name": "final_merged",
},
)
assert merge_response3.status_code == 200
merge_data3 = merge_response3.json()
assert "task_id" in merge_data3
merge_task_id3 = merge_data3["task_id"]
merge_status_data3 = await_task_completion(session, merge_task_id3)
final_model_name = merge_status_data3["response"]
time.sleep(5)
number_of_models = session.post(f"{BASE_URL}/list_models")
assert number_of_models.status_code == 200
number_of_models_data = number_of_models.json()
assert "task_id" in number_of_models_data
number_of_models_task_id = number_of_models_data["task_id"]
number_of_models_result = await_task_completion(session, number_of_models_task_id)
assert "response" in number_of_models_result
models = number_of_models_result["response"]
print(models)
assert len(models) == 5
generate_response = session.post(
f"{BASE_URL}/generate",
json={
"model_name": final_model_name,
"prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
"max_new_tokens": 50,
"temperature": 0.7,
},
)
assert generate_response.status_code == 200
generate_data = generate_response.json()
assert "task_id" in generate_data
generate_task_id = generate_data["task_id"]
final_result = await_task_completion(session, generate_task_id)
assert "response" in final_result
output_text = final_result["response"]
answer = get_final_answer(output_text)
assert answer == 14
def test_merge_fail(session):
"""
Tests merging with too many layers
"""
merge_repsonse = session.post(
f"{BASE_URL}/merge",
json={
"model1_name": "svamp",
"model2_name": "tinystories",
"layer_recipe": [[(i, 0, 1.0)] for i in range(50)],
"embedding_lambdas": [1.0, 1.0],
"linear_lambdas": [1.0, 1.0],
"merged_name": "svamp_merged",
},
)
assert merge_repsonse.status_code == 200
merge_data = merge_repsonse.json()
assert "task_id" in merge_data
merge_task_id = merge_data["task_id"]
merge_status_data = await_task_completion(session, merge_task_id)
assert "response" not in merge_status_data
assert "error" in merge_status_data
assert "Layer recipe too long" in merge_status_data["error"]