diff --git a/http/handler.go b/http/handler.go index d493fbe3e9..aef85e9756 100644 --- a/http/handler.go +++ b/http/handler.go @@ -97,11 +97,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l respondStandby(core, w, rawReq.URL) return resp, false } - if respondCommon(w, resp, err) { - return resp, false - } - if err != nil { - respondErrorStatus(w, err) + if respondErrorCommon(w, resp, err) { return resp, false } @@ -192,18 +188,6 @@ func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, er return req, nil } -// Determines the type of the error being returned and sets the HTTP -// status code appropriately -func respondErrorStatus(w http.ResponseWriter, err error) { - status := http.StatusInternalServerError - switch { - // Keep adding more error types here to appropriate the status codes - case err != nil && errwrap.ContainsType(err, new(vault.StatusBadRequest)): - status = http.StatusBadRequest - } - respondError(w, status, err) -} - func respondError(w http.ResponseWriter, status int, err error) { // Adjust status code when sealed if errwrap.Contains(err, vault.ErrSealed.Error()) { @@ -227,33 +211,43 @@ func respondError(w http.ResponseWriter, status int, err error) { enc.Encode(resp) } -func respondCommon(w http.ResponseWriter, resp *logical.Response, err error) bool { - if resp == nil { +func respondErrorCommon(w http.ResponseWriter, resp *logical.Response, err error) bool { + // If there are no errors return + if err == nil && (resp == nil || !resp.IsError()) { return false } - if resp.IsError() { - statusCode := http.StatusBadRequest - - if err != nil { - switch err { - case logical.ErrPermissionDenied: - statusCode = http.StatusForbidden - case logical.ErrUnsupportedOperation: - statusCode = http.StatusMethodNotAllowed - case logical.ErrUnsupportedPath: - statusCode = http.StatusNotFound - case logical.ErrInvalidRequest: - statusCode = http.StatusBadRequest - } - } - - err := fmt.Errorf("%s", resp.Data["error"].(string)) - respondError(w, statusCode, err) - return true + // Start out with internal server error since in most of these cases there + // won't be a response so this won't be overridden + statusCode := http.StatusInternalServerError + // If we actually have a response, start out with bad request + if resp != nil { + statusCode = http.StatusBadRequest } - return false + // Now, check the error itself; if it has a specific logical error, set the + // appropriate code + if err != nil { + switch { + case errwrap.ContainsType(err, new(vault.StatusBadRequest)): + statusCode = http.StatusBadRequest + case errwrap.Contains(err, logical.ErrPermissionDenied.Error()): + statusCode = http.StatusForbidden + case errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()): + statusCode = http.StatusMethodNotAllowed + case errwrap.Contains(err, logical.ErrUnsupportedPath.Error()): + statusCode = http.StatusNotFound + case errwrap.Contains(err, logical.ErrInvalidRequest.Error()): + statusCode = http.StatusBadRequest + } + } + + if resp != nil { + err = fmt.Errorf("%s", resp.Data["error"].(string)) + } + + respondError(w, statusCode, err) + return true } func respondOk(w http.ResponseWriter, body interface{}) { diff --git a/http/logical_test.go b/http/logical_test.go index c498140d2b..0a8296ba3d 100644 --- a/http/logical_test.go +++ b/http/logical_test.go @@ -30,8 +30,11 @@ func TestLogical(t *testing.T) { testResponseStatus(t, resp, 204) // READ - resp = testHttpGet(t, token, addr+"/v1/secret/foo") + // Bad token should return a 403 + resp = testHttpGet(t, token+"bad", addr+"/v1/secret/foo") + testResponseStatus(t, resp, 403) + resp = testHttpGet(t, token, addr+"/v1/secret/foo") var actual map[string]interface{} var nilWarnings interface{} expected := map[string]interface{}{