Skip to content

Commit

Permalink
Added aad client certificate support (#36)
Browse files Browse the repository at this point in the history
* added aad client certificate support

* Update pkg/token/options.go

Co-authored-by: Weinong Wang <weinong@outlook.com>

* Update pkg/converter/convert.go

Co-authored-by: Weinong Wang <weinong@outlook.com>

* Update pkg/converter/convert.go

Co-authored-by: Weinong Wang <weinong@outlook.com>

* Update pkg/token/options.go

Co-authored-by: Weinong Wang <weinong@outlook.com>

* addressed comnments

* Update pkg/token/serviceprincipaltoken.go

Co-authored-by: Weinong Wang <weinong@outlook.com>

* Update pkg/token/serviceprincipaltoken.go

Co-authored-by: Weinong Wang <weinong@outlook.com>

* addressed comments

* addressed comments

* Update pkg/token/serviceprincipaltoken.go

Co-authored-by: Weinong Wang <weinong@outlook.com>

* Update pkg/token/serviceprincipaltoken.go

Co-authored-by: Weinong Wang <weinong@outlook.com>

Co-authored-by: Weinong Wang <weinong@outlook.com>
  • Loading branch information
tamilmani1989 and weinong authored Sep 4, 2020
1 parent 311d493 commit 74bd077
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 12 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ require (
github.com/golang/mock v1.2.0
github.com/spf13/cobra v0.0.6
github.com/spf13/pflag v1.0.5
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
k8s.io/apimachinery v0.17.4
k8s.io/cli-runtime v0.17.4
k8s.io/client-go v0.17.4
Expand Down
6 changes: 6 additions & 0 deletions pkg/converter/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
argTenantID = "--tenant-id"
argEnvironment = "--environment"
argClientSecret = "--client-secret"
argClientCert = "--client-certificate"
argIsLegacy = "--legacy"
argUsername = "--username"
argPassword = "--password"
Expand All @@ -31,6 +32,7 @@ const (
flagTenantID = "tenant-id"
flagEnvironment = "environment"
flagClientSecret = "client-secret"
flagClientCert = "client-certificate"
flagIsLegacy = "legacy"
flagUsername = "username"
flagPassword = "password"
Expand Down Expand Up @@ -98,6 +100,10 @@ func Convert(o Options) error {
exec.Args = append(exec.Args, argClientSecret)
exec.Args = append(exec.Args, o.TokenOptions.ClientSecret)
}
if !isMSI && o.isSet(flagClientCert) {
exec.Args = append(exec.Args, argClientCert)
exec.Args = append(exec.Args, o.TokenOptions.ClientCert)
}
if !isMSI && o.isSet(flagUsername) {
exec.Args = append(exec.Args, argUsername)
exec.Args = append(exec.Args, o.TokenOptions.Username)
Expand Down
5 changes: 5 additions & 0 deletions pkg/converter/convet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ func TestConvert(t *testing.T) {
clientID = "clientID"
tenantID = "tenantID"
clientSecret = "foosecret"
clientCert = "/tmp/clientcert"
username = "foo123"
password = "foobar"
loginMethod = "device"
Expand Down Expand Up @@ -77,6 +78,7 @@ func TestConvert(t *testing.T) {
flagClientID: clientID,
flagTenantID: tenantID,
flagClientSecret: clientSecret,
flagClientCert: clientCert,
flagUsername: username,
flagPassword: password,
flagLoginMethod: loginMethod,
Expand All @@ -89,6 +91,7 @@ func TestConvert(t *testing.T) {
argTenantID, tenantID,
argIsLegacy,
argClientSecret, clientSecret,
argClientCert, clientCert,
argUsername, username,
argPassword, password,
argLoginMethod, loginMethod,
Expand All @@ -105,6 +108,7 @@ func TestConvert(t *testing.T) {
flagClientID: clientID,
flagTenantID: tenantID,
flagClientSecret: clientSecret,
flagClientCert: clientCert,
flagUsername: username,
flagPassword: password,
flagLoginMethod: loginMethod,
Expand All @@ -118,6 +122,7 @@ func TestConvert(t *testing.T) {
argTenantID, tenantID,
argIsLegacy,
argClientSecret, clientSecret,
argClientCert, clientCert,
argUsername, username,
argPassword, password,
argLoginMethod, loginMethod,
Expand Down
6 changes: 6 additions & 0 deletions pkg/token/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type Options struct {
LoginMethod string
ClientID string
ClientSecret string
ClientCert string
Username string
Password string
ServerID string
Expand All @@ -33,6 +34,7 @@ const (

envServicePrincipalClientID = "AAD_SERVICE_PRINCIPAL_CLIENT_ID"
envServicePrincipalClientSecret = "AAD_SERVICE_PRINCIPAL_CLIENT_SECRET"
envServicePrincipalClientCert = "AAD_SERVICE_PRINCIPAL_CLIENT_CERTIFICATE"
envROPCUsername = "AAD_USER_PRINCIPAL_NAME"
envROPCPassword = "AAD_USER_PRINCIPAL_PASSWORD"
envLoginMethod = "AAD_LOGIN_METHOD"
Expand All @@ -59,6 +61,7 @@ func (o *Options) AddFlags(fs *pflag.FlagSet) {
fs.StringVarP(&o.LoginMethod, "login", "l", o.LoginMethod, fmt.Sprintf("Login method. Supported methods: %s. It may be specified in %s environment variable", GetSupportedLogins(), envLoginMethod))
fs.StringVar(&o.ClientID, "client-id", o.ClientID, fmt.Sprintf("AAD client application ID. It may be specified in %s environment variable", envServicePrincipalClientID))
fs.StringVar(&o.ClientSecret, "client-secret", o.ClientSecret, fmt.Sprintf("AAD client application secret. Used in spn login. It may be specified in %s environment variable", envServicePrincipalClientSecret))
fs.StringVar(&o.ClientCert, "client-certificate", o.ClientCert, fmt.Sprintf("AAD client cert in pfx. Used in spn login. It may be specified in %s environment variable", envServicePrincipalClientCert))
fs.StringVar(&o.Username, "username", o.Username, fmt.Sprintf("user name for ropc login flow. It may be specified in %s environment variable", envROPCUsername))
fs.StringVar(&o.Password, "password", o.Password, fmt.Sprintf("password for ropc login flow. It may be specified in %s environment variable", envROPCPassword))
fs.StringVar(&o.IdentityResourceId, "identity-resource-id", o.IdentityResourceId, "Managed Identity resource id.")
Expand Down Expand Up @@ -89,6 +92,9 @@ func (o *Options) UpdateFromEnv() {
if v, ok := os.LookupEnv(envServicePrincipalClientSecret); ok {
o.ClientSecret = v
}
if v, ok := os.LookupEnv(envServicePrincipalClientCert); ok {
o.ClientCert = v
}
if v, ok := os.LookupEnv(envROPCUsername); ok {
o.Username = v
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/token/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func newTokenProvider(o *Options) (TokenProvider, error) {
case DeviceCodeLogin:
return newDeviceCodeTokenProvider(*oAuthConfig, o.ClientID, o.ServerID, o.TenantID)
case ServicePrincipalLogin:
return newServicePrincipalToken(*oAuthConfig, o.ClientID, o.ClientSecret, o.ServerID, o.TenantID)
return newServicePrincipalToken(*oAuthConfig, o.ClientID, o.ClientSecret, o.ClientCert, o.ServerID, o.TenantID)
case ROPCLogin:
return newResourceOwnerToken(*oAuthConfig, o.ClientID, o.Username, o.Password, o.ServerID, o.TenantID)
case MSILogin:
Expand Down
175 changes: 164 additions & 11 deletions pkg/token/serviceprincipaltoken.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,45 @@
package token

import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"

"github.com/Azure/go-autorest/autorest/adal"
"golang.org/x/crypto/pkcs12"
)

//pem block types
const (
certificate = "CERTIFICATE"
privateKey = "PRIVATE KEY"
)

const (
defaultEnvironment = "AzurePublicCloud"
)

type servicePrincipalToken struct {
clientID string
clientSecret string
clientCert string
resourceID string
tenantID string
oAuthConfig adal.OAuthConfig
}

func newServicePrincipalToken(oAuthConfig adal.OAuthConfig, clientID, clientSecret, resourceID, tenantID string) (TokenProvider, error) {
func newServicePrincipalToken(oAuthConfig adal.OAuthConfig, clientID, clientSecret, clientCert, resourceID, tenantID string) (TokenProvider, error) {
if clientID == "" {
return nil, errors.New("clientID cannot be empty")
}
if clientSecret == "" {
return nil, errors.New("clientSecret cannot be empty")
if clientSecret == "" && clientCert == "" {
return nil, errors.New("Both clientSecret and clientcert cannot be empty")
}
if clientSecret != "" && clientCert != "" {
return nil, errors.New("Client secret and client certificate cannot be set at the same time. Only one has to be specified")
}
if resourceID == "" {
return nil, errors.New("resourceID cannot be empty")
Expand All @@ -32,6 +51,7 @@ func newServicePrincipalToken(oAuthConfig adal.OAuthConfig, clientID, clientSecr
return &servicePrincipalToken{
clientID: clientID,
clientSecret: clientSecret,
clientCert: clientCert,
resourceID: resourceID,
tenantID: tenantID,
oAuthConfig: oAuthConfig,
Expand All @@ -43,14 +63,44 @@ func (p *servicePrincipalToken) Token() (adal.Token, error) {
callback := func(t adal.Token) error {
return nil
}
spt, err := adal.NewServicePrincipalToken(
p.oAuthConfig,
p.clientID,
p.clientSecret,
p.resourceID,
callback)
if err != nil {
return emptyToken, fmt.Errorf("failed to create service principal token: %s", err)

var (
spt *adal.ServicePrincipalToken
err error
)

if p.clientSecret != "" {
spt, err = adal.NewServicePrincipalToken(
p.oAuthConfig,
p.clientID,
p.clientSecret,
p.resourceID,
callback)
if err != nil {
return emptyToken, fmt.Errorf("failed to create service principal token using secret: %s", err)
}
} else if p.clientCert != "" {
certData, err := ioutil.ReadFile(p.clientCert)
if err != nil {
return emptyToken, fmt.Errorf("failed to read the certificate file (%s): %w", p.clientCert, err)
}

// Get the certificate and private key from pfx file
cert, rsaPrivateKey, err := decodePkcs12(certData, "")
if err != nil {
return emptyToken, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %w", err)
}

spt, err = adal.NewServicePrincipalTokenFromCertificate(
p.oAuthConfig,
p.clientID,
cert,
rsaPrivateKey,
p.resourceID,
callback)
if err != nil {
return emptyToken, fmt.Errorf("failed to create service principal token using cert: %s", err)
}
}

err = spt.Refresh()
Expand All @@ -59,3 +109,106 @@ func (p *servicePrincipalToken) Token() (adal.Token, error) {
}
return spt.Token(), nil
}

func isPublicKeyEqual(key1, key2 *rsa.PublicKey) bool {
if key1.N == nil || key2.N == nil {
return false
}
return key1.E == key2.E && key1.N.Cmp(key2.N) == 0
}

func splitPEMBlock(pemBlock []byte) (certPEM []byte, keyPEM []byte) {
for {
var derBlock *pem.Block
derBlock, pemBlock = pem.Decode(pemBlock)
if derBlock == nil {
break
}
if derBlock.Type == certificate {
certPEM = append(certPEM, pem.EncodeToMemory(derBlock)...)
} else if derBlock.Type == privateKey {
keyPEM = append(keyPEM, pem.EncodeToMemory(derBlock)...)
}
}

return certPEM, keyPEM
}

func parseRsaPrivateKey(privateKeyPEM []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(privateKeyPEM)
if block == nil {
return nil, fmt.Errorf("Failed to decode a pem block from private key")
}

privatePkcs1Key, errPkcs1 := x509.ParsePKCS1PrivateKey(block.Bytes)
if errPkcs1 == nil {
return privatePkcs1Key, nil
}

privatePkcs8Key, errPkcs8 := x509.ParsePKCS8PrivateKey(block.Bytes)
if errPkcs8 == nil {
privatePkcs8RsaKey, ok := privatePkcs8Key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("pkcs8 contained non-RSA key. Expected RSA key")
}
return privatePkcs8RsaKey, nil
}

return nil, fmt.Errorf("failed to parse private key as Pkcs#1 or Pkcs#8. (%s). (%s)", errPkcs1, errPkcs8)
}

func parseKeyPairFromPEMBlock(pemBlock []byte) (*x509.Certificate, *rsa.PrivateKey, error) {
certPEM, keyPEM := splitPEMBlock(pemBlock)

privateKey, err := parseRsaPrivateKey(keyPEM)
if err != nil {
return nil, nil, err
}

found := false
var cert *x509.Certificate
for {
var certBlock *pem.Block
var err error
certBlock, certPEM = pem.Decode(certPEM)
if certBlock == nil {
break
}

cert, err = x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("unable to parse certificate. %w", err)
}

certPublicKey, ok := cert.PublicKey.(*rsa.PublicKey)
if ok {
if isPublicKeyEqual(certPublicKey, &privateKey.PublicKey) {
found = true
break
}
}
}

if !found {
return nil, nil, fmt.Errorf("Unable to find a matching public certificate")
}

return cert, privateKey, nil
}

func decodePkcs12(pkcs []byte, password string) (*x509.Certificate, *rsa.PrivateKey, error) {
blocks, err := pkcs12.ToPEM(pkcs, password)
if err != nil {
return nil, nil, err
}

var (
pemData []byte
)

for _, b := range blocks {
pemData = append(pemData, pem.EncodeToMemory(b)...)
}

return parseKeyPairFromPEMBlock(pemData)
}

0 comments on commit 74bd077

Please sign in to comment.