mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 10:37:56 +00:00 
			
		
		
		
	Fix up error detection regression to return correct status codes
This commit is contained in:
		| @@ -97,11 +97,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l | ||||
| 		respondStandby(core, w, rawReq.URL) | ||||
| 		return resp, false | ||||
| 	} | ||||
| 	if respondCommon(w, resp, err) { | ||||
| 		return resp, false | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		respondErrorStatus(w, err) | ||||
| 	if respondErrorCommon(w, resp, err) { | ||||
| 		return resp, false | ||||
| 	} | ||||
|  | ||||
| @@ -192,18 +188,6 @@ func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, er | ||||
| 	return req, nil | ||||
| } | ||||
|  | ||||
| // Determines the type of the error being returned and sets the HTTP | ||||
| // status code appropriately | ||||
| func respondErrorStatus(w http.ResponseWriter, err error) { | ||||
| 	status := http.StatusInternalServerError | ||||
| 	switch { | ||||
| 	// Keep adding more error types here to appropriate the status codes | ||||
| 	case err != nil && errwrap.ContainsType(err, new(vault.StatusBadRequest)): | ||||
| 		status = http.StatusBadRequest | ||||
| 	} | ||||
| 	respondError(w, status, err) | ||||
| } | ||||
|  | ||||
| func respondError(w http.ResponseWriter, status int, err error) { | ||||
| 	// Adjust status code when sealed | ||||
| 	if errwrap.Contains(err, vault.ErrSealed.Error()) { | ||||
| @@ -227,33 +211,43 @@ func respondError(w http.ResponseWriter, status int, err error) { | ||||
| 	enc.Encode(resp) | ||||
| } | ||||
|  | ||||
| func respondCommon(w http.ResponseWriter, resp *logical.Response, err error) bool { | ||||
| 	if resp == nil { | ||||
| func respondErrorCommon(w http.ResponseWriter, resp *logical.Response, err error) bool { | ||||
| 	// If there are no errors return | ||||
| 	if err == nil && (resp == nil || !resp.IsError()) { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	if resp.IsError() { | ||||
| 		statusCode := http.StatusBadRequest | ||||
|  | ||||
| 		if err != nil { | ||||
| 			switch err { | ||||
| 			case logical.ErrPermissionDenied: | ||||
| 				statusCode = http.StatusForbidden | ||||
| 			case logical.ErrUnsupportedOperation: | ||||
| 				statusCode = http.StatusMethodNotAllowed | ||||
| 			case logical.ErrUnsupportedPath: | ||||
| 				statusCode = http.StatusNotFound | ||||
| 			case logical.ErrInvalidRequest: | ||||
| 				statusCode = http.StatusBadRequest | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		err := fmt.Errorf("%s", resp.Data["error"].(string)) | ||||
| 		respondError(w, statusCode, err) | ||||
| 		return true | ||||
| 	// Start out with internal server error since in most of these cases there | ||||
| 	// won't be a response so this won't be overridden | ||||
| 	statusCode := http.StatusInternalServerError | ||||
| 	// If we actually have a response, start out with bad request | ||||
| 	if resp != nil { | ||||
| 		statusCode = http.StatusBadRequest | ||||
| 	} | ||||
|  | ||||
| 	return false | ||||
| 	// Now, check the error itself; if it has a specific logical error, set the | ||||
| 	// appropriate code | ||||
| 	if err != nil { | ||||
| 		switch { | ||||
| 		case errwrap.ContainsType(err, new(vault.StatusBadRequest)): | ||||
| 			statusCode = http.StatusBadRequest | ||||
| 		case errwrap.Contains(err, logical.ErrPermissionDenied.Error()): | ||||
| 			statusCode = http.StatusForbidden | ||||
| 		case errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()): | ||||
| 			statusCode = http.StatusMethodNotAllowed | ||||
| 		case errwrap.Contains(err, logical.ErrUnsupportedPath.Error()): | ||||
| 			statusCode = http.StatusNotFound | ||||
| 		case errwrap.Contains(err, logical.ErrInvalidRequest.Error()): | ||||
| 			statusCode = http.StatusBadRequest | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if resp != nil { | ||||
| 		err = fmt.Errorf("%s", resp.Data["error"].(string)) | ||||
| 	} | ||||
|  | ||||
| 	respondError(w, statusCode, err) | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func respondOk(w http.ResponseWriter, body interface{}) { | ||||
|   | ||||
| @@ -30,8 +30,11 @@ func TestLogical(t *testing.T) { | ||||
| 	testResponseStatus(t, resp, 204) | ||||
|  | ||||
| 	// READ | ||||
| 	resp = testHttpGet(t, token, addr+"/v1/secret/foo") | ||||
| 	// Bad token should return a 403 | ||||
| 	resp = testHttpGet(t, token+"bad", addr+"/v1/secret/foo") | ||||
| 	testResponseStatus(t, resp, 403) | ||||
|  | ||||
| 	resp = testHttpGet(t, token, addr+"/v1/secret/foo") | ||||
| 	var actual map[string]interface{} | ||||
| 	var nilWarnings interface{} | ||||
| 	expected := map[string]interface{}{ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Mitchell
					Jeff Mitchell