Working token exchange
5 files changed, 209 insertions(+), 20 deletions(-)

M oauth2_authorizations.go => authorizations.go
M oauth2_clients.go => clients.go
M oauth2_grants.go => grants.go
M logic.go
M routes.go
M oauth2_authorizations.go => authorizations.go +6 -0
@@ 4,6 4,7 @@ import (
 	"context"
 	"database/sql"
 	"fmt"
+	"time"
 
 	sq "github.com/Masterminds/squirrel"
 	"hg.code.netlandish.com/~netlandish/gobwebs/database"

          
@@ 90,3 91,8 @@ func (a *Authorization) Delete(ctx conte
 	})
 	return err
 }
+
+// IsExpired will check the current authorization exiration status
+func (a *Authorization) IsExpired() bool {
+	return time.Now().UTC().After(a.CreatedOn.Add(5 * time.Minute).UTC())
+}

          
M oauth2_clients.go => clients.go +20 -1
@@ 4,6 4,7 @@ import (
 	"context"
 	"crypto/rand"
 	"crypto/sha512"
+	"crypto/subtle"
 	"database/sql"
 	"encoding/base64"
 	"encoding/hex"

          
@@ 110,7 111,7 @@ func (c *Client) Store(ctx context.Conte
 }
 
 // GenerateKeys will create keys for client
-func (c *Client) GenerateKeys() {
+func (c *Client) GenerateKeys() string {
 	var seed [64]byte
 	n, err := rand.Read(seed[:])
 	if err != nil || n != len(seed) {

          
@@ 124,6 125,24 @@ func (c *Client) GenerateKeys() {
 	c.Key = clientID.String()
 	c.SecretHash = hex.EncodeToString(hash[:])
 	c.SecretPartial = partial
+	return secret
+}
+
+// VerifyClientSecret will verify the secret key
+func (c *Client) VerifyClientSecret(clientSecret string) bool {
+	wantHash, err := hex.DecodeString(c.SecretHash)
+	if err != nil {
+		panic(err)
+	}
+
+	b, err := base64.StdEncoding.DecodeString(clientSecret)
+	if err != nil {
+		return false
+	}
+	gotHash := sha512.Sum512(b)
+
+	return subtle.ConstantTimeCompare(wantHash, gotHash[:]) == 1
+
 }
 
 // Delete will delete this client

          
M oauth2_grants.go => grants.go +0 -0

M logic.go +1 -1
@@ 27,7 27,7 @@ func OAuth2(ctx context.Context, token s
 	opts := &database.FilterOptions{
 		Filter: sq.And{
 			sq.Eq{"token_hash": hashStr},
-			sq.Expr("expires > NOW() at time zone 'UTC'"),
+			sq.Expr("expires at time zone 'UTC' > NOW() at time zone 'UTC'"),
 		},
 	}
 	grants, err := GetGrants(ctx, opts)

          
M routes.go +182 -18
@@ 3,6 3,8 @@ package oauth2
 import (
 	"crypto/rand"
 	"crypto/sha512"
+	"database/sql"
+	"encoding/base64"
 	"encoding/hex"
 	"encoding/json"
 	"fmt"

          
@@ 21,6 23,7 @@ import (
 
 // ServiceConfig let's you add basic config variables to service
 type ServiceConfig struct {
+	Helper           Helper
 	DocumentationURL string
 	Scopes           []string
 }

          
@@ 29,13 32,13 @@ type ServiceConfig struct {
 type Service struct {
 	name   string
 	eg     *echo.Group
-	helper Helper
 	config *ServiceConfig
 }
 
 // RegisterRoutes ...
 func (s *Service) RegisterRoutes() {
-	s.eg.POST("/introspect", s.Introspect).Name = s.RouteName("introspect_post")
+	s.eg.POST("/access-token", s.AccessTokenPOST).Name = s.RouteName("access_token_post")
+	s.eg.POST("/introspect", s.IntrospectPOST).Name = s.RouteName("introspect_post")
 
 	s.eg.Use(auth.AuthRequired())
 	s.eg.GET("/personal", s.ListPersonal).Name = s.RouteName("list_personal")

          
@@ 151,17 154,18 @@ func (s *Service) AddClient(c echo.Conte
 			RedirectURL: form.RedirectURL,
 			ClientURL:   form.ClientURL,
 		}
-		client.GenerateKeys()
+		token := client.GenerateKeys()
 		if err := client.Store(c.Request().Context()); err != nil {
 			return err
 		}
-		if s.helper != nil {
+		if s.config.Helper != nil {
 			c.Set("client", client)
-			if err := s.helper.ProcessSuccessfulClientAdd(c); err != nil {
+			if err := s.config.Helper.ProcessSuccessfulClientAdd(c); err != nil {
 				return err
 			}
 		}
 		gmap["client"] = client
+		gmap["token"] = token
 		return gctx.Render(http.StatusOK, "oauth2_add_client_done.html", gmap)
 	}
 	return gctx.Render(http.StatusOK, "oauth2_add_client.html", gmap)

          
@@ 324,21 328,181 @@ func (s *Service) AuthorizePOST(c echo.C
 	return oauth2Redirect(c, client.RedirectURL, gmap)
 }
 
-// Introspect ...
-func (s *Service) Introspect(c echo.Context) error {
+func (s *Service) accessTokenError(c echo.Context, err, desc string, status int) error {
+	if status == 0 {
+		status = http.StatusBadRequest // 400
+	}
+	retErr := struct {
+		Err  string `json:"error"`
+		Desc string `json:"error_description"`
+		URI  string `json:"error_url"`
+	}{
+		Err:  err,
+		Desc: desc,
+		URI:  s.config.DocumentationURL,
+	}
+	return c.JSON(status, &retErr)
+}
+
+// AccessTokenPOST ...
+func (s *Service) AccessTokenPOST(c echo.Context) error {
 	req := c.Request()
 	ctype := req.Header.Get("Content-Type")
 	if ctype != "application/x-www-form-urlencoded" {
-		retErr := struct {
-			Err  string `json:"error"`
-			Desc string `json:"error_description"`
-			URI  string `json:"error_url"`
-		}{
-			Err:  "invalid request",
-			Desc: "Content-Type must be application/x-www-form-urlencoded",
-			URI:  s.config.DocumentationURL,
+		return s.accessTokenError(c, "invalid_request",
+			"Content-Type must be application/x-www-form-urlencoded", 400)
+	}
+
+	params, err := c.FormParams()
+	if err != nil {
+		return err
+	}
+	grantType := params.Get("grant_type")
+	code := params.Get("code")
+	redirectURI := params.Get("redirect_uri")
+	clientID := params.Get("client_id")
+	clientSecret := params.Get("client_secret")
+
+	auth := req.Header.Get("Authorization")
+	if auth != "" && (clientID != "" || clientSecret != "") {
+		return s.accessTokenError(c, "invalid_client",
+			"Cannot supply both client_id & client_secret and Authorization header", 400)
+	} else if auth != "" {
+		parts := strings.SplitN(auth, " ", 2)
+		if len(parts) != 2 || parts[0] != "Basic" {
+			return s.accessTokenError(c, "invalid_client",
+				"Invalid Authorization header", 400)
+		}
+		bytes, err := base64.StdEncoding.DecodeString(parts[1])
+		if err != nil {
+			return s.accessTokenError(c, "invalid_client",
+				"Invalid Authorization header contents", 400)
+		}
+		auth = string(bytes)
+		if !strings.Contains(auth, ":") {
+			return s.accessTokenError(c, "invalid_client",
+				"Invalid Authorization header", 400)
+		}
+		parts = strings.SplitN(auth, ":", 2)
+		clientID, err = url.PathUnescape(parts[0])
+		if err != nil {
+			return s.accessTokenError(c, "invalid_client",
+				"Invalid Authorization header contents", 400)
+		}
+		clientSecret, err = url.PathUnescape(parts[1])
+		if err != nil {
+			return s.accessTokenError(c, "invalid_client",
+				"Invalid Authorization header contents", 400)
 		}
-		return c.JSON(http.StatusBadRequest, &retErr)
+	} else if clientID == "" || clientSecret == "" {
+		return s.accessTokenError(c, "invalid_client",
+			"Missing client authorization", 401)
+	}
+
+	if grantType == "" {
+		return s.accessTokenError(c, "invalid_request",
+			"The grant_type parameter is required", 400)
+	}
+	if grantType != "authorization_code" {
+		return s.accessTokenError(c, "unsupported_grant_type",
+			fmt.Sprintf("Unsupported grant type %s", grantType), 400)
+	}
+	if code == "" {
+		return s.accessTokenError(c, "invalid_request",
+			"The code parameter is required", 400)
+	}
+
+	opts := &database.FilterOptions{
+		Filter: sq.Eq{"code": code},
+	}
+	auths, err := GetAuthorizations(c.Request().Context(), opts)
+	if err != nil {
+		return s.accessTokenError(c, "server_error",
+			"server error occurred, try again.", 400)
+	}
+	if len(auths) == 0 {
+		return s.accessTokenError(c, "invalid_request",
+			"Invalid authorization code", 400)
+	}
+	authCode := auths[0]
+	if authCode.IsExpired() {
+		authCode.Delete(c.Request().Context())
+		return s.accessTokenError(c, "invalid_request",
+			"Authorization code expired", 400)
+	}
+
+	var payload AuthorizationPayload
+	if err = json.Unmarshal([]byte(authCode.Payload), &payload); err != nil {
+		panic(err)
+	}
+
+	issued := time.Now().UTC()
+	expires := issued.Add(366 * 24 * time.Hour)
+
+	client, err := GetClientByID(c.Request().Context(), clientID)
+	if err != nil {
+		return s.accessTokenError(c, "invalid_request",
+			"Invalid client id", 400)
+	}
+	if !client.VerifyClientSecret(clientSecret) {
+		return s.accessTokenError(c, "invalid_request",
+			"Invalid client secret", 400)
+	}
+
+	if redirectURI != "" && redirectURI != client.RedirectURL {
+		return s.accessTokenError(c, "invalid_request",
+			"Invalid redirect URI", 400)
+	}
+
+	bt := BearerToken{
+		Version:  TokenVersion,
+		Issued:   ToTimestamp(issued),
+		Expires:  ToTimestamp(expires),
+		Grants:   payload.Grants,
+		UserID:   payload.UserID,
+		ClientID: payload.ClientKey,
+	}
+
+	token := bt.Encode(c.Request().Context())
+	hash := sha512.Sum512([]byte(token))
+	tokenHash := hex.EncodeToString(hash[:])
+
+	grant := &Grant{
+		Issued:    issued,
+		Expires:   expires,
+		TokenHash: tokenHash,
+		UserID:    payload.UserID,
+		ClientID:  sql.NullInt64{Int64: int64(client.ID), Valid: true},
+	}
+	if err := grant.Store(c.Request().Context()); err != nil {
+		return s.accessTokenError(c, "server_error",
+			"server error occurred storing token, try again.", 400)
+	}
+
+	authCode.Delete(c.Request().Context())
+
+	ret := struct {
+		Token   string `json:"access_token"`
+		Type    string `json:"token_type"`
+		Expires int    `json:"expires_in"`
+		Scope   string `json:"scope"`
+	}{
+		Token:   token,
+		Type:    "bearer",
+		Expires: int(expires.Sub(time.Now().UTC()).Seconds()),
+		Scope:   payload.Grants,
+	}
+
+	return c.JSON(http.StatusOK, &ret)
+}
+
+// IntrospectPOST ...
+func (s *Service) IntrospectPOST(c echo.Context) error {
+	req := c.Request()
+	ctype := req.Header.Get("Content-Type")
+	if ctype != "application/x-www-form-urlencoded" {
+		return s.accessTokenError(c, "invalid_request",
+			"Content-Type must be application/x-www-form-urlencoded", 400)
 	}
 
 	retFalse := struct {

          
@@ 378,11 542,11 @@ func (s *Service) RouteName(value string
 }
 
 // NewService return service
-func NewService(eg *echo.Group, name string, helper Helper, config *ServiceConfig) *Service {
+func NewService(eg *echo.Group, name string, config *ServiceConfig) *Service {
 	if name == "" {
 		name = "oauth2"
 	}
-	service := &Service{name: name, eg: eg, helper: helper, config: config}
+	service := &Service{name: name, eg: eg, config: config}
 	service.RegisterRoutes()
 	return service
 }