diff --git a/build.gradle b/build.gradle index ed8a5bec..c69fbadd 100644 --- a/build.gradle +++ b/build.gradle @@ -7,7 +7,7 @@ apply plugin: 'maven-publish' apply plugin: 'maven' group = 'org.squiddev' -version = '0.5.1-SNAPSHOT' +version = '0.6.0-SNAPSHOT' targetCompatibility = sourceCompatibility = 1.8 compileTestJava.sourceCompatibility = compileTestJava.targetCompatibility = 1.8 diff --git a/src/main/java/org/squiddev/cobalt/Lua.java b/src/main/java/org/squiddev/cobalt/Lua.java index 11e5158a..46ff8ab9 100644 --- a/src/main/java/org/squiddev/cobalt/Lua.java +++ b/src/main/java/org/squiddev/cobalt/Lua.java @@ -35,7 +35,7 @@ public class Lua { /** * Version is supplied by ant build task */ - public static final String _VERSION = "Luaj 0.0"; + public static final String _VERSION = "Lua 5.1"; /** * Use return values from previous op diff --git a/src/main/java/org/squiddev/cobalt/LuaDouble.java b/src/main/java/org/squiddev/cobalt/LuaDouble.java index fdacc260..86b10441 100644 --- a/src/main/java/org/squiddev/cobalt/LuaDouble.java +++ b/src/main/java/org/squiddev/cobalt/LuaDouble.java @@ -224,11 +224,6 @@ public double checkDouble() { return v; } - @Override - public String checkString() { - return toString(); - } - @Override public LuaString checkLuaString() { return strvalue(); @@ -239,9 +234,4 @@ public LuaValue checkValidKey() throws LuaError { if (Double.isNaN(v)) throw new LuaError("table index is NaN"); return this; } - - @Override - public double checkArith() { - return v; - } } diff --git a/src/main/java/org/squiddev/cobalt/LuaInteger.java b/src/main/java/org/squiddev/cobalt/LuaInteger.java index cf87f837..b6f18992 100644 --- a/src/main/java/org/squiddev/cobalt/LuaInteger.java +++ b/src/main/java/org/squiddev/cobalt/LuaInteger.java @@ -205,18 +205,8 @@ public double checkDouble() { return v; } - @Override - public String checkString() { - return String.valueOf(v); - } - @Override public LuaString checkLuaString() { return ValueFactory.valueOf(String.valueOf(v)); } - - @Override - public double checkArith() { - return v; - } } diff --git a/src/main/java/org/squiddev/cobalt/LuaNumber.java b/src/main/java/org/squiddev/cobalt/LuaNumber.java index 6d272f24..060deb04 100644 --- a/src/main/java/org/squiddev/cobalt/LuaNumber.java +++ b/src/main/java/org/squiddev/cobalt/LuaNumber.java @@ -41,37 +41,42 @@ public LuaNumber() { } @Override - public LuaNumber checkNumber() { + public final LuaNumber checkNumber() { return this; } @Override - public LuaNumber checkNumber(String errmsg) { + public final LuaNumber checkNumber(String errmsg) { return this; } @Override - public LuaNumber optNumber(LuaNumber defval) { + public final LuaNumber optNumber(LuaNumber defval) { return this; } @Override - public LuaValue toNumber() { + public final LuaValue toNumber() { return this; } @Override - public boolean isNumber() { + public final boolean isNumber() { return true; } @Override - public LuaTable getMetatable(LuaState state) { + public final LuaTable getMetatable(LuaState state) { return state.numberMetatable; } @Override - public double checkArith() { + public final double checkArith() { return toDouble(); } + + @Override + public final String checkString() { + return toString(); + } } diff --git a/src/main/java/org/squiddev/cobalt/LuaState.java b/src/main/java/org/squiddev/cobalt/LuaState.java index b5bcbe2a..65616362 100644 --- a/src/main/java/org/squiddev/cobalt/LuaState.java +++ b/src/main/java/org/squiddev/cobalt/LuaState.java @@ -29,9 +29,12 @@ import org.squiddev.cobalt.debug.DebugHandler; import org.squiddev.cobalt.lib.platform.FileResourceManipulator; import org.squiddev.cobalt.lib.platform.ResourceManipulator; +import org.squiddev.cobalt.persist.Serializer; import java.io.InputStream; import java.io.PrintStream; +import java.util.HashMap; +import java.util.Map; import java.util.Random; import java.util.TimeZone; import java.util.concurrent.Executor; @@ -134,6 +137,9 @@ public final class LuaState { */ boolean abandoned; + private final Map> serializers = new HashMap<>(); + private final Map permanent = new HashMap<>(); + public LuaState() { this(new LuaState.Builder()); } @@ -201,6 +207,30 @@ public void setupThread(LuaTable environment) { currentThread = thread; } + public void addSerializer(Serializer serializer) { + String name = serializer.getName(); + Serializer current = serializers.get(name); + if (serializer == current) return; + if (current != null) throw new IllegalArgumentException("Duplicate serializers for " + name); + serializers.put(name, serializer); + } + + public Serializer getSerializer(String name) { + return serializers.get(name); + } + + public void addPermanent(String key, Object value) { + Object current = permanent.get(key); + if (value == current) return; + // TODO: Do we want this + // if (current != null) throw new IllegalArgumentException("Duplicate object for " + key); + permanent.put(key, value); + } + + public Object getPermanent(String name) { + return permanent.get(name); + } + public static LuaState.Builder builder() { return new LuaState.Builder(); } diff --git a/src/main/java/org/squiddev/cobalt/LuaThread.java b/src/main/java/org/squiddev/cobalt/LuaThread.java index 7954f6cc..17f30b47 100644 --- a/src/main/java/org/squiddev/cobalt/LuaThread.java +++ b/src/main/java/org/squiddev/cobalt/LuaThread.java @@ -30,7 +30,10 @@ import org.squiddev.cobalt.function.LuaFunction; import org.squiddev.cobalt.lib.CoroutineLib; import org.squiddev.cobalt.lib.jse.JsePlatform; +import org.squiddev.cobalt.persist.ValueReader; +import org.squiddev.cobalt.persist.ValueWriter; +import java.io.IOException; import java.lang.ref.WeakReference; import java.util.Objects; import java.util.concurrent.TimeUnit; @@ -674,6 +677,22 @@ static class State { } } + public void writeInternalState(ValueWriter writer) throws IOException { + if (function != null && state.status == STATUS_RUNNING) { + throw new IOException("Cannot serialize running coroutine"); + } + if (state.needsThreadedResume) throw new IOException("Cannot serialize coroutine which needs threaded resume"); + if (state.javaCount != 0) throw new IOException("Cannot serialize coroutine with Java functions on stack"); + writer.writeVarInt(state.status); + writer.write(state.previousThread == null ? Constants.NIL : state.previousThread); + } + + public void readInternalState(ValueReader reader) throws IOException { + state.status = reader.readVarInt(); + LuaValue previous = (LuaValue) reader.read(); + state.previousThread = previous.isNil() ? null : (LuaThread) previous; + } + /** * Used inside {@link #loop(LuaState, LuaThread, LuaFunction, Varargs)} when * this particular thread has transferred control elsewhere. diff --git a/src/main/java/org/squiddev/cobalt/debug/DebugHandler.java b/src/main/java/org/squiddev/cobalt/debug/DebugHandler.java index 7c7cd268..3fdb4291 100644 --- a/src/main/java/org/squiddev/cobalt/debug/DebugHandler.java +++ b/src/main/java/org/squiddev/cobalt/debug/DebugHandler.java @@ -42,9 +42,6 @@ public static DebugState getDebugState(LuaState state) { return state.getCurrentThread().getDebugState(); } - protected DebugHandler() { - } - /** * Called by Closures and recurring Java functions on return * diff --git a/src/main/java/org/squiddev/cobalt/debug/DebugState.java b/src/main/java/org/squiddev/cobalt/debug/DebugState.java index 4743a03b..902806e2 100644 --- a/src/main/java/org/squiddev/cobalt/debug/DebugState.java +++ b/src/main/java/org/squiddev/cobalt/debug/DebugState.java @@ -24,10 +24,14 @@ */ package org.squiddev.cobalt.debug; -import org.squiddev.cobalt.LuaError; -import org.squiddev.cobalt.LuaState; -import org.squiddev.cobalt.UnwindThrowable; +import org.squiddev.cobalt.*; +import org.squiddev.cobalt.function.LuaClosure; import org.squiddev.cobalt.function.LuaFunction; +import org.squiddev.cobalt.function.Upvalue; +import org.squiddev.cobalt.persist.ValueReader; +import org.squiddev.cobalt.persist.ValueWriter; + +import java.io.IOException; import static org.squiddev.cobalt.debug.DebugFrame.*; @@ -104,9 +108,9 @@ public LuaState getLuaState() { /** * Push a new debug frame onto the stack, marking it as also consuming one or more Java stack frames. * - * @return The created info. This should be marked with {@link DebugFrame#FLAG_JAVA} or { + * @return The created info. This should be marked with {@link DebugFrame#FLAG_JAVA} or + * {@link DebugFrame#FLAG_FRESH} by the calling function. * @throws LuaError On a stack overflow - * @link DebugFrame#FLAG_FRESH} by the calling function. */ public DebugFrame pushJavaInfo() throws LuaError { int javaTop = this.javaCount + 1; @@ -272,4 +276,108 @@ void hookLine(DebugFrame frame, int newLine) throws LuaError, UnwindThrowable { inhook = false; frame.flags &= ~FLAG_HOOKED; } + + public int getTop() { + return top; + } + + private static final int HOOK_CALL = 1 << 0; + private static final int HOOK_RETURN = 1 << 1; + private static final int HOOK_LINE = 1 << 2; + + public void writeInternalState(ValueWriter writer) throws IOException { + if (hookfunc instanceof LuaValue) { + writer.write((LuaValue) hookfunc); + writer.writeByte((hookcall ? HOOK_CALL : 0) + | (hookrtrn ? HOOK_RETURN : 0) + | (hookline ? HOOK_LINE : 0)); + writer.writeVarInt(hookcount); + writer.writeVarInt(hookcodes); + } else { + writer.write(Constants.NIL); + } + + writer.writeVarInt(javaCount); + + // Call stack + writer.writeVarInt(top); + for (int i = 0; i <= top; i++) { + DebugFrame frame = stack[i]; + assert frame != null; + + writer.write(frame.func); + writer.writeVarInt(frame.flags); + + writer.writeByte(frame.closure == null ? 0 : 1); + if (frame.closure != null) { + writer.write(frame.stack); + if (frame.stackUpvalues != null) { + for (int j = 0; j < frame.stack.length; j++) { + Upvalue u = frame.stackUpvalues[j]; + if (u == null) continue; + writer.writeVarInt(j); + writer.write(u); + } + } + writer.writeVarInt(-1); + + writer.write(frame.varargs); + writer.write(frame.extras); + writer.writeVarInt(frame.pc); + writer.writeVarInt(frame.oldPc); + writer.writeVarInt(frame.top); + } else if (frame.state == null) { + writer.write(Constants.NIL); + } else { + writer.serialize(frame.state); + } + } + } + + public void readInternalState(ValueReader reader) throws IOException { + // Debug functions + LuaValue hook = (LuaValue) reader.read(); + if (!hook.isNil()) { + hookfunc = (LuaFunction) hook; + int state = reader.readByte(); + hookcall = (state & HOOK_CALL) != 0; + hookrtrn = (state & HOOK_RETURN) != 0; + hookline = (state & HOOK_LINE) != 0; + + hookcount = reader.readVarInt(); + hookcodes = reader.readVarInt(); + } + + javaCount = reader.readVarInt(); + + // Call stack + top = reader.readVarInt(); + stack = new DebugFrame[top + 1]; + for (int i = 0; i <= top; i++) { + DebugFrame frame = stack[i] = new DebugFrame(i > 0 ? stack[i - 1] : null); + + frame.func = (LuaFunction) reader.read(); + frame.flags = reader.readVarInt(); + if (reader.readByte() != 0) { + frame.closure = (LuaClosure) frame.func; + frame.stack = (LuaValue[]) reader.read(); + + frame.stackUpvalues = new Upvalue[frame.stack.length]; + while (true) { + int j = reader.readVarInt(); + if (j == -1) break; + + frame.stackUpvalues[j] = (Upvalue) reader.read(); + } + + frame.varargs = reader.readVarargs(); + frame.extras = reader.readVarargs(); + frame.pc = reader.readVarInt(); + frame.oldPc = reader.readVarInt(); + frame.top = reader.readVarInt(); + } else { + frame.state = reader.read(); + } + } + } } diff --git a/src/main/java/org/squiddev/cobalt/function/LibFunction.java b/src/main/java/org/squiddev/cobalt/function/LibFunction.java index ce1d64e0..d1fe0fda 100644 --- a/src/main/java/org/squiddev/cobalt/function/LibFunction.java +++ b/src/main/java/org/squiddev/cobalt/function/LibFunction.java @@ -29,7 +29,12 @@ import org.squiddev.cobalt.LuaValue; import org.squiddev.cobalt.lib.BaseLib; import org.squiddev.cobalt.lib.TableLib; +import org.squiddev.cobalt.persist.Serializable; +import org.squiddev.cobalt.persist.Serializer; +import org.squiddev.cobalt.persist.ValueReader; +import org.squiddev.cobalt.persist.ValueWriter; +import java.io.IOException; import java.util.function.Supplier; /** @@ -91,7 +96,7 @@ * } * } * } - * + * } * The default constructor is used to instantiate the library * in response to {@code require 'hyperbolic'} statement, * provided it is on Javas class path. @@ -129,12 +134,11 @@ * See the source code in any of the library functions * such as {@link BaseLib} or {@link TableLib} for other examples. */ -public abstract class LibFunction extends LuaFunction { - +public abstract class LibFunction extends LuaFunction implements Serializable { /** * User-defined opcode to differentiate between instances of the library function class. * - * Subclass will typicall switch on this value to provide the specific behavior for each function. + * Subclass will typically switch on this value to provide the specific behavior for each function. */ protected int opcode; @@ -145,6 +149,11 @@ public abstract class LibFunction extends LuaFunction { */ protected String name; + /** + * The library this function belongs to. + */ + private String library; + /** * Default constructor for use by subclasses */ @@ -162,13 +171,15 @@ public String debugName() { * An array of names is provided, and the first name is bound * with opcode = 0, second with 1, etc. * + * @param state The active Lua state. + * @param name The name of the set of the library one is binding. * @param env The environment to apply to each bound function * @param factory The factory to provide a new instance each time * @param names Array of function names - * @see #bind(LuaTable, Supplier, String[], int) + * @see #bind(LuaState, String, LuaTable, Supplier, String[], int) */ - public static void bind(LuaTable env, Supplier factory, String[] names) { - bind(env, factory, names, 0); + public static void bind(LuaState state, String name, LuaTable env, Supplier factory, String[] names) { + bind(state, name, env, factory, names, 0); } /** @@ -177,19 +188,61 @@ public static void bind(LuaTable env, Supplier factory, String[] na * An array of names is provided, and the first name is bound * with opcode = {@code firstopcode}, second with {@code firstopcode+1}, etc. * + * @param state The active Lua state. + * @param name The name of the set of the library one is binding. * @param env The environment to apply to each bound function * @param factory The factory to provide a new instance each time * @param names Array of function names * @param firstOpcode The first opcode to use - * @see #bind(LuaTable, Supplier, String[]) + * @see #bind(LuaState, String, LuaTable, Supplier, String[]) */ - public static void bind(LuaTable env, Supplier factory, String[] names, int firstOpcode) { + public static void bind(LuaState state, String name, LuaTable env, Supplier factory, String[] names, int firstOpcode) { + state.addSerializer(SerializerImpl.INSTANCE); + for (int i = 0; i < names.length; i++) { LibFunction f = factory.get(); + f.setName(state, name, names[i]); f.opcode = firstOpcode + i; - f.name = names[i]; f.env = env; env.rawset(f.name, f); } } + + protected void setName(LuaState state, String library, String name) { + this.name = name; + this.library = library; + if (library != null) state.addPermanent(library + "." + name, this); + } + + @Override + public Serializer getSerializer() { + return name != null && library != null ? SerializerImpl.INSTANCE : null; + } + + private static class SerializerImpl implements Serializer { + static final Serializer INSTANCE = new SerializerImpl(); + + private SerializerImpl() { + } + + @Override + public String getName() { + return "cobalt.permanent"; + } + + @Override + public void save(ValueWriter writer, LibFunction value) throws IOException { + writer.write(value.library); + writer.write(value.name); + } + + @Override + public LibFunction load(ValueReader reader) throws IOException { + String key = reader.readString() + "." + reader.readString(); + Object value = reader.getState().getPermanent(key); + if (value == null) throw new IOException("Cannot find " + key); + if (!(value instanceof LibFunction)) throw new IOException("Malformed value for " + key); + return (LibFunction) value; + } + } } diff --git a/src/main/java/org/squiddev/cobalt/function/LuaInterpretedFunction.java b/src/main/java/org/squiddev/cobalt/function/LuaInterpretedFunction.java index 7d3ccb44..0d12f3ab 100644 --- a/src/main/java/org/squiddev/cobalt/function/LuaInterpretedFunction.java +++ b/src/main/java/org/squiddev/cobalt/function/LuaInterpretedFunction.java @@ -92,14 +92,18 @@ * @see LoadState */ public final class LuaInterpretedFunction extends LuaClosure implements Resumable { - private static final Upvalue[] NO_UPVALUES = new Upvalue[0]; + public static final Upvalue[] NO_UPVALUES = new Upvalue[0]; public final Prototype p; public final Upvalue[] upvalues; - public LuaInterpretedFunction(Prototype p) { + public LuaInterpretedFunction(Prototype p, Upvalue[] upvalues) { this.p = p; - this.upvalues = p.nups > 0 ? new Upvalue[p.nups] : NO_UPVALUES; + this.upvalues = upvalues; + } + + public LuaInterpretedFunction(Prototype p) { + this(p, p.nups > 0 ? new Upvalue[p.nups] : NO_UPVALUES); } /** diff --git a/src/main/java/org/squiddev/cobalt/function/Upvalue.java b/src/main/java/org/squiddev/cobalt/function/Upvalue.java index d072c6a9..532ef854 100644 --- a/src/main/java/org/squiddev/cobalt/function/Upvalue.java +++ b/src/main/java/org/squiddev/cobalt/function/Upvalue.java @@ -88,4 +88,12 @@ public void close() { array = new LuaValue[]{array[index]}; index = 0; } + + public LuaValue[] getArray() { + return array; + } + + public int getIndex() { + return index; + } } diff --git a/src/main/java/org/squiddev/cobalt/lib/BaseLib.java b/src/main/java/org/squiddev/cobalt/lib/BaseLib.java index 937e6f85..0a9774b0 100644 --- a/src/main/java/org/squiddev/cobalt/lib/BaseLib.java +++ b/src/main/java/org/squiddev/cobalt/lib/BaseLib.java @@ -32,7 +32,12 @@ import org.squiddev.cobalt.function.*; import org.squiddev.cobalt.lib.jse.JsePlatform; import org.squiddev.cobalt.lib.platform.ResourceManipulator; +import org.squiddev.cobalt.persist.Serializable; +import org.squiddev.cobalt.persist.Serializer; +import org.squiddev.cobalt.persist.ValueReader; +import org.squiddev.cobalt.persist.ValueWriter; +import java.io.IOException; import java.io.InputStream; import static org.squiddev.cobalt.OperationHelper.noUnwind; @@ -102,16 +107,16 @@ public class BaseLib implements LuaLibrary { public LuaValue add(LuaState state, LuaTable env) { env.rawset("_G", env); env.rawset("_VERSION", valueOf(Lua._VERSION)); - LibFunction.bind(env, BaseLib2::new, LIB2_KEYS); - LibFunction.bind(env, () -> new BaseLibV(this), LIBV_KEYS); - LibFunction.bind(env, BaseLibR::new, LIBR_KEYS); + + state.addSerializer(PCALL_SERIALIZER); + LibFunction.bind(state, "_G", env, BaseLib2::new, LIB2_KEYS); + LibFunction.bind(state, "_G", env, () -> new BaseLibV(this), LIBV_KEYS); + LibFunction.bind(state, "_G", env, BaseLibR::new, LIBR_KEYS); // remember next, and inext for use in pairs and ipairs next = env.rawget("next"); inext = env.rawget("__inext"); - env.rawset("_VERSION", valueOf("Lua 5.1")); - return env; } @@ -298,7 +303,7 @@ public Varargs invoke(LuaState state, Varargs args) throws LuaError, UnwindThrow case 16: { // "pairs" (t) -> iter-func, t, nil LuaValue value = args.checkValue(1); LuaValue pairs = value.metatag(state, Constants.PAIRS); - if(pairs.isNil()) { + if (pairs.isNil()) { return varargsOf(baselib.next, value, Constants.NIL); } else { return OperationHelper.invoke(state, pairs, value); @@ -333,7 +338,7 @@ protected Varargs invoke(LuaState state, DebugFrame di, Varargs args) throws Lua case 0: // "pcall", // (f, arg1, ...) -> status, result1, ... return pcall(state, di, args.checkValue(1), args.subargs(2), null); case 1: // "xpcall", // (f, err) -> result1, ... - return pcall(state, di, args.checkValue(1), Constants.NONE, args.checkValue(2)); + return pcall(state, di, args.checkValue(1), Constants.NONE, args.checkNotNil(2)); case 2: // "load", // ( func|str [,chunkname[, mode[, env]]] ) -> chunk | nil, msg { @@ -416,12 +421,45 @@ private Varargs finish(PCallState pState, Varargs value) { } } - private static final class PCallState { - DebugFrame frame; + private static final class PCallState implements Serializable { + int frame; LuaValue oldErrorFunc; boolean errored = false; + + @Override + public Serializer getSerializer() { + return PCALL_SERIALIZER; + } } + private static final Serializer PCALL_SERIALIZER = new Serializer() { + @Override + public String getName() { + return "cobalt.base$pcall"; + } + + @Override + public void save(ValueWriter writer, PCallState value) throws IOException { + writer.writeByte( + (value.oldErrorFunc != null ? 1 : 0) + | (value.errored ? 2 : 0) + ); + if (value.oldErrorFunc != null) writer.write(value.oldErrorFunc); + writer.writeVarInt(value.frame); + } + + @Override + public PCallState load(ValueReader reader) throws IOException { + int flags = reader.readByte(); + + PCallState state = new PCallState(); + state.errored = (flags & 2) != 0; + state.oldErrorFunc = (flags & 1) != 0 ? (LuaValue)reader.read() : null; + state.frame = reader.readVarInt(); + return state; + } + }; + private static Varargs pcall(LuaState state, DebugFrame di, LuaValue func, Varargs args, LuaValue errFunc) throws UnwindThrowable { // Mark this frame as being an error handler PCallState pState = new PCallState(); @@ -429,7 +467,8 @@ private static Varargs pcall(LuaState state, DebugFrame di, LuaValue func, Varar di.flags |= FLAG_YPCALL; // Store this frame in the current state. - pState.frame = di; + int top = DebugHandler.getDebugState(state).getTop(); + pState.frame = top; LuaValue oldErr = pState.oldErrorFunc = state.getCurrentThread().setErrorFunc(errFunc); try { @@ -447,18 +486,17 @@ private static Varargs pcall(LuaState state, DebugFrame di, LuaValue func, Varar le.fillTraceback(state); state.getCurrentThread().setErrorFunc(oldErr); - closeUntil(state, di); + closeUntil(state, top); return varargsOf(Constants.FALSE, le.value); } } - private static void closeUntil(LuaState state, DebugFrame top) { + private static void closeUntil(LuaState state, int top) { DebugState ds = DebugHandler.getDebugState(state); DebugHandler handler = state.debug; - DebugFrame current; - while ((current = ds.getStackUnsafe()) != top) { - current.cleanup(); + while (ds.getTop() != top) { + ds.getStackUnsafe().cleanup(); handler.onReturnError(ds); } } diff --git a/src/main/java/org/squiddev/cobalt/lib/Bit32Lib.java b/src/main/java/org/squiddev/cobalt/lib/Bit32Lib.java index 305c9200..3a568100 100644 --- a/src/main/java/org/squiddev/cobalt/lib/Bit32Lib.java +++ b/src/main/java/org/squiddev/cobalt/lib/Bit32Lib.java @@ -41,8 +41,8 @@ public class Bit32Lib implements LuaLibrary { @Override public LuaValue add(LuaState state, LuaTable env) { LuaTable t = new LuaTable(); - bind(t, Bit32LibV::new, new String[]{"band", "bnot", "bor", "btest", "bxor", "extract", "replace"}); - bind(t, Bit32Lib2::new, new String[]{"arshift", "lrotate", "lshift", "rrotate", "rshift"}); + bind(state, "bit32", t, Bit32LibV::new, new String[]{"band", "bnot", "bor", "btest", "bxor", "extract", "replace"}); + bind(state, "bit32", t, Bit32Lib2::new, new String[]{"arshift", "lrotate", "lshift", "rrotate", "rshift"}); env.rawset("bit32", t); state.loadedPackages.rawset("bit32", t); return t; diff --git a/src/main/java/org/squiddev/cobalt/lib/CoroutineLib.java b/src/main/java/org/squiddev/cobalt/lib/CoroutineLib.java index 366830b3..2b60fb9d 100644 --- a/src/main/java/org/squiddev/cobalt/lib/CoroutineLib.java +++ b/src/main/java/org/squiddev/cobalt/lib/CoroutineLib.java @@ -73,7 +73,7 @@ private CoroutineLib(LuaThread thread) { @Override public LuaValue add(LuaState state, LuaTable env) { LuaTable t = new LuaTable(); - bind(t, CoroutineLib::new, new String[]{"create", "resume", "running", "status", "yield", "wrap"}); + bind(state, "coroutine", t, CoroutineLib::new, new String[]{"create", "resume", "running", "status", "yield", "wrap"}); env.rawset("coroutine", t); state.loadedPackages.rawset("coroutine", t); return t; diff --git a/src/main/java/org/squiddev/cobalt/lib/DebugLib.java b/src/main/java/org/squiddev/cobalt/lib/DebugLib.java index 919e14d1..b9838cb4 100644 --- a/src/main/java/org/squiddev/cobalt/lib/DebugLib.java +++ b/src/main/java/org/squiddev/cobalt/lib/DebugLib.java @@ -93,7 +93,7 @@ public class DebugLib extends VarArgFunction implements LuaLibrary { @Override public LuaTable add(LuaState state, LuaTable env) { LuaTable t = new LuaTable(); - bind(t, DebugLib::new, NAMES); + bind(state, "debug", t, DebugLib::new, NAMES); env.rawset("debug", t); state.loadedPackages.rawset("debug", t); return t; diff --git a/src/main/java/org/squiddev/cobalt/lib/IoLib.java b/src/main/java/org/squiddev/cobalt/lib/IoLib.java index f21f7e11..613aac74 100644 --- a/src/main/java/org/squiddev/cobalt/lib/IoLib.java +++ b/src/main/java/org/squiddev/cobalt/lib/IoLib.java @@ -224,7 +224,7 @@ public LuaValue add(LuaState state, LuaTable env) { // io lib functions LuaTable t = new LuaTable(); - LibFunction.bind(t, () -> new IoLibV(this), IO_NAMES); + LibFunction.bind(state, "io", t, () -> new IoLibV(this), IO_NAMES); // setup streams try { @@ -237,7 +237,7 @@ public LuaValue add(LuaState state, LuaTable env) { // create file methods table filemethods = new LuaTable(); - LibFunction.bind(filemethods, () -> new IoLibV(this), FILE_NAMES, FILE_CLOSE); + LibFunction.bind(state, "io$file", filemethods, () -> new IoLibV(this), FILE_NAMES, FILE_CLOSE); // setup library and index filemethods.rawset("__index", filemethods); diff --git a/src/main/java/org/squiddev/cobalt/lib/MathLib.java b/src/main/java/org/squiddev/cobalt/lib/MathLib.java index 134b05ae..26ebd9fc 100644 --- a/src/main/java/org/squiddev/cobalt/lib/MathLib.java +++ b/src/main/java/org/squiddev/cobalt/lib/MathLib.java @@ -52,20 +52,16 @@ public LuaValue add(LuaState state, LuaTable env) { LuaTable t = new LuaTable(0, 30); t.rawset("pi", ValueFactory.valueOf(Math.PI)); t.rawset("huge", LuaDouble.POSINF); - LibFunction.bind(t, MathLib1::new, new String[]{ - "abs", "ceil", "cos", "deg", - "exp", "floor", "rad", "sin", - "sqrt", "tan", - "acos", "asin", "atan", "cosh", - "exp", "log10", "sinh", - "tanh" + LibFunction.bind(state, "math", t, MathLib1::new, new String[]{ + "abs", "ceil", "cos", "deg", "exp", "floor", "rad", "sin", + "sqrt", "tan", "acos", "asin", "atan", "cosh", "log10", "sinh", "tanh" }); - LibFunction.bind(t, MathLib2::new, new String[]{ + LibFunction.bind(state, "math", t, MathLib2::new, new String[]{ "fmod", "ldexp", "pow", "atan2", "log" }); - LibFunction.bind(t, MathLibV::new, new String[]{ - "frexp", "max", "min", "modf", - "randomseed", "random",}); + LibFunction.bind(state, "math", t, MathLibV::new, new String[]{ + "frexp", "max", "min", "modf", "randomseed", "random", + }); t.rawset("mod", t.rawget("fmod")); env.rawset("math", t); @@ -106,12 +102,10 @@ public LuaValue call(LuaState state, LuaValue arg) throws LuaError { case 13: return ValueFactory.valueOf(Math.cosh(arg.checkDouble())); case 14: - return ValueFactory.valueOf(Math.exp(arg.checkDouble())); - case 15: return ValueFactory.valueOf(Math.log10(arg.checkDouble())); - case 16: + case 15: return ValueFactory.valueOf(Math.sinh(arg.checkDouble())); - case 17: + case 16: return ValueFactory.valueOf(Math.tanh(arg.checkDouble())); } return Constants.NIL; diff --git a/src/main/java/org/squiddev/cobalt/lib/OsLib.java b/src/main/java/org/squiddev/cobalt/lib/OsLib.java index f8a3faf4..60d18d6b 100644 --- a/src/main/java/org/squiddev/cobalt/lib/OsLib.java +++ b/src/main/java/org/squiddev/cobalt/lib/OsLib.java @@ -86,7 +86,7 @@ public class OsLib extends VarArgFunction implements LuaLibrary { @Override public LuaValue add(LuaState state, LuaTable env) { LuaTable t = new LuaTable(); - LibFunction.bind(t, OsLib::new, NAMES); + LibFunction.bind(state, "os", t, OsLib::new, NAMES); env.rawset("os", t); state.loadedPackages.rawset("os", t); return t; diff --git a/src/main/java/org/squiddev/cobalt/lib/StringLib.java b/src/main/java/org/squiddev/cobalt/lib/StringLib.java index d7898879..144adfe2 100644 --- a/src/main/java/org/squiddev/cobalt/lib/StringLib.java +++ b/src/main/java/org/squiddev/cobalt/lib/StringLib.java @@ -55,13 +55,13 @@ public class StringLib implements LuaLibrary { @Override public LuaValue add(LuaState state, LuaTable env) { LuaTable t = new LuaTable(); - LibFunction.bind(t, StringLib1::new, new String[]{ + LibFunction.bind(state, "string", t, StringLib1::new, new String[]{ "len", "lower", "reverse", "upper", "packsize" }); - LibFunction.bind(t, StringLibV::new, new String[]{ + LibFunction.bind(state, "string", t, StringLibV::new, new String[]{ "dump", "byte", "char", "find", "gmatch", "match", "rep", "sub", "pack", "unpack" }); - LibFunction.bind(t, StringLibR::new, new String[]{"gsub", "format"}); + LibFunction.bind(state, "string", t, StringLibR::new, new String[]{"gsub", "format"}); t.rawset("gfind", t.rawget("gmatch")); env.rawset("string", t); diff --git a/src/main/java/org/squiddev/cobalt/lib/TableLib.java b/src/main/java/org/squiddev/cobalt/lib/TableLib.java index 9608f173..c32df685 100644 --- a/src/main/java/org/squiddev/cobalt/lib/TableLib.java +++ b/src/main/java/org/squiddev/cobalt/lib/TableLib.java @@ -49,9 +49,9 @@ public class TableLib implements LuaLibrary { @Override public LuaTable add(LuaState state, LuaTable env) { LuaTable t = new LuaTable(); - LibFunction.bind(t, TableLib1::new, new String[]{"getn", "maxn",}); - LibFunction.bind(t, TableLibV::new, new String[]{"remove", "concat", "insert", "pack"}); - LibFunction.bind(t, TableLibR::new, new String[]{"sort", "foreach", "foreachi", "unpack"}); + LibFunction.bind(state, "table", t, TableLib1::new, new String[]{"getn", "maxn",}); + LibFunction.bind(state, "table", t, TableLibV::new, new String[]{"remove", "concat", "insert", "pack"}); + LibFunction.bind(state, "table", t, TableLibR::new, new String[]{"sort", "foreach", "foreachi", "unpack"}); env.rawset("table", t); state.loadedPackages.rawset("table", t); return t; diff --git a/src/main/java/org/squiddev/cobalt/lib/Utf8Lib.java b/src/main/java/org/squiddev/cobalt/lib/Utf8Lib.java index 872a2c1c..dc284ad8 100644 --- a/src/main/java/org/squiddev/cobalt/lib/Utf8Lib.java +++ b/src/main/java/org/squiddev/cobalt/lib/Utf8Lib.java @@ -31,12 +31,11 @@ public class Utf8Lib implements LuaLibrary { public LuaValue add(LuaState state, LuaTable environment) { LuaTable t = new LuaTable(0, 6); t.rawset("charpattern", PATTERN); - LibFunction.bind(t, Utf8Char::new, new String[]{"char", "codes", "codepoint", "len", "offset"}); + LibFunction.bind(state, "utf8", t, Utf8Char::new, new String[]{"char", "codes", "codepoint", "len", "offset"}); environment.rawset("utf8", t); state.loadedPackages.rawset("utf8", t); - codesIter = new Utf8CodesIter(); - codesIter.setfenv(environment); + codesIter = new Utf8CodesIter(state, environment); return t; } @@ -207,6 +206,11 @@ private static boolean isCont(LuaString s, int idx) { * invariant state in the hopes that this is the tiniest bit faster. */ private static class Utf8CodesIter extends VarArgFunction { + public Utf8CodesIter(LuaState state, LuaTable env) { + setName(state, "utf8", "$iter"); + setfenv(env); + } + @Override public Varargs invoke(LuaState state, Varargs args) throws LuaError, UnwindThrowable { // Arg 1: invariant state (the string) diff --git a/src/main/java/org/squiddev/cobalt/lib/jse/JsePlatform.java b/src/main/java/org/squiddev/cobalt/lib/jse/JsePlatform.java index c7c97f24..d027238e 100644 --- a/src/main/java/org/squiddev/cobalt/lib/jse/JsePlatform.java +++ b/src/main/java/org/squiddev/cobalt/lib/jse/JsePlatform.java @@ -27,6 +27,7 @@ import org.squiddev.cobalt.LuaState; import org.squiddev.cobalt.LuaTable; import org.squiddev.cobalt.compiler.LuaC; +import org.squiddev.cobalt.function.LibFunction; import org.squiddev.cobalt.lib.*; import org.squiddev.cobalt.lib.platform.ResourceManipulator; @@ -91,7 +92,6 @@ public static LuaTable standardGlobals(LuaState state) { _G.load(state, new MathLib()); _G.load(state, new JseIoLib()); _G.load(state, new OsLib()); - _G.load(state, new Utf8Lib()); return _G; } diff --git a/src/main/java/org/squiddev/cobalt/persist/Persist.java b/src/main/java/org/squiddev/cobalt/persist/Persist.java new file mode 100644 index 00000000..8847cab1 --- /dev/null +++ b/src/main/java/org/squiddev/cobalt/persist/Persist.java @@ -0,0 +1,89 @@ +package org.squiddev.cobalt.persist; + +import org.squiddev.cobalt.LuaState; +import org.squiddev.cobalt.LuaValue; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Constants for persistence. + * + * Saving and loading are handled by {@link ValueWriter} and {@link ValueWriter} respectively. The format is relatively + * simple, but worth explaining below. + * + * Say we wish to write value {@code x}. We may take on of several actions: + *
    + *
  • + * If this is a simple value (a number, boolean, nil, etc...) we write the value's type and then the value's + * representation. + *
  • + *
  • + * If this is a reference value, and we've seen it before, we write {@link Persist#TAG_REFERENCE} and the + * value's ID. + *
  • + *
  • + * If we've not seen the value before, we add it to the object->id table. The id is just an incrementing id, and + * so does not need to be written when writing the value. + * + * We now come to write the value's tag. When writing a value we keep track of its depth. Once we're deeper than + * {@link ValueWriter#MAX_DEPTH}, we also write {@link Persist#FLAG_PARTIAL} as part of the tag, add it to a + * queue and abort. Otherwise we recursively visit this value and write its contents (such as a table's fields, + * etc...). + *
  • + *
+ * + * Once all values have been written, we loop through the queue. Each enqueued value's type is written (ored with + * {@link Persist#FLAG_PARTIAL}) and then the body written. + * + * Reading follows much the same principle (read the main body, and then populate later on). Care must be taken that + * values are added to the id->object table before their body is read. + */ +public final class Persist { + private Persist() { + } + + static final int FLAG_PARTIAL = 1 << 7; + static final int FLAG_POPULATE = 1 << 6; + static final int MASK_TAG = ~(FLAG_PARTIAL | FLAG_POPULATE); + + static final int TAG_NIL = 0; + static final int TAG_TRUE = 1; + static final int TAG_FALSE = 2; + + static final int TAG_SHORT_INT = 3; + static final int TAG_INT = 4; + static final int TAG_FLOAT = 5; + + static final int TAG_SHORT_STRING = 6; + static final int TAG_STRING = 7; + + static final int TAG_TABLE = 8; + + static final int TAG_CLOSURE = 9; + static final int TAG_PROTOTYPE = 10; + static final int TAG_UPVALUE = 11; + static final int TAG_STACK = 12; + static final int TAG_THREAD = 13; + + static final int TAG_SERIALIZER = 14; + static final int TAG_SERIALIZED = 15; + + static final int TAG_REFERENCE = 20; + + static final int IN_HOOK = 1 << 0; + static final int HAS_ERRORFUNC = 1 << 1; + + public static void persist(LuaState state, DataOutput stream, LuaValue value) throws IOException { + try (ValueWriter writer = new ValueWriter(state, stream)) { + writer.write(value); + } + } + + public static LuaValue unpersist(LuaState state, DataInput stream) throws IOException { + try (ValueReader reader = new ValueReader(state, stream)) { + return (LuaValue) reader.read(); + } + } +} diff --git a/src/main/java/org/squiddev/cobalt/persist/Serializable.java b/src/main/java/org/squiddev/cobalt/persist/Serializable.java new file mode 100644 index 00000000..e2ab7cc5 --- /dev/null +++ b/src/main/java/org/squiddev/cobalt/persist/Serializable.java @@ -0,0 +1,9 @@ +package org.squiddev.cobalt.persist; + +public interface Serializable> { + Serializer getSerializer(); + + static Serializer getSerializer(Object object) { + return object instanceof Serializable ? ((Serializable) object).getSerializer() : null; + } +} diff --git a/src/main/java/org/squiddev/cobalt/persist/Serializer.java b/src/main/java/org/squiddev/cobalt/persist/Serializer.java new file mode 100644 index 00000000..67e18602 --- /dev/null +++ b/src/main/java/org/squiddev/cobalt/persist/Serializer.java @@ -0,0 +1,11 @@ +package org.squiddev.cobalt.persist; + +import java.io.IOException; + +public interface Serializer> { + String getName(); + + void save(ValueWriter writer, T value) throws IOException; + + T load(ValueReader reader) throws IOException; +} diff --git a/src/main/java/org/squiddev/cobalt/persist/ValueReader.java b/src/main/java/org/squiddev/cobalt/persist/ValueReader.java new file mode 100644 index 00000000..4616ee9c --- /dev/null +++ b/src/main/java/org/squiddev/cobalt/persist/ValueReader.java @@ -0,0 +1,320 @@ +package org.squiddev.cobalt.persist; + +import org.squiddev.cobalt.*; +import org.squiddev.cobalt.debug.DebugState; +import org.squiddev.cobalt.function.LocalVariable; +import org.squiddev.cobalt.function.LuaClosure; +import org.squiddev.cobalt.function.LuaInterpretedFunction; +import org.squiddev.cobalt.function.Upvalue; + +import java.io.Closeable; +import java.io.DataInput; +import java.io.IOException; +import java.util.Arrays; +import java.util.function.IntFunction; + +import static org.squiddev.cobalt.Constants.NIL; +import static org.squiddev.cobalt.Constants.NONE; +import static org.squiddev.cobalt.persist.Persist.*; + +public class ValueReader implements Closeable { + private final LuaState state; + private final DataInput input; + private Object[] output = new Object[128]; + int id = 0; + int partial = 0; + + public ValueReader(LuaState state, DataInput input) { + this.state = state; + this.input = input; + } + + public LuaState getState() { + return state; + } + + public Object read() throws IOException { + int tag = readByte(); + int type = tag & MASK_TAG; + boolean partial = (tag & FLAG_PARTIAL) != 0; + + // If we expect a full object, we can't have a "to populate" instruction. + if ((tag & FLAG_POPULATE) != 0) throw new IllegalStateException("Unexpected FLAG_POPULATE"); + if (partial) this.partial++; + + switch (type) { + case TAG_NIL: return NIL; + case TAG_TRUE: return Constants.TRUE; + case TAG_FALSE: return Constants.FALSE; + case TAG_SHORT_INT: return LuaInteger.valueOf(readShort()); + case TAG_INT: return LuaInteger.valueOf(readInt()); + case TAG_FLOAT: return LuaDouble.valueOf(readDouble()); + + case TAG_SHORT_STRING: { + int length = readByte() & 0xFF; + return LuaString.valueOf(read(length)); + } + + case TAG_STRING: { + int id = this.id++; + int length = readVarInt() & 0xFF; + return set(id, LuaString.valueOf(read(length))); + } + + case TAG_TABLE: { + int id = this.id++; + + LuaTable table = new LuaTable(); + set(id, table); + if (!partial) readTableBody(table); + return table; + } + + case TAG_CLOSURE: { + int id = this.id++; + Prototype prototype = (Prototype) read(); + int nups = readVarInt(); + Upvalue[] upvalues = nups > 0 ? new Upvalue[nups] : LuaInterpretedFunction.NO_UPVALUES; + LuaClosure closure = new LuaInterpretedFunction(prototype, upvalues); + set(id, closure); + + for (int i = 0; i < upvalues.length; i++) upvalues[i] = (Upvalue) read(); + + closure.setfenv((LuaTable) read()); + return closure; + } + case TAG_PROTOTYPE: { + int id = this.id++; + Prototype prototype = new Prototype(); + set(id, prototype); + if (!partial) readPrototypeBody(prototype); + return prototype; + } + + case TAG_UPVALUE: { + int id = this.id++; + LuaValue[] stack = (LuaValue[]) read(); + int index = readVarInt(); + return set(id, new Upvalue(stack, index)); + } + + case TAG_STACK: { + int id = this.id++; + int length = readVarInt(); + + LuaValue[] stack = new LuaValue[length]; + set(id, stack); + if (!partial) { + for (int i = 0; i < length; i++) stack[i] = (LuaValue) read(); + } + return stack; + } + + case TAG_THREAD: { + int id = this.id++; + + LuaThread coroutine = new LuaThread(state, null); + set(id, coroutine); + if (!partial) readThreadBody(coroutine); + return coroutine; + } + + case TAG_SERIALIZER: { + int id = this.id++; + String name = readString(); + + Serializer serializer = state.getSerializer(name); + if (serializer == null) throw new IOException("No such serializer " + name); + return set(id, serializer); + } + + case TAG_SERIALIZED: { + int id = this.id++; + + @SuppressWarnings("rawtypes") + Serializer serializer = (Serializer) read(); + return set(id, serializer.load(this)); + } + + case TAG_REFERENCE: + return this.output[readVarInt()]; + + default: + throw new IOException("Malformed input: unknown tag " + tag); + } + } + + private void readTableBody(LuaTable table) throws IOException { + while (true) { + LuaValue key = (LuaValue) read(); + if (key.isNil()) break; + + LuaValue value = (LuaValue) read(); + table.rawset(key, value); + } + + LuaValue metatable = (LuaValue) read(); + if (!metatable.isNil()) table.setMetatable((LuaTable) metatable); + } + + private void readThreadBody(LuaThread coroutine) throws IOException { + coroutine.readInternalState(this); + coroutine.setfenv((LuaTable) read()); + + DebugState debug = coroutine.getDebugState(); + + int flags = readByte(); + debug.inhook = (flags & IN_HOOK) != 0; + + if ((flags & HAS_ERRORFUNC) != 0) coroutine.setErrorFunc((LuaValue) read()); + debug.readInternalState(this); + } + + public Varargs readVarargs() throws IOException { + int count = readVarInt(); + switch (count) { + case 0: return NONE; + case 1: return (LuaValue) read(); + default: { + LuaValue[] values = new LuaValue[count]; + for (int i = 0; i < count; i++) values[i] = (LuaValue) read(); + return ValueFactory.varargsOf(values); + } + } + } + + private void readPrototypeBody(Prototype prototype) throws IOException { + prototype.source = (LuaString) read(); + prototype.linedefined = readVarInt(); + prototype.lastlinedefined = readVarInt(); + prototype.nups = readVarInt(); + + int flags = readVarInt(); + prototype.is_vararg = flags & 3; + prototype.numparams = flags >> 2; + prototype.maxstacksize = readVarInt(); + + prototype.k = readArrayUnsafe(LuaValue[]::new); + prototype.code = readIntArray(); + prototype.p = readArrayUnsafe(Prototype[]::new); + prototype.lineinfo = readIntArray(); + prototype.locvars = readArray(LocalVariable[]::new, this::readLocalVariable); + prototype.upvalues = readArrayUnsafe(LuaString[]::new); + } + + private LocalVariable readLocalVariable() throws IOException { + LuaString name = (LuaString) read(); + int startPc = readVarInt(); + int endPc = readVarInt(); + return new LocalVariable(name, startPc, endPc); + } + + public void close() throws IOException { + while (partial-- > 0) populateOne(); + } + + private void populateOne() throws IOException { + int tag = readByte(); + int type = tag & MASK_TAG; + + if ((tag & FLAG_PARTIAL) != 0) throw new IllegalStateException("Unexpected FLAG_PARTIAL"); + if ((tag & FLAG_POPULATE) == 0) throw new IllegalStateException("Expected FLAG_PARTIAL"); + + int id = readVarInt(); + Object object = output[id]; + switch (type) { + case TAG_TABLE: + readTableBody((LuaTable) object); + break; + case TAG_THREAD: + readThreadBody((LuaThread) object); + break; + case TAG_PROTOTYPE: + readPrototypeBody((Prototype) object); + break; + case TAG_STACK: { + LuaValue[] stack = (LuaValue[]) object; + for (int i = 0; i < stack.length; i++) stack[i] = (LuaValue) read(); + break; + } + default: + throw new IllegalStateException("Do not know how to reconstruct " + object + " (tag=" + type + ")"); + } + } + + private T set(int id, T value) { + if (id >= output.length) output = Arrays.copyOf(output, Math.max(id, output.length * 2)); + if (output[id] != null) throw new IllegalStateException("Duplicate keys for " + id); + output[id] = value; + return value; + } + + // region Primitives + public final int readVarInt() throws IOException { + int result = 0; + for (int j = 0; j < 5; j++) { + byte b = input.readByte(); + result |= (b & 0x7F) << j * 7; + if ((b & 0x80) != 128) return result; + } + + throw new IOException("readVarInt read more than 5 bytes"); + } + + public final void read(byte[] b) throws IOException { + input.readFully(b); + } + + public void read(byte[] b, int off, int len) throws IOException { + input.readFully(b, off, len); + } + + public final byte[] read(int length) throws IOException { + byte[] bytes = new byte[length]; + input.readFully(bytes, 0, length); + return bytes; + } + + public final int readByte() throws IOException { + return input.readUnsignedByte(); + } + + public final short readShort() throws IOException { + return input.readShort(); + } + + public final int readInt() throws IOException { + return input.readInt(); + } + + public final double readDouble() throws IOException { + return input.readDouble(); + } + + public T[] readArray(IntFunction make, Reader child) throws IOException { + T[] values = make.apply(readVarInt()); + for (int i = 0; i < values.length; i++) values[i] = child.read(); + return values; + } + + public int[] readIntArray() throws IOException { + int[] values = new int[readVarInt()]; + for (int i = 0; i < values.length; i++) values[i] = readVarInt(); + return values; + } + + @SuppressWarnings("unchecked") + public T[] readArrayUnsafe(IntFunction make) throws IOException { + return readArray(make, () -> (T) read()); + } + + public String readString() throws IOException { + int length = readVarInt(); + return LuaString.decode(read(length), 0, length); + } + + public interface Reader { + T read() throws IOException; + } + // endregion +} diff --git a/src/main/java/org/squiddev/cobalt/persist/ValueWriter.java b/src/main/java/org/squiddev/cobalt/persist/ValueWriter.java new file mode 100644 index 00000000..60c25ac5 --- /dev/null +++ b/src/main/java/org/squiddev/cobalt/persist/ValueWriter.java @@ -0,0 +1,360 @@ +package org.squiddev.cobalt.persist; + +import org.squiddev.cobalt.*; +import org.squiddev.cobalt.debug.DebugState; +import org.squiddev.cobalt.function.LocalVariable; +import org.squiddev.cobalt.function.LuaClosure; +import org.squiddev.cobalt.function.Upvalue; + +import java.io.Closeable; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.HashMap; +import java.util.Queue; + +import static org.squiddev.cobalt.Constants.*; +import static org.squiddev.cobalt.persist.Persist.*; + +public class ValueWriter implements Closeable { + private static final int NO_WRITE = -1; + static final int MAX_DEPTH = 16; + + private final DataOutput output; + + private final HashMap ids = new HashMap<>(); + private final LuaState state; + private int nextId = 0; + private int depth = 0; + + private final Queue queue = new ArrayDeque<>(); + + public ValueWriter(LuaState state, DataOutput output) { + this.output = output; + this.state = state; + } + + private int checkWritten(Object object, int type) throws IOException { + Integer index = ids.get(object); + if (index == null) { + int id = nextId++; + ids.put(object, id); + writeByte(type); + return id; + } else { + writeByte(TAG_REFERENCE); + writeVarInt(index); + return NO_WRITE; + } + } + + private int checkWrittenPartial(Object object, int type) throws IOException { + boolean skip = depth > MAX_DEPTH; + int id = checkWritten(object, (skip ? FLAG_PARTIAL : 0) | type); + if (skip && id != NO_WRITE) { + queue.add(new Update(id, object, type)); + return NO_WRITE; + } + + return id; + } + + public void write(LuaValue value) throws IOException { + switch (value.type()) { + case TNIL: + writeByte(TAG_NIL); + break; + case TBOOLEAN: + writeByte(((LuaBoolean) value).v ? TAG_TRUE : TAG_FALSE); + break; + + case TNUMBER: { + if (value instanceof LuaInteger) { + int v = ((LuaInteger) value).v; + if (v >= Short.MIN_VALUE && v <= Short.MIN_VALUE) { + writeByte(TAG_SHORT_INT); + writeShort(v); + } else { + writeByte(TAG_INT); + writeInt(v); + } + } else { + writeByte(TAG_FLOAT); + writeDouble(value.toDouble()); + } + break; + } + + case TSTRING: { + LuaString str = ((LuaBaseString) value).strvalue(); + if (str.length <= 24) { + writeByte(TAG_SHORT_STRING); + writeByte(str.length); + write(str.bytes, str.offset, str.length); + } else { + if (checkWritten(str, TAG_STRING) == NO_WRITE) return; + writeVarInt(str.length); + write(str.bytes, str.offset, str.length); + } + break; + } + + case TTABLE: + if (checkWrittenPartial(value, TAG_TABLE) == NO_WRITE) return; + writeTableBody((LuaTable) value); + break; + + case TTHREAD: + if (checkWrittenPartial(value, TAG_THREAD) == NO_WRITE) return; + writeThreadBody((LuaThread) value); + break; + + + case TFUNCTION: + if (value instanceof LuaClosure) { + LuaClosure closure = (LuaClosure) value; + if (checkWritten(closure, TAG_CLOSURE) == NO_WRITE) return; + + write(closure.getPrototype()); + + int nups = closure.getPrototype().nups; + writeVarInt(nups); + + write(closure.getfenv()); + for (int i = 0; i < nups; i++) write(closure.getUpvalue(i)); + + break; + } + + // fallthrough + default: + serialize(value); + break; + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void serialize(Object value) throws IOException { + if (checkWritten(value, TAG_SERIALIZED) == NO_WRITE) return; + + Serializer serializer = Serializable.getSerializer(value); + if (serializer == null) { + throw new IllegalStateException(String.format("Cannot serialize %s (of type %s)", value, value.getClass().getName())); + } + + if (checkWritten(serializer, TAG_SERIALIZER) != NO_WRITE) write(serializer.getName()); + + depth++; + serializer.save(this, (Serializable) value); + depth--; + } + + public void write(Varargs value) throws IOException { + int count = value == null ? 0 : value.count(); + writeVarInt(count); + for (int i = 0; i < count; i++) write(value.arg(i + 1)); + } + + private void writeTableBody(LuaTable table) throws IOException { + depth++; + { + LuaValue key = NIL; + try { + while (true) { + Varargs pair = table.next(key); + key = pair.arg(1); + if (key.isNil()) break; + + write(key); + write(pair.arg(2)); + } + } catch (LuaError e) { + throw new IOException(e); + } + + writeByte(TAG_NIL); + LuaTable metatable = table.getMetatable(state); + if (metatable == null) { + writeByte(TAG_NIL); + } else { + write(metatable); + } + } + depth--; + } + + private void writeThreadBody(LuaThread coroutine) throws IOException { + depth++; + { + // Thread state + coroutine.writeInternalState(this); + write(coroutine.getfenv()); + + LuaValue errorFunc = coroutine.getErrorFunc(); + + // Debug hooks + DebugState debug = coroutine.getDebugState(); + writeByte( + (debug.inhook ? IN_HOOK : 0) + | (errorFunc != null ? HAS_ERRORFUNC : 0) + ); + if (errorFunc != null) write(errorFunc); + debug.writeInternalState(this); + } + depth--; + } + + public void write(Prototype prototype) throws IOException { + if (checkWrittenPartial(prototype, TAG_PROTOTYPE) == NO_WRITE) return; + writePrototypeBody(prototype); + } + + private void writePrototypeBody(Prototype prototype) throws IOException { + depth++; + { + write(prototype.source); + writeVarInt(prototype.linedefined); + writeVarInt(prototype.lastlinedefined); + writeVarInt(prototype.nups); + writeVarInt((prototype.numparams << 2) | prototype.is_vararg); + writeVarInt(prototype.maxstacksize); + + writeArray(prototype.k, this::write); + writeIntArray(prototype.code); + writeArray(prototype.p, this::write); + writeIntArray(prototype.lineinfo); + writeArray(prototype.locvars, this::write); + writeArray(prototype.upvalues, this::write); + } + depth--; + } + + public void write(Upvalue upvalue) throws IOException { + if (checkWritten(upvalue, TAG_UPVALUE) == NO_WRITE) return; + write(upvalue.getArray()); + writeVarInt(upvalue.getIndex()); + } + + public void write(LuaValue[] stack) throws IOException { + boolean skip = depth > MAX_DEPTH; + int id = checkWritten(stack, (skip ? FLAG_PARTIAL : 0) | TAG_STACK); + if (id == NO_WRITE) return; + + writeVarInt(stack.length); + if (skip) { + queue.add(new Update(id, stack, TAG_STACK)); + return; + } + + for (LuaValue value : stack) write(value); + } + + public void write(LocalVariable var) throws IOException { + write(var.name); + writeVarInt(var.startpc); + writeVarInt(var.endpc); + } + + public void close() throws IOException { + Update update; + while ((update = queue.poll()) != null) { + writeByte(FLAG_POPULATE | update.type); + writeVarInt(update.id); + switch (update.type) { + case TAG_TABLE: + writeTableBody((LuaTable) update.object); + break; + case TAG_PROTOTYPE: + writePrototypeBody((Prototype) update.object); + break; + case TAG_STACK: { + for (LuaValue value : (LuaValue[]) update.object) write(value); + break; + } + default: + throw new IllegalStateException("Cannot resume for " + update.object + " (tag=" + update.type + ")"); + } + } + } + + private static final class Update { + final Object object; + final int id; + final int type; + + Update(int id, Object object, int type) { + this.object = object; + this.id = id; + this.type = type; + } + } + + // region Primitives + public final void writeVarInt(int input) throws IOException { + while ((input & 0xFFFFFF80) != 0) { + output.writeByte(input & 0x7F | 0x80); + input >>>= 7; + } + output.writeByte(input); + } + + public final void write(byte[] b) throws IOException { + output.write(b); + } + + public void write(byte[] b, int off, int len) throws IOException { + output.write(b, off, len); + } + + public void writeIntArray(int[] values) throws IOException { + if (values == null) { + writeByte(0); + return; + } + + writeVarInt(values.length); + for (int x : values) writeVarInt(x); + } + + public boolean writeArrayHeader(T[] values) throws IOException { + if (values == null) { + writeByte(0); + return false; + } + + writeVarInt(values.length); + return true; + } + + public void writeArray(T[] values, Writer child) throws IOException { + if (!writeArrayHeader(values)) return; + for (T value : values) child.write(value); + } + + public final void writeByte(int v) throws IOException { + output.writeByte(v); + } + + public final void writeShort(int v) throws IOException { + output.writeShort(v); + } + + public final void writeInt(int v) throws IOException { + output.writeInt(v); + } + + public final void writeDouble(double v) throws IOException { + output.writeDouble(v); + } + + public final void write(String str) throws IOException { + int length = str.length(); + writeVarInt(length); + for (int i = 0; i < length; i++) writeByte(str.charAt(i) & 0xFF); + } + + public interface Writer { + void write(T value) throws IOException; + } + // endregion +} diff --git a/src/test/java/org/squiddev/cobalt/CoroutineTest.java b/src/test/java/org/squiddev/cobalt/CoroutineTest.java index a736c2b4..5c86f16e 100644 --- a/src/test/java/org/squiddev/cobalt/CoroutineTest.java +++ b/src/test/java/org/squiddev/cobalt/CoroutineTest.java @@ -30,21 +30,18 @@ import org.junit.jupiter.params.provider.MethodSource; import org.squiddev.cobalt.compiler.CompileException; import org.squiddev.cobalt.debug.DebugFrame; -import org.squiddev.cobalt.debug.DebugHandler; import org.squiddev.cobalt.debug.DebugHelpers; -import org.squiddev.cobalt.debug.DebugState; import org.squiddev.cobalt.function.LuaFunction; import org.squiddev.cobalt.function.ResumableVarArgFunction; import org.squiddev.cobalt.function.VarArgFunction; import org.squiddev.cobalt.lib.LuaLibrary; +import org.squiddev.cobalt.support.SuspendingDebug; import java.io.IOException; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; import static org.squiddev.cobalt.OperationHelper.noUnwind; -import static org.squiddev.cobalt.debug.DebugFrame.FLAG_HOOKED; -import static org.squiddev.cobalt.debug.DebugFrame.FLAG_HOOKYIELD; /** * Tests yielding in a whole load of places @@ -130,7 +127,7 @@ public void runSuspendBlocking(String name) throws IOException, CompileException private static class Functions extends ResumableVarArgFunction implements LuaLibrary { @Override public LuaValue add(LuaState state, LuaTable environment) { - bind(environment, Functions::new, new String[]{"suspend", "run", "assertEquals", "fail", "id", "noUnwind"}); + bind(state, "$test", environment, Functions::new, new String[]{"suspend", "run", "assertEquals", "fail", "id", "noUnwind"}); return environment; } @@ -180,36 +177,4 @@ public Varargs resumeThis(LuaState state, LuaThread thread, Varargs value) throw } } - private static class SuspendingDebug extends DebugHandler { - private boolean suspend = true; - - private int flags; - private boolean inHook; - - @Override - public void onInstruction(DebugState ds, DebugFrame di, int pc) throws LuaError, UnwindThrowable { - di.pc = pc; - - if (suspend) { - // Save the current state - flags = di.flags; - inHook = ds.inhook; - - // Set HOOK_YIELD and HOOKED flags so we know its an instruction hook - di.flags |= FLAG_HOOKYIELD | FLAG_HOOKED; - - // We don't want to suspend next tick. - suspend = false; - LuaThread.suspend(ds.getLuaState()); - } - - // Restore the old state - ds.inhook = inHook; - di.flags = flags; - suspend = true; - - // And continue as normal - super.onInstruction(ds, di, pc); - } - } } diff --git a/src/test/java/org/squiddev/cobalt/PersistTests.java b/src/test/java/org/squiddev/cobalt/PersistTests.java new file mode 100644 index 00000000..4e68fa0f --- /dev/null +++ b/src/test/java/org/squiddev/cobalt/PersistTests.java @@ -0,0 +1,145 @@ +package org.squiddev.cobalt; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.squiddev.cobalt.compiler.CompileException; +import org.squiddev.cobalt.compiler.LoadState; +import org.squiddev.cobalt.debug.DebugFrame; +import org.squiddev.cobalt.function.LuaFunction; +import org.squiddev.cobalt.function.ResumableVarArgFunction; +import org.squiddev.cobalt.lib.*; +import org.squiddev.cobalt.lib.jse.JsePlatform; +import org.squiddev.cobalt.persist.Persist; +import org.squiddev.cobalt.support.PairedStream; +import org.squiddev.cobalt.support.PrettyValue; + +import java.io.*; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import static org.squiddev.cobalt.Constants.NONE; +import static org.squiddev.cobalt.ValueFactory.valueOf; + +public class PersistTests { + @ParameterizedTest(name = ParameterizedTest.ARGUMENTS_WITH_NAMES_PLACEHOLDER) + @MethodSource("basicValues") + public void values(PrettyValue value) throws Exception { + LuaState state = new LuaState(); + + LuaValue read = PairedStream.run( + output -> Persist.persist(state, new DataOutputStream(output), value.getValue()), + input -> Persist.unpersist(state, new DataInputStream(input)) + ); + + assertEquals(value, new PrettyValue(read)); + } + + public static Stream basicValues() { + return Stream.of( + Constants.TRUE, Constants.FALSE, Constants.NIL, + valueOf(0), valueOf(1), valueOf(-5), + valueOf(0.1), valueOf(Double.POSITIVE_INFINITY), + valueOf(""), valueOf("hello"), + valueOf("a long string which will not be stored inline"), + get(() -> { + LuaTable table = new LuaTable(); + for (int i = 1; i < 5; i++) table.rawset(i, valueOf(Integer.toString(i))); + table.rawset("a", table); + return table; + }), + get(() -> { + LuaTable first = new LuaTable(); + LuaTable last = first; + for (int i = 0; i < 100; i++) { + LuaTable next = new LuaTable(); + last.rawset(i, next); + last = next; + } + return first; + }) + ).map(PrettyValue::new); + } + + @ParameterizedTest(name = ParameterizedTest.ARGUMENTS_WITH_NAMES_PLACEHOLDER) + @ValueSource(strings = { + "add", + "coroutines", + "pcall", + }) + public void program(String filename) throws IOException, InterruptedException, CompileException, LuaError { + byte[] contents; + try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { + LuaState state = new LuaState(); + LuaTable env = setupGlobals(state); + + LuaFunction function; + try (InputStream stream = getClass().getResourceAsStream("/persist/" + filename + ".lua")) { + if (stream == null) fail("Could not load script for test case: " + filename); + function = LoadState.load(state, stream, "@" + filename + ".lua", env); + } + + LuaThread thread = new LuaThread(state, function, env); + LuaThread.run(thread, NONE); + + assertEquals("suspended", state.getCurrentThread().getStatus()); + Persist.persist(state, new DataOutputStream(out), state.getCurrentThread()); + contents = out.toByteArray(); + } + + LuaState state2 = new LuaState(); + setupGlobals(state2); + LuaThread thread; + try (ByteArrayInputStream out = new ByteArrayInputStream(contents)) { + thread = (LuaThread) Persist.unpersist(state2, new DataInputStream(out)); + } + + Varargs result = LuaThread.run(thread, NONE); + assertEquals(valueOf("OK"), result); + } + + private static T get(Supplier supplier) { + return supplier.get(); + } + + /** + * Like {@link JsePlatform#debugGlobals(LuaState)}, but without os, io and package. Namely, the ones + * ComputerCraft uses (as that's all we need persistence for right now). + * + * @param state The state to setup globals in. + */ + public LuaTable setupGlobals(LuaState state) { + LuaTable _G = new LuaTable(); + state.setupThread(_G); + _G.load(state, new BaseLib()); + _G.load(state, new TableLib()); + _G.load(state, new StringLib()); + _G.load(state, new CoroutineLib()); + _G.load(state, new MathLib()); + _G.load(state, new Utf8Lib()); + _G.load(state, new DebugLib()); + _G.load(state, new Functions()); + return _G; + } + + private static class Functions extends ResumableVarArgFunction implements LuaLibrary { + @Override + public LuaValue add(LuaState state, LuaTable environment) { + bind(state, "$test", environment, Functions::new, new String[]{"suspend"}); + return environment; + } + + @Override + public Varargs invoke(LuaState state, DebugFrame di, Varargs args) throws LuaError, UnwindThrowable { + LuaThread.suspend(state); + return NONE; + } + + @Override + protected Varargs resumeThis(LuaState state, Object object, Varargs value) { + return NONE; + } + } +} diff --git a/src/test/java/org/squiddev/cobalt/support/Equality.java b/src/test/java/org/squiddev/cobalt/support/Equality.java new file mode 100644 index 00000000..9957c3f9 --- /dev/null +++ b/src/test/java/org/squiddev/cobalt/support/Equality.java @@ -0,0 +1,76 @@ +package org.squiddev.cobalt.support; + +import org.squiddev.cobalt.*; + +import java.util.ArrayDeque; +import java.util.HashMap; +import java.util.Map; +import java.util.Queue; + +class Equality { + private final Map map = new HashMap<>(); + private final Queue queue = new ArrayDeque<>(); + + public static boolean equals(LuaValue left, LuaValue right) { + Equality eq = new Equality(); + if (!eq.checkValue(left, right)) return false; + + try { + Pair pair; + while ((pair = eq.queue.poll()) != null) { + if (!eq.checkTables(pair.left, pair.right)) return false; + } + } catch (LuaError e) { + throw new IllegalStateException(e); + } + + return true; + } + + private boolean checkValue(LuaValue left, LuaValue right) { + if (left.equals(right)) return true; + if (!(left instanceof LuaTable) || !(right instanceof LuaTable)) return false; + + queue.add(new Pair((LuaTable) left, (LuaTable) right)); + return true; + } + + private boolean checkTables(LuaTable left, LuaTable right) throws LuaError { + if (map.get(left) == right) return true; + if (map.containsKey(left) || map.containsKey(right)) return false; + + map.put(left, right); + map.put(right, left); + + LuaValue key = Constants.NIL; + while (true) { + Varargs next = left.next(key); + key = next.first(); + if (key.isNil()) break; + + LuaValue value = next.arg(2); + if (!checkValue(value, right.rawget(key))) return false; + } + + key = Constants.NIL; + while (true) { + Varargs next = right.next(key); + key = next.first(); + if (key.isNil()) break; + + if (left.rawget(key).isNil()) return false; + } + + return true; + } + + private static class Pair { + final LuaTable left; + final LuaTable right; + + private Pair(LuaTable left, LuaTable right) { + this.left = left; + this.right = right; + } + } +} diff --git a/src/test/java/org/squiddev/cobalt/Matchers.java b/src/test/java/org/squiddev/cobalt/support/Matchers.java similarity index 97% rename from src/test/java/org/squiddev/cobalt/Matchers.java rename to src/test/java/org/squiddev/cobalt/support/Matchers.java index 56f977c6..ec2bb317 100644 --- a/src/test/java/org/squiddev/cobalt/Matchers.java +++ b/src/test/java/org/squiddev/cobalt/support/Matchers.java @@ -22,7 +22,7 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -package org.squiddev.cobalt; +package org.squiddev.cobalt.support; import org.hamcrest.Matcher; diff --git a/src/test/java/org/squiddev/cobalt/support/PairedStream.java b/src/test/java/org/squiddev/cobalt/support/PairedStream.java new file mode 100644 index 00000000..92beab67 --- /dev/null +++ b/src/test/java/org/squiddev/cobalt/support/PairedStream.java @@ -0,0 +1,119 @@ +package org.squiddev.cobalt.support; + +import java.io.*; +import java.util.concurrent.*; + +/** + * A paired input and ouput stream. Thus operates in two possible modes: + * + *
    + *
  • "Sequential" - this just writes to a byte array and then reads from it.
  • + *
  • "Lockstep": Runs the writer and reader at the same time, and ensure read and write operations happen in sync + * with each other. This is useful when debugging, but obviously has much higher overhead.
  • + *
