Skip to content

Commit dd92e07

Browse files
committed
feat: add endpoint parameter to azure-blob cas-backend
1 parent 6fb89be commit dd92e07

File tree

4 files changed

+146
-18
lines changed

4 files changed

+146
-18
lines changed

app/cli/cmd/casbackend_add_azureblob.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
)
2727

2828
func newCASBackendAddAzureBlobStorageCmd() *cobra.Command {
29-
var storageAccountName, tenantID, clientID, clientSecret, container string
29+
var storageAccountName, tenantID, clientID, clientSecret, container, endpoint string
3030
cmd := &cobra.Command{
3131
Use: "azure-blob",
3232
Short: "Register a Azure Blob Storage CAS Backend",
@@ -54,9 +54,14 @@ func newCASBackendAddAzureBlobStorageCmd() *cobra.Command {
5454
}
5555
}
5656

57+
location := fmt.Sprintf("%s/%s", storageAccountName, container)
58+
if endpoint != "" {
59+
location = fmt.Sprintf("%s/%s/%s", endpoint, storageAccountName, container)
60+
}
61+
5762
opts := &action.NewCASBackendAddOpts{
5863
Name: name,
59-
Location: fmt.Sprintf("%s/%s", storageAccountName, container),
64+
Location: location,
6065
Provider: azureblob.ProviderID,
6166
Description: description,
6267
Credentials: map[string]any{
@@ -97,5 +102,7 @@ func newCASBackendAddAzureBlobStorageCmd() *cobra.Command {
97102

98103
cmd.Flags().StringVar(&container, "container", "chainloop", "Storage Container Name")
99104

105+
cmd.Flags().StringVar(&endpoint, "endpoint", "", "Azure Blob endpoint suffix (e.g., blob.core.windows.net, blob.core.usgovcloudapi.net)")
106+
100107
return cmd
101108
}

pkg/blobmanager/azureblob/backend.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ type Backend struct {
3434
storageAccountName string
3535
container string
3636
credentials *azidentity.ClientSecretCredential
37+
endpoint string
3738
}
3839

3940
var _ backend.UploaderDownloader = (*Backend)(nil)
@@ -48,12 +49,13 @@ func NewBackend(creds *Credentials) (*Backend, error) {
4849
storageAccountName: creds.StorageAccountName,
4950
credentials: credential,
5051
container: creds.Container,
52+
endpoint: creds.Endpoint,
5153
}, nil
5254
}
5355

5456
// top level client used for creation/upload/download/listing operations
5557
func (b *Backend) client() (*azblob.Client, error) {
56-
url := fmt.Sprintf("https://%s.blob.core.windows.net/", b.storageAccountName)
58+
url := b.getServiceURL()
5759
// Top level client
5860
client, err := azblob.NewClient(url, b.credentials, nil)
5961
if err != nil {
@@ -63,6 +65,14 @@ func (b *Backend) client() (*azblob.Client, error) {
6365
return client, nil
6466
}
6567

68+
// getServiceURL returns the Azure Blob Storage service URL. Uses custom endpoint if provided, otherwise defaults to public Azure cloud
69+
func (b *Backend) getServiceURL() string {
70+
if b.endpoint != "" {
71+
return fmt.Sprintf("https://%s.%s/", b.storageAccountName, b.endpoint)
72+
}
73+
return fmt.Sprintf("https://%s.blob.core.windows.net/", b.storageAccountName)
74+
}
75+
6676
// blob client used for operating with a single blob
6777
func (b *Backend) blobClient(digest string) (*blob.Client, error) {
6878
blobClient, err := blob.NewClient(b.resourcePath(digest), b.credentials, nil)
@@ -74,6 +84,9 @@ func (b *Backend) blobClient(digest string) (*blob.Client, error) {
7484
}
7585

7686
func (b *Backend) resourcePath(digest string) string {
87+
if b.endpoint != "" {
88+
return fmt.Sprintf("https://%s.%s/%s/%s", b.storageAccountName, b.endpoint, b.container, resourceName(digest))
89+
}
7790
return fmt.Sprintf("https://%s.blob.core.windows.net/%s/%s", b.storageAccountName, b.container, resourceName(digest))
7891
}
7992

pkg/blobmanager/azureblob/provider.go

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,15 @@ func extractCreds(location string, credsJSON []byte) (*Credentials, error) {
6262
return nil, fmt.Errorf("unmarshaling credentials: %w", err)
6363
}
6464

65-
parts := strings.Split(location, "/")
66-
if len(parts) != 2 {
67-
return nil, errors.New("invalid location: must be in the format <account>/<container>")
65+
endpoint, storageAccount, container, err := extractLocationAndContainer(location)
66+
if err != nil {
67+
return nil, err
6868
}
6969

7070
// Override the location in the credentials since that's something we don't allow updating
71-
creds.StorageAccountName = parts[0]
72-
creds.Container = parts[1]
71+
creds.Endpoint = endpoint
72+
creds.StorageAccountName = storageAccount
73+
creds.Container = container
7374

7475
if err := creds.Validate(); err != nil {
7576
return nil, fmt.Errorf("invalid credentials: %w", err)
@@ -78,6 +79,24 @@ func extractCreds(location string, credsJSON []byte) (*Credentials, error) {
7879
return creds, nil
7980
}
8081

82+
// Extract the custom endpoint, storage account name, and container name from the location string
83+
// The location string can be either:
84+
// - <account>/<container> (uses default Azure blob endpoint)
85+
// - <endpoint>/<account>/<container> (uses custom endpoint for Azure Government, etc.)
86+
func extractLocationAndContainer(location string) (string, string, string, error) {
87+
parts := strings.Split(location, "/")
88+
89+
if len(parts) == 2 {
90+
return "", parts[0], parts[1], nil
91+
}
92+
93+
if len(parts) == 3 {
94+
return parts[0], parts[1], parts[2], nil
95+
}
96+
97+
return "", "", "", errors.New("invalid location: must be in the format <account>/<container> or <endpoint>/<account>/<container>")
98+
}
99+
81100
func (p *BackendProvider) ValidateAndExtractCredentials(location string, credsJSON []byte) (any, error) {
82101
creds, err := extractCreds(location, credsJSON)
83102
if err != nil {
@@ -108,6 +127,9 @@ type Credentials struct {
108127
ClientID string
109128
// Registered application / service principal client secret
110129
ClientSecret string
130+
// Optional custom endpoint URL
131+
// If empty, defaults to blob.core.windows.net
132+
Endpoint string
111133
}
112134

113135
// Validate that the APICreds has all its properties set

pkg/blobmanager/azureblob/provider_test.go

Lines changed: 96 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,15 @@ func TestFromCredentials(t *testing.T) {
130130
}
131131

132132
func TestExtractCreds(t *testing.T) {
133-
tetCases := []struct {
133+
testCases := []struct {
134134
name string
135135
location string
136136
credsJSON []byte
137+
wantCreds *Credentials
137138
wantErr bool
138139
}{
139140
{
140-
name: "valid credentials",
141+
name: "valid credentials without endpoint",
141142
location: "account/container",
142143
credsJSON: []byte(`{
143144
"storageAccountName": "test",
@@ -146,6 +147,33 @@ func TestExtractCreds(t *testing.T) {
146147
"clientID": "test",
147148
"clientSecret": "test"
148149
}`),
150+
wantCreds: &Credentials{
151+
StorageAccountName: "account",
152+
Container: "container",
153+
TenantID: "test",
154+
ClientID: "test",
155+
ClientSecret: "test",
156+
Endpoint: "",
157+
},
158+
},
159+
{
160+
name: "valid credentials with custom endpoint",
161+
location: "blob.core.usgovcloudapi.net/account/container",
162+
credsJSON: []byte(`{
163+
"storageAccountName": "test",
164+
"container": "test",
165+
"tenantID": "test",
166+
"clientID": "test",
167+
"clientSecret": "test"
168+
}`),
169+
wantCreds: &Credentials{
170+
StorageAccountName: "account",
171+
Container: "container",
172+
TenantID: "test",
173+
ClientID: "test",
174+
ClientSecret: "test",
175+
Endpoint: "blob.core.usgovcloudapi.net",
176+
},
149177
},
150178
{
151179
name: "invalid location, missing container",
@@ -173,20 +201,78 @@ func TestExtractCreds(t *testing.T) {
173201
},
174202
}
175203

176-
for _, tc := range tetCases {
204+
for _, tc := range testCases {
177205
t.Run(tc.name, func(t *testing.T) {
178206
creds, err := extractCreds(tc.location, tc.credsJSON)
179207
if tc.wantErr {
180208
assert.Error(t, err)
181209
} else {
182210
assert.NoError(t, err)
183-
assert.Equal(t, &Credentials{
184-
StorageAccountName: "account",
185-
Container: "container",
186-
TenantID: "test",
187-
ClientID: "test",
188-
ClientSecret: "test",
189-
}, creds)
211+
assert.Equal(t, tc.wantCreds, creds)
212+
}
213+
})
214+
}
215+
}
216+
217+
func TestExtractLocationAndContainer(t *testing.T) {
218+
testCases := []struct {
219+
name string
220+
location string
221+
wantEndpoint string
222+
wantAccount string
223+
wantContainer string
224+
wantErr bool
225+
}{
226+
{
227+
name: "simple location without endpoint",
228+
location: "myaccount/mycontainer",
229+
wantEndpoint: "",
230+
wantAccount: "myaccount",
231+
wantContainer: "mycontainer",
232+
},
233+
{
234+
name: "Azure Government Cloud endpoint",
235+
location: "blob.core.usgovcloudapi.net/myaccount/mycontainer",
236+
wantEndpoint: "blob.core.usgovcloudapi.net",
237+
wantAccount: "myaccount",
238+
wantContainer: "mycontainer",
239+
},
240+
{
241+
name: "Azure Stack Hub endpoint",
242+
location: "blob.local.azurestack.external/myaccount/mycontainer",
243+
wantEndpoint: "blob.local.azurestack.external",
244+
wantAccount: "myaccount",
245+
wantContainer: "mycontainer",
246+
},
247+
{
248+
name: "custom endpoint with path segments",
249+
location: "custom.endpoint.com/account/container",
250+
wantEndpoint: "custom.endpoint.com",
251+
wantAccount: "account",
252+
wantContainer: "container",
253+
},
254+
{
255+
name: "invalid simple location - missing container",
256+
location: "myaccount",
257+
wantErr: true,
258+
},
259+
{
260+
name: "invalid location - too many segments",
261+
location: "endpoint/account/container/extra",
262+
wantErr: true,
263+
},
264+
}
265+
266+
for _, tc := range testCases {
267+
t.Run(tc.name, func(t *testing.T) {
268+
endpoint, account, container, err := extractLocationAndContainer(tc.location)
269+
if tc.wantErr {
270+
assert.Error(t, err)
271+
} else {
272+
assert.NoError(t, err)
273+
assert.Equal(t, tc.wantEndpoint, endpoint)
274+
assert.Equal(t, tc.wantAccount, account)
275+
assert.Equal(t, tc.wantContainer, container)
190276
}
191277
})
192278
}

0 commit comments

Comments
 (0)