diff --git a/authority/provisioner/nebula_test.go b/authority/provisioner/nebula_test.go index 84bf2926..a5925b7d 100644 --- a/authority/provisioner/nebula_test.go +++ b/authority/provisioner/nebula_test.go @@ -7,11 +7,13 @@ import ( "crypto/rand" "crypto/x509" "net" + "net/url" "reflect" "strings" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/slackhq/nebula/cert" "go.step.sm/crypto/jose" "go.step.sm/crypto/randutil" @@ -26,7 +28,7 @@ func mustNebulaIPNet(t *testing.T, s string) *net.IPNet { if err != nil { t.Fatal(err) } - if ip.To4() == nil { + if ip = ip.To4(); ip == nil { t.Fatalf("nebula only supports ipv4, have %s", s) } ipNet.IP = ip @@ -46,7 +48,7 @@ func mustNebulaCA(t *testing.T) (*cert.NebulaCertificate, ed25519.PrivateKey) { Ips: []*net.IPNet{ mustNebulaIPNet(t, "10.1.0.0/16"), }, - Subnets: nil, + Subnets: []*net.IPNet{}, NotBefore: time.Now(), NotAfter: time.Now().Add(10 * time.Minute), PublicKey: pub, @@ -72,16 +74,24 @@ func mustNebulaCert(t *testing.T, name string, ipNet *net.IPNet, groups []string t.Fatal(err) } + invertedGroups := make(map[string]struct{}, len(groups)) + for _, name := range groups { + invertedGroups[name] = struct{}{} + } + + t1 := time.Now().Truncate(time.Second) nc := &cert.NebulaCertificate{ Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{ipNet}, - Groups: groups, - NotBefore: time.Now(), - NotAfter: time.Now().Add(5 * time.Minute), - PublicKey: pub, - IsCA: false, - Issuer: issuer, + Name: name, + Ips: []*net.IPNet{ipNet}, + Subnets: []*net.IPNet{}, + Groups: groups, + NotBefore: t1, + NotAfter: t1.Add(5 * time.Minute), + PublicKey: pub, + IsCA: false, + Issuer: issuer, + InvertedGroups: invertedGroups, }, } @@ -244,6 +254,10 @@ func TestNebula_Init(t *testing.T) { {"fail type", fields{"", "Nebulous", ncPem, nil, nil}, args{cfg}, true}, {"fail name", fields{"Nebula", "", ncPem, nil, nil}, args{cfg}, true}, {"fail root", fields{"Nebula", "Nebulous", nil, nil, nil}, args{cfg}, true}, + {"fail bad root", fields{"Nebula", "Nebulous", ncPem[:16], nil, nil}, args{cfg}, true}, + {"fail bad claims", fields{"Nebula", "Nebulous", ncPem, &Claims{ + MinTLSDur: &Duration{Duration: 0}, + }, nil}, args{cfg}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -707,3 +721,235 @@ func TestNebula_AuthorizeSSHRekey(t *testing.T) { }) } } + +func TestNebula_authorizeToken(t *testing.T) { + t1 := now() + p, ca, signer := mustNebulaProvisioner(t) + crt, priv := mustNebulaCert(t, "test.lan", mustNebulaIPNet(t, "10.1.0.1/16"), []string{"test"}, ca, signer) + ok := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + okNoSANs := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1, nil, crt, priv) + okSSH := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, &SignSSHOptions{ + CertType: "host", + KeyID: "test.lan", + Principals: []string{"test.lan"}, + }, crt, priv) + okSSHNoOptions := mustNebulaSSHToken(t, "test.lan", p.Name, p.audiences.SSHSign[0], t1, nil, crt, priv) + + // Token with errors + failNotBefore := mustNebulaToken(t, "test.lan", p.Name, p.audiences.Sign[0], t1.Add(1*time.Hour), []string{"10.1.0.1"}, crt, priv) + failIssuer := mustNebulaToken(t, "test.lan", "foo", p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + failAudience := mustNebulaToken(t, "test.lan", p.Name, "foo", t1, []string{"10.1.0.1"}, crt, priv) + failSubject := mustNebulaToken(t, "", p.Name, p.audiences.Sign[0], t1, []string{"10.1.0.1"}, crt, priv) + + // Not a nebula token + jwk, err := generateJSONWebKey() + if err != nil { + t.Fatal(err) + } + simpleToken, err := generateSimpleToken("iss", "aud", jwk) + if err != nil { + t.Fatal(err) + } + + // Provisioner with a different CA + p2, _, _ := mustNebulaProvisioner(t) + + x509Claims := jose.Claims{ + ID: "[REPLACEME]", + Subject: "test.lan", + Issuer: p.Name, + IssuedAt: jose.NewNumericDate(t1), + NotBefore: jose.NewNumericDate(t1), + Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), + Audience: []string{p.audiences.Sign[0]}, + } + sshClaims := jose.Claims{ + ID: "[REPLACEME]", + Subject: "test.lan", + Issuer: p.Name, + IssuedAt: jose.NewNumericDate(t1), + NotBefore: jose.NewNumericDate(t1), + Expiry: jose.NewNumericDate(t1.Add(5 * time.Minute)), + Audience: []string{p.audiences.SSHSign[0]}, + } + + type args struct { + token string + audiences []string + } + tests := []struct { + name string + p *Nebula + args args + want *cert.NebulaCertificate + want1 *jwtPayload + wantErr bool + }{ + {"ok x509", p, args{ok, p.audiences.Sign}, crt, &jwtPayload{ + Claims: x509Claims, + SANs: []string{"10.1.0.1"}, + }, false}, + {"ok x509 no sans", p, args{okNoSANs, p.audiences.Sign}, crt, &jwtPayload{ + Claims: x509Claims, + }, false}, + {"ok ssh", p, args{okSSH, p.audiences.SSHSign}, crt, &jwtPayload{ + Claims: sshClaims, + Step: &stepPayload{ + SSH: &SignSSHOptions{ + CertType: "host", + KeyID: "test.lan", + Principals: []string{"test.lan"}, + }, + }, + }, false}, + {"ok ssh no principals", p, args{okSSHNoOptions, p.audiences.SSHSign}, crt, &jwtPayload{ + Claims: sshClaims, + }, false}, + {"fail parse", p, args{"bad.token", p.audiences.Sign}, nil, nil, true}, + {"fail header", p, args{simpleToken, p.audiences.Sign}, nil, nil, true}, + {"fail verify", p2, args{ok, p.audiences.Sign}, nil, nil, true}, + {"fail claims nbf", p, args{failNotBefore, p.audiences.Sign}, nil, nil, true}, + {"fail claims iss", p, args{failIssuer, p.audiences.Sign}, nil, nil, true}, + {"fail claims aud", p, args{failAudience, p.audiences.Sign}, nil, nil, true}, + {"fail claims sub", p, args{failSubject, p.audiences.Sign}, nil, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1, err := tt.p.authorizeToken(tt.args.token, tt.args.audiences) + if (err != nil) != tt.wantErr { + t.Errorf("Nebula.authorizeToken() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Nebula.authorizeToken() got = %#v, want %#v", got, tt.want) + t.Error(cmp.Equal(got, tt.want)) + } + + if got1 != nil && tt.want1 != nil { + tt.want1.ID = got1.ID + } + + if !reflect.DeepEqual(got1, tt.want1) { + t.Errorf("Nebula.authorizeToken() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func Test_nebulaSANsValidator_Valid(t *testing.T) { + ipNet := mustNebulaIPNet(t, "10.1.2.3/16") + type fields struct { + Name string + IPs []*net.IPNet + } + type args struct { + req *x509.CertificateRequest + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{"dns.name", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + DNSNames: []string{"dns.name"}, + IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, + }}, false}, + {"ok name only", fields{"dns.name", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + DNSNames: []string{"dns.name"}, + }}, false}, + {"ok ip only", fields{"dns.name", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, + }}, false}, + {"ok email name", fields{"jane@doe.org", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + EmailAddresses: []string{"jane@doe.org"}, + IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, + }}, false}, + {"ok uri name", fields{"urn:foobar", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + URIs: []*url.URL{{Scheme: "urn", Opaque: "foobar"}}, + IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, + }}, false}, + {"ok ip name", fields{"127.0.0.1", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv4(10, 1, 2, 3)}, + }}, false}, + {"ok multiple ips", fields{"dns.name", []*net.IPNet{ipNet, mustNebulaIPNet(t, "10.2.2.3/8")}}, args{&x509.CertificateRequest{ + DNSNames: []string{"dns.name"}, + IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3), net.IPv4(10, 2, 2, 3)}, + }}, false}, + {"fail dns", fields{"fail.name", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + DNSNames: []string{"dns.name"}, + IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, + }}, true}, + {"fail email", fields{"fail@doe.org", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + EmailAddresses: []string{"jane@doe.org"}, + IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, + }}, true}, + {"fail uri", fields{"urn:barfoo", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + URIs: []*url.URL{{Scheme: "urn", Opaque: "foobar"}}, + IPAddresses: []net.IP{net.IPv4(10, 1, 2, 3)}, + }}, true}, + {"fail ip", fields{"127.0.0.1", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + IPAddresses: []net.IP{net.IPv4(10, 1, 2, 1), net.IPv4(10, 1, 2, 3)}, + }}, true}, + {"fail nebula ip", fields{"dns.name", []*net.IPNet{ipNet}}, args{&x509.CertificateRequest{ + DNSNames: []string{"dns.name"}, + IPAddresses: []net.IP{net.IPv4(10, 2, 2, 3)}, + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := nebulaSANsValidator{ + Name: tt.fields.Name, + IPs: tt.fields.IPs, + } + if err := v.Valid(tt.args.req); (err != nil) != tt.wantErr { + t.Errorf("nebulaSANsValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_nebulaPrincipalsValidator_Valid(t *testing.T) { + ipNet := mustNebulaIPNet(t, "10.1.2.3/16") + + type fields struct { + Name string + IPs []*net.IPNet + } + type args struct { + got SignSSHOptions + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + {"ok", fields{"dns.name", []*net.IPNet{ipNet}}, args{SignSSHOptions{ + Principals: []string{"dns.name", "10.1.2.3"}, + }}, false}, + {"ok name", fields{"dns.name", []*net.IPNet{ipNet}}, args{SignSSHOptions{ + Principals: []string{"dns.name"}, + }}, false}, + {"ok ip", fields{"dns.name", []*net.IPNet{ipNet}}, args{SignSSHOptions{ + Principals: []string{"10.1.2.3"}, + }}, false}, + {"fail name", fields{"dns.name", []*net.IPNet{ipNet}}, args{SignSSHOptions{ + Principals: []string{"foo.name", "10.1.2.3"}, + }}, true}, + {"fail ip", fields{"dns.name", []*net.IPNet{ipNet}}, args{SignSSHOptions{ + Principals: []string{"dns.name", "10.2.2.3"}, + }}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := nebulaPrincipalsValidator{ + Name: tt.fields.Name, + IPs: tt.fields.IPs, + } + if err := v.Valid(tt.args.got); (err != nil) != tt.wantErr { + t.Errorf("nebulaPrincipalsValidator.Valid() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}