+ */ +public class PairedStream { + private static final boolean DEBUG = false; + + private final BlockingQueue value = new ArrayBlockingQueue<>(1); + private boolean closed = false; + + private final OutputStream writer = new OutputStream() { + @Override + public void write(int b) throws IOException { + if (b < 0 || b > 255) throw new IllegalStateException("Byte out of bounds"); + try { + value.put(b); + } catch (InterruptedException e) { + throw new IOException(e); + } + } + + @Override + public void close() throws IOException { + super.close(); + try { + value.put(-1); + } catch (InterruptedException e) { + throw new IOException(e); + } + } + }; + + private final InputStream reader = new InputStream() { + @Override + public int read() throws IOException { + if (closed) return -1; + + try { + int read = value.take(); + if (read == -1) closed = true; + return read; + } catch (InterruptedException e) { + throw new IOException(e); + } + } + }; + + public static T run(UncheckedConsumer write, UncheckedFunction read) throws Exception { + return DEBUG ? runLockstep(write, read) : runSequential(write, read); + } + + public static T runSequential(UncheckedConsumer write, UncheckedFunction read) throws Exception { + byte[] contents; + try (ByteArrayOutputStream output = new ByteArrayOutputStream()) { + write.accept(output); + contents = output.toByteArray(); + } + + try (ByteArrayInputStream input = new ByteArrayInputStream(contents)) { + return read.apply(input); + } + } + + public static T runLockstep(UncheckedConsumer write, UncheckedFunction read) throws Exception { + ExecutorService executor = Executors.newCachedThreadPool(); + PairedStream stream = new PairedStream(); + CompletableFuture writer = CompletableFuture.runAsync(() -> { + try { + write.accept(stream.writer); + } catch (Exception e) { + doSneakyThrow(e); + } finally { + try { + stream.writer.close(); + } catch (IOException ignored) { + } + } + }, executor); + CompletableFuture reader = CompletableFuture.supplyAsync(() -> { + try { + return read.apply(stream.reader); + } catch (Exception e) { + doSneakyThrow(e); + return null; + } + }, executor); + + CompletableFuture.anyOf(reader, writer).get(); + T value = reader.get(); + executor.shutdownNow(); + return value; + } + + public interface UncheckedConsumer { + void accept(T value) throws Exception; + } + + public interface UncheckedFunction { + U apply(T value) throws Exception; + } + + public static E sneakyThrow(Throwable e) throws E { + throw (E) e; + } + + public static void doSneakyThrow(Throwable e) { + PairedStream.sneakyThrow(e); + } +} diff --git a/src/test/java/org/squiddev/cobalt/support/PrettyValue.java b/src/test/java/org/squiddev/cobalt/support/PrettyValue.java new file mode 100644 index 00000000..d24f5c97 --- /dev/null +++ b/src/test/java/org/squiddev/cobalt/support/PrettyValue.java @@ -0,0 +1,84 @@ +package org.squiddev.cobalt.support; + +import org.squiddev.cobalt.*; +import org.squiddev.cobalt.lib.FormatDesc; +import org.squiddev.cobalt.lib.UncheckedLuaError; + +import java.util.IdentityHashMap; +import java.util.Map; + +public class PrettyValue { + private static final FormatDesc quote = FormatDesc.ofUnsafe("%q"); + private final LuaValue value; + + public PrettyValue(LuaValue value) { + this.value = value; + } + + public LuaValue getValue() { + return value; + } + + @Override + public String toString() { + StringBuffer buffer = new StringBuffer(); + try { + toString(buffer, new IdentityHashMap<>(), value); + } catch (LuaError e) { + throw new UncheckedLuaError(e); + } + return buffer.toString(); + } + + @Override + public boolean equals(Object obj) { + return obj instanceof PrettyValue && (this == obj || Equality.equals(value, ((PrettyValue) obj).value)); + } + + @Override + public int hashCode() { + return value instanceof LuaTable ? 0 : value.hashCode(); + } + + private static void toString(StringBuffer buffer, Map seen, LuaValue value) throws LuaError { + if (value instanceof LuaString) { + buffer.append('"').append(value).append('"'); + } else if (!(value instanceof LuaTable)) { + buffer.append(value); + } else { + LuaTable table = (LuaTable) value; + + Integer id = seen.get(table); + boolean skip = id != null; + if (!skip) { + id = seen.size(); + seen.put(table, id); + } + + buffer.append("table#").append(id); + if (skip) return; + + buffer.append(":={"); + boolean first = true; + LuaValue key = Constants.NIL; + while (true) { + Varargs next = table.next(key); + key = next.first(); + if (key.isNil()) break; + + if (first) { + first = false; + } else { + buffer.append(", "); + } + + buffer.append("["); + toString(buffer, seen, key); + buffer.append("] = "); + toString(buffer, seen, table.rawget(key)); + } + + buffer.append("}"); + } + } +} diff --git a/src/test/java/org/squiddev/cobalt/support/SuspendingDebug.java b/src/test/java/org/squiddev/cobalt/support/SuspendingDebug.java new file mode 100644 index 00000000..395d6385 --- /dev/null +++ b/src/test/java/org/squiddev/cobalt/support/SuspendingDebug.java @@ -0,0 +1,70 @@ +package org.squiddev.cobalt.support; + +import org.squiddev.cobalt.LuaError; +import org.squiddev.cobalt.LuaThread; +import org.squiddev.cobalt.UnwindThrowable; +import org.squiddev.cobalt.debug.DebugFrame; +import org.squiddev.cobalt.debug.DebugHandler; +import org.squiddev.cobalt.debug.DebugState; + +import static org.squiddev.cobalt.debug.DebugFrame.FLAG_HOOKED; +import static org.squiddev.cobalt.debug.DebugFrame.FLAG_HOOKYIELD; + +public class SuspendingDebug extends DebugHandler { + private boolean suspend = true; + + private int flags; + private boolean inHook; + + @Override + public void onInstruction(DebugState ds, DebugFrame di, int pc) throws LuaError, UnwindThrowable { + di.pc = pc; + + if (suspend) { + // Save the current state + flags = di.flags; + inHook = ds.inhook; + + // Set HOOK_YIELD and HOOKED flags so we know its an instruction hook + di.flags |= FLAG_HOOKYIELD | FLAG_HOOKED; + + // We don't want to suspend next tick. + suspend = false; + LuaThread.suspend(ds.getLuaState()); + } + + // Restore the old state + ds.inhook = inHook; + di.flags = flags; + suspend = true; + + // And continue as normal + super.onInstruction(ds, di, pc); + } + + public DebugHandler justResume() { + return suspend ? new DebugHandler() : new Resuming(flags, inHook); + } + + private static class Resuming extends DebugHandler { + private boolean needsWork = true; + private final int flags; + private final boolean inHook; + + private Resuming(int flags, boolean inHook) { + this.flags = flags; + this.inHook = inHook; + } + + @Override + public void onInstruction(DebugState ds, DebugFrame di, int pc) throws LuaError, UnwindThrowable { + if (needsWork) { + ds.inhook = inHook; + di.flags = flags; + needsWork = false; + } + + super.onInstruction(ds, di, pc); + } + } +} diff --git a/src/test/java/org/squiddev/cobalt/table/TableArrayTest.java b/src/test/java/org/squiddev/cobalt/table/TableArrayTest.java index d6734e99..8b24815b 100644 --- a/src/test/java/org/squiddev/cobalt/table/TableArrayTest.java +++ b/src/test/java/org/squiddev/cobalt/table/TableArrayTest.java @@ -32,7 +32,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.junit.jupiter.api.Assertions.*; -import static org.squiddev.cobalt.Matchers.between; +import static org.squiddev.cobalt.support.Matchers.between; /** * Tests for tables used as lists. diff --git a/src/test/java/org/squiddev/cobalt/table/TableTest.java b/src/test/java/org/squiddev/cobalt/table/TableTest.java index 2b16d16e..878f8e37 100644 --- a/src/test/java/org/squiddev/cobalt/table/TableTest.java +++ b/src/test/java/org/squiddev/cobalt/table/TableTest.java @@ -33,7 +33,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.junit.jupiter.api.Assertions.*; -import static org.squiddev.cobalt.Matchers.between; +import static org.squiddev.cobalt.support.Matchers.between; import static org.squiddev.cobalt.ValueFactory.valueOf; public class TableTest { diff --git a/src/test/resources/persist/add.lua b/src/test/resources/persist/add.lua new file mode 100644 index 00000000..0641e21c --- /dev/null +++ b/src/test/resources/persist/add.lua @@ -0,0 +1,6 @@ +local a, b = 1, 2 + +suspend() + +assert(a + b == 3) +return "OK" diff --git a/src/test/resources/persist/coroutines.lua b/src/test/resources/persist/coroutines.lua new file mode 100644 index 00000000..3d3f2bd9 --- /dev/null +++ b/src/test/resources/persist/coroutines.lua @@ -0,0 +1,18 @@ +local function worker() + local a = coroutine.yield() + suspend() + local b = coroutine.yield() + + return a + b +end + +local h = coroutine.create(worker) +assert(coroutine.resume(h)) +assert(coroutine.resume(h, 1)) + +local ok, res = coroutine.resume(h, 2) +assert(ok, res) +assert(res == 3) +assert(coroutine.status(h) == "dead") + +return "OK" diff --git a/src/test/resources/persist/pcall.lua b/src/test/resources/persist/pcall.lua new file mode 100644 index 00000000..e64dad5f --- /dev/null +++ b/src/test/resources/persist/pcall.lua @@ -0,0 +1,10 @@ +local function go(a, b) + suspend() + return a + b +end + +local ok, res = pcall(go, 1, 2) +assert(ok) +assert(res == 3) + +return "OK"