Adding oauth2 authorization handlers
8 files changed, 350 insertions(+), 7 deletions(-)

M bearer.go
M logic.go
M models.go
A => oauth2_authorizations.go
M oauth2_clients.go
M oauth2_grants.go
M routes.go
M schema.sql
M bearer.go +40 -0
@@ 136,6 136,9 @@ func DecodeGrants(grants string) Grants 
 			access = "RO"
 		} else {
 			access = parts[1]
+			if access != "RW" && access != "RO" {
+				access = "RO"
+			}
 		}
 		accessMap[scope] = access
 	}

          
@@ 174,6 177,35 @@ func (g *Grants) Encode() string {
 	return g.encoded
 }
 
+// Validate ...
+func (g *Grants) Validate(scopes []string) []string {
+	var errors []string
+	for k := range g.grants {
+		if !contains(scopes, k) {
+			errors = append(errors, fmt.Sprintf("Invalid scope: %s", k))
+		}
+	}
+	return errors
+}
+
+// List ...
+func (g *Grants) List() []string {
+	var grants []string
+	for k, v := range g.grants {
+		grants = append(grants, fmt.Sprintf("%s:%s", k, v))
+	}
+	return grants
+}
+
+func contains(values []string, str string) bool {
+	for _, v := range values {
+		if v == str {
+			return true
+		}
+	}
+	return false
+}
+
 // TokenUser wrapper for gobwebs.User and token grants
 type TokenUser struct {
 	User      gobwebs.User

          
@@ 181,3 213,11 @@ type TokenUser struct {
 	Grants    *Grants
 	TokenHash [64]byte
 }
+
+// AuthorizationPayload holds temporary approval while the oauth2 cycle is
+// in process
+type AuthorizationPayload struct {
+	Grants    string
+	ClientKey string
+	UserID    int
+}

          
M logic.go +4 -4
@@ 25,7 25,10 @@ func OAuth2(ctx context.Context, token s
 	hash := sha512.Sum512([]byte(token))
 	hashStr := hex.EncodeToString(hash[:])
 	opts := &database.FilterOptions{
-		Filter: sq.Eq{"token_hash": hashStr},
+		Filter: sq.And{
+			sq.Eq{"token_hash": hashStr},
+			sq.Expr("expires > NOW() at time zone 'UTC'"),
+		},
 	}
 	grants, err := GetGrants(ctx, opts)
 	if err != nil {

          
@@ 39,9 42,6 @@ func OAuth2(ctx context.Context, token s
 		return nil, fmt.Errorf("Error with provided OAuth 2.0 bearer token")
 	}
 	grant := grants[0]
-	if grant.IsExpired() {
-		return nil, fmt.Errorf("Invalid or expired OAuth 2.0 bearer token")
-	}
 
 	bt.Issued = ToTimestamp(grant.Issued)
 	gt := DecodeGrants(bt.Grants)

          
M models.go +8 -0
@@ 31,3 31,11 @@ type Grant struct {
 	UserID    int           `db:"user_id"`
 	ClientID  sql.NullInt64 `db:"client_id"`
 }
+
+// Authorization ...
+type Authorization struct {
+	ID        int       `db:"id"`
+	Code      string    `db:"code"`
+	Payload   string    `db:"payload"`
+	CreatedOn time.Time `db:"created_on"`
+}

          
A => oauth2_authorizations.go +92 -0
@@ 0,0 1,92 @@ 
+package oauth2
+
+import (
+	"context"
+	"database/sql"
+	"fmt"
+
+	sq "github.com/Masterminds/squirrel"
+	"hg.code.netlandish.com/~netlandish/gobwebs/database"
+)
+
+// GetAuthorizations retuns oauth2 authorizations using the given filters
+func GetAuthorizations(ctx context.Context, opts *database.FilterOptions) ([]*Authorization, error) {
+	if opts == nil {
+		opts = &database.FilterOptions{}
+	}
+	auths := make([]*Authorization, 0)
+	if err := database.WithTx(ctx, database.TxOptionsRO, func(tx *sql.Tx) error {
+		q := opts.GetBuilder(nil)
+		rows, err := q.
+			Columns("id", "code", "payload", "created_on").
+			From("oauth2_authorizations").
+			PlaceholderFormat(sq.Dollar).
+			RunWith(tx).
+			QueryContext(ctx)
+		if err != nil {
+			if err == sql.ErrNoRows {
+				return nil
+			}
+			return err
+		}
+		defer rows.Close()
+
+		for rows.Next() {
+			var a Authorization
+			if err = rows.Scan(&a.ID, &a.Code, &a.Payload, &a.CreatedOn); err != nil {
+				return err
+			}
+			auths = append(auths, &a)
+		}
+		return nil
+	}); err != nil {
+		return nil, err
+	}
+	return auths, nil
+}
+
+// Store will save a client
+func (a *Authorization) Store(ctx context.Context) error {
+	err := database.WithTx(ctx, nil, func(tx *sql.Tx) error {
+		var err error
+		if a.ID == 0 {
+			err = sq.
+				Insert("oauth2_authorizations").
+				Columns("code", "payload").
+				Values(a.Code, a.Payload).
+				Suffix(`RETURNING (id)`).
+				PlaceholderFormat(sq.Dollar).
+				RunWith(tx).
+				ScanContext(ctx, &a.ID)
+		} else {
+			err = sq.
+				Update("oauth2_authorizations").
+				Set("code", a.Code).
+				Set("payload", a.Payload).
+				Where("id = ?", a.ID).
+				Suffix(`RETURNING (id)`).
+				PlaceholderFormat(sq.Dollar).
+				RunWith(tx).
+				ScanContext(ctx, &a.ID)
+		}
+		return err
+	})
+	return err
+}
+
+// Delete will delete this rate
+func (a *Authorization) Delete(ctx context.Context) error {
+	if a.ID == 0 {
+		return fmt.Errorf("Authorization object is not populated")
+	}
+	err := database.WithTx(ctx, nil, func(tx *sql.Tx) error {
+		_, err := sq.
+			Delete("oauth2_authorizations").
+			Where("id = ?", a.ID).
+			PlaceholderFormat(sq.Dollar).
+			RunWith(tx).
+			ExecContext(ctx)
+		return err
+	})
+	return err
+}

          
M oauth2_clients.go +15 -0
@@ 54,6 54,21 @@ func GetClients(ctx context.Context, opt
 	return clients, nil
 }
 
+// GetClientByID will fetch a client by given client id
+func GetClientByID(ctx context.Context, clientID string) (*Client, error) {
+	opts := &database.FilterOptions{
+		Filter: sq.Eq{"key": clientID},
+	}
+	clients, err := GetClients(ctx, opts)
+	if err != nil {
+		return nil, err
+	}
+	if len(clients) == 0 {
+		return nil, nil
+	}
+	return clients[0], nil
+}
+
 // Store will save a client
 func (c *Client) Store(ctx context.Context) error {
 	err := database.WithTx(ctx, nil, func(tx *sql.Tx) error {

          
M oauth2_grants.go +6 -0
@@ 85,6 85,12 @@ func (g *Grant) IsExpired() bool {
 	return time.Now().UTC().After(g.Expires.UTC())
 }
 
+// Revoke ...
+func (g *Grant) Revoke(ctx context.Context) error {
+	g.Expires = time.Now().UTC()
+	return g.Store(ctx)
+}
+
 // Delete will delete this rate
 func (g *Grant) Delete(ctx context.Context) error {
 	if g.ID == 0 {

          
M routes.go +174 -3
@@ 1,10 1,14 @@ 
 package oauth2
 
 import (
+	"crypto/rand"
 	"crypto/sha512"
 	"encoding/hex"
+	"encoding/json"
 	"fmt"
 	"net/http"
+	"net/url"
+	"strings"
 	"time"
 
 	sq "github.com/Masterminds/squirrel"

          
@@ 15,11 19,18 @@ import (
 	"hg.code.netlandish.com/~netlandish/gobwebs/server"
 )
 
+// ServiceConfig let's you add basic config variables to service
+type ServiceConfig struct {
+	DocumentationURL string
+	Scopes           []string
+}
+
 // Service is the base accounts service struct
 type Service struct {
 	name   string
 	eg     *echo.Group
 	helper Helper
+	config *ServiceConfig
 }
 
 // RegisterRoutes ...

          
@@ 33,6 44,8 @@ func (s *Service) RegisterRoutes() {
 	s.eg.GET("/clients", s.ListClients).Name = s.RouteName("list_clients")
 	s.eg.GET("/clients/add", s.AddClient).Name = s.RouteName("add_client")
 	s.eg.POST("/clients/add", s.AddClient).Name = s.RouteName("add_client_post")
+	s.eg.GET("/authorize", s.Authorize).Name = s.RouteName("authorize")
+	s.eg.POST("/authorize", s.AuthorizePOST).Name = s.RouteName("authorize_post")
 }
 
 // ListPersonal ...

          
@@ 154,6 167,163 @@ func (s *Service) AddClient(c echo.Conte
 	return gctx.Render(http.StatusOK, "oauth2_add_client.html", gmap)
 }
 
+func oauth2Redirect(c echo.Context, redirectURI string, params gobwebs.Map) error {
+	parts, err := url.Parse(redirectURI)
+	if err != nil {
+		return err
+	}
+	qs := parts.Query()
+	for k, v := range params {
+		qs.Set(k, v.(string))
+	}
+	parts.RawQuery = qs.Encode()
+	return c.Redirect(http.StatusMovedPermanently, parts.String())
+}
+
+func (s *Service) authorizeError(c echo.Context,
+	redirectURI, state, errorCode, errorDescription string) error {
+	if redirectURI == "" {
+		gctx := c.(*server.Context)
+		return gctx.Render(http.StatusOK, "oauth2_error.html", gobwebs.Map{
+			"code":        errorCode,
+			"description": errorDescription,
+		})
+	}
+	return oauth2Redirect(c, redirectURI, gobwebs.Map{
+		"error":             errorCode,
+		"error_description": errorDescription,
+		"error_uri":         s.config.DocumentationURL,
+		"state":             state,
+	})
+}
+
+// Authorize ...
+func (s *Service) Authorize(c echo.Context) error {
+	respType := c.QueryParam("response_type")
+	clientID := c.QueryParam("client_id")
+	scope := c.QueryParam("scope")
+	state := c.QueryParam("state")
+	redirectURL := c.QueryParam("redirect_uri")
+
+	if clientID == "" {
+		return s.authorizeError(c, "", state, "invalid_request",
+			"The client_id parameter is required")
+	}
+
+	client, err := GetClientByID(c.Request().Context(), clientID)
+	if err != nil {
+		return s.authorizeError(c, "", state, "server_error", err.Error())
+	}
+	if client == nil {
+		return s.authorizeError(c, "", state, "invalid_request", "Invalid client ID")
+	}
+
+	if redirectURL != "" && redirectURL != client.RedirectURL {
+		return s.authorizeError(c, "", state, "invalid_request",
+			"The redirect_uri parameter doesn't match the registered client's")
+	}
+	if respType != "code" {
+		return s.authorizeError(c, redirectURL, state, "unsupported_response_type",
+			"The response_type parameter must be set to 'code'")
+	}
+	if scope == "" {
+		return s.authorizeError(c, redirectURL, state, "invalid_scope",
+			"The scope parameter is required")
+	}
+
+	grants := DecodeGrants(scope)
+	errors := grants.Validate(s.config.Scopes)
+	if len(errors) > 0 {
+		return s.authorizeError(c, redirectURL, state,
+			"invalid_scope", strings.Join(errors, ", "))
+	}
+
+	gctx := c.(*server.Context)
+	return gctx.Render(http.StatusOK, "oauth2_authorization.html", gobwebs.Map{
+		"client":       client,
+		"grants":       grants.List(),
+		"client_id":    clientID,
+		"redirect_uri": redirectURL,
+		"state":        state,
+	})
+}
+
+// AuthorisePOST ...
+func (s *Service) AuthorizePOST(c echo.Context) error {
+	params, err := c.FormParams()
+	if err != nil {
+		return err
+	}
+	clientID := params.Get("client_id")
+	redirectURL := params.Get("redirect_uri")
+	state := params.Get("state")
+
+	if params.Has("reject") {
+		return s.authorizeError(c, redirectURL, state, "access_denied",
+			"The resource owner denied the request.")
+	}
+
+	subgrants := []string{}
+	// XXX csrf shouldn't be hard coded here
+	skip := []string{"accept", "client_id", "redirect_uri", "state", "csrf"}
+	for grant := range params {
+		if contains(skip, grant) {
+			continue
+		}
+		subgrants = append(subgrants, grant)
+	}
+
+	grants := DecodeGrants(strings.Join(subgrants, " "))
+	errors := grants.Validate(s.config.Scopes)
+	if len(errors) > 0 {
+		return s.authorizeError(c, redirectURL, state,
+			"invalid_scope", strings.Join(errors, ", "))
+	}
+
+	client, err := GetClientByID(c.Request().Context(), clientID)
+	if err != nil {
+		return s.authorizeError(c, "", state, "server_error", err.Error())
+	}
+	if client == nil {
+		return s.authorizeError(c, "", state, "invalid_request", "Invalid client ID")
+	}
+
+	var seed [64]byte
+	gctx := c.(*server.Context)
+	n, err := rand.Read(seed[:])
+	if err != nil || n != len(seed) {
+		panic(err)
+	}
+	hash := sha512.Sum512(seed[:])
+	code := hex.EncodeToString(hash[:])[:32]
+
+	payload := AuthorizationPayload{
+		Grants:    grants.encoded,
+		ClientKey: clientID,
+		UserID:    int(gctx.User.GetID()),
+	}
+	data, err := json.Marshal(&payload)
+	if err != nil {
+		panic(err)
+	}
+
+	auth := &Authorization{
+		Code:    code,
+		Payload: string(data),
+	}
+	if err := auth.Store(c.Request().Context()); err != nil {
+		return s.authorizeError(c, "", state, "server_error", err.Error())
+	}
+
+	gmap := gobwebs.Map{
+		"code": code,
+	}
+	if state != "" {
+		gmap["state"] = state
+	}
+	return oauth2Redirect(c, client.RedirectURL, gmap)
+}
+
 // Introspect ...
 func (s *Service) Introspect(c echo.Context) error {
 	req := c.Request()

          
@@ 162,10 332,11 @@ func (s *Service) Introspect(c echo.Cont
 		retErr := struct {
 			Err  string `json:"error"`
 			Desc string `json:"error_description"`
-			URI  string `json:"error_url"` // TODO Make this customizable
+			URI  string `json:"error_url"`
 		}{
 			Err:  "invalid request",
 			Desc: "Content-Type must be application/x-www-form-urlencoded",
+			URI:  s.config.DocumentationURL,
 		}
 		return c.JSON(http.StatusBadRequest, &retErr)
 	}

          
@@ 207,11 378,11 @@ func (s *Service) RouteName(value string
 }
 
 // NewService return service
-func NewService(eg *echo.Group, name string, helper Helper) *Service {
+func NewService(eg *echo.Group, name string, helper Helper, config *ServiceConfig) *Service {
 	if name == "" {
 		name = "oauth2"
 	}
-	service := &Service{name: name, eg: eg, helper: helper}
+	service := &Service{name: name, eg: eg, helper: helper, config: config}
 	service.RegisterRoutes()
 	return service
 }

          
M schema.sql +11 -0
@@ 45,3 45,14 @@ CREATE TABLE oauth2_grants (
 CREATE INDEX oauth2_grants_id_idx ON oauth2_grants (id);
 CREATE INDEX oauth2_grants_user_id_idx ON oauth2_grants (user_id);
 CREATE INDEX oauth2_grants_client_id_idx ON oauth2_grants (client_id);
+
+
+CREATE TABLE oauth2_authorizations (
+        id serial PRIMARY KEY,
+	code character varying(128) NOT NULL,
+	payload character varying NOT NULL,
+	created_on TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP
+);
+
+CREATE INDEX oauth2_authorizations_id_idx ON oauth2_authorizations (id);
+CREATE INDEX oauth2_authorizations_code_idx ON oauth2_authorizations (code);