diff --git a/authority/provisioner/nebula_test.go b/authority/provisioner/nebula_test.go index 4288cab7..5ab7c2a0 100644 --- a/authority/provisioner/nebula_test.go +++ b/authority/provisioner/nebula_test.go @@ -8,7 +8,6 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/x509" - "fmt" "net" "net/url" "reflect" @@ -16,8 +15,8 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" "github.com/slackhq/nebula/cert" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" @@ -30,9 +29,7 @@ import ( func mustNebulaIPNet(t *testing.T, s string) *net.IPNet { t.Helper() ip, ipNet, err := net.ParseCIDR(s) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if ip = ip.To4(); ip == nil { t.Fatalf("nebula only supports ipv4, have %s", s) } @@ -43,9 +40,7 @@ func mustNebulaIPNet(t *testing.T, s string) *net.IPNet { func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) { t.Helper() pub, priv, err := ed25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) nc := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ Name: "TestCA", @@ -61,9 +56,9 @@ func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) { Curve: cert.Curve_CURVE25519, }, } - if err := nc.Sign(cert.Curve_CURVE25519, priv); err != nil { - t.Fatal(err) - } + err = nc.Sign(cert.Curve_CURVE25519, priv) + require.NoError(t, err) + return nc, priv } @@ -99,14 +94,10 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string t.Helper() pub, priv, err := x25519.GenerateKey(rand.Reader) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) issuer, err := ca.Sha256Sum() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) invertedGroups := make(map[string]struct{}, len(groups)) for _, name := range groups { @@ -130,9 +121,8 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string }, } - if err := nc.Sign(cert.Curve_CURVE25519, signer); err != nil { - t.Fatal(err) - } + err = nc.Sign(cert.Curve_CURVE25519, signer) + require.NoError(t, err) return nc, priv } @@ -184,9 +174,7 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25 nc, signer := mustNebulaCA(t) ncPem, err := nc.MarshalToPEM() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) bTrue := true p := &Nebula{ Type: TypeNebula.String(), @@ -196,12 +184,11 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25 EnableSSHCA: &bTrue, }, } - if err := p.Init(Config{ + err = p.Init(Config{ Claims: globalProvisionerClaims, Audiences: testAudiences, - }); err != nil { - t.Fatal(err) - } + }) + require.NoError(t, err) return p, nc, signer } @@ -310,9 +297,7 @@ func mustNebulaSSHToken(t *testing.T, sub, iss, aud string, iat time.Time, opts func TestNebula_Init(t *testing.T) { nc, _ := mustNebulaCA(t) ncPem, err := nc.MarshalToPEM() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) cfg := Config{ Claims: globalProvisionerClaims, @@ -416,9 +401,7 @@ func TestNebula_GetTokenID(t *testing.T) { c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer) t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv, jose.XEdDSA) _, claims, err := parseToken(t1) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) type args struct { token string @@ -838,13 +821,9 @@ func TestNebula_authorizeToken(t *testing.T) { // Not a nebula token jwk, err := generateJSONWebKey() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) simpleToken, err := generateSimpleToken("iss", "aud", jwk) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // Provisioner with a different CA p2, _, _ := mustNebulaProvisioner(t) @@ -911,22 +890,20 @@ func TestNebula_authorizeToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences) - if (err != nil) != tt.wantErr { - t.Errorf("Nebula.authorizeToken() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + assert.Nil(t, got1) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Nebula.authorizeToken() got = %#v, want %#v", got, tt.want) - t.Error(cmp.Equal(got, tt.want)) - } if got1 != nil && tt.want1 != nil { tt.want1.ID = got1.ID } - if !reflect.DeepEqual(got1, tt.want1) { - t.Errorf("Nebula.authorizeToken() got1 = %v, want %v", got1, tt.want1) - } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want1, got1) }) } } @@ -1021,23 +998,20 @@ func TestNebula_authorizeToken_P256(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences) - if (err != nil) != tt.wantErr { - fmt.Println(err) - t.Errorf("Nebula.authorizeToken() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + assert.Nil(t, got1) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Nebula.authorizeToken() got = %#v, want %#v", got, tt.want) - t.Error(cmp.Equal(got, tt.want)) - } if got1 != nil && tt.want1 != nil { tt.want1.ID = got1.ID } - if !reflect.DeepEqual(got1, tt.want1) { - t.Errorf("Nebula.authorizeToken() got1 = %v, want %v", got1, tt.want1) - } + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.want1, got1) }) } }