blob: b575dcc4ec67d292567e303185619dbe6300425b [file] [log] [blame]
// Copyright 2015 The Vanadium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package main
import (
"fmt"
"go/ast"
"go/build"
"go/parser"
"go/token"
"go/types"
"io"
"io/ioutil"
"os"
"path"
"path/filepath"
"sort"
"strconv"
"strings"
"unicode"
"v.io/jiri"
"v.io/jiri/collect"
"v.io/x/devtools/internal/goutil"
)
const (
// nologComment is the magic comment text that disables log injection.
nologComment = "nologcall"
// logCallComment is the comment to be appended to all injected calls.
logCallComment = "// gologcop: DO NOT EDIT, MUST BE FIRST STATEMENT"
v23ContextPackage = "v.io/v23/context"
v23ContextTypeName = "T"
)
var (
// the import tag for inject, if any, as in import tag "path"
injectImportTag string
// the import path for inject.
injectImportPath string
// the package name to use at call sites, it's either the
// injectImportTag if one is specified or the base name of injectImportPath
injectPackage string
// the call to be injected, without the package name.
injectCall string
// the package and call to be removed
removePackage, removeCall string
)
// parseState encapsulates all of the state acquired during parsing and
// type checking. It makes sure that any give file or package
// is parsed or type checked once and only once.
type parseState struct {
jirix *jiri.X
config *types.Config
fset *token.FileSet
info *types.Info
packages map[string]*types.Package // keyed by the package path name.
asts map[string][]*ast.File // keyed by the package path name
}
func newState(jirix *jiri.X) *parseState {
ps := &parseState{
jirix: jirix,
fset: token.NewFileSet(),
packages: make(map[string]*types.Package),
asts: make(map[string][]*ast.File),
config: &types.Config{
IgnoreFuncBodies: true,
},
info: &types.Info{
Types: make(map[ast.Expr]types.TypeAndValue),
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
},
}
ps.config.Importer = ps
return ps
}
func (ps *parseState) Import(path string) (*types.Package, error) {
return ps.sourceImporter(path)
}
func (ps *parseState) parsedPackage(path string) (*types.Package, []*ast.File) {
return ps.packages[path], ps.asts[path]
}
func (ps *parseState) addParsedPackage(path string, pkg *types.Package, asts []*ast.File) {
if p, _ := ps.parsedPackage(path); p != nil {
fmt.Fprintf(ps.jirix.Stdout(), "Warning: %s is already cached\n", path)
return
}
ps.packages[path] = pkg
ps.asts[path] = asts
}
// sourceImporter will always import from source code.
func (ps *parseState) sourceImporter(path string) (*types.Package, error) {
// It seems that we need to special case the unsafe package.
if path == "unsafe" {
return types.Unsafe, nil
}
if pkg, _ := ps.parsedPackage(path); pkg != nil {
return pkg, nil
}
progressMsg(ps.jirix.Stdout(), "importing from source: %s\n", path)
bpkg, err := build.Default.Import(path, ".", build.ImportMode(build.ImportComment))
_, pkg, err := ps.parseAndTypeCheckPackage(bpkg)
if err != nil {
return nil, err
}
return pkg, err
}
// parseAndTypeCheckPackage will parse and type check a given package.
func (ps *parseState) parseAndTypeCheckPackage(bpkg *build.Package) ([]*ast.File, *types.Package, error) {
if tpkg, asts := ps.parsedPackage(bpkg.ImportPath); tpkg != nil {
return asts, tpkg, nil
}
tpkg := types.NewPackage(bpkg.ImportPath, bpkg.Name)
checker := types.NewChecker(ps.config, ps.fset, tpkg, ps.info)
// Parse the files in this package.
asts := []*ast.File{}
dir := bpkg.Dir
for _, fileInPkg := range bpkg.GoFiles {
file := filepath.Join(dir, fileInPkg)
a, err := parser.ParseFile(ps.fset, file, nil, parser.ParseComments)
if err != nil {
return nil, nil, err
}
asts = append(asts, a)
}
if err := checker.Files(asts); err != nil {
return nil, nil, err
}
// make sure that type checking is complete at this stage. It should
// always be so, so this is really an 'assertion' that it is.
if !tpkg.Complete() {
return nil, nil, fmt.Errorf("checked %q is not completely parsed+checked", bpkg.Name)
}
progressMsg(ps.jirix.Stdout(), "parsed from source: %s\n", bpkg.ImportPath)
ps.addParsedPackage(bpkg.ImportPath, tpkg, asts)
return asts, tpkg, nil
}
// importPkgs will expand the supplied list of packages using go list
// (so v.io/v23/... can be used as an interface package spec for example) and
// then import those packages.
func importPkgs(jirix *jiri.X, goFlags, packageSpec []string) (ifcs []*build.Package, err error) {
flags := append(goFlags, "-merge-policies="+mergePoliciesFlag.String())
pkgs, err := goutil.List(jirix, flags, packageSpec...)
if err != nil {
return nil, err
}
importer := func(pkgs []string) ([]*build.Package, error) {
pkgInfos := []*build.Package{}
for _, pkg := range pkgs {
pkgInfo, err := build.Default.Import(pkg, ".", build.ImportMode(build.ImportComment))
if err != nil {
return nil, err
}
pkgInfos = append(pkgInfos, pkgInfo)
}
return pkgInfos, nil
}
bpkgs, err := importer(pkgs)
if err != nil {
return nil, fmt.Errorf("error importing packages: %v", err)
}
return bpkgs, nil
}
// exists is used as the value to indicate existence for maps that
// function as sets.
var exists = struct{}{}
func initInjectorFlags() error {
parts := strings.FieldsFunc(injectCallImportFlag, unicode.IsSpace)
var err error
switch len(parts) {
case 1:
injectImportTag = ""
injectImportPath, err = strconv.Unquote(injectCallImportFlag)
if err != nil {
injectImportPath = injectCallImportFlag
}
injectPackage = path.Base(injectImportPath)
case 2:
injectImportTag = parts[0]
injectImportPath, err = strconv.Unquote(parts[1])
if err != nil {
injectImportPath = parts[1]
}
injectPackage = injectImportTag
default:
return fmt.Errorf("%q doesn't look like an import declaration", injectCallImportFlag)
}
injectCall = injectCallFlag
return nil
}
// run runs the log injector.
func runInjector(jirix *jiri.X, goFlags, interfaceList, implementationList []string, checkOnly bool) error {
if err := initInjectorFlags(); err != nil {
return err
}
// use 'go list' and the builder to import all of the packages
// specified as interfaces and implementations.
ifcs, err := importPkgs(jirix, goFlags, interfaceList)
if err != nil {
return err
}
impls, err := importPkgs(jirix, goFlags, implementationList)
if err != nil {
return err
}
printHeader(jirix.Stdout(), "Package Summary")
progressMsg(jirix.Stdout(), "%v expands to %d interface packages\n", interfaceList, len(ifcs))
progressMsg(jirix.Stdout(), "%v expands to %d implementation packages\n", implementationList, len(impls))
ps := newState(jirix)
checkFailed := []string{}
printHeader(jirix.Stdout(), "Parsing and Type Checking Interface Packages")
ifcPkgs := []*types.Package{}
for _, ifc := range ifcs {
_, tpkg, err := ps.parseAndTypeCheckPackage(ifc)
if err != nil {
return fmt.Errorf("failed to parse+type check: %s: %s", ifc.ImportPath, err)
}
ifcPkgs = append(ifcPkgs, tpkg)
}
publicInterfaces := findPublicInterfaces(jirix, ifcPkgs)
for _, impl := range impls {
printHeader(jirix.Stdout(), "Parsing and Type Checking Implementation Packages")
asts, tpkg, err := ps.parseAndTypeCheckPackage(impl)
if err != nil {
return fmt.Errorf("failed to parse+type check: %s: %s", impl.ImportPath, err)
}
// Now find the methods that implement those public interfaces.
methods := findMethodsImplementing(jirix, ps.fset, tpkg, publicInterfaces)
// and their positions in the files.
methodPositions, err := functionDeclarationsAtPositions(ps.fset, asts, ps.info, methods)
if err != nil {
return err
}
// then check to see if those methods already have logging statements.
needsInjection := checkMethods(methodPositions)
if checkOnly {
if len(needsInjection) > 0 {
printHeader(jirix.Stdout(), "Check Results")
reportResults(jirix, ps.fset, needsInjection)
checkFailed = append(checkFailed, impl.ImportPath)
}
} else {
if err := inject(jirix, ps.fset, needsInjection); err != nil {
return fmt.Errorf("injection failed for: %s: %s", impl.ImportPath, err)
}
}
}
if checkOnly && len(checkFailed) > 0 {
for _, p := range checkFailed {
fmt.Fprintf(jirix.Stdout(), "check failed for: %s\n", p)
}
os.Exit(1)
}
return nil
}
func initRemoverFlags() error {
parts := strings.Split(removeCallFlag, ".")
switch len(parts) {
case 2:
removePackage = parts[0]
removeCall = parts[1]
default:
return fmt.Errorf("%q doesn't look like a function call on an imported package", removeCallFlag)
}
return nil
}
func runRemover(jirix *jiri.X, goFlags, implementationList []string) error {
if err := initRemoverFlags(); err != nil {
return err
}
// use 'go list' and the builder to import all of the packages
// specified as implementations.
impls, err := importPkgs(jirix, goFlags, implementationList)
if err != nil {
return err
}
ps := newState(jirix)
printHeader(jirix.Stdout(), "Package Summary")
progressMsg(jirix.Stdout(), "%v expands to %d implementation packages\n", implementationList, len(impls))
for _, impl := range impls {
asts, tpkg, err := ps.parseAndTypeCheckPackage(impl)
if err != nil {
return fmt.Errorf("failed to parse+type check: %s: %s", impl.ImportPath, err)
}
methods := findMethods(jirix, ps.fset, tpkg)
methodPositions, err := functionDeclarationsAtPositions(ps.fset, asts, ps.info, methods)
if err != nil {
return err
}
needsRemoval := findRemovals(methodPositions)
if err := remove(jirix, ps.fset, needsRemoval); err != nil {
return fmt.Errorf("removal failed for: %s: %s", impl.ImportPath, err)
}
}
return nil
}
// funcDeclRef stores a reference to a function declaration, paired
// with the file containing it.
type funcDeclRef struct {
Decl *ast.FuncDecl
File *ast.File
LogCall string
}
// methodSetVisibleThroughInterfaces returns intersection of all
// exported method names implemented by t and the union of all method
// names declared by interfaces.
func methodSetVisibleThroughInterfaces(t types.Type, interfaces []*types.Interface) map[string]struct{} {
set := map[string]struct{}{}
for _, ifc := range interfaces {
if types.Implements(t, ifc) || types.Implements(types.NewPointer(t), ifc) {
// t implements ifc, so add all the public
// method names of ifc to set.
for i := 0; i < ifc.NumMethods(); i++ {
if name := ifc.Method(i).Name(); ast.IsExported(name) {
set[name] = exists
}
}
}
}
return set
}
func hasV23Context(info *types.Info, parameters *ast.FieldList) (*ast.FieldList, string) {
if !useContextFlag {
return parameters, ""
}
if parameters == nil {
return nil, "nil"
}
filtered := *parameters
for i, field := range filtered.List {
typ := info.TypeOf(field.Type)
ptr, ok := typ.(*types.Pointer)
if !ok {
continue
}
named, ok := ptr.Elem().(*types.Named)
if !ok {
continue
}
name := named.Obj()
if name.Pkg().Path() == v23ContextPackage && name.Name() == v23ContextTypeName {
filtered.List = append(filtered.List[:i], filtered.List[i+1:]...)
jirixname := "nil"
if len(field.Names) > 0 && field.Names[0].Name != "_" {
jirixname = field.Names[0].Name
}
return &filtered, jirixname
}
}
return &filtered, "nil"
}
func genFmt(info *types.Info, fields *ast.FieldList, indirect bool) ([]string, []string, error) {
fmtForBasicType := func(typ *types.Basic) string {
if typ.Kind() == types.String {
return "%.10s..."
} else {
return "%v"
}
}
if fields == nil {
return nil, nil, nil
}
format := []string{}
args := []string{}
for _, param := range fields.List {
typ := info.TypeOf(param.Type)
var f string
printable := false
ellipsis := false
switch v := typ.(type) {
case *types.Basic:
f = fmtForBasicType(v)
printable = true
case *types.Named:
switch u := typ.Underlying().(type) {
case *types.Basic:
f = fmtForBasicType(u)
printable = true
case *types.Interface:
if v.Obj().Name() == "error" {
f = "%v"
printable = true
}
}
case nil:
if _, ok := param.Type.(*ast.Ellipsis); !ok {
return nil, nil, fmt.Errorf("failed to locate type for %v", param.Names)
}
// We'll print out the ellipsis args as a slice of whatever type it is.
f = "%v"
printable = true
ellipsis = true
}
for _, n := range param.Names {
if n.Name != "_" && len(n.Name) > 0 {
if printable {
if ellipsis {
format = append(format, n.Name+"...="+f)
} else {
format = append(format, n.Name+"="+f)
}
name := n.Name
if indirect {
name = "&" + name
}
args = append(args, name)
} else {
format = append(format, n.Name+"=")
}
}
}
}
return format, args, nil
}
func genCall(info *types.Info, params, results *ast.FieldList) (string, error) {
params, contextPar := hasV23Context(info, params)
noargs := fmt.Sprintf("\n\tdefer %s.%s(%s)(%s) %s", injectPackage, injectCall, contextPar, contextPar, logCallComment)
if info == nil {
return noargs, nil
}
argFormat, printableArgs, err := genFmt(info, params, false)
if err != nil {
return "", err
}
resFormat, printableResults, err := genFmt(info, results, true)
if err != nil {
return "", err
}
if len(argFormat) == 0 && len(resFormat) == 0 {
return noargs, nil
}
formatArgs := func(format, parameters []string) string {
if len(format) > 0 {
formatStr := strings.TrimSpace(strings.Join(format, ","))
parametersStr := strings.Join(parameters, ",")
return fmt.Sprintf("\"%s\", %s", formatStr, parametersStr)
}
return "\"\""
}
pars := formatArgs(argFormat, printableArgs)
res := formatArgs(resFormat, printableResults)
contextParArg, contextParRes := contextPar, contextPar
if len(contextPar) > 0 {
if len(pars) > 0 {
contextParArg += ", "
}
if len(res) > 0 {
contextParRes += ", "
}
}
return fmt.Sprintf("\n\tdefer %s.%sf(%s%s)(%s%s) %s", injectPackage, injectCall, contextParArg, pars, contextParRes, res, logCallComment), nil
}
// functionDeclarationsAtPositions returns references to function
// declarations in packages where the position of the identifier token
// representing the name of the function is in positions.
func functionDeclarationsAtPositions(fset *token.FileSet, files []*ast.File, info *types.Info, positions map[token.Pos]struct{}) ([]funcDeclRef, error) {
result := []funcDeclRef{}
for _, file := range files {
for _, decl := range file.Decls {
if decl, ok := decl.(*ast.FuncDecl); ok {
call, err := genCall(info, decl.Type.Params, decl.Type.Results)
if err != nil {
pos := fset.Position(decl.Pos())
return nil, fmt.Errorf("%s:%d: %v", pos.Filename, pos.Line, err)
}
// for each function declaration in packages:
//
// it's important not to use decl.Pos() here
// as it gives us the position of the "func"
// token, whereas positions has collected
// the locations of method name tokens:
if _, ok := positions[decl.Name.Pos()]; ok {
result = append(result, funcDeclRef{decl, file, call})
}
}
}
}
return result, nil
}
// findMethodsImplementing searches the specified packages and returns
// a list of function declarations that are implementations for
// the specified interfaces.
func findMethodsImplementing(jirix *jiri.X, fset *token.FileSet, tpkg *types.Package, interfaces []*types.Interface) map[token.Pos]struct{} {
// positions will hold the set of Pos values of methods
// that should be logged. Each element will be the position of
// the identifier token representing the method name of such
// methods. The reason we collect the positions first is that
// our static analysis library has no easy way to map types.Func
// objects to ast.FuncDecl objects, so we then look into AST
// declarations and find everything that has a matching position.
positions := map[token.Pos]struct{}{}
printHeader(jirix.Stdout(), "Methods Implementing Public Interfaces in %s", tpkg.Path())
scope := tpkg.Scope()
for _, child := range scope.Names() {
object := scope.Lookup(child)
typ := object.Type()
// ignore interfaces as they have no method implementations
if types.IsInterface(typ) {
continue
}
// for each non-interface type t declared in packages:
apiMethodSet := methodSetVisibleThroughInterfaces(typ, interfaces)
// optimization: if t implements no non-empty interfaces that
// we care about, we can just ignore it.
if len(apiMethodSet) > 0 {
// find all the methods explicitly declared or implicitly
// inherited through embedding on type t or *t.
methodSet := types.NewMethodSet(typ)
if methodSet.Len() == 0 {
methodSet = types.NewMethodSet(types.NewPointer(typ))
}
for i := 0; i < methodSet.Len(); i++ {
method := methodSet.At(i)
fn := method.Obj().(*types.Func)
// t may have a method that is not declared in any of
// the interfaces we care about. No need to log that.
if _, ok := apiMethodSet[fn.Name()]; ok {
if fn.Pos() == 0 {
// Embedded functions show up with a zero pos.
continue
}
progressMsg(jirix.Stdout(), "%s.%s: %s\n", tpkg.Path(), fn.Name(), fset.Position(fn.Pos()))
positions[fn.Pos()] = exists
}
}
}
}
return positions
}
func findMethodsInScope(jirix *jiri.X, fset *token.FileSet, positions map[token.Pos]struct{}, scope *types.Scope) {
for _, child := range scope.Names() {
object := scope.Lookup(child)
typ := object.Type()
switch v := typ.(type) {
case *types.Named:
for i := 0; i < v.NumMethods(); i++ {
m := v.Method(i)
positions[m.Pos()] = exists
}
case *types.Signature:
positions[object.Pos()] = exists
}
}
}
func findMethods(jirix *jiri.X, fset *token.FileSet, tpkg *types.Package) map[token.Pos]struct{} {
positions := map[token.Pos]struct{}{}
printHeader(jirix.Stdout(), "Methods in %s", tpkg.Path())
scope := tpkg.Scope()
findMethodsInScope(jirix, fset, positions, scope)
return positions
}
type patch struct {
Offset int
Text string
NextOffset int
}
type patchSorter []patch
func insertAt(offset int, text string) patch {
return patch{
Offset: offset,
Text: text,
NextOffset: offset,
}
}
func removeRange(from, to int) patch {
return patch{
Offset: from,
NextOffset: to,
}
}
func (p patchSorter) Len() int {
return len(p)
}
func (p patchSorter) Less(i, j int) bool {
return p[i].Offset < p[j].Offset
}
func (p patchSorter) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
// countOverlap counts the length of the common prefix between two strings.
func countOverlap(a, b string) (i int) {
for ; i < len(a) && i < len(b) && a[i] == b[i]; i++ {
}
return
}
// ©3ImportLogPackage will make sure that the file includes an
// import declaration to the package to be injected, and adds one if it does not
// already.
func ensureImportLogPackage(fset *token.FileSet, file *ast.File) (patch, bool) {
maxOverlap := 0
var candidate token.Pos
quotedImportPath := strconv.Quote(injectImportPath)
for _, d := range file.Decls {
d, ok := d.(*ast.GenDecl)
if !ok || d.Tok != token.IMPORT {
// We encountered a non-import declaration. As
// imports always precede other declarations,
// we are done with our search.
break
}
for _, s := range d.Specs {
s := s.(*ast.ImportSpec)
tag := ""
if s.Name != nil {
tag = s.Name.Name
}
path := s.Path.Value
// Match import tag.
if len(injectImportTag) > 0 && injectImportTag == tag {
return patch{}, false
}
// Match path.
if quotedImportPath == path {
return patch{}, false
}
// Keep track of which import in a parenthesised list of imports
// has the greatest overlap with the one we're going to add - i.e.
// make sure we insert the new import in the lexicographically ordered
// location.
overlap := countOverlap(s.Path.Value, quotedImportPath)
if d.Lparen.IsValid() && overlap > maxOverlap {
maxOverlap = overlap
candidate = s.Pos()
}
}
}
impStmt := func() string {
if len(injectImportTag) > 0 {
return injectImportTag + " " + quotedImportPath + "\n"
}
return quotedImportPath + "\n"
}
if maxOverlap > 0 {
return insertAt(fset.Position(candidate).Offset, impStmt()), true
}
// No import declaration found with parenthesis; create a new
// one and add it to the beginning of the file.
return insertAt(fset.Position(file.Decls[0].Pos()).Offset, "import "+impStmt()), true
}
// methodBeginsWithNoLogComment returns true if method has a
// "nologcall" comment before any non-whitespace or non-comment token.
func methodBeginsWithNoLogComment(m funcDeclRef) bool {
method := m.Decl
lbound := method.Body.Lbrace
ubound := method.Body.Rbrace
stmts := method.Body.List
if len(stmts) > 0 {
ubound = stmts[0].Pos()
}
for _, cmt := range m.File.Comments {
if lbound <= cmt.Pos() && cmt.End() <= ubound {
for _, line := range strings.Split(cmt.Text(), "\n") {
line := strings.TrimSpace(line)
if line == nologComment {
return true
}
}
}
}
return false
}
func findRemovals(methods []funcDeclRef) map[funcDeclRef]error {
result := map[funcDeclRef]error{}
for _, m := range methods {
if err := validateLogStatement(m.Decl, removePackage, removeCall); err == nil {
result[m] = nil
}
}
return result
}
// checkMethods checks all items in methods and returns the subset
// of them that do not have valid log statements.
func checkMethods(methods []funcDeclRef) map[funcDeclRef]error {
result := map[funcDeclRef]error{}
for _, m := range methods {
if err := checkMethod(m); err != nil {
result[m] = err
}
}
return result
}
// checkMethod checks that method includes an acceptable logging
// construct before any other non-whitespace or non-comment token.
func checkMethod(method funcDeclRef) error {
if err := validateLogStatement(method.Decl, injectPackage, injectCall); err != nil && !methodBeginsWithNoLogComment(method) {
return err
}
return nil
}
// gofmt runs "gofmt -w files...".
func gofmt(jirix *jiri.X, verbose bool, files []string) error {
if len(files) == 0 || !gofmtFlag {
return nil
}
return jirix.NewSeq().Verbose(verbose).Last("gofmt", append([]string{"-w"}, files...)...)
}
// writeFiles writes out files modified by the patch sets supplied to it.
func writeFiles(jirix *jiri.X, fset *token.FileSet, files map[*ast.File][]patch) (e error) {
filesToFormat := []string{}
// Write out files in a fixed order so that other tools/tests can count on the
// diff output.
filenames := []string{}
asts := map[string]*ast.File{}
for file, _ := range files {
filename := fset.Position(file.Pos()).Filename
filenames = append(filenames, filename)
asts[filename] = file
}
sort.Strings(filenames)
s := jirix.NewSeq()
for _, filename := range filenames {
file := asts[filename]
patches := files[file]
filesToFormat = append(filesToFormat, filename)
sort.Sort(patchSorter(patches))
src, err := ioutil.ReadFile(filename)
if err != nil {
return err
}
beginOffset := 0
patchedSrc := []byte{}
for _, patch := range patches {
patchedSrc = append(patchedSrc, src[beginOffset:patch.Offset]...)
patchedSrc = append(patchedSrc, patch.Text...)
beginOffset = patch.NextOffset
}
patchedSrc = append(patchedSrc, src[beginOffset:]...)
if diffOnlyFlag {
tmpDir, err := s.TempDir("", "")
if err != nil {
return err
}
tmpFilename := filepath.Join(tmpDir, "gologcop-"+filepath.Base(filename))
defer collect.Error(func() error { return jirix.NewSeq().RemoveAll(tmpDir).Done() }, &e)
if err := s.WriteFile(tmpFilename, patchedSrc, os.FileMode(0644)).Done(); err != nil {
return err
}
progressMsg(jirix.Stdout(), "Diffing %s with %s\n", filename, tmpFilename)
gofmt(jirix, false, []string{tmpFilename})
s.Verbose(false).Capture(jirix.Stdout(), jirix.Stderr()).Last("diff", filename, tmpFilename)
} else {
s.WriteFile(filename, patchedSrc, 644).Done()
}
}
if diffOnlyFlag {
return nil
}
return gofmt(jirix, jirix.Verbose(), filesToFormat)
}
// remove removes a log call at the beginning of each method in methods.
func remove(jirix *jiri.X, fset *token.FileSet, methods map[funcDeclRef]error) error {
files := map[*ast.File][]patch{}
comments := map[*ast.File]ast.CommentMap{}
for fdRef, _ := range methods {
file := fdRef.File
if _, present := comments[file]; !present {
comments[fdRef.File] = ast.NewCommentMap(fset, file, file.Comments)
}
}
// endAt returns the position of the next statement, comment or function,
// i.e. the end of the block of code to be removed.
endAt := func(fn *ast.FuncDecl, cm ast.CommentMap) int {
endpos := fn.Body.Rbrace
stmt := fn.Body.List[0]
if len(fn.Body.List) > 1 {
nextStmt := fn.Body.List[1]
endpos = nextStmt.Pos()
if cg := cm.Filter(nextStmt).Comments(); len(cg) > 0 {
if len(cg[0].List) > 0 {
if cg[0].List[0].Pos() < endpos {
// Only use a comment if it comes before the next statemnt.
endpos = cg[0].List[0].Pos()
}
}
}
}
stmtLine := fset.Position(stmt.Pos()).Line
// Delete any comment on the same line as the logcall.
for _, cg := range cm.Filter(stmt).Comments() {
for _, c := range cg.List {
if fset.Position(c.Pos()).Line > stmtLine {
endpos = c.Pos()
break
}
}
}
return fset.Position(endpos).Offset
}
for m, _ := range methods {
file := m.File
stmts := m.Decl.Body.List
if len(stmts) == 0 {
return fmt.Errorf("no statements found for %s", m.Decl.Name)
}
// The first statement should be the call we want to remove.
start := fset.Position(stmts[0].Pos()).Offset
end := endAt(m.Decl, comments[m.File])
files[file] = append(files[file], removeRange(start, end))
}
return writeFiles(jirix, fset, files)
}
// inject injects a log call at the beginning of each method in methods.
func inject(jirix *jiri.X, fset *token.FileSet, methods map[funcDeclRef]error) error {
// Warn the user for methods that already have something at
// their beginning that looks like a logging construct, but it
// is invalid for some reason.
for m, err := range methods {
if _, ok := err.(*errInvalid); ok {
method := m.Decl
position := fset.Position(method.Pos())
methodName := method.Name.Name
fmt.Fprintf(jirix.Stdout(), "Warning: %v: %s: %v\n", position, methodName, err)
}
}
files := map[*ast.File][]patch{}
for m, _ := range methods {
text := m.LogCall
// Catch the case where the function body is on the same line - e.g. func() {}
// so that we make sure we add a newline to the comment to push the right brace
// onto the next line.
if fset.Position(m.Decl.Body.Lbrace).Line == fset.Position(m.Decl.Body.Rbrace).Line {
text += "\n"
}
delta := insertAt(fset.Position(m.Decl.Body.Lbrace).Offset+1, text)
file := m.File
files[file] = append(files[file], delta)
}
for file, deltas := range files {
if delta, hasChanges := ensureImportLogPackage(fset, file); hasChanges {
files[file] = append(deltas, delta)
}
}
return writeFiles(jirix, fset, files)
}
// reportResults prints out the validation results from checkMethods
// in a human-readable form.
func reportResults(jirix *jiri.X, fset *token.FileSet, methods map[funcDeclRef]error) {
for m, err := range methods {
fmt.Fprintf(jirix.Stdout(), "%v: %s: %v\n", fset.Position(m.Decl.Pos()), m.Decl.Name.Name, err)
}
}
// ensureExprsArePointers returns an error if at least one of the
// expressions in exprs is not in the form of &x.
func ensureExprsArePointers(exprs []ast.Expr) error {
for _, expr := range exprs {
if !isAddressOfExpression(expr) {
return &errInvalid{"output arguments should be passed to the log function via their addresses"}
}
}
return nil
}
// validateLogStatement returns an error if method does not begin
// with a valid defer call.
func validateLogStatement(method *ast.FuncDecl, pkg, name string) error {
stmtList := method.Body.List
if len(stmtList) == 0 {
return &errNotExists{"empty method"}
}
deferStmt, ok := stmtList[0].(*ast.DeferStmt)
if !ok {
return &errNotExists{"no defer statement"}
}
logCall, ok := deferStmt.Call.Fun.(*ast.CallExpr)
if !ok {
return &errNotExists{"defer is a not a function call"}
}
selector, ok := logCall.Fun.(*ast.SelectorExpr)
if !ok {
return &errNotExists{"not a <pkg>.<method> call"}
}
packageIdent, ok := selector.X.(*ast.Ident)
if !ok {
return &errNotExists{"not a valid package selector"}
}
if packageIdent.Name != pkg {
return &errNotExists{fmt.Sprintf("wrong package: got %q, want %q", packageIdent.Name, pkg)}
}
deferArgs := deferStmt.Call.Args
if useContextFlag && len(deferArgs) > 0 {
deferArgs = deferArgs[1:]
}
switch selector.Sel.Name {
case name:
return ensureExprsArePointers(deferArgs)
case name + "f":
nFnArgs := 0
if fnCall, ok := deferStmt.Call.Fun.(*ast.CallExpr); ok {
nFnArgs = len(fnCall.Args)
}
if nFnArgs < 1 {
return &errInvalid{"no format specifier specified for called defer func: " + name}
}
nCallArgs := len(deferStmt.Call.Args)
if nCallArgs < 1 {
return &errInvalid{"no format specifier specified for returned defer func: " + name}
}
if len(deferArgs) > 0 {
// Skip past format flag, but if we're called for a Remove
// then we can't be sure there is a format.
deferArgs = deferArgs[1:]
}
return ensureExprsArePointers(deferArgs)
}
return &errNotExists{fmt.Sprintf("got \"%s.%s\", want \"%s.%s\"", packageIdent.Name, selector.Sel.Name, pkg, name)}
}
// isAddressOfExpression checks if expr is an expression in the form
// of `&expression`
func isAddressOfExpression(expr ast.Expr) (isAddrExpr bool) {
// TODO: support (&x) as well as &x
unaryExpr, ok := expr.(*ast.UnaryExpr)
return ok && unaryExpr.Op == token.AND
}
func printHeader(out io.Writer, format string, args ...interface{}) {
if progressFlag {
s := fmt.Sprintf(format, args...)
fmt.Fprintln(out)
fmt.Fprintln(out, s)
fmt.Fprintln(out, strings.Repeat("=", len(s)))
}
}
func progressMsg(out io.Writer, format string, args ...interface{}) {
if progressFlag {
fmt.Fprintf(out, format, args...)
}
}
// findPublicInterfaces returns all the public interfaces defined in the
// supplied packages.
func findPublicInterfaces(jirix *jiri.X, ifcs []*types.Package) (interfaces []*types.Interface) {
for _, ifc := range ifcs {
printHeader(jirix.Stdout(), "Public Interfaces for %s", ifc.Path())
scope := ifc.Scope()
for _, child := range scope.Names() {
object := scope.Lookup(child)
typ := object.Type()
if object.Exported() && types.IsInterface(typ) {
ifcType := typ.Underlying().(*types.Interface)
if !ifcType.Empty() {
progressMsg(jirix.Stdout(), "%s.%s\n", ifc.Path(), object.Name())
interfaces = append(interfaces, ifcType)
}
}
}
}
return interfaces
}
type errInvalid struct {
message string
}
func (l errInvalid) Error() string {
if len(l.message) > 0 {
return l.message
}
return "invalid log statement"
}
type errNotExists struct {
message string
}
func (e errNotExists) Error() string {
return fmt.Sprintf("injected statement does not exist: %s", e.message)
}