mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 18:48:08 +00:00 
			
		
		
		
	Fix up error detection regression to return correct status codes
This commit is contained in:
		| @@ -97,11 +97,7 @@ func request(core *vault.Core, w http.ResponseWriter, rawReq *http.Request, r *l | |||||||
| 		respondStandby(core, w, rawReq.URL) | 		respondStandby(core, w, rawReq.URL) | ||||||
| 		return resp, false | 		return resp, false | ||||||
| 	} | 	} | ||||||
| 	if respondCommon(w, resp, err) { | 	if respondErrorCommon(w, resp, err) { | ||||||
| 		return resp, false |  | ||||||
| 	} |  | ||||||
| 	if err != nil { |  | ||||||
| 		respondErrorStatus(w, err) |  | ||||||
| 		return resp, false | 		return resp, false | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| @@ -192,18 +188,6 @@ func requestWrapTTL(r *http.Request, req *logical.Request) (*logical.Request, er | |||||||
| 	return req, nil | 	return req, nil | ||||||
| } | } | ||||||
|  |  | ||||||
| // Determines the type of the error being returned and sets the HTTP |  | ||||||
| // status code appropriately |  | ||||||
| func respondErrorStatus(w http.ResponseWriter, err error) { |  | ||||||
| 	status := http.StatusInternalServerError |  | ||||||
| 	switch { |  | ||||||
| 	// Keep adding more error types here to appropriate the status codes |  | ||||||
| 	case err != nil && errwrap.ContainsType(err, new(vault.StatusBadRequest)): |  | ||||||
| 		status = http.StatusBadRequest |  | ||||||
| 	} |  | ||||||
| 	respondError(w, status, err) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func respondError(w http.ResponseWriter, status int, err error) { | func respondError(w http.ResponseWriter, status int, err error) { | ||||||
| 	// Adjust status code when sealed | 	// Adjust status code when sealed | ||||||
| 	if errwrap.Contains(err, vault.ErrSealed.Error()) { | 	if errwrap.Contains(err, vault.ErrSealed.Error()) { | ||||||
| @@ -227,33 +211,43 @@ func respondError(w http.ResponseWriter, status int, err error) { | |||||||
| 	enc.Encode(resp) | 	enc.Encode(resp) | ||||||
| } | } | ||||||
|  |  | ||||||
| func respondCommon(w http.ResponseWriter, resp *logical.Response, err error) bool { | func respondErrorCommon(w http.ResponseWriter, resp *logical.Response, err error) bool { | ||||||
| 	if resp == nil { | 	// If there are no errors return | ||||||
|  | 	if err == nil && (resp == nil || !resp.IsError()) { | ||||||
| 		return false | 		return false | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	if resp.IsError() { | 	// Start out with internal server error since in most of these cases there | ||||||
| 		statusCode := http.StatusBadRequest | 	// won't be a response so this won't be overridden | ||||||
|  | 	statusCode := http.StatusInternalServerError | ||||||
| 		if err != nil { | 	// If we actually have a response, start out with bad request | ||||||
| 			switch err { | 	if resp != nil { | ||||||
| 			case logical.ErrPermissionDenied: | 		statusCode = http.StatusBadRequest | ||||||
| 				statusCode = http.StatusForbidden |  | ||||||
| 			case logical.ErrUnsupportedOperation: |  | ||||||
| 				statusCode = http.StatusMethodNotAllowed |  | ||||||
| 			case logical.ErrUnsupportedPath: |  | ||||||
| 				statusCode = http.StatusNotFound |  | ||||||
| 			case logical.ErrInvalidRequest: |  | ||||||
| 				statusCode = http.StatusBadRequest |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		err := fmt.Errorf("%s", resp.Data["error"].(string)) |  | ||||||
| 		respondError(w, statusCode, err) |  | ||||||
| 		return true |  | ||||||
| 	} | 	} | ||||||
|  |  | ||||||
| 	return false | 	// Now, check the error itself; if it has a specific logical error, set the | ||||||
|  | 	// appropriate code | ||||||
|  | 	if err != nil { | ||||||
|  | 		switch { | ||||||
|  | 		case errwrap.ContainsType(err, new(vault.StatusBadRequest)): | ||||||
|  | 			statusCode = http.StatusBadRequest | ||||||
|  | 		case errwrap.Contains(err, logical.ErrPermissionDenied.Error()): | ||||||
|  | 			statusCode = http.StatusForbidden | ||||||
|  | 		case errwrap.Contains(err, logical.ErrUnsupportedOperation.Error()): | ||||||
|  | 			statusCode = http.StatusMethodNotAllowed | ||||||
|  | 		case errwrap.Contains(err, logical.ErrUnsupportedPath.Error()): | ||||||
|  | 			statusCode = http.StatusNotFound | ||||||
|  | 		case errwrap.Contains(err, logical.ErrInvalidRequest.Error()): | ||||||
|  | 			statusCode = http.StatusBadRequest | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if resp != nil { | ||||||
|  | 		err = fmt.Errorf("%s", resp.Data["error"].(string)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	respondError(w, statusCode, err) | ||||||
|  | 	return true | ||||||
| } | } | ||||||
|  |  | ||||||
| func respondOk(w http.ResponseWriter, body interface{}) { | func respondOk(w http.ResponseWriter, body interface{}) { | ||||||
|   | |||||||
| @@ -30,8 +30,11 @@ func TestLogical(t *testing.T) { | |||||||
| 	testResponseStatus(t, resp, 204) | 	testResponseStatus(t, resp, 204) | ||||||
|  |  | ||||||
| 	// READ | 	// READ | ||||||
| 	resp = testHttpGet(t, token, addr+"/v1/secret/foo") | 	// Bad token should return a 403 | ||||||
|  | 	resp = testHttpGet(t, token+"bad", addr+"/v1/secret/foo") | ||||||
|  | 	testResponseStatus(t, resp, 403) | ||||||
|  |  | ||||||
|  | 	resp = testHttpGet(t, token, addr+"/v1/secret/foo") | ||||||
| 	var actual map[string]interface{} | 	var actual map[string]interface{} | ||||||
| 	var nilWarnings interface{} | 	var nilWarnings interface{} | ||||||
| 	expected := map[string]interface{}{ | 	expected := map[string]interface{}{ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Jeff Mitchell
					Jeff Mitchell