mirror of
				https://github.com/optim-enterprises-bv/vault.git
				synced 2025-11-03 12:07:54 +00:00 
			
		
		
		
	License /shamir under MPL. Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
		
			
				
	
	
		
			202 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			202 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
// Copyright (c) HashiCorp, Inc.
 | 
						|
// SPDX-License-Identifier: MPL-2.0
 | 
						|
 | 
						|
package shamir
 | 
						|
 | 
						|
import (
 | 
						|
	"bytes"
 | 
						|
	"testing"
 | 
						|
)
 | 
						|
 | 
						|
func TestSplit_invalid(t *testing.T) {
 | 
						|
	secret := []byte("test")
 | 
						|
 | 
						|
	if _, err := Split(secret, 0, 0); err == nil {
 | 
						|
		t.Fatalf("expect error")
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := Split(secret, 2, 3); err == nil {
 | 
						|
		t.Fatalf("expect error")
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := Split(secret, 1000, 3); err == nil {
 | 
						|
		t.Fatalf("expect error")
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := Split(secret, 10, 1); err == nil {
 | 
						|
		t.Fatalf("expect error")
 | 
						|
	}
 | 
						|
 | 
						|
	if _, err := Split(nil, 3, 2); err == nil {
 | 
						|
		t.Fatalf("expect error")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestSplit(t *testing.T) {
 | 
						|
	secret := []byte("test")
 | 
						|
 | 
						|
	out, err := Split(secret, 5, 3)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if len(out) != 5 {
 | 
						|
		t.Fatalf("bad: %v", out)
 | 
						|
	}
 | 
						|
 | 
						|
	for _, share := range out {
 | 
						|
		if len(share) != len(secret)+1 {
 | 
						|
			t.Fatalf("bad: %v", out)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestCombine_invalid(t *testing.T) {
 | 
						|
	// Not enough parts
 | 
						|
	if _, err := Combine(nil); err == nil {
 | 
						|
		t.Fatalf("should err")
 | 
						|
	}
 | 
						|
 | 
						|
	// Mis-match in length
 | 
						|
	parts := [][]byte{
 | 
						|
		[]byte("foo"),
 | 
						|
		[]byte("ba"),
 | 
						|
	}
 | 
						|
	if _, err := Combine(parts); err == nil {
 | 
						|
		t.Fatalf("should err")
 | 
						|
	}
 | 
						|
 | 
						|
	// Too short
 | 
						|
	parts = [][]byte{
 | 
						|
		[]byte("f"),
 | 
						|
		[]byte("b"),
 | 
						|
	}
 | 
						|
	if _, err := Combine(parts); err == nil {
 | 
						|
		t.Fatalf("should err")
 | 
						|
	}
 | 
						|
 | 
						|
	parts = [][]byte{
 | 
						|
		[]byte("foo"),
 | 
						|
		[]byte("foo"),
 | 
						|
	}
 | 
						|
	if _, err := Combine(parts); err == nil {
 | 
						|
		t.Fatalf("should err")
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestCombine(t *testing.T) {
 | 
						|
	secret := []byte("test")
 | 
						|
 | 
						|
	out, err := Split(secret, 5, 3)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	// There is 5*4*3 possible choices,
 | 
						|
	// we will just brute force try them all
 | 
						|
	for i := 0; i < 5; i++ {
 | 
						|
		for j := 0; j < 5; j++ {
 | 
						|
			if j == i {
 | 
						|
				continue
 | 
						|
			}
 | 
						|
			for k := 0; k < 5; k++ {
 | 
						|
				if k == i || k == j {
 | 
						|
					continue
 | 
						|
				}
 | 
						|
				parts := [][]byte{out[i], out[j], out[k]}
 | 
						|
				recomb, err := Combine(parts)
 | 
						|
				if err != nil {
 | 
						|
					t.Fatalf("err: %v", err)
 | 
						|
				}
 | 
						|
 | 
						|
				if !bytes.Equal(recomb, secret) {
 | 
						|
					t.Errorf("parts: (i:%d, j:%d, k:%d) %v", i, j, k, parts)
 | 
						|
					t.Fatalf("bad: %v %v", recomb, secret)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestField_Add(t *testing.T) {
 | 
						|
	if out := add(16, 16); out != 0 {
 | 
						|
		t.Fatalf("Bad: %v 16", out)
 | 
						|
	}
 | 
						|
 | 
						|
	if out := add(3, 4); out != 7 {
 | 
						|
		t.Fatalf("Bad: %v 7", out)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestField_Mult(t *testing.T) {
 | 
						|
	if out := mult(3, 7); out != 9 {
 | 
						|
		t.Fatalf("Bad: %v 9", out)
 | 
						|
	}
 | 
						|
 | 
						|
	if out := mult(3, 0); out != 0 {
 | 
						|
		t.Fatalf("Bad: %v 0", out)
 | 
						|
	}
 | 
						|
 | 
						|
	if out := mult(0, 3); out != 0 {
 | 
						|
		t.Fatalf("Bad: %v 0", out)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestField_Divide(t *testing.T) {
 | 
						|
	if out := div(0, 7); out != 0 {
 | 
						|
		t.Fatalf("Bad: %v 0", out)
 | 
						|
	}
 | 
						|
 | 
						|
	if out := div(3, 3); out != 1 {
 | 
						|
		t.Fatalf("Bad: %v 1", out)
 | 
						|
	}
 | 
						|
 | 
						|
	if out := div(6, 3); out != 2 {
 | 
						|
		t.Fatalf("Bad: %v 2", out)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPolynomial_Random(t *testing.T) {
 | 
						|
	p, err := makePolynomial(42, 2)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if p.coefficients[0] != 42 {
 | 
						|
		t.Fatalf("bad: %v", p.coefficients)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestPolynomial_Eval(t *testing.T) {
 | 
						|
	p, err := makePolynomial(42, 1)
 | 
						|
	if err != nil {
 | 
						|
		t.Fatalf("err: %v", err)
 | 
						|
	}
 | 
						|
 | 
						|
	if out := p.evaluate(0); out != 42 {
 | 
						|
		t.Fatalf("bad: %v", out)
 | 
						|
	}
 | 
						|
 | 
						|
	out := p.evaluate(1)
 | 
						|
	exp := add(42, mult(1, p.coefficients[1]))
 | 
						|
	if out != exp {
 | 
						|
		t.Fatalf("bad: %v %v %v", out, exp, p.coefficients)
 | 
						|
	}
 | 
						|
}
 | 
						|
 | 
						|
func TestInterpolate_Rand(t *testing.T) {
 | 
						|
	for i := 0; i < 256; i++ {
 | 
						|
		p, err := makePolynomial(uint8(i), 2)
 | 
						|
		if err != nil {
 | 
						|
			t.Fatalf("err: %v", err)
 | 
						|
		}
 | 
						|
 | 
						|
		x_vals := []uint8{1, 2, 3}
 | 
						|
		y_vals := []uint8{p.evaluate(1), p.evaluate(2), p.evaluate(3)}
 | 
						|
		out := interpolatePolynomial(x_vals, y_vals, 0)
 | 
						|
		if out != uint8(i) {
 | 
						|
			t.Fatalf("Bad: %v %d", out, i)
 | 
						|
		}
 | 
						|
	}
 | 
						|
}
 |