Skip to content

Commit

Permalink
SAML Bearer support
Browse files Browse the repository at this point in the history
IAS does not support it, but CF
  • Loading branch information
strehle committed Dec 27, 2024
1 parent 7d9938c commit f23cfb0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 22 deletions.
49 changes: 34 additions & 15 deletions openid-client/openid-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ func main() {
" password Perform resource owner flow, also known as password flow.\n" +
" token-exchange Perform OAuth2 Token Exchange (RFC 8693).\n" +
" jwt-bearer Perform OAuth2 JWT Bearer Grant Type.\n" +
" saml-bearer Perform OAuth2 SAML 2.0 Bearer Grant Type.\n" +
" passcode Retrieve user passcode from X509 user authentication.\n" +
" version Show version.\n" +
" help Show this help for more details.\n" +
"\n" +
"Flags:\n" +
" -issuer IAS. Default is https://<tenant>.accounts.ondemand.com; XSUAA Default is: https://uaa.cf.eu10.hana.ondemand.com/oauth/token\n" +
" -url Generic endpoint for request. Used if issuer is not OIDC complaint with support of discovery endpoint.\n" +
" -client_id OIDC client ID. This is a mandatory flag.\n" +
" -client_secret OIDC client secret. This is an optional flag and only needed for confidential clients.\n" +
" -client_tls P12 file for client mTLS authentication. This is an optional flag and only needed for confidential clients as replacement for client_secret.\n" +
Expand Down Expand Up @@ -78,6 +80,7 @@ func main() {
}

var issEndPoint = flag.String("issuer", "", "OIDC Issuer URI")
var urlEndPoint = flag.String("url", "", "Generic URL endpoint")
var clientID = flag.String("client_id", "", "OIDC client ID")
var clientSecret = flag.String("client_secret", "", "OIDC client secret")
var doRefresh = flag.Bool("refresh", false, "Refresh the received id_token")
Expand Down Expand Up @@ -117,9 +120,9 @@ func main() {
} else {
arguments = os.Args[1:]
}
err := flag.CommandLine.Parse(arguments)
if err != nil {
log.Fatal(err)
oidcError := flag.CommandLine.Parse(arguments)
if oidcError != nil {
log.Fatal(oidcError)
}
switch *command {
case "jwks":
Expand All @@ -130,7 +133,7 @@ func main() {
case "version":
showVersion()
return
case "client_credentials", "password", "token-exchange", "jwt-bearer", "":
case "client_credentials", "password", "token-exchange", "jwt-bearer", "saml-bearer", "":
case "passcode":
*clientID = "T000000" /* default */
case "authorization_code":
Expand All @@ -153,17 +156,25 @@ func main() {
}
var callbackURL = "http://localhost:" + *portParameter + "/callback"
ctx := context.Background()
provider, err := oidc.NewProvider(ctx, *issEndPoint)
if err != nil {
log.Fatal(err)
}
var claims struct {
AuthorizeEndpoint string `json:"authorization_endpoint"`
EndSessionEndpoint string `json:"end_session_endpoint"`
TokenEndPoint string `json:"token_endpoint"`
}
err = provider.Claims(&claims)
if err != nil {
log.Fatal(err)
provider, oidcError := oidc.NewProvider(ctx, *issEndPoint)
if oidcError != nil {
if *urlEndPoint != "" && *command != "" {
claims.TokenEndPoint = *urlEndPoint
claims.AuthorizeEndpoint = *urlEndPoint
claims.EndSessionEndpoint = ""
} else {
log.Fatal(oidcError)
}
} else {
oidcError = provider.Claims(&claims)
if oidcError != nil {
log.Fatal(oidcError)
}
}
tlsClient := &http.Client{
Transport: &http.Transport{
Expand Down Expand Up @@ -363,15 +374,23 @@ func main() {
if *resourceParam != "" {
requestMap.Add("resource", *resourceParam)
}
var exchangedTokenResponse = client.HandleTokenExchangeGrant(requestMap, *provider, *tlsClient, verbose)
var exchangedTokenResponse = client.HandleTokenExchangeGrant(requestMap, claims.TokenEndPoint, *tlsClient, verbose)
fmt.Println(exchangedTokenResponse)
} else if *command == "jwt-bearer" {
requestMap.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
if *assertionToken == "" {
log.Fatal("assertion parameter not set. Needed to pass it for JWT bearer")
}
requestMap.Set("assertion", *assertionToken)
var jwtBearerTokenResponse = client.HandleJwtBearerGrant(requestMap, *provider, *tlsClient, verbose)
var jwtBearerTokenResponse = client.HandleJwtBearerGrant(requestMap, claims.TokenEndPoint, *tlsClient, verbose)
fmt.Println(jwtBearerTokenResponse)
} else if *command == "saml-bearer" {
requestMap.Set("grant_type", "urn:ietf:params:oauth:grant-type:saml2-bearer")
if *assertionToken == "" {
log.Fatal("assertion parameter not set. Needed to pass it for SAML bearer")
}
requestMap.Set("assertion", *assertionToken)
var jwtBearerTokenResponse = client.HandleSamlBearerGrant(requestMap, claims.TokenEndPoint, *tlsClient, verbose)
fmt.Println(jwtBearerTokenResponse)
} else if *command == "passcode" {
if *issEndPoint == "" || !strings.HasPrefix(*issEndPoint, "https://") {
Expand Down Expand Up @@ -406,7 +425,7 @@ func main() {
log.Fatal("client_secret is required to run this command")
return
}
var idpTokenResponse = client.HandleCorpIdpExchangeFlow(*clientID, *clientSecret, idToken, *idpScopeParameter, privateKeyJwt, *provider, *tlsClient)
var idpTokenResponse = client.HandleCorpIdpExchangeFlow(*clientID, *clientSecret, idToken, *idpScopeParameter, privateKeyJwt, claims.TokenEndPoint, *tlsClient)
data, _ := json.MarshalIndent(idpTokenResponse, "", " ")
if verbose {
fmt.Println("Response from endpoint /exchange/corporateidp")
Expand All @@ -429,7 +448,7 @@ func main() {
requestMap.Add("resource", *resourceParam)
}

var exchangedTokenResponse = client.HandleTokenExchangeGrant(requestMap, *provider, *tlsClient, verbose)
var exchangedTokenResponse = client.HandleTokenExchangeGrant(requestMap, claims.TokenEndPoint, *tlsClient, verbose)
fmt.Println(exchangedTokenResponse)
}
}
Expand Down
48 changes: 41 additions & 7 deletions pkg/client/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package client
import (
"encoding/json"
"fmt"
"github.com/coreos/go-oidc/v3/oidc"
"golang.org/x/oauth2"
"io"
"log"
Expand All @@ -12,7 +11,7 @@ import (
"strings"
)

func HandleCorpIdpExchangeFlow(clientID string, clientSecret string, existingIdToken string, idpScopeParameter string, privateKeyJwt string, provider oidc.Provider, tlsClient http.Client) map[string]interface{} {
func HandleCorpIdpExchangeFlow(clientID string, clientSecret string, existingIdToken string, idpScopeParameter string, privateKeyJwt string, tokenEndpoint string, tlsClient http.Client) map[string]interface{} {

params := url.Values{}
params.Add("assertion", existingIdToken)
Expand All @@ -28,7 +27,7 @@ func HandleCorpIdpExchangeFlow(clientID string, clientSecret string, existingIdT

body := strings.NewReader(params.Encode())

tokenEndPoint := strings.Replace(provider.Endpoint().TokenURL, "/token", "/exchange/corporateidp", 1)
tokenEndPoint := strings.Replace(tokenEndpoint, "/token", "/exchange/corporateidp", 1)
fmt.Println("Call IdP Token Exchange Endpoint: " + tokenEndPoint)
req, err := http.NewRequest("POST", tokenEndPoint, body)
if err != nil {
Expand Down Expand Up @@ -58,10 +57,10 @@ func HandleCorpIdpExchangeFlow(clientID string, clientSecret string, existingIdT
return outBodyMap
}

func HandleTokenExchangeGrant(request url.Values, provider oidc.Provider, tlsClient http.Client, verbose bool) string {
func HandleTokenExchangeGrant(request url.Values, tokenEndpoint string, tlsClient http.Client, verbose bool) string {
accessToken := ""
request.Set("grant_type", "urn:ietf:params:oauth:grant-type:token-exchange")
req, requestError := http.NewRequest("POST", provider.Endpoint().TokenURL, strings.NewReader(request.Encode()))
req, requestError := http.NewRequest("POST", tokenEndpoint, strings.NewReader(request.Encode()))
if requestError != nil {
log.Fatal(requestError)
}
Expand Down Expand Up @@ -93,10 +92,10 @@ func HandleTokenExchangeGrant(request url.Values, provider oidc.Provider, tlsCli
return accessToken
}

func HandleJwtBearerGrant(request url.Values, provider oidc.Provider, tlsClient http.Client, verbose bool) string {
func HandleJwtBearerGrant(request url.Values, tokenEndpoint string, tlsClient http.Client, verbose bool) string {
accessToken := ""
request.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
req, requestError := http.NewRequest("POST", provider.Endpoint().TokenURL, strings.NewReader(request.Encode()))
req, requestError := http.NewRequest("POST", tokenEndpoint, strings.NewReader(request.Encode()))
if requestError != nil {
log.Fatal(requestError)
}
Expand Down Expand Up @@ -128,6 +127,41 @@ func HandleJwtBearerGrant(request url.Values, provider oidc.Provider, tlsClient
return accessToken
}

func HandleSamlBearerGrant(request url.Values, tokenEndpoint string, tlsClient http.Client, verbose bool) string {
accessToken := ""
request.Set("grant_type", "urn:ietf:params:oauth:grant-type:saml2-bearer")
req, requestError := http.NewRequest("POST", tokenEndpoint, strings.NewReader(request.Encode()))
if requestError != nil {
log.Fatal(requestError)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, clientError := tlsClient.Do(req)
if clientError != nil {
log.Fatal(clientError)
}
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
if result != nil {
jsonStr, marshalError := json.Marshal(result)
if marshalError != nil {
log.Fatal(marshalError)
}
var myToken oauth2.Token
json.Unmarshal([]byte(jsonStr), &myToken)
if myToken.AccessToken == "" {
fmt.Println(string(jsonStr))
} else {
if verbose {
fmt.Println("Response from SAML bearer endpoint ")
ShowJSonResponse(result, verbose)
}
accessToken = myToken.AccessToken
}
}
return accessToken
}

func ShowJSonResponse(result map[string]interface{}, verbose bool) {
fmt.Println("==========")
resultJson, _ := json.MarshalIndent(result, "", " ")
Expand Down

0 comments on commit f23cfb0

Please sign in to comment.