diff --git a/shared/src/main/scala/mlscript/JSBackend.scala b/shared/src/main/scala/mlscript/JSBackend.scala index 3ee699d06..9ea71bf8e 100644 --- a/shared/src/main/scala/mlscript/JSBackend.scala +++ b/shared/src/main/scala/mlscript/JSBackend.scala @@ -277,6 +277,12 @@ class JSBackend(allowUnresolvedSymbols: Boolean) { name -> translateTerm(value) }) :: Nil ) + // Only parenthesize binary operators + // Custom operators do not need special handling since they are desugared to plain methods + case Bra(false, trm) => trm match { + case App(Var(op), _) if JSBinary.operators.contains(op) => JSParenthesis(translateTerm(trm)) + case trm => translateTerm(trm) + } case Bra(_, trm) => translateTerm(trm) case Tup(terms) => JSArray(terms map { case (_, Fld(_, term)) => translateTerm(term) }) diff --git a/shared/src/main/scala/mlscript/codegen/Codegen.scala b/shared/src/main/scala/mlscript/codegen/Codegen.scala index 0639db3b7..3efacdef7 100644 --- a/shared/src/main/scala/mlscript/codegen/Codegen.scala +++ b/shared/src/main/scala/mlscript/codegen/Codegen.scala @@ -924,6 +924,11 @@ final case class JSComment(text: Str) extends JSStmt { def toSourceCode: SourceCode = SourceCode(s"// $text") } +final case class JSParenthesis(exp: JSExpr) extends JSExpr { + implicit def precedence: Int = 0 + def toSourceCode: SourceCode = exp.embed +} + object JSCodeHelpers { def id(name: Str): JSIdent = JSIdent(name) def lit(value: Int): JSLit = JSLit(value.toString()) diff --git a/shared/src/test/diff/codegen/NuParentheses.mls b/shared/src/test/diff/codegen/NuParentheses.mls new file mode 100644 index 000000000..356a47a2f --- /dev/null +++ b/shared/src/test/diff/codegen/NuParentheses.mls @@ -0,0 +1,61 @@ +:NewDefs + + +:js +16 / (2 / 2) +//│ Num +//│ // Prelude +//│ let res; +//│ class TypingUnit {} +//│ const typing_unit = new TypingUnit; +//│ // Query 1 +//│ res = 16 / (2 / 2); +//│ // End of generated code +//│ res +//│ = 16 + +:js +1 - (3 - 5) +//│ Int +//│ // Prelude +//│ class TypingUnit1 {} +//│ const typing_unit1 = new TypingUnit1; +//│ // Query 1 +//│ res = 1 - (3 - 5); +//│ // End of generated code +//│ res +//│ = 3 + + +fun (--) minusminus(a, b) = a - b +//│ fun (--) minusminus: (Int, Int) -> Int + +:js +1 -- (3 -- 5) +//│ Int +//│ // Prelude +//│ class TypingUnit3 {} +//│ const typing_unit3 = new TypingUnit3; +//│ // Query 1 +//│ res = minusminus(1, minusminus(3, 5)); +//│ // End of generated code +//│ res +//│ = 3 + + +fun (-+-) complex(a, b) = a - 2*b +//│ fun (-+-) complex: (Int, Int) -> Int + +:js +1 -+- (3 -+- 5) +//│ Int +//│ // Prelude +//│ class TypingUnit5 {} +//│ const typing_unit5 = new TypingUnit5; +//│ // Query 1 +//│ res = complex(1, complex(3, 5)); +//│ // End of generated code +//│ res +//│ = 15 + + diff --git a/shared/src/test/diff/nu/UnaryMinus.mls b/shared/src/test/diff/nu/UnaryMinus.mls index 389426bfe..bb0341d0a 100644 --- a/shared/src/test/diff/nu/UnaryMinus.mls +++ b/shared/src/test/diff/nu/UnaryMinus.mls @@ -28,7 +28,7 @@ 1 - (3 - 5) //│ Int //│ res -//│ = -7 +//│ = 3 3 - 1 //│ Int @@ -53,7 +53,7 @@ 1 - (1 - 1) //│ Int //│ res -//│ = -1 +//│ = 1 1 - 1 //│ Int