diff --git a/handler.go b/handler.go index 8b2a709..2673f35 100644 --- a/handler.go +++ b/handler.go @@ -36,8 +36,8 @@ func DecodeParams[T any](p RawParams) (T, error) { // methodHandler is a handler for a single method type methodHandler struct { - paramReceivers []reflect.Type - nParams int + paramReceivers []reflect.Type + reqParams, optParams int receiver reflect.Value handlerFunc reflect.Value @@ -144,30 +144,34 @@ func (s *handler) register(namespace string, r interface{}) { for i := 0; i < val.NumMethod(); i++ { method := val.Type().Method(i) - funcType := method.Func.Type() hasCtx := 0 - if funcType.NumIn() >= 2 && funcType.In(1) == contextType { + if method.Type.NumIn() >= 2 && method.Type.In(1) == contextType { hasCtx = 1 } hasRawParams := false - ins := funcType.NumIn() - 1 - hasCtx + ins := method.Type.NumIn() - 1 - hasCtx + opt := 0 recvs := make([]reflect.Type, ins) for i := 0; i < ins; i++ { if hasRawParams && i > 0 { panic("raw params must be the last parameter") } - if funcType.In(i+1+hasCtx) == rtRawParams { + if method.Type.In(i+1+hasCtx) == rtRawParams { hasRawParams = true } recvs[i] = method.Type.In(i + 1 + hasCtx) + if recvs[i].Kind() == reflect.Pointer { + opt++ + } } - valOut, errOut, _ := processFuncOut(funcType) + valOut, errOut, _ := processFuncOut(method.Type) s.methods[namespace+"."+method.Name] = methodHandler{ paramReceivers: recvs, - nParams: ins, + reqParams: ins - opt, + optParams: opt, handlerFunc: method.Func, receiver: val, @@ -361,7 +365,7 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer return } - callParams := make([]reflect.Value, 1+handler.hasCtx+handler.nParams) + callParams := make([]reflect.Value, 1+handler.hasCtx+handler.reqParams+handler.optParams) callParams[0] = handler.receiver if handler.hasCtx == 1 { callParams[1] = reflect.ValueOf(ctx) @@ -385,21 +389,30 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer } } - if len(ps) != handler.nParams { - rpcError(w, &req, rpcInvalidParams, fmt.Errorf("wrong param count (method '%s'): %d != %d", req.Method, len(ps), handler.nParams)) + if len(ps) > handler.reqParams+handler.optParams || len(ps) < handler.reqParams { + var err error + if handler.optParams > 0 { + err = fmt.Errorf("wrong param count %d (method '%s'): expected %d - %d", + len(ps), req.Method, handler.reqParams, + handler.reqParams+handler.optParams, + ) + } else { + err = fmt.Errorf("wrong param count %d (method '%s'): expected %d", len(ps), req.Method, handler.reqParams) + } + rpcError(w, &req, rpcInvalidParams, err) stats.Record(ctx, metrics.RPCRequestError.M(1)) done(false) return } - for i := 0; i < handler.nParams; i++ { + for i, p := range ps { var rp reflect.Value typ := handler.paramReceivers[i] dec, found := s.paramDecoders[typ] if !found { rp = reflect.New(typ) - if err := json.NewDecoder(bytes.NewReader(ps[i].data)).Decode(rp.Interface()); err != nil { + if err := json.NewDecoder(bytes.NewReader(p.data)).Decode(rp.Interface()); err != nil { rpcError(w, &req, rpcParseError, xerrors.Errorf("unmarshaling params for '%s' (param: %T): %w", req.Method, rp.Interface(), err)) stats.Record(ctx, metrics.RPCRequestError.M(1)) return @@ -407,7 +420,7 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer rp = rp.Elem() } else { var err error - rp, err = dec(ctx, ps[i].data) + rp, err = dec(ctx, p.data) if err != nil { rpcError(w, &req, rpcParseError, xerrors.Errorf("decoding params for '%s' (param: %d; custom decoder): %w", req.Method, i, err)) stats.Record(ctx, metrics.RPCRequestError.M(1)) diff --git a/rpc_test.go b/rpc_test.go index 73bd9c1..df105aa 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -318,7 +318,7 @@ func TestRPC(t *testing.T) { require.NoError(t, err) _, err = erronly.AddGet() - if err == nil || err.Error() != "RPC error (-32602): wrong param count (method 'SimpleServerHandler.AddGet'): 0 != 1" { + if err == nil || err.Error() != "RPC error (-32602): wrong param count 0 (method 'SimpleServerHandler.AddGet'): expected 1" { t.Error("wrong error:", err) } closer() @@ -429,7 +429,7 @@ func TestRPCHttpClient(t *testing.T) { require.NoError(t, err) _, err = erronly.AddGet() - if err == nil || err.Error() != "RPC error (-32602): wrong param count (method 'SimpleServerHandler.AddGet'): 0 != 1" { + if err == nil || err.Error() != "RPC error (-32602): wrong param count 0 (method 'SimpleServerHandler.AddGet'): expected 1" { t.Error("wrong error:", err) } closer()