mirror of
https://github.com/outbackdingo/certificates.git
synced 2026-01-27 10:18:34 +00:00
Use require and assert in a few more Nebula test functions
This commit is contained in:
@@ -8,7 +8,6 @@ import (
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"reflect"
|
||||
@@ -16,8 +15,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
@@ -30,9 +29,7 @@ import (
|
||||
func mustNebulaIPNet(t *testing.T, s string) *net.IPNet {
|
||||
t.Helper()
|
||||
ip, ipNet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
if ip = ip.To4(); ip == nil {
|
||||
t.Fatalf("nebula only supports ipv4, have %s", s)
|
||||
}
|
||||
@@ -43,9 +40,7 @@ func mustNebulaIPNet(t *testing.T, s string) *net.IPNet {
|
||||
func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) {
|
||||
t.Helper()
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
nc := &cert.NebulaCertificate{
|
||||
Details: cert.NebulaCertificateDetails{
|
||||
Name: "TestCA",
|
||||
@@ -61,9 +56,9 @@ func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) {
|
||||
Curve: cert.Curve_CURVE25519,
|
||||
},
|
||||
}
|
||||
if err := nc.Sign(cert.Curve_CURVE25519, priv); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = nc.Sign(cert.Curve_CURVE25519, priv)
|
||||
require.NoError(t, err)
|
||||
|
||||
return nc, priv
|
||||
}
|
||||
|
||||
@@ -99,14 +94,10 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string
|
||||
t.Helper()
|
||||
|
||||
pub, priv, err := x25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
issuer, err := ca.Sha256Sum()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
invertedGroups := make(map[string]struct{}, len(groups))
|
||||
for _, name := range groups {
|
||||
@@ -130,9 +121,8 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string
|
||||
},
|
||||
}
|
||||
|
||||
if err := nc.Sign(cert.Curve_CURVE25519, signer); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = nc.Sign(cert.Curve_CURVE25519, signer)
|
||||
require.NoError(t, err)
|
||||
|
||||
return nc, priv
|
||||
}
|
||||
@@ -184,9 +174,7 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25
|
||||
|
||||
nc, signer := mustNebulaCA(t)
|
||||
ncPem, err := nc.MarshalToPEM()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
bTrue := true
|
||||
p := &Nebula{
|
||||
Type: TypeNebula.String(),
|
||||
@@ -196,12 +184,11 @@ func mustNebulaProvisioner(t *testing.T) (*Nebula, *cert.NebulaCertificate, ed25
|
||||
EnableSSHCA: &bTrue,
|
||||
},
|
||||
}
|
||||
if err := p.Init(Config{
|
||||
err = p.Init(Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
Audiences: testAudiences,
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return p, nc, signer
|
||||
}
|
||||
@@ -310,9 +297,7 @@ func mustNebulaSSHToken(t *testing.T, sub, iss, aud string, iat time.Time, opts
|
||||
func TestNebula_Init(t *testing.T) {
|
||||
nc, _ := mustNebulaCA(t)
|
||||
ncPem, err := nc.MarshalToPEM()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := Config{
|
||||
Claims: globalProvisionerClaims,
|
||||
@@ -416,9 +401,7 @@ func TestNebula_GetTokenID(t *testing.T) {
|
||||
c1, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"group"}, ca, signer)
|
||||
t1 := mustNebulaToken(t, "test.lan", p.Name, p.ctl.Audiences.Sign[0], now(), []string{"test.lan", "10.1.0.1"}, c1, priv, jose.XEdDSA)
|
||||
_, claims, err := parseToken(t1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
token string
|
||||
@@ -838,13 +821,9 @@ func TestNebula_authorizeToken(t *testing.T) {
|
||||
|
||||
// Not a nebula token
|
||||
jwk, err := generateJSONWebKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
simpleToken, err := generateSimpleToken("iss", "aud", jwk)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// Provisioner with a different CA
|
||||
p2, _, _ := mustNebulaProvisioner(t)
|
||||
@@ -911,22 +890,20 @@ func TestNebula_authorizeToken(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Nebula.authorizeToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
assert.Nil(t, got1)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Nebula.authorizeToken() got = %#v, want %#v", got, tt.want)
|
||||
t.Error(cmp.Equal(got, tt.want))
|
||||
}
|
||||
|
||||
if got1 != nil && tt.want1 != nil {
|
||||
tt.want1.ID = got1.ID
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, tt.want1) {
|
||||
t.Errorf("Nebula.authorizeToken() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
assert.Equal(t, tt.want1, got1)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1021,23 +998,20 @@ func TestNebula_authorizeToken_P256(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences)
|
||||
if (err != nil) != tt.wantErr {
|
||||
fmt.Println(err)
|
||||
t.Errorf("Nebula.authorizeToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, got)
|
||||
assert.Nil(t, got1)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Nebula.authorizeToken() got = %#v, want %#v", got, tt.want)
|
||||
t.Error(cmp.Equal(got, tt.want))
|
||||
}
|
||||
|
||||
if got1 != nil && tt.want1 != nil {
|
||||
tt.want1.ID = got1.ID
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(got1, tt.want1) {
|
||||
t.Errorf("Nebula.authorizeToken() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
assert.Equal(t, tt.want1, got1)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user