Use require and assert in a few more Nebula test functions

This commit is contained in:
Herman Slatman
2024-08-20 23:09:11 +02:00
parent 74d30d975a
commit 1b09b1143e

View File

@@ -8,7 +8,6 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"fmt"
"net"
"net/url"
"reflect"
@@ -16,8 +15,8 @@ import (
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
@@ -30,9 +29,7 @@ import (
func mustNebulaIPNet(t *testing.T, s string) *net.IPNet {
t.Helper()
ip, ipNet, err := net.ParseCIDR(s)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
if ip = ip.To4(); ip == nil {
t.Fatalf("nebula only supports ipv4, have %s", s)
}
@@ -43,9 +40,7 @@ func mustNebulaIPNet(t *testing.T, s string) *net.IPNet {
func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
nc := &cert.NebulaCertificate{
Details: cert.NebulaCertificateDetails{
Name: "TestCA",
@@ -61,9 +56,9 @@ func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) {
Curve: cert.Curve_CURVE25519,
},
}
if err := nc.Sign(cert.Curve_CURVE25519, priv); err != nil {
t.Fatal(err)
}
err = nc.Sign(cert.Curve_CURVE25519, priv)
require.NoError(t, err)
return nc, priv
}
@@ -99,14 +94,10 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string
t.Helper()
pub, priv, err := x25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
issuer, err := ca.Sha256Sum()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
invertedGroups := make(map[string]struct{}, len(groups))
for _, name := range groups {
@@ -130,9 +121,8 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string
},
}
if err := nc.Sign(cert.Curve_CURVE25519, signer); err != nil {
t.Fatal(err)
}
err = nc.Sign(cert.Curve_CURVE25519, signer)
require.NoError(t, err)
return nc, priv
}
@@ -184,9 +174,7 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25
nc, signer := mustNebulaCA(t)
ncPem, err := nc.MarshalToPEM()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
bTrue := true
p := &Nebula{
Type: TypeNebula.String(),
@@ -196,12 +184,11 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25
EnableSSHCA: &bTrue,
},
}
if err := p.Init(Config{
err = p.Init(Config{
Claims: globalProvisionerClaims,
Audiences: testAudiences,
}); err != nil {
t.Fatal(err)
}
})
require.NoError(t, err)
return p, nc, signer
}
@@ -310,9 +297,7 @@ func mustNebulaSSHToken(t *testing.T, sub, iss, aud string, iat time.Time, opts
func TestNebula_Init(t *testing.T) {
nc, _ := mustNebulaCA(t)
ncPem, err := nc.MarshalToPEM()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
cfg := Config{
Claims: globalProvisionerClaims,
@@ -416,9 +401,7 @@ func TestNebula_GetTokenID(t *testing.T) {
c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer)
t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv, jose.XEdDSA)
_, claims, err := parseToken(t1)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
type args struct {
token string
@@ -838,13 +821,9 @@ func TestNebula_authorizeToken(t *testing.T) {
// Not a nebula token
jwk, err := generateJSONWebKey()
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
simpleToken, err := generateSimpleToken("iss", "aud", jwk)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
// Provisioner with a different CA
p2, _, _ := mustNebulaProvisioner(t)
@@ -911,22 +890,20 @@ func TestNebula_authorizeToken(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences)
if (err != nil) != tt.wantErr {
t.Errorf("Nebula.authorizeToken() error = %v, wantErr %v", err, tt.wantErr)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
assert.Nil(t, got1)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Nebula.authorizeToken() got = %#v, want %#v", got, tt.want)
t.Error(cmp.Equal(got, tt.want))
}
if got1 != nil && tt.want1 != nil {
tt.want1.ID = got1.ID
}
if !reflect.DeepEqual(got1, tt.want1) {
t.Errorf("Nebula.authorizeToken() got1 = %v, want %v", got1, tt.want1)
}
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.want1, got1)
})
}
}
@@ -1021,23 +998,20 @@ func TestNebula_authorizeToken_P256(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences)
if (err != nil) != tt.wantErr {
fmt.Println(err)
t.Errorf("Nebula.authorizeToken() error = %v, wantErr %v", err, tt.wantErr)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
assert.Nil(t, got1)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Nebula.authorizeToken() got = %#v, want %#v", got, tt.want)
t.Error(cmp.Equal(got, tt.want))
}
if got1 != nil && tt.want1 != nil {
tt.want1.ID = got1.ID
}
if !reflect.DeepEqual(got1, tt.want1) {
t.Errorf("Nebula.authorizeToken() got1 = %v, want %v", got1, tt.want1)
}
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.want1, got1)
})
}
}