From e86ce1807366293f6cfea9558885a4fc23f69e21 Mon Sep 17 00:00:00 2001 From: Peter Goetz Date: Mon, 13 May 2019 23:09:33 +0200 Subject: [PATCH] Allow also directed channels in callback stubs Fixes previous commit which accidentally only allowed undirected channels. https://github.com/petergtz/pegomock/issues/84 --- dsl_test.go | 19 +++++++++++++++++++ mockgen/mockgen.go | 45 +++++++++++++++++++++++++++------------------ 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/dsl_test.go b/dsl_test.go index 18da964..59ab622 100644 --- a/dsl_test.go +++ b/dsl_test.go @@ -42,6 +42,7 @@ var ( BeTrue = gomega.BeTrue ConsistOf = gomega.ConsistOf ContainSubstring = gomega.ContainSubstring + MatchError = gomega.MatchError Equal = gomega.Equal Expect = gomega.Expect HaveLen = gomega.HaveLen @@ -894,6 +895,24 @@ var _ = Describe("MockDisplay", func() { }) display.ChanReturnValues() }) + + It("allows to return directed channels from callbacks", func() { + When(display.ChanReturnValues()).Then(func([]pegomock.Param) pegomock.ReturnValues { + return []ReturnValue{make(<-chan string), make(chan<- error)} + }) + display.ChanReturnValues() + }) + + It("does not allow to return directed channels from callbacks with wrong direction", func() { + When(display.ChanReturnValues()).Then(func([]pegomock.Param) pegomock.ReturnValues { + return []ReturnValue{make(chan<- string), make(chan<- error)} + }) + + Expect(func() { display.ChanReturnValues() }).To(PanicWithMessageTo(MatchError( + "interface conversion: pegomock.ReturnValue is chan<- string, not <-chan string", + ))) + + }) }) Context("using send-/receive-only channels", func() { diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index c3f392e..49438ef 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -155,7 +155,7 @@ func (g *generator) generateMockFor(iface *model.Interface, mockTypeName, selfPa g.generateVerifierType(mockTypeName) for _, method := range iface.Methods { ongoingVerificationTypeName := fmt.Sprintf("%v_%v_OngoingVerification", mockTypeName, method.Name) - args, argNames, argTypes, _, _ := argDataFor(method, g.packageMap, selfPackage) + args, argNames, argTypes, _ := argDataFor(method, g.packageMap, selfPackage) g.generateVerifierMethod(mockTypeName, method, selfPackage, ongoingVerificationTypeName, args, argNames) g.generateOngoingVerificationType(mockTypeName, ongoingVerificationTypeName) g.generateOngoingVerificationGetCapturedArguments(ongoingVerificationTypeName, argNames, argTypes) @@ -185,15 +185,15 @@ func (g *generator) generateMockType(mockTypeName string) { // If non-empty, pkgOverride is the package in which unqualified types reside. func (g *generator) generateMockMethod(mockType string, method *model.Method, pkgOverride string) *generator { - args, argNames, _, signatureReturnTypes, returnTypes := argDataFor(method, g.packageMap, pkgOverride) - g.p("func (mock *%v) %v(%v) (%v) {", mockType, method.Name, join(args), join(signatureReturnTypes)) + args, argNames, _, returnTypes := argDataFor(method, g.packageMap, pkgOverride) + g.p("func (mock *%v) %v(%v) (%v) {", mockType, method.Name, join(args), join(stringSliceFrom(returnTypes, g.packageMap, pkgOverride))) g.p("if mock == nil {"). p(" panic(\"mock must not be nil. Use myMock := New%v().\")", mockType). p("}") g.GenerateParamsDeclaration(argNames, method.Variadic != nil) reflectReturnTypes := make([]string, len(returnTypes)) for i, returnType := range returnTypes { - reflectReturnTypes[i] = fmt.Sprintf("reflect.TypeOf((*%v)(nil)).Elem()", returnType) + reflectReturnTypes[i] = fmt.Sprintf("reflect.TypeOf((*%v)(nil)).Elem()", returnType.String(g.packageMap, pkgOverride)) } resultAssignment := "" if len(method.Out) > 0 { @@ -204,13 +204,23 @@ func (g *generator) generateMockMethod(mockType string, method *model.Method, pk if len(method.Out) > 0 { // TODO: translate LastInvocation into a Matcher so it can be used as key for Stubbings for i, returnType := range returnTypes { - g.p("var ret%v %v", i, returnType) + g.p("var ret%v %v", i, returnType.String(g.packageMap, pkgOverride)) } g.p("if len(result) != 0 {") returnValues := make([]string, len(returnTypes)) for i, returnType := range returnTypes { g.p("if result[%v] != nil {", i) - g.p("ret%v = result[%v].(%v)", i, i, returnType) + if chanType, isChanType := returnType.(*model.ChanType); isChanType && chanType.Dir != 0 { + undirectedChanType := *chanType + undirectedChanType.Dir = 0 + g.p("var ok bool"). + p(" ret%v, ok = result[%v].(%v)", i, i, undirectedChanType.String(g.packageMap, pkgOverride)) + g.p("if !ok{"). + p("ret%v = result[%v].(%v)", i, i, chanType.String(g.packageMap, pkgOverride)). + p("}") + } else { + g.p("ret%v = result[%v].(%v)", i, i, returnType.String(g.packageMap, pkgOverride)) + } g.p("}") returnValues[i] = fmt.Sprintf("ret%v", i) } @@ -353,8 +363,7 @@ func argDataFor(method *model.Method, packageMap map[string]string, pkgOverride args []string, argNames []string, argTypes []string, - signatureReturnTypes []string, - returnTypes []string, + returnTypes []model.Type, ) { args = make([]string, len(method.In)) argNames = make([]string, len(method.In)) @@ -379,21 +388,21 @@ func argDataFor(method *model.Method, packageMap map[string]string, pkgOverride argNames = append(argNames, argName) argTypes = append(argTypes, "[]"+argType) } - signatureReturnTypes = make([]string, len(method.Out)) - returnTypes = make([]string, len(method.Out)) + returnTypes = make([]model.Type, len(method.Out)) for i, ret := range method.Out { - if chanType, isChanType := ret.Type.(*model.ChanType); isChanType { - chanTypeNoDir := *chanType - chanTypeNoDir.Dir = 0 - returnTypes[i] = chanTypeNoDir.String(packageMap, pkgOverride) - } else { - returnTypes[i] = ret.Type.String(packageMap, pkgOverride) - } - signatureReturnTypes[i] = ret.Type.String(packageMap, pkgOverride) + returnTypes[i] = ret.Type } return } +func stringSliceFrom(types []model.Type, packageMap map[string]string, pkgOverride string) []string { + result := make([]string, len(types)) + for i, t := range types { + result[i] = t.String(packageMap, pkgOverride) + } + return result +} + func addTypesFromMethodParamsTo(typesSet map[string]string, params []*model.Parameter, packageMap map[string]string) { for _, param := range params { switch typedType := param.Type.(type) {