| // Copyright 2015 The Vanadium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package main |
| |
| import ( |
| "fmt" |
| "go/ast" |
| "go/build" |
| "go/parser" |
| "go/token" |
| "go/types" |
| "io" |
| "io/ioutil" |
| "os" |
| "path" |
| "path/filepath" |
| "sort" |
| "strconv" |
| "strings" |
| "unicode" |
| |
| "v.io/jiri" |
| "v.io/jiri/collect" |
| "v.io/x/devtools/internal/goutil" |
| ) |
| |
| const ( |
| // nologComment is the magic comment text that disables log injection. |
| nologComment = "nologcall" |
| // logCallComment is the comment to be appended to all injected calls. |
| logCallComment = "// gologcop: DO NOT EDIT, MUST BE FIRST STATEMENT" |
| |
| v23ContextPackage = "v.io/v23/context" |
| v23ContextTypeName = "T" |
| ) |
| |
| var ( |
| // the import tag for inject, if any, as in import tag "path" |
| injectImportTag string |
| // the import path for inject. |
| injectImportPath string |
| // the package name to use at call sites, it's either the |
| // injectImportTag if one is specified or the base name of injectImportPath |
| injectPackage string |
| // the call to be injected, without the package name. |
| injectCall string |
| |
| // the package and call to be removed |
| removePackage, removeCall string |
| ) |
| |
| // parseState encapsulates all of the state acquired during parsing and |
| // type checking. It makes sure that any give file or package |
| // is parsed or type checked once and only once. |
| type parseState struct { |
| jirix *jiri.X |
| config *types.Config |
| fset *token.FileSet |
| info *types.Info |
| packages map[string]*types.Package // keyed by the package path name. |
| asts map[string][]*ast.File // keyed by the package path name |
| } |
| |
| func newState(jirix *jiri.X) *parseState { |
| ps := &parseState{ |
| jirix: jirix, |
| fset: token.NewFileSet(), |
| packages: make(map[string]*types.Package), |
| asts: make(map[string][]*ast.File), |
| config: &types.Config{ |
| IgnoreFuncBodies: true, |
| }, |
| info: &types.Info{ |
| Types: make(map[ast.Expr]types.TypeAndValue), |
| Defs: make(map[*ast.Ident]types.Object), |
| Uses: make(map[*ast.Ident]types.Object), |
| }, |
| } |
| ps.config.Importer = ps |
| return ps |
| } |
| |
| func (ps *parseState) Import(path string) (*types.Package, error) { |
| return ps.sourceImporter(path) |
| } |
| |
| func (ps *parseState) parsedPackage(path string) (*types.Package, []*ast.File) { |
| return ps.packages[path], ps.asts[path] |
| } |
| |
| func (ps *parseState) addParsedPackage(path string, pkg *types.Package, asts []*ast.File) { |
| if p, _ := ps.parsedPackage(path); p != nil { |
| fmt.Fprintf(ps.jirix.Stdout(), "Warning: %s is already cached\n", path) |
| return |
| } |
| ps.packages[path] = pkg |
| ps.asts[path] = asts |
| } |
| |
| // sourceImporter will always import from source code. |
| func (ps *parseState) sourceImporter(path string) (*types.Package, error) { |
| // It seems that we need to special case the unsafe package. |
| if path == "unsafe" { |
| return types.Unsafe, nil |
| } |
| if pkg, _ := ps.parsedPackage(path); pkg != nil { |
| return pkg, nil |
| } |
| progressMsg(ps.jirix.Stdout(), "importing from source: %s\n", path) |
| bpkg, err := build.Default.Import(path, ".", build.ImportMode(build.ImportComment)) |
| _, pkg, err := ps.parseAndTypeCheckPackage(bpkg) |
| if err != nil { |
| return nil, err |
| } |
| return pkg, err |
| } |
| |
| // parseAndTypeCheckPackage will parse and type check a given package. |
| func (ps *parseState) parseAndTypeCheckPackage(bpkg *build.Package) ([]*ast.File, *types.Package, error) { |
| if tpkg, asts := ps.parsedPackage(bpkg.ImportPath); tpkg != nil { |
| return asts, tpkg, nil |
| } |
| |
| tpkg := types.NewPackage(bpkg.ImportPath, bpkg.Name) |
| checker := types.NewChecker(ps.config, ps.fset, tpkg, ps.info) |
| |
| // Parse the files in this package. |
| asts := []*ast.File{} |
| dir := bpkg.Dir |
| for _, fileInPkg := range bpkg.GoFiles { |
| file := filepath.Join(dir, fileInPkg) |
| a, err := parser.ParseFile(ps.fset, file, nil, parser.ParseComments) |
| if err != nil { |
| return nil, nil, err |
| } |
| asts = append(asts, a) |
| } |
| if err := checker.Files(asts); err != nil { |
| return nil, nil, err |
| } |
| // make sure that type checking is complete at this stage. It should |
| // always be so, so this is really an 'assertion' that it is. |
| if !tpkg.Complete() { |
| return nil, nil, fmt.Errorf("checked %q is not completely parsed+checked", bpkg.Name) |
| } |
| progressMsg(ps.jirix.Stdout(), "parsed from source: %s\n", bpkg.ImportPath) |
| ps.addParsedPackage(bpkg.ImportPath, tpkg, asts) |
| return asts, tpkg, nil |
| } |
| |
| // importPkgs will expand the supplied list of packages using go list |
| // (so v.io/v23/... can be used as an interface package spec for example) and |
| // then import those packages. |
| func importPkgs(jirix *jiri.X, goFlags, packageSpec []string) (ifcs []*build.Package, err error) { |
| flags := append(goFlags, "-merge-policies="+mergePoliciesFlag.String()) |
| pkgs, err := goutil.List(jirix, flags, packageSpec...) |
| if err != nil { |
| return nil, err |
| } |
| importer := func(pkgs []string) ([]*build.Package, error) { |
| pkgInfos := []*build.Package{} |
| for _, pkg := range pkgs { |
| pkgInfo, err := build.Default.Import(pkg, ".", build.ImportMode(build.ImportComment)) |
| if err != nil { |
| return nil, err |
| } |
| pkgInfos = append(pkgInfos, pkgInfo) |
| } |
| return pkgInfos, nil |
| } |
| bpkgs, err := importer(pkgs) |
| if err != nil { |
| return nil, fmt.Errorf("error importing packages: %v", err) |
| } |
| return bpkgs, nil |
| } |
| |
| // exists is used as the value to indicate existence for maps that |
| // function as sets. |
| var exists = struct{}{} |
| |
| func initInjectorFlags() error { |
| parts := strings.FieldsFunc(injectCallImportFlag, unicode.IsSpace) |
| var err error |
| switch len(parts) { |
| case 1: |
| injectImportTag = "" |
| injectImportPath, err = strconv.Unquote(injectCallImportFlag) |
| if err != nil { |
| injectImportPath = injectCallImportFlag |
| } |
| injectPackage = path.Base(injectImportPath) |
| case 2: |
| injectImportTag = parts[0] |
| injectImportPath, err = strconv.Unquote(parts[1]) |
| if err != nil { |
| injectImportPath = parts[1] |
| } |
| injectPackage = injectImportTag |
| default: |
| return fmt.Errorf("%q doesn't look like an import declaration", injectCallImportFlag) |
| } |
| injectCall = injectCallFlag |
| return nil |
| } |
| |
| // run runs the log injector. |
| func runInjector(jirix *jiri.X, goFlags, interfaceList, implementationList []string, checkOnly bool) error { |
| if err := initInjectorFlags(); err != nil { |
| return err |
| } |
| // use 'go list' and the builder to import all of the packages |
| // specified as interfaces and implementations. |
| ifcs, err := importPkgs(jirix, goFlags, interfaceList) |
| if err != nil { |
| return err |
| } |
| |
| impls, err := importPkgs(jirix, goFlags, implementationList) |
| if err != nil { |
| return err |
| } |
| |
| printHeader(jirix.Stdout(), "Package Summary") |
| progressMsg(jirix.Stdout(), "%v expands to %d interface packages\n", interfaceList, len(ifcs)) |
| progressMsg(jirix.Stdout(), "%v expands to %d implementation packages\n", implementationList, len(impls)) |
| |
| ps := newState(jirix) |
| checkFailed := []string{} |
| |
| printHeader(jirix.Stdout(), "Parsing and Type Checking Interface Packages") |
| |
| ifcPkgs := []*types.Package{} |
| for _, ifc := range ifcs { |
| _, tpkg, err := ps.parseAndTypeCheckPackage(ifc) |
| if err != nil { |
| return fmt.Errorf("failed to parse+type check: %s: %s", ifc.ImportPath, err) |
| } |
| ifcPkgs = append(ifcPkgs, tpkg) |
| } |
| publicInterfaces := findPublicInterfaces(jirix, ifcPkgs) |
| |
| for _, impl := range impls { |
| printHeader(jirix.Stdout(), "Parsing and Type Checking Implementation Packages") |
| asts, tpkg, err := ps.parseAndTypeCheckPackage(impl) |
| if err != nil { |
| return fmt.Errorf("failed to parse+type check: %s: %s", impl.ImportPath, err) |
| } |
| |
| // Now find the methods that implement those public interfaces. |
| methods := findMethodsImplementing(jirix, ps.fset, tpkg, publicInterfaces) |
| |
| // and their positions in the files. |
| methodPositions, err := functionDeclarationsAtPositions(ps.fset, asts, ps.info, methods) |
| if err != nil { |
| return err |
| } |
| // then check to see if those methods already have logging statements. |
| needsInjection := checkMethods(methodPositions) |
| |
| if checkOnly { |
| if len(needsInjection) > 0 { |
| printHeader(jirix.Stdout(), "Check Results") |
| reportResults(jirix, ps.fset, needsInjection) |
| checkFailed = append(checkFailed, impl.ImportPath) |
| } |
| } else { |
| if err := inject(jirix, ps.fset, needsInjection); err != nil { |
| return fmt.Errorf("injection failed for: %s: %s", impl.ImportPath, err) |
| } |
| } |
| } |
| |
| if checkOnly && len(checkFailed) > 0 { |
| for _, p := range checkFailed { |
| fmt.Fprintf(jirix.Stdout(), "check failed for: %s\n", p) |
| } |
| os.Exit(1) |
| } |
| |
| return nil |
| } |
| |
| func initRemoverFlags() error { |
| parts := strings.Split(removeCallFlag, ".") |
| switch len(parts) { |
| case 2: |
| removePackage = parts[0] |
| removeCall = parts[1] |
| default: |
| return fmt.Errorf("%q doesn't look like a function call on an imported package", removeCallFlag) |
| } |
| return nil |
| } |
| |
| func runRemover(jirix *jiri.X, goFlags, implementationList []string) error { |
| if err := initRemoverFlags(); err != nil { |
| return err |
| } |
| |
| // use 'go list' and the builder to import all of the packages |
| // specified as implementations. |
| impls, err := importPkgs(jirix, goFlags, implementationList) |
| if err != nil { |
| return err |
| } |
| |
| ps := newState(jirix) |
| |
| printHeader(jirix.Stdout(), "Package Summary") |
| progressMsg(jirix.Stdout(), "%v expands to %d implementation packages\n", implementationList, len(impls)) |
| |
| for _, impl := range impls { |
| asts, tpkg, err := ps.parseAndTypeCheckPackage(impl) |
| if err != nil { |
| return fmt.Errorf("failed to parse+type check: %s: %s", impl.ImportPath, err) |
| } |
| methods := findMethods(jirix, ps.fset, tpkg) |
| methodPositions, err := functionDeclarationsAtPositions(ps.fset, asts, ps.info, methods) |
| if err != nil { |
| return err |
| } |
| needsRemoval := findRemovals(methodPositions) |
| if err := remove(jirix, ps.fset, needsRemoval); err != nil { |
| return fmt.Errorf("removal failed for: %s: %s", impl.ImportPath, err) |
| } |
| } |
| return nil |
| } |
| |
| // funcDeclRef stores a reference to a function declaration, paired |
| // with the file containing it. |
| type funcDeclRef struct { |
| Decl *ast.FuncDecl |
| File *ast.File |
| LogCall string |
| } |
| |
| // methodSetVisibleThroughInterfaces returns intersection of all |
| // exported method names implemented by t and the union of all method |
| // names declared by interfaces. |
| func methodSetVisibleThroughInterfaces(t types.Type, interfaces []*types.Interface) map[string]struct{} { |
| set := map[string]struct{}{} |
| for _, ifc := range interfaces { |
| if types.Implements(t, ifc) || types.Implements(types.NewPointer(t), ifc) { |
| // t implements ifc, so add all the public |
| // method names of ifc to set. |
| for i := 0; i < ifc.NumMethods(); i++ { |
| if name := ifc.Method(i).Name(); ast.IsExported(name) { |
| set[name] = exists |
| } |
| } |
| } |
| } |
| return set |
| } |
| |
| func hasV23Context(info *types.Info, parameters *ast.FieldList) (*ast.FieldList, string) { |
| if !useContextFlag { |
| return parameters, "" |
| } |
| if parameters == nil { |
| return nil, "nil" |
| } |
| filtered := *parameters |
| for i, field := range filtered.List { |
| typ := info.TypeOf(field.Type) |
| ptr, ok := typ.(*types.Pointer) |
| if !ok { |
| continue |
| } |
| named, ok := ptr.Elem().(*types.Named) |
| if !ok { |
| continue |
| } |
| name := named.Obj() |
| if name.Pkg().Path() == v23ContextPackage && name.Name() == v23ContextTypeName { |
| filtered.List = append(filtered.List[:i], filtered.List[i+1:]...) |
| jirixname := "nil" |
| if len(field.Names) > 0 && field.Names[0].Name != "_" { |
| jirixname = field.Names[0].Name |
| } |
| return &filtered, jirixname |
| } |
| } |
| return &filtered, "nil" |
| } |
| |
| func genFmt(info *types.Info, fields *ast.FieldList, indirect bool) ([]string, []string, error) { |
| |
| fmtForBasicType := func(typ *types.Basic) string { |
| if typ.Kind() == types.String { |
| return "%.10s..." |
| } else { |
| return "%v" |
| } |
| } |
| |
| if fields == nil { |
| return nil, nil, nil |
| } |
| format := []string{} |
| args := []string{} |
| for _, param := range fields.List { |
| typ := info.TypeOf(param.Type) |
| var f string |
| printable := false |
| ellipsis := false |
| switch v := typ.(type) { |
| case *types.Basic: |
| f = fmtForBasicType(v) |
| printable = true |
| case *types.Named: |
| switch u := typ.Underlying().(type) { |
| case *types.Basic: |
| f = fmtForBasicType(u) |
| printable = true |
| case *types.Interface: |
| if v.Obj().Name() == "error" { |
| f = "%v" |
| printable = true |
| } |
| } |
| case nil: |
| if _, ok := param.Type.(*ast.Ellipsis); !ok { |
| return nil, nil, fmt.Errorf("failed to locate type for %v", param.Names) |
| } |
| // We'll print out the ellipsis args as a slice of whatever type it is. |
| f = "%v" |
| printable = true |
| ellipsis = true |
| } |
| for _, n := range param.Names { |
| if n.Name != "_" && len(n.Name) > 0 { |
| if printable { |
| if ellipsis { |
| format = append(format, n.Name+"...="+f) |
| } else { |
| format = append(format, n.Name+"="+f) |
| } |
| name := n.Name |
| if indirect { |
| name = "&" + name |
| } |
| args = append(args, name) |
| } else { |
| format = append(format, n.Name+"=") |
| } |
| } |
| } |
| } |
| return format, args, nil |
| } |
| |
| func genCall(info *types.Info, params, results *ast.FieldList) (string, error) { |
| params, contextPar := hasV23Context(info, params) |
| noargs := fmt.Sprintf("\n\tdefer %s.%s(%s)(%s) %s", injectPackage, injectCall, contextPar, contextPar, logCallComment) |
| if info == nil { |
| return noargs, nil |
| } |
| |
| argFormat, printableArgs, err := genFmt(info, params, false) |
| if err != nil { |
| return "", err |
| } |
| |
| resFormat, printableResults, err := genFmt(info, results, true) |
| if err != nil { |
| return "", err |
| } |
| |
| if len(argFormat) == 0 && len(resFormat) == 0 { |
| return noargs, nil |
| } |
| |
| formatArgs := func(format, parameters []string) string { |
| if len(format) > 0 { |
| formatStr := strings.TrimSpace(strings.Join(format, ",")) |
| parametersStr := strings.Join(parameters, ",") |
| return fmt.Sprintf("\"%s\", %s", formatStr, parametersStr) |
| } |
| return "\"\"" |
| } |
| |
| pars := formatArgs(argFormat, printableArgs) |
| res := formatArgs(resFormat, printableResults) |
| |
| contextParArg, contextParRes := contextPar, contextPar |
| if len(contextPar) > 0 { |
| if len(pars) > 0 { |
| contextParArg += ", " |
| } |
| if len(res) > 0 { |
| contextParRes += ", " |
| } |
| } |
| |
| return fmt.Sprintf("\n\tdefer %s.%sf(%s%s)(%s%s) %s", injectPackage, injectCall, contextParArg, pars, contextParRes, res, logCallComment), nil |
| } |
| |
| // functionDeclarationsAtPositions returns references to function |
| // declarations in packages where the position of the identifier token |
| // representing the name of the function is in positions. |
| func functionDeclarationsAtPositions(fset *token.FileSet, files []*ast.File, info *types.Info, positions map[token.Pos]struct{}) ([]funcDeclRef, error) { |
| result := []funcDeclRef{} |
| for _, file := range files { |
| for _, decl := range file.Decls { |
| if decl, ok := decl.(*ast.FuncDecl); ok { |
| call, err := genCall(info, decl.Type.Params, decl.Type.Results) |
| if err != nil { |
| pos := fset.Position(decl.Pos()) |
| return nil, fmt.Errorf("%s:%d: %v", pos.Filename, pos.Line, err) |
| } |
| // for each function declaration in packages: |
| // |
| // it's important not to use decl.Pos() here |
| // as it gives us the position of the "func" |
| // token, whereas positions has collected |
| // the locations of method name tokens: |
| if _, ok := positions[decl.Name.Pos()]; ok { |
| result = append(result, funcDeclRef{decl, file, call}) |
| } |
| } |
| } |
| } |
| return result, nil |
| } |
| |
| // findMethodsImplementing searches the specified packages and returns |
| // a list of function declarations that are implementations for |
| // the specified interfaces. |
| func findMethodsImplementing(jirix *jiri.X, fset *token.FileSet, tpkg *types.Package, interfaces []*types.Interface) map[token.Pos]struct{} { |
| // positions will hold the set of Pos values of methods |
| // that should be logged. Each element will be the position of |
| // the identifier token representing the method name of such |
| // methods. The reason we collect the positions first is that |
| // our static analysis library has no easy way to map types.Func |
| // objects to ast.FuncDecl objects, so we then look into AST |
| // declarations and find everything that has a matching position. |
| positions := map[token.Pos]struct{}{} |
| |
| printHeader(jirix.Stdout(), "Methods Implementing Public Interfaces in %s", tpkg.Path()) |
| |
| scope := tpkg.Scope() |
| for _, child := range scope.Names() { |
| object := scope.Lookup(child) |
| typ := object.Type() |
| // ignore interfaces as they have no method implementations |
| if types.IsInterface(typ) { |
| continue |
| } |
| |
| // for each non-interface type t declared in packages: |
| apiMethodSet := methodSetVisibleThroughInterfaces(typ, interfaces) |
| |
| // optimization: if t implements no non-empty interfaces that |
| // we care about, we can just ignore it. |
| if len(apiMethodSet) > 0 { |
| // find all the methods explicitly declared or implicitly |
| // inherited through embedding on type t or *t. |
| methodSet := types.NewMethodSet(typ) |
| if methodSet.Len() == 0 { |
| methodSet = types.NewMethodSet(types.NewPointer(typ)) |
| } |
| for i := 0; i < methodSet.Len(); i++ { |
| method := methodSet.At(i) |
| fn := method.Obj().(*types.Func) |
| // t may have a method that is not declared in any of |
| // the interfaces we care about. No need to log that. |
| if _, ok := apiMethodSet[fn.Name()]; ok { |
| if fn.Pos() == 0 { |
| // Embedded functions show up with a zero pos. |
| continue |
| } |
| progressMsg(jirix.Stdout(), "%s.%s: %s\n", tpkg.Path(), fn.Name(), fset.Position(fn.Pos())) |
| positions[fn.Pos()] = exists |
| } |
| } |
| } |
| } |
| return positions |
| } |
| |
| func findMethodsInScope(jirix *jiri.X, fset *token.FileSet, positions map[token.Pos]struct{}, scope *types.Scope) { |
| for _, child := range scope.Names() { |
| object := scope.Lookup(child) |
| typ := object.Type() |
| switch v := typ.(type) { |
| case *types.Named: |
| for i := 0; i < v.NumMethods(); i++ { |
| m := v.Method(i) |
| positions[m.Pos()] = exists |
| } |
| case *types.Signature: |
| positions[object.Pos()] = exists |
| } |
| } |
| } |
| |
| func findMethods(jirix *jiri.X, fset *token.FileSet, tpkg *types.Package) map[token.Pos]struct{} { |
| positions := map[token.Pos]struct{}{} |
| printHeader(jirix.Stdout(), "Methods in %s", tpkg.Path()) |
| scope := tpkg.Scope() |
| findMethodsInScope(jirix, fset, positions, scope) |
| return positions |
| } |
| |
| type patch struct { |
| Offset int |
| Text string |
| NextOffset int |
| } |
| |
| type patchSorter []patch |
| |
| func insertAt(offset int, text string) patch { |
| return patch{ |
| Offset: offset, |
| Text: text, |
| NextOffset: offset, |
| } |
| } |
| |
| func removeRange(from, to int) patch { |
| return patch{ |
| Offset: from, |
| NextOffset: to, |
| } |
| } |
| |
| func (p patchSorter) Len() int { |
| return len(p) |
| } |
| |
| func (p patchSorter) Less(i, j int) bool { |
| return p[i].Offset < p[j].Offset |
| } |
| |
| func (p patchSorter) Swap(i, j int) { |
| p[i], p[j] = p[j], p[i] |
| } |
| |
| // countOverlap counts the length of the common prefix between two strings. |
| func countOverlap(a, b string) (i int) { |
| for ; i < len(a) && i < len(b) && a[i] == b[i]; i++ { |
| } |
| return |
| } |
| |
| // ©3ImportLogPackage will make sure that the file includes an |
| // import declaration to the package to be injected, and adds one if it does not |
| // already. |
| func ensureImportLogPackage(fset *token.FileSet, file *ast.File) (patch, bool) { |
| maxOverlap := 0 |
| var candidate token.Pos |
| |
| quotedImportPath := strconv.Quote(injectImportPath) |
| |
| for _, d := range file.Decls { |
| d, ok := d.(*ast.GenDecl) |
| if !ok || d.Tok != token.IMPORT { |
| // We encountered a non-import declaration. As |
| // imports always precede other declarations, |
| // we are done with our search. |
| break |
| } |
| |
| for _, s := range d.Specs { |
| s := s.(*ast.ImportSpec) |
| tag := "" |
| if s.Name != nil { |
| tag = s.Name.Name |
| } |
| path := s.Path.Value |
| |
| // Match import tag. |
| if len(injectImportTag) > 0 && injectImportTag == tag { |
| return patch{}, false |
| } |
| |
| // Match path. |
| if quotedImportPath == path { |
| return patch{}, false |
| } |
| |
| // Keep track of which import in a parenthesised list of imports |
| // has the greatest overlap with the one we're going to add - i.e. |
| // make sure we insert the new import in the lexicographically ordered |
| // location. |
| overlap := countOverlap(s.Path.Value, quotedImportPath) |
| if d.Lparen.IsValid() && overlap > maxOverlap { |
| maxOverlap = overlap |
| candidate = s.Pos() |
| } |
| } |
| } |
| |
| impStmt := func() string { |
| if len(injectImportTag) > 0 { |
| return injectImportTag + " " + quotedImportPath + "\n" |
| } |
| return quotedImportPath + "\n" |
| } |
| |
| if maxOverlap > 0 { |
| return insertAt(fset.Position(candidate).Offset, impStmt()), true |
| } |
| |
| // No import declaration found with parenthesis; create a new |
| // one and add it to the beginning of the file. |
| return insertAt(fset.Position(file.Decls[0].Pos()).Offset, "import "+impStmt()), true |
| } |
| |
| // methodBeginsWithNoLogComment returns true if method has a |
| // "nologcall" comment before any non-whitespace or non-comment token. |
| func methodBeginsWithNoLogComment(m funcDeclRef) bool { |
| method := m.Decl |
| lbound := method.Body.Lbrace |
| ubound := method.Body.Rbrace |
| stmts := method.Body.List |
| if len(stmts) > 0 { |
| ubound = stmts[0].Pos() |
| } |
| |
| for _, cmt := range m.File.Comments { |
| if lbound <= cmt.Pos() && cmt.End() <= ubound { |
| for _, line := range strings.Split(cmt.Text(), "\n") { |
| line := strings.TrimSpace(line) |
| if line == nologComment { |
| return true |
| } |
| } |
| } |
| } |
| |
| return false |
| } |
| |
| func findRemovals(methods []funcDeclRef) map[funcDeclRef]error { |
| result := map[funcDeclRef]error{} |
| for _, m := range methods { |
| if err := validateLogStatement(m.Decl, removePackage, removeCall); err == nil { |
| result[m] = nil |
| } |
| } |
| return result |
| } |
| |
| // checkMethods checks all items in methods and returns the subset |
| // of them that do not have valid log statements. |
| func checkMethods(methods []funcDeclRef) map[funcDeclRef]error { |
| result := map[funcDeclRef]error{} |
| for _, m := range methods { |
| if err := checkMethod(m); err != nil { |
| result[m] = err |
| } |
| } |
| return result |
| } |
| |
| // checkMethod checks that method includes an acceptable logging |
| // construct before any other non-whitespace or non-comment token. |
| func checkMethod(method funcDeclRef) error { |
| if err := validateLogStatement(method.Decl, injectPackage, injectCall); err != nil && !methodBeginsWithNoLogComment(method) { |
| return err |
| } |
| return nil |
| } |
| |
| // gofmt runs "gofmt -w files...". |
| func gofmt(jirix *jiri.X, verbose bool, files []string) error { |
| if len(files) == 0 || !gofmtFlag { |
| return nil |
| } |
| return jirix.NewSeq().Verbose(verbose).Last("gofmt", append([]string{"-w"}, files...)...) |
| } |
| |
| // writeFiles writes out files modified by the patch sets supplied to it. |
| func writeFiles(jirix *jiri.X, fset *token.FileSet, files map[*ast.File][]patch) (e error) { |
| filesToFormat := []string{} |
| |
| // Write out files in a fixed order so that other tools/tests can count on the |
| // diff output. |
| filenames := []string{} |
| asts := map[string]*ast.File{} |
| for file, _ := range files { |
| filename := fset.Position(file.Pos()).Filename |
| filenames = append(filenames, filename) |
| asts[filename] = file |
| } |
| sort.Strings(filenames) |
| |
| s := jirix.NewSeq() |
| for _, filename := range filenames { |
| file := asts[filename] |
| patches := files[file] |
| filesToFormat = append(filesToFormat, filename) |
| sort.Sort(patchSorter(patches)) |
| src, err := ioutil.ReadFile(filename) |
| if err != nil { |
| return err |
| } |
| beginOffset := 0 |
| patchedSrc := []byte{} |
| for _, patch := range patches { |
| patchedSrc = append(patchedSrc, src[beginOffset:patch.Offset]...) |
| patchedSrc = append(patchedSrc, patch.Text...) |
| beginOffset = patch.NextOffset |
| } |
| patchedSrc = append(patchedSrc, src[beginOffset:]...) |
| if diffOnlyFlag { |
| tmpDir, err := s.TempDir("", "") |
| if err != nil { |
| return err |
| } |
| tmpFilename := filepath.Join(tmpDir, "gologcop-"+filepath.Base(filename)) |
| defer collect.Error(func() error { return jirix.NewSeq().RemoveAll(tmpDir).Done() }, &e) |
| if err := s.WriteFile(tmpFilename, patchedSrc, os.FileMode(0644)).Done(); err != nil { |
| return err |
| } |
| progressMsg(jirix.Stdout(), "Diffing %s with %s\n", filename, tmpFilename) |
| gofmt(jirix, false, []string{tmpFilename}) |
| s.Verbose(false).Capture(jirix.Stdout(), jirix.Stderr()).Last("diff", filename, tmpFilename) |
| } else { |
| s.WriteFile(filename, patchedSrc, 644).Done() |
| } |
| } |
| if diffOnlyFlag { |
| return nil |
| } |
| return gofmt(jirix, jirix.Verbose(), filesToFormat) |
| } |
| |
| // remove removes a log call at the beginning of each method in methods. |
| func remove(jirix *jiri.X, fset *token.FileSet, methods map[funcDeclRef]error) error { |
| files := map[*ast.File][]patch{} |
| comments := map[*ast.File]ast.CommentMap{} |
| for fdRef, _ := range methods { |
| file := fdRef.File |
| if _, present := comments[file]; !present { |
| comments[fdRef.File] = ast.NewCommentMap(fset, file, file.Comments) |
| } |
| } |
| |
| // endAt returns the position of the next statement, comment or function, |
| // i.e. the end of the block of code to be removed. |
| endAt := func(fn *ast.FuncDecl, cm ast.CommentMap) int { |
| endpos := fn.Body.Rbrace |
| stmt := fn.Body.List[0] |
| if len(fn.Body.List) > 1 { |
| nextStmt := fn.Body.List[1] |
| endpos = nextStmt.Pos() |
| if cg := cm.Filter(nextStmt).Comments(); len(cg) > 0 { |
| if len(cg[0].List) > 0 { |
| if cg[0].List[0].Pos() < endpos { |
| // Only use a comment if it comes before the next statemnt. |
| endpos = cg[0].List[0].Pos() |
| } |
| } |
| } |
| } |
| stmtLine := fset.Position(stmt.Pos()).Line |
| // Delete any comment on the same line as the logcall. |
| for _, cg := range cm.Filter(stmt).Comments() { |
| for _, c := range cg.List { |
| if fset.Position(c.Pos()).Line > stmtLine { |
| endpos = c.Pos() |
| break |
| } |
| } |
| } |
| return fset.Position(endpos).Offset |
| } |
| |
| for m, _ := range methods { |
| file := m.File |
| stmts := m.Decl.Body.List |
| if len(stmts) == 0 { |
| return fmt.Errorf("no statements found for %s", m.Decl.Name) |
| } |
| // The first statement should be the call we want to remove. |
| start := fset.Position(stmts[0].Pos()).Offset |
| end := endAt(m.Decl, comments[m.File]) |
| files[file] = append(files[file], removeRange(start, end)) |
| } |
| return writeFiles(jirix, fset, files) |
| } |
| |
| // inject injects a log call at the beginning of each method in methods. |
| func inject(jirix *jiri.X, fset *token.FileSet, methods map[funcDeclRef]error) error { |
| // Warn the user for methods that already have something at |
| // their beginning that looks like a logging construct, but it |
| // is invalid for some reason. |
| for m, err := range methods { |
| if _, ok := err.(*errInvalid); ok { |
| method := m.Decl |
| position := fset.Position(method.Pos()) |
| methodName := method.Name.Name |
| fmt.Fprintf(jirix.Stdout(), "Warning: %v: %s: %v\n", position, methodName, err) |
| } |
| } |
| |
| files := map[*ast.File][]patch{} |
| for m, _ := range methods { |
| text := m.LogCall |
| // Catch the case where the function body is on the same line - e.g. func() {} |
| // so that we make sure we add a newline to the comment to push the right brace |
| // onto the next line. |
| if fset.Position(m.Decl.Body.Lbrace).Line == fset.Position(m.Decl.Body.Rbrace).Line { |
| text += "\n" |
| } |
| delta := insertAt(fset.Position(m.Decl.Body.Lbrace).Offset+1, text) |
| file := m.File |
| files[file] = append(files[file], delta) |
| } |
| |
| for file, deltas := range files { |
| if delta, hasChanges := ensureImportLogPackage(fset, file); hasChanges { |
| files[file] = append(deltas, delta) |
| } |
| } |
| return writeFiles(jirix, fset, files) |
| } |
| |
| // reportResults prints out the validation results from checkMethods |
| // in a human-readable form. |
| func reportResults(jirix *jiri.X, fset *token.FileSet, methods map[funcDeclRef]error) { |
| for m, err := range methods { |
| fmt.Fprintf(jirix.Stdout(), "%v: %s: %v\n", fset.Position(m.Decl.Pos()), m.Decl.Name.Name, err) |
| } |
| } |
| |
| // ensureExprsArePointers returns an error if at least one of the |
| // expressions in exprs is not in the form of &x. |
| func ensureExprsArePointers(exprs []ast.Expr) error { |
| for _, expr := range exprs { |
| if !isAddressOfExpression(expr) { |
| return &errInvalid{"output arguments should be passed to the log function via their addresses"} |
| } |
| } |
| return nil |
| } |
| |
| // validateLogStatement returns an error if method does not begin |
| // with a valid defer call. |
| func validateLogStatement(method *ast.FuncDecl, pkg, name string) error { |
| stmtList := method.Body.List |
| |
| if len(stmtList) == 0 { |
| return &errNotExists{"empty method"} |
| } |
| |
| deferStmt, ok := stmtList[0].(*ast.DeferStmt) |
| if !ok { |
| return &errNotExists{"no defer statement"} |
| } |
| |
| logCall, ok := deferStmt.Call.Fun.(*ast.CallExpr) |
| if !ok { |
| return &errNotExists{"defer is a not a function call"} |
| } |
| |
| selector, ok := logCall.Fun.(*ast.SelectorExpr) |
| if !ok { |
| return &errNotExists{"not a <pkg>.<method> call"} |
| } |
| |
| packageIdent, ok := selector.X.(*ast.Ident) |
| if !ok { |
| return &errNotExists{"not a valid package selector"} |
| } |
| |
| if packageIdent.Name != pkg { |
| return &errNotExists{fmt.Sprintf("wrong package: got %q, want %q", packageIdent.Name, pkg)} |
| } |
| |
| deferArgs := deferStmt.Call.Args |
| if useContextFlag && len(deferArgs) > 0 { |
| deferArgs = deferArgs[1:] |
| } |
| |
| switch selector.Sel.Name { |
| case name: |
| return ensureExprsArePointers(deferArgs) |
| case name + "f": |
| nFnArgs := 0 |
| if fnCall, ok := deferStmt.Call.Fun.(*ast.CallExpr); ok { |
| nFnArgs = len(fnCall.Args) |
| } |
| if nFnArgs < 1 { |
| return &errInvalid{"no format specifier specified for called defer func: " + name} |
| } |
| nCallArgs := len(deferStmt.Call.Args) |
| if nCallArgs < 1 { |
| return &errInvalid{"no format specifier specified for returned defer func: " + name} |
| } |
| if len(deferArgs) > 0 { |
| // Skip past format flag, but if we're called for a Remove |
| // then we can't be sure there is a format. |
| deferArgs = deferArgs[1:] |
| } |
| return ensureExprsArePointers(deferArgs) |
| } |
| |
| return &errNotExists{fmt.Sprintf("got \"%s.%s\", want \"%s.%s\"", packageIdent.Name, selector.Sel.Name, pkg, name)} |
| } |
| |
| // isAddressOfExpression checks if expr is an expression in the form |
| // of `&expression` |
| func isAddressOfExpression(expr ast.Expr) (isAddrExpr bool) { |
| // TODO: support (&x) as well as &x |
| unaryExpr, ok := expr.(*ast.UnaryExpr) |
| return ok && unaryExpr.Op == token.AND |
| } |
| |
| func printHeader(out io.Writer, format string, args ...interface{}) { |
| if progressFlag { |
| s := fmt.Sprintf(format, args...) |
| fmt.Fprintln(out) |
| fmt.Fprintln(out, s) |
| fmt.Fprintln(out, strings.Repeat("=", len(s))) |
| } |
| } |
| |
| func progressMsg(out io.Writer, format string, args ...interface{}) { |
| if progressFlag { |
| fmt.Fprintf(out, format, args...) |
| } |
| } |
| |
| // findPublicInterfaces returns all the public interfaces defined in the |
| // supplied packages. |
| func findPublicInterfaces(jirix *jiri.X, ifcs []*types.Package) (interfaces []*types.Interface) { |
| for _, ifc := range ifcs { |
| printHeader(jirix.Stdout(), "Public Interfaces for %s", ifc.Path()) |
| scope := ifc.Scope() |
| for _, child := range scope.Names() { |
| object := scope.Lookup(child) |
| typ := object.Type() |
| |
| if object.Exported() && types.IsInterface(typ) { |
| ifcType := typ.Underlying().(*types.Interface) |
| |
| if !ifcType.Empty() { |
| progressMsg(jirix.Stdout(), "%s.%s\n", ifc.Path(), object.Name()) |
| interfaces = append(interfaces, ifcType) |
| } |
| } |
| } |
| } |
| return interfaces |
| } |
| |
| type errInvalid struct { |
| message string |
| } |
| |
| func (l errInvalid) Error() string { |
| if len(l.message) > 0 { |
| return l.message |
| } |
| return "invalid log statement" |
| } |
| |
| type errNotExists struct { |
| message string |
| } |
| |
| func (e errNotExists) Error() string { |
| return fmt.Sprintf("injected statement does not exist: %s", e.message) |
| } |