Spaces:
Build error
Build error
| package github | |
| import ( | |
| "encoding/json" | |
| "net/http" | |
| "testing" | |
| "github.com/mark3labs/mcp-go/mcp" | |
| "github.com/stretchr/testify/assert" | |
| "github.com/stretchr/testify/require" | |
| ) | |
| type expectations struct { | |
| path string | |
| queryParams map[string]string | |
| requestBody any | |
| } | |
| // expect is a helper function to create a partial mock that expects various | |
| // request behaviors, such as path, query parameters, and request body. | |
| func expect(t *testing.T, e expectations) *partialMock { | |
| return &partialMock{ | |
| t: t, | |
| expectedPath: e.path, | |
| expectedQueryParams: e.queryParams, | |
| expectedRequestBody: e.requestBody, | |
| } | |
| } | |
| // expectPath is a helper function to create a partial mock that expects a | |
| // request with the given path, with the ability to chain a response handler. | |
| func expectPath(t *testing.T, expectedPath string) *partialMock { | |
| return &partialMock{ | |
| t: t, | |
| expectedPath: expectedPath, | |
| } | |
| } | |
| // expectQueryParams is a helper function to create a partial mock that expects a | |
| // request with the given query parameters, with the ability to chain a response handler. | |
| func expectQueryParams(t *testing.T, expectedQueryParams map[string]string) *partialMock { | |
| return &partialMock{ | |
| t: t, | |
| expectedQueryParams: expectedQueryParams, | |
| } | |
| } | |
| // expectRequestBody is a helper function to create a partial mock that expects a | |
| // request with the given body, with the ability to chain a response handler. | |
| func expectRequestBody(t *testing.T, expectedRequestBody any) *partialMock { | |
| return &partialMock{ | |
| t: t, | |
| expectedRequestBody: expectedRequestBody, | |
| } | |
| } | |
| type partialMock struct { | |
| t *testing.T | |
| expectedPath string | |
| expectedQueryParams map[string]string | |
| expectedRequestBody any | |
| } | |
| func (p *partialMock) andThen(responseHandler http.HandlerFunc) http.HandlerFunc { | |
| p.t.Helper() | |
| return func(w http.ResponseWriter, r *http.Request) { | |
| if p.expectedPath != "" { | |
| require.Equal(p.t, p.expectedPath, r.URL.Path) | |
| } | |
| if p.expectedQueryParams != nil { | |
| require.Equal(p.t, len(p.expectedQueryParams), len(r.URL.Query())) | |
| for k, v := range p.expectedQueryParams { | |
| require.Equal(p.t, v, r.URL.Query().Get(k)) | |
| } | |
| } | |
| if p.expectedRequestBody != nil { | |
| var unmarshaledRequestBody any | |
| err := json.NewDecoder(r.Body).Decode(&unmarshaledRequestBody) | |
| require.NoError(p.t, err) | |
| require.Equal(p.t, p.expectedRequestBody, unmarshaledRequestBody) | |
| } | |
| responseHandler(w, r) | |
| } | |
| } | |
| // mockResponse is a helper function to create a mock HTTP response handler | |
| // that returns a specified status code and marshaled body. | |
| func mockResponse(t *testing.T, code int, body interface{}) http.HandlerFunc { | |
| t.Helper() | |
| return func(w http.ResponseWriter, _ *http.Request) { | |
| w.WriteHeader(code) | |
| // Some tests do not expect to return a JSON object, such as fetching a raw pull request diff, | |
| // so allow strings to be returned directly. | |
| s, ok := body.(string) | |
| if ok { | |
| _, _ = w.Write([]byte(s)) | |
| return | |
| } | |
| b, err := json.Marshal(body) | |
| require.NoError(t, err) | |
| _, _ = w.Write(b) | |
| } | |
| } | |
| // createMCPRequest is a helper function to create a MCP request with the given arguments. | |
| func createMCPRequest(args any) mcp.CallToolRequest { | |
| return mcp.CallToolRequest{ | |
| Params: struct { | |
| Name string `json:"name"` | |
| Arguments any `json:"arguments,omitempty"` | |
| Meta *mcp.Meta `json:"_meta,omitempty"` | |
| }{ | |
| Arguments: args, | |
| }, | |
| } | |
| } | |
| // getTextResult is a helper function that returns a text result from a tool call. | |
| func getTextResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent { | |
| t.Helper() | |
| assert.NotNil(t, result) | |
| require.Len(t, result.Content, 1) | |
| require.IsType(t, mcp.TextContent{}, result.Content[0]) | |
| textContent := result.Content[0].(mcp.TextContent) | |
| assert.Equal(t, "text", textContent.Type) | |
| return textContent | |
| } | |
| func getErrorResult(t *testing.T, result *mcp.CallToolResult) mcp.TextContent { | |
| res := getTextResult(t, result) | |
| require.True(t, result.IsError, "expected tool call result to be an error") | |
| return res | |
| } | |
| // getTextResourceResult is a helper function that returns a text result from a tool call. | |
| func getTextResourceResult(t *testing.T, result *mcp.CallToolResult) mcp.TextResourceContents { | |
| t.Helper() | |
| assert.NotNil(t, result) | |
| require.Len(t, result.Content, 2) | |
| content := result.Content[1] | |
| require.IsType(t, mcp.EmbeddedResource{}, content) | |
| resource := content.(mcp.EmbeddedResource) | |
| require.IsType(t, mcp.TextResourceContents{}, resource.Resource) | |
| return resource.Resource.(mcp.TextResourceContents) | |
| } | |
| // getBlobResourceResult is a helper function that returns a blob result from a tool call. | |
| func getBlobResourceResult(t *testing.T, result *mcp.CallToolResult) mcp.BlobResourceContents { | |
| t.Helper() | |
| assert.NotNil(t, result) | |
| require.Len(t, result.Content, 2) | |
| content := result.Content[1] | |
| require.IsType(t, mcp.EmbeddedResource{}, content) | |
| resource := content.(mcp.EmbeddedResource) | |
| require.IsType(t, mcp.BlobResourceContents{}, resource.Resource) | |
| return resource.Resource.(mcp.BlobResourceContents) | |
| } | |
| func TestOptionalParamOK(t *testing.T) { | |
| tests := []struct { | |
| name string | |
| args map[string]interface{} | |
| paramName string | |
| expectedVal interface{} | |
| expectedOk bool | |
| expectError bool | |
| errorMsg string | |
| }{ | |
| { | |
| name: "present and correct type (string)", | |
| args: map[string]interface{}{"myParam": "hello"}, | |
| paramName: "myParam", | |
| expectedVal: "hello", | |
| expectedOk: true, | |
| expectError: false, | |
| }, | |
| { | |
| name: "present and correct type (bool)", | |
| args: map[string]interface{}{"myParam": true}, | |
| paramName: "myParam", | |
| expectedVal: true, | |
| expectedOk: true, | |
| expectError: false, | |
| }, | |
| { | |
| name: "present and correct type (number)", | |
| args: map[string]interface{}{"myParam": float64(123)}, | |
| paramName: "myParam", | |
| expectedVal: float64(123), | |
| expectedOk: true, | |
| expectError: false, | |
| }, | |
| { | |
| name: "present but wrong type (string expected, got bool)", | |
| args: map[string]interface{}{"myParam": true}, | |
| paramName: "myParam", | |
| expectedVal: "", // Zero value for string | |
| expectedOk: true, // ok is true because param exists | |
| expectError: true, | |
| errorMsg: "parameter myParam is not of type string, is bool", | |
| }, | |
| { | |
| name: "present but wrong type (bool expected, got string)", | |
| args: map[string]interface{}{"myParam": "true"}, | |
| paramName: "myParam", | |
| expectedVal: false, // Zero value for bool | |
| expectedOk: true, // ok is true because param exists | |
| expectError: true, | |
| errorMsg: "parameter myParam is not of type bool, is string", | |
| }, | |
| { | |
| name: "parameter not present", | |
| args: map[string]interface{}{"anotherParam": "value"}, | |
| paramName: "myParam", | |
| expectedVal: "", // Zero value for string | |
| expectedOk: false, | |
| expectError: false, | |
| }, | |
| } | |
| for _, tc := range tests { | |
| t.Run(tc.name, func(t *testing.T) { | |
| request := createMCPRequest(tc.args) | |
| // Test with string type assertion | |
| if _, isString := tc.expectedVal.(string); isString || tc.errorMsg == "parameter myParam is not of type string, is bool" { | |
| val, ok, err := OptionalParamOK[string](request, tc.paramName) | |
| if tc.expectError { | |
| require.Error(t, err) | |
| assert.Contains(t, err.Error(), tc.errorMsg) | |
| assert.Equal(t, tc.expectedOk, ok) // Check ok even on error | |
| assert.Equal(t, tc.expectedVal, val) // Check zero value on error | |
| } else { | |
| require.NoError(t, err) | |
| assert.Equal(t, tc.expectedOk, ok) | |
| assert.Equal(t, tc.expectedVal, val) | |
| } | |
| } | |
| // Test with bool type assertion | |
| if _, isBool := tc.expectedVal.(bool); isBool || tc.errorMsg == "parameter myParam is not of type bool, is string" { | |
| val, ok, err := OptionalParamOK[bool](request, tc.paramName) | |
| if tc.expectError { | |
| require.Error(t, err) | |
| assert.Contains(t, err.Error(), tc.errorMsg) | |
| assert.Equal(t, tc.expectedOk, ok) // Check ok even on error | |
| assert.Equal(t, tc.expectedVal, val) // Check zero value on error | |
| } else { | |
| require.NoError(t, err) | |
| assert.Equal(t, tc.expectedOk, ok) | |
| assert.Equal(t, tc.expectedVal, val) | |
| } | |
| } | |
| // Test with float64 type assertion (for number case) | |
| if _, isFloat := tc.expectedVal.(float64); isFloat { | |
| val, ok, err := OptionalParamOK[float64](request, tc.paramName) | |
| if tc.expectError { | |
| // This case shouldn't happen for float64 in the defined tests | |
| require.Fail(t, "Unexpected error case for float64") | |
| } else { | |
| require.NoError(t, err) | |
| assert.Equal(t, tc.expectedOk, ok) | |
| assert.Equal(t, tc.expectedVal, val) | |
| } | |
| } | |
| }) | |
| } | |
| } | |