Rework stubmaker logic so that if the funcs are found in Go, we don't attempt to write the file. (#21636)

This commit is contained in:
Nick Cabatoff
2023-07-07 13:16:27 -04:00
committed by GitHub
parent 87d37fecb7
commit 3d7aab7a34

View File

@@ -31,6 +31,8 @@ func main() {
Level: hclog.Trace,
})
// Setup git, both so we can determine if we're running on enterprise, and
// so we can make sure we don't clobber a non-transient file.
repo, err := git.PlainOpenWithOptions(".", &git.PlainOpenOptions{
DetectDotGit: true,
})
@@ -46,6 +48,29 @@ func main() {
return
}
// Read the file and figure out if we need to do anything.
inputFile := os.Getenv("GOFILE")
if !strings.HasSuffix(inputFile, "_oss.go") {
fatal(fmt.Errorf("stubmaker should only be invoked from files ending in _oss.go"))
}
baseFilename := strings.TrimSuffix(inputFile, "_oss.go")
outputFile := baseFilename + "_ent.go"
b, err := os.ReadFile(inputFile)
if err != nil {
fatal(err)
}
inputLines, err := readLines(bytes.NewBuffer(b))
funcs := getFuncs(inputLines)
if needed, err := isStubNeeded(funcs); err != nil {
fatal(err)
} else if !needed {
return
}
// We'd like to write the file, but first make sure that we're not going
// to blow away anyone's work or overwrite a file already in git.
head, err := repo.Head()
if err != nil {
fatal(err)
@@ -60,23 +85,27 @@ func main() {
fatal(err)
}
inputFile := os.Getenv("GOFILE")
if !strings.HasSuffix(inputFile, "_oss.go") {
fatal(fmt.Errorf("stubmaker should only be invoked from files ending in _oss.go"))
}
baseFilename := strings.TrimSuffix(inputFile, "_oss.go")
target := baseFilename + "_ent.go"
tracked, err := inGit(wt, st, obj, target)
tracked, err := inGit(wt, st, obj, outputFile)
if err != nil {
fatal(err)
}
if tracked {
fatal(fmt.Errorf("output file %s exists in git, not overwriting", target))
fatal(fmt.Errorf("output file %s exists in git, not overwriting", outputFile))
}
if err := writeStubIfNeeded(inputFile, target); err != nil {
// Now we can finally write the file
output, err := os.Create(outputFile + ".tmp")
if err != nil {
fatal(err)
}
_, err = io.WriteString(output, strings.Join(getOutput(inputLines), "\n")+"\n")
if err != nil {
// If we don't end up writing to the file, delete it.
os.Remove(outputFile + ".tmp")
} else {
os.Rename(outputFile+".tmp", outputFile)
}
if err != nil {
fatal(err)
}
}
@@ -165,40 +194,10 @@ func readLines(r io.Reader) ([]string, error) {
return lines, nil
}
func writeStubIfNeeded(inputFile, outputFile string) (err error) {
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."
var output *os.File
b, err := os.ReadFile(inputFile)
if err != nil {
return err
}
inputLines, err := readLines(bytes.NewBuffer(b))
var funcs []string
var outputLines []string
for _, line := range inputLines {
switch line {
case "//go:build !enterprise":
outputLines = append(outputLines, warning, "")
line = "//go:build enterprise"
case "//go:generate go run github.com/hashicorp/vault/tools/stubmaker":
continue
}
outputLines = append(outputLines, line)
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "func ") {
i := strings.Index(trimmed, "(")
if i != -1 {
funcs = append(funcs, trimmed[5:i])
}
}
}
func isStubNeeded(funcs []string) (bool, error) {
pkg, err := parsePackage(".", []string{"enterprise"})
if err != nil {
return err
return false, err
}
var found []string
@@ -218,26 +217,44 @@ func writeStubIfNeeded(inputFile, outputFile string) (err error) {
}
switch {
case len(found) == len(funcs):
return nil
return false, nil
case len(found) != 0:
return fmt.Errorf("funcs partially defined: need=%v, found=%v", funcs, found)
return false, fmt.Errorf("funcs partially defined: need=%v, found=%v", funcs, found)
}
output, err = os.Create(outputFile + ".tmp")
if err != nil {
return err
}
// If we don't end up writing to the file, delete it.
defer func() {
if err != nil {
os.Remove(outputFile + ".tmp")
} else {
os.Rename(outputFile+".tmp", outputFile)
return true, nil
}
func getFuncs(inputLines []string) []string {
var funcs []string
for _, line := range inputLines {
trimmed := strings.TrimSpace(line)
if strings.HasPrefix(trimmed, "func ") {
i := strings.Index(trimmed, "(")
if i != -1 {
funcs = append(funcs, trimmed[5:i])
}
}
}()
}
return funcs
}
_, err = io.WriteString(output, strings.Join(outputLines, "\n")+"\n")
return err
func getOutput(inputLines []string) []string {
warning := "// Code generated by tools/stubmaker; DO NOT EDIT."
var outputLines []string
for _, line := range inputLines {
switch line {
case "//go:build !enterprise":
outputLines = append(outputLines, warning, "")
line = "//go:build enterprise"
case "//go:generate go run github.com/hashicorp/vault/tools/stubmaker":
continue
}
outputLines = append(outputLines, line)
}
return outputLines
}
func parsePackage(name string, tags []string) (*packages.Package, error) {