mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-10-30 18:17:55 +00:00 
			
		
		
		
	Refactor Code Focused on DevTLS Mode into New Function (#20376)
* refactor code focused on DevTLS mode into new function * add tests for configureDevTLS function * replace testcase comments with fields in testcase struct
This commit is contained in:
		| @@ -930,6 +930,69 @@ func (c *ServerCommand) InitListeners(config *server.Config, disableClustering b | |||||||
| 	return 0, lns, clusterAddrs, nil | 	return 0, lns, clusterAddrs, nil | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func configureDevTLS(c *ServerCommand) (func(), *server.Config, string, error) { | ||||||
|  | 	var devStorageType string | ||||||
|  |  | ||||||
|  | 	switch { | ||||||
|  | 	case c.flagDevConsul: | ||||||
|  | 		devStorageType = "consul" | ||||||
|  | 	case c.flagDevHA && c.flagDevTransactional: | ||||||
|  | 		devStorageType = "inmem_transactional_ha" | ||||||
|  | 	case !c.flagDevHA && c.flagDevTransactional: | ||||||
|  | 		devStorageType = "inmem_transactional" | ||||||
|  | 	case c.flagDevHA && !c.flagDevTransactional: | ||||||
|  | 		devStorageType = "inmem_ha" | ||||||
|  | 	default: | ||||||
|  | 		devStorageType = "inmem" | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var certDir string | ||||||
|  | 	var err error | ||||||
|  | 	var config *server.Config | ||||||
|  | 	var f func() | ||||||
|  |  | ||||||
|  | 	if c.flagDevTLS { | ||||||
|  | 		if c.flagDevTLSCertDir != "" { | ||||||
|  | 			if _, err = os.Stat(c.flagDevTLSCertDir); err != nil { | ||||||
|  | 				return nil, nil, "", err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			certDir = c.flagDevTLSCertDir | ||||||
|  | 		} else { | ||||||
|  | 			if certDir, err = os.MkdirTemp("", "vault-tls"); err != nil { | ||||||
|  | 				return nil, nil, certDir, err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		config, err = server.DevTLSConfig(devStorageType, certDir) | ||||||
|  |  | ||||||
|  | 		f = func() { | ||||||
|  | 			if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)); err != nil { | ||||||
|  | 				c.UI.Error(err.Error()) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCertFilename)); err != nil { | ||||||
|  | 				c.UI.Error(err.Error()) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevKeyFilename)); err != nil { | ||||||
|  | 				c.UI.Error(err.Error()) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// Only delete temp directories we made. | ||||||
|  | 			if c.flagDevTLSCertDir == "" { | ||||||
|  | 				if err := os.Remove(certDir); err != nil { | ||||||
|  | 					c.UI.Error(err.Error()) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 	} else { | ||||||
|  | 		config, err = server.DevConfig(devStorageType) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return f, config, certDir, err | ||||||
|  | } | ||||||
|  |  | ||||||
| func (c *ServerCommand) Run(args []string) int { | func (c *ServerCommand) Run(args []string) int { | ||||||
| 	f := c.Flags() | 	f := c.Flags() | ||||||
|  |  | ||||||
| @@ -970,68 +1033,11 @@ func (c *ServerCommand) Run(args []string) int { | |||||||
|  |  | ||||||
| 	// Load the configuration | 	// Load the configuration | ||||||
| 	var config *server.Config | 	var config *server.Config | ||||||
| 	var err error |  | ||||||
| 	var certDir string | 	var certDir string | ||||||
| 	if c.flagDev { | 	if c.flagDev { | ||||||
| 		var devStorageType string | 		df, cfg, dir, err := configureDevTLS(c) | ||||||
| 		switch { | 		if df != nil { | ||||||
| 		case c.flagDevConsul: | 			defer df() | ||||||
| 			devStorageType = "consul" |  | ||||||
| 		case c.flagDevHA && c.flagDevTransactional: |  | ||||||
| 			devStorageType = "inmem_transactional_ha" |  | ||||||
| 		case !c.flagDevHA && c.flagDevTransactional: |  | ||||||
| 			devStorageType = "inmem_transactional" |  | ||||||
| 		case c.flagDevHA && !c.flagDevTransactional: |  | ||||||
| 			devStorageType = "inmem_ha" |  | ||||||
| 		default: |  | ||||||
| 			devStorageType = "inmem" |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if c.flagDevTLS { |  | ||||||
| 			if c.flagDevTLSCertDir != "" { |  | ||||||
| 				_, err := os.Stat(c.flagDevTLSCertDir) |  | ||||||
| 				if err != nil { |  | ||||||
| 					c.UI.Error(err.Error()) |  | ||||||
| 					return 1 |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				certDir = c.flagDevTLSCertDir |  | ||||||
| 			} else { |  | ||||||
| 				certDir, err = os.MkdirTemp("", "vault-tls") |  | ||||||
| 				if err != nil { |  | ||||||
| 					c.UI.Error(err.Error()) |  | ||||||
| 					return 1 |  | ||||||
| 				} |  | ||||||
| 			} |  | ||||||
| 			config, err = server.DevTLSConfig(devStorageType, certDir) |  | ||||||
|  |  | ||||||
| 			defer func() { |  | ||||||
| 				err := os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCAFilename)) |  | ||||||
| 				if err != nil { |  | ||||||
| 					c.UI.Error(err.Error()) |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				err = os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevCertFilename)) |  | ||||||
| 				if err != nil { |  | ||||||
| 					c.UI.Error(err.Error()) |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				err = os.Remove(fmt.Sprintf("%s/%s", certDir, server.VaultDevKeyFilename)) |  | ||||||
| 				if err != nil { |  | ||||||
| 					c.UI.Error(err.Error()) |  | ||||||
| 				} |  | ||||||
|  |  | ||||||
| 				// Only delete temp directories we made. |  | ||||||
| 				if c.flagDevTLSCertDir == "" { |  | ||||||
| 					err = os.Remove(certDir) |  | ||||||
| 					if err != nil { |  | ||||||
| 						c.UI.Error(err.Error()) |  | ||||||
| 					} |  | ||||||
| 				} |  | ||||||
| 			}() |  | ||||||
|  |  | ||||||
| 		} else { |  | ||||||
| 			config, err = server.DevConfig(devStorageType) |  | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -1039,6 +1045,9 @@ func (c *ServerCommand) Run(args []string) int { | |||||||
| 			return 1 | 			return 1 | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | 		config = cfg | ||||||
|  | 		certDir = dir | ||||||
|  |  | ||||||
| 		if c.flagDevListenAddr != "" { | 		if c.flagDevListenAddr != "" { | ||||||
| 			config.Listeners[0].Address = c.flagDevListenAddr | 			config.Listeners[0].Address = c.flagDevListenAddr | ||||||
| 		} | 		} | ||||||
|   | |||||||
| @@ -330,3 +330,66 @@ func TestServer_DevTLS(t *testing.T) { | |||||||
| 	require.Equal(t, 0, retCode, output) | 	require.Equal(t, 0, retCode, output) | ||||||
| 	require.Contains(t, output, `tls: "enabled"`) | 	require.Contains(t, output, `tls: "enabled"`) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // TestConfigureDevTLS verifies the various logic paths that flow through the | ||||||
|  | // configureDevTLS function. | ||||||
|  | func TestConfigureDevTLS(t *testing.T) { | ||||||
|  | 	testcases := []struct { | ||||||
|  | 		ServerCommand   *ServerCommand | ||||||
|  | 		DeferFuncNotNil bool | ||||||
|  | 		ConfigNotNil    bool | ||||||
|  | 		TLSDisable      bool | ||||||
|  | 		CertPathEmpty   bool | ||||||
|  | 		ErrNotNil       bool | ||||||
|  | 		TestDescription string | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			ServerCommand: &ServerCommand{ | ||||||
|  | 				flagDevTLS: false, | ||||||
|  | 			}, | ||||||
|  | 			ConfigNotNil:    true, | ||||||
|  | 			TLSDisable:      true, | ||||||
|  | 			CertPathEmpty:   true, | ||||||
|  | 			ErrNotNil:       false, | ||||||
|  | 			TestDescription: "flagDev is false, nothing will be configured", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			ServerCommand: &ServerCommand{ | ||||||
|  | 				flagDevTLS:        true, | ||||||
|  | 				flagDevTLSCertDir: "", | ||||||
|  | 			}, | ||||||
|  | 			DeferFuncNotNil: true, | ||||||
|  | 			ConfigNotNil:    true, | ||||||
|  | 			ErrNotNil:       false, | ||||||
|  | 			TestDescription: "flagDevTLSCertDir is empty", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			ServerCommand: &ServerCommand{ | ||||||
|  | 				flagDevTLS:        true, | ||||||
|  | 				flagDevTLSCertDir: "@/#", | ||||||
|  | 			}, | ||||||
|  | 			CertPathEmpty:   true, | ||||||
|  | 			ErrNotNil:       true, | ||||||
|  | 			TestDescription: "flagDevTLSCertDir is set to something invalid", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, testcase := range testcases { | ||||||
|  | 		fun, cfg, certPath, err := configureDevTLS(testcase.ServerCommand) | ||||||
|  | 		if fun != nil { | ||||||
|  | 			// If a function is returned, call it right away to clean up | ||||||
|  | 			// files created in the temporary directory before anything else has | ||||||
|  | 			// a chance to fail this test. | ||||||
|  | 			fun() | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		require.Equal(t, testcase.DeferFuncNotNil, (fun != nil), "test description %s", testcase.TestDescription) | ||||||
|  | 		require.Equal(t, testcase.ConfigNotNil, cfg != nil, "test description %s", testcase.TestDescription) | ||||||
|  | 		if testcase.ConfigNotNil { | ||||||
|  | 			require.True(t, len(cfg.Listeners) > 0, "test description %s", testcase.TestDescription) | ||||||
|  | 			require.Equal(t, testcase.TLSDisable, cfg.Listeners[0].TLSDisable, "test description %s", testcase.TestDescription) | ||||||
|  | 		} | ||||||
|  | 		require.Equal(t, testcase.CertPathEmpty, len(certPath) == 0, "test description %s", testcase.TestDescription) | ||||||
|  | 		require.Equal(t, testcase.ErrNotNil, (err != nil), "test description %s", testcase.TestDescription) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Marc Boudreau
					Marc Boudreau