diff --git a/app/cli/cmd/casbackend_add_azureblob.go b/app/cli/cmd/casbackend_add_azureblob.go index 6dffde22c..4f3be9639 100644 --- a/app/cli/cmd/casbackend_add_azureblob.go +++ b/app/cli/cmd/casbackend_add_azureblob.go @@ -26,7 +26,7 @@ import ( ) func newCASBackendAddAzureBlobStorageCmd() *cobra.Command { - var storageAccountName, tenantID, clientID, clientSecret, container string + var storageAccountName, tenantID, clientID, clientSecret, container, endpoint string cmd := &cobra.Command{ Use: "azure-blob", Short: "Register a Azure Blob Storage CAS Backend", @@ -54,9 +54,14 @@ func newCASBackendAddAzureBlobStorageCmd() *cobra.Command { } } + location := fmt.Sprintf("%s/%s", storageAccountName, container) + if endpoint != "" { + location = fmt.Sprintf("%s/%s/%s", endpoint, storageAccountName, container) + } + opts := &action.NewCASBackendAddOpts{ Name: name, - Location: fmt.Sprintf("%s/%s", storageAccountName, container), + Location: location, Provider: azureblob.ProviderID, Description: description, Credentials: map[string]any{ @@ -97,5 +102,7 @@ func newCASBackendAddAzureBlobStorageCmd() *cobra.Command { cmd.Flags().StringVar(&container, "container", "chainloop", "Storage Container Name") + cmd.Flags().StringVar(&endpoint, "endpoint", "", "Custom Azure Blob endpoint suffix (e.g., blob.core.usgovcloudapi.net), if not provided, the public Azure cloud endpoint will be used.") + return cmd } diff --git a/app/cli/documentation/cli-reference.mdx b/app/cli/documentation/cli-reference.mdx index 4cdef395f..d1a636aaf 100755 --- a/app/cli/documentation/cli-reference.mdx +++ b/app/cli/documentation/cli-reference.mdx @@ -713,6 +713,7 @@ Options --client-id string Service Principal Client ID --client-secret string Service Principal Client Secret --container string Storage Container Name (default "chainloop") +--endpoint string Custom Azure Blob endpoint suffix (e.g., blob.core.usgovcloudapi.net), if not provided, the public Azure cloud endpoint will be used. -h, --help help for azure-blob --storage-account string Storage Account Name --tenant string Active Directory Tenant ID diff --git a/pkg/blobmanager/azureblob/backend.go b/pkg/blobmanager/azureblob/backend.go index 1c80927a3..1f41486e0 100644 --- a/pkg/blobmanager/azureblob/backend.go +++ b/pkg/blobmanager/azureblob/backend.go @@ -1,5 +1,5 @@ // -// Copyright 2023 The Chainloop Authors. +// Copyright 2023-2025 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ type Backend struct { storageAccountName string container string credentials *azidentity.ClientSecretCredential + endpoint string } var _ backend.UploaderDownloader = (*Backend)(nil) @@ -48,12 +49,13 @@ func NewBackend(creds *Credentials) (*Backend, error) { storageAccountName: creds.StorageAccountName, credentials: credential, container: creds.Container, + endpoint: creds.Endpoint, }, nil } // top level client used for creation/upload/download/listing operations func (b *Backend) client() (*azblob.Client, error) { - url := fmt.Sprintf("https://%s.blob.core.windows.net/", b.storageAccountName) + url := b.getServiceURL() // Top level client client, err := azblob.NewClient(url, b.credentials, nil) if err != nil { @@ -63,6 +65,14 @@ func (b *Backend) client() (*azblob.Client, error) { return client, nil } +// getServiceURL returns the Azure Blob Storage service URL. Uses custom endpoint if provided, otherwise defaults to public Azure cloud +func (b *Backend) getServiceURL() string { + if b.endpoint != "" { + return fmt.Sprintf("https://%s.%s/", b.storageAccountName, b.endpoint) + } + return fmt.Sprintf("https://%s.blob.core.windows.net/", b.storageAccountName) +} + // blob client used for operating with a single blob func (b *Backend) blobClient(digest string) (*blob.Client, error) { blobClient, err := blob.NewClient(b.resourcePath(digest), b.credentials, nil) @@ -74,6 +84,9 @@ func (b *Backend) blobClient(digest string) (*blob.Client, error) { } func (b *Backend) resourcePath(digest string) string { + if b.endpoint != "" { + return fmt.Sprintf("https://%s.%s/%s/%s", b.storageAccountName, b.endpoint, b.container, resourceName(digest)) + } return fmt.Sprintf("https://%s.blob.core.windows.net/%s/%s", b.storageAccountName, b.container, resourceName(digest)) } diff --git a/pkg/blobmanager/azureblob/provider.go b/pkg/blobmanager/azureblob/provider.go index 8016f8319..cc265d810 100644 --- a/pkg/blobmanager/azureblob/provider.go +++ b/pkg/blobmanager/azureblob/provider.go @@ -1,5 +1,5 @@ // -// Copyright 2023 The Chainloop Authors. +// Copyright 2023-2025 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -62,14 +62,15 @@ func extractCreds(location string, credsJSON []byte) (*Credentials, error) { return nil, fmt.Errorf("unmarshaling credentials: %w", err) } - parts := strings.Split(location, "/") - if len(parts) != 2 { - return nil, errors.New("invalid location: must be in the format /") + endpoint, storageAccount, container, err := extractLocationAndContainer(location) + if err != nil { + return nil, err } // Override the location in the credentials since that's something we don't allow updating - creds.StorageAccountName = parts[0] - creds.Container = parts[1] + creds.Endpoint = endpoint + creds.StorageAccountName = storageAccount + creds.Container = container if err := creds.Validate(); err != nil { return nil, fmt.Errorf("invalid credentials: %w", err) @@ -78,6 +79,24 @@ func extractCreds(location string, credsJSON []byte) (*Credentials, error) { return creds, nil } +// Extract the custom endpoint, storage account name, and container name from the location string +// The location string can be either: +// - / (uses default Azure blob endpoint) +// - // (uses custom endpoint for Azure Government, etc.) +func extractLocationAndContainer(location string) (string, string, string, error) { + parts := strings.Split(location, "/") + + if len(parts) == 2 { + return "", parts[0], parts[1], nil + } + + if len(parts) == 3 { + return parts[0], parts[1], parts[2], nil + } + + return "", "", "", errors.New("invalid location: must be in the format / or //") +} + func (p *BackendProvider) ValidateAndExtractCredentials(location string, credsJSON []byte) (any, error) { creds, err := extractCreds(location, credsJSON) if err != nil { @@ -108,6 +127,9 @@ type Credentials struct { ClientID string // Registered application / service principal client secret ClientSecret string + // Optional custom endpoint URL + // If empty, defaults to blob.core.windows.net + Endpoint string } // Validate that the APICreds has all its properties set diff --git a/pkg/blobmanager/azureblob/provider_test.go b/pkg/blobmanager/azureblob/provider_test.go index 660b648eb..dc1ddac2b 100644 --- a/pkg/blobmanager/azureblob/provider_test.go +++ b/pkg/blobmanager/azureblob/provider_test.go @@ -1,5 +1,5 @@ // -// Copyright 2023 The Chainloop Authors. +// Copyright 2023-2025 The Chainloop Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -130,14 +130,15 @@ func TestFromCredentials(t *testing.T) { } func TestExtractCreds(t *testing.T) { - tetCases := []struct { + testCases := []struct { name string location string credsJSON []byte + wantCreds *Credentials wantErr bool }{ { - name: "valid credentials", + name: "valid credentials without endpoint", location: "account/container", credsJSON: []byte(`{ "storageAccountName": "test", @@ -146,6 +147,33 @@ func TestExtractCreds(t *testing.T) { "clientID": "test", "clientSecret": "test" }`), + wantCreds: &Credentials{ + StorageAccountName: "account", + Container: "container", + TenantID: "test", + ClientID: "test", + ClientSecret: "test", + Endpoint: "", + }, + }, + { + name: "valid credentials with custom endpoint", + location: "blob.core.usgovcloudapi.net/account/container", + credsJSON: []byte(`{ + "storageAccountName": "test", + "container": "test", + "tenantID": "test", + "clientID": "test", + "clientSecret": "test" + }`), + wantCreds: &Credentials{ + StorageAccountName: "account", + Container: "container", + TenantID: "test", + ClientID: "test", + ClientSecret: "test", + Endpoint: "blob.core.usgovcloudapi.net", + }, }, { name: "invalid location, missing container", @@ -173,20 +201,78 @@ func TestExtractCreds(t *testing.T) { }, } - for _, tc := range tetCases { + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { creds, err := extractCreds(tc.location, tc.credsJSON) if tc.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) - assert.Equal(t, &Credentials{ - StorageAccountName: "account", - Container: "container", - TenantID: "test", - ClientID: "test", - ClientSecret: "test", - }, creds) + assert.Equal(t, tc.wantCreds, creds) + } + }) + } +} + +func TestExtractLocationAndContainer(t *testing.T) { + testCases := []struct { + name string + location string + wantEndpoint string + wantAccount string + wantContainer string + wantErr bool + }{ + { + name: "simple location without endpoint", + location: "myaccount/mycontainer", + wantEndpoint: "", + wantAccount: "myaccount", + wantContainer: "mycontainer", + }, + { + name: "Azure Government Cloud endpoint", + location: "blob.core.usgovcloudapi.net/myaccount/mycontainer", + wantEndpoint: "blob.core.usgovcloudapi.net", + wantAccount: "myaccount", + wantContainer: "mycontainer", + }, + { + name: "Azure Stack Hub endpoint", + location: "blob.local.azurestack.external/myaccount/mycontainer", + wantEndpoint: "blob.local.azurestack.external", + wantAccount: "myaccount", + wantContainer: "mycontainer", + }, + { + name: "custom endpoint with path segments", + location: "custom.endpoint.com/account/container", + wantEndpoint: "custom.endpoint.com", + wantAccount: "account", + wantContainer: "container", + }, + { + name: "invalid simple location - missing container", + location: "myaccount", + wantErr: true, + }, + { + name: "invalid location - too many segments", + location: "endpoint/account/container/extra", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + endpoint, account, container, err := extractLocationAndContainer(tc.location) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.wantEndpoint, endpoint) + assert.Equal(t, tc.wantAccount, account) + assert.Equal(t, tc.wantContainer, container) } }) }