mirror of
https://github.com/optim-enterprises-bv/vault.git
synced 2025-10-29 17:52:32 +00:00
Rework stubmaker logic so that if the funcs are found in Go, we don't attempt to write the file. (#21636)
This commit is contained in:
@@ -31,6 +31,8 @@ func main() {
|
||||
Level: hclog.Trace,
|
||||
})
|
||||
|
||||
// Setup git, both so we can determine if we're running on enterprise, and
|
||||
// so we can make sure we don't clobber a non-transient file.
|
||||
repo, err := git.PlainOpenWithOptions(".", &git.PlainOpenOptions{
|
||||
DetectDotGit: true,
|
||||
})
|
||||
@@ -46,6 +48,29 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
// Read the file and figure out if we need to do anything.
|
||||
inputFile := os.Getenv("GOFILE")
|
||||
if !strings.HasSuffix(inputFile, "_oss.go") {
|
||||
fatal(fmt.Errorf("stubmaker should only be invoked from files ending in _oss.go"))
|
||||
}
|
||||
|
||||
baseFilename := strings.TrimSuffix(inputFile, "_oss.go")
|
||||
outputFile := baseFilename + "_ent.go"
|
||||
b, err := os.ReadFile(inputFile)
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
|
||||
inputLines, err := readLines(bytes.NewBuffer(b))
|
||||
funcs := getFuncs(inputLines)
|
||||
if needed, err := isStubNeeded(funcs); err != nil {
|
||||
fatal(err)
|
||||
} else if !needed {
|
||||
return
|
||||
}
|
||||
|
||||
// We'd like to write the file, but first make sure that we're not going
|
||||
// to blow away anyone's work or overwrite a file already in git.
|
||||
head, err := repo.Head()
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
@@ -60,23 +85,27 @@ func main() {
|
||||
fatal(err)
|
||||
}
|
||||
|
||||
inputFile := os.Getenv("GOFILE")
|
||||
if !strings.HasSuffix(inputFile, "_oss.go") {
|
||||
fatal(fmt.Errorf("stubmaker should only be invoked from files ending in _oss.go"))
|
||||
}
|
||||
|
||||
baseFilename := strings.TrimSuffix(inputFile, "_oss.go")
|
||||
target := baseFilename + "_ent.go"
|
||||
|
||||
tracked, err := inGit(wt, st, obj, target)
|
||||
tracked, err := inGit(wt, st, obj, outputFile)
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
if tracked {
|
||||
fatal(fmt.Errorf("output file %s exists in git, not overwriting", target))
|
||||
fatal(fmt.Errorf("output file %s exists in git, not overwriting", outputFile))
|
||||
}
|
||||
|
||||
if err := writeStubIfNeeded(inputFile, target); err != nil {
|
||||
// Now we can finally write the file
|
||||
output, err := os.Create(outputFile + ".tmp")
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
_, err = io.WriteString(output, strings.Join(getOutput(inputLines), "\n")+"\n")
|
||||
if err != nil {
|
||||
// If we don't end up writing to the file, delete it.
|
||||
os.Remove(outputFile + ".tmp")
|
||||
} else {
|
||||
os.Rename(outputFile+".tmp", outputFile)
|
||||
}
|
||||
if err != nil {
|
||||
fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -165,40 +194,10 @@ func readLines(r io.Reader) ([]string, error) {
|
||||
return lines, nil
|
||||
}
|
||||
|
||||
func writeStubIfNeeded(inputFile, outputFile string) (err error) {
|
||||
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."
|
||||
|
||||
var output *os.File
|
||||
b, err := os.ReadFile(inputFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
inputLines, err := readLines(bytes.NewBuffer(b))
|
||||
var funcs []string
|
||||
var outputLines []string
|
||||
for _, line := range inputLines {
|
||||
switch line {
|
||||
case "//go:build !enterprise":
|
||||
outputLines = append(outputLines, warning, "")
|
||||
line = "//go:build enterprise"
|
||||
case "//go:generate go run github.com/hashicorp/vault/tools/stubmaker":
|
||||
continue
|
||||
}
|
||||
outputLines = append(outputLines, line)
|
||||
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "func ") {
|
||||
i := strings.Index(trimmed, "(")
|
||||
if i != -1 {
|
||||
funcs = append(funcs, trimmed[5:i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isStubNeeded(funcs []string) (bool, error) {
|
||||
pkg, err := parsePackage(".", []string{"enterprise"})
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
var found []string
|
||||
@@ -218,26 +217,44 @@ func writeStubIfNeeded(inputFile, outputFile string) (err error) {
|
||||
}
|
||||
switch {
|
||||
case len(found) == len(funcs):
|
||||
return nil
|
||||
return false, nil
|
||||
case len(found) != 0:
|
||||
return fmt.Errorf("funcs partially defined: need=%v, found=%v", funcs, found)
|
||||
return false, fmt.Errorf("funcs partially defined: need=%v, found=%v", funcs, found)
|
||||
}
|
||||
|
||||
output, err = os.Create(outputFile + ".tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If we don't end up writing to the file, delete it.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
os.Remove(outputFile + ".tmp")
|
||||
} else {
|
||||
os.Rename(outputFile+".tmp", outputFile)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func getFuncs(inputLines []string) []string {
|
||||
var funcs []string
|
||||
for _, line := range inputLines {
|
||||
trimmed := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmed, "func ") {
|
||||
i := strings.Index(trimmed, "(")
|
||||
if i != -1 {
|
||||
funcs = append(funcs, trimmed[5:i])
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
return funcs
|
||||
}
|
||||
|
||||
_, err = io.WriteString(output, strings.Join(outputLines, "\n")+"\n")
|
||||
return err
|
||||
func getOutput(inputLines []string) []string {
|
||||
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."
|
||||
|
||||
var outputLines []string
|
||||
for _, line := range inputLines {
|
||||
switch line {
|
||||
case "//go:build !enterprise":
|
||||
outputLines = append(outputLines, warning, "")
|
||||
line = "//go:build enterprise"
|
||||
case "//go:generate go run github.com/hashicorp/vault/tools/stubmaker":
|
||||
continue
|
||||
}
|
||||
outputLines = append(outputLines, line)
|
||||
}
|
||||
|
||||
return outputLines
|
||||
}
|
||||
|
||||
func parsePackage(name string, tags []string) (*packages.Package, error) {
|
||||
|
||||
Reference in New Issue
Block a user