Skip to content

Commit

Permalink
ruleguard: implement SinkType filter (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
quasilyte authored Mar 22, 2022
1 parent 3c6c7c9 commit 003e476
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 8 deletions.
120 changes: 120 additions & 0 deletions analyzer/testdata/src/filtertest/f1.go
Original file line number Diff line number Diff line change
Expand Up @@ -955,3 +955,123 @@ func detectGlobal() {
print(globalVar)
}
}

func detectSinkType() {
// Call argument context.
_ = acceptReader(newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
_ = acceptReader(newIface("sink is io.Reader").(io.Reader)) // want `true`
_ = acceptReader((newIface("sink is io.Reader").(io.Reader))) // want `true`
_ = acceptBuffer(newIface("sink is io.Reader").(*bytes.Buffer))
_ = acceptReaderVariadic(10, newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
_ = acceptReaderVariadic(10, newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
_ = acceptReaderVariadic(10, nil, newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
_ = acceptReaderVariadic(10, newIface("sink is io.Reader").([]io.Reader)...)
_ = acceptWriterVariadic(10, newIface("sink is io.Reader").(*bytes.Buffer))
_ = acceptWriterVariadic(10, nil, newIface("sink is io.Reader").(*bytes.Buffer))
_ = acceptWriterVariadic(10, nil, nil, newIface("sink is io.Reader").(*bytes.Buffer))
_ = acceptVariadic(10, newIface("sink is io.Reader").(*bytes.Buffer))
_ = acceptVariadic(10, nil, newIface("sink is io.Reader").(*bytes.Buffer))
_ = acceptVariadic(10, nil, nil, newIface("sink is io.Reader").(*bytes.Buffer))
fmt.Println(newIface("sink is interface{}").(int)) // want `true`
fmt.Println(1, newIface("sink is interface{}").(io.Reader)) // want `true`

// Type conversion context.
_ = io.Reader(newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
_ = io.Writer(newIface("sink is io.Reader").(*bytes.Buffer))

// Return stmt context.
{
_ = func() (io.Reader, io.Writer) {
return newIface("sink is io.Reader").(*bytes.Buffer), nil // want `true`
}
_ = func() (io.Reader, io.Writer) {
return nil, newIface("sink is io.Reader").(*bytes.Buffer)
}
_ = func() (io.Writer, io.Reader) {
return nil, newIface("sink is io.Reader").(*bytes.Buffer) // want `true`
}
}

// Assignment context.
{
var r io.Reader = (newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
var _ io.Reader = newIface("sink is io.Reader").(*bytes.Buffer) // want `true`
var w io.Writer = newIface("sink is io.Reader").(*bytes.Buffer)
x := newIface("sink is io.Reader").(*bytes.Buffer)
_ = r
_ = w
_ = x
var readers map[string]io.Reader
readers["foo"] = newIface("sink is io.Reader").(*bytes.Buffer) // want `true`
var writers map[string]io.Writer
writers["foo"] = newIface("sink is io.Reader").(*bytes.Buffer)
var foo exampleStruct
foo.r = newIface("sink is io.Reader").(*bytes.Buffer) // want `true`
foo.buf = newIface("sink is io.Reader").(*bytes.Buffer)
foo.w = newIface("sink is io.Reader").(*bytes.Buffer)
}

// Index expr context
{
var readerKeys map[io.Reader]string
readerKeys[newIface("sink is io.Reader").(*bytes.Buffer)] = "ok" // want `true`
readerKeys[(newIface("sink is io.Reader").(*bytes.Buffer))] = "ok" // want `true`
var writerKeys map[io.Writer]string
writerKeys[newIface("sink is io.Reader").(*bytes.Buffer)] = "ok"
writerKeys[(newIface("sink is io.Reader").(*bytes.Buffer))] = "ok"
}

// Composite lit element context.
_ = []io.Reader{
newIface("sink is io.Reader").(*bytes.Buffer), // want `true`
}
_ = []io.Reader{
10: newIface("sink is io.Reader").(*bytes.Buffer), // want `true`
}
_ = [10]io.Reader{
4: newIface("sink is io.Reader").(*bytes.Buffer), // want `true`
}
_ = map[string]io.Reader{
"foo": newIface("sink is io.Reader").(*bytes.Buffer), // want `true`
}
_ = map[io.Reader]string{
newIface("sink is io.Reader").(*bytes.Buffer): "foo", // want `true`
}
_ = map[io.Reader]string{
(newIface("sink is io.Reader").(*bytes.Buffer)): "foo", // want `true`
}
_ = []io.Writer{
(newIface("sink is io.Reader").(*bytes.Buffer)),
}
_ = exampleStruct{
w: newIface("sink is io.Reader").(*bytes.Buffer),
r: newIface("sink is io.Reader").(*bytes.Buffer), // want `true`
}
_ = []interface{}{
newIface("sink is interface{}").(*bytes.Buffer), // want `true`
newIface("sink is interface{}").(int), // want `true`
}
}

func detectSinkType2() io.Reader {
return newIface("sink is io.Reader").(*bytes.Buffer) // want `true`
}

func detectSinkType3() io.Writer {
return newIface("sink is io.Reader").(*bytes.Buffer)
}

func newIface(key string) interface{} { return nil }

func acceptReaderVariadic(a int, r ...io.Reader) int { return 0 }
func acceptWriterVariadic(a int, r ...io.Writer) int { return 0 }
func acceptVariadic(a int, r ...interface{}) int { return 0 }
func acceptReader(r io.Reader) int { return 0 }
func acceptWriter(r io.Writer) int { return 0 }
func acceptBuffer(b *bytes.Buffer) int { return 0 }

type exampleStruct struct {
r io.Reader
w io.Writer
buf *bytes.Buffer
}
8 changes: 8 additions & 0 deletions analyzer/testdata/src/filtertest/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,12 @@ func testRules(m dsl.Matcher) {
`$x := time.Now().String()`).
Where(m["x"].Object.IsGlobal()).
Report(`global var`)

m.Match(`newIface("sink is io.Reader").($_)`).
Where(m["$$"].SinkType.Is(`io.Reader`)).
Report(`true`)

m.Match(`newIface("sink is interface{}").($_)`).
Where(m["$$"].SinkType.Is(`interface{}`)).
Report(`true`)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ go 1.17
require (
github.com/go-toolsmith/astcopy v1.0.0
github.com/google/go-cmp v0.5.6
github.com/quasilyte/go-ruleguard/dsl v0.3.18
github.com/quasilyte/go-ruleguard/dsl v0.3.19
github.com/quasilyte/go-ruleguard/rules v0.0.0-20211022131956-028d6511ab71
github.com/quasilyte/gogrep v0.0.0-20220120141003-628d8b3623b5
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567
Expand Down
6 changes: 2 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/quasilyte/go-ruleguard v0.3.1-0.20210203134552-1b5a410e1cc8/go.mod h1:KsAh3x0e7Fkpgs+Q9pNLS5XpFSvYCEVl5gP9Pp1xp30=
github.com/quasilyte/go-ruleguard/dsl v0.3.0/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
github.com/quasilyte/go-ruleguard/dsl v0.3.17 h1:L5xf3nifnRIdYe9vyMuY2sDnZHIgQol/fDq74FQz7ZY=
github.com/quasilyte/go-ruleguard/dsl v0.3.17/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
github.com/quasilyte/go-ruleguard/dsl v0.3.18 h1:gzHcFxmTwhn+ZKZd6nGw7JyjoDcYuwcA+TY5MNn9oMk=
github.com/quasilyte/go-ruleguard/dsl v0.3.18/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
github.com/quasilyte/go-ruleguard/dsl v0.3.19 h1:5+KTKb2YREUYiqZFEIuifFyBxlcCUPWgNZkWy71XS0Q=
github.com/quasilyte/go-ruleguard/dsl v0.3.19/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
github.com/quasilyte/go-ruleguard/rules v0.0.0-20201231183845-9e62ed36efe1/go.mod h1:7JTjp89EGyU1d6XfBiXihJNG37wB2VRkd125Q1u7Plc=
github.com/quasilyte/go-ruleguard/rules v0.0.0-20211022131956-028d6511ab71 h1:CNooiryw5aisadVfzneSZPswRWvnVW8hF1bS/vo8ReI=
github.com/quasilyte/go-ruleguard/rules v0.0.0-20211022131956-028d6511ab71/go.mod h1:4cgAphtvu7Ftv7vOT2ZOYhC6CvBxZixcasr8qIOTA50=
Expand Down
136 changes: 136 additions & 0 deletions ruleguard/filters.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/quasilyte/gogrep"
"github.com/quasilyte/gogrep/nodetag"
"golang.org/x/tools/go/ast/astutil"

"github.com/quasilyte/go-ruleguard/internal/xtypes"
"github.com/quasilyte/go-ruleguard/ruleguard/quasigo"
Expand Down Expand Up @@ -303,6 +304,21 @@ func makeTypesIdenticalFilter(src, lhsVarname, rhsVarname string) filterFunc {
}
}

func makeRootSinkTypeIsFilter(src string, pat *typematch.Pattern) filterFunc {
return func(params *filterParams) matchFilterResult {
// TODO(quasilyte): add variadic support?
e, ok := params.match.Node().(ast.Expr)
if ok {
parent, kv := findSinkRoot(params)
typ := findSinkType(params, parent, kv, e)
if pat.MatchIdentical(params.typematchState, typ) {
return filterSuccess
}
}
return filterFailure(src)
}
}

func makeTypeIsFilter(src, varname string, underlying bool, pat *typematch.Pattern) filterFunc {
if underlying {
return func(params *filterParams) matchFilterResult {
Expand Down Expand Up @@ -660,3 +676,123 @@ func typeHasPointers(typ types.Type) bool {
return true
}
}

func findSinkRoot(params *filterParams) (ast.Node, *ast.KeyValueExpr) {
for i := 1; i < params.nodePath.Len(); i++ {
switch n := params.nodePath.NthParent(i).(type) {
case *ast.ParenExpr:
// Skip and continue.
continue
case *ast.KeyValueExpr:
return params.nodePath.NthParent(i + 1).(ast.Expr), n
default:
return n, nil
}
}
return nil, nil
}

func findContainingFunc(params *filterParams) *types.Signature {
for i := 2; i < params.nodePath.Len(); i++ {
switch n := params.nodePath.NthParent(i).(type) {
case *ast.FuncDecl:
fn, ok := params.ctx.Types.TypeOf(n.Name).(*types.Signature)
if ok {
return fn
}
case *ast.FuncLit:
fn, ok := params.ctx.Types.TypeOf(n.Type).(*types.Signature)
if ok {
return fn
}
}
}
return nil
}

func findSinkType(params *filterParams, parent ast.Node, kv *ast.KeyValueExpr, e ast.Expr) types.Type {
switch parent := parent.(type) {
case *ast.ValueSpec:
return params.ctx.Types.TypeOf(parent.Type)

case *ast.ReturnStmt:
for i, result := range parent.Results {
if astutil.Unparen(result) != e {
continue
}
sig := findContainingFunc(params)
if sig == nil {
break
}
return sig.Results().At(i).Type()
}

case *ast.IndexExpr:
if astutil.Unparen(parent.Index) == e {
switch typ := params.ctx.Types.TypeOf(parent.X).Underlying().(type) {
case *types.Map:
return typ.Key()
case *types.Slice, *types.Array:
return nil // TODO: some untyped int type?
}
}

case *ast.AssignStmt:
if parent.Tok != token.ASSIGN || len(parent.Lhs) != len(parent.Rhs) {
break
}
for i, rhs := range parent.Rhs {
if rhs == e {
return params.ctx.Types.TypeOf(parent.Lhs[i])
}
}

case *ast.CompositeLit:
switch typ := params.ctx.Types.TypeOf(parent).Underlying().(type) {
case *types.Slice:
return typ.Elem()
case *types.Array:
return typ.Elem()
case *types.Map:
if astutil.Unparen(kv.Key) == e {
return typ.Key()
}
return typ.Elem()
case *types.Struct:
fieldName, ok := kv.Key.(*ast.Ident)
if !ok {
break
}
for i := 0; i < typ.NumFields(); i++ {
field := typ.Field(i)
if field.Name() == fieldName.String() {
return field.Type()
}
}
}

case *ast.CallExpr:
switch typ := params.ctx.Types.TypeOf(parent.Fun).(type) {
case *types.Signature:
// A function call argument.
for i, arg := range parent.Args {
if astutil.Unparen(arg) != e {
continue
}
isVariadicArg := (i >= typ.Params().Len()-1) && typ.Variadic()
if isVariadicArg && !parent.Ellipsis.IsValid() {
return typ.Params().At(typ.Params().Len() - 1).Type().(*types.Slice).Elem()
}
if i < typ.Params().Len() {
return typ.Params().At(i).Type()
}
break
}
default:
// Probably a type cast.
return typ
}
}

return invalidType
}
2 changes: 1 addition & 1 deletion ruleguard/gorule.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (params *filterParams) typeofNode(n ast.Node) types.Type {
if typ := params.ctx.Types.TypeOf(e); typ != nil {
return typ
}
return types.Typ[types.Invalid]
return invalidType
}

func mergeRuleSets(toMerge []*goRuleSet) (*goRuleSet, error) {
Expand Down
4 changes: 4 additions & 0 deletions ruleguard/ir/filter_op.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ruleguard/ir/gen_filter_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ func main() {
{name: "Int", comment: "$Value holds an int64 constant", valueType: "int64", flags: flagIsBasicLit},

{name: "RootNodeParentIs", comment: "m[`$$`].Node.Parent().Is($Args[0])"},
{name: "RootSinkTypeIs", comment: "m[`$$`].SinkType.Is($Args[0])"},
}

var buf bytes.Buffer
Expand Down
12 changes: 12 additions & 0 deletions ruleguard/ir_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,18 @@ func (l *irLoader) newFilter(filter ir.FilterExpr, info *filterInfo) (matchFilte
}
result.fn = makeNodeIsFilter(result.src, filter.Value.(string), tag)

case ir.FilterRootSinkTypeIsOp:
typeString := l.unwrapStringExpr(filter.Args[0])
if typeString == "" {
return result, l.errorf(filter.Line, nil, "expected a non-empty string argument")
}
ctx := typematch.Context{Itab: l.itab}
pat, err := typematch.Parse(&ctx, typeString)
if err != nil {
return result, l.errorf(filter.Line, err, "parse type expr")
}
result.fn = makeRootSinkTypeIsFilter(result.src, pat)

case ir.FilterVarTypeHasPointersOp:
result.fn = makeTypeHasPointersFilter(result.src, filter.Value.(string))

Expand Down
6 changes: 6 additions & 0 deletions ruleguard/irconv/irconv.go
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,12 @@ func (conv *converter) convertFilterExprImpl(e ast.Expr) ir.FilterExpr {
return ir.FilterExpr{Op: ir.FilterVarObjectIsOp, Value: op.varName, Args: args}
case "Object.IsGlobal":
return ir.FilterExpr{Op: ir.FilterVarObjectIsGlobalOp, Value: op.varName}
case "SinkType.Is":
if op.varName != "$$" {
// TODO: remove this restriction.
panic(conv.errorf(e.Args[0], "sink type is only implemented for $$ var"))
}
return ir.FilterExpr{Op: ir.FilterRootSinkTypeIsOp, Value: op.varName, Args: args}
case "Type.HasPointers":
return ir.FilterExpr{Op: ir.FilterVarTypeHasPointersOp, Value: op.varName}
case "Type.Is":
Expand Down
4 changes: 2 additions & 2 deletions ruleguard/nodepath.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ func (p nodePath) Current() ast.Node {
}

func (p nodePath) NthParent(n int) ast.Node {
index := len(p.stack) - n - 1
if index >= 0 {
index := uint(len(p.stack) - n - 1)
if index < uint(len(p.stack)) {
return p.stack[index]
}
return nil
Expand Down
2 changes: 2 additions & 0 deletions ruleguard/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"strings"
)

var invalidType = types.Typ[types.Invalid]

func regexpHasCaptureGroups(pattern string) bool {
// regexp.Compile() uses syntax.Perl flags, so
// we use the same flags here.
Expand Down

0 comments on commit 003e476

Please sign in to comment.