327 lines
8.6 KiB
Go
327 lines
8.6 KiB
Go
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package oracle
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"go/ast"
|
||
|
"go/token"
|
||
|
"sort"
|
||
|
|
||
|
"golang.org/x/tools/go/ast/astutil"
|
||
|
"golang.org/x/tools/go/loader"
|
||
|
"golang.org/x/tools/go/ssa"
|
||
|
"golang.org/x/tools/go/ssa/ssautil"
|
||
|
"golang.org/x/tools/go/types"
|
||
|
"golang.org/x/tools/oracle/serial"
|
||
|
)
|
||
|
|
||
|
var builtinErrorType = types.Universe.Lookup("error").Type()
|
||
|
|
||
|
// whicherrs takes an position to an error and tries to find all types, constants
|
||
|
// and global value which a given error can point to and which can be checked from the
|
||
|
// scope where the error lives.
|
||
|
// In short, it returns a list of things that can be checked against in order to handle
|
||
|
// an error properly.
|
||
|
//
|
||
|
// TODO(dmorsing): figure out if fields in errors like *os.PathError.Err
|
||
|
// can be queried recursively somehow.
|
||
|
func whicherrs(q *Query) error {
|
||
|
lconf := loader.Config{Build: q.Build}
|
||
|
|
||
|
if err := setPTAScope(&lconf, q.Scope); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Load/parse/type-check the program.
|
||
|
lprog, err := lconf.Load()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
q.Fset = lprog.Fset
|
||
|
|
||
|
qpos, err := parseQueryPos(lprog, q.Pos, true) // needs exact pos
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
prog := ssautil.CreateProgram(lprog, ssa.GlobalDebug)
|
||
|
|
||
|
ptaConfig, err := setupPTA(prog, lprog, q.PTALog, q.Reflection)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
path, action := findInterestingNode(qpos.info, qpos.path)
|
||
|
if action != actionExpr {
|
||
|
return fmt.Errorf("whicherrs wants an expression; got %s",
|
||
|
astutil.NodeDescription(qpos.path[0]))
|
||
|
}
|
||
|
var expr ast.Expr
|
||
|
var obj types.Object
|
||
|
switch n := path[0].(type) {
|
||
|
case *ast.ValueSpec:
|
||
|
// ambiguous ValueSpec containing multiple names
|
||
|
return fmt.Errorf("multiple value specification")
|
||
|
case *ast.Ident:
|
||
|
obj = qpos.info.ObjectOf(n)
|
||
|
expr = n
|
||
|
case ast.Expr:
|
||
|
expr = n
|
||
|
default:
|
||
|
return fmt.Errorf("unexpected AST for expr: %T", n)
|
||
|
}
|
||
|
|
||
|
typ := qpos.info.TypeOf(expr)
|
||
|
if !types.Identical(typ, builtinErrorType) {
|
||
|
return fmt.Errorf("selection is not an expression of type 'error'")
|
||
|
}
|
||
|
// Determine the ssa.Value for the expression.
|
||
|
var value ssa.Value
|
||
|
if obj != nil {
|
||
|
// def/ref of func/var object
|
||
|
value, _, err = ssaValueForIdent(prog, qpos.info, obj, path)
|
||
|
} else {
|
||
|
value, _, err = ssaValueForExpr(prog, qpos.info, path)
|
||
|
}
|
||
|
if err != nil {
|
||
|
return err // e.g. trivially dead code
|
||
|
}
|
||
|
|
||
|
// Defer SSA construction till after errors are reported.
|
||
|
prog.BuildAll()
|
||
|
|
||
|
globals := findVisibleErrs(prog, qpos)
|
||
|
constants := findVisibleConsts(prog, qpos)
|
||
|
|
||
|
res := &whicherrsResult{
|
||
|
qpos: qpos,
|
||
|
errpos: expr.Pos(),
|
||
|
}
|
||
|
|
||
|
// TODO(adonovan): the following code is heavily duplicated
|
||
|
// w.r.t. "pointsto". Refactor?
|
||
|
|
||
|
// Find the instruction which initialized the
|
||
|
// global error. If more than one instruction has stored to the global
|
||
|
// remove the global from the set of values that we want to query.
|
||
|
allFuncs := ssautil.AllFunctions(prog)
|
||
|
for fn := range allFuncs {
|
||
|
for _, b := range fn.Blocks {
|
||
|
for _, instr := range b.Instrs {
|
||
|
store, ok := instr.(*ssa.Store)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
gval, ok := store.Addr.(*ssa.Global)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
gbl, ok := globals[gval]
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
// we already found a store to this global
|
||
|
// The normal error define is just one store in the init
|
||
|
// so we just remove this global from the set we want to query
|
||
|
if gbl != nil {
|
||
|
delete(globals, gval)
|
||
|
}
|
||
|
globals[gval] = store.Val
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
ptaConfig.AddQuery(value)
|
||
|
for _, v := range globals {
|
||
|
ptaConfig.AddQuery(v)
|
||
|
}
|
||
|
|
||
|
ptares := ptrAnalysis(ptaConfig)
|
||
|
valueptr := ptares.Queries[value]
|
||
|
for g, v := range globals {
|
||
|
ptr, ok := ptares.Queries[v]
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
if !ptr.MayAlias(valueptr) {
|
||
|
continue
|
||
|
}
|
||
|
res.globals = append(res.globals, g)
|
||
|
}
|
||
|
pts := valueptr.PointsTo()
|
||
|
dedup := make(map[*ssa.NamedConst]bool)
|
||
|
for _, label := range pts.Labels() {
|
||
|
// These values are either MakeInterfaces or reflect
|
||
|
// generated interfaces. For the purposes of this
|
||
|
// analysis, we don't care about reflect generated ones
|
||
|
makeiface, ok := label.Value().(*ssa.MakeInterface)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
constval, ok := makeiface.X.(*ssa.Const)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
c := constants[*constval]
|
||
|
if c != nil && !dedup[c] {
|
||
|
dedup[c] = true
|
||
|
res.consts = append(res.consts, c)
|
||
|
}
|
||
|
}
|
||
|
concs := pts.DynamicTypes()
|
||
|
concs.Iterate(func(conc types.Type, _ interface{}) {
|
||
|
// go/types is a bit annoying here.
|
||
|
// We want to find all the types that we can
|
||
|
// typeswitch or assert to. This means finding out
|
||
|
// if the type pointed to can be seen by us.
|
||
|
//
|
||
|
// For the purposes of this analysis, the type is always
|
||
|
// either a Named type or a pointer to one.
|
||
|
// There are cases where error can be implemented
|
||
|
// by unnamed types, but in that case, we can't assert to
|
||
|
// it, so we don't care about it for this analysis.
|
||
|
var name *types.TypeName
|
||
|
switch t := conc.(type) {
|
||
|
case *types.Pointer:
|
||
|
named, ok := t.Elem().(*types.Named)
|
||
|
if !ok {
|
||
|
return
|
||
|
}
|
||
|
name = named.Obj()
|
||
|
case *types.Named:
|
||
|
name = t.Obj()
|
||
|
default:
|
||
|
return
|
||
|
}
|
||
|
if !isAccessibleFrom(name, qpos.info.Pkg) {
|
||
|
return
|
||
|
}
|
||
|
res.types = append(res.types, &errorType{conc, name})
|
||
|
})
|
||
|
sort.Sort(membersByPosAndString(res.globals))
|
||
|
sort.Sort(membersByPosAndString(res.consts))
|
||
|
sort.Sort(sorterrorType(res.types))
|
||
|
|
||
|
q.result = res
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// findVisibleErrs returns a mapping from each package-level variable of type "error" to nil.
|
||
|
func findVisibleErrs(prog *ssa.Program, qpos *queryPos) map[*ssa.Global]ssa.Value {
|
||
|
globals := make(map[*ssa.Global]ssa.Value)
|
||
|
for _, pkg := range prog.AllPackages() {
|
||
|
for _, mem := range pkg.Members {
|
||
|
gbl, ok := mem.(*ssa.Global)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
gbltype := gbl.Type()
|
||
|
// globals are always pointers
|
||
|
if !types.Identical(deref(gbltype), builtinErrorType) {
|
||
|
continue
|
||
|
}
|
||
|
if !isAccessibleFrom(gbl.Object(), qpos.info.Pkg) {
|
||
|
continue
|
||
|
}
|
||
|
globals[gbl] = nil
|
||
|
}
|
||
|
}
|
||
|
return globals
|
||
|
}
|
||
|
|
||
|
// findVisibleConsts returns a mapping from each package-level constant assignable to type "error", to nil.
|
||
|
func findVisibleConsts(prog *ssa.Program, qpos *queryPos) map[ssa.Const]*ssa.NamedConst {
|
||
|
constants := make(map[ssa.Const]*ssa.NamedConst)
|
||
|
for _, pkg := range prog.AllPackages() {
|
||
|
for _, mem := range pkg.Members {
|
||
|
obj, ok := mem.(*ssa.NamedConst)
|
||
|
if !ok {
|
||
|
continue
|
||
|
}
|
||
|
consttype := obj.Type()
|
||
|
if !types.AssignableTo(consttype, builtinErrorType) {
|
||
|
continue
|
||
|
}
|
||
|
if !isAccessibleFrom(obj.Object(), qpos.info.Pkg) {
|
||
|
continue
|
||
|
}
|
||
|
constants[*obj.Value] = obj
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return constants
|
||
|
}
|
||
|
|
||
|
type membersByPosAndString []ssa.Member
|
||
|
|
||
|
func (a membersByPosAndString) Len() int { return len(a) }
|
||
|
func (a membersByPosAndString) Less(i, j int) bool {
|
||
|
cmp := a[i].Pos() - a[j].Pos()
|
||
|
return cmp < 0 || cmp == 0 && a[i].String() < a[j].String()
|
||
|
}
|
||
|
func (a membersByPosAndString) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||
|
|
||
|
type sorterrorType []*errorType
|
||
|
|
||
|
func (a sorterrorType) Len() int { return len(a) }
|
||
|
func (a sorterrorType) Less(i, j int) bool {
|
||
|
cmp := a[i].obj.Pos() - a[j].obj.Pos()
|
||
|
return cmp < 0 || cmp == 0 && a[i].typ.String() < a[j].typ.String()
|
||
|
}
|
||
|
func (a sorterrorType) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||
|
|
||
|
type errorType struct {
|
||
|
typ types.Type // concrete type N or *N that implements error
|
||
|
obj *types.TypeName // the named type N
|
||
|
}
|
||
|
|
||
|
type whicherrsResult struct {
|
||
|
qpos *queryPos
|
||
|
errpos token.Pos
|
||
|
globals []ssa.Member
|
||
|
consts []ssa.Member
|
||
|
types []*errorType
|
||
|
}
|
||
|
|
||
|
func (r *whicherrsResult) display(printf printfFunc) {
|
||
|
if len(r.globals) > 0 {
|
||
|
printf(r.qpos, "this error may point to these globals:")
|
||
|
for _, g := range r.globals {
|
||
|
printf(g.Pos(), "\t%s", g.RelString(r.qpos.info.Pkg))
|
||
|
}
|
||
|
}
|
||
|
if len(r.consts) > 0 {
|
||
|
printf(r.qpos, "this error may contain these constants:")
|
||
|
for _, c := range r.consts {
|
||
|
printf(c.Pos(), "\t%s", c.RelString(r.qpos.info.Pkg))
|
||
|
}
|
||
|
}
|
||
|
if len(r.types) > 0 {
|
||
|
printf(r.qpos, "this error may contain these dynamic types:")
|
||
|
for _, t := range r.types {
|
||
|
printf(t.obj.Pos(), "\t%s", r.qpos.typeString(t.typ))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (r *whicherrsResult) toSerial(res *serial.Result, fset *token.FileSet) {
|
||
|
we := &serial.WhichErrs{}
|
||
|
we.ErrPos = fset.Position(r.errpos).String()
|
||
|
for _, g := range r.globals {
|
||
|
we.Globals = append(we.Globals, fset.Position(g.Pos()).String())
|
||
|
}
|
||
|
for _, c := range r.consts {
|
||
|
we.Constants = append(we.Constants, fset.Position(c.Pos()).String())
|
||
|
}
|
||
|
for _, t := range r.types {
|
||
|
var et serial.WhichErrsType
|
||
|
et.Type = r.qpos.typeString(t.typ)
|
||
|
et.Position = fset.Position(t.obj.Pos()).String()
|
||
|
we.Types = append(we.Types, et)
|
||
|
}
|
||
|
res.WhichErrs = we
|
||
|
}
|