Skip to content

Commit

Permalink
Improve compilation of lambdas so we remember the original header and…
Browse files Browse the repository at this point in the history
… use that when compiling the lambda bridge method
  • Loading branch information
stanhebben committed Nov 15, 2024
1 parent a29b19f commit 6b4c963
Show file tree
Hide file tree
Showing 17 changed files with 96 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,31 @@ public static CastedEval implicit(ExpressionCompiler compiler, CodePosition posi
return new CastedEval(compiler, position, type, false, false);
}

public static CastedEval implicit(ExpressionCompiler compiler, CodePosition position, TypeID type, TypeID original) {
return new CastedEval(compiler, position, type, original, false, false);
}

private final ExpressionCompiler compiler;
private final CodePosition position;
public final TypeID type;
public final TypeID original; // used for lambdas
private final boolean explicit;
private final boolean optional;

public CastedEval(ExpressionCompiler compiler, CodePosition position, TypeID type, boolean explicit, boolean optional) {
this.compiler = compiler;
this.position = position;
this.type = type;
this.original = type;
this.explicit = explicit;
this.optional = optional;
}

public CastedEval(ExpressionCompiler compiler, CodePosition position, TypeID type, TypeID original, boolean explicit, boolean optional) {
this.compiler = compiler;
this.position = position;
this.type = type;
this.original = original;
this.explicit = explicit;
this.optional = optional;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ public interface ExpressionBuilder {

Expression lambda(LambdaClosure closure, FunctionHeader header, Statement body);

Expression lambda(LambdaClosure closure, FunctionHeader header, FunctionHeader original, Statement body);

Expression newArray(ArrayTypeID type, Expression[] values);

Expression newAssoc(AssocTypeID type, List<Expression> keys, List<Expression> values);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ private static <T extends AnyMethod> MatchedCallArguments<T> matchNormal(
// invalid
return CastedExpression.invalid(position, CompileErrors.missingParameter(header.getParameter(false, i).name));
}
return argument.cast(CastedEval.implicit(compiler, position, header.getParameterType(false, i)));
TypeID originalType = method.getHeader().getParameterType(false, i);
return argument.cast(CastedEval.implicit(compiler, position, header.getParameterType(false, i), originalType));
})
.toArray(CastedExpression[]::new);

Expand Down Expand Up @@ -293,7 +294,12 @@ private static <T extends AnyMethod> Optional<MatchedCallArguments<T>> matchVarA
}

CastedExpression[] castedExpressions = IntStream.range(0, arguments.length)
.mapToObj(i -> arguments[i].cast(CastedEval.implicit(compiler, position, header.getParameterType(true, i))))
.mapToObj(i -> arguments[i].cast(CastedEval.implicit(
compiler,
position,
header.getParameterType(true, i),
method.getHeader().getParameterType(true, i)
)))
.toArray(CastedExpression[]::new);

