diff --git a/src/agents/bedrock-discovery.test.ts b/src/agents/bedrock-discovery.test.ts index a8fc1b2e9..1477dfca0 100644 --- a/src/agents/bedrock-discovery.test.ts +++ b/src/agents/bedrock-discovery.test.ts @@ -14,6 +14,7 @@ describe("bedrock discovery", () => { await import("./bedrock-discovery.js"); resetBedrockDiscoveryCacheForTest(); + // Mock ListFoundationModelsCommand response sendMock.mockResolvedValueOnce({ modelSummaries: [ { @@ -54,6 +55,8 @@ describe("bedrock discovery", () => { }, ], }); + // Mock ListInferenceProfilesCommand response (empty) + sendMock.mockResolvedValueOnce({ inferenceProfileSummaries: [] }); const models = await discoverBedrockModels({ region: "us-east-1", clientFactory }); expect(models).toHaveLength(1); @@ -85,6 +88,7 @@ describe("bedrock discovery", () => { }, ], }); + sendMock.mockResolvedValueOnce({ inferenceProfileSummaries: [] }); const models = await discoverBedrockModels({ region: "us-east-1", @@ -112,6 +116,7 @@ describe("bedrock discovery", () => { }, ], }); + sendMock.mockResolvedValueOnce({ inferenceProfileSummaries: [] }); const models = await discoverBedrockModels({ region: "us-east-1", @@ -139,10 +144,12 @@ describe("bedrock discovery", () => { }, ], }); + sendMock.mockResolvedValueOnce({ inferenceProfileSummaries: [] }); await discoverBedrockModels({ region: "us-east-1", clientFactory }); await discoverBedrockModels({ region: "us-east-1", clientFactory }); - expect(sendMock).toHaveBeenCalledTimes(1); + // 2 calls for first discovery (foundation models + inference profiles), 0 for second (cached) + expect(sendMock).toHaveBeenCalledTimes(2); }); it("skips cache when refreshInterval is 0", async () => { @@ -150,6 +157,7 @@ describe("bedrock discovery", () => { await import("./bedrock-discovery.js"); resetBedrockDiscoveryCacheForTest(); + // First call - foundation models + inference profiles sendMock .mockResolvedValueOnce({ modelSummaries: [ @@ -164,6 +172,8 @@ describe("bedrock discovery", () => { }, ], }) + .mockResolvedValueOnce({ inferenceProfileSummaries: [] }) + // Second call - foundation models + inference profiles .mockResolvedValueOnce({ modelSummaries: [ { @@ -176,7 +186,8 @@ describe("bedrock discovery", () => { modelLifecycle: { status: "ACTIVE" }, }, ], - }); + }) + .mockResolvedValueOnce({ inferenceProfileSummaries: [] }); await discoverBedrockModels({ region: "us-east-1", @@ -188,6 +199,158 @@ describe("bedrock discovery", () => { config: { refreshInterval: 0 }, clientFactory, }); - expect(sendMock).toHaveBeenCalledTimes(2); + // 2 calls per discovery * 2 discoveries = 4 calls + expect(sendMock).toHaveBeenCalledTimes(4); + }); + + it("discovers inference profiles with CRIS prefixes", async () => { + const { discoverBedrockModels, resetBedrockDiscoveryCacheForTest } = + await import("./bedrock-discovery.js"); + resetBedrockDiscoveryCacheForTest(); + + sendMock.mockResolvedValueOnce({ + modelSummaries: [ + { + modelId: "anthropic.claude-3-7-sonnet-20250219-v1:0", + modelName: "Claude 3.7 Sonnet", + providerName: "anthropic", + inputModalities: ["TEXT", "IMAGE"], + outputModalities: ["TEXT"], + responseStreamingSupported: true, + modelLifecycle: { status: "ACTIVE" }, + }, + ], + }); + sendMock.mockResolvedValueOnce({ + inferenceProfileSummaries: [ + { + inferenceProfileId: "global.anthropic.claude-opus-4-5-20251101-v1:0", + inferenceProfileName: "GLOBAL Anthropic Claude Opus 4.5", + status: "ACTIVE", + models: [ + { + modelArn: + "arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-opus-4-5-20251101-v1:0", + }, + ], + }, + { + inferenceProfileId: "eu.anthropic.claude-3-sonnet-20240229-v1:0", + inferenceProfileName: "EU Anthropic Claude 3 Sonnet", + status: "ACTIVE", + models: [ + { + modelArn: + "arn:aws:bedrock:eu-west-1::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", + }, + ], + }, + { + inferenceProfileId: "us.amazon.nova-pro-v1:0", + inferenceProfileName: "US Amazon Nova Pro", + status: "INACTIVE", + models: [], + }, + ], + }); + + const models = await discoverBedrockModels({ region: "us-east-1", clientFactory }); + + // Should include 1 foundation model + 2 active inference profiles + expect(models).toHaveLength(3); + + const foundationModel = models.find( + (m) => m.id === "anthropic.claude-3-7-sonnet-20250219-v1:0", + ); + expect(foundationModel).toBeDefined(); + + const globalProfile = models.find( + (m) => m.id === "global.anthropic.claude-opus-4-5-20251101-v1:0", + ); + expect(globalProfile).toBeDefined(); + expect(globalProfile?.name).toBe("GLOBAL Anthropic Claude Opus 4.5"); + + const euProfile = models.find((m) => m.id === "eu.anthropic.claude-3-sonnet-20240229-v1:0"); + expect(euProfile).toBeDefined(); + expect(euProfile?.name).toBe("EU Anthropic Claude 3 Sonnet"); + + // Inactive profile should not be included + const usProfile = models.find((m) => m.id === "us.amazon.nova-pro-v1:0"); + expect(usProfile).toBeUndefined(); + }); + + it("inherits capabilities from foundation model for inference profiles", async () => { + const { discoverBedrockModels, resetBedrockDiscoveryCacheForTest } = + await import("./bedrock-discovery.js"); + resetBedrockDiscoveryCacheForTest(); + + sendMock.mockResolvedValueOnce({ + modelSummaries: [ + { + modelId: "anthropic.claude-3-sonnet-20240229-v1:0", + modelName: "Claude 3 Sonnet", + providerName: "anthropic", + inputModalities: ["TEXT", "IMAGE"], + outputModalities: ["TEXT"], + responseStreamingSupported: true, + modelLifecycle: { status: "ACTIVE" }, + }, + ], + }); + sendMock.mockResolvedValueOnce({ + inferenceProfileSummaries: [ + { + inferenceProfileId: "eu.anthropic.claude-3-sonnet-20240229-v1:0", + inferenceProfileName: "EU Anthropic Claude 3 Sonnet", + status: "ACTIVE", + models: [ + { + modelArn: + "arn:aws:bedrock:eu-west-1::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0", + }, + ], + }, + ], + }); + + const models = await discoverBedrockModels({ region: "us-east-1", clientFactory }); + + const euProfile = models.find((m) => m.id === "eu.anthropic.claude-3-sonnet-20240229-v1:0"); + expect(euProfile).toBeDefined(); + // Should inherit text+image input from the foundation model + expect(euProfile?.input).toEqual(["text", "image"]); + }); + + it("applies provider filter to inference profiles", async () => { + const { discoverBedrockModels, resetBedrockDiscoveryCacheForTest } = + await import("./bedrock-discovery.js"); + resetBedrockDiscoveryCacheForTest(); + + sendMock.mockResolvedValueOnce({ modelSummaries: [] }); + sendMock.mockResolvedValueOnce({ + inferenceProfileSummaries: [ + { + inferenceProfileId: "global.anthropic.claude-opus-4-5-20251101-v1:0", + inferenceProfileName: "GLOBAL Anthropic Claude Opus 4.5", + status: "ACTIVE", + models: [], + }, + { + inferenceProfileId: "global.amazon.nova-pro-v1:0", + inferenceProfileName: "GLOBAL Amazon Nova Pro", + status: "ACTIVE", + models: [], + }, + ], + }); + + const models = await discoverBedrockModels({ + region: "us-east-1", + config: { providerFilter: ["anthropic"] }, + clientFactory, + }); + + expect(models).toHaveLength(1); + expect(models[0].id).toBe("global.anthropic.claude-opus-4-5-20251101-v1:0"); }); }); diff --git a/src/agents/bedrock-discovery.ts b/src/agents/bedrock-discovery.ts index 3b42d0081..0b6c850c2 100644 --- a/src/agents/bedrock-discovery.ts +++ b/src/agents/bedrock-discovery.ts @@ -1,7 +1,9 @@ import { BedrockClient, ListFoundationModelsCommand, + ListInferenceProfilesCommand, type ListFoundationModelsCommandOutput, + type ListInferenceProfilesCommandOutput, } from "@aws-sdk/client-bedrock"; import type { BedrockDiscoveryConfig, ModelDefinitionConfig } from "../config/types.js"; @@ -17,6 +19,9 @@ const DEFAULT_COST = { }; type BedrockModelSummary = NonNullable[number]; +type BedrockInferenceProfileSummary = NonNullable< + ListInferenceProfilesCommandOutput["inferenceProfileSummaries"] +>[number]; type BedrockDiscoveryCacheEntry = { expiresAt: number; @@ -116,6 +121,80 @@ function toModelDefinition( }; } +function extractBaseModelIdFromArn(arn: string): string | undefined { + // ARN format: arn:aws:bedrock:region::foundation-model/model-id + const match = /foundation-model\/(.+)$/.exec(arn); + return match?.[1]; +} + +function isActiveInferenceProfile(summary: BedrockInferenceProfileSummary): boolean { + const status = summary.status; + return typeof status === "string" ? status.toUpperCase() === "ACTIVE" : false; +} + +function matchesInferenceProfileProviderFilter( + summary: BedrockInferenceProfileSummary, + filter: string[], +): boolean { + if (filter.length === 0) return true; + // Extract provider from inference profile ID (e.g., "global.anthropic.claude-..." -> "anthropic") + const profileId = summary.inferenceProfileId ?? ""; + const parts = profileId.split("."); + // Format is: prefix.provider.model (e.g., global.anthropic.claude-3-sonnet...) + const providerName = parts.length >= 2 ? parts[1] : undefined; + const normalized = providerName?.trim().toLowerCase(); + if (!normalized) return false; + return filter.includes(normalized); +} + +function inferInferenceProfileCapabilities( + summary: BedrockInferenceProfileSummary, + foundationModels: Map, +): { input: Array<"text" | "image">; reasoning: boolean } { + // Try to get capabilities from the first underlying foundation model + const modelArns = summary.models ?? []; + for (const model of modelArns) { + const modelArn = model.modelArn; + if (!modelArn) continue; + const baseModelId = extractBaseModelIdFromArn(modelArn); + if (!baseModelId) continue; + const foundationModel = foundationModels.get(baseModelId); + if (foundationModel) { + return { + input: mapInputModalities(foundationModel), + reasoning: inferReasoningSupport(foundationModel), + }; + } + } + // Fall back to inferring from the profile ID/name + const haystack = + `${summary.inferenceProfileId ?? ""} ${summary.inferenceProfileName ?? ""}`.toLowerCase(); + return { + input: haystack.includes("embed") + ? (["text"] as Array<"text" | "image">) + : (["text", "image"] as Array<"text" | "image">), + reasoning: haystack.includes("reasoning") || haystack.includes("thinking"), + }; +} + +function inferenceProfileToModelDefinition( + summary: BedrockInferenceProfileSummary, + foundationModels: Map, + defaults: { contextWindow: number; maxTokens: number }, +): ModelDefinitionConfig { + const id = summary.inferenceProfileId?.trim() ?? ""; + const capabilities = inferInferenceProfileCapabilities(summary, foundationModels); + return { + id, + name: summary.inferenceProfileName?.trim() || id, + reasoning: capabilities.reasoning, + input: capabilities.input, + cost: DEFAULT_COST, + contextWindow: defaults.contextWindow, + maxTokens: defaults.maxTokens, + }; +} + export function resetBedrockDiscoveryCacheForTest(): void { discoveryCache.clear(); hasLoggedBedrockError = false; @@ -157,9 +236,25 @@ export async function discoverBedrockModels(params: { const client = clientFactory(params.region); const discoveryPromise = (async () => { - const response = await client.send(new ListFoundationModelsCommand({})); + // Fetch foundation models and inference profiles in parallel + const [foundationResponse, inferenceResponse] = await Promise.all([ + client.send(new ListFoundationModelsCommand({})), + client.send(new ListInferenceProfilesCommand({})), + ]); + + // Build a map of foundation models for capability lookups + const foundationModelMap = new Map(); + for (const summary of foundationResponse.modelSummaries ?? []) { + const modelId = summary.modelId?.trim(); + if (modelId) { + foundationModelMap.set(modelId, summary); + } + } + const discovered: ModelDefinitionConfig[] = []; - for (const summary of response.modelSummaries ?? []) { + + // Add foundation models + for (const summary of foundationResponse.modelSummaries ?? []) { if (!shouldIncludeSummary(summary, providerFilter)) continue; discovered.push( toModelDefinition(summary, { @@ -168,6 +263,20 @@ export async function discoverBedrockModels(params: { }), ); } + + // Add inference profiles (CRIS: global., us., eu., etc.) + for (const summary of inferenceResponse.inferenceProfileSummaries ?? []) { + if (!summary.inferenceProfileId?.trim()) continue; + if (!isActiveInferenceProfile(summary)) continue; + if (!matchesInferenceProfileProviderFilter(summary, providerFilter)) continue; + discovered.push( + inferenceProfileToModelDefinition(summary, foundationModelMap, { + contextWindow: defaultContextWindow, + maxTokens: defaultMaxTokens, + }), + ); + } + return discovered.sort((a, b) => a.name.localeCompare(b.name)); })();