Compare commits

...

1 Commits

Author SHA1 Message Date
zhangxiaoqiang
f827b71c57 commit pki2.0 2025-09-12 14:27:27 +08:00

View File

@@ -7,7 +7,13 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"crypto/rand"
"crypto/rsa"
"crypto/ecdsa"
"encoding/json"
"encoding/pem"
"encoding/base64"
"flag"
"fmt"
"io/ioutil"
@@ -22,10 +28,12 @@ import (
"syscall"
"time"
"os/exec"
"bytes"
"github.com/go-co-op/gocron"
"github.com/gorilla/websocket"
"golang.org/x/sys/unix"
"github.com/fullsailor/pkcs7"
)
type WsConn struct {
@@ -55,6 +63,7 @@ const (
var (
cloudDiscoveryHost = "https://discovery.open-lan.org/v1/devices/"
estHost = "est.certificates.open-lan.org"
)
var (
@@ -112,6 +121,7 @@ var (
PublicIpLookup = "ifconfig.me"
VlanStatsLast = map[string]InterfaceCounter{}
PortStatsLast = map[string]OLSInterfaceCounter{}
estServerList = []string{}
)
func sendMessageToController() error {
@@ -567,13 +577,378 @@ func getControllerUrl() string {
return ControllerUrl
}
func firstContact() bool {
ControllerAddr = getControllerUrl()
if ControllerAddr == "" {
logger.Error("Could not get ControllerAddr")
func getEstServer(domain string) []string {
CAA := []string{}
// Execute the dig command to retrieve CAA records
cmd := exec.Command("dig", "+short", "caa", domain)
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err != nil {
logger.Error("command dig failed, err is %s", err.Error())
return CAA
}
output := strings.TrimSpace(out.String())
if output == "" {
logger.Error("%s no CAA record", domain)
return CAA
}
// Split each row and extract the third field
lines := strings.Split(output, "\n")
var thirdFields []string
for _, line := range lines {
fields := strings.Fields(line)
if len(fields) >= 3 {
thirdField := fields[2]
thirdField = strings.Trim(thirdField, `"`)
thirdFields = append(thirdFields, thirdField)
}
}
for _, field := range thirdFields {
CAA = append(CAA, field)
}
logger.Info("CAA: %v", CAA)
return CAA
}
// Obtain signature algorithm based on private key type
func getSignatureAlgorithm(privateKey interface{}) x509.SignatureAlgorithm {
switch privateKey.(type) {
case *rsa.PrivateKey:
return x509.SHA256WithRSA
case *ecdsa.PrivateKey:
return x509.ECDSAWithSHA256
default:
return x509.UnknownSignatureAlgorithm
}
}
func getOperationalCA(estServer string) bool {
client, err := tlsclient(true)
if err != nil {
logger.Error("tls client created failed, err is %s", err.Error())
return false
}
caUrl := "https://" + estServer + "/cacerts"
resp, err := client.Get(caUrl)
if err != nil {
logger.Error("request failed, err is %s", err.Error())
return false
}
defer resp.Body.Close()
logger.Info("resp.StatusCode %d", resp.StatusCode)
calist, err := ioutil.ReadAll(resp.Body)
if err != nil {
logger.Error("read response failed, err is %s", err.Error())
return false
}
decoded, err := base64.StdEncoding.DecodeString(string(calist))
if err != nil {
logger.Info("Decode String failed, err is %s", err.Error())
return false
}
certPEM := decoded
// parse PKCS#7 data
p7, err := pkcs7.Parse(certPEM)
if err != nil {
logger.Info("parse PKCS7 failed, err is %s", err.Error())
return false
}
var certs []*x509.Certificate
if p7.Certificates != nil {
certs = p7.Certificates
}
logger.Info("Converted P7 to PEM")
if len(certs) == 0 {
logger.Info("cannot find operational.ca from response")
return false
}
// save to operational.ca
file, err := os.Create(operationalCAPath)
if err != nil {
logger.Error("cannot create file operational.ca, err is %s", err.Error())
return false
}
defer file.Close()
for _, cert := range certs {
err = pem.Encode(file, &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
})
if err != nil {
logger.Error("cannot write CA to file operational.ca, err is %s", err.Error())
return false
}
}
logger.Info("Persistently stored operational.ca")
return true
}
func getOperationalCert(estServer string, reenroll bool) bool {
result := false
// Read existing private key file (PEM format)
logger.Info("start to get operational.pem")
privateKeyPEM, err := os.ReadFile(keyPath)
if err != nil {
logger.Error("read %s failed, err is %s", keyPath, err.Error())
return result
}
block, _ := pem.Decode(privateKeyPEM)
if block == nil {
logger.Error("Invalid PEM format private key")
return result
}
// parse private key
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
logger.Error("parse private key failed, err is %s", err.Error())
return result
}
pemData := []byte{}
if !reenroll {
pemData, err = ioutil.ReadFile(certPath)
if err != nil {
logger.Error("read %s failed, err is %s", certPath, err.Error())
return result
}
} else {
pemData, err = ioutil.ReadFile(operationalPath)
if err != nil {
logger.Error("read %s failed, err is %s", operationalPath, err.Error())
return result
}
}
blockCert, _ := pem.Decode(pemData)
if blockCert == nil || blockCert.Type != "CERTIFICATE" {
logger.Error("Invalid PEM format cert key")
return result
}
cert, err := x509.ParseCertificate(blockCert.Bytes)
if err != nil {
logger.Error("parse cert key failed, err is %s", err.Error())
return result
}
subjectCert := cert.Subject
subject := pkix.Name{
CommonName: subjectCert.CommonName,
Organization: subjectCert.Organization,
}
// create CSR template
csrTemplate := x509.CertificateRequest{
Subject: subject,
SignatureAlgorithm: getSignatureAlgorithm(privateKey),
}
// create CSR
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &csrTemplate, privateKey)
if err != nil {
logger.Error("create CSR failed, err is %s", err.Error())
return result
}
logger.Info("Generated CSR")
encoded := base64.StdEncoding.EncodeToString(csrBytes)
reader := strings.NewReader(encoded)
caURL := "https://" + estServer + "/simpleenroll"
if reenroll {
caURL = "https://" + estServer + "/simplereenroll"
}
req, _ := http.NewRequest("POST", caURL, reader)
req.Header.Set("Content-Type", "application/pkcs10-base64")
req.Header.Set("Accept", "application/pkcs7")
client, err := tlsclient(true)
if err != nil {
logger.Error("tls client created failed, err is %s", err.Error())
return result
}
resp, err := client.Do(req)
if err != nil {
logger.Error("request failed, err is %s", err.Error())
return result
}
logger.Info("EST succeeded")
defer resp.Body.Close()
certPEM, err := ioutil.ReadAll(resp.Body)
if err != nil {
logger.Error("read response failed, err is %s", err.Error())
return result
}
decoded, err := base64.StdEncoding.DecodeString(string(certPEM))
if err != nil {
logger.Error("Decode String failed, err is %s", err.Error())
return result
}
certPEM = decoded
// parse PKCS#7 data
p7, err := pkcs7.Parse(certPEM)
if err != nil {
logger.Error("parse PKCS7 failed, err is %s", err.Error())
return result
}
var certs []*x509.Certificate
if p7.Certificates != nil {
certs = p7.Certificates
}
logger.Info("Converted P7 to PEM")
if len(certs) == 0 {
logger.Error("cannot find operational certificate from response")
return result
}
// save to operational.pem
file, err := os.Create(operationalPath)
if err != nil {
logger.Error("cannot create opreational.pem, err is %s", err.Error())
return result
}
defer file.Close()
for _, cert := range certs {
err = pem.Encode(file, &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
})
if err != nil {
logger.Error("write opreational.pem failed, err is %s", err.Error())
return result
}
}
logger.Info("Persistently stored operational.pem")
result = true
return result
}
func operationalExpireCheck() bool {
certPEM, err := ioutil.ReadFile(operationalPath)
if err != nil {
logger.Error("read opreational.pem failed, err is %s", err.Error())
return true
}
block, _ := pem.Decode(certPEM)
if block == nil {
logger.Error("decode opreational.pem failed, err is %s", err.Error())
return true
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
logger.Error("decode opreational.pem failed, err is %s", err.Error())
return true
}
now := time.Now()
notBefore := cert.NotBefore
notAfter := cert.NotAfter
totalDuration := notAfter.Sub(notBefore)
twoThirdsTime := notBefore.Add(totalDuration * 2 / 3)
logger.Info("Certificate issuance time: %v", notBefore)
logger.Info("Certificate expiration date: %v", notAfter)
if now.After(twoThirdsTime) || now.Equal(twoThirdsTime) {
logger.Info("The certificate has reached or exceeded 2/3 of its usage period")
return true
} else {
logger.Info("Certificate status is normal")
}
return false
}
func setQAorProduct() bool {
result := false
pemData, err := ioutil.ReadFile(certPath)
if err != nil {
logger.Error("read %s failed, err is %s", certPath, err.Error())
return result
}
blockCert, _ := pem.Decode(pemData)
if blockCert == nil || blockCert.Type != "CERTIFICATE" {
logger.Error("Invalid PEM format cert key")
return result
}
cert, err := x509.ParseCertificate(blockCert.Bytes)
if err != nil {
logger.Error("parse cert key failed, err is %s", err.Error())
return result
}
issue := cert.Issuer.String()
logger.Info("(Issuer): %s\n", issue)
if strings.Contains(issue, "OpenLAN Demo") {
estHost = "qaest.certificates.open-lan.org:8001"
cloudDiscoveryHost = "https://discovery-qa.open-lan.org/v1/devices/"
}
return true
}
func firstContact() bool {
resEnv := setQAorProduct()
if !resEnv {
return true
}
ControllerAddr = getControllerUrl()
if ControllerAddr == "" {
logger.Error("Could not get ControllerAddr")
estServerList = []string{estHost}
updateOperationalPem()
return true
}
estServerList = getEstServer(ControllerAddr)
estServerList = append(estServerList, estHost)
res := updateOperationalPem()
if !res {
return true
}
// set Cloud controller FQDN to connection instance, set it in redis
serverInfo := map[string]interface{}{
@@ -589,6 +964,53 @@ func firstContact() bool {
return false
}
func updateOperationalPem() bool {
for _, oneServer := range estServerList {
res := getOperationalCA(oneServer)
if res {
break
}
}
hasOperationalPem := false
_, err := ioutil.ReadFile(operationalPath)
if err != nil {
logger.Warn("Reading %s failed, err is %s", operationalPath, err.Error())
for _, oneServer := range estServerList {
res := getOperationalCert(oneServer, false)
if res{
hasOperationalPem = true
break
}
}
} else {
isExpire := operationalExpireCheck()
if isExpire {
for _, oneServer := range estServerList {
res := getOperationalCert(oneServer, true)
if res{
hasOperationalPem = true
break
}
}
}
hasOperationalPem = true
}
return hasOperationalPem
}
func periodUpdateOpertaional() {
ticker := time.NewTicker(60 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
updateOperationalPem()
}
}
}
func main() {
flag.Parse()
log.SetFlags(0)
@@ -648,6 +1070,8 @@ func main() {
break
}
go periodUpdateOpertaional()
// Start the main event loop in a goroutine
go startEventLoop()