Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 38 additions & 26 deletions instrumenter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
"go/printer"
"go/token"
"os"
"os/exec"
"path/filepath"
"reflect"
"runtime"
"sort"
"strconv"
"strings"
)

Expand Down Expand Up @@ -95,7 +95,8 @@ func (i *instrumenter) instrument(srcDir, singleFile, dstDir string) bool {
i.instrumentFile(name, file, dstDir)
})
}
i.writeGobcoFiles(dstDir, pkgs)

i.writeGobcoFiles(srcDir, dstDir, pkgs)
return true
}

Expand Down Expand Up @@ -574,7 +575,7 @@ var fixedTemplate string
//go:embed templates/gobco_no_testmain_test.go
var noTestMainTemplate string

func (i *instrumenter) writeGobcoFiles(tmpDir string, pkgs []*ast.Package) {
func (i *instrumenter) writeGobcoFiles(srcDir, tmpDir string, pkgs []*ast.Package) {
pkgname := pkgs[0].Name
fixPkgname := func(str string) string {
str = strings.TrimPrefix(str, "//go:build ignore\n// +build ignore\n\n")
Expand All @@ -587,7 +588,7 @@ func (i *instrumenter) writeGobcoFiles(tmpDir string, pkgs []*ast.Package) {
writeFile(filepath.Join(tmpDir, "gobco_no_testmain_test.go"), fixPkgname(noTestMainTemplate))
}

i.writeGobcoBlackBox(pkgs, tmpDir)
i.writeGobcoBlackBox(pkgs, srcDir, tmpDir)
}

func (i *instrumenter) writeGobcoGo(filename, pkgname string) {
Expand All @@ -612,35 +613,46 @@ func (i *instrumenter) writeGobcoGo(filename, pkgname string) {
writeFile(filename, sb.String())
}

// findPackagePath finds import path of a package that srcDir indicates
func findPackagePath(srcDir string) (string, error) {
_, moduleRel, err := findInModule(srcDir)
if err != nil {
return "", err
}

moduleName, err := getModuleName()
if err != nil {
return "", err
}

if moduleRel == "." {
return moduleName, nil
} else {
pkgPath := fmt.Sprintf("%s/%s", moduleName, moduleRel)
return pkgPath, nil
}
}

func getModuleName() (string, error) {
cmd := exec.Command("go", "list", "-m")
output, err := cmd.Output()
if err != nil {
return "", err
}
return strings.TrimSpace(string(output)), nil
}

// writeGobcoBlackBox makes the function 'GobcoCover' available
// to black box tests (those in 'package x_test' instead of 'package x')
// by delegating to the function of the same name in the main package.
func (i *instrumenter) writeGobcoBlackBox(pkgs []*ast.Package, dstDir string) {
func (i *instrumenter) writeGobcoBlackBox(pkgs []*ast.Package, srcDir, dstDir string) {
if len(pkgs) < 2 {
return
}

// Copy the 'import' directive from one of the existing files.
pkgName, pkgPath := "", ""
for _, pkg := range pkgs {
forEachFile(pkg, func(name string, file *ast.File) {
for _, imp := range file.Imports {
var impName string
p, err := strconv.Unquote(imp.Path.Value)
ok(err)
if imp.Name != nil {
impName = imp.Name.Name
} else {
impName = filepath.Base(p)
}

if impName == pkgs[0].Name {
pkgName = impName
pkgPath = p
}
}
})
}
pkgPath, err := findPackagePath(srcDir)
ok(err)
pkgName := filepath.Base(pkgPath)

text := "" +
"package " + pkgs[0].Name + "_test\n" +
Expand Down
21 changes: 16 additions & 5 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,27 +205,38 @@ func (g *gobco) gopaths() string {
return filepath.Join(home, "go")
}

func (g *gobco) findInModule(dir string) (moduleRoot, moduleRel string) {
absDir, err := filepath.Abs(dir)
func (g *gobco) findInModule(dir string) (string, string) {
moduleRoot, moduleRel, err := findInModule(dir)
g.check(err)
return moduleRoot, moduleRel
}

// findInModule finds path of moduleRoot and relative path from the moduleRoot to dir
func findInModule(dir string) (moduleRoot, moduleRel string, err error) {
absDir, err := filepath.Abs(dir)
if err != nil {
return "", "", err
}

abs := absDir
for {
if _, err := os.Lstat(filepath.Join(abs, "go.mod")); err == nil {
rel, err := filepath.Rel(abs, absDir)
g.check(err)
if err != nil {
return "", "", err
}

root := abs
if rel == "." {
root = dir
}

return root, rel
return root, rel, nil
}

parent := filepath.Dir(abs)
if parent == abs {
return "", ""
return "", "", nil
}
abs = parent
}
Expand Down