diff --git a/agent-schema.json b/agent-schema.json index e92890edb..e70a17de6 100644 --- a/agent-schema.json +++ b/agent-schema.json @@ -1781,7 +1781,7 @@ }, "timeout": { "type": "integer", - "description": "Timeout in seconds for the fetch tool", + "description": "HTTP timeout in seconds (valid for type 'fetch', 'api', and 'openapi'). Defaults to 30 seconds when omitted.", "minimum": 1 }, "allowed_domains": { @@ -1825,7 +1825,7 @@ }, "allow_private_ips": { "type": "boolean", - "description": "Opt in to dialling non-public IP addresses (valid for type 'fetch' and remote MCP toolsets). By default protected HTTP clients refuse connections \u2014 after DNS resolution, so DNS rebinding is also blocked \u2014 to loopback, RFC1918 private ranges, link-local (including the cloud metadata endpoint at 169.254.169.254), multicast and the unspecified address. Set this to true when an agent legitimately needs to call internal services. For fetch, 'allowed_domains' / 'blocked_domains' are evaluated independently and still apply." + "description": "Opt in to dialling non-public IP addresses (valid for type 'fetch', 'api', 'openapi', and remote MCP toolsets). By default protected HTTP clients refuse connections \u2014 after DNS resolution, so DNS rebinding is also blocked \u2014 to loopback, RFC1918 private ranges, link-local (including the cloud metadata endpoint at 169.254.169.254), multicast and the unspecified address. Set this to true when an agent legitimately needs to call internal services. For fetch, 'allowed_domains' / 'blocked_domains' are evaluated independently and still apply." }, "url": { "type": "string", diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 11b6a2bdb..104f0caa4 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -883,7 +883,8 @@ type Toolset struct { // For the `lsp` tool FileTypes []string `json:"file_types,omitempty"` - // For the `fetch` tool + // HTTP timeout in seconds for `fetch`, `api`, and `openapi` toolsets. + // Defaults to 30 seconds when omitted. Timeout int `json:"timeout,omitempty"` // For the `fetch` tool - allow-list of domains the tool is permitted to fetch. @@ -898,8 +899,8 @@ type Toolset struct { // `allowed_domains`. BlockedDomains []string `json:"blocked_domains,omitempty" yaml:"blocked_domains,omitempty"` - // For the `fetch` tool and remote `mcp` toolsets — opt in to dialling - // non-public IP addresses. + // For the `fetch`, `api`, `openapi` and remote `mcp` toolsets — opt in to + // dialling non-public IP addresses. // // By default, protected HTTP clients refuse connections (after DNS // resolution, so DNS rebinding is also blocked) to loopback (127/8, diff --git a/pkg/config/latest/validate.go b/pkg/config/latest/validate.go index c6170a4dc..d6ac582c1 100644 --- a/pkg/config/latest/validate.go +++ b/pkg/config/latest/validate.go @@ -145,8 +145,8 @@ func (t *Toolset) validate() error { if len(t.BlockedDomains) > 0 && t.Type != "fetch" { return errors.New("blocked_domains can only be used with type 'fetch'") } - if t.AllowPrivateIPsEnabled() && t.Type != "fetch" && t.Type != "mcp" { - return errors.New("allow_private_ips can only be used with type 'fetch' or remote MCP toolsets") + if t.AllowPrivateIPsEnabled() && t.Type != "fetch" && t.Type != "mcp" && t.Type != "api" && t.Type != "openapi" { + return errors.New("allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets") } if len(t.AllowedDomains) > 0 && len(t.BlockedDomains) > 0 { return errors.New("allowed_domains and blocked_domains are mutually exclusive") @@ -235,7 +235,7 @@ func (t *Toolset) validate() error { return errors.New("either command, remote or ref must be set, but only one of those") } if t.AllowPrivateIPsEnabled() && t.Remote.URL == "" && t.Ref == "" { - return errors.New("allow_private_ips can only be used with type 'fetch' or remote MCP toolsets") + return errors.New("allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets") } if t.Remote.OAuth != nil { if t.Remote.URL == "" { diff --git a/pkg/config/toolset_validate_test.go b/pkg/config/toolset_validate_test.go index 693f0a8e1..4dfd2c1e0 100644 --- a/pkg/config/toolset_validate_test.go +++ b/pkg/config/toolset_validate_test.go @@ -294,7 +294,7 @@ agents: - type: shell allow_private_ips: true `, - wantErr: "allow_private_ips can only be used with type 'fetch' or remote MCP toolsets", + wantErr: "allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets", }, { name: "allow_private_ips on fetch toolset is accepted", @@ -305,6 +305,34 @@ agents: toolsets: - type: fetch allow_private_ips: true +`, + }, + { + name: "allow_private_ips on api toolset is accepted", + config: ` +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: api + allow_private_ips: true + api_config: + name: probe + method: GET + endpoint: http://10.0.0.1/health + instruction: probe +`, + }, + { + name: "allow_private_ips on openapi toolset is accepted", + config: ` +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: openapi + url: http://10.0.0.1/openapi.json + allow_private_ips: true `, }, { @@ -332,7 +360,7 @@ agents: allow_private_ips: true command: docker `, - wantErr: "allow_private_ips can only be used with type 'fetch' or remote MCP toolsets", + wantErr: "allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets", }, { name: "empty allowed_domains entry is rejected", diff --git a/pkg/tools/builtin/api/api.go b/pkg/tools/builtin/api/api.go index 6032f084c..9792e6d5b 100644 --- a/pkg/tools/builtin/api/api.go +++ b/pkg/tools/builtin/api/api.go @@ -25,11 +25,12 @@ type ToolSet struct { config latest.APIToolConfig expander *js.Expander - // unsafe disables SSRF dial-time protection. Only set by the test-only - // constructor in api_test.go (httptest.NewServer binds to 127.0.0.1). - unsafe bool + timeout time.Duration + allowPrivateIPs bool } +const defaultHTTPTimeout = 30 * time.Second + // Verify interface compliance var ( _ tools.ToolSet = (*ToolSet)(nil) @@ -37,7 +38,7 @@ var ( ) func (t *ToolSet) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) { - client := httpclient.NewSafeClient(30*time.Second, t.unsafe) + client := httpclient.NewSafeClient(t.timeout, t.allowPrivateIPs) endpoint := t.config.Endpoint var reqBody io.Reader = http.NoBody @@ -100,14 +101,42 @@ func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *confi toolset.APIConfig.Endpoint = expander.Expand(ctx, toolset.APIConfig.Endpoint, nil) toolset.APIConfig.Headers = expander.ExpandMap(ctx, toolset.APIConfig.Headers) - return New(toolset.APIConfig, expander), nil + var opts []Option + if toolset.Timeout > 0 { + opts = append(opts, WithTimeout(time.Duration(toolset.Timeout)*time.Second)) + } + if toolset.AllowPrivateIPsEnabled() { + opts = append(opts, WithAllowPrivateIPs(true)) + } + return New(toolset.APIConfig, expander, opts...), nil } -func New(apiConfig latest.APIToolConfig, expander *js.Expander) *ToolSet { - return &ToolSet{ +// Option configures an api ToolSet. +type Option func(*ToolSet) + +// WithTimeout overrides the default 30s HTTP client timeout. +func WithTimeout(d time.Duration) Option { + return func(t *ToolSet) { t.timeout = d } +} + +// WithAllowPrivateIPs disables SSRF dial-time protection so the api tool +// may dial loopback / RFC1918 / link-local addresses. Operators opt in via +// `allow_private_ips: true` when the configured endpoint legitimately +// targets internal services. Tests use this to talk to httptest.NewServer. +func WithAllowPrivateIPs(allow bool) Option { + return func(t *ToolSet) { t.allowPrivateIPs = allow } +} + +func New(apiConfig latest.APIToolConfig, expander *js.Expander, opts ...Option) *ToolSet { + t := &ToolSet{ config: apiConfig, expander: expander, + timeout: defaultHTTPTimeout, + } + for _, opt := range opts { + opt(t) } + return t } func (t *ToolSet) Instructions() string { diff --git a/pkg/tools/builtin/api/api_test.go b/pkg/tools/builtin/api/api_test.go index 6fb4bb786..5d2ba77d8 100644 --- a/pkg/tools/builtin/api/api_test.go +++ b/pkg/tools/builtin/api/api_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -23,9 +24,7 @@ import ( // 127.0.0.1). It is defined in a *_test.go file so it is not compiled // into release binaries. Production callers must use [New]. func newAPIToolForTest(config latest.APIToolConfig, expander *js.Expander) *ToolSet { - t := New(config, expander) - t.unsafe = true - return t + return New(config, expander, WithAllowPrivateIPs(true)) } type testServer struct { @@ -235,6 +234,39 @@ func TestAPITool_RejectsLocalAddresses(t *testing.T) { } } +// TestAPITool_AllowPrivateIPsRestoresLegacyBehaviour verifies that the +// allow_private_ips opt-in actually disables the SSRF dial filter. +func TestAPITool_AllowPrivateIPsRestoresLegacyBehaviour(t *testing.T) { + t.Parallel() + ts := getTestServer(t) + + tool := New(latest.APIToolConfig{ + Method: http.MethodGet, + Endpoint: ts.serverURL, + }, testExpander(), WithAllowPrivateIPs(true)) + + _, err := tool.callTool(t.Context(), tools.ToolCall{}) + require.NoError(t, err, "WithAllowPrivateIPs(true) must permit dialling 127.0.0.1") +} + +// TestAPITool_TimeoutHonoured confirms WithTimeout caps the request. +func TestAPITool_TimeoutHonoured(t *testing.T) { + t.Parallel() + + slow := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + })) + t.Cleanup(slow.Close) + + tool := New(latest.APIToolConfig{ + Method: http.MethodGet, + Endpoint: slow.URL, + }, testExpander(), WithAllowPrivateIPs(true), WithTimeout(50*time.Millisecond)) + + _, err := tool.callTool(t.Context(), tools.ToolCall{}) + require.Error(t, err) +} + type noopEnvProvider struct{} func (noopEnvProvider) Get(context.Context, string) (string, bool) { return "", false } diff --git a/pkg/tools/builtin/openapi/openapi.go b/pkg/tools/builtin/openapi/openapi.go index 26d4ed14d..10d9edb5f 100644 --- a/pkg/tools/builtin/openapi/openapi.go +++ b/pkg/tools/builtin/openapi/openapi.go @@ -27,7 +27,7 @@ import ( "github.com/docker/docker-agent/pkg/useragent" ) -const httpTimeout = 30 * time.Second +const defaultHTTPTimeout = 30 * time.Second // CreateToolSet is used by the tools registry. func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) { @@ -36,7 +36,14 @@ func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *confi specURL := expander.Expand(ctx, toolset.URL, nil) headers := expander.ExpandMap(ctx, toolset.Headers) - return New(specURL, headers), nil + var opts []Option + if toolset.Timeout > 0 { + opts = append(opts, WithTimeout(time.Duration(toolset.Timeout)*time.Second)) + } + if toolset.AllowPrivateIPsEnabled() { + opts = append(opts, WithAllowPrivateIPs(true)) + } + return New(specURL, headers, opts...), nil } // ToolSet generates HTTP tools from an OpenAPI specification. @@ -44,11 +51,8 @@ type ToolSet struct { specURL string headers map[string]string - // unsafe disables SSRF dial-time protection on both the spec fetch - // and the generated tools' HTTP calls. It is only set by the - // test-only constructor in openapi_test.go (which exists because - // tests use httptest.NewServer that binds to 127.0.0.1). - unsafe bool + timeout time.Duration + allowPrivateIPs bool } // Verify interface compliance. @@ -57,12 +61,34 @@ var ( _ tools.Instructable = (*ToolSet)(nil) ) +// Option configures an openapi ToolSet. +type Option func(*ToolSet) + +// WithTimeout overrides the default 30s HTTP client timeout used both for +// fetching the spec and for the generated tools' HTTP calls. +func WithTimeout(d time.Duration) Option { + return func(t *ToolSet) { t.timeout = d } +} + +// WithAllowPrivateIPs disables SSRF dial-time protection on both the spec +// fetch and the generated tools' HTTP calls. Operators opt in via +// `allow_private_ips: true` when the spec or its servers legitimately +// target internal services. Tests use this to talk to httptest.NewServer. +func WithAllowPrivateIPs(allow bool) Option { + return func(t *ToolSet) { t.allowPrivateIPs = allow } +} + // New creates a new OpenAPI toolset from the given spec URL. -func New(specURL string, headers map[string]string) *ToolSet { - return &ToolSet{ +func New(specURL string, headers map[string]string, opts ...Option) *ToolSet { + t := &ToolSet{ specURL: specURL, headers: headers, + timeout: defaultHTTPTimeout, } + for _, opt := range opts { + opt(t) + } + return t } // Instructions returns usage instructions for the OpenAPI toolset. @@ -93,7 +119,7 @@ func (t *ToolSet) fetchSpec(ctx context.Context) (*v3.Document, error) { req.Header.Set("Accept", "application/json") setHeaders(req, t.headers) - resp, err := httpclient.NewSafeClient(httpTimeout, t.unsafe).Do(req) + resp, err := httpclient.NewSafeClient(t.timeout, t.allowPrivateIPs).Do(req) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } @@ -220,11 +246,12 @@ func (t *ToolSet) operationToTool(baseURL, path, method string, op *v3.Operation Description: desc, Parameters: schema, Handler: tools.NewHandler((&openAPIHandler{ - baseURL: baseURL, - path: path, - method: method, - headers: t.headers, - unsafe: t.unsafe, + baseURL: baseURL, + path: path, + method: method, + headers: t.headers, + timeout: t.timeout, + allowPrivateIPs: t.allowPrivateIPs, }).callTool), Annotations: tools.ToolAnnotations{ ReadOnlyHint: readOnly, @@ -412,8 +439,9 @@ type openAPIHandler struct { path string method string headers map[string]string - // unsafe disables SSRF dial-time protection. See OpenAPITool.unsafe. - unsafe bool + + timeout time.Duration + allowPrivateIPs bool } type openAPICallArgs map[string]any @@ -446,7 +474,7 @@ func (h *openAPIHandler) callTool(ctx context.Context, params openAPICallArgs) ( req.Header.Set("Accept", "application/json") setHeaders(req, h.headers) - resp, err := httpclient.NewSafeClient(httpTimeout, h.unsafe).Do(req) + resp, err := httpclient.NewSafeClient(h.timeout, h.allowPrivateIPs).Do(req) if err != nil { return nil, fmt.Errorf("request failed: %w", err) } diff --git a/pkg/tools/builtin/openapi/openapi_test.go b/pkg/tools/builtin/openapi/openapi_test.go index d56b02e8a..cabb19afa 100644 --- a/pkg/tools/builtin/openapi/openapi_test.go +++ b/pkg/tools/builtin/openapi/openapi_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -21,9 +22,7 @@ import ( // compiled into release binaries. Production callers must use // [New]. func newOpenAPIToolForTest(specURL string, headers map[string]string) *ToolSet { - t := New(specURL, headers) - t.unsafe = true - return t + return New(specURL, headers, WithAllowPrivateIPs(true)) } const petStoreSpec = `{ @@ -412,6 +411,33 @@ func TestOpenAPITool_RejectsLocalSpecURL(t *testing.T) { } } +// TestOpenAPITool_AllowPrivateIPsRestoresLegacyBehaviour verifies that +// WithAllowPrivateIPs(true) lets the spec fetch reach a loopback host. +func TestOpenAPITool_AllowPrivateIPsRestoresLegacyBehaviour(t *testing.T) { + t.Parallel() + + specServer := serveSpec(t, petStoreSpec) + + _, err := New(specServer.URL+"/openapi.json", nil, WithAllowPrivateIPs(true)).Tools(t.Context()) + require.NoError(t, err, "WithAllowPrivateIPs(true) must permit dialling 127.0.0.1") +} + +// TestOpenAPITool_TimeoutHonoured confirms WithTimeout caps the spec fetch. +func TestOpenAPITool_TimeoutHonoured(t *testing.T) { + t.Parallel() + + slow := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + })) + t.Cleanup(slow.Close) + + _, err := New(slow.URL+"/openapi.json", nil, + WithAllowPrivateIPs(true), + WithTimeout(50*time.Millisecond), + ).Tools(t.Context()) + require.Error(t, err) +} + func TestOpenAPITool_RejectsLocalSpecServerURL(t *testing.T) { // Even when the spec itself comes from a public URL, the malicious // `servers[].url` it advertises must not be silently dialled. We @@ -442,9 +468,10 @@ func TestOpenAPITool_RejectsLocalSpecServerURL(t *testing.T) { require.NoError(t, err) require.Len(t, toolsList, 1) - // Even though the spec was fetched in unsafe mode, the generated - // handler still inherits the unsafe flag — so for the real safety - // guarantee we re-run the operation through the production path. + // The test constructor opted into private IPs for the spec fetch, + // and that opt-in propagates to the generated handlers — so to + // validate the production guarantee we re-run the operation through + // a freshly-constructed production client (default-deny). prod := New(specServer.URL+"/openapi.json", nil) prodTools, err := prod.Tools(t.Context()) require.Error(t, err, "production constructor must refuse a loopback spec server")