Skip to content

Commit

Permalink
Fix generic expansions and handling of generic array return types
Browse files Browse the repository at this point in the history
  • Loading branch information
stanhebben committed Jan 27, 2024
1 parent b1d89fa commit c9407d9
Show file tree
Hide file tree
Showing 20 changed files with 113 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ public boolean isSimilarTo(FunctionHeader other) {
public FunctionHeader instanceForCall(CallArguments arguments) {
if (arguments.getNumberOfTypeArguments() > 0) {
Map<TypeParameter, TypeID> typeParameters = TypeID.getMapping(this.typeParameters, arguments.typeArguments);
return instance(new GenericMapper(typeParameters));
return instance(new GenericMapper(typeParameters, arguments.expansionTypeArguments));
} else {
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@

public class GenericMapper {

public static final GenericMapper EMPTY = new GenericMapper(Collections.emptyMap());
public static final GenericMapper EMPTY = new GenericMapper(Collections.emptyMap(), TypeID.NONE);

private final Map<TypeParameter, TypeID> mapping;
private final TypeID[] expansionTypeArguments;

public GenericMapper(Map<TypeParameter, TypeID> mapping) {
public GenericMapper(Map<TypeParameter, TypeID> mapping, TypeID[] expansionTypeArguments) {
if (mapping == null)
throw new IllegalArgumentException();

this.mapping = mapping;
this.expansionTypeArguments = expansionTypeArguments;
}

public static GenericMapper create(TypeParameter[] typeParameters, TypeID[] typeArguments) {
Expand All @@ -33,11 +35,11 @@ public static GenericMapper create(TypeParameter[] typeParameters, TypeID[] type
for (int i = 0; i < typeParameters.length; i++) {
mapping.put(typeParameters[i], typeArguments[i]);
}
return new GenericMapper(mapping);
return new GenericMapper(mapping, TypeID.NONE);
}

public static GenericMapper single(TypeParameter parameter, TypeID argument) {
return new GenericMapper(Collections.singletonMap(parameter, argument));
return new GenericMapper(Collections.singletonMap(parameter, argument), TypeID.NONE);
}

public Map<TypeParameter, TypeID> getMapping() {
Expand Down Expand Up @@ -74,7 +76,7 @@ public FieldInstance map(FieldSymbol field) {
}

public MethodInstance map(TypeID target, MethodSymbol method) {
return new MethodInstance(method, map(method.getHeader()), target);
return new MethodInstance(method, map(method.getHeader()), target, expansionTypeArguments);
}

public GenericMapper getInner(Map<TypeParameter, TypeID> mapping) {
Expand All @@ -87,14 +89,14 @@ public GenericMapper getInner(Map<TypeParameter, TypeID> mapping) {
resultMap.put(typeParameter, type);
});

return new GenericMapper(resultMap);
return new GenericMapper(resultMap, expansionTypeArguments);
}

public GenericMapper getInner(TypeParameter[] parameters) {
Map<TypeParameter, TypeID> resultMap = new HashMap<>(this.mapping);
for (TypeParameter parameter : parameters)
resultMap.put(parameter, new GenericTypeID(parameter));
return new GenericMapper(resultMap);
return new GenericMapper(resultMap, expansionTypeArguments);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.openzen.zenscript.codemodel.type.member.ExpandedResolvedType;

import java.util.*;
import java.util.stream.Stream;

public class CompileContext extends AbstractTypeBuilder implements TypeResolver {
private final ZSPackage rootPackage;
Expand Down Expand Up @@ -93,8 +94,9 @@ public ResolvedType resolve(TypeID type) {
if (mapping == null)
continue;

TypeID[] expansionTypeArguments = Stream.of(expansion.typeParameters).map(mapping::get).toArray(TypeID[]::new);
MemberSet.Builder resolution = MemberSet.create();
GenericMapper mapper = new GenericMapper(mapping);
GenericMapper mapper = new GenericMapper(mapping, expansionTypeArguments);
for (IDefinitionMember member : expansion.members)
member.registerTo(type, resolution, mapper);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.openzen.zenscript.codemodel.expression.CallArguments;
import org.openzen.zenscript.codemodel.expression.Expression;
import org.openzen.zenscript.codemodel.generic.TypeParameter;
import org.openzen.zenscript.codemodel.identifiers.instances.MethodInstance;
import org.openzen.zenscript.codemodel.type.ArrayTypeID;
import org.openzen.zenscript.codemodel.type.BasicTypeID;
import org.openzen.zenscript.codemodel.type.TypeID;
Expand Down Expand Up @@ -109,9 +110,12 @@ private static CallArguments match(
TypeID[] typeArguments,
CompilingExpression... arguments
) {
TypeID[] expansionTypeArguments = method.asMethod().map(MethodInstance::getExpansionTypeArguments).orElse(TypeID.NONE);

if (!method.getHeader().accepts(arguments.length))
return new CallArguments(
CastedExpression.Level.INVALID,
expansionTypeArguments,
typeArguments,
Expression.NONE);

Expand All @@ -124,7 +128,7 @@ private static CallArguments match(

// create a mapping with everything found so far
// NOTE - this means that inference is sensitive to order of parameters
GenericMapper mapper = new GenericMapper(typeArgumentMap);
GenericMapper mapper = new GenericMapper(typeArgumentMap, expansionTypeArguments);

// now try to infer type arguments from the arguments
for (int i = 0; i < arguments.length; i++) {
Expand Down Expand Up @@ -152,7 +156,7 @@ private static CallArguments match(
}
if (hasUnknowns) {
// TODO: improve type inference
return new CallArguments(CastedExpression.Level.INVALID, typeArguments, Expression.NONE);
return new CallArguments(CastedExpression.Level.INVALID, TypeID.NONE, typeArguments, Expression.NONE);
}

typeArguments = typeArguments2;
Expand Down Expand Up @@ -195,6 +199,6 @@ private static CallArguments match(
cArguments[cArguments.length - 1] = compiler.at(position).newArray(new ArrayTypeID(variadicType), varargArguments.toArray(Expression.NONE));
levelNormalCall = levelVarargCall;
}
return new CallArguments(levelNormalCall, typeArguments, cArguments);
return new CallArguments(levelNormalCall, expansionTypeArguments, typeArguments, cArguments);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,44 @@ public class CallArguments {
public static final CallArguments EMPTY = new CallArguments(Expression.NONE);

public final CastedExpression.Level level;
public final TypeID[] expansionTypeArguments;
public final TypeID[] typeArguments;
public final Expression[] arguments;

public CallArguments(Expression... arguments) {
this.level = CastedExpression.Level.EXACT;
this.expansionTypeArguments = TypeID.NONE;
this.typeArguments = TypeID.NONE;
this.arguments = arguments;
}

public CallArguments(TypeID[] typeArguments, Expression[] arguments) {
public CallArguments(TypeID[] expansionTypeArguments, TypeID[] typeArguments, Expression[] arguments) {
if (typeArguments == null)
typeArguments = TypeID.NONE;
if (arguments == null)
throw new IllegalArgumentException("Arguments cannot be null!");

this.level = CastedExpression.Level.EXACT;
this.expansionTypeArguments = expansionTypeArguments;
this.typeArguments = typeArguments;
this.arguments = arguments;
}

public CallArguments(CastedExpression.Level level, TypeID[] typeArguments, Expression[] arguments) {
public CallArguments(CastedExpression.Level level, TypeID[] expansionTypeArguments, TypeID[] typeArguments, Expression[] arguments) {
if (typeArguments == null)
typeArguments = TypeID.NONE;
if (arguments == null)
throw new IllegalArgumentException("Arguments cannot be null!");

this.level = level;
this.expansionTypeArguments = expansionTypeArguments;
this.typeArguments = typeArguments;
this.arguments = arguments;
}

public CallArguments(TypeID... dummy) {
this.level = CastedExpression.Level.EXACT;
this.expansionTypeArguments = TypeID.NONE;
this.typeArguments = TypeID.NONE;
this.arguments = new Expression[dummy.length];
for (int i = 0; i < dummy.length; i++)
Expand All @@ -50,7 +55,7 @@ public CallArguments bind(Expression target) {
Expression[] newArguments = new Expression[arguments.length + 1];
newArguments[0] = target;
System.arraycopy(arguments, 0, newArguments, 1, arguments.length);
return new CallArguments(typeArguments, newArguments);
return new CallArguments(level, expansionTypeArguments, typeArguments, newArguments);
}

public int getNumberOfTypeArguments() {
Expand All @@ -59,6 +64,6 @@ public int getNumberOfTypeArguments() {

public CallArguments transform(ExpressionTransformer transformer) {
Expression[] tArguments = Expression.transform(arguments, transformer);
return tArguments == arguments ? this : new CallArguments(level, typeArguments, tArguments);
return tArguments == arguments ? this : new CallArguments(level, expansionTypeArguments, typeArguments, tArguments);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,27 @@ public class MethodInstance implements InstanceCallableMethod, StaticCallableMet
public final MethodSymbol method;
private final FunctionHeader header;
private final TypeID target;
private final TypeID[] expansionTypeArguments;

public MethodInstance(MethodSymbol method) {
this.method = method;
this.header = method.getHeader();
this.target = method.getTargetType();
this.expansionTypeArguments = TypeID.NONE;
}

public MethodInstance(MethodSymbol method, FunctionHeader header, TypeID target) {
this.method = method;
this.header = header;
this.target = target;
this.expansionTypeArguments = TypeID.NONE;
}

public MethodInstance(MethodSymbol method, FunctionHeader header, TypeID target, TypeID[] expansionTypeArguments) {
this.method = method;
this.header = header;
this.target = target;
this.expansionTypeArguments = expansionTypeArguments;
}

public TypeID getTarget() {
Expand Down Expand Up @@ -66,4 +76,8 @@ public Expression call(ExpressionBuilder builder, CallArguments arguments) {
public boolean isImplicit() {
return method.getModifiers().isImplicit();
}

public TypeID[] getExpansionTypeArguments() {
return expansionTypeArguments;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public ResolvedType resolve(TypeID[] typeArguments) {
Map<TypeParameter, TypeID> parameterFilled = new HashMap<>();
parameterFilled.put(PARAMETER, typeArguments[0]);
parameterFilled.put(VALUE, value);
GenericMapper mapper = new GenericMapper(parameterFilled);
GenericMapper mapper = new GenericMapper(parameterFilled, TypeID.NONE);

TypeID valueType = mapper.map(value);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ public static boolean isLarge(TypeID type) {
return type == BasicTypeID.DOUBLE || type == BasicTypeID.LONG || type == BasicTypeID.ULONG;
}

private boolean isGenericReturn(TypeID type) {
return type.isGeneric() || type.asArray().map(array -> isGenericReturn(type)).orElse(false);
}

public static int calcAccess(Modifiers modifiers) {
int out = 0;
if (modifiers.isStatic())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.openzen.zenscript.javashared.expressions.JavaFunctionInterfaceCastExpression;
import org.openzen.zenscript.javashared.types.JavaFunctionalInterfaceTypeID;

import java.lang.reflect.Array;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.*;
Expand All @@ -31,6 +32,7 @@

public class JavaExpressionVisitor implements ExpressionVisitor<Void> {
private static final JavaNativeMethod MAP_PUT = JavaNativeMethod.getInterface(JavaClass.MAP, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
private static final JavaNativeMethod ARRAY_NEWINSTANCE = JavaNativeMethod.getNativeStatic(JavaClass.ARRAY, "newInstance", "(Ljava/lang/Class;I)Ljava/lang/Object;");
private static final MethodID CONSTRUCTOR = MethodID.staticOperator(OperatorType.CONSTRUCTOR);

final JavaWriter javaWriter;
Expand Down Expand Up @@ -108,9 +110,17 @@ public Void visitAndAnd(AndAndExpression expression) {

@Override
public Void visitArray(ArrayExpression expression) {
javaWriter.constant(expression.expressions.length);
Type type = context.getType(((ArrayTypeID) expression.type).elementType);
javaWriter.newArray(type);
Type type = context.getType(expression.arrayType.elementType);
if (expression.arrayType.elementType.isGeneric()) {

expression.arrayType.elementType.accept(javaWriter, new JavaTypeExpressionVisitor(context));
javaWriter.constant(expression.expressions.length);
javaWriter.invokeStatic(ARRAY_NEWINSTANCE);
javaWriter.checkCast(context.getInternalName(expression.arrayType));
} else {
javaWriter.constant(expression.expressions.length);
javaWriter.newArray(type);
}
for (int i = 0; i < expression.expressions.length; i++) {
javaWriter.dup();
javaWriter.constant(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ private void handleTypeArguments(JavaNativeMethod method, CallArguments argument
if (arguments.typeArguments.length != method.typeParameterArguments.length)
throw new IllegalArgumentException("Number of type parameters doesn't match");

for (int i = 0; i < arguments.expansionTypeArguments.length; i++) {
arguments.expansionTypeArguments[i].accept(javaWriter, javaTypeExpressionVisitor);
}
for (int i = 0; i < arguments.typeArguments.length; i++) {
if (method.typeParameterArguments[i])
arguments.typeArguments[i].accept(javaWriter, javaTypeExpressionVisitor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ private TypeID loadAnnotatedParameterizedType(TypeVariableContext context, Annot
map.put(typeParameter, codeParameters[i]);
}

return rawTypeId.instance(new GenericMapper(map));
return rawTypeId.instance(new GenericMapper(map, TypeID.NONE));
}
return this.loadType(context, JavaAnnotatedType.of(type), unsigned);
}
Expand Down Expand Up @@ -302,7 +302,7 @@ private TypeID loadFunctionalInterface(TypeVariableContext loadContext, Class<?>
mapping.put(context.get(javaParameters[i]), loadType(loadContext, parameters[i], false));
}

header = header.withGenericArguments(new GenericMapper(mapping));
header = header.withGenericArguments(new GenericMapper(mapping, TypeID.NONE));
JavaNativeMethod method = new JavaNativeMethod(
JavaClass.fromInternalName(getInternalName(cls), JavaClass.Kind.INTERFACE),
JavaNativeMethod.Kind.INTERFACE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class JavaClass implements Comparable<JavaClass> {
public static final JavaClass COLLECTION = new JavaClass("java.util", "Collection", Kind.INTERFACE);
public static final JavaClass COLLECTIONS = new JavaClass("java.util", "Collections", Kind.CLASS);
public static final JavaClass STRINGBUILDER = new JavaClass("java.lang", "StringBuilder", Kind.CLASS);
public static final JavaClass ARRAY = new JavaClass("java.lang.reflect", "Array", Kind.CLASS);

public static final JavaClass SHARED = new JavaClass("zsynthetic", "Shared", Kind.CLASS);
public final JavaClass outer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ public String getMethodDescriptor(FunctionHeader header) {
return getMethodDescriptor(header, false, "");
}

public boolean isGenericReturn(TypeID type) {
return type.isGeneric() || type.asArray().map(array -> isGenericReturn(array.elementType)).orElse(false);
}

public String getMethodDescriptorExpansion(FunctionHeader header, TypeID expandedType) {
StringBuilder startBuilder = new StringBuilder(getDescriptor(expandedType));
final List<TypeParameter> typeParameters = new ArrayList<>();
Expand Down Expand Up @@ -579,7 +583,11 @@ public String getMethodDescriptor(FunctionHeader header, boolean isEnumConstruct
descBuilder.append(getDescriptor(parameter.type));
}
descBuilder.append(")");
descBuilder.append(getDescriptor(header.getReturnType()));
if (isGenericReturn(header.getReturnType())) {
descBuilder.append("Ljava/lang/Object;");
} else {
descBuilder.append(getDescriptor(header.getReturnType()));
}
return descBuilder.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public String visitBasic(BasicTypeID basic) {

@Override
public String visitArray(ArrayTypeID array) {
return "[" + array.elementType.accept(this);
return "[" + array.elementType.accept(new JavaTypeDescriptorVisitor(context));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public void addMethod(MethodSymbol method, NativeTag native_) {
true,
descriptor,
JavaModifiers.getJavaModifiers(method.getModifiers()),
method.getHeader().getReturnType() instanceof GenericTypeID,
getContext().isGenericReturn(method.getHeader().getReturnType()),
method.getHeader().useTypeParameters());
addMethod(method, new JavaCompilingMethod(compiled, javaMethod, signature));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.openzen.zenscript.codemodel.member.*;
import org.openzen.zenscript.codemodel.type.BasicTypeID;
import org.openzen.zenscript.codemodel.type.GenericTypeID;
import org.openzen.zenscript.codemodel.type.TypeID;
import org.openzen.zenscript.javashared.*;
import org.openzen.zenscript.javashared.compiling.JavaCompilingClass;
import org.openzen.zenscript.javashared.compiling.JavaCompilingMethod;
Expand Down Expand Up @@ -261,7 +262,7 @@ private void visitFunctional(FunctionalMember member, FunctionHeader header, Str
true,
context.getMethodDescriptor(header),
modifiers | JavaModifiers.getJavaModifiers(member.getEffectiveModifiers()),
header.getReturnType() instanceof GenericTypeID,
context.isGenericReturn(header.getReturnType()),
header.useTypeParameters()),
signature);
} else if (method == null) {
Expand Down Expand Up @@ -302,7 +303,7 @@ private void visitFunctional(FunctionalMember member, FunctionHeader header, Str
true,
context.getMethodDescriptor(header),
modifiers | JavaModifiers.getJavaModifiers(member.getEffectiveModifiers()),
header.getReturnType() instanceof GenericTypeID,
context.isGenericReturn(header.getReturnType()),
header.useTypeParameters()),
signature);
}
Expand Down
Loading

0 comments on commit c9407d9

Please sign in to comment.