Expression[] expressions = new Expression[header.parameters.length];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,12 @@ public Expression is(Expression value, TypeID type) {

@Override
public Expression lambda(LambdaClosure closure, FunctionHeader header, Statement body) {
return new FunctionExpression(position, closure, header, body);
return lambda(closure, header, null, body);
}

@Override
public Expression lambda(LambdaClosure closure, FunctionHeader header, FunctionHeader original, Statement body) {
return new FunctionExpression(position, closure, header, original, body);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public final Expression transform(StatementTransformer transformer) {
if (body == function.body)
return function;

return new FunctionExpression(function.position, function.closure, function.header, body);
return new FunctionExpression(function.position, function.closure, function.header, function.original, body);
} else {
return expression;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@ public class FunctionExpression extends Expression {
public final FunctionHeader header;
public final LambdaClosure closure;
public final Statement body;
public final FunctionHeader original;

public FunctionExpression(
CodePosition position,
LambdaClosure closure,
FunctionHeader header,
FunctionHeader original,
Statement body) {
super(position, new FunctionTypeID(header), body.getThrownType());

this.header = header;
this.closure = closure;
this.body = body;
this.original = original;
}

@Override
Expand All @@ -44,7 +47,7 @@ public <C, R> R accept(C context, ExpressionVisitorWithContext<C, R> visitor) {
@Override
public FunctionExpression transform(ExpressionTransformer transformer) {
Statement tBody = body.transform(transformer, ConcatMap.empty(LoopStatement.class, LoopStatement.class));
return tBody == body ? this : new FunctionExpression(position, closure, header, tBody);
return tBody == body ? this : new FunctionExpression(position, closure, header, original, tBody);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ public Void visitRange(TypeID context, RangeTypeID range) {

@Override
public Void visitOptional(TypeID context, OptionalTypeID type) {
if (type.baseType == BasicTypeID.USIZE) {
writer.invokeStatic(INTEGER_VALUEOF);
}
//NO-OP
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.objectweb.asm.Type;
import org.openzen.zencode.shared.CodePosition;
import org.openzen.zenscript.codemodel.CompareType;
import org.openzen.zenscript.codemodel.FunctionHeader;
import org.openzen.zenscript.codemodel.OperatorType;
import org.openzen.zenscript.codemodel.definition.ExpansionDefinition;
import org.openzen.zenscript.codemodel.expression.captured.CapturedExpression;
Expand Down Expand Up @@ -474,6 +475,7 @@ public Void visitFunction(FunctionExpression expression) {
}*/

final String[] interfaces;
FunctionHeader header = expression.original == null ? expression.header : expression.original;

if (expression.type instanceof JavaFunctionalInterfaceTypeID) {
//Let's implement the functional Interface instead
Expand All @@ -484,19 +486,18 @@ public Void visitFunction(FunctionExpression expression) {
interfaces = new String[]{Type.getInternalName(functionalInterfaceMethod.getDeclaringClass())};
} else {
//Normal way, no casting to functional interface
interfaces = new String[]{context.getInternalName(new FunctionTypeID(expression.header))};
interfaces = new String[]{context.getInternalName(new FunctionTypeID(header))};
}

final JavaNativeMethod methodInfo;
final String className = this.javaMangler.mangleGeneratedLambdaName(interfaces[0]);
{
final JavaNativeMethod m = context.getFunctionalInterface(expression.type);
final JavaNativeMethod m = context.getFunctionalInterface(expression.original == null ? expression.type : new FunctionTypeID(expression.original));
methodInfo = m.withModifiers(m.modifiers & ~JavaModifiers.ABSTRACT);
}
final ClassWriter lambdaCW = new JavaClassWriter(ClassWriter.COMPUTE_FRAMES);
JavaClass lambdaClass = JavaClass.fromInternalName(className, JavaClass.Kind.CLASS);
lambdaCW.visit(Opcodes.V1_8, Opcodes.ACC_PUBLIC, className, null, "java/lang/Object", interfaces);
final JavaWriter functionWriter;

JavaCompilingMethod actualCompiling = JavaMemberVisitor.compileBridgeableMethod(
context,
Expand All @@ -507,8 +508,6 @@ public Void visitFunction(FunctionExpression expression) {
expression.header,
null
);
functionWriter = new JavaWriter(context.logger, expression.position, lambdaCW, actualCompiling, null);
functionWriter.clazzVisitor.visitSource(expression.position.getFilename(), null);
javaWriter.newObject(className);
javaWriter.dup();

Expand Down Expand Up @@ -544,7 +543,8 @@ public Void visitFunction(FunctionExpression expression) {
constructorWriter.ret();
constructorWriter.end();


JavaWriter functionWriter = new JavaWriter(context.logger, expression.position, lambdaCW, actualCompiling, null);
functionWriter.clazzVisitor.visitSource(expression.position.getFilename(), null);
functionWriter.start();

JavaExpressionVisitor withCapturedExpressionVisitor = new JavaExpressionVisitor(
Expand All @@ -564,6 +564,7 @@ public Void visitFunction(FunctionExpression expression) {

functionWriter.ret();
functionWriter.end();

lambdaCW.visitEnd();

context.register(className, lambdaCW.toByteArray());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,9 +468,12 @@ public static JavaCompilingMethod compileBridgeableMethod(
bridgeWriter.invokeVirtual(new JavaNativeMethod(localClass, JavaNativeMethod.Kind.INSTANCE, overriddenMethodInfo.name, overriddenMethodInfo.compile, implementationDescriptor, overriddenMethodInfo.modifiers, overriddenMethodInfo.genericResult));
final TypeID returnType = implementationHeader.getReturnType();
if (returnType != BasicTypeID.VOID) {
final Type returnTypeASM = context.getType(returnType);
Type returnTypeASM = context.getType(returnType);
if (!CompilerUtils.isPrimitive(returnType)) {
bridgeWriter.checkCast(returnTypeASM);
} else if (!isPrimitiveReturnType(overriddenMethodInfo.descriptor)) {
returnType.accept(returnType, JavaBoxingTypeVisitor.forJavaBoxing(bridgeWriter));
returnTypeASM = Type.getReturnType(overriddenMethodInfo.descriptor);
}
bridgeWriter.returnType(returnTypeASM);
}
Expand All @@ -484,4 +487,20 @@ public static JavaCompilingMethod compileBridgeableMethod(
return new JavaCompilingMethod(overriddenMethodInfo, implementationSignature);
}
}

private static boolean isPrimitiveReturnType(String descriptor) {
switch (Type.getReturnType(descriptor).getSort()) {
case Type.BYTE:
case Type.SHORT:
case Type.INT:
case Type.LONG:
case Type.FLOAT:
case Type.DOUBLE:
case Type.BOOLEAN:
case Type.CHAR:
return true;
default:
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ public CastedExpression cast(CastedEval cast) {
thatOtherHeader.setReturnType(header.getReturnType());
}*/

return cast.of(CastedExpression.Level.EXACT, compiler.at(position).lambda(closure, header, statement));
FunctionHeader originalHeader = cast.original.asFunction().map(f -> f.header).orElse(null);
return cast.of(CastedExpression.Level.EXACT, compiler.at(position).lambda(closure, header, originalHeader, statement));
} else {
return cast.of(eval());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#disabled: Requires primitive specializations, not supported yet
#dependency: stdlib
#output: 1
#output: 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public expand <T> T[] {


var arr = new string[](6, "Hello");
var lengths = arr.map<usize>((element) => element.length);
var lengths = arr.map<string>((element) => element.length);

println(lengths.length);
for length in lengths println(length);
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#disabled: Requires primitive specializations, not supported yet
#output: 6
#output: 5
#output: 5
#output: 5
#output: 5
#output: 5
#output: 5

// One of the tests required to make StdLib (Arrays.zs) work
public expand <T> T[] {
public map<U>(projection as function(value as T) as U) as U[] {
return new U[](length, i => projection(this[i]));
}
}


var arr = new string[](6, "Hello");
var lengths = arr.map<usize>((element) => element.length);

println(lengths.length);
for length in lengths println(length);
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#disabled: Adding signed and unsigned requires explicit casting

val byteValue = 1 as byte;
val sbyteValue = 2 as sbyte;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#disabled: Not supported yet
#output: hello
#output: hello2

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#disabled: Requires primitive specializations, not supported yet
#dependency: stdlib
#output: 5
#output: 4
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#disabled: Requires primitive specializations, not supported yet
#dependency: stdlib
#output: 5
#output: 4
Expand Down

0 comments on commit 6b4c963

Please sign in to comment.