mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-31 18:48:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			160 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			160 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package api
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net/http"
 | |
| 
 | |
| 	"github.com/hashicorp/vault/sdk/helper/consts"
 | |
| )
 | |
| 
 | |
| // RaftJoinResponse represents the response of the raft join API
 | |
| type RaftJoinResponse struct {
 | |
| 	Joined bool `json:"joined"`
 | |
| }
 | |
| 
 | |
| // RaftJoinRequest represents the parameters consumed by the raft join API
 | |
| type RaftJoinRequest struct {
 | |
| 	LeaderAPIAddr    string `json:"leader_api_addr"`
 | |
| 	LeaderCACert     string `json:"leader_ca_cert"`
 | |
| 	LeaderClientCert string `json:"leader_client_cert"`
 | |
| 	LeaderClientKey  string `json:"leader_client_key"`
 | |
| 	Retry            bool   `json:"retry"`
 | |
| 	NonVoter         bool   `json:"non_voter"`
 | |
| }
 | |
| 
 | |
| // RaftJoin adds the node from which this call is invoked from to the raft
 | |
| // cluster represented by the leader address in the parameter.
 | |
| func (c *Sys) RaftJoin(opts *RaftJoinRequest) (*RaftJoinResponse, error) {
 | |
| 	r := c.c.NewRequest("POST", "/v1/sys/storage/raft/join")
 | |
| 
 | |
| 	if err := r.SetJSONBody(opts); err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	ctx, cancelFunc := context.WithCancel(context.Background())
 | |
| 	defer cancelFunc()
 | |
| 	resp, err := c.c.RawRequestWithContext(ctx, r)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	defer resp.Body.Close()
 | |
| 
 | |
| 	var result RaftJoinResponse
 | |
| 	err = resp.DecodeJSON(&result)
 | |
| 	return &result, err
 | |
| }
 | |
| 
 | |
| // RaftSnapshot invokes the API that takes the snapshot of the raft cluster and
 | |
| // writes it to the supplied io.Writer.
 | |
| func (c *Sys) RaftSnapshot(snapWriter io.Writer) error {
 | |
| 	r := c.c.NewRequest("GET", "/v1/sys/storage/raft/snapshot")
 | |
| 	r.URL.RawQuery = r.Params.Encode()
 | |
| 
 | |
| 	req, err := http.NewRequest(http.MethodGet, r.URL.RequestURI(), nil)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	req.URL.User = r.URL.User
 | |
| 	req.URL.Scheme = r.URL.Scheme
 | |
| 	req.URL.Host = r.URL.Host
 | |
| 	req.Host = r.URL.Host
 | |
| 
 | |
| 	if r.Headers != nil {
 | |
| 		for header, vals := range r.Headers {
 | |
| 			for _, val := range vals {
 | |
| 				req.Header.Add(header, val)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if len(r.ClientToken) != 0 {
 | |
| 		req.Header.Set(consts.AuthHeaderName, r.ClientToken)
 | |
| 	}
 | |
| 
 | |
| 	if len(r.WrapTTL) != 0 {
 | |
| 		req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL)
 | |
| 	}
 | |
| 
 | |
| 	if len(r.MFAHeaderVals) != 0 {
 | |
| 		for _, mfaHeaderVal := range r.MFAHeaderVals {
 | |
| 			req.Header.Add("X-Vault-MFA", mfaHeaderVal)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	if r.PolicyOverride {
 | |
| 		req.Header.Set("X-Vault-Policy-Override", "true")
 | |
| 	}
 | |
| 
 | |
| 	// Avoiding the use of RawRequestWithContext which reads the response body
 | |
| 	// to determine if the body contains error message.
 | |
| 	var result *Response
 | |
| 	resp, err := c.c.config.HttpClient.Do(req)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	if resp == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	// Check for a redirect, only allowing for a single redirect
 | |
| 	if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 {
 | |
| 		// Parse the updated location
 | |
| 		respLoc, err := resp.Location()
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		// Ensure a protocol downgrade doesn't happen
 | |
| 		if req.URL.Scheme == "https" && respLoc.Scheme != "https" {
 | |
| 			return fmt.Errorf("redirect would cause protocol downgrade")
 | |
| 		}
 | |
| 
 | |
| 		// Update the request
 | |
| 		req.URL = respLoc
 | |
| 
 | |
| 		// Retry the request
 | |
| 		resp, err = c.c.config.HttpClient.Do(req)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	result = &Response{Response: resp}
 | |
| 	if err := result.Error(); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	_, err = io.Copy(snapWriter, resp.Body)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // RaftSnapshotRestore reads the snapshot from the io.Reader and installs that
 | |
| // snapshot, returning the cluster to the state defined by it.
 | |
| func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error {
 | |
| 	path := "/v1/sys/storage/raft/snapshot"
 | |
| 	if force {
 | |
| 		path = "/v1/sys/storage/raft/snapshot-force"
 | |
| 	}
 | |
| 	r := c.c.NewRequest("POST", path)
 | |
| 
 | |
| 	r.Body = snapReader
 | |
| 
 | |
| 	ctx, cancelFunc := context.WithCancel(context.Background())
 | |
| 	defer cancelFunc()
 | |
| 	resp, err := c.c.RawRequestWithContext(ctx, r)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	defer resp.Body.Close()
 | |
| 
 | |
| 	return nil
 | |
| }
 | 
