Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions agent-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1781,7 +1781,7 @@
},
"timeout": {
"type": "integer",
"description": "Timeout in seconds for the fetch tool",
"description": "HTTP timeout in seconds (valid for type 'fetch', 'api', and 'openapi'). Defaults to 30 seconds when omitted.",
"minimum": 1
},
"allowed_domains": {
Expand Down Expand Up @@ -1825,7 +1825,7 @@
},
"allow_private_ips": {
"type": "boolean",
"description": "Opt in to dialling non-public IP addresses (valid for type 'fetch' and remote MCP toolsets). By default protected HTTP clients refuse connections \u2014 after DNS resolution, so DNS rebinding is also blocked \u2014 to loopback, RFC1918 private ranges, link-local (including the cloud metadata endpoint at 169.254.169.254), multicast and the unspecified address. Set this to true when an agent legitimately needs to call internal services. For fetch, 'allowed_domains' / 'blocked_domains' are evaluated independently and still apply."
"description": "Opt in to dialling non-public IP addresses (valid for type 'fetch', 'api', 'openapi', and remote MCP toolsets). By default protected HTTP clients refuse connections \u2014 after DNS resolution, so DNS rebinding is also blocked \u2014 to loopback, RFC1918 private ranges, link-local (including the cloud metadata endpoint at 169.254.169.254), multicast and the unspecified address. Set this to true when an agent legitimately needs to call internal services. For fetch, 'allowed_domains' / 'blocked_domains' are evaluated independently and still apply."
},
"url": {
"type": "string",
Expand Down
7 changes: 4 additions & 3 deletions pkg/config/latest/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,8 @@ type Toolset struct {
// For the `lsp` tool
FileTypes []string `json:"file_types,omitempty"`

// For the `fetch` tool
// HTTP timeout in seconds for `fetch`, `api`, and `openapi` toolsets.
// Defaults to 30 seconds when omitted.
Timeout int `json:"timeout,omitempty"`

// For the `fetch` tool - allow-list of domains the tool is permitted to fetch.
Expand All @@ -898,8 +899,8 @@ type Toolset struct {
// `allowed_domains`.
BlockedDomains []string `json:"blocked_domains,omitempty" yaml:"blocked_domains,omitempty"`

// For the `fetch` tool and remote `mcp` toolsets — opt in to dialling
// non-public IP addresses.
// For the `fetch`, `api`, `openapi` and remote `mcp` toolsets — opt in to
// dialling non-public IP addresses.
//
// By default, protected HTTP clients refuse connections (after DNS
// resolution, so DNS rebinding is also blocked) to loopback (127/8,
Expand Down
6 changes: 3 additions & 3 deletions pkg/config/latest/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ func (t *Toolset) validate() error {
if len(t.BlockedDomains) > 0 && t.Type != "fetch" {
return errors.New("blocked_domains can only be used with type 'fetch'")
}
if t.AllowPrivateIPsEnabled() && t.Type != "fetch" && t.Type != "mcp" {
return errors.New("allow_private_ips can only be used with type 'fetch' or remote MCP toolsets")
if t.AllowPrivateIPsEnabled() && t.Type != "fetch" && t.Type != "mcp" && t.Type != "api" && t.Type != "openapi" {
return errors.New("allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets")
}
if len(t.AllowedDomains) > 0 && len(t.BlockedDomains) > 0 {
return errors.New("allowed_domains and blocked_domains are mutually exclusive")
Expand Down Expand Up @@ -235,7 +235,7 @@ func (t *Toolset) validate() error {
return errors.New("either command, remote or ref must be set, but only one of those")
}
if t.AllowPrivateIPsEnabled() && t.Remote.URL == "" && t.Ref == "" {
return errors.New("allow_private_ips can only be used with type 'fetch' or remote MCP toolsets")
return errors.New("allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets")
}
if t.Remote.OAuth != nil {
if t.Remote.URL == "" {
Expand Down
32 changes: 30 additions & 2 deletions pkg/config/toolset_validate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ agents:
- type: shell
allow_private_ips: true
`,
wantErr: "allow_private_ips can only be used with type 'fetch' or remote MCP toolsets",
wantErr: "allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets",
},
{
name: "allow_private_ips on fetch toolset is accepted",
Expand All @@ -305,6 +305,34 @@ agents:
toolsets:
- type: fetch
allow_private_ips: true
`,
},
{
name: "allow_private_ips on api toolset is accepted",
config: `
agents:
root:
model: "openai/gpt-4"
toolsets:
- type: api
allow_private_ips: true
api_config:
name: probe
method: GET
endpoint: http://10.0.0.1/health
instruction: probe
`,
},
{
name: "allow_private_ips on openapi toolset is accepted",
config: `
agents:
root:
model: "openai/gpt-4"
toolsets:
- type: openapi
url: http://10.0.0.1/openapi.json
allow_private_ips: true
`,
},
{
Expand Down Expand Up @@ -332,7 +360,7 @@ agents:
allow_private_ips: true
command: docker
`,
wantErr: "allow_private_ips can only be used with type 'fetch' or remote MCP toolsets",
wantErr: "allow_private_ips can only be used with type 'fetch', 'api', 'openapi' or remote MCP toolsets",
},
{
name: "empty allowed_domains entry is rejected",
Expand Down
43 changes: 36 additions & 7 deletions pkg/tools/builtin/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,20 @@ type ToolSet struct {
config latest.APIToolConfig
expander *js.Expander

// unsafe disables SSRF dial-time protection. Only set by the test-only
// constructor in api_test.go (httptest.NewServer binds to 127.0.0.1).
unsafe bool
timeout time.Duration
allowPrivateIPs bool
}

const defaultHTTPTimeout = 30 * time.Second

// Verify interface compliance
var (
_ tools.ToolSet = (*ToolSet)(nil)
_ tools.Instructable = (*ToolSet)(nil)
)

func (t *ToolSet) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools.ToolCallResult, error) {
client := httpclient.NewSafeClient(30*time.Second, t.unsafe)
client := httpclient.NewSafeClient(t.timeout, t.allowPrivateIPs)

endpoint := t.config.Endpoint
var reqBody io.Reader = http.NoBody
Expand Down Expand Up @@ -100,14 +101,42 @@ func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *confi
toolset.APIConfig.Endpoint = expander.Expand(ctx, toolset.APIConfig.Endpoint, nil)
toolset.APIConfig.Headers = expander.ExpandMap(ctx, toolset.APIConfig.Headers)

return New(toolset.APIConfig, expander), nil
var opts []Option
if toolset.Timeout > 0 {
opts = append(opts, WithTimeout(time.Duration(toolset.Timeout)*time.Second))
}
if toolset.AllowPrivateIPsEnabled() {
opts = append(opts, WithAllowPrivateIPs(true))
}
return New(toolset.APIConfig, expander, opts...), nil
}

func New(apiConfig latest.APIToolConfig, expander *js.Expander) *ToolSet {
return &ToolSet{
// Option configures an api ToolSet.
type Option func(*ToolSet)

// WithTimeout overrides the default 30s HTTP client timeout.
func WithTimeout(d time.Duration) Option {
return func(t *ToolSet) { t.timeout = d }
}

// WithAllowPrivateIPs disables SSRF dial-time protection so the api tool
// may dial loopback / RFC1918 / link-local addresses. Operators opt in via
// `allow_private_ips: true` when the configured endpoint legitimately
// targets internal services. Tests use this to talk to httptest.NewServer.
func WithAllowPrivateIPs(allow bool) Option {
return func(t *ToolSet) { t.allowPrivateIPs = allow }
}

func New(apiConfig latest.APIToolConfig, expander *js.Expander, opts ...Option) *ToolSet {
t := &ToolSet{
config: apiConfig,
expander: expander,
timeout: defaultHTTPTimeout,
}
for _, opt := range opts {
opt(t)
}
return t
}

func (t *ToolSet) Instructions() string {
Expand Down
38 changes: 35 additions & 3 deletions pkg/tools/builtin/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -23,9 +24,7 @@ import (
// 127.0.0.1). It is defined in a *_test.go file so it is not compiled
// into release binaries. Production callers must use [New].
func newAPIToolForTest(config latest.APIToolConfig, expander *js.Expander) *ToolSet {
t := New(config, expander)
t.unsafe = true
return t
return New(config, expander, WithAllowPrivateIPs(true))
}

type testServer struct {
Expand Down Expand Up @@ -235,6 +234,39 @@ func TestAPITool_RejectsLocalAddresses(t *testing.T) {
}
}

// TestAPITool_AllowPrivateIPsRestoresLegacyBehaviour verifies that the
// allow_private_ips opt-in actually disables the SSRF dial filter.
func TestAPITool_AllowPrivateIPsRestoresLegacyBehaviour(t *testing.T) {
t.Parallel()
ts := getTestServer(t)

tool := New(latest.APIToolConfig{
Method: http.MethodGet,
Endpoint: ts.serverURL,
}, testExpander(), WithAllowPrivateIPs(true))

_, err := tool.callTool(t.Context(), tools.ToolCall{})
require.NoError(t, err, "WithAllowPrivateIPs(true) must permit dialling 127.0.0.1")
}

// TestAPITool_TimeoutHonoured confirms WithTimeout caps the request.
func TestAPITool_TimeoutHonoured(t *testing.T) {
t.Parallel()

slow := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
<-r.Context().Done()
}))
t.Cleanup(slow.Close)

tool := New(latest.APIToolConfig{
Method: http.MethodGet,
Endpoint: slow.URL,
}, testExpander(), WithAllowPrivateIPs(true), WithTimeout(50*time.Millisecond))

_, err := tool.callTool(t.Context(), tools.ToolCall{})
require.Error(t, err)
}

type noopEnvProvider struct{}

func (noopEnvProvider) Get(context.Context, string) (string, bool) { return "", false }
Expand Down
64 changes: 46 additions & 18 deletions pkg/tools/builtin/openapi/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
"github.com/docker/docker-agent/pkg/useragent"
)

const httpTimeout = 30 * time.Second
const defaultHTTPTimeout = 30 * time.Second

// CreateToolSet is used by the tools registry.
func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *config.RuntimeConfig) (tools.ToolSet, error) {
Expand All @@ -36,19 +36,23 @@ func CreateToolSet(ctx context.Context, toolset latest.Toolset, runConfig *confi
specURL := expander.Expand(ctx, toolset.URL, nil)
headers := expander.ExpandMap(ctx, toolset.Headers)

return New(specURL, headers), nil
var opts []Option
if toolset.Timeout > 0 {
opts = append(opts, WithTimeout(time.Duration(toolset.Timeout)*time.Second))
}
if toolset.AllowPrivateIPsEnabled() {
opts = append(opts, WithAllowPrivateIPs(true))
}
return New(specURL, headers, opts...), nil
}

// ToolSet generates HTTP tools from an OpenAPI specification.
type ToolSet struct {
specURL string
headers map[string]string

// unsafe disables SSRF dial-time protection on both the spec fetch
// and the generated tools' HTTP calls. It is only set by the
// test-only constructor in openapi_test.go (which exists because
// tests use httptest.NewServer that binds to 127.0.0.1).
unsafe bool
timeout time.Duration
allowPrivateIPs bool
}

// Verify interface compliance.
Expand All @@ -57,12 +61,34 @@ var (
_ tools.Instructable = (*ToolSet)(nil)
)

// Option configures an openapi ToolSet.
type Option func(*ToolSet)

// WithTimeout overrides the default 30s HTTP client timeout used both for
// fetching the spec and for the generated tools' HTTP calls.
func WithTimeout(d time.Duration) Option {
return func(t *ToolSet) { t.timeout = d }
}

// WithAllowPrivateIPs disables SSRF dial-time protection on both the spec
// fetch and the generated tools' HTTP calls. Operators opt in via
// `allow_private_ips: true` when the spec or its servers legitimately
// target internal services. Tests use this to talk to httptest.NewServer.
func WithAllowPrivateIPs(allow bool) Option {
return func(t *ToolSet) { t.allowPrivateIPs = allow }
}

// New creates a new OpenAPI toolset from the given spec URL.
func New(specURL string, headers map[string]string) *ToolSet {
return &ToolSet{
func New(specURL string, headers map[string]string, opts ...Option) *ToolSet {
t := &ToolSet{
specURL: specURL,
headers: headers,
timeout: defaultHTTPTimeout,
}
for _, opt := range opts {
opt(t)
}
return t
}

// Instructions returns usage instructions for the OpenAPI toolset.
Expand Down Expand Up @@ -93,7 +119,7 @@ func (t *ToolSet) fetchSpec(ctx context.Context) (*v3.Document, error) {
req.Header.Set("Accept", "application/json")
setHeaders(req, t.headers)

resp, err := httpclient.NewSafeClient(httpTimeout, t.unsafe).Do(req)
resp, err := httpclient.NewSafeClient(t.timeout, t.allowPrivateIPs).Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
Expand Down Expand Up @@ -220,11 +246,12 @@ func (t *ToolSet) operationToTool(baseURL, path, method string, op *v3.Operation
Description: desc,
Parameters: schema,
Handler: tools.NewHandler((&openAPIHandler{
baseURL: baseURL,
path: path,
method: method,
headers: t.headers,
unsafe: t.unsafe,
baseURL: baseURL,
path: path,
method: method,
headers: t.headers,
timeout: t.timeout,
allowPrivateIPs: t.allowPrivateIPs,
}).callTool),
Annotations: tools.ToolAnnotations{
ReadOnlyHint: readOnly,
Expand Down Expand Up @@ -412,8 +439,9 @@ type openAPIHandler struct {
path string
method string
headers map[string]string
// unsafe disables SSRF dial-time protection. See OpenAPITool.unsafe.
unsafe bool

timeout time.Duration
allowPrivateIPs bool
}

type openAPICallArgs map[string]any
Expand Down Expand Up @@ -446,7 +474,7 @@ func (h *openAPIHandler) callTool(ctx context.Context, params openAPICallArgs) (
req.Header.Set("Accept", "application/json")
setHeaders(req, h.headers)

resp, err := httpclient.NewSafeClient(httpTimeout, h.unsafe).Do(req)
resp, err := httpclient.NewSafeClient(h.timeout, h.allowPrivateIPs).Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
Expand Down
Loading
Loading