gvisor/tools/go_generics/globals/globals_visitor.go

598 lines
14 KiB
Go

// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package globals provides an AST visitor that calls the visit function for all
// global identifiers.
package globals
import (
"fmt"
"go/ast"
"go/token"
"path/filepath"
"strconv"
)
// globalsVisitor holds the state used while traversing the nodes of a file in
// search of globals.
//
// The visitor does two passes on the global declarations: the first one adds
// all globals to the global scope (since Go allows references to globals that
// haven't been declared yet), and the second one calls f() for the definition
// and uses of globals found in the first pass.
//
// The implementation correctly handles cases when globals are aliased by
// locals; in such cases, f() is not called.
type globalsVisitor struct {
// file is the file whose nodes are being visited.
file *ast.File
// fset is the file set the file being visited belongs to.
fset *token.FileSet
// f is the visit function to be called when a global symbol is reached.
f func(*ast.Ident, SymKind)
// scope is the current scope as nodes are visited.
scope *scope
// processAnon indicates whether we should process anonymous struct fields.
// It does not perform strict checking on parameter types that share the same name
// as the global type and therefore will rename them as well.
processAnon bool
}
// unexpected is called when an unexpected node appears in the AST. It dumps
// the location of the associated token and panics because this should only
// happen when there is a bug in the traversal code.
func (v *globalsVisitor) unexpected(p token.Pos) {
panic(fmt.Sprintf("Unable to parse at %v", v.fset.Position(p)))
}
// pushScope creates a new scope and pushes it to the top of the scope stack.
func (v *globalsVisitor) pushScope() {
v.scope = newScope(v.scope)
}
// popScope removes the scope created by the last call to pushScope.
func (v *globalsVisitor) popScope() {
v.scope = v.scope.outer
}
// visitType is called when an expression is known to be a type, for example,
// on the first argument of make(). It visits all children nodes and reports
// any globals.
func (v *globalsVisitor) visitType(ge ast.Expr) {
switch e := ge.(type) {
case *ast.Ident:
if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() {
v.f(e, s.kind)
}
case *ast.SelectorExpr:
id := GetIdent(e.X)
if id == nil {
v.unexpected(e.X.Pos())
}
case *ast.StarExpr:
v.visitType(e.X)
case *ast.ParenExpr:
v.visitType(e.X)
case *ast.ChanType:
v.visitType(e.Value)
case *ast.Ellipsis:
v.visitType(e.Elt)
case *ast.ArrayType:
v.visitExpr(e.Len)
v.visitType(e.Elt)
case *ast.MapType:
v.visitType(e.Key)
v.visitType(e.Value)
case *ast.StructType:
v.visitFields(e.Fields, KindUnknown)
case *ast.FuncType:
v.visitFields(e.Params, KindUnknown)
v.visitFields(e.Results, KindUnknown)
case *ast.InterfaceType:
v.visitFields(e.Methods, KindUnknown)
default:
v.unexpected(ge.Pos())
}
}
// visitFields visits all fields, and add symbols if kind isn't KindUnknown.
func (v *globalsVisitor) visitFields(l *ast.FieldList, kind SymKind) {
if l == nil {
return
}
for _, f := range l.List {
if kind != KindUnknown {
for _, n := range f.Names {
v.scope.add(n.Name, kind, n.Pos())
}
}
v.visitType(f.Type)
if f.Tag != nil {
tag := ast.NewIdent(f.Tag.Value)
v.f(tag, KindTag)
// Replace the tag if updated.
if tag.Name != f.Tag.Value {
f.Tag.Value = tag.Name
}
}
}
}
// visitGenDecl is called when a generic declaration is encountered, for example,
// on variable, constant and type declarations. It adds all newly defined
// symbols to the current scope and reports them if the current scope is the
// global one.
func (v *globalsVisitor) visitGenDecl(d *ast.GenDecl) {
switch d.Tok {
case token.IMPORT:
case token.TYPE:
for _, gs := range d.Specs {
s := gs.(*ast.TypeSpec)
v.scope.add(s.Name.Name, KindType, s.Name.Pos())
if v.scope.isGlobal() {
v.f(s.Name, KindType)
}
v.visitType(s.Type)
}
case token.CONST, token.VAR:
kind := KindConst
if d.Tok == token.VAR {
kind = KindVar
}
for _, gs := range d.Specs {
s := gs.(*ast.ValueSpec)
if s.Type != nil {
v.visitType(s.Type)
}
for _, e := range s.Values {
v.visitExpr(e)
}
for _, n := range s.Names {
if v.scope.isGlobal() {
v.f(n, kind)
}
v.scope.add(n.Name, kind, n.Pos())
}
}
default:
v.unexpected(d.Pos())
}
}
// isViableType determines if the given expression is a viable type expression,
// that is, if it could be interpreted as a type, for example, sync.Mutex,
// myType, func(int)int, as opposed to -1, 2 * 2, a + b, etc.
func (v *globalsVisitor) isViableType(expr ast.Expr) bool {
switch e := expr.(type) {
case *ast.Ident:
// This covers the plain identifier case. When we see it, we
// have to check if it resolves to a type; if the symbol is not
// known, we'll claim it's viable as a type.
s := v.scope.deepLookup(e.Name)
return s == nil || s.kind == KindType
case *ast.ChanType, *ast.ArrayType, *ast.MapType, *ast.StructType, *ast.FuncType, *ast.InterfaceType, *ast.Ellipsis:
// This covers the following cases:
// 1. ChanType:
// chan T
// <-chan T
// chan<- T
// 2. ArrayType:
// [Expr]T
// 3. MapType:
// map[T]U
// 4. StructType:
// struct { Fields }
// 5. FuncType:
// func(Fields)Returns
// 6. Interface:
// interface { Fields }
// 7. Ellipsis:
// ...T
return true
case *ast.SelectorExpr:
// The only case in which an expression involving a selector can
// be a type is if it has the following form X.T, where X is an
// import, and T is a type exported by X.
//
// There's no way to know whether T is a type because we don't
// parse imports. So we just claim that this is a viable type;
// it doesn't affect the general result because we don't visit
// imported symbols.
id := GetIdent(e.X)
if id == nil {
return false
}
s := v.scope.deepLookup(id.Name)
return s != nil && s.kind == KindImport
case *ast.StarExpr:
// This covers the *T case. The expression is a viable type if
// T is.
return v.isViableType(e.X)
case *ast.ParenExpr:
// This covers the (T) case. The expression is a viable type if
// T is.
return v.isViableType(e.X)
default:
return false
}
}
// visitCallExpr visits a "call expression" which can be either a
// function/method call (e.g., f(), pkg.f(), obj.f(), etc.) call or a type
// conversion (e.g., int32(1), (*sync.Mutex)(ptr), etc.).
func (v *globalsVisitor) visitCallExpr(e *ast.CallExpr) {
if v.isViableType(e.Fun) {
v.visitType(e.Fun)
} else {
v.visitExpr(e.Fun)
}
// If the function being called is new or make, the first argument is
// a type, so it needs to be visited as such.
first := 0
if id := GetIdent(e.Fun); id != nil && (id.Name == "make" || id.Name == "new") {
if len(e.Args) > 0 {
v.visitType(e.Args[0])
}
first = 1
}
for i := first; i < len(e.Args); i++ {
v.visitExpr(e.Args[i])
}
}
// visitExpr visits all nodes of an expression, and reports any globals that it
// finds.
func (v *globalsVisitor) visitExpr(ge ast.Expr) {
switch e := ge.(type) {
case nil:
case *ast.Ident:
if s := v.scope.deepLookup(e.Name); s != nil && s.scope.isGlobal() {
v.f(e, s.kind)
}
case *ast.BasicLit:
case *ast.CompositeLit:
v.visitType(e.Type)
for _, ne := range e.Elts {
v.visitExpr(ne)
}
case *ast.FuncLit:
v.pushScope()
v.visitFields(e.Type.Params, KindParameter)
v.visitFields(e.Type.Results, KindResult)
v.visitBlockStmt(e.Body)
v.popScope()
case *ast.BinaryExpr:
v.visitExpr(e.X)
v.visitExpr(e.Y)
case *ast.CallExpr:
v.visitCallExpr(e)
case *ast.IndexExpr:
v.visitExpr(e.X)
v.visitExpr(e.Index)
case *ast.KeyValueExpr:
v.visitExpr(e.Value)
case *ast.ParenExpr:
v.visitExpr(e.X)
case *ast.SelectorExpr:
v.visitExpr(e.X)
if v.processAnon {
v.visitExpr(e.Sel)
}
case *ast.SliceExpr:
v.visitExpr(e.X)
v.visitExpr(e.Low)
v.visitExpr(e.High)
v.visitExpr(e.Max)
case *ast.StarExpr:
v.visitExpr(e.X)
case *ast.TypeAssertExpr:
v.visitExpr(e.X)
if e.Type != nil {
v.visitType(e.Type)
}
case *ast.UnaryExpr:
v.visitExpr(e.X)
default:
v.unexpected(ge.Pos())
}
}
// GetIdent returns the identifier associated with the given expression by
// removing parentheses if needed.
func GetIdent(expr ast.Expr) *ast.Ident {
switch e := expr.(type) {
case *ast.Ident:
return e
case *ast.ParenExpr:
return GetIdent(e.X)
default:
return nil
}
}
// visitStmt visits all nodes of a statement, and reports any globals that it
// finds. It also adds to the current scope new symbols defined/declared.
func (v *globalsVisitor) visitStmt(gs ast.Stmt) {
switch s := gs.(type) {
case nil, *ast.BranchStmt, *ast.EmptyStmt:
case *ast.AssignStmt:
for _, e := range s.Rhs {
v.visitExpr(e)
}
// We visit the LHS after the RHS because the symbols we'll
// potentially add to the table aren't meant to be visible to
// the RHS.
for _, e := range s.Lhs {
if s.Tok == token.DEFINE {
if n := GetIdent(e); n != nil {
v.scope.add(n.Name, KindVar, n.Pos())
}
}
v.visitExpr(e)
}
case *ast.BlockStmt:
v.visitBlockStmt(s)
case *ast.DeclStmt:
v.visitGenDecl(s.Decl.(*ast.GenDecl))
case *ast.DeferStmt:
v.visitCallExpr(s.Call)
case *ast.ExprStmt:
v.visitExpr(s.X)
case *ast.ForStmt:
v.pushScope()
v.visitStmt(s.Init)
v.visitExpr(s.Cond)
v.visitStmt(s.Post)
v.visitBlockStmt(s.Body)
v.popScope()
case *ast.GoStmt:
v.visitCallExpr(s.Call)
case *ast.IfStmt:
v.pushScope()
v.visitStmt(s.Init)
v.visitExpr(s.Cond)
v.visitBlockStmt(s.Body)
v.visitStmt(s.Else)
v.popScope()
case *ast.IncDecStmt:
v.visitExpr(s.X)
case *ast.LabeledStmt:
v.visitStmt(s.Stmt)
case *ast.RangeStmt:
v.pushScope()
v.visitExpr(s.X)
if s.Tok == token.DEFINE {
if n := GetIdent(s.Key); n != nil {
v.scope.add(n.Name, KindVar, n.Pos())
}
if n := GetIdent(s.Value); n != nil {
v.scope.add(n.Name, KindVar, n.Pos())
}
}
v.visitExpr(s.Key)
v.visitExpr(s.Value)
v.visitBlockStmt(s.Body)
v.popScope()
case *ast.ReturnStmt:
for _, r := range s.Results {
v.visitExpr(r)
}
case *ast.SelectStmt:
for _, ns := range s.Body.List {
c := ns.(*ast.CommClause)
v.pushScope()
v.visitStmt(c.Comm)
for _, bs := range c.Body {
v.visitStmt(bs)
}
v.popScope()
}
case *ast.SendStmt:
v.visitExpr(s.Chan)
v.visitExpr(s.Value)
case *ast.SwitchStmt:
v.pushScope()
v.visitStmt(s.Init)
v.visitExpr(s.Tag)
for _, ns := range s.Body.List {
c := ns.(*ast.CaseClause)
v.pushScope()
for _, ce := range c.List {
v.visitExpr(ce)
}
for _, bs := range c.Body {
v.visitStmt(bs)
}
v.popScope()
}
v.popScope()
case *ast.TypeSwitchStmt:
v.pushScope()
v.visitStmt(s.Init)
v.visitStmt(s.Assign)
for _, ns := range s.Body.List {
c := ns.(*ast.CaseClause)
v.pushScope()
for _, ce := range c.List {
v.visitType(ce)
}
for _, bs := range c.Body {
v.visitStmt(bs)
}
v.popScope()
}
v.popScope()
default:
v.unexpected(gs.Pos())
}
}
// visitBlockStmt visits all statements in the block, adding symbols to a newly
// created scope.
func (v *globalsVisitor) visitBlockStmt(s *ast.BlockStmt) {
v.pushScope()
for _, c := range s.List {
v.visitStmt(c)
}
v.popScope()
}
// visitFuncDecl is called when a function or method declaration is encountered.
// it creates a new scope for the function [optional] receiver, parameters and
// results, and visits all children nodes.
func (v *globalsVisitor) visitFuncDecl(d *ast.FuncDecl) {
// We don't report methods.
if d.Recv == nil {
v.f(d.Name, KindFunction)
}
v.pushScope()
v.visitFields(d.Recv, KindReceiver)
v.visitFields(d.Type.Params, KindParameter)
v.visitFields(d.Type.Results, KindResult)
if d.Body != nil {
v.visitBlockStmt(d.Body)
}
v.popScope()
}
// globalsFromDecl is called in the first, and adds symbols to global scope.
func (v *globalsVisitor) globalsFromGenDecl(d *ast.GenDecl) {
switch d.Tok {
case token.IMPORT:
for _, gs := range d.Specs {
s := gs.(*ast.ImportSpec)
if s.Name == nil {
str, _ := strconv.Unquote(s.Path.Value)
v.scope.add(filepath.Base(str), KindImport, s.Path.Pos())
} else if s.Name.Name != "_" {
v.scope.add(s.Name.Name, KindImport, s.Name.Pos())
}
}
case token.TYPE:
for _, gs := range d.Specs {
s := gs.(*ast.TypeSpec)
v.scope.add(s.Name.Name, KindType, s.Name.Pos())
}
case token.CONST, token.VAR:
kind := KindConst
if d.Tok == token.VAR {
kind = KindVar
}
for _, s := range d.Specs {
for _, n := range s.(*ast.ValueSpec).Names {
v.scope.add(n.Name, kind, n.Pos())
}
}
default:
v.unexpected(d.Pos())
}
}
// visit implements the visiting of globals. It does performs the two passes
// described in the description of the globalsVisitor struct.
func (v *globalsVisitor) visit() {
// Gather all symbols in the global scope. This excludes methods.
v.pushScope()
for _, gd := range v.file.Decls {
switch d := gd.(type) {
case *ast.GenDecl:
v.globalsFromGenDecl(d)
case *ast.FuncDecl:
if d.Recv == nil {
v.scope.add(d.Name.Name, KindFunction, d.Name.Pos())
}
default:
v.unexpected(gd.Pos())
}
}
// Go through the contents of the declarations.
for _, gd := range v.file.Decls {
switch d := gd.(type) {
case *ast.GenDecl:
v.visitGenDecl(d)
case *ast.FuncDecl:
v.visitFuncDecl(d)
}
}
}
// Visit traverses the provided AST and calls f() for each identifier that
// refers to global names. The global name must be defined in the file itself.
//
// The function f() is allowed to modify the identifier, for example, to rename
// uses of global references.
func Visit(fset *token.FileSet, file *ast.File, f func(*ast.Ident, SymKind), processAnon bool) {
v := globalsVisitor{
fset: fset,
file: file,
f: f,
processAnon: processAnon,
}
v.visit()
}