summaryrefslogtreecommitdiff
path: root/vendor/github.com/authzed/spicedb/internal/middleware/handwrittenvalidation/handwrittenvalidation.go
blob: 2adc4b3e31a3bb01353d84abfbb2ed66a732ebd7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package handwrittenvalidation

import (
	"context"

	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

type handwrittenValidator interface {
	HandwrittenValidate() error
}

// UnaryServerInterceptor returns a new unary server interceptor that runs the handwritten validation
// on the incoming request, if any.
func UnaryServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
	validator, ok := req.(handwrittenValidator)
	if ok {
		err := validator.HandwrittenValidate()
		if err != nil {
			return nil, status.Errorf(codes.InvalidArgument, "%s", err)
		}
	}

	return handler(ctx, req)
}

// StreamServerInterceptor returns a new stream server interceptor that runs the handwritten validation
// on the incoming request messages, if any.
func StreamServerInterceptor(srv interface{}, stream grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
	wrapper := &recvWrapper{stream}
	return handler(srv, wrapper)
}

type recvWrapper struct {
	grpc.ServerStream
}

func (s *recvWrapper) RecvMsg(m interface{}) error {
	if err := s.ServerStream.RecvMsg(m); err != nil {
		return err
	}

	validator, ok := m.(handwrittenValidator)
	if ok {
		err := validator.HandwrittenValidate()
		if err != nil {
			return err
		}
	}

	return nil
}