From 46110524efe1cb9e60ceda52d19bce23e5e03486 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 29 Oct 2021 13:25:12 -0700 Subject: [PATCH] Sync to upstream/release/501 (#20) Co-authored-by: Rodactor --- .clang-format | 25 + Analysis/include/Luau/AstQuery.h | 63 + Analysis/include/Luau/Autocomplete.h | 91 + Analysis/include/Luau/BuiltinDefinitions.h | 51 + Analysis/include/Luau/Config.h | 58 + Analysis/include/Luau/Documentation.h | 50 + Analysis/include/Luau/Error.h | 332 + Analysis/include/Luau/FileResolver.h | 103 + Analysis/include/Luau/Frontend.h | 179 + Analysis/include/Luau/IostreamHelpers.h | 46 + Analysis/include/Luau/JsonEncoder.h | 13 + Analysis/include/Luau/Linter.h | 96 + Analysis/include/Luau/Module.h | 111 + Analysis/include/Luau/ModuleResolver.h | 79 + Analysis/include/Luau/Predicate.h | 120 + Analysis/include/Luau/RecursionCounter.h | 39 + Analysis/include/Luau/RequireTracer.h | 28 + Analysis/include/Luau/Substitution.h | 208 + Analysis/include/Luau/Symbol.h | 95 + Analysis/include/Luau/ToString.h | 72 + Analysis/include/Luau/TopoSortStatements.h | 18 + Analysis/include/Luau/Transpiler.h | 30 + Analysis/include/Luau/TxnLog.h | 46 + Analysis/include/Luau/TypeAttach.h | 21 + Analysis/include/Luau/TypeInfer.h | 453 ++ Analysis/include/Luau/TypePack.h | 161 + Analysis/include/Luau/TypeUtils.h | 19 + Analysis/include/Luau/TypeVar.h | 531 ++ Analysis/include/Luau/TypedAllocator.h | 134 + Analysis/include/Luau/Unifiable.h | 123 + Analysis/include/Luau/Unifier.h | 98 + Analysis/include/Luau/Variant.h | 302 + Analysis/include/Luau/VisitTypeVar.h | 200 + Analysis/src/AstQuery.cpp | 411 + Analysis/src/Autocomplete.cpp | 1566 ++++ Analysis/src/BuiltinDefinitions.cpp | 805 ++ Analysis/src/Config.cpp | 278 + Analysis/src/EmbeddedBuiltinDefinitions.cpp | 238 + Analysis/src/Error.cpp | 751 ++ Analysis/src/Frontend.cpp | 967 +++ Analysis/src/IostreamHelpers.cpp | 280 + Analysis/src/JsonEncoder.cpp | 1041 +++ Analysis/src/Linter.cpp | 2568 +++++++ Analysis/src/Module.cpp | 521 ++ Analysis/src/Predicate.cpp | 93 + Analysis/src/RequireTracer.cpp | 190 + Analysis/src/Substitution.cpp | 530 ++ Analysis/src/Symbol.cpp | 18 + Analysis/src/ToString.cpp | 1142 +++ Analysis/src/TopoSortStatements.cpp | 552 ++ Analysis/src/Transpiler.cpp | 1156 +++ Analysis/src/TxnLog.cpp | 72 + Analysis/src/TypeAttach.cpp | 437 ++ Analysis/src/TypeInfer.cpp | 5497 ++++++++++++++ Analysis/src/TypePack.cpp | 277 + Analysis/src/TypeUtils.cpp | 95 + Analysis/src/TypeVar.cpp | 1505 ++++ Analysis/src/TypedAllocator.cpp | 99 + Analysis/src/Unifiable.cpp | 67 + Analysis/src/Unifier.cpp | 1575 ++++ Ast/include/Luau/Ast.h | 1198 +++ Ast/include/Luau/Common.h | 133 + Ast/include/Luau/Confusables.h | 9 + Ast/include/Luau/DenseHash.h | 407 + Ast/include/Luau/Lexer.h | 236 + Ast/include/Luau/Location.h | 109 + Ast/include/Luau/ParseOptions.h | 23 + Ast/include/Luau/Parser.h | 423 ++ Ast/include/Luau/StringUtils.h | 37 + Ast/src/Ast.cpp | 886 +++ Ast/src/Confusables.cpp | 1818 +++++ Ast/src/Lexer.cpp | 1149 +++ Ast/src/Location.cpp | 17 + Ast/src/Parser.cpp | 2693 +++++++ Ast/src/StringUtils.cpp | 228 + CLI/Analyze.cpp | 252 + CLI/FileUtils.cpp | 224 + CLI/FileUtils.h | 14 + CLI/Profiler.cpp | 155 + CLI/Profiler.h | 8 + CLI/Repl.cpp | 495 ++ CMakeLists.txt | 88 + Compiler/include/Luau/Bytecode.h | 478 ++ Compiler/include/Luau/BytecodeBuilder.h | 250 + Compiler/include/Luau/Compiler.h | 67 + Compiler/src/BytecodeBuilder.cpp | 1726 +++++ Compiler/src/Compiler.cpp | 3778 ++++++++++ LICENSE.txt | 21 + Makefile | 169 + Sources.cmake | 215 + VM/include/lua.h | 385 + VM/include/luaconf.h | 124 + VM/include/lualib.h | 129 + VM/src/lapi.cpp | 1273 ++++ VM/src/lapi.h | 8 + VM/src/laux.cpp | 477 ++ VM/src/lbaselib.cpp | 466 ++ VM/src/lbitlib.cpp | 201 + VM/src/lbuiltins.cpp | 1099 +++ VM/src/lbuiltins.h | 9 + VM/src/lbytecode.h | 9 + VM/src/lcommon.h | 52 + VM/src/lcorolib.cpp | 265 + VM/src/ldblib.cpp | 167 + VM/src/ldebug.cpp | 428 ++ VM/src/ldebug.h | 28 + VM/src/ldo.cpp | 554 ++ VM/src/ldo.h | 54 + VM/src/lfunc.cpp | 167 + VM/src/lfunc.h | 18 + VM/src/lgc.cpp | 1696 +++++ VM/src/lgc.h | 150 + VM/src/linit.cpp | 87 + VM/src/lmathlib.cpp | 446 ++ VM/src/lmem.cpp | 340 + VM/src/lmem.h | 21 + VM/src/lnumutils.h | 52 + VM/src/lobject.cpp | 160 + VM/src/lobject.h | 447 ++ VM/src/loslib.cpp | 193 + VM/src/lperf.cpp | 55 + VM/src/lstate.cpp | 199 + VM/src/lstate.h | 271 + VM/src/lstring.cpp | 237 + VM/src/lstring.h | 33 + VM/src/lstrlib.cpp | 1654 +++++ VM/src/ltable.cpp | 799 ++ VM/src/ltable.h | 30 + VM/src/ltablib.cpp | 569 ++ VM/src/ltm.cpp | 140 + VM/src/ltm.h | 57 + VM/src/lutf8lib.cpp | 294 + VM/src/lvm.h | 31 + VM/src/lvmexecute.cpp | 2957 ++++++++ VM/src/lvmload.cpp | 330 + VM/src/lvmutils.cpp | 492 ++ bench/bench.py | 860 +++ bench/bench_support.lua | 50 + bench/color.py | 37 + bench/gc/test_BinaryTree.lua | 56 + bench/gc/test_GC_Boehm_Trees.lua | 77 + bench/gc/test_GC_Tree_Pruning_Eager.lua | 50 + bench/gc/test_GC_Tree_Pruning_Gen.lua | 54 + bench/gc/test_GC_Tree_Pruning_Lazy.lua | 56 + bench/gc/test_GC_hashtable_Keyval.lua | 22 + bench/gc/test_LB_mandel.lua | 97 + bench/gc/test_LargeTableCtor_array.lua | 35 + bench/gc/test_LargeTableCtor_hash.lua | 14 + bench/gc/test_Pcall_pcall_yield.lua | 18 + bench/gc/test_SunSpider_3d-raytrace.lua | 502 ++ bench/gc/test_SunSpider_crypto-aes.lua | 436 ++ bench/gc/test_TableCreate_nil.lua | 12 + bench/gc/test_TableCreate_number.lua | 12 + bench/gc/test_TableCreate_zerofill.lua | 15 + bench/gc/test_TableMarshal_select.lua | 20 + bench/gc/test_TableMarshal_table_pack.lua | 20 + bench/gc/test_TableMarshal_varargs.lua | 20 + bench/influxbench.py | 85 + bench/install.bat | 3 + bench/install.sh | 3 + bench/micro_tests/test_AbsSum_abs.lua | 14 + bench/micro_tests/test_AbsSum_and_or.lua | 13 + bench/micro_tests/test_AbsSum_math_abs.lua | 13 + bench/micro_tests/test_Assert.lua | 12 + bench/micro_tests/test_Factorial.lua | 24 + .../micro_tests/test_Failure_pcall_a_bar.lua | 14 + .../test_Failure_pcall_game_Foo.lua | 14 + .../micro_tests/test_Failure_xpcall_a_bar.lua | 15 + .../test_Failure_xpcall_game_Foo.lua | 15 + .../micro_tests/test_LargeTableCtor_array.lua | 35 + .../micro_tests/test_LargeTableCtor_hash.lua | 14 + .../test_LargeTableSum_loop_index.lua | 17 + .../test_LargeTableSum_loop_ipairs.lua | 17 + .../test_LargeTableSum_loop_pairs.lua | 17 + bench/micro_tests/test_MethodCalls.lua | 100 + bench/micro_tests/test_OOP_constructor.lua | 29 + bench/micro_tests/test_OOP_method_call.lua | 31 + .../test_OOP_virtual_constructor.lua | 29 + bench/micro_tests/test_Pcall_call_return.lua | 14 + bench/micro_tests/test_Pcall_pcall_return.lua | 14 + bench/micro_tests/test_Pcall_pcall_yield.lua | 18 + .../micro_tests/test_Pcall_xpcall_return.lua | 14 + bench/micro_tests/test_SqrtSum_exponent.lua | 13 + bench/micro_tests/test_SqrtSum_math_sqrt.lua | 13 + bench/micro_tests/test_SqrtSum_sqrt.lua | 14 + .../micro_tests/test_SqrtSum_sqrt_getfenv.lua | 15 + .../test_SqrtSum_sqrt_roundabout.lua | 14 + bench/micro_tests/test_TableCreate_nil.lua | 12 + bench/micro_tests/test_TableCreate_number.lua | 12 + .../micro_tests/test_TableCreate_zerofill.lua | 15 + .../test_TableFind_loop_ipairs.lua | 24 + .../micro_tests/test_TableFind_table_find.lua | 14 + .../test_TableInsertion_index_cached.lua | 19 + .../test_TableInsertion_index_len.lua | 19 + .../test_TableInsertion_table_insert.lua | 19 + ...test_TableInsertion_table_insert_index.lua | 30 + bench/micro_tests/test_TableIteration.lua | 19 + .../micro_tests/test_TableMarshal_select.lua | 20 + .../test_TableMarshal_table_pack.lua | 20 + .../test_TableMarshal_table_unpack_array.lua | 21 + .../test_TableMarshal_table_unpack_range.lua | 21 + .../micro_tests/test_TableMarshal_varargs.lua | 20 + .../test_TableMove_empty_table.lua | 23 + .../micro_tests/test_TableMove_same_table.lua | 21 + .../test_TableMove_table_create.lua | 23 + .../test_TableRemoval_table_remove.lua | 22 + bench/micro_tests/test_UpvalueCapture.lua | 19 + bench/micro_tests/test_VariadicSelect.lua | 26 + bench/micro_tests/test_string_lib.lua | 39 + bench/micro_tests/test_table_concat.lua | 27 + bench/tabulate.py | 91 + bench/tests/base64.lua | 81 + bench/tests/deltablue.lua | 934 +++ bench/tests/life.lua | 122 + bench/tests/qsort.lua | 79 + bench/tests/sha256.lua | 142 + bench/tests/shootout/ack.lua | 46 + bench/tests/shootout/binary-trees.lua | 82 + bench/tests/shootout/fannkuch-redux.lua | 78 + bench/tests/shootout/fixpoint-fact.lua | 56 + bench/tests/shootout/heapsort.lua | 79 + bench/tests/shootout/mandel.lua | 97 + bench/tests/shootout/n-body.lua | 131 + bench/tests/shootout/qt.lua | 334 + bench/tests/shootout/queen.lua | 76 + bench/tests/shootout/scimark.lua | 442 ++ bench/tests/shootout/spectral-norm.lua | 74 + bench/tests/sieve.lua | 44 + bench/tests/sunspider/3d-cube.lua | 381 + bench/tests/sunspider/3d-morph.lua | 75 + bench/tests/sunspider/3d-raytrace.lua | 502 ++ bench/tests/sunspider/access-binary-trees.lua | 69 + .../tests/sunspider/controlflow-recursive.lua | 42 + bench/tests/sunspider/crypto-aes.lua | 362 + bench/tests/sunspider/fannkuch.lua | 90 + bench/tests/sunspider/math-cordic.lua | 104 + bench/tests/sunspider/math-partial-sums.lua | 53 + bench/tests/sunspider/math-spectral-norm.lua | 72 + bench/tests/sunspider/n-body-oop.lua | 169 + bench/tests/tictactoe.lua | 228 + bench/tests/trig.lua | 71 + extern/.clang-format | 2 + extern/doctest.h | 6580 +++++++++++++++++ extern/linenoise.hpp | 2415 ++++++ fuzz/basic.lua | 7 + fuzz/compiler.cpp | 10 + fuzz/format.cpp | 20 + fuzz/linter.cpp | 39 + fuzz/luau.proto | 342 + fuzz/parser.cpp | 15 + fuzz/proto.cpp | 266 + fuzz/protoprint.cpp | 951 +++ fuzz/prototest.cpp | 12 + fuzz/syntax.dict | 21 + fuzz/transpiler.cpp | 10 + fuzz/typeck.cpp | 52 + lua_LICENSE.txt | 19 + tests/AstQuery.test.cpp | 81 + tests/AstVisitor.test.cpp | 117 + tests/Autocomplete.test.cpp | 2576 +++++++ tests/BuiltinDefinitions.test.cpp | 45 + tests/Compiler.test.cpp | 3662 +++++++++ tests/Config.test.cpp | 156 + tests/Conformance.test.cpp | 804 ++ tests/Error.test.cpp | 16 + tests/Fixture.cpp | 431 ++ tests/Fixture.h | 205 + tests/Frontend.test.cpp | 965 +++ tests/IostreamOptional.h | 18 + tests/JsonEncoder.test.cpp | 53 + tests/Linter.test.cpp | 1494 ++++ tests/Module.test.cpp | 267 + tests/NonstrictMode.test.cpp | 282 + tests/Parser.test.cpp | 2522 +++++++ tests/Predicate.test.cpp | 123 + tests/RequireTracer.test.cpp | 221 + tests/ScopedFlags.h | 59 + tests/StringUtils.test.cpp | 109 + tests/Symbol.test.cpp | 40 + tests/ToString.test.cpp | 482 ++ tests/TopoSort.test.cpp | 413 ++ tests/Transpiler.test.cpp | 403 + tests/TypeInfer.annotations.test.cpp | 728 ++ tests/TypeInfer.builtins.test.cpp | 847 +++ tests/TypeInfer.classes.test.cpp | 456 ++ tests/TypeInfer.definitions.test.cpp | 300 + tests/TypeInfer.generics.test.cpp | 698 ++ tests/TypeInfer.intersectionTypes.test.cpp | 344 + tests/TypeInfer.provisional.test.cpp | 588 ++ tests/TypeInfer.refinements.test.cpp | 1160 +++ tests/TypeInfer.tables.test.cpp | 1827 +++++ tests/TypeInfer.test.cpp | 5306 +++++++++++++ tests/TypeInfer.tryUnify.test.cpp | 207 + tests/TypeInfer.typePacks.cpp | 297 + tests/TypeInfer.unionTypes.test.cpp | 464 ++ tests/TypePack.test.cpp | 201 + tests/TypeVar.test.cpp | 267 + tests/Variant.test.cpp | 178 + tests/conformance/apicalls.lua | 8 + tests/conformance/assert.lua | 34 + tests/conformance/attrib.lua | 106 + tests/conformance/basic.lua | 879 +++ tests/conformance/bitwise.lua | 140 + tests/conformance/calls.lua | 229 + tests/conformance/clear.lua | 79 + tests/conformance/closure.lua | 429 ++ tests/conformance/constructs.lua | 240 + tests/conformance/coroutine.lua | 322 + tests/conformance/datetime.lua | 77 + tests/conformance/debug.lua | 101 + tests/conformance/debugger.lua | 48 + tests/conformance/errors.lua | 296 + tests/conformance/events.lua | 389 + tests/conformance/exceptions.lua | 35 + tests/conformance/gc.lua | 294 + tests/conformance/ifelseexpr.lua | 80 + tests/conformance/literals.lua | 180 + tests/conformance/locals.lua | 127 + tests/conformance/math.lua | 296 + tests/conformance/move.lua | 77 + tests/conformance/nextvar.lua | 515 ++ tests/conformance/pcall.lua | 147 + tests/conformance/pm.lua | 317 + tests/conformance/sort.lua | 75 + tests/conformance/strings.lua | 185 + tests/conformance/tpack.lua | 327 + tests/conformance/types.lua | 56 + tests/conformance/utf8.lua | 208 + tests/conformance/vararg.lua | 137 + tests/conformance/vector.lua | 74 + tests/main.cpp | 267 + tools/gdb-printers.py | 19 + tools/heapgraph.py | 182 + tools/heapstat.py | 59 + tools/perfgraph.py | 52 + tools/svg.py | 498 ++ 336 files changed, 118239 insertions(+) create mode 100644 .clang-format create mode 100644 Analysis/include/Luau/AstQuery.h create mode 100644 Analysis/include/Luau/Autocomplete.h create mode 100644 Analysis/include/Luau/BuiltinDefinitions.h create mode 100644 Analysis/include/Luau/Config.h create mode 100644 Analysis/include/Luau/Documentation.h create mode 100644 Analysis/include/Luau/Error.h create mode 100644 Analysis/include/Luau/FileResolver.h create mode 100644 Analysis/include/Luau/Frontend.h create mode 100644 Analysis/include/Luau/IostreamHelpers.h create mode 100644 Analysis/include/Luau/JsonEncoder.h create mode 100644 Analysis/include/Luau/Linter.h create mode 100644 Analysis/include/Luau/Module.h create mode 100644 Analysis/include/Luau/ModuleResolver.h create mode 100644 Analysis/include/Luau/Predicate.h create mode 100644 Analysis/include/Luau/RecursionCounter.h create mode 100644 Analysis/include/Luau/RequireTracer.h create mode 100644 Analysis/include/Luau/Substitution.h create mode 100644 Analysis/include/Luau/Symbol.h create mode 100644 Analysis/include/Luau/ToString.h create mode 100644 Analysis/include/Luau/TopoSortStatements.h create mode 100644 Analysis/include/Luau/Transpiler.h create mode 100644 Analysis/include/Luau/TxnLog.h create mode 100644 Analysis/include/Luau/TypeAttach.h create mode 100644 Analysis/include/Luau/TypeInfer.h create mode 100644 Analysis/include/Luau/TypePack.h create mode 100644 Analysis/include/Luau/TypeUtils.h create mode 100644 Analysis/include/Luau/TypeVar.h create mode 100644 Analysis/include/Luau/TypedAllocator.h create mode 100644 Analysis/include/Luau/Unifiable.h create mode 100644 Analysis/include/Luau/Unifier.h create mode 100644 Analysis/include/Luau/Variant.h create mode 100644 Analysis/include/Luau/VisitTypeVar.h create mode 100644 Analysis/src/AstQuery.cpp create mode 100644 Analysis/src/Autocomplete.cpp create mode 100644 Analysis/src/BuiltinDefinitions.cpp create mode 100644 Analysis/src/Config.cpp create mode 100644 Analysis/src/EmbeddedBuiltinDefinitions.cpp create mode 100644 Analysis/src/Error.cpp create mode 100644 Analysis/src/Frontend.cpp create mode 100644 Analysis/src/IostreamHelpers.cpp create mode 100644 Analysis/src/JsonEncoder.cpp create mode 100644 Analysis/src/Linter.cpp create mode 100644 Analysis/src/Module.cpp create mode 100644 Analysis/src/Predicate.cpp create mode 100644 Analysis/src/RequireTracer.cpp create mode 100644 Analysis/src/Substitution.cpp create mode 100644 Analysis/src/Symbol.cpp create mode 100644 Analysis/src/ToString.cpp create mode 100644 Analysis/src/TopoSortStatements.cpp create mode 100644 Analysis/src/Transpiler.cpp create mode 100644 Analysis/src/TxnLog.cpp create mode 100644 Analysis/src/TypeAttach.cpp create mode 100644 Analysis/src/TypeInfer.cpp create mode 100644 Analysis/src/TypePack.cpp create mode 100644 Analysis/src/TypeUtils.cpp create mode 100644 Analysis/src/TypeVar.cpp create mode 100644 Analysis/src/TypedAllocator.cpp create mode 100644 Analysis/src/Unifiable.cpp create mode 100644 Analysis/src/Unifier.cpp create mode 100644 Ast/include/Luau/Ast.h create mode 100644 Ast/include/Luau/Common.h create mode 100644 Ast/include/Luau/Confusables.h create mode 100644 Ast/include/Luau/DenseHash.h create mode 100644 Ast/include/Luau/Lexer.h create mode 100644 Ast/include/Luau/Location.h create mode 100644 Ast/include/Luau/ParseOptions.h create mode 100644 Ast/include/Luau/Parser.h create mode 100644 Ast/include/Luau/StringUtils.h create mode 100644 Ast/src/Ast.cpp create mode 100644 Ast/src/Confusables.cpp create mode 100644 Ast/src/Lexer.cpp create mode 100644 Ast/src/Location.cpp create mode 100644 Ast/src/Parser.cpp create mode 100644 Ast/src/StringUtils.cpp create mode 100644 CLI/Analyze.cpp create mode 100644 CLI/FileUtils.cpp create mode 100644 CLI/FileUtils.h create mode 100644 CLI/Profiler.cpp create mode 100644 CLI/Profiler.h create mode 100644 CLI/Repl.cpp create mode 100644 CMakeLists.txt create mode 100644 Compiler/include/Luau/Bytecode.h create mode 100644 Compiler/include/Luau/BytecodeBuilder.h create mode 100644 Compiler/include/Luau/Compiler.h create mode 100644 Compiler/src/BytecodeBuilder.cpp create mode 100644 Compiler/src/Compiler.cpp create mode 100644 LICENSE.txt create mode 100644 Makefile create mode 100644 Sources.cmake create mode 100644 VM/include/lua.h create mode 100644 VM/include/luaconf.h create mode 100644 VM/include/lualib.h create mode 100644 VM/src/lapi.cpp create mode 100644 VM/src/lapi.h create mode 100644 VM/src/laux.cpp create mode 100644 VM/src/lbaselib.cpp create mode 100644 VM/src/lbitlib.cpp create mode 100644 VM/src/lbuiltins.cpp create mode 100644 VM/src/lbuiltins.h create mode 100644 VM/src/lbytecode.h create mode 100644 VM/src/lcommon.h create mode 100644 VM/src/lcorolib.cpp create mode 100644 VM/src/ldblib.cpp create mode 100644 VM/src/ldebug.cpp create mode 100644 VM/src/ldebug.h create mode 100644 VM/src/ldo.cpp create mode 100644 VM/src/ldo.h create mode 100644 VM/src/lfunc.cpp create mode 100644 VM/src/lfunc.h create mode 100644 VM/src/lgc.cpp create mode 100644 VM/src/lgc.h create mode 100644 VM/src/linit.cpp create mode 100644 VM/src/lmathlib.cpp create mode 100644 VM/src/lmem.cpp create mode 100644 VM/src/lmem.h create mode 100644 VM/src/lnumutils.h create mode 100644 VM/src/lobject.cpp create mode 100644 VM/src/lobject.h create mode 100644 VM/src/loslib.cpp create mode 100644 VM/src/lperf.cpp create mode 100644 VM/src/lstate.cpp create mode 100644 VM/src/lstate.h create mode 100644 VM/src/lstring.cpp create mode 100644 VM/src/lstring.h create mode 100644 VM/src/lstrlib.cpp create mode 100644 VM/src/ltable.cpp create mode 100644 VM/src/ltable.h create mode 100644 VM/src/ltablib.cpp create mode 100644 VM/src/ltm.cpp create mode 100644 VM/src/ltm.h create mode 100644 VM/src/lutf8lib.cpp create mode 100644 VM/src/lvm.h create mode 100644 VM/src/lvmexecute.cpp create mode 100644 VM/src/lvmload.cpp create mode 100644 VM/src/lvmutils.cpp create mode 100644 bench/bench.py create mode 100644 bench/bench_support.lua create mode 100644 bench/color.py create mode 100644 bench/gc/test_BinaryTree.lua create mode 100644 bench/gc/test_GC_Boehm_Trees.lua create mode 100644 bench/gc/test_GC_Tree_Pruning_Eager.lua create mode 100644 bench/gc/test_GC_Tree_Pruning_Gen.lua create mode 100644 bench/gc/test_GC_Tree_Pruning_Lazy.lua create mode 100644 bench/gc/test_GC_hashtable_Keyval.lua create mode 100644 bench/gc/test_LB_mandel.lua create mode 100644 bench/gc/test_LargeTableCtor_array.lua create mode 100644 bench/gc/test_LargeTableCtor_hash.lua create mode 100644 bench/gc/test_Pcall_pcall_yield.lua create mode 100644 bench/gc/test_SunSpider_3d-raytrace.lua create mode 100644 bench/gc/test_SunSpider_crypto-aes.lua create mode 100644 bench/gc/test_TableCreate_nil.lua create mode 100644 bench/gc/test_TableCreate_number.lua create mode 100644 bench/gc/test_TableCreate_zerofill.lua create mode 100644 bench/gc/test_TableMarshal_select.lua create mode 100644 bench/gc/test_TableMarshal_table_pack.lua create mode 100644 bench/gc/test_TableMarshal_varargs.lua create mode 100644 bench/influxbench.py create mode 100644 bench/install.bat create mode 100644 bench/install.sh create mode 100644 bench/micro_tests/test_AbsSum_abs.lua create mode 100644 bench/micro_tests/test_AbsSum_and_or.lua create mode 100644 bench/micro_tests/test_AbsSum_math_abs.lua create mode 100644 bench/micro_tests/test_Assert.lua create mode 100644 bench/micro_tests/test_Factorial.lua create mode 100644 bench/micro_tests/test_Failure_pcall_a_bar.lua create mode 100644 bench/micro_tests/test_Failure_pcall_game_Foo.lua create mode 100644 bench/micro_tests/test_Failure_xpcall_a_bar.lua create mode 100644 bench/micro_tests/test_Failure_xpcall_game_Foo.lua create mode 100644 bench/micro_tests/test_LargeTableCtor_array.lua create mode 100644 bench/micro_tests/test_LargeTableCtor_hash.lua create mode 100644 bench/micro_tests/test_LargeTableSum_loop_index.lua create mode 100644 bench/micro_tests/test_LargeTableSum_loop_ipairs.lua create mode 100644 bench/micro_tests/test_LargeTableSum_loop_pairs.lua create mode 100644 bench/micro_tests/test_MethodCalls.lua create mode 100644 bench/micro_tests/test_OOP_constructor.lua create mode 100644 bench/micro_tests/test_OOP_method_call.lua create mode 100644 bench/micro_tests/test_OOP_virtual_constructor.lua create mode 100644 bench/micro_tests/test_Pcall_call_return.lua create mode 100644 bench/micro_tests/test_Pcall_pcall_return.lua create mode 100644 bench/micro_tests/test_Pcall_pcall_yield.lua create mode 100644 bench/micro_tests/test_Pcall_xpcall_return.lua create mode 100644 bench/micro_tests/test_SqrtSum_exponent.lua create mode 100644 bench/micro_tests/test_SqrtSum_math_sqrt.lua create mode 100644 bench/micro_tests/test_SqrtSum_sqrt.lua create mode 100644 bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua create mode 100644 bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua create mode 100644 bench/micro_tests/test_TableCreate_nil.lua create mode 100644 bench/micro_tests/test_TableCreate_number.lua create mode 100644 bench/micro_tests/test_TableCreate_zerofill.lua create mode 100644 bench/micro_tests/test_TableFind_loop_ipairs.lua create mode 100644 bench/micro_tests/test_TableFind_table_find.lua create mode 100644 bench/micro_tests/test_TableInsertion_index_cached.lua create mode 100644 bench/micro_tests/test_TableInsertion_index_len.lua create mode 100644 bench/micro_tests/test_TableInsertion_table_insert.lua create mode 100644 bench/micro_tests/test_TableInsertion_table_insert_index.lua create mode 100644 bench/micro_tests/test_TableIteration.lua create mode 100644 bench/micro_tests/test_TableMarshal_select.lua create mode 100644 bench/micro_tests/test_TableMarshal_table_pack.lua create mode 100644 bench/micro_tests/test_TableMarshal_table_unpack_array.lua create mode 100644 bench/micro_tests/test_TableMarshal_table_unpack_range.lua create mode 100644 bench/micro_tests/test_TableMarshal_varargs.lua create mode 100644 bench/micro_tests/test_TableMove_empty_table.lua create mode 100644 bench/micro_tests/test_TableMove_same_table.lua create mode 100644 bench/micro_tests/test_TableMove_table_create.lua create mode 100644 bench/micro_tests/test_TableRemoval_table_remove.lua create mode 100644 bench/micro_tests/test_UpvalueCapture.lua create mode 100644 bench/micro_tests/test_VariadicSelect.lua create mode 100644 bench/micro_tests/test_string_lib.lua create mode 100644 bench/micro_tests/test_table_concat.lua create mode 100644 bench/tabulate.py create mode 100644 bench/tests/base64.lua create mode 100644 bench/tests/deltablue.lua create mode 100644 bench/tests/life.lua create mode 100644 bench/tests/qsort.lua create mode 100644 bench/tests/sha256.lua create mode 100644 bench/tests/shootout/ack.lua create mode 100644 bench/tests/shootout/binary-trees.lua create mode 100644 bench/tests/shootout/fannkuch-redux.lua create mode 100644 bench/tests/shootout/fixpoint-fact.lua create mode 100644 bench/tests/shootout/heapsort.lua create mode 100644 bench/tests/shootout/mandel.lua create mode 100644 bench/tests/shootout/n-body.lua create mode 100644 bench/tests/shootout/qt.lua create mode 100644 bench/tests/shootout/queen.lua create mode 100644 bench/tests/shootout/scimark.lua create mode 100644 bench/tests/shootout/spectral-norm.lua create mode 100644 bench/tests/sieve.lua create mode 100644 bench/tests/sunspider/3d-cube.lua create mode 100644 bench/tests/sunspider/3d-morph.lua create mode 100644 bench/tests/sunspider/3d-raytrace.lua create mode 100644 bench/tests/sunspider/access-binary-trees.lua create mode 100644 bench/tests/sunspider/controlflow-recursive.lua create mode 100644 bench/tests/sunspider/crypto-aes.lua create mode 100644 bench/tests/sunspider/fannkuch.lua create mode 100644 bench/tests/sunspider/math-cordic.lua create mode 100644 bench/tests/sunspider/math-partial-sums.lua create mode 100644 bench/tests/sunspider/math-spectral-norm.lua create mode 100644 bench/tests/sunspider/n-body-oop.lua create mode 100644 bench/tests/tictactoe.lua create mode 100644 bench/tests/trig.lua create mode 100644 extern/.clang-format create mode 100644 extern/doctest.h create mode 100644 extern/linenoise.hpp create mode 100644 fuzz/basic.lua create mode 100644 fuzz/compiler.cpp create mode 100644 fuzz/format.cpp create mode 100644 fuzz/linter.cpp create mode 100644 fuzz/luau.proto create mode 100644 fuzz/parser.cpp create mode 100644 fuzz/proto.cpp create mode 100644 fuzz/protoprint.cpp create mode 100644 fuzz/prototest.cpp create mode 100644 fuzz/syntax.dict create mode 100644 fuzz/transpiler.cpp create mode 100644 fuzz/typeck.cpp create mode 100644 lua_LICENSE.txt create mode 100644 tests/AstQuery.test.cpp create mode 100644 tests/AstVisitor.test.cpp create mode 100644 tests/Autocomplete.test.cpp create mode 100644 tests/BuiltinDefinitions.test.cpp create mode 100644 tests/Compiler.test.cpp create mode 100644 tests/Config.test.cpp create mode 100644 tests/Conformance.test.cpp create mode 100644 tests/Error.test.cpp create mode 100644 tests/Fixture.cpp create mode 100644 tests/Fixture.h create mode 100644 tests/Frontend.test.cpp create mode 100644 tests/IostreamOptional.h create mode 100644 tests/JsonEncoder.test.cpp create mode 100644 tests/Linter.test.cpp create mode 100644 tests/Module.test.cpp create mode 100644 tests/NonstrictMode.test.cpp create mode 100644 tests/Parser.test.cpp create mode 100644 tests/Predicate.test.cpp create mode 100644 tests/RequireTracer.test.cpp create mode 100644 tests/ScopedFlags.h create mode 100644 tests/StringUtils.test.cpp create mode 100644 tests/Symbol.test.cpp create mode 100644 tests/ToString.test.cpp create mode 100644 tests/TopoSort.test.cpp create mode 100644 tests/Transpiler.test.cpp create mode 100644 tests/TypeInfer.annotations.test.cpp create mode 100644 tests/TypeInfer.builtins.test.cpp create mode 100644 tests/TypeInfer.classes.test.cpp create mode 100644 tests/TypeInfer.definitions.test.cpp create mode 100644 tests/TypeInfer.generics.test.cpp create mode 100644 tests/TypeInfer.intersectionTypes.test.cpp create mode 100644 tests/TypeInfer.provisional.test.cpp create mode 100644 tests/TypeInfer.refinements.test.cpp create mode 100644 tests/TypeInfer.tables.test.cpp create mode 100644 tests/TypeInfer.test.cpp create mode 100644 tests/TypeInfer.tryUnify.test.cpp create mode 100644 tests/TypeInfer.typePacks.cpp create mode 100644 tests/TypeInfer.unionTypes.test.cpp create mode 100644 tests/TypePack.test.cpp create mode 100644 tests/TypeVar.test.cpp create mode 100644 tests/Variant.test.cpp create mode 100644 tests/conformance/apicalls.lua create mode 100644 tests/conformance/assert.lua create mode 100644 tests/conformance/attrib.lua create mode 100644 tests/conformance/basic.lua create mode 100644 tests/conformance/bitwise.lua create mode 100644 tests/conformance/calls.lua create mode 100644 tests/conformance/clear.lua create mode 100644 tests/conformance/closure.lua create mode 100644 tests/conformance/constructs.lua create mode 100644 tests/conformance/coroutine.lua create mode 100644 tests/conformance/datetime.lua create mode 100644 tests/conformance/debug.lua create mode 100644 tests/conformance/debugger.lua create mode 100644 tests/conformance/errors.lua create mode 100644 tests/conformance/events.lua create mode 100644 tests/conformance/exceptions.lua create mode 100644 tests/conformance/gc.lua create mode 100644 tests/conformance/ifelseexpr.lua create mode 100644 tests/conformance/literals.lua create mode 100644 tests/conformance/locals.lua create mode 100644 tests/conformance/math.lua create mode 100644 tests/conformance/move.lua create mode 100644 tests/conformance/nextvar.lua create mode 100644 tests/conformance/pcall.lua create mode 100644 tests/conformance/pm.lua create mode 100644 tests/conformance/sort.lua create mode 100644 tests/conformance/strings.lua create mode 100644 tests/conformance/tpack.lua create mode 100644 tests/conformance/types.lua create mode 100644 tests/conformance/utf8.lua create mode 100644 tests/conformance/vararg.lua create mode 100644 tests/conformance/vector.lua create mode 100644 tests/main.cpp create mode 100644 tools/gdb-printers.py create mode 100644 tools/heapgraph.py create mode 100644 tools/heapstat.py create mode 100644 tools/perfgraph.py create mode 100644 tools/svg.py diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..5f533ac --- /dev/null +++ b/.clang-format @@ -0,0 +1,25 @@ +BasedOnStyle: LLVM + +AccessModifierOffset: -4 +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortLambdasOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +BreakBeforeBraces: Allman +BreakConstructorInitializers: BeforeComma +BreakInheritanceList: BeforeComma +ColumnLimit: 150 +IndentCaseLabels: false +SortIncludes: false +IndentWidth: 4 +TabWidth: 4 +ObjCBlockIndentWidth: 4 +AlignAfterOpenBracket: DontAlign +UseTab: Never +PointerAlignment: Left +SpaceAfterTemplateKeyword: false +AlignEscapedNewlines: DontAlign +AlwaysBreakTemplateDeclarations: Yes +MaxEmptyLinesToKeep: 10 diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h new file mode 100644 index 0000000..d38976e --- /dev/null +++ b/Analysis/include/Luau/AstQuery.h @@ -0,0 +1,63 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Documentation.h" + +#include + +namespace Luau +{ + +struct Binding; +struct SourceModule; +struct Module; + +struct TypeVar; +using TypeId = const TypeVar*; + +using ScopePtr = std::shared_ptr; + +struct ExprOrLocal +{ + AstExpr* getExpr() + { + return expr; + } + AstLocal* getLocal() + { + return local; + } + void setExpr(AstExpr* newExpr) + { + expr = newExpr; + local = nullptr; + } + void setLocal(AstLocal* newLocal) + { + local = newLocal; + expr = nullptr; + } + std::optional getLocation() + { + return expr ? expr->location : (local ? local->location : std::optional{}); + } + +private: + AstExpr* expr = nullptr; + AstLocal* local = nullptr; +}; + +std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos); +AstNode* findNodeAtPosition(const SourceModule& source, Position pos); +AstExpr* findExprAtPosition(const SourceModule& source, Position pos); +ScopePtr findScopeAtPosition(const Module& module, Position pos); +std::optional findBindingAtPosition(const Module& module, const SourceModule& source, Position pos); +ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos); + +std::optional findTypeAtPosition(const Module& module, const SourceModule& sourceModule, Position pos); +std::optional findExpectedTypeAtPosition(const Module& module, const SourceModule& sourceModule, Position pos); + +std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position); + +} // namespace Luau diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h new file mode 100644 index 0000000..5853429 --- /dev/null +++ b/Analysis/include/Luau/Autocomplete.h @@ -0,0 +1,91 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" +#include "Luau/TypeVar.h" + +#include +#include +#include +#include + +namespace Luau +{ + +struct Frontend; +struct SourceModule; +struct Module; +struct TypeChecker; + +using ModulePtr = std::shared_ptr; + +enum class AutocompleteEntryKind +{ + Property, + Binding, + Keyword, + String, + Type, + Module, +}; + +enum class ParenthesesRecommendation +{ + None, + CursorAfter, + CursorInside, +}; + +enum class TypeCorrectKind +{ + None, + Correct, + CorrectFunctionResult, +}; + +struct AutocompleteEntry +{ + AutocompleteEntryKind kind = AutocompleteEntryKind::Property; + // Nullopt if kind is Keyword + std::optional type = std::nullopt; + bool deprecated = false; + // Only meaningful if kind is Property. + bool wrongIndexType = false; + // Set if this suggestion matches the type expected in the context + TypeCorrectKind typeCorrect = TypeCorrectKind::None; + + std::optional containingClass = std::nullopt; + std::optional prop = std::nullopt; + std::optional documentationSymbol = std::nullopt; + Tags tags; + ParenthesesRecommendation parens = ParenthesesRecommendation::None; +}; + +using AutocompleteEntryMap = std::unordered_map; +struct AutocompleteResult +{ + AutocompleteEntryMap entryMap; + std::vector ancestry; + + AutocompleteResult() = default; + AutocompleteResult(AutocompleteEntryMap entryMap, std::vector ancestry) + : entryMap(std::move(entryMap)) + , ancestry(std::move(ancestry)) + { + } +}; + +using ModuleName = std::string; +using StringCompletionCallback = std::function(std::string tag, std::optional ctx)>; + +struct OwningAutocompleteResult +{ + AutocompleteResult result; + ModulePtr module; + std::unique_ptr sourceModule; +}; + +AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); +OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback); + +} // namespace Luau diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h new file mode 100644 index 0000000..8f17fff --- /dev/null +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -0,0 +1,51 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "TypeInfer.h" + +namespace Luau +{ + +void registerBuiltinTypes(TypeChecker& typeChecker); + +TypeId makeUnion(TypeArena& arena, std::vector&& types); +TypeId makeIntersection(TypeArena& arena, std::vector&& types); + +/** Build an optional 't' + */ +TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t); + +/** Small utility function for building up type definitions from C++. + */ +TypeId makeFunction( // Monomorphic + TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list retTypes); + +TypeId makeFunction( // Polymorphic + TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, + std::initializer_list paramTypes, std::initializer_list retTypes); + +TypeId makeFunction( // Monomorphic + TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list paramNames, + std::initializer_list retTypes); + +TypeId makeFunction( // Polymorphic + TypeArena& arena, std::optional selfType, std::initializer_list generics, std::initializer_list genericPacks, + std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); + +void attachMagicFunction(TypeId ty, MagicFunction fn); +void attachFunctionTag(TypeId ty, std::string constraint); + +Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); +void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName); + +std::string getBuiltinDefinitionSource(); + +void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding); +void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding); +std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std::string& name); +Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name); +TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name); + +} // namespace Luau diff --git a/Analysis/include/Luau/Config.h b/Analysis/include/Luau/Config.h new file mode 100644 index 0000000..56cdfe7 --- /dev/null +++ b/Analysis/include/Luau/Config.h @@ -0,0 +1,58 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Linter.h" +#include "Luau/ParseOptions.h" + +#include +#include +#include + +namespace Luau +{ + +using ModuleName = std::string; + +constexpr const char* kConfigName = ".luaurc"; + +struct Config +{ + Config() + { + enabledLint.setDefaults(); + } + + Mode mode = Mode::NoCheck; + + ParseOptions parseOptions; + + LintOptions enabledLint; + LintOptions fatalLint; + + bool lintErrors = false; + bool typeErrors = true; + + std::vector globals; +}; + +struct ConfigResolver +{ + virtual ~ConfigResolver() {} + + virtual const Config& getConfig(const ModuleName& name) const = 0; +}; + +struct NullConfigResolver : ConfigResolver +{ + Config defaultConfig; + + virtual const Config& getConfig(const ModuleName& name) const override; +}; + +std::optional parseModeString(Mode& mode, const std::string& modeString, bool compat = false); +std::optional parseLintRuleString( + LintOptions& enabledLints, LintOptions& fatalLints, const std::string& warningName, const std::string& value, bool compat = false); + +std::optional parseConfig(const std::string& contents, Config& config, bool compat = false); + +} // namespace Luau diff --git a/Analysis/include/Luau/Documentation.h b/Analysis/include/Luau/Documentation.h new file mode 100644 index 0000000..7b609b4 --- /dev/null +++ b/Analysis/include/Luau/Documentation.h @@ -0,0 +1,50 @@ +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/Variant.h" + +#include +#include + +namespace Luau +{ + +struct FunctionDocumentation; +struct TableDocumentation; +struct OverloadedFunctionDocumentation; + +using Documentation = Luau::Variant; +using DocumentationSymbol = std::string; + +struct FunctionParameterDocumentation +{ + std::string name; + DocumentationSymbol documentation; +}; + +// Represents documentation for anything callable. This could be a method or a +// callback or a free function. +struct FunctionDocumentation +{ + std::string documentation; + std::vector parameters; + std::vector returns; +}; + +struct OverloadedFunctionDocumentation +{ + // This is a map of function signature to overload symbol name. + Luau::DenseHashMap overloads; +}; + +// Represents documentation for a table-like item, meaning "anything with keys". +// This could be a table or a class. +struct TableDocumentation +{ + std::string documentation; + Luau::DenseHashMap keys; +}; + +using DocumentationDatabase = Luau::DenseHashMap; + +} // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h new file mode 100644 index 0000000..946bc92 --- /dev/null +++ b/Analysis/include/Luau/Error.h @@ -0,0 +1,332 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/FileResolver.h" +#include "Luau/Location.h" +#include "Luau/TypeVar.h" +#include "Luau/Variant.h" + +namespace Luau +{ + +struct TypeMismatch +{ + TypeId wantedType; + TypeId givenType; + + bool operator==(const TypeMismatch& rhs) const; +}; + +struct UnknownSymbol +{ + enum Context + { + Binding, + Type, + Generic + }; + Name name; + Context context; + + bool operator==(const UnknownSymbol& rhs) const; +}; + +struct UnknownProperty +{ + TypeId table; + Name key; + + bool operator==(const UnknownProperty& rhs) const; +}; + +struct NotATable +{ + TypeId ty; + + bool operator==(const NotATable& rhs) const; +}; + +struct CannotExtendTable +{ + enum Context + { + Property, + Indexer, + Metatable + }; + TypeId tableType; + Context context; + Name prop; + + bool operator==(const CannotExtendTable& rhs) const; +}; + +struct OnlyTablesCanHaveMethods +{ + TypeId tableType; + + bool operator==(const OnlyTablesCanHaveMethods& rhs) const; +}; + +struct DuplicateTypeDefinition +{ + Name name; + Location previousLocation; + + bool operator==(const DuplicateTypeDefinition& rhs) const; +}; + +struct CountMismatch +{ + enum Context + { + Arg, + Result, + Return, + }; + size_t expected; + size_t actual; + Context context = Arg; + + bool operator==(const CountMismatch& rhs) const; +}; + +struct FunctionDoesNotTakeSelf +{ + bool operator==(const FunctionDoesNotTakeSelf& rhs) const; +}; + +struct FunctionRequiresSelf +{ + int requiredExtraNils = 0; + + bool operator==(const FunctionRequiresSelf& rhs) const; +}; + +struct OccursCheckFailed +{ + bool operator==(const OccursCheckFailed& rhs) const; +}; + +struct UnknownRequire +{ + std::string modulePath; + + bool operator==(const UnknownRequire& rhs) const; +}; + +struct IncorrectGenericParameterCount +{ + Name name; + TypeFun typeFun; + size_t actualParameters; + + bool operator==(const IncorrectGenericParameterCount& rhs) const; +}; + +struct SyntaxError +{ + std::string message; + + bool operator==(const SyntaxError& rhs) const; +}; + +struct CodeTooComplex +{ + bool operator==(const CodeTooComplex&) const; +}; + +struct UnificationTooComplex +{ + bool operator==(const UnificationTooComplex&) const; +}; + +// Could easily be folded into UnknownProperty with an extra field, std::set candidates. +// But for telemetry purposes, we want to have this be a distinct variant. +struct UnknownPropButFoundLikeProp +{ + TypeId table; + Name key; + std::set candidates; + + bool operator==(const UnknownPropButFoundLikeProp& rhs) const; +}; + +struct GenericError +{ + std::string message; + + bool operator==(const GenericError& rhs) const; +}; + +struct CannotCallNonFunction +{ + TypeId ty; + + bool operator==(const CannotCallNonFunction& rhs) const; +}; + +struct ExtraInformation +{ + std::string message; + bool operator==(const ExtraInformation& rhs) const; +}; + +struct DeprecatedApiUsed +{ + std::string symbol; + std::string useInstead; + bool operator==(const DeprecatedApiUsed& rhs) const; +}; + +struct ModuleHasCyclicDependency +{ + std::vector cycle; + bool operator==(const ModuleHasCyclicDependency& rhs) const; +}; + +struct FunctionExitsWithoutReturning +{ + TypePackId expectedReturnType; + bool operator==(const FunctionExitsWithoutReturning& rhs) const; +}; + +struct IllegalRequire +{ + std::string moduleName; + std::string reason; + + bool operator==(const IllegalRequire& rhs) const; +}; + +struct MissingProperties +{ + enum Context + { + Missing, + Extra + }; + TypeId superType; + TypeId subType; + std::vector properties; + Context context = Missing; + + bool operator==(const MissingProperties& rhs) const; +}; + +struct DuplicateGenericParameter +{ + std::string parameterName; + + bool operator==(const DuplicateGenericParameter& rhs) const; +}; + +struct CannotInferBinaryOperation +{ + enum OpKind + { + Operation, + Comparison, + }; + + AstExprBinary::Op op; + std::optional suggestedToAnnotate; + OpKind kind; + + bool operator==(const CannotInferBinaryOperation& rhs) const; +}; + +struct SwappedGenericTypeParameter +{ + enum Kind + { + Type, + Pack, + }; + + std::string name; + // What was `name` being used as? + Kind kind; + + bool operator==(const SwappedGenericTypeParameter& rhs) const; +}; + +struct OptionalValueAccess +{ + TypeId optional; + + bool operator==(const OptionalValueAccess& rhs) const; +}; + +struct MissingUnionProperty +{ + TypeId type; + std::vector missing; + Name key; + + bool operator==(const MissingUnionProperty& rhs) const; +}; + +using TypeErrorData = Variant; + +struct TypeError +{ + Location location; + ModuleName moduleName; + TypeErrorData data; + + int code() const; + + TypeError() = default; + + TypeError(const Location& location, const ModuleName& moduleName, const TypeErrorData& data) + : location(location) + , moduleName(moduleName) + , data(data) + { + } + + TypeError(const Location& location, const TypeErrorData& data) + : TypeError(location, {}, data) + { + } + + bool operator==(const TypeError& rhs) const; +}; + +template +const T* get(const TypeError& e) +{ + return get_if(&e.data); +} + +template +T* get(TypeError& e) +{ + return get_if(&e.data); +} + +using ErrorVec = std::vector; + +std::string toString(const TypeError& error); + +bool containsParseErrorName(const TypeError& error); + +// Copy any types named in the error into destArena. +void copyErrors(ErrorVec& errors, struct TypeArena& destArena); + +// Internal Compiler Error +struct InternalErrorReporter +{ + std::function onInternalError; + std::string moduleName; + + [[noreturn]] void ice(const std::string& message, const Location& location); + [[noreturn]] void ice(const std::string& message); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h new file mode 100644 index 0000000..71f9464 --- /dev/null +++ b/Analysis/include/Luau/FileResolver.h @@ -0,0 +1,103 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +namespace Luau +{ + +class AstExpr; + +using ModuleName = std::string; + +struct SourceCode +{ + enum Type + { + None, + Module, + Script, + Local + }; + + std::string source; + Type type; +}; + +struct FileResolver +{ + virtual ~FileResolver() {} + + /** Fetch the source code associated with the provided ModuleName. + * + * FIXME: This requires a string copy! + * + * @returns The actual Lua code on success. + * @returns std::nullopt if no such file exists. When this occurs, type inference will report an UnknownRequire error. + */ + virtual std::optional readSource(const ModuleName& name) = 0; + + /** Does the module exist? + * + * Saves a string copy over reading the source and throwing it away. + */ + virtual bool moduleExists(const ModuleName& name) const = 0; + + virtual std::optional fromAstFragment(AstExpr* expr) const = 0; + + /** Given a valid module name and a string of arbitrary data, figure out the concatenation. + */ + virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; + + /** Goes "up" a level in the hierarchy that the ModuleName represents. + * + * For instances, this is analogous to someInstance.Parent; for paths, this is equivalent to removing the last + * element of the path. Other ModuleName representations may have other ways of doing this. + * + * @returns The parent ModuleName, if one exists. + * @returns std::nullopt if there is no parent for this module name. + */ + virtual std::optional getParentModuleName(const ModuleName& name) const = 0; + + virtual std::optional getHumanReadableModuleName_(const ModuleName& name) const + { + return name; + } + + virtual std::optional getEnvironmentForModule(const ModuleName& name) const = 0; + + /** LanguageService only: + * std::optional fromInstance(Instance* inst) + */ +}; + +struct NullFileResolver : FileResolver +{ + std::optional readSource(const ModuleName& name) override + { + return std::nullopt; + } + bool moduleExists(const ModuleName& name) const override + { + return false; + } + std::optional fromAstFragment(AstExpr* expr) const override + { + return std::nullopt; + } + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override + { + return lhs; + } + std::optional getParentModuleName(const ModuleName& name) const override + { + return std::nullopt; + } + std::optional getEnvironmentForModule(const ModuleName& name) const override + { + return std::nullopt; + } +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h new file mode 100644 index 0000000..07a0296 --- /dev/null +++ b/Analysis/include/Luau/Frontend.h @@ -0,0 +1,179 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Config.h" +#include "Luau/Module.h" +#include "Luau/ModuleResolver.h" +#include "Luau/RequireTracer.h" +#include "Luau/TypeInfer.h" +#include "Luau/Variant.h" + +#include +#include +#include + +namespace Luau +{ + +class AstStat; +class ParseError; +struct Frontend; +struct TypeError; +struct LintWarning; +struct TypeChecker; +struct FileResolver; +struct ModuleResolver; +struct ParseResult; + +struct LoadDefinitionFileResult +{ + bool success; + ParseResult parseResult; + ModulePtr module; +}; + +LoadDefinitionFileResult loadDefinitionFile( + TypeChecker& typeChecker, ScopePtr targetScope, std::string_view definition, const std::string& packageName); + +std::optional parseMode(const std::vector& hotcomments); + +std::vector parsePathExpr(const AstExpr& pathExpr); + +// Exported only for convenient testing. +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& expr); + +/** Try to convert an AST fragment into a ModuleName. + * Returns std::nullopt if the expression cannot be resolved. This will most likely happen in cases where + * the import path involves some dynamic computation that we cannot see into at typechecking time. + * + * Unintuitively, weirdly-formulated modules (like game.Parent.Parent.Parent.Foo) will successfully produce a ModuleName + * as long as it falls within the permitted syntax. This is ok because we will fail to find the module and produce an + * error when we try during typechecking. + */ +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& expr); + +struct SourceNode +{ + ModuleName name; + std::unordered_set requires; + std::vector> requireLocations; + bool dirty = true; +}; + +struct FrontendOptions +{ + // When true, we retain full type information about every term in the AST. + // Setting this to false cuts back on RAM and is a good idea for batch + // jobs where the type graph is not deeply inspected after typechecking + // is complete. + bool retainFullTypeGraphs = false; + + // When true, we run typechecking twice, one in the regular mode, ond once in strict mode + // in order to get more precise type information (e.g. for autocomplete). + bool typecheckTwice = false; +}; + +struct CheckResult +{ + std::vector errors; +}; + +struct FrontendModuleResolver : ModuleResolver +{ + FrontendModuleResolver(Frontend* frontend); + + const ModulePtr getModule(const ModuleName& moduleName) const override; + bool moduleExists(const ModuleName& moduleName) const override; + std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override; + std::string getHumanReadableModuleName(const ModuleName& moduleName) const override; + + Frontend* frontend; + std::unordered_map modules; +}; + +struct Frontend +{ + struct Stats + { + size_t files = 0; + size_t lines = 0; + + size_t filesStrict = 0; + size_t filesNonstrict = 0; + + double timeRead = 0; + double timeParse = 0; + double timeCheck = 0; + double timeLint = 0; + }; + + Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {}); + + CheckResult check(const ModuleName& name); // new shininess + LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); + + /** Lint some code that has no associated DataModel object + * + * Since this source fragment has no name, we cannot cache its AST. Instead, + * we return it to the caller to use as they wish. + */ + std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); + + CheckResult check(const SourceModule& module); // OLD. TODO KILL + LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); + + bool isDirty(const ModuleName& name) const; + void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); + + /** Borrow a pointer into the SourceModule cache. + * + * Returns nullptr if we don't have it. This could mean that the script + * doesn't exist, or simply that its contents have changed since the previous + * check, in which case we do not have its AST. + * + * IMPORTANT: this pointer is only valid until the next call to markDirty. Do not retain it. + */ + SourceModule* getSourceModule(const ModuleName& name); + const SourceModule* getSourceModule(const ModuleName& name) const; + + void clearStats(); + void clear(); + + ScopePtr addEnvironment(const std::string& environmentName); + ScopePtr getEnvironmentScope(const std::string& environmentName); + + void registerBuiltinDefinition(const std::string& name, std::function); + void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); + +private: + std::pair getSourceNode(CheckResult& checkResult, const ModuleName& name); + SourceModule parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions); + + bool parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root); + + static LintResult classifyLints(const std::vector& warnings, const Config& config); + + ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config); + + std::unordered_map environments; + std::unordered_map> builtinDefinitions; + +public: + FileResolver* fileResolver; + FrontendModuleResolver moduleResolver; + FrontendModuleResolver moduleResolverForAutocomplete; + TypeChecker typeChecker; + TypeChecker typeCheckerForAutocomplete; + ConfigResolver* configResolver; + FrontendOptions options; + InternalErrorReporter iceHandler; + TypeArena arenaForAutocomplete; + + std::unordered_map sourceNodes; + std::unordered_map sourceModules; + std::unordered_map requires; + + Stats stats = {}; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h new file mode 100644 index 0000000..f9e9cd4 --- /dev/null +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -0,0 +1,46 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Error.h" +#include "Luau/Location.h" +#include "Luau/TypeVar.h" +#include "Luau/Ast.h" + +#include + +namespace Luau +{ + +std::ostream& operator<<(std::ostream& lhs, const Position& position); +std::ostream& operator<<(std::ostream& lhs, const Location& location); +std::ostream& operator<<(std::ostream& lhs, const AstName& name); + +std::ostream& operator<<(std::ostream& lhs, const TypeError& error); +std::ostream& operator<<(std::ostream& lhs, const TypeMismatch& error); +std::ostream& operator<<(std::ostream& lhs, const UnknownSymbol& error); +std::ostream& operator<<(std::ostream& lhs, const UnknownProperty& error); +std::ostream& operator<<(std::ostream& lhs, const NotATable& error); +std::ostream& operator<<(std::ostream& lhs, const CannotExtendTable& error); +std::ostream& operator<<(std::ostream& lhs, const OnlyTablesCanHaveMethods& error); +std::ostream& operator<<(std::ostream& lhs, const DuplicateTypeDefinition& error); +std::ostream& operator<<(std::ostream& lhs, const CountMismatch& error); +std::ostream& operator<<(std::ostream& lhs, const FunctionDoesNotTakeSelf& error); +std::ostream& operator<<(std::ostream& lhs, const FunctionRequiresSelf& error); +std::ostream& operator<<(std::ostream& lhs, const OccursCheckFailed& error); +std::ostream& operator<<(std::ostream& lhs, const UnknownRequire& error); +std::ostream& operator<<(std::ostream& lhs, const UnknownPropButFoundLikeProp& e); +std::ostream& operator<<(std::ostream& lhs, const GenericError& error); +std::ostream& operator<<(std::ostream& lhs, const FunctionExitsWithoutReturning& error); +std::ostream& operator<<(std::ostream& lhs, const MissingProperties& error); +std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error); +std::ostream& operator<<(std::ostream& lhs, const ModuleHasCyclicDependency& error); +std::ostream& operator<<(std::ostream& lhs, const DuplicateGenericParameter& error); +std::ostream& operator<<(std::ostream& lhs, const CannotInferBinaryOperation& error); + +std::ostream& operator<<(std::ostream& lhs, const TableState& tv); +std::ostream& operator<<(std::ostream& lhs, const TypeVar& tv); +std::ostream& operator<<(std::ostream& lhs, const TypePackVar& tv); + +std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted); + +} // namespace Luau diff --git a/Analysis/include/Luau/JsonEncoder.h b/Analysis/include/Luau/JsonEncoder.h new file mode 100644 index 0000000..aa00390 --- /dev/null +++ b/Analysis/include/Luau/JsonEncoder.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +class AstNode; + +std::string toJson(AstNode* node); + +} // namespace Luau diff --git a/Analysis/include/Luau/Linter.h b/Analysis/include/Luau/Linter.h new file mode 100644 index 0000000..1f7f7f9 --- /dev/null +++ b/Analysis/include/Luau/Linter.h @@ -0,0 +1,96 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" + +#include +#include + +namespace Luau +{ + +struct AstName; +class AstStat; +class AstNameTable; +struct TypeChecker; +struct Module; + +using ScopePtr = std::shared_ptr; + +struct LintWarning +{ + // Make sure any new lint codes are documented here: https://luau-lang.org/lint + // Note that in Studio, the active set of lint warnings is determined by FStringStudioLuauLints + enum Code + { + Code_Unknown = 0, + + Code_UnknownGlobal = 1, // superseded by type checker + Code_DeprecatedGlobal = 2, + Code_GlobalUsedAsLocal = 3, + Code_LocalShadow = 4, // disabled in Studio + Code_SameLineStatement = 5, // disabled in Studio + Code_MultiLineStatement = 6, + Code_LocalUnused = 7, // disabled in Studio + Code_FunctionUnused = 8, // disabled in Studio + Code_ImportUnused = 9, // disabled in Studio + Code_BuiltinGlobalWrite = 10, + Code_PlaceholderRead = 11, + Code_UnreachableCode = 12, + Code_UnknownType = 13, + Code_ForRange = 14, + Code_UnbalancedAssignment = 15, + Code_ImplicitReturn = 16, // disabled in Studio, superseded by type checker in strict mode + Code_DuplicateLocal = 17, + Code_FormatString = 18, + Code_TableLiteral = 19, + Code_UninitializedLocal = 20, + Code_DuplicateFunction = 21, + Code_DeprecatedApi = 22, + Code_TableOperations = 23, + Code_DuplicateCondition = 24, + + Code__Count + }; + + Code code; + Location location; + std::string text; + + static const char* getName(Code code); + static Code parseName(const char* name); + static uint64_t parseMask(const std::vector& hotcomments); +}; + +struct LintResult +{ + std::vector errors; + std::vector warnings; +}; + +struct LintOptions +{ + uint64_t warningMask = 0; + + void enableWarning(LintWarning::Code code) + { + warningMask |= 1ull << code; + } + void disableWarning(LintWarning::Code code) + { + warningMask &= ~(1ull << code); + } + + bool isEnabled(LintWarning::Code code) const + { + return 0 != (warningMask & (1ull << code)); + } + + void setDefaults(); +}; + +std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const LintOptions& options); + +std::vector getDeprecatedGlobals(const AstNameTable& names); + +} // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h new file mode 100644 index 0000000..413b68f --- /dev/null +++ b/Analysis/include/Luau/Module.h @@ -0,0 +1,111 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/FileResolver.h" +#include "Luau/TypePack.h" +#include "Luau/TypedAllocator.h" +#include "Luau/ParseOptions.h" +#include "Luau/Error.h" +#include "Luau/Parser.h" + +#include +#include +#include +#include + +namespace Luau +{ + +struct Module; + +using ScopePtr = std::shared_ptr; +using ModulePtr = std::shared_ptr; + +/// Root of the AST of a parsed source file +struct SourceModule +{ + ModuleName name; // DataModel path if possible. Filename if not. + SourceCode::Type type = SourceCode::None; + std::optional environmentName; + bool cyclic = false; + + std::unique_ptr allocator; + std::unique_ptr names; + std::vector parseErrors; + + AstStatBlock* root = nullptr; + std::optional mode; + uint64_t ignoreLints = 0; + + std::vector commentLocations; + + SourceModule() + : allocator(new Allocator) + , names(new AstNameTable(*allocator)) + { + } +}; + +bool isWithinComment(const SourceModule& sourceModule, Position pos); + +struct TypeArena +{ + TypedAllocator typeVars; + TypedAllocator typePacks; + + void clear(); + + template + TypeId addType(T tv) + { + return addTV(TypeVar(std::move(tv))); + } + + TypeId addTV(TypeVar&& tv); + + TypeId freshType(TypeLevel level); + + TypePackId addTypePack(std::initializer_list types); + TypePackId addTypePack(std::vector types); + TypePackId addTypePack(TypePack pack); + TypePackId addTypePack(TypePackVar pack); +}; + +void freeze(TypeArena& arena); +void unfreeze(TypeArena& arena); + +// Only exposed so they can be unit tested. +using SeenTypes = std::unordered_map; +using SeenTypePacks = std::unordered_map; + +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); +TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); + +struct Module +{ + ~Module(); + + TypeArena interfaceTypes; + TypeArena internalTypes; + + std::vector> scopes; // never empty + std::unordered_map astTypes; + std::unordered_map astExpectedTypes; + std::unordered_map astOriginalCallTypes; + std::unordered_map astOverloadResolvedTypes; + std::unordered_map declaredGlobals; + ErrorVec errors; + Mode mode; + SourceCode::Type type; + + ScopePtr getModuleScope() const; + + // Once a module has been typechecked, we clone its public interface into a separate arena. + // This helps us to force TypeVar ownership into a DAG rather than a DCG. + // Returns true if there were any free types encountered in the public interface. This + // indicates a bug in the type checker that we want to surface. + bool clonePublicInterface(); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/ModuleResolver.h b/Analysis/include/Luau/ModuleResolver.h new file mode 100644 index 0000000..a394a21 --- /dev/null +++ b/Analysis/include/Luau/ModuleResolver.h @@ -0,0 +1,79 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/FileResolver.h" + +#include +#include +#include + +namespace Luau +{ + +class AstExpr; +struct Module; + +using ModulePtr = std::shared_ptr; + +struct ModuleInfo +{ + ModuleName name; + bool optional = false; +}; + +struct ModuleResolver +{ + virtual ~ModuleResolver() {} + + /** Compute a ModuleName from an AST fragment. This AST fragment is generally the argument to the require() function. + * + * You probably want to implement this with some variation of pathExprToModuleName. + * + * @returns The ModuleInfo if the expression is a syntactically legal path. + * @returns std::nullopt if we are unable to determine whether or not the expression is a valid path. Type inference will + * silently assume that it could succeed in this case. + * + * FIXME: This is clearly not the right behaviour longterm. We'll want to adust this interface to be able to signal + * a) success, + * b) Definitive failure (this expression will absolutely cause require() to fail at runtime), and + * c) uncertainty + */ + virtual std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) = 0; + + /** Get a typechecked module from its name. + * + * This can return null under two circumstances: the module is unknown at compile time, + * or there's a cycle, and we are still in the middle of typechecking the module. + */ + virtual const ModulePtr getModule(const ModuleName& moduleName) const = 0; + + /** Is a module known at compile time? + * + * This function can be used to distinguish the above two cases. + */ + virtual bool moduleExists(const ModuleName& moduleName) const = 0; + + virtual std::string getHumanReadableModuleName(const ModuleName& moduleName) const = 0; +}; + +struct NullModuleResolver : ModuleResolver +{ + std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override + { + return std::nullopt; + } + const ModulePtr getModule(const ModuleName& moduleName) const override + { + return nullptr; + } + bool moduleExists(const ModuleName& moduleName) const override + { + return false; + } + std::string getHumanReadableModuleName(const ModuleName& moduleName) const override + { + return moduleName; + } +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Predicate.h b/Analysis/include/Luau/Predicate.h new file mode 100644 index 0000000..a5e8b6a --- /dev/null +++ b/Analysis/include/Luau/Predicate.h @@ -0,0 +1,120 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Variant.h" +#include "Luau/Location.h" +#include "Luau/Symbol.h" + +#include +#include +#include + +namespace Luau +{ + +struct TypeVar; +using TypeId = const TypeVar*; + +struct Field; +using LValue = Variant; + +struct Field +{ + std::shared_ptr parent; // TODO: Eventually use unique_ptr to enforce non-copyable trait. + std::string key; +}; + +std::optional tryGetLValue(const class AstExpr& expr); + +// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. +std::pair> getFullName(const LValue& lvalue); + +std::string toString(const LValue& lvalue); + +template +const T* get(const LValue& lvalue) +{ + return get_if(&lvalue); +} + +// Key is a stringified encoding of an LValue. +using RefinementMap = std::map; + +void merge(RefinementMap& l, const RefinementMap& r, std::function f); +void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); + +struct TruthyPredicate; +struct IsAPredicate; +struct TypeGuardPredicate; +struct EqPredicate; +struct AndPredicate; +struct OrPredicate; +struct NotPredicate; + +using Predicate = Variant; +using PredicateVec = std::vector; + +struct TruthyPredicate +{ + LValue lvalue; + Location location; +}; + +struct IsAPredicate +{ + LValue lvalue; + Location location; + TypeId ty; +}; + +struct TypeGuardPredicate +{ + LValue lvalue; + Location location; + std::string kind; // TODO: When singleton types arrive, replace this with `TypeId ty;` + bool isTypeof; +}; + +struct EqPredicate +{ + LValue lvalue; + TypeId type; + Location location; +}; + +struct AndPredicate +{ + PredicateVec lhs; + PredicateVec rhs; + + AndPredicate(PredicateVec&& lhs, PredicateVec&& rhs) + : lhs(std::move(lhs)) + , rhs(std::move(rhs)) + { + } +}; + +struct OrPredicate +{ + PredicateVec lhs; + PredicateVec rhs; + + OrPredicate(PredicateVec&& lhs, PredicateVec&& rhs) + : lhs(std::move(lhs)) + , rhs(std::move(rhs)) + { + } +}; + +struct NotPredicate +{ + PredicateVec predicates; +}; + +template +const T* get(const Predicate& predicate) +{ + return get_if(&predicate); +} + +} // namespace Luau diff --git a/Analysis/include/Luau/RecursionCounter.h b/Analysis/include/Luau/RecursionCounter.h new file mode 100644 index 0000000..89632ce --- /dev/null +++ b/Analysis/include/Luau/RecursionCounter.h @@ -0,0 +1,39 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +namespace Luau +{ + +struct RecursionCounter +{ + RecursionCounter(int* count) + : count(count) + { + ++(*count); + } + + ~RecursionCounter() + { + LUAU_ASSERT(*count > 0); + --(*count); + } + +private: + int* count; +}; + +struct RecursionLimiter : RecursionCounter +{ + RecursionLimiter(int* count, int limit) + : RecursionCounter(count) + { + if (limit > 0 && *count > limit) + throw std::runtime_error("Internal recursion counter limit exceeded"); + } +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h new file mode 100644 index 0000000..e977887 --- /dev/null +++ b/Analysis/include/Luau/RequireTracer.h @@ -0,0 +1,28 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/FileResolver.h" +#include "Luau/Location.h" + +#include + +namespace Luau +{ + +class AstStat; +class AstExpr; +class AstStatBlock; +struct AstLocal; + +struct RequireTraceResult +{ + DenseHashMap exprs{0}; + DenseHashMap optional{0}; + + std::vector> requires; +}; + +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName); + +} // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h new file mode 100644 index 0000000..6ac868f --- /dev/null +++ b/Analysis/include/Luau/Substitution.h @@ -0,0 +1,208 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Module.h" +#include "Luau/ModuleResolver.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" +#include "Luau/DenseHash.h" + +// We provide an implementation of substitution on types, +// which recursively replaces types by other types. +// Examples include quantification (replacing free types by generics) +// and instantiation (replacing generic types by free ones). +// +// To implement a substitution, implement a subclass of `Substitution` +// and provide implementations of `isDirty` (which should be true for types that +// should be replaced) and `clean` which replaces any dirty types. +// +// struct MySubst : Substitution +// { +// bool isDirty(TypeId ty) override { ... } +// bool isDirty(TypePackId tp) override { ... } +// TypeId clean(TypeId ty) override { ... } +// TypePackId clean(TypePackId tp) override { ... } +// bool ignoreChildren(TypeId ty) override { ... } +// bool ignoreChildren(TypePackId tp) override { ... } +// }; +// +// For example, `Instantiation` in `TypeInfer.cpp` uses this. + +// The implementation of substitution tries not to copy types +// unnecessarily. It first finds all the types which can reach +// a dirty type, and either cleans them (if they are dirty) +// or clones them (if they are not). It then updates the children +// of the newly created types. When considering reachability, +// we do not consider the children of any type where ignoreChildren(ty) is true. + +// There is a gotcha for cyclic types, which means we can't just use +// a straightforward DFS. For example: +// +// type T = { f : () -> T, g: () -> number, h: X } +// +// If X is dirty, and is being replaced by X' then the result should be: +// +// type T' = { f : () -> T', g: () -> number, h: X' } +// +// that is the type of `f` is replaced, but the type of `g` is not. +// +// For this reason, we first use Tarjan's algorithm to find strongly +// connected components. If any type in an SCC can reach a dirty type, +// them the whole SCC can. For instance, in the above example, +// `T`, and the type of `f` are in the same SCC, which is why `f` gets +// replaced. + +LUAU_FASTFLAG(DebugLuauTrackOwningArena) + +namespace Luau +{ + +enum class TarjanResult +{ + TooManyChildren, + Ok +}; + +struct TarjanWorklistVertex +{ + int index; + int currEdge; + int lastEdge; +}; + +// Tarjan's algorithm for finding the SCCs in a cyclic structure. +// https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm +struct Tarjan +{ + // Vertices (types and type packs) are indexed, using pre-order traversal. + DenseHashMap typeToIndex{nullptr}; + DenseHashMap packToIndex{nullptr}; + std::vector indexToType; + std::vector indexToPack; + + // Tarjan keeps a stack of vertices where we're still in the process + // of finding their SCC. + std::vector stack; + std::vector onStack; + + // Tarjan calculates the lowlink for each vertex, + // which is the lowest ancestor index reachable from the vertex. + std::vector lowlink; + + int childCount = 0; + + std::vector edgesTy; + std::vector edgesTp; + std::vector worklist; + // This is hot code, so we optimize recursion to a stack. + TarjanResult loop(); + + // Clear the state + void clear(); + + // Find or create the index for a vertex. + // Return a boolean which is `true` if it's a freshly created index. + std::pair indexify(TypeId ty); + std::pair indexify(TypePackId tp); + + // Recursively visit all the children of a vertex + void visitChildren(TypeId ty, int index); + void visitChildren(TypePackId tp, int index); + + void visitChild(TypeId ty); + void visitChild(TypePackId ty); + + // Visit the root vertex. + TarjanResult visitRoot(TypeId ty); + TarjanResult visitRoot(TypePackId ty); + + // Each subclass gets called back once for each edge, + // and once for each SCC. + virtual void visitEdge(int index, int parentIndex) {} + virtual void visitSCC(int index) {} + + // Each subclass can decide to ignore some nodes. + virtual bool ignoreChildren(TypeId ty) + { + return false; + } + virtual bool ignoreChildren(TypePackId ty) + { + return false; + } +}; + +// We use Tarjan to calculate dirty bits. We set `dirty[i]` true +// if the vertex with index `i` can reach a dirty vertex. +struct FindDirty : Tarjan +{ + std::vector dirty; + + // Get/set the dirty bit for an index (grows the vector if needed) + bool getDirty(int index); + void setDirty(int index, bool d); + + // Find all the dirty vertices reachable from `t`. + TarjanResult findDirty(TypeId t); + TarjanResult findDirty(TypePackId t); + + // We find dirty vertices using Tarjan + void visitEdge(int index, int parentIndex) override; + void visitSCC(int index) override; + + // Subclasses should say which vertices are dirty, + // and what to do with dirty vertices. + virtual bool isDirty(TypeId ty) = 0; + virtual bool isDirty(TypePackId tp) = 0; + virtual void foundDirty(TypeId ty) = 0; + virtual void foundDirty(TypePackId tp) = 0; +}; + +// And finally substitution, which finds all the reachable dirty vertices +// and replaces them with clean ones. +struct Substitution : FindDirty +{ + ModulePtr currentModule; + DenseHashMap newTypes{nullptr}; + DenseHashMap newPacks{nullptr}; + + std::optional substitute(TypeId ty); + std::optional substitute(TypePackId tp); + + TypeId replace(TypeId ty); + TypePackId replace(TypePackId tp); + void replaceChildren(TypeId ty); + void replaceChildren(TypePackId tp); + TypeId clone(TypeId ty); + TypePackId clone(TypePackId tp); + + // Substitutions use Tarjan to find dirty nodes and replace them + void foundDirty(TypeId ty) override; + void foundDirty(TypePackId tp) override; + + // Implementing subclasses define how to clean a dirty type. + virtual TypeId clean(TypeId ty) = 0; + virtual TypePackId clean(TypePackId tp) = 0; + + // Helper functions to create new types (used by subclasses) + template + TypeId addType(const T& tv) + { + TypeId allocated = currentModule->internalTypes.typeVars.allocate(tv); + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = ¤tModule->internalTypes; + + return allocated; + } + template + TypePackId addTypePack(const T& tp) + { + TypePackId allocated = currentModule->internalTypes.typePacks.allocate(tp); + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = ¤tModule->internalTypes; + + return allocated; + } +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Symbol.h b/Analysis/include/Luau/Symbol.h new file mode 100644 index 0000000..b5dd9c8 --- /dev/null +++ b/Analysis/include/Luau/Symbol.h @@ -0,0 +1,95 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Common.h" + +#include + +namespace Luau +{ + +// TODO Rename this to Name once the old type alias is gone. +struct Symbol +{ + Symbol() + : local(nullptr) + , global() + { + } + + Symbol(AstLocal* local) + : local(local) + , global() + { + } + + Symbol(const AstName& global) + : local(nullptr) + , global(global) + { + } + + AstLocal* local; + AstName global; + + bool operator==(const Symbol& rhs) const + { + if (local) + return local == rhs.local; + if (global.value) + return rhs.global.value && global == rhs.global.value; // Subtlety: AstName::operator==(const char*) uses strcmp, not pointer identity. + return false; + } + + bool operator!=(const Symbol& rhs) const + { + return !(*this == rhs); + } + + bool operator<(const Symbol& rhs) const + { + if (local && rhs.local) + return local < rhs.local; + else if (global.value && rhs.global.value) + return global < rhs.global; + else if (local) + return true; + else + return false; + } + + AstName astName() const + { + if (local) + return local->name; + + LUAU_ASSERT(global.value); + return global; + } + + const char* c_str() const + { + if (local) + return local->name.value; + + LUAU_ASSERT(global.value); + return global.value; + } +}; + +std::string toString(const Symbol& name); + +} // namespace Luau + +namespace std +{ +template<> +struct hash +{ + std::size_t operator()(const Luau::Symbol& s) const noexcept + { + return std::hash()(s.local) ^ (s.global.value ? std::hash()(s.global.value) : 0); + } +}; +} // namespace std diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h new file mode 100644 index 0000000..0897ec8 --- /dev/null +++ b/Analysis/include/Luau/ToString.h @@ -0,0 +1,72 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/TypeVar.h" + +#include +#include +#include +#include + +LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) +LUAU_FASTINT(LuauTypeMaximumStringifierLength) + +namespace Luau +{ + +struct ToStringNameMap +{ + std::unordered_map typeVars; + std::unordered_map typePacks; +}; + +struct ToStringOptions +{ + bool exhaustive = false; // If true, we produce complete output rather than comprehensible output + bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. + bool functionTypeArguments = false; // If true, output function type argument names when they are available + bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars + size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); + std::optional nameMap; + std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' +}; + +struct ToStringResult +{ + std::string name; + ToStringNameMap nameMap; + + bool invalid = false; + bool error = false; + bool cycle = false; + bool truncated = false; +}; + +ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts = {}); +ToStringResult toStringDetailed(TypePackId ty, const ToStringOptions& opts = {}); + +std::string toString(TypeId ty, const ToStringOptions& opts); +std::string toString(TypePackId ty, const ToStringOptions& opts); + +// These are offered as overloads rather than a default parameter so that they can be easily invoked from within the MSVC debugger. +// You can use them in watch expressions! +inline std::string toString(TypeId ty) +{ + return toString(ty, ToStringOptions{}); +} +inline std::string toString(TypePackId ty) +{ + return toString(ty, ToStringOptions{}); +} + +std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); +std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); + +// It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class +// These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression +void dump(TypeId ty); +void dump(TypePackId ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/TopoSortStatements.h b/Analysis/include/Luau/TopoSortStatements.h new file mode 100644 index 0000000..751694f --- /dev/null +++ b/Analysis/include/Luau/TopoSortStatements.h @@ -0,0 +1,18 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +template +struct AstArray; + +class AstStat; + +bool containsFunctionCall(const AstStat& stat); +bool isFunction(const AstStat& stat); +void toposort(std::vector& stats); + +} // namespace Luau diff --git a/Analysis/include/Luau/Transpiler.h b/Analysis/include/Luau/Transpiler.h new file mode 100644 index 0000000..817459f --- /dev/null +++ b/Analysis/include/Luau/Transpiler.h @@ -0,0 +1,30 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" +#include "Luau/ParseOptions.h" + +#include + +namespace Luau +{ +class AstNode; +class AstStatBlock; + +struct TranspileResult +{ + std::string code; + Location errorLocation; + std::string parseError; // Nonempty if the transpile failed +}; + +void dump(AstNode* node); + +// Never fails on a well-formed AST +std::string transpile(AstStatBlock& ast); +std::string transpileWithTypes(AstStatBlock& block); + +// Only fails when parsing fails +TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}); + +} // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h new file mode 100644 index 0000000..055441c --- /dev/null +++ b/Analysis/include/Luau/TxnLog.h @@ -0,0 +1,46 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeVar.h" + +namespace Luau +{ + +// Log of where what TypeIds we are rebinding and what they used to be +struct TxnLog +{ + TxnLog() = default; + + explicit TxnLog(const std::vector>& seen) + : seen(seen) + { + } + + TxnLog(const TxnLog&) = delete; + TxnLog& operator=(const TxnLog&) = delete; + + TxnLog(TxnLog&&) = default; + TxnLog& operator=(TxnLog&&) = default; + + void operator()(TypeId a); + void operator()(TypePackId a); + void operator()(TableTypeVar* a); + + void rollback(); + + void concat(TxnLog rhs); + + bool haveSeen(TypeId lhs, TypeId rhs); + void pushSeen(TypeId lhs, TypeId rhs); + void popSeen(TypeId lhs, TypeId rhs); + +private: + std::vector> typeVarChanges; + std::vector> typePackChanges; + std::vector>> tableChanges; + +public: + std::vector> seen; // used to avoid infinite recursion when types are cyclic +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeAttach.h b/Analysis/include/Luau/TypeAttach.h new file mode 100644 index 0000000..c945805 --- /dev/null +++ b/Analysis/include/Luau/TypeAttach.h @@ -0,0 +1,21 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Module.h" + +#include + +namespace Luau +{ + +struct TypeRehydrationOptions +{ + std::unordered_set bannedNames; + bool expandClassProps = false; +}; + +void attachTypeData(SourceModule& source, Module& result); + +AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options = {}); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h new file mode 100644 index 0000000..ec2a1a2 --- /dev/null +++ b/Analysis/include/Luau/TypeInfer.h @@ -0,0 +1,453 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Predicate.h" +#include "Luau/Error.h" +#include "Luau/Module.h" +#include "Luau/Symbol.h" +#include "Luau/Parser.h" +#include "Luau/Substitution.h" +#include "Luau/TxnLog.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" +#include "Luau/Unifier.h" + +#include +#include +#include + +namespace Luau +{ + +struct Scope; +struct TypeChecker; +struct ModuleResolver; + +using Name = std::string; +using ScopePtr = std::shared_ptr; +using OverloadErrorEntry = std::tuple, std::vector, const FunctionTypeVar*>; + +bool doesCallError(const AstExprCall* call); +bool hasBreak(AstStat* node); +const AstStat* getFallthrough(const AstStat* node); + +struct Unifier; + +// A substitution which replaces generic types in a given set by free types. +struct ReplaceGenerics : Substitution +{ + TypeLevel level; + std::vector generics; + std::vector genericPacks; + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +// A substitution which replaces generic functions by monomorphic functions +struct Instantiation : Substitution +{ + TypeLevel level; + ReplaceGenerics replaceGenerics; + bool ignoreChildren(TypeId ty) override; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +// A substitution which replaces free types by generic types. +struct Quantification : Substitution +{ + TypeLevel level; + std::vector generics; + std::vector genericPacks; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +// A substitution which replaces free types by any +struct Anyification : Substitution +{ + TypeId anyType; + TypePackId anyTypePack; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +// A substitution which replaces the type parameters of a type function by arguments +struct ApplyTypeFunction : Substitution +{ + TypeLevel level; + bool encounteredForwardedType; + std::unordered_map arguments; + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId tp) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId tp) override; +}; + +// All TypeVars are retained via Environment::typeVars. All TypeIds +// within a program are borrowed pointers into this set. +struct TypeChecker +{ + explicit TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler); + TypeChecker(const TypeChecker&) = delete; + TypeChecker& operator=(const TypeChecker&) = delete; + + ModulePtr check(const SourceModule& module, Mode mode, std::optional environmentScope = std::nullopt); + + std::vector> getScopes() const; + + void check(const ScopePtr& scope, const AstStat& statement); + void check(const ScopePtr& scope, const AstStatBlock& statement); + void check(const ScopePtr& scope, const AstStatIf& statement); + void check(const ScopePtr& scope, const AstStatWhile& statement); + void check(const ScopePtr& scope, const AstStatRepeat& statement); + void check(const ScopePtr& scope, const AstStatReturn& return_); + void check(const ScopePtr& scope, const AstStatAssign& assign); + void check(const ScopePtr& scope, const AstStatCompoundAssign& assign); + void check(const ScopePtr& scope, const AstStatLocal& local); + void check(const ScopePtr& scope, const AstStatFor& local); + void check(const ScopePtr& scope, const AstStatForIn& forin); + void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); + void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); + void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false); + void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); + void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); + + void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); + void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); + + ExprResult checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt); + ExprResult checkExpr(const ScopePtr& scope, const AstExprLocal& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprCall& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexName& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType = std::nullopt); + ExprResult checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType = std::nullopt); + ExprResult checkExpr(const ScopePtr& scope, const AstExprUnary& expr); + TypeId checkRelationalOperation( + const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); + TypeId checkBinaryOperation( + const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates = {}); + ExprResult checkExpr(const ScopePtr& scope, const AstExprBinary& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprError& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr); + + TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, + std::optional expectedType); + + // Returns the type of the lvalue. + TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr); + + // Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding). + // Note: the binding may be null. + std::pair checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); + std::pair checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); + std::pair checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); + std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); + std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); + + TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName); + std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, + std::optional originalNameLoc, std::optional expectedType); + void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); + + void checkArgumentList( + const ScopePtr& scope, Unifier& state, TypePackId paramPack, TypePackId argPack, const std::vector& argLocations); + + ExprResult checkExprPack(const ScopePtr& scope, const AstExpr& expr); + ExprResult checkExprPack(const ScopePtr& scope, const AstExprCall& expr); + std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); + std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, + std::vector& overloadsThatMatchArgCount, std::vector& errors); + bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, + const std::vector& errors); + ExprResult reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, + const std::vector& errors); + + ExprResult checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, + bool substituteFreeForNil = false, const std::vector& lhsAnnotations = {}, + const std::vector>& expectedTypes = {}); + + static std::optional matchRequire(const AstExprCall& call); + TypeId checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location); + + // Try to infer that the provided type is a table of some sort. + // Reports an error if the type is already some kind of non-table. + void tablify(TypeId type); + + /** In nonstrict mode, many typevars need to be replaced by any. + */ + TypeId anyIfNonstrict(TypeId ty) const; + + /** Attempt to unify the types left and right. Treat any failures as type errors + * in the final typecheck report. + */ + bool unify(TypeId left, TypeId right, const Location& location); + bool unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); + + /** Attempt to unify the types left and right. + * If this fails, and the right type can be instantiated, do so and try unification again. + */ + bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location); + void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state); + + /** Attempt to unify left with right. + * If there are errors, undo everything and return the errors. + * If there are no errors, commit and return an empty error vector. + */ + ErrorVec tryUnify(TypeId left, TypeId right, const Location& location); + ErrorVec tryUnify(TypePackId left, TypePackId right, const Location& location); + + // Test whether the two type vars unify. Never commits the result. + ErrorVec canUnify(TypeId superTy, TypeId subTy, const Location& location); + ErrorVec canUnify(TypePackId superTy, TypePackId subTy, const Location& location); + + // Variant that takes a preexisting 'seen' set. We need this in certain cases to avoid infinitely recursing + // into cyclic types. + ErrorVec canUnify(const std::vector>& seen, TypeId left, TypeId right, const Location& location); + + std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); + std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); + + std::optional getIndexTypeFromType(const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors); + + // Reduces the union to its simplest possible shape. + // (A | B) | B | C yields A | B | C + std::vector reduceUnion(const std::vector& types); + + std::optional tryStripUnionFromNil(TypeId ty); + TypeId stripFromNilAndReport(TypeId ty, const Location& location); + + template + ErrorVec tryUnify_(Id left, Id right, const Location& location); + + template + ErrorVec canUnify_(Id left, Id right, const Location& location); + +public: + /* + * Convert monotype into a a polytype, by replacing any metavariables in descendant scopes + * by bound generic type variables. This is used to infer that a function is generic. + */ + TypeId quantify(const ScopePtr& scope, TypeId ty, Location location); + + /* + * Convert a polytype into a monotype, by replacing any bound generic types by type metavariables. + * This is used to typecheck particular calls to generic functions, and when generic functions + * are passed as arguments. + * + * The "changed" boolean is used to permit us to return the same TypeId in the case that the instantiated type is unchanged. + * This is important in certain cases, such as methods on objects, where a table contains a function whose first argument is the table. + * Without this property, we can wind up in a situation where a new TypeId is allocated for the outer table. This can cause us to produce + * unfortunate types like + * + * {method: ({method: () -> a}) -> a} + * + */ + TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location); + // Removed by FFlag::LuauRankNTypes + TypePackId DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location); + + // Replace any free types or type packs by `any`. + // This is used when exporting types from modules, to make sure free types don't leak. + TypeId anyify(const ScopePtr& scope, TypeId ty, Location location); + TypePackId anyify(const ScopePtr& scope, TypePackId ty, Location location); + + void reportError(const TypeError& error); + void reportError(const Location& location, TypeErrorData error); + void reportErrors(const ErrorVec& errors); + + [[noreturn]] void ice(const std::string& message, const Location& location); + [[noreturn]] void ice(const std::string& message); + + ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); + ScopePtr childScope(const ScopePtr& parent, const Location& location, int subLevel = 0); + + // Wrapper for merge(l, r, toUnion) but without the lambda junk. + void merge(RefinementMap& l, const RefinementMap& r); + +private: + void prepareErrorsForDisplay(ErrorVec& errVec); + void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data); + void reportErrorCodeTooComplex(const Location& location); + +private: + Unifier mkUnifier(const Location& location); + Unifier mkUnifier(const std::vector>& seen, const Location& location); + + // These functions are only safe to call when we are in the process of typechecking a module. + + // Produce a new free type var. + TypeId freshType(const ScopePtr& scope); + TypeId freshType(TypeLevel level); + TypeId DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric = false); + TypeId DEPRECATED_freshType(TypeLevel level, bool canBeGeneric = false); + + // Returns nullopt if the predicate filters down the TypeId to 0 options. + std::optional filterMap(TypeId type, TypeIdPredicate predicate); + + TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true); + + // ex + // TypeId id = addType(FreeTypeVar()); + template + TypeId addType(const T& tv) + { + return addTV(TypeVar(tv)); + } + + TypeId addType(const UnionTypeVar& utv); + + TypeId addTV(TypeVar&& tv); + + TypePackId addTypePack(TypePackVar&& tp); + TypePackId addTypePack(TypePack&& tp); + + TypePackId addTypePack(const std::vector& ty); + TypePackId addTypePack(const std::vector& ty, std::optional tail); + TypePackId addTypePack(std::initializer_list&& ty); + TypePackId freshTypePack(const ScopePtr& scope); + TypePackId freshTypePack(TypeLevel level); + TypePackId DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric = false); + TypePackId DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric = false); + + TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false); + TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); + TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); + TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location); + + // Note: `scope` must be a fresh scope. + std::pair, std::vector> createGenericTypes( + const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + +public: + ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); + +private: + std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); + std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); + + void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); + void resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); + void resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr); + void resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); + void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); + void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); + + bool isNonstrictMode() const; + +public: + /** Extract the types in a type pack, given the assumption that the pack must have some exact length. + * TypePacks can have free tails, which means that inference has not yet determined the length of the pack. + * Calling this function means submitting evidence that the pack must have the length provided. + * If the pack is known not to have the correct length, an error will be reported. + * The return vector is always of the exact requested length. In the event that the pack's length does + * not match up, excess TypeIds will be ErrorTypeVars. + */ + std::vector unTypePack(const ScopePtr& scope, TypePackId pack, size_t expectedLength, const Location& location); + + TypeArena globalTypes; + + ModuleResolver* resolver; + SourceModule globalNames; // names for symbols entered into globalScope + ScopePtr globalScope; // shared by all modules + ModulePtr currentModule; + ModuleName currentModuleName; + + Instantiation instantiation; + Quantification quantification; + Anyification anyification; + ApplyTypeFunction applyTypeFunction; + + std::function prepareModuleScope; + InternalErrorReporter* iceHandler; + +public: + const TypeId nilType; + const TypeId numberType; + const TypeId stringType; + const TypeId booleanType; + const TypeId threadType; + const TypeId anyType; + + const TypeId errorType; + const TypeId optionalNumberType; + + const TypePackId anyTypePack; + const TypePackId errorTypePack; + +private: + int checkRecursionCount = 0; + int recursionCount = 0; +}; + +struct Binding +{ + TypeId typeId; + Location location; + bool deprecated = false; + std::string deprecatedSuggestion; + std::optional documentationSymbol; +}; + +struct Scope +{ + explicit Scope(TypePackId returnType); // root scope + explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. + + const ScopePtr parent; // null for the root + std::unordered_map bindings; + TypePackId returnType; + bool breakOk = false; + std::optional varargPack; + + TypeLevel level; + + std::unordered_map exportedTypeBindings; + std::unordered_map privateTypeBindings; + std::unordered_map typeAliasLocations; + + std::unordered_map> importedTypeBindings; + + std::optional lookup(const Symbol& name); + + std::optional lookupType(const Name& name); + std::optional lookupImportedType(const Name& moduleAlias, const Name& name); + + std::unordered_map privateTypePackBindings; + std::optional lookupPack(const Name& name); + + // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) + std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); + + RefinementMap refinements; + + // For mutually recursive type aliases, it's important that + // they use the same types for the same names. + // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` + // we need that the generic type `T` in both cases is the same, so we use a cache. + std::unordered_map typeAliasParameters; +}; + +// Unit test hook +void setPrintLine(void (*pl)(const std::string& s)); +void resetPrintLine(); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h new file mode 100644 index 0000000..0d0adce --- /dev/null +++ b/Analysis/include/Luau/TypePack.h @@ -0,0 +1,161 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeVar.h" +#include "Luau/Unifiable.h" +#include "Luau/Variant.h" + +#include +#include + +LUAU_FASTFLAG(LuauAddMissingFollow) + +namespace Luau +{ + +struct TypeArena; + +struct TypePack; +struct VariadicTypePack; + +struct TypePackVar; + +using TypePackId = const TypePackVar*; +using FreeTypePack = Unifiable::Free; +using BoundTypePack = Unifiable::Bound; +using GenericTypePack = Unifiable::Generic; +using TypePackVariant = Unifiable::Variant; + +/* A TypePack is a rope-like string of TypeIds. We use this structure to encode + * notions like packs of unknown length and packs of any length, as well as more + * nuanced compositions like "a pack which is a number prepended to this other pack," + * or "a pack that is 2 numbers followed by any number of any other types." + */ +struct TypePack +{ + std::vector head; + std::optional tail; +}; + +struct VariadicTypePack +{ + TypeId ty; +}; + +struct TypePackVar +{ + explicit TypePackVar(const TypePackVariant& ty); + explicit TypePackVar(TypePackVariant&& ty); + TypePackVar(TypePackVariant&& ty, bool persistent); + bool operator==(const TypePackVar& rhs) const; + TypePackVar& operator=(TypePackVariant&& tp); + + TypePackVariant ty; + bool persistent = false; + + // Pointer to the type arena that allocated this type. + // Do not depend on the value of this under any circumstances. This is for + // debugging purposes only. This is only set in debug builds; it is nullptr + // in all other environments. + TypeArena* owningArena = nullptr; +}; + +/* Walk the set of TypeIds in a TypePack. + * + * Like TypeVars, individual TypePacks can be free, generic, or any. + * + * We afford the ability to work with these kinds of packs by giving the + * iterator a .tail() property that yields the tail-most TypePack in the + * rope. + * + * It is very commonplace to want to walk each type in a pack, then handle + * the tail specially. eg when checking parameters, it might be the case + * that the parameter pack ends with a VariadicTypePack. In this case, we + * want to allow any number of extra arguments. + * + * The iterator obtained by calling end(tp) does not have a .tail(), but is + * equivalent with end(tp2) for any two type packs. + */ +struct TypePackIterator +{ + using value_type = Luau::TypeId; + using pointer = value_type*; + using reference = value_type&; + using difference_type = size_t; + using iterator_category = std::input_iterator_tag; + + TypePackIterator() = default; + explicit TypePackIterator(TypePackId tp); + + TypePackIterator& operator++(); + TypePackIterator operator++(int); + bool operator!=(const TypePackIterator& rhs); + bool operator==(const TypePackIterator& rhs); + + const TypeId& operator*(); + + /** Return the tail of a TypePack. + * This may *only* be called on an iterator that has been incremented to the end. + * Returns nullopt if the pack has fixed length. + */ + std::optional tail(); + + friend TypePackIterator end(TypePackId tp); + +private: + TypePackId currentTypePack = nullptr; + const TypePack* tp = nullptr; + size_t currentIndex = 0; +}; + +TypePackIterator begin(TypePackId tp); +TypePackIterator end(TypePackId tp); + +using SeenSet = std::set>; + +bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); + +TypePackId follow(TypePackId tp); + +size_t size(const TypePackId tp); +size_t size(const TypePack& tp); +std::optional first(TypePackId tp); + +TypePackVar* asMutable(TypePackId tp); +TypePack* asMutable(const TypePack* tp); + +template +const T* get(TypePackId tp) +{ + if (FFlag::LuauAddMissingFollow) + { + LUAU_ASSERT(tp); + + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tp->ty) == nullptr); + } + + return get_if(&(tp->ty)); +} + +template +T* getMutable(TypePackId tp) +{ + if (FFlag::LuauAddMissingFollow) + { + LUAU_ASSERT(tp); + + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tp->ty) == nullptr); + } + + return get_if(&(asMutable(tp)->ty)); +} + +/// Returns true if the type pack is known to be empty (no types in the head and no/an empty tail). +bool isEmpty(TypePackId tp); + +/// Flattens out a type pack. Also returns a valid TypePackId tail if the type pack's full size is not known +std::pair, std::optional> flatten(TypePackId tp); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h new file mode 100644 index 0000000..ffddfe4 --- /dev/null +++ b/Analysis/include/Luau/TypeUtils.h @@ -0,0 +1,19 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Error.h" +#include "Luau/Location.h" +#include "Luau/TypeVar.h" + +#include +#include + +namespace Luau +{ + +using ScopePtr = std::shared_ptr; + +std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globalScope, TypeId type, std::string entry, Location location); +std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const ScopePtr& globalScope, TypeId ty, Name name, Location location); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h new file mode 100644 index 0000000..90a28b2 --- /dev/null +++ b/Analysis/include/Luau/TypeVar.h @@ -0,0 +1,531 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Predicate.h" +#include "Luau/Unifiable.h" +#include "Luau/Variant.h" +#include "Luau/Common.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) +LUAU_FASTINT(LuauTypeMaximumStringifierLength) +LUAU_FASTFLAG(LuauAddMissingFollow) + +namespace Luau +{ + +struct TypeArena; + +/** + * There are three kinds of type variables: + * - `Free` variables are metavariables, which stand for unconstrained types. + * - `Bound` variables are metavariables that have an equality constraint. + * - `Generic` variables are type variables that are bound by generic functions. + * + * For example, consider the program: + * ``` + * function(x, y) x.f = y end + * ``` + * To typecheck this, we first introduce free metavariables for the types of `x` and `y`: + * ``` + * function(x: X, y: Y) x.f = y end + * ``` + * Type inference for the function body then produces the constraint: + * ``` + * X = { f: Y } + * ``` + * so `X` is now a bound metavariable. We can then quantify the metavariables, + * which replaces any bound metavariables by their binding, and free metavariables + * by bound generic variables: + * ``` + * function(x: { f: a }, y: a) x.f = y end + * ``` + */ + +// So... why `const T*` here rather than `T*`? +// It's because we've had problems caused by the type graph being mutated +// in ways it shouldn't be, for example mutating types from other modules. +// To try to control this, we make the use of types immutable by default, +// then provide explicit mutable access via getMutable and asMutable. +// This means we can grep for all the places we're mutating the type graph, +// and it makes it possible to provide other APIs (e.g. the txn log) +// which control mutable access to the type graph. +struct TypePackVar; +using TypePackId = const TypePackVar*; + +// TODO: rename to Type? CLI-39100 +struct TypeVar; + +// Should never be null +using TypeId = const TypeVar*; + +using Name = std::string; + +// A free type var is one whose exact shape has yet to be fully determined. +using FreeTypeVar = Unifiable::Free; + +// When a free type var is unified with any other, it is then "bound" +// to that type var, indicating that the two types are actually the same type. +using BoundTypeVar = Unifiable::Bound; + +using GenericTypeVar = Unifiable::Generic; + +using Tags = std::vector; + +using ModuleName = std::string; + +struct PrimitiveTypeVar +{ + enum Type + { + NilType, // ObjC #defines Nil :( + Boolean, + Number, + String, + Thread, + }; + + Type type; + std::optional metatable; // string has a metatable + + explicit PrimitiveTypeVar(Type type) + : type(type) + { + } + + explicit PrimitiveTypeVar(Type type, TypeId metatable) + : type(type) + , metatable(metatable) + { + } +}; + +struct FunctionArgument +{ + Name name; + Location location; +}; + +struct FunctionDefinition +{ + std::optional definitionModuleName; + Location definitionLocation; + std::optional varargLocation; + Location originalNameLocation; +}; + +// TODO: Come up with a better name. +// TODO: Do we actually need this? We'll find out later if we can delete this. +// Does not exactly belong in TypeVar.h, but this is the only way to appease the compiler. +template +struct ExprResult +{ + T type; + PredicateVec predicates; +}; + +using MagicFunction = std::function>( + struct TypeChecker&, const std::shared_ptr&, const class AstExprCall&, ExprResult)>; + +struct FunctionTypeVar +{ + // Global monomorphic function + FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + + // Global polymorphic function + FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + std::optional defn = {}, bool hasSelf = false); + + // Local monomorphic function + FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn = {}, bool hasSelf = false); + + // Local polymorphic function + FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + std::optional defn = {}, bool hasSelf = false); + + TypeLevel level; + /// These should all be generic + std::vector generics; + std::vector genericPacks; + TypePackId argTypes; + std::vector> argNames; + TypePackId retType; + std::optional definition; + MagicFunction magicFunction = nullptr; // Function pointer, can be nullptr. + bool hasSelf; + Tags tags; +}; + +enum class TableState +{ + // Sealed tables have an exact, known shape + Sealed, + + // An unsealed table can have extra properties added to it + Unsealed, + + // Tables which are not yet fully understood. We are still in the process of learning its shape. + Free, + + // A table which is a generic parameter to a function. We know that certain properties are required, + // but we don't care about the full shape. + Generic, +}; + +struct TableIndexer +{ + TableIndexer(TypeId indexType, TypeId indexResultType) + : indexType(indexType) + , indexResultType(indexResultType) + { + } + + TypeId indexType; + TypeId indexResultType; +}; + +struct Property +{ + TypeId type; + bool deprecated = false; + std::string deprecatedSuggestion; + std::optional location = std::nullopt; + Tags tags; + std::optional documentationSymbol; +}; + +struct TableTypeVar +{ + // We choose std::map over unordered_map here just because we have unit tests that compare + // textual outputs. I don't want to spend the effort making them resilient in the case where + // random events cause the iteration order of the map elements to change. + // If this shows up in a profile, we can revisit it. + using Props = std::map; + + TableTypeVar() = default; + explicit TableTypeVar(TableState state, TypeLevel level); + TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, TableState state = TableState::Unsealed); + + Props props; + std::optional indexer; + + TableState state = TableState::Unsealed; + TypeLevel level; + std::optional name; + + // Sometimes we throw a type on a name to make for nicer error messages, but without creating any entry in the type namespace + // We need to know which is which when we stringify types. + std::optional syntheticName; + + std::map methodDefinitionLocations; + std::vector instantiatedTypeParams; + ModuleName definitionModuleName; + + std::optional boundTo; + Tags tags; +}; + +// Represents a metatable attached to a table typevar. Somewhat analogous to a bound typevar. +struct MetatableTypeVar +{ + // Always points to a TableTypeVar. + TypeId table; + // Always points to either a TableTypeVar or a MetatableTypeVar. + TypeId metatable; + + std::optional syntheticName; +}; + +// Custom userdata of a class type +struct ClassUserData +{ + virtual ~ClassUserData() {} +}; + +/** The type of a class. + * + * Classes behave like tables in many ways, but there are some important differences: + * + * The properties of a class are always exactly known. + * Classes optionally have a parent class. + * Two different classes that share the same properties are nevertheless distinct and mutually incompatible. + */ +struct ClassTypeVar +{ + using Props = TableTypeVar::Props; + + Name name; + Props props; + std::optional parent; + std::optional metatable; // metaclass? + Tags tags; + std::shared_ptr userData; + + ClassTypeVar( + Name name, Props props, std::optional parent, std::optional metatable, Tags tags, std::shared_ptr userData) + : name(name) + , props(props) + , parent(parent) + , metatable(metatable) + , tags(tags) + , userData(userData) + { + } +}; + +struct TypeFun +{ + /// These should all be generic + std::vector typeParams; + + /** The underlying type. + * + * WARNING! This is not safe to use as a type if typeParams is not empty!! + * You must first use TypeChecker::instantiateTypeFun to turn it into a real type. + */ + TypeId type; +}; + +// Anything! All static checking is off. +struct AnyTypeVar +{ +}; + +struct UnionTypeVar +{ + std::vector options; +}; + +struct IntersectionTypeVar +{ + std::vector parts; +}; + +struct LazyTypeVar +{ + std::function thunk; +}; + +using ErrorTypeVar = Unifiable::Error; + +using TypeVariant = Unifiable::Variant; + +struct TypeVar final +{ + explicit TypeVar(const TypeVariant& ty) + : ty(ty) + { + } + + explicit TypeVar(TypeVariant&& ty) + : ty(std::move(ty)) + { + } + + TypeVar(const TypeVariant& ty, bool persistent) + : ty(ty) + , persistent(persistent) + { + } + + TypeVariant ty; + + // Kludge: A persistent TypeVar is one that belongs to the global scope. + // Global type bindings are immutable but are reused many times. + // Persistent TypeVars do not get cloned. + bool persistent = false; + + std::optional documentationSymbol; + + // Pointer to the type arena that allocated this type. + // Do not depend on the value of this under any circumstances. This is for + // debugging purposes only. This is only set in debug builds; it is nullptr + // in all other environments. + TypeArena* owningArena = nullptr; + + bool operator==(const TypeVar& rhs) const; + bool operator!=(const TypeVar& rhs) const; + + TypeVar& operator=(const TypeVariant& rhs); + TypeVar& operator=(TypeVariant&& rhs); +}; + +using SeenSet = std::set>; +bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); + +// Follow BoundTypeVars until we get to something real +TypeId follow(TypeId t); + +std::vector flattenIntersection(TypeId ty); + +bool isPrim(TypeId ty, PrimitiveTypeVar::Type primType); +bool isNil(TypeId ty); +bool isBoolean(TypeId ty); +bool isNumber(TypeId ty); +bool isString(TypeId ty); +bool isThread(TypeId ty); +bool isOptional(TypeId ty); +bool isTableIntersection(TypeId ty); +bool isOverloadedFunction(TypeId ty); + +std::optional getMetatable(TypeId type); +TableTypeVar* getMutableTableType(TypeId type); +const TableTypeVar* getTableType(TypeId type); + +// If the type has a name, return that. Else if it has a synthetic name, return that. +// Returns nullptr if the type has no name. +const std::string* getName(TypeId type); + +// Checks whether a union contains all types of another union. +bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub); + +// Checks if a type conains generic type binders +bool isGeneric(const TypeId ty); + +// Checks if a type may be instantiated to one containing generic type binders +bool maybeGeneric(const TypeId ty); + +struct SingletonTypes +{ + const TypeId nilType = &nilType_; + const TypeId numberType = &numberType_; + const TypeId stringType = &stringType_; + const TypeId booleanType = &booleanType_; + const TypeId threadType = &threadType_; + const TypeId anyType = &anyType_; + const TypeId errorType = &errorType_; + + SingletonTypes(); + SingletonTypes(const SingletonTypes&) = delete; + void operator=(const SingletonTypes&) = delete; + +private: + std::unique_ptr arena; + TypeVar nilType_; + TypeVar numberType_; + TypeVar stringType_; + TypeVar booleanType_; + TypeVar threadType_; + TypeVar anyType_; + TypeVar errorType_; + + TypeId makeStringMetatable(); +}; + +extern SingletonTypes singletonTypes; + +void persist(TypeId ty); +void persist(TypePackId tp); + +struct ToDotOptions +{ + bool showPointers = true; // Show pointer value in the node label + bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes +}; + +std::string toDot(TypeId ty, const ToDotOptions& opts); +std::string toDot(TypePackId tp, const ToDotOptions& opts); + +std::string toDot(TypeId ty); +std::string toDot(TypePackId tp); + +void dumpDot(TypeId ty); +void dumpDot(TypePackId tp); + +const TypeLevel* getLevel(TypeId ty); +TypeLevel* getMutableLevel(TypeId ty); + +const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); +bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); + +bool hasGeneric(TypeId ty); +bool hasGeneric(TypePackId tp); + +TypeVar* asMutable(TypeId ty); + +template +const T* get(TypeId tv) +{ + if (FFlag::LuauAddMissingFollow) + { + LUAU_ASSERT(tv); + + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); + } + + return get_if(&tv->ty); +} + +template +T* getMutable(TypeId tv) +{ + if (FFlag::LuauAddMissingFollow) + { + LUAU_ASSERT(tv); + + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); + } + + return get_if(&asMutable(tv)->ty); +} + +/* Traverses the UnionTypeVar yielding each TypeId. + * If the iterator encounters a nested UnionTypeVar, it will instead yield each TypeId within. + * + * Beware: the iterator does not currently filter for unique TypeIds. This may change in the future. + */ +struct UnionTypeVarIterator +{ + using value_type = Luau::TypeId; + using pointer = value_type*; + using reference = value_type&; + using difference_type = size_t; + using iterator_category = std::input_iterator_tag; + + explicit UnionTypeVarIterator(const UnionTypeVar* utv); + + UnionTypeVarIterator& operator++(); + UnionTypeVarIterator operator++(int); + bool operator!=(const UnionTypeVarIterator& rhs); + bool operator==(const UnionTypeVarIterator& rhs); + + const TypeId& operator*(); + + friend UnionTypeVarIterator end(const UnionTypeVar* utv); + +private: + UnionTypeVarIterator() = default; + + // (UnionTypeVar* utv, size_t currentIndex) + using SavedIterInfo = std::pair; + + std::deque stack; + std::unordered_set seen; // Only needed to protect the iterator from hanging the thread. + + void advance(); + void descend(); +}; + +UnionTypeVarIterator begin(const UnionTypeVar* utv); +UnionTypeVarIterator end(const UnionTypeVar* utv); + +using TypeIdPredicate = std::function(TypeId)>; +std::vector filterMap(TypeId type, TypeIdPredicate predicate); + +// TEMP: Clip this prototype with FFlag::LuauStringMetatable +std::optional> magicFunctionFormat( + struct TypeChecker& typechecker, const std::shared_ptr& scope, const AstExprCall& expr, ExprResult exprResult); + +} // namespace Luau diff --git a/Analysis/include/Luau/TypedAllocator.h b/Analysis/include/Luau/TypedAllocator.h new file mode 100644 index 0000000..0ded148 --- /dev/null +++ b/Analysis/include/Luau/TypedAllocator.h @@ -0,0 +1,134 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include +#include + +namespace Luau +{ + +void* pagedAllocate(size_t size); +void pagedDeallocate(void* ptr); +void pagedFreeze(void* ptr, size_t size); +void pagedUnfreeze(void* ptr, size_t size); + +template +class TypedAllocator +{ +public: + TypedAllocator() + { + appendBlock(); + } + + ~TypedAllocator() + { + if (frozen) + unfreeze(); + free(); + } + + template + T* allocate(Args&&... args) + { + LUAU_ASSERT(!frozen); + + if (currentBlockSize >= kBlockSize) + { + LUAU_ASSERT(currentBlockSize == kBlockSize); + appendBlock(); + } + + T* block = stuff.back(); + T* res = block + currentBlockSize; + new (res) T(std::forward(args...)); + ++currentBlockSize; + return res; + } + + bool contains(const T* ptr) const + { + for (T* block : stuff) + if (ptr >= block && ptr < block + kBlockSize) + return true; + + return false; + } + + bool empty() const + { + return stuff.size() == 1 && currentBlockSize == 0; + } + + size_t size() const + { + return kBlockSize * (stuff.size() - 1) + currentBlockSize; + } + + void clear() + { + if (frozen) + unfreeze(); + free(); + appendBlock(); + } + + void freeze() + { + for (T* block : stuff) + pagedFreeze(block, kBlockSizeBytes); + frozen = true; + } + + void unfreeze() + { + for (T* block : stuff) + pagedUnfreeze(block, kBlockSizeBytes); + frozen = false; + } + + bool isFrozen() + { + return frozen; + } + +private: + void free() + { + LUAU_ASSERT(!frozen); + + for (T* block : stuff) + { + size_t blockSize = (block == stuff.back()) ? currentBlockSize : kBlockSize; + + for (size_t i = 0; i < blockSize; ++i) + block[i].~T(); + + pagedDeallocate(block); + } + + stuff.clear(); + currentBlockSize = 0; + } + + void appendBlock() + { + void* block = pagedAllocate(kBlockSizeBytes); + if (!block) + throw std::bad_alloc(); + + stuff.emplace_back(static_cast(block)); + currentBlockSize = 0; + } + + bool frozen = false; + std::vector stuff; + size_t currentBlockSize = 0; + + static constexpr size_t kBlockSizeBytes = 32768; + static constexpr size_t kBlockSize = kBlockSizeBytes / sizeof(T); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h new file mode 100644 index 0000000..10dbf33 --- /dev/null +++ b/Analysis/include/Luau/Unifiable.h @@ -0,0 +1,123 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Variant.h" + +#include + +namespace Luau +{ + +/** + * The 'level' of a TypeVar is an indirect way to talk about the scope that it 'belongs' too. + * To start, read http://okmij.org/ftp/ML/generalization.html + * + * We extend the idea by adding a "sub-level" which helps us to differentiate sibling scopes + * within a single larger scope. + * + * We need this because we try to prototype functions and add them to the type environment before + * we check the function bodies. This allows us to properly typecheck many scenarios where there + * is no single good order in which to typecheck a program. + */ +struct TypeLevel +{ + int level = 0; + int subLevel = 0; + + // Returns true if the typelevel "this" is "bigger" than rhs + bool subsumes(const TypeLevel& rhs) const + { + if (level < rhs.level) + return true; + if (level > rhs.level) + return false; + if (subLevel == rhs.subLevel) + return true; // if level == rhs.level and subLevel == rhs.subLevel, then they are the exact same TypeLevel + + // Sibling TypeLevels (that is, TypeLevels that share a level but have a different subLevel) are not considered to subsume one another + return false; + } + + TypeLevel incr() const + { + TypeLevel result; + result.level = level + 1; + result.subLevel = 0; + return result; + } +}; + +inline TypeLevel min(const TypeLevel& a, const TypeLevel& b) +{ + if (a.subsumes(b)) + return a; + else + return b; +} + +namespace Unifiable +{ + +using Name = std::string; + +struct Free +{ + explicit Free(TypeLevel level); + Free(TypeLevel level, bool DEPRECATED_canBeGeneric); + + int index; + TypeLevel level; + // Removed by FFlag::LuauRankNTypes + bool DEPRECATED_canBeGeneric = false; + // True if this free type variable is part of a mutually + // recursive type alias whose definitions haven't been + // resolved yet. + bool forwardedTypeAlias = false; + +private: + static int nextIndex; +}; + +template +struct Bound +{ + explicit Bound(Id boundTo) + : boundTo(boundTo) + { + } + + Id boundTo; +}; + +struct Generic +{ + // By default, generics are global, with a synthetic name + Generic(); + explicit Generic(TypeLevel level); + explicit Generic(const Name& name); + Generic(TypeLevel level, const Name& name); + + int index; + TypeLevel level; + Name name; + bool explicitName; + +private: + static int nextIndex; +}; + +struct Error +{ + Error(); + + int index; + +private: + static int nextIndex; +}; + +template +using Variant = Variant, Generic, Error, Value...>; + +} // namespace Unifiable +} // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h new file mode 100644 index 0000000..0ddc3cc --- /dev/null +++ b/Analysis/include/Luau/Unifier.h @@ -0,0 +1,98 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Error.h" +#include "Luau/Location.h" +#include "Luau/TxnLog.h" +#include "Luau/TypeInfer.h" +#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header. + +#include + +namespace Luau +{ + +enum Variance +{ + Covariant, + Invariant +}; + +struct UnifierCounters +{ + int recursionCount = 0; + int iterationCount = 0; +}; + +struct Unifier +{ + TypeArena* const types; + Mode mode; + ScopePtr globalScope; // sigh. Needed solely to get at string's metatable. + + TxnLog log; + ErrorVec errors; + Location location; + Variance variance = Covariant; + CountMismatch::Context ctx = CountMismatch::Arg; + + std::shared_ptr counters; + InternalErrorReporter* iceHandler; + + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, + Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters = nullptr); + + // Test whether the two type vars unify. Never commits the result. + ErrorVec canUnify(TypeId superTy, TypeId subTy); + ErrorVec canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + + /** Attempt to unify left with right. + * Populate the vector errors with any type errors that may arise. + * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. + */ + void tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); + +private: + void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnifyPrimitives(TypeId superTy, TypeId subTy); + void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); + void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); + void tryUnifyFreeTable(TypeId free, TypeId other); + void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection); + void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); + void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); + void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); + +public: + void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + +private: + void tryUnify_(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + void tryUnifyVariadics(TypePackId superTy, TypePackId subTy, bool reversed, int subOffset = 0); + + void tryUnifyWithAny(TypeId any, TypeId ty); + void tryUnifyWithAny(TypePackId any, TypePackId ty); + + std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); + std::optional findMetatableEntry(TypeId type, std::string entry); + +public: + // Report an "infinite type error" if the type "needle" already occurs within "haystack" + void occursCheck(TypeId needle, TypeId haystack); + void occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack); + void occursCheck(TypePackId needle, TypePackId haystack); + void occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack); + + Unifier makeChildUnifier(); + +private: + bool isNonstrictMode() const; + + void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType); + + [[noreturn]] void ice(const std::string& message, const Location& location); + [[noreturn]] void ice(const std::string& message); +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Variant.h b/Analysis/include/Luau/Variant.h new file mode 100644 index 0000000..63d5a65 --- /dev/null +++ b/Analysis/include/Luau/Variant.h @@ -0,0 +1,302 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#ifndef LUAU_USE_STD_VARIANT +#define LUAU_USE_STD_VARIANT 0 +#endif + +#if LUAU_USE_STD_VARIANT +#include +#else +#include +#include +#include +#include +#endif + +namespace Luau +{ + +#if LUAU_USE_STD_VARIANT +template +using Variant = std::variant; + +template +auto visit(Visitor&& vis, Variant&& var) +{ + // This change resolves the ABI issues with std::variant on libc++; std::visit normally throws bad_variant_access + // but it requires an update to libc++.dylib which ships with macOS 10.14. To work around this, we assert on valueless + // variants since we will never generate them and call into a libc++ function that doesn't throw. + LUAU_ASSERT(!var.valueless_by_exception()); + +#ifdef __APPLE__ + // See https://stackoverflow.com/a/53868971/503215 + return std::__variant_detail::__visitation::__variant::__visit_value(vis, var); +#else + return std::visit(vis, var); +#endif +} + +using std::get_if; +#else +template +class Variant +{ + static_assert(sizeof...(Ts) > 0, "variant must have at least 1 type (empty variants are ill-formed)"); + static_assert(std::disjunction_v...> == false, "variant does not allow void as an alternative type"); + static_assert(std::disjunction_v...> == false, "variant does not allow references as an alternative type"); + static_assert(std::disjunction_v...> == false, "variant does not allow arrays as an alternative type"); + +private: + template + static constexpr int getTypeId() + { + using TT = std::decay_t; + + constexpr int N = sizeof...(Ts); + constexpr bool is[N] = {std::is_same_v...}; + + for (int i = 0; i < N; ++i) + if (is[i]) + return i; + + return -1; + } + + template + struct First + { + using type = T; + }; + +public: + using first_alternative = typename First::type; + + Variant() + { + static_assert(std::is_default_constructible_v, "first alternative type must be default constructible"); + typeId = 0; + new (&storage) first_alternative(); + } + + template + Variant(T&& value, std::enable_if_t() >= 0>* = 0) + { + using TT = std::decay_t; + + constexpr int tid = getTypeId(); + typeId = tid; + new (&storage) TT(value); + } + + Variant(const Variant& other) + { + typeId = other.typeId; + tableCopy[typeId](&storage, &other.storage); + } + + Variant(Variant&& other) + { + typeId = other.typeId; + tableMove[typeId](&storage, &other.storage); + } + + ~Variant() + { + tableDtor[typeId](&storage); + } + + Variant& operator=(const Variant& other) + { + Variant copy(other); + // static_cast is equivalent to std::move() but faster in Debug + return *this = static_cast(copy); + } + + Variant& operator=(Variant&& other) + { + if (this != &other) + { + tableDtor[typeId](&storage); + typeId = other.typeId; + tableMove[typeId](&storage, &other.storage); // nothrow + } + return *this; + } + + template + const T* get_if() const + { + constexpr int tid = getTypeId(); + static_assert(tid >= 0, "unsupported T"); + + return tid == typeId ? reinterpret_cast(&storage) : nullptr; + } + + template + T* get_if() + { + constexpr int tid = getTypeId(); + static_assert(tid >= 0, "unsupported T"); + + return tid == typeId ? reinterpret_cast(&storage) : nullptr; + } + + bool valueless_by_exception() const + { + return false; + } + + int index() const + { + return typeId; + } + + bool operator==(const Variant& other) const + { + static constexpr FnPred table[sizeof...(Ts)] = {&fnPredEq...}; + + return typeId == other.typeId && table[typeId](&storage, &other.storage); + } + + bool operator!=(const Variant& other) const + { + return !(*this == other); + } + +private: + static constexpr size_t cmax(std::initializer_list l) + { + size_t res = 0; + for (size_t i : l) + res = (res < i) ? i : res; + return res; + } + + static constexpr size_t storageSize = cmax({sizeof(Ts)...}); + static constexpr size_t storageAlign = cmax({alignof(Ts)...}); + + using FnCopy = void (*)(void*, const void*); + using FnMove = void (*)(void*, void*); + using FnDtor = void (*)(void*); + using FnPred = bool (*)(const void*, const void*); + + template + static void fnCopy(void* dst, const void* src) + { + new (dst) T(*static_cast(src)); + } + + template + static void fnMove(void* dst, void* src) + { + // static_cast is equivalent to std::move() but faster in Debug + new (dst) T(static_cast(*static_cast(src))); + } + + template + static void fnDtor(void* dst) + { + static_cast(dst)->~T(); + } + + template + static bool fnPredEq(const void* lhs, const void* rhs) + { + return *static_cast(lhs) == *static_cast(rhs); + } + + static constexpr FnCopy tableCopy[sizeof...(Ts)] = {&fnCopy...}; + static constexpr FnMove tableMove[sizeof...(Ts)] = {&fnMove...}; + static constexpr FnDtor tableDtor[sizeof...(Ts)] = {&fnDtor...}; + + int typeId; + alignas(storageAlign) char storage[storageSize]; + + template + friend auto visit(Visitor&& vis, const Variant<_Ts...>& var); + template + friend auto visit(Visitor&& vis, Variant<_Ts...>& var); +}; + +template +const T* get_if(const Variant* var) +{ + return var ? var->template get_if() : nullptr; +} + +template +T* get_if(Variant* var) +{ + return var ? var->template get_if() : nullptr; +} + +template +static void fnVisitR(Visitor& vis, Result& dst, std::conditional_t, const void, void>* src) +{ + dst = vis(*static_cast(src)); +} + +template +static void fnVisitV(Visitor& vis, std::conditional_t, const void, void>* src) +{ + vis(*static_cast(src)); +} + +template +auto visit(Visitor&& vis, const Variant& var) +{ + using Result = std::invoke_result_t::first_alternative>; + static_assert(std::conjunction_v>...>, + "visitor result type must be consistent between alternatives"); + + if constexpr (std::is_same_v) + { + using FnVisitV = void (*)(Visitor&, const void*); + static const FnVisitV tableVisit[sizeof...(Ts)] = {&fnVisitV...}; + + tableVisit[var.typeId](vis, &var.storage); + } + else + { + using FnVisitR = void (*)(Visitor&, Result&, const void*); + static const FnVisitR tableVisit[sizeof...(Ts)] = {&fnVisitR...}; + + Result res; + tableVisit[var.typeId](vis, res, &var.storage); + return res; + } +} + +template +auto visit(Visitor&& vis, Variant& var) +{ + using Result = std::invoke_result_t::first_alternative&>; + static_assert(std::conjunction_v>...>, + "visitor result type must be consistent between alternatives"); + + if constexpr (std::is_same_v) + { + using FnVisitV = void (*)(Visitor&, void*); + static const FnVisitV tableVisit[sizeof...(Ts)] = {&fnVisitV...}; + + tableVisit[var.typeId](vis, &var.storage); + } + else + { + using FnVisitR = void (*)(Visitor&, Result&, void*); + static const FnVisitR tableVisit[sizeof...(Ts)] = {&fnVisitR...}; + + Result res; + tableVisit[var.typeId](vis, res, &var.storage); + return res; + } +} +#endif + +template +inline constexpr bool always_false_v = false; + +} // namespace Luau diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h new file mode 100644 index 0000000..df0bd42 --- /dev/null +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -0,0 +1,200 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +namespace Luau +{ + +namespace visit_detail +{ +/** + * Apply f(tid, t, seen) if doing so would pass type checking, else apply f(tid, t) + * + * We do this to permit (but not require) TypeVar visitors to accept the seen set as an argument. + */ +template +auto apply(A tid, const B& t, C& c, F& f) -> decltype(f(tid, t, c)) +{ + return f(tid, t, c); +} + +template +auto apply(A tid, const B& t, C&, F& f) -> decltype(f(tid, t)) +{ + return f(tid, t); +} + +inline bool hasSeen(std::unordered_set& seen, const void* tv) +{ + void* ttv = const_cast(tv); + return !seen.insert(ttv).second; +} + +inline void unsee(std::unordered_set& seen, const void* tv) +{ + void* ttv = const_cast(tv); + seen.erase(ttv); +} + +template +void visit(TypePackId tp, F& f, std::unordered_set& seen); + +template +void visit(TypeId ty, F& f, std::unordered_set& seen) +{ + if (visit_detail::hasSeen(seen, ty)) + { + f.cycle(ty); + return; + } + + if (auto btv = get(ty)) + { + if (apply(ty, *btv, seen, f)) + visit(btv->boundTo, f, seen); + } + + else if (auto ftv = get(ty)) + apply(ty, *ftv, seen, f); + + else if (auto gtv = get(ty)) + apply(ty, *gtv, seen, f); + + else if (auto etv = get(ty)) + apply(ty, *etv, seen, f); + + else if (auto ptv = get(ty)) + apply(ty, *ptv, seen, f); + + else if (auto ftv = get(ty)) + { + if (apply(ty, *ftv, seen, f)) + { + visit(ftv->argTypes, f, seen); + visit(ftv->retType, f, seen); + } + } + + else if (auto ttv = get(ty)) + { + if (apply(ty, *ttv, seen, f)) + { + for (auto& [_name, prop] : ttv->props) + visit(prop.type, f, seen); + + if (ttv->indexer) + { + visit(ttv->indexer->indexType, f, seen); + visit(ttv->indexer->indexResultType, f, seen); + } + } + } + + else if (auto mtv = get(ty)) + { + if (apply(ty, *mtv, seen, f)) + { + visit(mtv->table, f, seen); + visit(mtv->metatable, f, seen); + } + } + + else if (auto ctv = get(ty)) + { + if (apply(ty, *ctv, seen, f)) + { + for (const auto& [name, prop] : ctv->props) + visit(prop.type, f, seen); + + if (ctv->parent) + visit(*ctv->parent, f, seen); + + if (ctv->metatable) + visit(*ctv->metatable, f, seen); + } + } + + else if (auto atv = get(ty)) + apply(ty, *atv, seen, f); + + else if (auto utv = get(ty)) + { + if (apply(ty, *utv, seen, f)) + { + for (TypeId optTy : utv->options) + visit(optTy, f, seen); + } + } + + else if (auto itv = get(ty)) + { + if (apply(ty, *itv, seen, f)) + { + for (TypeId partTy : itv->parts) + visit(partTy, f, seen); + } + } + + visit_detail::unsee(seen, ty); +} + +template +void visit(TypePackId tp, F& f, std::unordered_set& seen) +{ + if (visit_detail::hasSeen(seen, tp)) + { + f.cycle(tp); + return; + } + + if (auto btv = get(tp)) + { + if (apply(tp, *btv, seen, f)) + visit(btv->boundTo, f, seen); + } + + else if (auto ftv = get(tp)) + apply(tp, *ftv, seen, f); + + else if (auto gtv = get(tp)) + apply(tp, *gtv, seen, f); + + else if (auto etv = get(tp)) + apply(tp, *etv, seen, f); + + else if (auto pack = get(tp)) + { + apply(tp, *pack, seen, f); + + for (TypeId ty : pack->head) + visit(ty, f, seen); + + if (pack->tail) + visit(*pack->tail, f, seen); + } + else if (auto pack = get(tp)) + { + apply(tp, *pack, seen, f); + visit(pack->ty, f, seen); + } + + visit_detail::unsee(seen, tp); +} +} // namespace visit_detail + +template +void visitTypeVar(TID ty, F& f, std::unordered_set& seen) +{ + visit_detail::visit(ty, f, seen); +} + +template +void visitTypeVar(TID ty, F& f) +{ + std::unordered_set seen; + visit_detail::visit(ty, f, seen); +} + +} // namespace Luau diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp new file mode 100644 index 0000000..d3de175 --- /dev/null +++ b/Analysis/src/AstQuery.cpp @@ -0,0 +1,411 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AstQuery.h" + +#include "Luau/Module.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/ToString.h" + +#include "Luau/Common.h" + +#include + +namespace Luau +{ + +namespace +{ + +struct FindNode : public AstVisitor +{ + const Position pos; + const Position documentEnd; + AstNode* best = nullptr; + + explicit FindNode(Position pos, Position documentEnd) + : pos(pos) + , documentEnd(documentEnd) + { + } + + bool visit(AstNode* node) override + { + if (node->location.contains(pos)) + { + best = node; + return true; + } + + // Edge case: If we ask for the node at the position that is the very end of the document + // return the innermost AST element that ends at that position. + + if (node->location.end == documentEnd && pos >= documentEnd) + { + best = node; + return true; + } + + return false; + } + + bool visit(AstStatBlock* block) override + { + visit(static_cast(block)); + + for (AstStat* stat : block->body) + { + if (stat->location.end < pos) + continue; + if (stat->location.begin > pos) + break; + + stat->visit(this); + } + + return false; + } +}; + +struct FindFullAncestry final : public AstVisitor +{ + std::vector nodes; + Position pos; + + explicit FindFullAncestry(Position pos) + : pos(pos) + { + } + + bool visit(AstNode* node) + { + if (node->location.contains(pos)) + { + nodes.push_back(node); + return true; + } + return false; + } +}; + +} // namespace + +std::vector findAstAncestryOfPosition(const SourceModule& source, Position pos) +{ + FindFullAncestry finder(pos); + source.root->visit(&finder); + return std::move(finder.nodes); +} + +AstNode* findNodeAtPosition(const SourceModule& source, Position pos) +{ + const Position end = source.root->location.end; + if (pos < source.root->location.begin) + return source.root; + + if (pos > end) + pos = end; + + FindNode findNode{pos, end}; + findNode.visit(source.root); + return findNode.best; +} + +AstExpr* findExprAtPosition(const SourceModule& source, Position pos) +{ + AstNode* node = findNodeAtPosition(source, pos); + if (node) + return node->asExpr(); + else + return nullptr; +} + +ScopePtr findScopeAtPosition(const Module& module, Position pos) +{ + LUAU_ASSERT(!module.scopes.empty()); + + Location scopeLocation = module.scopes.front().first; + ScopePtr scope = module.scopes.front().second; + for (const auto& s : module.scopes) + { + if (s.first.contains(pos)) + { + if (!scope || scopeLocation.encloses(s.first)) + { + scopeLocation = s.first; + scope = s.second; + } + } + } + return scope; +} + +std::optional findTypeAtPosition(const Module& module, const SourceModule& sourceModule, Position pos) +{ + if (auto expr = findExprAtPosition(sourceModule, pos)) + { + if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) + return it->second; + } + + return std::nullopt; +} + +std::optional findExpectedTypeAtPosition(const Module& module, const SourceModule& sourceModule, Position pos) +{ + if (auto expr = findExprAtPosition(sourceModule, pos)) + { + if (auto it = module.astExpectedTypes.find(expr); it != module.astExpectedTypes.end()) + return it->second; + } + + return std::nullopt; +} + +static std::optional findBindingLocalStatement(const SourceModule& source, const Binding& binding) +{ + std::vector nodes = findAstAncestryOfPosition(source, binding.location.begin); + auto iter = std::find_if(nodes.rbegin(), nodes.rend(), [](AstNode* node) { + return node->is(); + }); + return iter != nodes.rend() ? std::make_optional((*iter)->as()) : std::nullopt; +} + +std::optional findBindingAtPosition(const Module& module, const SourceModule& source, Position pos) +{ + AstExpr* expr = findExprAtPosition(source, pos); + if (!expr) + return std::nullopt; + + Symbol name; + if (auto g = expr->as()) + name = g->name; + else if (auto l = expr->as()) + name = l->local; + else + return std::nullopt; + + ScopePtr currentScope = findScopeAtPosition(module, pos); + LUAU_ASSERT(currentScope); + + while (currentScope) + { + auto iter = currentScope->bindings.find(name); + if (iter != currentScope->bindings.end() && iter->second.location.begin <= pos) + { + /* Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope */ + std::optional bindingStatement = findBindingLocalStatement(source, iter->second); + if (!bindingStatement || !(*bindingStatement)->location.contains(pos)) + return iter->second; + } + currentScope = currentScope->parent; + } + + return std::nullopt; +} + +namespace +{ +struct FindExprOrLocal : public AstVisitor +{ + const Position pos; + ExprOrLocal result; + + explicit FindExprOrLocal(Position pos) + : pos(pos) + { + } + + // We want to find the result with the smallest location range. + bool isCloserMatch(Location newLocation) + { + auto current = result.getLocation(); + return newLocation.contains(pos) && (!current || current->encloses(newLocation)); + } + + bool visit(AstStatBlock* block) override + { + for (AstStat* stat : block->body) + { + if (stat->location.end <= pos) + continue; + if (stat->location.begin > pos) + break; + + stat->visit(this); + } + + return false; + } + + bool visit(AstExpr* expr) override + { + if (isCloserMatch(expr->location)) + { + result.setExpr(expr); + return true; + } + return false; + } + + bool visitLocal(AstLocal* local) + { + if (isCloserMatch(local->location)) + { + result.setLocal(local); + return true; + } + return false; + } + + bool visit(AstStatLocalFunction* function) override + { + visitLocal(function->name); + return true; + } + + bool visit(AstStatLocal* al) override + { + for (size_t i = 0; i < al->vars.size; ++i) + { + visitLocal(al->vars.data[i]); + } + return true; + } + + virtual bool visit(AstExprFunction* fn) override + { + for (size_t i = 0; i < fn->args.size; ++i) + { + visitLocal(fn->args.data[i]); + } + return visit((class AstExpr*)fn); + } + + virtual bool visit(AstStatFor* forStat) override + { + visitLocal(forStat->var); + return true; + } + + virtual bool visit(AstStatForIn* forIn) override + { + for (AstLocal* var : forIn->vars) + { + visitLocal(var); + } + return true; + } +}; +}; // namespace + +ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos) +{ + FindExprOrLocal findVisitor{pos}; + findVisitor.visit(source.root); + return findVisitor.result; +} + +std::optional getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position) +{ + std::vector ancestry = findAstAncestryOfPosition(source, position); + + AstExpr* targetExpr = ancestry.size() >= 1 ? ancestry[ancestry.size() - 1]->asExpr() : nullptr; + AstExpr* parentExpr = ancestry.size() >= 2 ? ancestry[ancestry.size() - 2]->asExpr() : nullptr; + + if (std::optional binding = findBindingAtPosition(module, source, position)) + { + if (binding->documentationSymbol) + { + // This might be an overloaded function binding. + if (get(follow(binding->typeId))) + { + TypeId matchingOverload = nullptr; + if (parentExpr && parentExpr->is()) + { + if (auto it = module.astOverloadResolvedTypes.find(parentExpr); it != module.astOverloadResolvedTypes.end()) + { + matchingOverload = it->second; + } + } + + if (matchingOverload) + { + std::string overloadSymbol = *binding->documentationSymbol + "/overload/"; + // Default toString options are fine for this purpose. + overloadSymbol += toString(matchingOverload); + return overloadSymbol; + } + } + } + + return binding->documentationSymbol; + } + + if (targetExpr) + { + if (AstExprIndexName* indexName = targetExpr->as()) + { + if (auto it = module.astTypes.find(indexName->expr); it != module.astTypes.end()) + { + TypeId parentTy = follow(it->second); + if (const TableTypeVar* ttv = get(parentTy)) + { + if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) + { + return propIt->second.documentationSymbol; + } + } + else if (const ClassTypeVar* ctv = get(parentTy)) + { + if (auto propIt = ctv->props.find(indexName->index.value); propIt != ctv->props.end()) + { + return propIt->second.documentationSymbol; + } + } + } + } + else if (AstExprFunction* fn = targetExpr->as()) + { + // Handle event connection-like structures where we have + // something:Connect(function(a, b, c) end) + // In this case, we want to ascribe a documentation symbol to 'a' + // based on the documentation symbol of Connect. + if (parentExpr && parentExpr->is()) + { + AstExprCall* call = parentExpr->as(); + if (std::optional parentSymbol = getDocumentationSymbolAtPosition(source, module, call->func->location.begin)) + { + for (size_t i = 0; i < call->args.size; ++i) + { + AstExpr* callArg = call->args.data[i]; + if (callArg == targetExpr) + { + std::string fnSymbol = *parentSymbol + "/param/" + std::to_string(i); + for (size_t j = 0; j < fn->args.size; ++j) + { + AstLocal* fnArg = fn->args.data[j]; + + if (fnArg->location.contains(position)) + { + return fnSymbol + "/param/" + std::to_string(j); + } + } + } + } + } + } + } + } + + if (std::optional ty = findTypeAtPosition(module, source, position)) + { + if ((*ty)->documentationSymbol) + { + return (*ty)->documentationSymbol; + } + } + + return std::nullopt; +} + +} // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp new file mode 100644 index 0000000..dce92a0 --- /dev/null +++ b/Analysis/src/Autocomplete.cpp @@ -0,0 +1,1566 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Autocomplete.h" + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Frontend.h" +#include "Luau/ToString.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypePack.h" + +#include +#include +#include + +LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) +LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); +LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) + +static const std::unordered_set kStatementStartingKeywords = { + "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; + +namespace Luau +{ + +struct NodeFinder : public AstVisitor +{ + const Position pos; + std::vector ancestry; + + explicit NodeFinder(Position pos, AstNode* root) + : pos(pos) + { + } + + bool visit(AstExpr* expr) override + { + if (expr->location.begin < pos && pos <= expr->location.end) + { + ancestry.push_back(expr); + return true; + } + return false; + } + + bool visit(AstStat* stat) override + { + if (stat->location.begin < pos && pos <= stat->location.end) + { + ancestry.push_back(stat); + return true; + } + return false; + } + + bool visit(AstType* type) override + { + if (type->location.begin < pos && pos <= type->location.end) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(AstTypeError* type) override + { + // For a missing type, match the whole range including the start position + if (type->isMissing && type->location.containsClosed(pos)) + { + ancestry.push_back(type); + return true; + } + return false; + } + + bool visit(class AstTypePack* typePack) override + { + return true; + } + + bool visit(AstStatBlock* block) override + { + // If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite. + if (ancestry.empty()) + { + ancestry.push_back(block); + return true; + } + + // AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes. + // ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz} + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // Type annotation error might intersect the block statement when the function header is being written, + // annotation takes priority + if (!ancestry.empty() && ancestry.back()->is()) + return false; + + // If the cursor is at the end of an expression or type and simultaneously at the beginning of a block, + // the expression or type wins out. + // The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to + // be within the block. + if (block->location.begin == pos && !ancestry.empty()) + { + if (ancestry.back()->asExpr() && !ancestry.back()->is()) + return false; + + if (ancestry.back()->asType()) + return false; + } + + if (block->location.begin <= pos && pos <= block->location.end) + { + ancestry.push_back(block); + return true; + } + return false; + } +}; + +static bool alreadyHasParens(const std::vector& nodes) +{ + auto iter = nodes.rbegin(); + while (iter != nodes.rend() && + ((*iter)->is() || (*iter)->is() || (*iter)->is() || (*iter)->is())) + { + iter++; + } + + if (iter == nodes.rend() || iter == nodes.rbegin()) + { + return false; + } + + if (AstExprCall* call = (*iter)->as()) + { + return call->func == *(iter - 1); + } + + return false; +} + +static ParenthesesRecommendation getParenRecommendationForFunc(const FunctionTypeVar* func, const std::vector& nodes) +{ + if (alreadyHasParens(nodes)) + { + return ParenthesesRecommendation::None; + } + + auto idxExpr = nodes.back()->as(); + bool hasImplicitSelf = idxExpr && idxExpr->op == ':'; + auto args = Luau::flatten(func->argTypes); + bool noArgFunction = (args.first.empty() || (hasImplicitSelf && args.first.size() == 1)) && !args.second.has_value(); + return noArgFunction ? ParenthesesRecommendation::CursorAfter : ParenthesesRecommendation::CursorInside; +} + +static ParenthesesRecommendation getParenRecommendationForIntersect(const IntersectionTypeVar* intersect, const std::vector& nodes) +{ + ParenthesesRecommendation rec = ParenthesesRecommendation::None; + for (Luau::TypeId partId : intersect->parts) + { + if (auto partFunc = Luau::get(partId)) + { + rec = std::max(rec, getParenRecommendationForFunc(partFunc, nodes)); + } + else + { + return ParenthesesRecommendation::None; + } + } + return rec; +} + +static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::vector& nodes, TypeCorrectKind typeCorrect) +{ + // If element is already type-correct, even a function should be inserted without parenthesis + if (typeCorrect == TypeCorrectKind::Correct) + return ParenthesesRecommendation::None; + + id = Luau::follow(id); + if (auto func = get(id)) + { + return getParenRecommendationForFunc(func, nodes); + } + else if (auto intersect = get(id)) + { + return getParenRecommendationForIntersect(intersect, nodes); + } + return ParenthesesRecommendation::None; +} + +static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, TypeId ty) +{ + ty = follow(ty); + + auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) { + InternalErrorReporter iceReporter; + Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter); + + unifier.tryUnify(expectedType, actualType); + + bool ok = unifier.errors.empty(); + unifier.log.rollback(); + return ok; + }; + + auto expr = node->asExpr(); + if (!expr) + return TypeCorrectKind::None; + + auto it = module.astExpectedTypes.find(expr); + if (it == module.astExpectedTypes.end()) + return TypeCorrectKind::None; + + TypeId expectedType = follow(it->second); + + if (canUnify(expectedType, ty)) + return TypeCorrectKind::Correct; + + // We also want to suggest functions that return compatible result + const FunctionTypeVar* ftv = get(ty); + + if (!ftv) + return TypeCorrectKind::None; + + auto [retHead, retTail] = flatten(ftv->retType); + + if (!retHead.empty()) + return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail))) + return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + } + + return TypeCorrectKind::None; +} + +enum class PropIndexType +{ + Point, + Colon, + Key, +}; + +static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes, + AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) +{ + ty = follow(ty); + + if (seen.count(ty)) + return; + seen.insert(ty); + + auto isWrongIndexer = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { + if (indexType == PropIndexType::Key) + return false; + + bool colonIndex = indexType == PropIndexType::Colon; + + if (const FunctionTypeVar* ftv = get(type)) + { + return useStrictFunctionIndexers ? colonIndex != ftv->hasSelf : false; + } + else if (const IntersectionTypeVar* itv = get(type)) + { + bool allHaveSelf = true; + for (auto subType : itv->parts) + { + if (const FunctionTypeVar* ftv = get(Luau::follow(subType))) + { + allHaveSelf &= ftv->hasSelf; + } + else + { + return colonIndex; + } + } + return useStrictFunctionIndexers ? colonIndex != allHaveSelf : false; + } + else + { + return colonIndex; + } + }; + + auto fillProps = [&](const ClassTypeVar::Props& props) { + for (const auto& [name, prop] : props) + { + // We are walking up the class hierarchy, so if we encounter a property that we have + // already populated, it takes precedence over the property we found just now. + if (result.count(name) == 0 && name != Parser::errorName) + { + Luau::TypeId type = Luau::follow(prop.type); + TypeCorrectKind typeCorrect = + indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, nodes.back(), type); + ParenthesesRecommendation parens = + indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); + + result[name] = AutocompleteEntry{ + AutocompleteEntryKind::Property, + type, + prop.deprecated, + isWrongIndexer(type), + typeCorrect, + containingClass, + &prop, + prop.documentationSymbol, + {}, + parens, + }; + } + } + }; + + if (auto cls = get(ty)) + { + containingClass = containingClass.value_or(cls); + fillProps(cls->props); + if (cls->parent) + autocompleteProps(module, typeArena, *cls->parent, indexType, nodes, result, seen, cls); + } + else if (auto tbl = get(ty)) + fillProps(tbl->props); + else if (auto mt = get(ty)) + { + autocompleteProps(module, typeArena, mt->table, indexType, nodes, result, seen); + + auto mtable = get(mt->metatable); + if (!mtable) + return; + + auto indexIt = mtable->props.find("__index"); + if (indexIt != mtable->props.end()) + { + if (get(indexIt->second.type) || get(indexIt->second.type)) + autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); + else if (auto indexFunction = get(indexIt->second.type)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + } + } + } + else if (auto i = get(ty)) + { + // Complete all properties in every variant + for (TypeId ty : i->parts) + { + AutocompleteEntryMap inner; + std::unordered_set innerSeen = seen; + + autocompleteProps(module, typeArena, ty, indexType, nodes, inner, innerSeen); + + for (auto& pair : inner) + result.insert(pair); + } + } + else if (auto u = get(ty)) + { + // Complete all properties common to all variants + auto iter = begin(u); + auto endIter = end(u); + + while (iter != endIter) + { + if (FFlag::LuauAddMissingFollow) + { + if (isNil(*iter)) + ++iter; + else + break; + } + else + { + if (auto primTy = Luau::get(*iter); primTy && primTy->type == PrimitiveTypeVar::NilType) + ++iter; + else + break; + } + } + + if (iter == endIter) + return; + + autocompleteProps(module, typeArena, *iter, indexType, nodes, result, seen); + + ++iter; + + while (iter != endIter) + { + AutocompleteEntryMap inner; + std::unordered_set innerSeen = seen; + + if (FFlag::LuauAddMissingFollow) + { + if (isNil(*iter)) + { + ++iter; + continue; + } + } + else + { + if (auto innerPrimTy = Luau::get(*iter); innerPrimTy && innerPrimTy->type == PrimitiveTypeVar::NilType) + { + ++iter; + continue; + } + } + + autocompleteProps(module, typeArena, *iter, indexType, nodes, inner, innerSeen); + + std::unordered_set toRemove; + + for (const auto& [k, v] : result) + { + (void)v; + if (!inner.count(k)) + toRemove.insert(k); + } + + for (const std::string& k : toRemove) + result.erase(k); + + ++iter; + } + } +} + +static void autocompleteKeywords( + const SourceModule& sourceModule, const std::vector& ancestry, Position position, AutocompleteEntryMap& result) +{ + LUAU_ASSERT(!ancestry.empty()); + + AstNode* node = ancestry.back(); + + if (!node->is() && node->asExpr()) + { + // This is not strictly correct. We should recommend `and` and `or` only after + // another expression, not at the start of a new one. We should only recommend + // `not` at the start of an expression. Detecting either case reliably is quite + // complex, however; this is good enough for now. + + // These are not context-sensitive keywords, so we can unconditionally assign. + result["and"] = {AutocompleteEntryKind::Keyword}; + result["or"] = {AutocompleteEntryKind::Keyword}; + result["not"] = {AutocompleteEntryKind::Keyword}; + } +} + +static void autocompleteProps( + const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result) +{ + std::unordered_set seen; + autocompleteProps(module, typeArena, ty, indexType, nodes, result, seen); +} + +AutocompleteEntryMap autocompleteProps( + const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes) +{ + AutocompleteEntryMap result; + autocompleteProps(module, typeArena, ty, indexType, nodes, result); + return result; +} + +AutocompleteEntryMap autocompleteModuleTypes(const Module& module, Position position, std::string_view moduleName) +{ + AutocompleteEntryMap result; + + for (ScopePtr scope = findScopeAtPosition(module, position); scope; scope = scope->parent) + { + if (auto it = scope->importedTypeBindings.find(std::string(moduleName)); it != scope->importedTypeBindings.end()) + { + for (const auto& [name, ty] : it->second) + result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type}; + + break; + } + } + + return result; +} + +static bool canSuggestInferredType(ScopePtr scope, TypeId ty) +{ + ty = follow(ty); + + // No point in suggesting 'any', invalid to suggest others + if (get(ty) || get(ty) || get(ty) || get(ty)) + return false; + + // No syntax for unnamed tables with a metatable + if (const MetatableTypeVar* mtv = get(ty)) + return false; + + if (const TableTypeVar* ttv = get(ty)) + { + if (ttv->name) + return true; + + if (ttv->syntheticName) + return false; + } + + // We might still have a type with cycles or one that is too long, we'll check that later + return true; +} + +// Walk complex type trees to find the element that is being edited +static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position); + +static std::optional findTypeElementAt(const AstTypeList& astTypeList, TypePackId tp, Position position) +{ + for (size_t i = 0; i < astTypeList.types.size; i++) + { + AstType* type = astTypeList.types.data[i]; + + if (type->location.containsClosed(position)) + { + auto [head, _] = flatten(tp); + + if (i < head.size()) + return findTypeElementAt(type, head[i], position); + } + } + + if (AstTypePack* argTp = astTypeList.tailType) + { + if (auto variadic = argTp->as()) + { + if (variadic->location.containsClosed(position)) + { + auto [_, tail] = flatten(tp); + + if (tail) + { + if (const VariadicTypePack* vtp = get(follow(*tail))) + return findTypeElementAt(variadic->variadicType, vtp->ty, position); + } + } + } + } + + return {}; +} + +static std::optional findTypeElementAt(AstType* astType, TypeId ty, Position position) +{ + ty = follow(ty); + + if (astType->is()) + return ty; + + if (astType->is()) + return ty; + + if (AstTypeFunction* type = astType->as()) + { + const FunctionTypeVar* ftv = get(ty); + + if (!ftv) + return {}; + + if (auto element = findTypeElementAt(type->argTypes, ftv->argTypes, position)) + return element; + + if (auto element = findTypeElementAt(type->returnTypes, ftv->retType, position)) + return element; + } + + // It's possible to walk through other types like intrsection and unions if we find value in doing that + return {}; +} + +std::optional getLocalTypeInScopeAt(const Module& module, Position position, AstLocal* local) +{ + if (ScopePtr scope = findScopeAtPosition(module, position)) + { + for (const auto& [name, binding] : scope->bindings) + { + if (name == local) + return binding.typeId; + } + } + + return {}; +} + +static std::optional tryGetTypeNameInScope(ScopePtr scope, TypeId ty) +{ + if (!canSuggestInferredType(scope, ty)) + return std::nullopt; + + ToStringOptions opts; + opts.useLineBreaks = false; + opts.hideTableKind = true; + opts.scope = scope; + ToStringResult name = toStringDetailed(ty, opts); + + if (name.error || name.invalid || name.cycle || name.truncated) + return std::nullopt; + + return name.name; +} + +static bool tryAddTypeCorrectSuggestion(AutocompleteEntryMap& result, ScopePtr scope, AstType* topType, TypeId inferredType, Position position) +{ + std::optional ty; + + if (topType) + ty = findTypeElementAt(topType, inferredType, position); + else + ty = inferredType; + + if (!ty) + return false; + + if (auto name = tryGetTypeNameInScope(scope, *ty)) + { + if (auto it = result.find(*name); it != result.end()) + it->second.typeCorrect = TypeCorrectKind::Correct; + else + result[*name] = AutocompleteEntry{AutocompleteEntryKind::Type, *ty, false, false, TypeCorrectKind::Correct}; + + return true; + } + + return false; +} + +static std::optional tryGetTypePackTypeAt(TypePackId tp, size_t index) +{ + auto [tpHead, tpTail] = flatten(tp); + + if (index < tpHead.size()) + return tpHead[index]; + + // Infinite tail + if (tpTail) + { + if (const VariadicTypePack* vtp = get(follow(*tpTail))) + return vtp->ty; + } + + return {}; +} + +template +std::optional returnFirstNonnullOptionOfType(const UnionTypeVar* utv) +{ + std::optional ret; + for (TypeId subTy : utv) + { + if (isNil(subTy)) + continue; + + if (const T* ftv = get(follow(subTy))) + { + if (ret.has_value()) + { + return std::nullopt; + } + ret = ftv; + } + else + { + return std::nullopt; + } + } + return ret; +} + +static std::optional functionIsExpectedAt(const Module& module, AstNode* node) +{ + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + auto it = module.astExpectedTypes.find(expr); + if (it == module.astExpectedTypes.end()) + return std::nullopt; + + TypeId expectedType = follow(it->second); + + if (const FunctionTypeVar* ftv = get(expectedType)) + return true; + + if (const IntersectionTypeVar* itv = get(expectedType)) + { + return std::all_of(begin(itv->parts), end(itv->parts), [](auto&& ty) { + return get(Luau::follow(ty)) != nullptr; + }); + } + + if (const UnionTypeVar* utv = get(expectedType)) + return returnFirstNonnullOptionOfType(utv).has_value(); + + return false; +} + +AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position position, const std::vector& ancestry) +{ + AutocompleteEntryMap result; + + ScopePtr startScope = findScopeAtPosition(module, position); + + for (ScopePtr scope = startScope; scope; scope = scope->parent) + { + for (const auto& [name, ty] : scope->exportedTypeBindings) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type, false, false, TypeCorrectKind::None, std::nullopt, + std::nullopt, ty.type->documentationSymbol}; + } + + for (const auto& [name, ty] : scope->privateTypeBindings) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{AutocompleteEntryKind::Type, ty.type, false, false, TypeCorrectKind::None, std::nullopt, + std::nullopt, ty.type->documentationSymbol}; + } + + for (const auto& [name, _] : scope->importedTypeBindings) + { + if (auto binding = scope->linearSearchForBinding(name, true)) + { + if (!result.count(name)) + result[name] = AutocompleteEntry{AutocompleteEntryKind::Module, binding->typeId}; + } + } + } + + AstNode* parent = nullptr; + AstType* topType = nullptr; + + for (auto it = ancestry.rbegin(), e = ancestry.rend(); it != e; ++it) + { + if (AstType* asType = (*it)->asType()) + { + topType = asType; + } + else + { + parent = *it; + break; + } + } + + if (!parent) + return result; + + if (AstStatLocal* node = parent->as()) // Try to provide inferred type of the local + { + // Look at which of the variable types we are defining + for (size_t i = 0; i < node->vars.size; i++) + { + AstLocal* var = node->vars.data[i]; + + if (var->annotation && var->annotation->location.containsClosed(position)) + { + if (node->values.size == 0) + break; + + unsigned tailPos = 0; + + // For multiple return values we will try to unpack last function call return type pack + if (i >= node->values.size) + { + tailPos = int(i) - int(node->values.size) + 1; + i = int(node->values.size) - 1; + } + + AstExpr* expr = node->values.data[i]->asExpr(); + + if (!expr) + break; + + TypeId inferredType = nullptr; + + if (AstExprCall* exprCall = expr->as()) + { + if (auto it = module.astTypes.find(exprCall->func); it != module.astTypes.end()) + { + if (const FunctionTypeVar* ftv = get(follow(it->second))) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos)) + inferredType = *ty; + } + } + } + else + { + if (tailPos != 0) + break; + + if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) + inferredType = it->second; + } + + if (inferredType) + tryAddTypeCorrectSuggestion(result, startScope, topType, inferredType, position); + + break; + } + } + } + else if (AstExprFunction* node = parent->as()) + { + // For lookup inside expected function type if that's available + auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionTypeVar* { + auto it = module.astExpectedTypes.find(expr); + + if (it == module.astExpectedTypes.end()) + return nullptr; + + TypeId ty = follow(it->second); + + if (const FunctionTypeVar* ftv = get(ty)) + return ftv; + + // Handle optional function type + if (const UnionTypeVar* utv = get(ty)) + { + return returnFirstNonnullOptionOfType(utv).value_or(nullptr); + } + + return nullptr; + }; + + // Find which argument type we are defining + for (size_t i = 0; i < node->args.size; i++) + { + AstLocal* arg = node->args.data[i]; + + if (arg->annotation && arg->annotation->location.containsClosed(position)) + { + if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, i)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + // Otherwise, try to use the type inferred by typechecker + else if (auto inferredType = getLocalTypeInScopeAt(module, position, arg)) + { + tryAddTypeCorrectSuggestion(result, startScope, topType, *inferredType, position); + } + + break; + } + } + + if (AstTypePack* argTp = node->varargAnnotation) + { + if (auto variadic = argTp->as()) + { + if (variadic->location.containsClosed(position)) + { + if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->argTypes, ~0u)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + } + } + } + + for (size_t i = 0; i < node->returnAnnotation.types.size; i++) + { + AstType* ret = node->returnAnnotation.types.data[i]; + + if (ret->location.containsClosed(position)) + { + if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retType, i)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + + // TODO: with additional type information, we could suggest inferred return type here + break; + } + } + + if (AstTypePack* retTp = node->returnAnnotation.tailType) + { + if (auto variadic = retTp->as()) + { + if (variadic->location.containsClosed(position)) + { + if (const FunctionTypeVar* ftv = tryGetExpectedFunctionType(module, node)) + { + if (auto ty = tryGetTypePackTypeAt(ftv->retType, ~0u)) + tryAddTypeCorrectSuggestion(result, startScope, topType, *ty, position); + } + } + } + } + } + + return result; +} + +static bool isInLocalNames(const std::vector& ancestry, Position position) +{ + for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) + { + if (auto statLocal = (*iter)->as()) + { + for (auto var : statLocal->vars) + { + if (var->location.containsClosed(position)) + { + return true; + } + } + } + else if (auto funcExpr = (*iter)->as()) + { + if (funcExpr->argLocation && funcExpr->argLocation->contains(position)) + { + return true; + } + } + else if (auto localFunc = (*iter)->as()) + { + return localFunc->name->location.containsClosed(position); + } + else if (auto block = (*iter)->as()) + { + if (block->body.size > 0) + { + return false; + } + } + else if ((*iter)->asStat()) + { + return false; + } + } + return false; +} + +static bool isIdentifier(AstNode* node) +{ + return node->is() || node->is(); +} + +static bool isBeingDefined(const std::vector& ancestry, const Symbol& symbol) +{ + // Current set of rules only check for local binding match + if (!symbol.local) + return false; + + for (auto iter = ancestry.rbegin(); iter != ancestry.rend(); iter++) + { + if (auto statLocal = (*iter)->as()) + { + for (auto var : statLocal->vars) + { + if (symbol.local == var) + return true; + } + } + } + + return false; +} + +template +T* extractStat(const std::vector& ancestry) +{ + AstNode* node = ancestry.size() >= 1 ? ancestry.rbegin()[0] : nullptr; + if (!node) + return nullptr; + + if (T* t = node->as()) + return t; + + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; + if (!parent) + return nullptr; + + if (T* t = parent->as(); t && parent->is()) + return t; + + AstNode* grandParent = ancestry.size() >= 3 ? ancestry.rbegin()[2] : nullptr; + AstNode* greatGrandParent = ancestry.size() >= 4 ? ancestry.rbegin()[3] : nullptr; + if (!grandParent || !greatGrandParent) + return nullptr; + + if (T* t = greatGrandParent->as(); t && grandParent->is() && parent->is() && isIdentifier(node)) + return t; + + return nullptr; +} + +static bool isBindingLegalAtCurrentPosition(const Binding& binding, Position pos) +{ + // Default Location used for global bindings, which are always legal. + return binding.location == Location() || binding.location.end < pos; +} + +static AutocompleteEntryMap autocompleteStatement( + const SourceModule& sourceModule, const Module& module, const std::vector& ancestry, Position position) +{ + // This is inefficient. :( + ScopePtr scope = findScopeAtPosition(module, position); + + AutocompleteEntryMap result; + + if (isInLocalNames(ancestry, position)) + { + autocompleteKeywords(sourceModule, ancestry, position, result); + return result; + } + + while (scope) + { + for (const auto& [name, binding] : scope->bindings) + { + if (!isBindingLegalAtCurrentPosition(binding, position)) + continue; + + std::string n = toString(name); + if (!result.count(n)) + result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, TypeCorrectKind::None, std::nullopt, + std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, TypeCorrectKind::None)}; + } + + scope = scope->parent; + } + + for (const auto& kw : kStatementStartingKeywords) + result.emplace(kw, AutocompleteEntry{AutocompleteEntryKind::Keyword}); + + for (auto it = ancestry.rbegin(); it != ancestry.rend(); ++it) + { + if (AstStatForIn* statForIn = (*it)->as(); statForIn && !statForIn->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstStatFor* statFor = (*it)->as(); statFor && !statFor->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstStatIf* statIf = (*it)->as(); statIf && !statIf->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstStatWhile* statWhile = (*it)->as(); statWhile && !statWhile->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + if (AstExprFunction* exprFunction = (*it)->as(); exprFunction && !exprFunction->hasEnd) + result.emplace("end", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + + if (ancestry.size() >= 2) + { + AstNode* parent = ancestry.rbegin()[1]; + if (AstStatIf* statIf = parent->as()) + { + if (!statIf->elsebody || (statIf->hasElse && statIf->elseLocation.containsClosed(position))) + { + result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + } + + if (AstStatRepeat* statRepeat = parent->as(); statRepeat && !statRepeat->hasUntil) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + + if (ancestry.size() >= 4) + { + auto iter = ancestry.rbegin(); + if (AstStatIf* statIf = iter[3]->as(); + statIf != nullptr && !statIf->elsebody && iter[2]->is() && iter[1]->is() && isIdentifier(iter[0])) + { + result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + } + } + + if (AstStatRepeat* statRepeat = extractStat(ancestry); statRepeat && !statRepeat->hasUntil) + result.emplace("until", AutocompleteEntry{AutocompleteEntryKind::Keyword}); + + return result; +} + +// Returns true if completions were generated (completions will be inserted into 'outResult') +// Returns false if no completions were generated +static bool autocompleteIfElseExpression( + const AstNode* node, const std::vector& ancestry, const Position& position, AutocompleteEntryMap& outResult) +{ + AstNode* parent = ancestry.size() >= 2 ? ancestry.rbegin()[1] : nullptr; + if (!parent) + return false; + + AstExprIfElse* ifElseExpr = parent->as(); + if (!ifElseExpr || ifElseExpr->condition->location.containsClosed(position)) + { + return false; + } + else if (!ifElseExpr->hasThen) + { + outResult["then"] = {AutocompleteEntryKind::Keyword}; + return true; + } + else if (ifElseExpr->trueExpr->location.containsClosed(position)) + { + return false; + } + else if (!ifElseExpr->hasElse) + { + outResult["else"] = {AutocompleteEntryKind::Keyword}; + outResult["elseif"] = {AutocompleteEntryKind::Keyword}; + return true; + } + else + { + return false; + } +} + +static void autocompleteExpression(const SourceModule& sourceModule, const Module& module, const TypeChecker& typeChecker, TypeArena* typeArena, + const std::vector& ancestry, Position position, AutocompleteEntryMap& result) +{ + LUAU_ASSERT(!ancestry.empty()); + + AstNode* node = ancestry.rbegin()[0]; + + if (node->is()) + { + auto it = module.astTypes.find(node->asExpr()); + if (it != module.astTypes.end()) + autocompleteProps(module, typeArena, it->second, PropIndexType::Point, ancestry, result); + } + else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) + return; + else if (node->is()) + return; + else + { + // This is inefficient. :( + ScopePtr scope = findScopeAtPosition(module, position); + + while (scope) + { + for (const auto& [name, binding] : scope->bindings) + { + if (!isBindingLegalAtCurrentPosition(binding, position)) + continue; + + if (isBeingDefined(ancestry, name)) + continue; + + std::string n = toString(name); + if (!result.count(n)) + { + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, binding.typeId); + + result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, + binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; + } + } + + scope = scope->parent; + } + + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, typeChecker.nilType); + TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, typeChecker.booleanType); + TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + + if (FFlag::LuauIfElseExpressionAnalysisSupport) + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; + result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; + result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; + result["not"] = {AutocompleteEntryKind::Keyword}; + result["function"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false, correctForFunction}; + } +} + +static AutocompleteEntryMap autocompleteExpression(const SourceModule& sourceModule, const Module& module, const TypeChecker& typeChecker, + TypeArena* typeArena, const std::vector& ancestry, Position position) +{ + AutocompleteEntryMap result; + autocompleteExpression(sourceModule, module, typeChecker, typeArena, ancestry, position, result); + return result; +} + +static std::optional getMethodContainingClass(const ModulePtr& module, AstExpr* funcExpr) +{ + AstExpr* parentExpr = nullptr; + if (auto indexName = funcExpr->as()) + { + parentExpr = indexName->expr; + } + else if (auto indexExpr = funcExpr->as()) + { + parentExpr = indexExpr->expr; + } + else + { + return std::nullopt; + } + + auto parentIter = module->astTypes.find(parentExpr); + if (parentIter == module->astTypes.end()) + { + return std::nullopt; + } + + Luau::TypeId parentType = Luau::follow(parentIter->second); + + if (auto parentClass = Luau::get(parentType)) + { + return parentClass; + } + + if (auto parentUnion = Luau::get(parentType)) + { + return returnFirstNonnullOptionOfType(parentUnion); + } + + return std::nullopt; +} + +static std::optional autocompleteStringParams(const SourceModule& sourceModule, const ModulePtr& module, + const std::vector& nodes, Position position, StringCompletionCallback callback) +{ + if (nodes.size() < 2) + { + return std::nullopt; + } + + if (!nodes.back()->is()) + { + return std::nullopt; + } + + AstExprCall* candidate = nodes.at(nodes.size() - 2)->as(); + if (!candidate) + { + return std::nullopt; + } + + // HACK: All current instances of 'magic string' params are the first parameter of their functions, + // so we encode that here rather than putting a useless member on the FunctionTypeVar struct. + if (candidate->args.size > 1 && !candidate->args.data[0]->location.contains(position)) + { + return std::nullopt; + } + + auto iter = module->astTypes.find(candidate->func); + if (iter == module->astTypes.end()) + { + return std::nullopt; + } + + auto performCallback = [&](const FunctionTypeVar* funcType) -> std::optional { + for (const std::string& tag : funcType->tags) + { + if (std::optional ret = callback(tag, getMethodContainingClass(module, candidate->func))) + { + return ret; + } + } + return std::nullopt; + }; + + auto followedId = Luau::follow(iter->second); + if (auto functionType = Luau::get(followedId)) + { + return performCallback(functionType); + } + + if (auto intersect = Luau::get(followedId)) + { + for (TypeId part : intersect->parts) + { + if (auto candidateFunctionType = Luau::get(part)) + { + if (std::optional ret = performCallback(candidateFunctionType)) + { + return ret; + } + } + } + } + + return std::nullopt; +} + +static AutocompleteResult autocomplete(const SourceModule& sourceModule, const ModulePtr& module, const TypeChecker& typeChecker, + TypeArena* typeArena, Position position, StringCompletionCallback callback) +{ + if (isWithinComment(sourceModule, position)) + return {}; + + NodeFinder finder{position, sourceModule.root}; + sourceModule.root->visit(&finder); + LUAU_ASSERT(!finder.ancestry.empty()); + AstNode* node = finder.ancestry.back(); + + AstExprConstantNil dummy{Location{}}; + AstNode* parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy; + + // If we are inside a body of a function that doesn't have a completed argument list, ignore the body node + if (auto exprFunction = parent->as(); exprFunction && !exprFunction->argLocation && node == exprFunction->body) + { + finder.ancestry.pop_back(); + + node = finder.ancestry.back(); + parent = finder.ancestry.size() >= 2 ? finder.ancestry.rbegin()[1] : &dummy; + } + + if (auto indexName = node->as()) + { + auto it = module->astTypes.find(indexName->expr); + if (it == module->astTypes.end()) + return {}; + + TypeId ty = follow(it->second); + PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; + + if (isString(ty)) + return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), + finder.ancestry}; + else + return {autocompleteProps(*module, typeArena, ty, indexType, finder.ancestry), finder.ancestry}; + } + else if (auto typeReference = node->as()) + { + if (typeReference->hasPrefix) + return {autocompleteModuleTypes(*module, position, typeReference->prefix.value), finder.ancestry}; + else + return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; + } + else if (node->is()) + { + return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; + } + else if (AstStatLocal* statLocal = node->as()) + { + if (statLocal->vars.size == 1 && (!statLocal->hasEqualsSign || position < statLocal->equalsSignLocation.begin)) + return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + else if (statLocal->hasEqualsSign && position >= statLocal->equalsSignLocation.end) + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + else + return {}; + } + + else if (AstStatFor* statFor = extractStat(finder.ancestry)) + { + if (!statFor->hasDo || position < statFor->doLocation.begin) + { + if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + + if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || + (statFor->step && statFor->step->location.containsClosed(position))) + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + + return {}; + } + + return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + } + + else if (AstStatForIn* statForIn = parent->as(); statForIn && (node->is() || isIdentifier(node))) + { + if (!statForIn->hasIn || position <= statForIn->inLocation.begin) + { + AstLocal* lastName = statForIn->vars.data[statForIn->vars.size - 1]; + if (lastName->name == Parser::errorName || lastName->location.containsClosed(position)) + { + // Here we are either working with a missing binding (as would be the case in a bare "for" keyword) or + // the cursor is still touching a binding name. The user is still typing a new name, so we should not offer + // any suggestions. + return {}; + } + + return {{{"in", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + } + + if (!statForIn->hasDo || position <= statForIn->doLocation.begin) + { + LUAU_ASSERT(statForIn->values.size > 0); + AstExpr* lastExpr = statForIn->values.data[statForIn->values.size - 1]; + + if (lastExpr->location.containsClosed(position)) + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + + if (position > lastExpr->location.end) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + + return {}; // Not sure what this means + } + } + else if (AstStatForIn* statForIn = extractStat(finder.ancestry)) + { + // The AST looks a bit differently if the cursor is at a position where only the "do" keyword is allowed. + // ex "for f in f do" + if (!statForIn->hasDo) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + + return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + } + + else if (AstStatWhile* statWhile = parent->as(); node->is() && statWhile) + { + if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + + if (!statWhile->hasDo || position < statWhile->doLocation.begin) + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + + if (statWhile->hasDo && position > statWhile->doLocation.end) + return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + } + + else if (AstStatWhile* statWhile = extractStat(finder.ancestry); statWhile && !statWhile->hasDo) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + + else if (AstStatIf* statIf = node->as(); FFlag::ElseElseIfCompletionImprovements && statIf && !statIf->hasElse) + { + return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, + finder.ancestry}; + } + else if (AstStatIf* statIf = parent->as(); statIf && node->is()) + { + if (statIf->condition->is()) + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + else if (!statIf->hasThen || statIf->thenLocation.containsClosed(position)) + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + } + else if (AstStatIf* statIf = extractStat(finder.ancestry); + statIf && (!statIf->hasThen || statIf->thenLocation.containsClosed(position))) + return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; + else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + else if (AstStatRepeat* statRepeat = extractStat(finder.ancestry); statRepeat) + return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + else if (AstExprTable* exprTable = parent->as(); exprTable && (node->is() || node->is())) + { + for (const auto& [kind, key, value] : exprTable->items) + { + // If item doesn't have a key, maybe the value is actually the key + if (key ? key == node : node->is() && value == node) + { + if (auto it = module->astExpectedTypes.find(exprTable); it != module->astExpectedTypes.end()) + { + auto result = autocompleteProps(*module, typeArena, it->second, PropIndexType::Key, finder.ancestry); + + // Remove keys that are already completed + for (const auto& item : exprTable->items) + { + if (!item.key) + continue; + + if (auto stringKey = item.key->as()) + result.erase(std::string(stringKey->value.data, stringKey->value.size)); + } + + // If we know for sure that a key is being written, do not offer general epxression suggestions + if (!key) + autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position, result); + + return {result, finder.ancestry}; + } + + break; + } + } + } + else if (isIdentifier(node) && (parent->is() || parent->is())) + return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + + if (std::optional ret = autocompleteStringParams(sourceModule, module, finder.ancestry, position, callback)) + { + return {*ret, finder.ancestry}; + } + else if (node->is()) + { + if (finder.ancestry.size() >= 2) + { + if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) + { + if (auto it = module->astTypes.find(idxExpr->expr); it != module->astTypes.end()) + { + return {autocompleteProps(*module, typeArena, follow(it->second), PropIndexType::Point, finder.ancestry), finder.ancestry}; + } + } + } + return {}; + } + + if (node->is()) + { + return {}; + } + + if (node->asExpr()) + return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; + else if (node->asStat()) + return {autocompleteStatement(sourceModule, *module, finder.ancestry, position), finder.ancestry}; + + return {}; +} + +AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback) +{ + // FIXME: We can improve performance here by parsing without checking. + // The old type graph is probably fine. (famous last words!) + // FIXME: We don't need to typecheck for script analysis here, just for autocomplete. + frontend.check(moduleName); + + const SourceModule* sourceModule = frontend.getSourceModule(moduleName); + if (!sourceModule) + return {}; + + TypeChecker& typeChecker = + (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + ModulePtr module = + (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); + + if (!module) + return {}; + + AutocompleteResult autocompleteResult = autocomplete(*sourceModule, module, typeChecker, &frontend.arenaForAutocomplete, position, callback); + + frontend.arenaForAutocomplete.clear(); + + return autocompleteResult; +} + +OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback) +{ + auto sourceModule = std::make_unique(); + ParseOptions parseOptions; + parseOptions.captureComments = true; + ParseResult result = Parser::parse(source.data(), source.size(), *sourceModule->names, *sourceModule->allocator, parseOptions); + + if (!result.root) + return {AutocompleteResult{}, {}, nullptr}; + + sourceModule->name = "FRAGMENT_SCRIPT"; + sourceModule->root = result.root; + sourceModule->mode = Mode::Strict; + sourceModule->commentLocations = std::move(result.commentLocations); + + TypeChecker& typeChecker = + (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + + ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); + + OwningAutocompleteResult autocompleteResult = { + autocomplete(*sourceModule, module, typeChecker, &frontend.arenaForAutocomplete, position, callback), std::move(module), + std::move(sourceModule)}; + + frontend.arenaForAutocomplete.clear(); + + return autocompleteResult; +} + +} // namespace Luau diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp new file mode 100644 index 0000000..68ad5ac --- /dev/null +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -0,0 +1,805 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/BuiltinDefinitions.h" + +#include "Luau/Frontend.h" +#include "Luau/Symbol.h" +#include "Luau/Common.h" +#include "Luau/ToString.h" + +#include + +LUAU_FASTFLAG(LuauParseGenericFunctions) +LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAG(LuauRankNTypes) +LUAU_FASTFLAG(LuauStringMetatable) + +/** FIXME: Many of these type definitions are not quite completely accurate. + * + * Some of them require richer generics than we have. For instance, we do not yet have a way to talk + * about a function that takes any number of values, but where each value must have some specific type. + */ + +namespace Luau +{ + +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); + +TypeId makeUnion(TypeArena& arena, std::vector&& types) +{ + return arena.addType(UnionTypeVar{std::move(types)}); +} + +TypeId makeIntersection(TypeArena& arena, std::vector&& types) +{ + return arena.addType(IntersectionTypeVar{std::move(types)}); +} + +TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t) +{ + return makeUnion(arena, {typeChecker.nilType, t}); +} + +TypeId makeFunction( + TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, std::initializer_list retTypes) +{ + return makeFunction(arena, selfType, {}, {}, paramTypes, {}, retTypes); +} + +TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, + std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list retTypes) +{ + return makeFunction(arena, selfType, generics, genericPacks, paramTypes, {}, retTypes); +} + +TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list paramTypes, + std::initializer_list paramNames, std::initializer_list retTypes) +{ + return makeFunction(arena, selfType, {}, {}, paramTypes, paramNames, retTypes); +} + +TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, + std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, + std::initializer_list retTypes) +{ + std::vector params; + if (selfType) + params.push_back(*selfType); + for (auto&& p : paramTypes) + params.push_back(p); + + TypePackId paramPack = arena.addTypePack(std::move(params)); + TypePackId retPack = arena.addTypePack(std::vector(retTypes)); + FunctionTypeVar ftv{generics, genericPacks, paramPack, retPack, {}, selfType.has_value()}; + + if (selfType) + ftv.argNames.push_back(Luau::FunctionArgument{"self", {}}); + + if (paramNames.size() != 0) + { + for (auto&& p : paramNames) + ftv.argNames.push_back(Luau::FunctionArgument{std::move(p), {}}); + } + else if (selfType) + { + // If argument names were not provided, but we have already added a name for 'self' argument, we have to fill remaining slots as well + for (size_t i = 0; i < paramTypes.size(); i++) + ftv.argNames.push_back(std::nullopt); + } + + return arena.addType(std::move(ftv)); +} + +void attachMagicFunction(TypeId ty, MagicFunction fn) +{ + if (auto ftv = getMutable(ty)) + ftv->magicFunction = fn; + else + LUAU_ASSERT(!"Got a non functional type"); +} + +void attachFunctionTag(TypeId ty, std::string tag) +{ + if (auto ftv = getMutable(ty)) + { + ftv->tags.emplace_back(std::move(tag)); + } + else + { + LUAU_ASSERT(!"Got a non functional type"); + } +} + +Property makeProperty(TypeId ty, std::optional documentationSymbol) +{ + return { + /* type */ ty, + /* deprecated */ false, + /* deprecatedSuggestion */ {}, + /* location */ std::nullopt, + /* tags */ {}, + documentationSymbol, + }; +} + +void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName) +{ + addGlobalBinding(typeChecker, typeChecker.globalScope, name, ty, packageName); +} + +void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding) +{ + addGlobalBinding(typeChecker, typeChecker.globalScope, name, binding); +} + +void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) +{ + std::string documentationSymbol = packageName + "/global/" + name; + addGlobalBinding(typeChecker, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); +} + +void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding) +{ + scope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = binding; +} + +TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name) +{ + auto t = tryGetGlobalBinding(typeChecker, name); + LUAU_ASSERT(t.has_value()); + return t->typeId; +} + +std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std::string& name) +{ + AstName astName = typeChecker.globalNames.names->getOrAdd(name.c_str()); + auto it = typeChecker.globalScope->bindings.find(astName); + if (it != typeChecker.globalScope->bindings.end()) + return it->second; + + return std::nullopt; +} + +Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name) +{ + AstName astName = typeChecker.globalNames.names->get(name.c_str()); + if (astName == AstName()) + return nullptr; + + auto it = typeChecker.globalScope->bindings.find(astName); + if (it != typeChecker.globalScope->bindings.end()) + return &it->second; + + return nullptr; +} + +void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName) +{ + for (auto& [name, prop] : props) + { + prop.documentationSymbol = baseName + "." + name; + } +} + +void registerBuiltinTypes(TypeChecker& typeChecker) +{ + LUAU_ASSERT(!typeChecker.globalTypes.typeVars.isFrozen()); + LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); + + TypeId numberType = typeChecker.numberType; + TypeId booleanType = typeChecker.booleanType; + TypeId nilType = typeChecker.nilType; + TypeId stringType = typeChecker.stringType; + TypeId threadType = typeChecker.threadType; + TypeId anyType = typeChecker.anyType; + + TypeArena& arena = typeChecker.globalTypes; + + TypeId optionalNumber = makeOption(typeChecker, arena, numberType); + TypeId optionalString = makeOption(typeChecker, arena, stringType); + TypeId optionalBoolean = makeOption(typeChecker, arena, booleanType); + + TypeId stringOrNumber = makeUnion(arena, {stringType, numberType}); + + TypePackId emptyPack = arena.addTypePack({}); + TypePackId oneNumberPack = arena.addTypePack({numberType}); + TypePackId oneStringPack = arena.addTypePack({stringType}); + TypePackId oneBooleanPack = arena.addTypePack({booleanType}); + TypePackId oneAnyPack = arena.addTypePack({anyType}); + + TypePackId anyTypePack = typeChecker.anyTypePack; + + TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); + TypePackId stringVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{stringType}}); + TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); + + TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{ + listOfAtLeastOneNumber, + oneNumberPack, + }); + + TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack}); + + TypeId stringToAnyMap = arena.addType(TableTypeVar{{}, TableIndexer(stringType, anyType), typeChecker.globalScope->level}); + + LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); + LUAU_ASSERT(loadResult.success); + + TypeId mathLibType = getGlobalBinding(typeChecker, "math"); + if (TableTypeVar* ttv = getMutable(mathLibType)) + { + ttv->props["min"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.min"); + ttv->props["max"] = makeProperty(listOfAtLeastOneNumberToNumberType, "@luau/global/math.max"); + } + + TypeId bit32LibType = getGlobalBinding(typeChecker, "bit32"); + if (TableTypeVar* ttv = getMutable(bit32LibType)) + { + ttv->props["band"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.band"); + ttv->props["bor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bor"); + ttv->props["bxor"] = makeProperty(listOfAtLeastZeroNumbersToNumberType, "@luau/global/bit32.bxor"); + ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest"); + } + + TypeId anyFunction = arena.addType(FunctionTypeVar{anyTypePack, anyTypePack}); + + TypeId genericK = arena.addType(GenericTypeVar{"K"}); + TypeId genericV = arena.addType(GenericTypeVar{"V"}); + TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); + + if (FFlag::LuauStringMetatable) + { + std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); + LUAU_ASSERT(stringMetatableTy); + const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); + LUAU_ASSERT(stringMetatableTable); + + auto it = stringMetatableTable->props.find("__index"); + LUAU_ASSERT(it != stringMetatableTable->props.end()); + + TypeId stringLib = it->second.type; + addGlobalBinding(typeChecker, "string", stringLib, "@luau"); + } + + if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) + { + if (!FFlag::LuauStringMetatable) + { + TypeId stringLibTy = getGlobalBinding(typeChecker, "string"); + TableTypeVar* stringLib = getMutable(stringLibTy); + TypeId replArgType = makeUnion( + arena, {stringType, + arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), + makeFunction(arena, std::nullopt, {stringType}, {stringType})}); + TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); + + stringLib->props["gsub"] = makeProperty(gsubFunc, "@luau/global/string.gsub"); + } + } + else + { + if (!FFlag::LuauStringMetatable) + { + TypeId stringToStringType = makeFunction(arena, std::nullopt, {stringType}, {stringType}); + + TypeId gmatchFunc = makeFunction(arena, stringType, {stringType}, {arena.addType(FunctionTypeVar{emptyPack, stringVariadicList})}); + + TypeId replArgType = makeUnion( + arena, {stringType, + arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), + makeFunction(arena, std::nullopt, {stringType}, {stringType})}); + TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); + + TypeId formatFn = arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}); + + TableTypeVar::Props stringLib = { + // FIXME string.byte "can" return a pack of numbers, but only if 2nd or 3rd arguments were supplied + {"byte", {makeFunction(arena, stringType, {optionalNumber, optionalNumber}, {optionalNumber})}}, + // FIXME char takes a variadic pack of numbers + {"char", {makeFunction(arena, std::nullopt, {numberType, optionalNumber, optionalNumber, optionalNumber}, {stringType})}}, + {"find", {makeFunction(arena, stringType, {stringType, optionalNumber, optionalBoolean}, {optionalNumber, optionalNumber})}}, + {"format", {formatFn}}, // FIXME + {"gmatch", {gmatchFunc}}, + {"gsub", {gsubFunc}}, + {"len", {makeFunction(arena, stringType, {}, {numberType})}}, + {"lower", {stringToStringType}}, + {"match", {makeFunction(arena, stringType, {stringType, optionalNumber}, {optionalString})}}, + {"rep", {makeFunction(arena, stringType, {numberType}, {stringType})}}, + {"reverse", {stringToStringType}}, + {"sub", {makeFunction(arena, stringType, {numberType, optionalNumber}, {stringType})}}, + {"upper", {stringToStringType}}, + {"split", {makeFunction(arena, stringType, {stringType, optionalString}, + {arena.addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, typeChecker.globalScope->level})})}}, + {"pack", {arena.addType(FunctionTypeVar{ + arena.addTypePack(TypePack{{stringType}, anyTypePack}), + oneStringPack, + })}}, + {"packsize", {makeFunction(arena, stringType, {}, {numberType})}}, + {"unpack", {arena.addType(FunctionTypeVar{ + arena.addTypePack(TypePack{{stringType, stringType, optionalNumber}}), + anyTypePack, + })}}, + }; + + assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + addGlobalBinding(typeChecker, "string", + arena.addType(TableTypeVar{stringLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); + } + + TableTypeVar::Props debugLib{ + {"info", {makeIntersection(arena, + { + arena.addType(FunctionTypeVar{arena.addTypePack({typeChecker.threadType, numberType, stringType}), anyTypePack}), + arena.addType(FunctionTypeVar{arena.addTypePack({numberType, stringType}), anyTypePack}), + arena.addType(FunctionTypeVar{arena.addTypePack({anyFunction, stringType}), anyTypePack}), + })}}, + {"traceback", {makeIntersection(arena, + { + makeFunction(arena, std::nullopt, {optionalString, optionalNumber}, {stringType}), + makeFunction(arena, std::nullopt, {typeChecker.threadType, optionalString, optionalNumber}, {stringType}), + })}}, + }; + + assignPropDocumentationSymbols(debugLib, "@luau/global/debug"); + addGlobalBinding(typeChecker, "debug", + arena.addType(TableTypeVar{debugLib, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}), "@luau"); + + TableTypeVar::Props utf8Lib = { + {"char", {arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneStringPack})}}, // FIXME + {"charpattern", {stringType}}, + {"codes", {makeFunction(arena, std::nullopt, {stringType}, + {makeFunction(arena, std::nullopt, {stringType, numberType}, {numberType, numberType}), stringType, numberType})}}, + {"codepoint", + {arena.addType(FunctionTypeVar{arena.addTypePack({stringType, optionalNumber, optionalNumber}), listOfAtLeastOneNumber})}}, // FIXME + {"len", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {optionalNumber, numberType})}}, + {"offset", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {numberType})}}, + {"nfdnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}}, + {"graphemes", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, + {makeFunction(arena, std::nullopt, {}, {numberType, numberType})})}}, + {"nfcnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}}, + }; + + assignPropDocumentationSymbols(utf8Lib, "@luau/global/utf8"); + addGlobalBinding( + typeChecker, "utf8", arena.addType(TableTypeVar{utf8Lib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); + + TypeId optionalV = makeOption(typeChecker, arena, genericV); + + TypeId arrayOfV = arena.addType(TableTypeVar{{}, TableIndexer(numberType, genericV), typeChecker.globalScope->level}); + + TypePackId unpackArgsPack = arena.addTypePack(TypePack{{arrayOfV, optionalNumber, optionalNumber}}); + TypePackId unpackReturnPack = arena.addTypePack(TypePack{{}, anyTypePack}); + TypeId unpackFunc = arena.addType(FunctionTypeVar{{genericV}, {}, unpackArgsPack, unpackReturnPack}); + + TypeId packResult = arena.addType(TableTypeVar{ + TableTypeVar::Props{{"n", {numberType}}}, TableIndexer{numberType, numberType}, typeChecker.globalScope->level, TableState::Sealed}); + TypePackId packArgsPack = arena.addTypePack(TypePack{{}, anyTypePack}); + TypePackId packReturnPack = arena.addTypePack(TypePack{{packResult}}); + + TypeId comparator = makeFunction(arena, std::nullopt, {genericV, genericV}, {booleanType}); + TypeId optionalComparator = makeOption(typeChecker, arena, comparator); + + TypeId packFn = arena.addType(FunctionTypeVar(packArgsPack, packReturnPack)); + + TableTypeVar::Props tableLib = { + {"concat", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalString, optionalNumber, optionalNumber}, {stringType})}}, + {"insert", {makeIntersection(arena, {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV}, {}), + makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, genericV}, {})})}}, + {"maxn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}}, + {"remove", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalNumber}, {optionalV})}}, + {"sort", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalComparator}, {})}}, + {"create", {makeFunction(arena, std::nullopt, {genericV}, {}, {numberType, optionalV}, {arrayOfV})}}, + {"find", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV, optionalNumber}, {optionalNumber})}}, + + {"unpack", {unpackFunc}}, // FIXME + {"pack", {packFn}}, + + // Lua 5.0 compat + {"getn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}}, + {"foreach", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, + {mapOfKtoV, makeFunction(arena, std::nullopt, {genericK, genericV}, {})}, {})}}, + {"foreachi", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, makeFunction(arena, std::nullopt, {genericV}, {})}, {})}}, + + // backported from Lua 5.3 + {"move", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, numberType, numberType, arrayOfV}, {})}}, + + // added in Luau (borrowed from LuaJIT) + {"clear", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {})}}, + + {"freeze", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {mapOfKtoV})}}, + {"isfrozen", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {booleanType})}}, + }; + + assignPropDocumentationSymbols(tableLib, "@luau/global/table"); + addGlobalBinding( + typeChecker, "table", arena.addType(TableTypeVar{tableLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); + + TableTypeVar::Props coroutineLib = { + {"create", {makeFunction(arena, std::nullopt, {anyFunction}, {threadType})}}, + {"resume", {arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{threadType}, anyTypePack}), anyTypePack})}}, + {"running", {makeFunction(arena, std::nullopt, {}, {threadType})}}, + {"status", {makeFunction(arena, std::nullopt, {threadType}, {stringType})}}, + {"wrap", {makeFunction( + arena, std::nullopt, {anyFunction}, {anyType})}}, // FIXME this technically returns a function, but we can't represent this + // atm since it can be called with different arg types at different times + {"yield", {arena.addType(FunctionTypeVar{anyTypePack, anyTypePack})}}, + {"isyieldable", {makeFunction(arena, std::nullopt, {}, {booleanType})}}, + }; + + assignPropDocumentationSymbols(coroutineLib, "@luau/global/coroutine"); + addGlobalBinding(typeChecker, "coroutine", + arena.addType(TableTypeVar{coroutineLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); + + TypeId genericT = arena.addType(GenericTypeVar{"T"}); + TypeId genericR = arena.addType(GenericTypeVar{"R"}); + + // assert returns all arguments + TypePackId assertArgs = arena.addTypePack({genericT, optionalString}); + TypePackId assertRets = arena.addTypePack({genericT}); + addGlobalBinding(typeChecker, "assert", arena.addType(FunctionTypeVar{assertArgs, assertRets}), "@luau"); + + addGlobalBinding(typeChecker, "print", arena.addType(FunctionTypeVar{anyTypePack, emptyPack}), "@luau"); + + addGlobalBinding(typeChecker, "type", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); + addGlobalBinding(typeChecker, "typeof", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); + + addGlobalBinding(typeChecker, "error", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {}), "@luau"); + + addGlobalBinding(typeChecker, "tostring", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); + addGlobalBinding( + typeChecker, "tonumber", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {numberType}), "@luau"); + + addGlobalBinding( + typeChecker, "rawequal", makeFunction(arena, std::nullopt, {genericT, genericR}, {}, {genericT, genericR}, {booleanType}), "@luau"); + addGlobalBinding( + typeChecker, "rawget", makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK}, {genericV}), "@luau"); + addGlobalBinding(typeChecker, "rawset", + makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK, genericV}, {mapOfKtoV}), "@luau"); + + TypePackId genericTPack = arena.addTypePack({genericT}); + TypePackId genericRPack = arena.addTypePack({genericR}); + TypeId genericArgsToReturnFunction = arena.addType( + FunctionTypeVar{{genericT, genericR}, {}, arena.addTypePack(TypePack{{}, genericTPack}), arena.addTypePack(TypePack{{}, genericRPack})}); + + TypeId setfenvArgType = makeUnion(arena, {numberType, genericArgsToReturnFunction}); + TypeId setfenvReturnType = makeOption(typeChecker, arena, genericArgsToReturnFunction); + addGlobalBinding(typeChecker, "setfenv", makeFunction(arena, std::nullopt, {setfenvArgType, stringToAnyMap}, {setfenvReturnType}), "@luau"); + + TypePackId ipairsArgsTypePack = arena.addTypePack({arrayOfV}); + + TypeId ipairsNextFunctionType = arena.addType( + FunctionTypeVar{{genericK, genericV}, {}, arena.addTypePack({arrayOfV, numberType}), arena.addTypePack({numberType, genericV})}); + + // ipairs returns 'next, Array, 0' so we would need type-level primitives and change to + // again, we have a direct reference to 'next' because ipairs returns it + // ipairs(t: Array) -> ((Array) -> (number, V), Array, 0) + TypePackId ipairsReturnTypePack = arena.addTypePack(TypePack{{ipairsNextFunctionType, arrayOfV, numberType}}); + + // ipairs(t: Array) -> ((Array) -> (number, V), Array, number) + addGlobalBinding(typeChecker, "ipairs", arena.addType(FunctionTypeVar{{genericV}, {}, ipairsArgsTypePack, ipairsReturnTypePack}), "@luau"); + + TypePackId pcallArg0FnArgs = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}}); + TypePackId pcallArg0FnRet = arena.addTypePack(TypePackVar{GenericTypeVar{"R"}}); + TypeId pcallArg0 = arena.addType(FunctionTypeVar{pcallArg0FnArgs, pcallArg0FnRet}); + TypePackId pcallArgsTypePack = arena.addTypePack(TypePack{{pcallArg0}, pcallArg0FnArgs}); + + TypePackId pcallReturnTypePack = arena.addTypePack(TypePack{{booleanType}, pcallArg0FnRet}); + + // pcall(f: (A...) -> R..., args: A...) -> boolean, R... + addGlobalBinding(typeChecker, "pcall", + arena.addType(FunctionTypeVar{{}, {pcallArg0FnArgs, pcallArg0FnRet}, pcallArgsTypePack, pcallReturnTypePack}), "@luau"); + + // errors thrown by the function 'f' are propagated onto the function 'err' that accepts it. + // and either 'f' or 'err' are valid results of this xpcall + // if 'err' did throw an error, then it returns: false, "error in error handling" + // TODO: the above is not represented (nor representable) in the type annotation below. + // + // The real type of xpcall is as such: (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, + // R2...) + TypePackId genericAPack = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}}); + TypePackId genericR1Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R1"}}); + TypePackId genericR2Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R2"}}); + + TypeId genericE = arena.addType(GenericTypeVar{"E"}); + + TypeId xpcallFArg = arena.addType(FunctionTypeVar{genericAPack, genericR1Pack}); + TypeId xpcallErrArg = arena.addType(FunctionTypeVar{arena.addTypePack({genericE}), genericR2Pack}); + + TypePackId xpcallArgsPack = arena.addTypePack({{xpcallFArg, xpcallErrArg}, genericAPack}); + TypePackId xpcallRetPack = arena.addTypePack({{booleanType}, genericR1Pack}); // FIXME + + addGlobalBinding(typeChecker, "xpcall", + arena.addType(FunctionTypeVar{{genericE}, {genericAPack, genericR1Pack, genericR2Pack}, xpcallArgsPack, xpcallRetPack}), "@luau"); + + addGlobalBinding(typeChecker, "unpack", unpackFunc, "@luau"); + + TypePackId selectArgsTypePack = arena.addTypePack(TypePack{ + {stringOrNumber}, + anyTypePack // FIXME? select() is tricky. + }); + + addGlobalBinding(typeChecker, "select", arena.addType(FunctionTypeVar{selectArgsTypePack, anyTypePack}), "@luau"); + + // TODO: not completely correct. loadstring's return type should be a function or (nil, string) + TypeId loadstringFunc = arena.addType(FunctionTypeVar{anyTypePack, oneAnyPack}); + + addGlobalBinding(typeChecker, "loadstring", + makeFunction(arena, std::nullopt, {stringType, optionalString}, + { + makeOption(typeChecker, arena, loadstringFunc), + makeOption(typeChecker, arena, stringType), + }), + "@luau"); + + // a userdata object is "roughly" the same as a sealed empty table + // except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too. + // another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT + // setmetatable. + // TODO: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`. + TypeId sealedTable = arena.addType(TableTypeVar(TableState::Sealed, typeChecker.globalScope->level)); + addGlobalBinding(typeChecker, "newproxy", makeFunction(arena, std::nullopt, {optionalBoolean}, {sealedTable}), "@luau"); + } + + // next(t: Table, i: K | nil) -> (K, V) + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); + addGlobalBinding(typeChecker, "next", + arena.addType(FunctionTypeVar{{genericK, genericV}, {}, nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}), "@luau"); + + TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); + + TypeId pairsNext = (FFlag::LuauRankNTypes ? arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}) + : getGlobalBinding(typeChecker, "next")); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + + // NOTE we are missing 'i: K | nil' argument in the first return types' argument. + // pairs(t: Table) -> ((Table) -> (K, V), Table, nil) + addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionTypeVar{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + + TypeId genericMT = arena.addType(GenericTypeVar{"MT"}); + + TableTypeVar tab{TableState::Generic, typeChecker.globalScope->level}; + TypeId tabTy = arena.addType(tab); + + TypeId tableMetaMT = arena.addType(MetatableTypeVar{tabTy, genericMT}); + + addGlobalBinding(typeChecker, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); + + // setmetatable({ @metatable MT }, MT) -> { @metatable MT } + // clang-format off + addGlobalBinding(typeChecker, "setmetatable", + arena.addType( + FunctionTypeVar{ + {genericMT}, + {}, + arena.addTypePack(TypePack{{tableMetaMT, genericMT}}), + arena.addTypePack(TypePack{{tableMetaMT}}) + } + ), "@luau" + ); + // clang-format on + + for (const auto& pair : typeChecker.globalScope->bindings) + { + persist(pair.second.typeId); + + if (TableTypeVar* ttv = getMutable(pair.second.typeId)) + ttv->name = toString(pair.first); + } + + attachMagicFunction(getGlobalBinding(typeChecker, "assert"), magicFunctionAssert); + attachMagicFunction(getGlobalBinding(typeChecker, "setmetatable"), magicFunctionSetMetaTable); + attachMagicFunction(getGlobalBinding(typeChecker, "select"), magicFunctionSelect); + + auto tableLib = getMutable(getGlobalBinding(typeChecker, "table")); + attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack); + + auto stringLib = getMutable(getGlobalBinding(typeChecker, "string")); + attachMagicFunction(stringLib->props["format"].type, magicFunctionFormat); + + attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); +} + +static std::optional> magicFunctionSelect( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +{ + auto [paramPack, _predicates] = exprResult; + + (void)scope; + + if (expr.args.size <= 0) + { + typechecker.reportError(TypeError{expr.location, GenericError{"select should take 1 or more arguments"}}); + return std::nullopt; + } + + AstExpr* arg1 = expr.args.data[0]; + if (AstExprConstantNumber* num = arg1->as()) + { + const auto& [v, tail] = flatten(paramPack); + + int offset = int(num->value); + if (offset > 0) + { + if (size_t(offset) < v.size()) + { + std::vector result(v.begin() + offset, v.end()); + return ExprResult{typechecker.currentModule->internalTypes.addTypePack(TypePack{std::move(result), tail})}; + } + else if (tail) + return ExprResult{*tail}; + } + + typechecker.reportError(TypeError{arg1->location, GenericError{"bad argument #1 to select (index out of range)"}}); + } + else if (AstExprConstantString* str = arg1->as()) + { + if (str->value.size == 1 && str->value.data[0] == '#') + return ExprResult{typechecker.currentModule->internalTypes.addTypePack({typechecker.numberType})}; + } + + return std::nullopt; +} + +static std::optional> magicFunctionSetMetaTable( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +{ + auto [paramPack, _predicates] = exprResult; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + std::vector expectedArgs = typechecker.unTypePack(scope, paramPack, 2, expr.location); + + TypeId target = follow(expectedArgs[0]); + TypeId mt = follow(expectedArgs[1]); + + if (const auto& tab = get(target)) + { + if (target->persistent) + { + typechecker.reportError(TypeError{expr.location, CannotExtendTable{target, CannotExtendTable::Metatable}}); + } + else + { + typechecker.tablify(mt); + + const TableTypeVar* mtTtv = get(mt); + MetatableTypeVar mtv{target, mt}; + if ((tab->name || tab->syntheticName) && (mtTtv && (mtTtv->name || mtTtv->syntheticName))) + { + std::string tableName = tab->name ? *tab->name : *tab->syntheticName; + std::string metatableName = mtTtv->name ? *mtTtv->name : *mtTtv->syntheticName; + + if (tableName == metatableName) + mtv.syntheticName = tableName; + else + mtv.syntheticName = "{ @metatable: " + metatableName + ", " + tableName + " }"; + } + + TypeId mtTy = arena.addType(mtv); + + AstExpr* targetExpr = expr.args.data[0]; + if (AstExprLocal* targetLocal = targetExpr->as()) + { + const Name targetName(targetLocal->local->name.value); + scope->bindings[targetLocal->local] = Binding{mtTy, expr.location}; + } + + return ExprResult{arena.addTypePack({mtTy})}; + } + } + else if (get(target) || get(target) || isTableIntersection(target)) + { + } + else + { + typechecker.reportError(TypeError{expr.location, GenericError{"setmetatable should take a table"}}); + } + + return ExprResult{arena.addTypePack({target})}; +} + +static std::optional> magicFunctionAssert( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +{ + auto [paramPack, predicates] = exprResult; + + if (expr.args.size < 1) + return ExprResult{paramPack}; + + typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); + + return ExprResult{paramPack}; +} + +static std::optional> magicFunctionPack( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +{ + auto [paramPack, _predicates] = exprResult; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + const auto& [paramTypes, paramTail] = flatten(paramPack); + + std::vector options; + options.reserve(paramTypes.size()); + for (auto type : paramTypes) + options.push_back(type); + + if (paramTail) + { + if (const VariadicTypePack* vtp = get(*paramTail)) + options.push_back(vtp->ty); + } + + options = typechecker.reduceUnion(options); + + // table.pack() -> {| n: number, [number]: nil |} + // table.pack(1) -> {| n: number, [number]: number |} + // table.pack(1, "foo") -> {| n: number, [number]: number | string |} + TypeId result = nullptr; + if (options.empty()) + result = typechecker.nilType; + else if (options.size() == 1) + result = options[0]; + else + result = arena.addType(UnionTypeVar{std::move(options)}); + + TypeId packedTable = arena.addType( + TableTypeVar{{{"n", {typechecker.numberType}}}, TableIndexer(typechecker.numberType, result), scope->level, TableState::Sealed}); + + return ExprResult{arena.addTypePack({packedTable})}; +} + +static bool checkRequirePath(TypeChecker& typechecker, AstExpr* expr) +{ + // require(foo.parent.bar) will technically work, but it depends on legacy goop that + // Luau does not and could not support without a bunch of work. It's deprecated anyway, so + // we'll warn here if we see it. + bool good = true; + AstExprIndexName* indexExpr = expr->as(); + + while (indexExpr) + { + if (indexExpr->index == "parent") + { + typechecker.reportError(indexExpr->indexLocation, DeprecatedApiUsed{"parent", "Parent"}); + good = false; + } + + indexExpr = indexExpr->expr->as(); + } + + return good; +} + +static std::optional> magicFunctionRequire( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +{ + TypeArena& arena = typechecker.currentModule->internalTypes; + + if (expr.args.size != 1) + { + typechecker.reportError(TypeError{expr.location, GenericError{"require takes 1 argument"}}); + return std::nullopt; + } + + AstExpr* require = expr.args.data[0]; + + if (!checkRequirePath(typechecker, require)) + return std::nullopt; + + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) + return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; + + return std::nullopt; +} + +} // namespace Luau diff --git a/Analysis/src/Config.cpp b/Analysis/src/Config.cpp new file mode 100644 index 0000000..d9fc44f --- /dev/null +++ b/Analysis/src/Config.cpp @@ -0,0 +1,278 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Config.h" + +#include "Luau/Parser.h" +#include "Luau/StringUtils.h" + +namespace +{ + +using Error = std::optional; + +} + +namespace Luau +{ + +static Error parseBoolean(bool& result, const std::string& value) +{ + if (value == "true") + result = true; + else if (value == "false") + result = false; + else + return Error{"Bad setting '" + value + "'. Valid options are true and false"}; + + return std::nullopt; +} + +Error parseModeString(Mode& mode, const std::string& modeString, bool compat) +{ + if (modeString == "nocheck") + mode = Mode::NoCheck; + else if (modeString == "strict") + mode = Mode::Strict; + else if (modeString == "nonstrict") + mode = Mode::Nonstrict; + else if (modeString == "noinfer" && compat) + mode = Mode::NoCheck; + else + return Error{"Bad mode \"" + modeString + "\". Valid options are nocheck, nonstrict, and strict"}; + + return std::nullopt; +} + +static Error parseLintRuleStringForCode( + LintOptions& enabledLints, LintOptions& fatalLints, LintWarning::Code code, const std::string& value, bool compat) +{ + if (value == "true") + { + enabledLints.enableWarning(code); + } + else if (value == "false") + { + enabledLints.disableWarning(code); + } + else if (compat) + { + if (value == "enabled") + { + enabledLints.enableWarning(code); + fatalLints.disableWarning(code); + } + else if (value == "disabled") + { + enabledLints.disableWarning(code); + fatalLints.disableWarning(code); + } + else if (value == "fatal") + { + enabledLints.enableWarning(code); + fatalLints.enableWarning(code); + } + else + { + return Error{"Bad setting '" + value + "'. Valid options are enabled, disabled, and fatal"}; + } + } + else + { + return Error{"Bad setting '" + value + "'. Valid options are true and false"}; + } + + return std::nullopt; +} + +Error parseLintRuleString(LintOptions& enabledLints, LintOptions& fatalLints, const std::string& warningName, const std::string& value, bool compat) +{ + if (warningName == "*") + { + for (int code = LintWarning::Code_Unknown; code < LintWarning::Code__Count; ++code) + { + if (auto err = parseLintRuleStringForCode(enabledLints, fatalLints, LintWarning::Code(code), value, compat)) + return Error{"In key " + warningName + ": " + *err}; + } + } + else + { + LintWarning::Code code = LintWarning::parseName(warningName.c_str()); + + if (code == LintWarning::Code_Unknown) + return Error{"Unknown lint " + warningName}; + + if (auto err = parseLintRuleStringForCode(enabledLints, fatalLints, code, value, compat)) + return Error{"In key " + warningName + ": " + *err}; + } + + return std::nullopt; +} + +static void next(Lexer& lexer) +{ + lexer.next(); + + // skip C-style comments as Lexer only understands Lua-style comments atm + while (lexer.current().type == '/') + { + Lexeme peek = lexer.lookahead(); + + if (peek.type != '/' || peek.location.begin != lexer.current().location.end) + break; + + lexer.nextline(); + } +} + +static Error fail(Lexer& lexer, const char* message) +{ + Lexeme cur = lexer.current(); + + return format("Expected %s at line %d, got %s instead", message, cur.location.begin.line + 1, cur.toString().c_str()); +} + +template +static Error parseJson(const std::string& contents, Action action) +{ + Allocator allocator; + AstNameTable names(allocator); + Lexer lexer(contents.data(), contents.size(), names); + next(lexer); + + std::vector keys; + bool arrayTop = false; // we don't support nested arrays + + if (lexer.current().type != '{') + return fail(lexer, "'{'"); + next(lexer); + + for (;;) + { + if (arrayTop) + { + if (lexer.current().type == ']') + { + next(lexer); + arrayTop = false; + + LUAU_ASSERT(!keys.empty()); + keys.pop_back(); + + if (lexer.current().type == ',') + next(lexer); + else if (lexer.current().type != '}') + return fail(lexer, "',' or '}'"); + } + else if (lexer.current().type == Lexeme::QuotedString) + { + std::string value(lexer.current().data, lexer.current().length); + next(lexer); + + if (Error err = action(keys, value)) + return err; + + if (lexer.current().type == ',') + next(lexer); + else if (lexer.current().type != ']') + return fail(lexer, "',' or ']'"); + } + else + return fail(lexer, "array element or ']'"); + } + else + { + if (lexer.current().type == '}') + { + next(lexer); + + if (keys.empty()) + { + if (lexer.current().type != Lexeme::Eof) + return fail(lexer, "end of file"); + + return {}; + } + else + keys.pop_back(); + + if (lexer.current().type == ',') + next(lexer); + else if (lexer.current().type != '}') + return fail(lexer, "',' or '}'"); + } + else if (lexer.current().type == Lexeme::QuotedString) + { + std::string key(lexer.current().data, lexer.current().length); + next(lexer); + + keys.push_back(key); + + if (lexer.current().type != ':') + return fail(lexer, "':'"); + next(lexer); + + if (lexer.current().type == '{' || lexer.current().type == '[') + { + arrayTop = (lexer.current().type == '['); + next(lexer); + } + else if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::ReservedTrue || + lexer.current().type == Lexeme::ReservedFalse) + { + std::string value = lexer.current().type == Lexeme::QuotedString + ? std::string(lexer.current().data, lexer.current().length) + : (lexer.current().type == Lexeme::ReservedTrue ? "true" : "false"); + next(lexer); + + if (Error err = action(keys, value)) + return err; + + keys.pop_back(); + + if (lexer.current().type == ',') + next(lexer); + else if (lexer.current().type != '}') + return fail(lexer, "',' or '}'"); + } + else + return fail(lexer, "field value"); + } + else + return fail(lexer, "field key"); + } + } + + return {}; +} + +Error parseConfig(const std::string& contents, Config& config, bool compat) +{ + return parseJson(contents, [&](const std::vector& keys, const std::string& value) -> Error { + if (keys.size() == 1 && keys[0] == "languageMode") + return parseModeString(config.mode, value, compat); + else if (keys.size() == 2 && keys[0] == "lint") + return parseLintRuleString(config.enabledLint, config.fatalLint, keys[1], value, compat); + else if (keys.size() == 1 && keys[0] == "lintErrors") + return parseBoolean(config.lintErrors, value); + else if (keys.size() == 1 && keys[0] == "typeErrors") + return parseBoolean(config.typeErrors, value); + else if (keys.size() == 1 && keys[0] == "globals") + { + config.globals.push_back(value); + return std::nullopt; + } + else if (compat && keys.size() == 2 && keys[0] == "language" && keys[1] == "mode") + return parseModeString(config.mode, value, compat); + else + { + std::vector keysv(keys.begin(), keys.end()); + return "Unknown key " + join(keysv, "/"); + } + }); +} + +const Config& NullConfigResolver::getConfig(const ModuleName& name) const +{ + return defaultConfig; +} + +} // namespace Luau diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp new file mode 100644 index 0000000..61a63f0 --- /dev/null +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -0,0 +1,238 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/BuiltinDefinitions.h" + +LUAU_FASTFLAG(LuauParseGenericFunctions) +LUAU_FASTFLAG(LuauGenericFunctions) + +namespace Luau +{ + +static const std::string kBuiltinDefinitionLuaSrc = R"BUILTIN_SRC( + +declare bit32: { + -- band, bor, bxor, and btest are declared in C++ + rrotate: (number, number) -> number, + lrotate: (number, number) -> number, + lshift: (number, number) -> number, + arshift: (number, number) -> number, + rshift: (number, number) -> number, + bnot: (number) -> number, + extract: (number, number, number?) -> number, + replace: (number, number, number, number?) -> number, +} + +declare math: { + frexp: (number) -> (number, number), + ldexp: (number, number) -> number, + fmod: (number, number) -> number, + modf: (number) -> (number, number), + pow: (number, number) -> number, + exp: (number) -> number, + + ceil: (number) -> number, + floor: (number) -> number, + abs: (number) -> number, + sqrt: (number) -> number, + + log: (number, number?) -> number, + log10: (number) -> number, + + rad: (number) -> number, + deg: (number) -> number, + + sin: (number) -> number, + cos: (number) -> number, + tan: (number) -> number, + sinh: (number) -> number, + cosh: (number) -> number, + tanh: (number) -> number, + atan: (number) -> number, + acos: (number) -> number, + asin: (number) -> number, + atan2: (number, number) -> number, + + -- min and max are declared in C++. + + pi: number, + huge: number, + + randomseed: (number) -> (), + random: (number?, number?) -> number, + + sign: (number) -> number, + clamp: (number, number, number) -> number, + noise: (number, number?, number?) -> number, + round: (number) -> number, +} + +type DateTypeArg = { + year: number, + month: number, + day: number, + hour: number?, + min: number?, + sec: number?, + isdst: boolean?, +} + +type DateTypeResult = { + year: number, + month: number, + wday: number, + yday: number, + day: number, + hour: number, + min: number, + sec: number, + isdst: boolean, +} + +declare os: { + time: (DateTypeArg?) -> number, + date: (string?, number?) -> DateTypeResult | string, + difftime: (DateTypeResult | number, DateTypeResult | number) -> number, + clock: () -> number, +} + +declare function require(target: any): any + +declare function getfenv(target: any?): { [string]: any } + +declare _G: any +declare _VERSION: string + +declare function gcinfo(): number + +)BUILTIN_SRC"; + +std::string getBuiltinDefinitionSource() +{ + std::string src = kBuiltinDefinitionLuaSrc; + + if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) + { + src += R"( + declare function print(...: T...) + + declare function type(value: T): string + declare function typeof(value: T): string + + -- `assert` has a magic function attached that will give more detailed type information + declare function assert(value: T, errorMessage: string?): T + + declare function error(message: T, level: number?) + + declare function tostring(value: T): string + declare function tonumber(value: T, radix: number?): number + + declare function rawequal(a: T1, b: T2): boolean + declare function rawget(tab: {[K]: V}, k: K): V + declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} + + declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? + + declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number) + + declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) + + -- FIXME: The actual type of `xpcall` is: + -- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) + -- Since we can't represent the return value, we use (boolean, R1...). + declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) + + -- `select` has a magic function attached to provide more detailed type information + declare function select(i: string | number, ...: A...): ...any + + -- FIXME: This type is not entirely correct - `loadstring` returns a function or + -- (nil, string). + declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) + + -- a userdata object is "roughly" the same as a sealed empty table + -- except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too. + -- another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT + -- setmetatable. + -- FIXME: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`. + declare function newproxy(mt: boolean?): {} + + declare coroutine: { + create: ((A...) -> R...) -> thread, + resume: (thread, A...) -> (boolean, R...), + running: () -> thread, + status: (thread) -> string, + -- FIXME: This technically returns a function, but we can't represent this yet. + wrap: ((A...) -> R...) -> any, + yield: (A...) -> R..., + isyieldable: () -> boolean, + } + + declare table: { + concat: ({V}, string?, number?, number?) -> string, + insert: (({V}, V) -> ()) & (({V}, number, V) -> ()), + maxn: ({V}) -> number, + remove: ({V}, number?) -> V?, + sort: ({V}, ((V, V) -> boolean)?) -> (), + create: (number, V?) -> {V}, + find: ({V}, V, number?) -> number?, + + unpack: ({V}, number?, number?) -> ...V, + pack: (...V) -> { n: number, [number]: V }, + + getn: ({V}) -> number, + foreach: ({[K]: V}, (K, V) -> ()) -> (), + foreachi: ({V}, (number, V) -> ()) -> (), + + move: ({V}, number, number, number, {V}?) -> (), + clear: ({[K]: V}) -> (), + + freeze: ({[K]: V}) -> {[K]: V}, + isfrozen: ({[K]: V}) -> boolean, + } + + declare debug: { + info: ((thread, number, string) -> R...) & ((number, string) -> R...) & (((A...) -> R1..., string) -> R2...), + traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), + } + + declare utf8: { + char: (number, ...number) -> string, + charpattern: string, + codes: (string) -> ((string, number) -> (number, number), string, number), + -- FIXME + codepoint: (string, number?, number?) -> (number, ...number), + len: (string, number?, number?) -> (number?, number?), + offset: (string, number?, number?) -> number, + nfdnormalize: (string) -> string, + nfcnormalize: (string) -> string, + graphemes: (string, number?, number?) -> (() -> (number, number)), + } + + declare string: { + byte: (string, number?, number?) -> ...number, + char: (number, ...number) -> string, + find: (string, string, number?, boolean?) -> (number?, number?), + -- `string.format` has a magic function attached that will provide more type information for literal format strings. + format: (string, A...) -> string, + gmatch: (string, string) -> () -> (...string), + -- gsub is defined in C++ because we don't have syntax for describing a generic table. + len: (string) -> number, + lower: (string) -> string, + match: (string, string, number?) -> string?, + rep: (string, number) -> string, + reverse: (string) -> string, + sub: (string, number, number?) -> string, + upper: (string) -> string, + split: (string, string, string?) -> {string}, + pack: (string, A...) -> string, + packsize: (string) -> number, + unpack: (string, string, number?) -> R..., + } + + -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. + declare function unpack(tab: {V}, i: number?, j: number?): ...V + )"; + } + + return src; +} + +} // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp new file mode 100644 index 0000000..680bcf3 --- /dev/null +++ b/Analysis/src/Error.cpp @@ -0,0 +1,751 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Error.h" + +#include "Luau/Module.h" +#include "Luau/StringUtils.h" +#include "Luau/ToString.h" + +#include + +LUAU_FASTFLAG(LuauFasterStringifier) + +static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) +{ + std::string s = "expects " + std::to_string(expectedCount) + " "; + + if (isTypeArgs) + s += "type "; + + s += "argument"; + if (expectedCount != 1) + s += "s"; + + s += ", but "; + + if (actualCount == 0) + { + s += "none"; + } + else + { + if (actualCount < expectedCount) + s += "only "; + + s += std::to_string(actualCount); + } + + s += (actualCount == 1) ? " is" : " are"; + + s += " specified"; + + return s; +} + +namespace Luau +{ + +struct ErrorConverter +{ + std::string operator()(const Luau::TypeMismatch& tm) const + { + ToStringOptions opts; + return "Type '" + Luau::toString(tm.givenType, opts) + "' could not be converted into '" + Luau::toString(tm.wantedType, opts) + "'"; + } + + std::string operator()(const Luau::UnknownSymbol& e) const + { + switch (e.context) + { + case UnknownSymbol::Binding: + return "Unknown global '" + e.name + "'"; + case UnknownSymbol::Type: + return "Unknown type '" + e.name + "'"; + case UnknownSymbol::Generic: + return "Unknown generic '" + e.name + "'"; + } + + LUAU_ASSERT(!"Unexpected context for UnknownSymbol"); + return ""; + } + + std::string operator()(const Luau::UnknownProperty& e) const + { + TypeId t = follow(e.table); + if (get(t)) + return "Key '" + e.key + "' not found in table '" + Luau::toString(t) + "'"; + else if (get(t)) + return "Key '" + e.key + "' not found in class '" + Luau::toString(t) + "'"; + else + return "Type '" + Luau::toString(e.table) + "' does not have key '" + e.key + "'"; + } + + std::string operator()(const Luau::NotATable& e) const + { + return "Expected type table, got '" + Luau::toString(e.ty) + "' instead"; + } + + std::string operator()(const Luau::CannotExtendTable& e) const + { + switch (e.context) + { + case Luau::CannotExtendTable::Property: + return "Cannot add property '" + e.prop + "' to table '" + Luau::toString(e.tableType) + "'"; + case Luau::CannotExtendTable::Metatable: + return "Cannot add metatable to table '" + Luau::toString(e.tableType) + "'"; + case Luau::CannotExtendTable::Indexer: + return "Cannot add indexer to table '" + Luau::toString(e.tableType) + "'"; + } + + LUAU_ASSERT(!"Unknown context"); + return ""; + } + + std::string operator()(const Luau::OnlyTablesCanHaveMethods& e) const + { + return "Cannot add method to non-table type '" + Luau::toString(e.tableType) + "'"; + } + + std::string operator()(const Luau::DuplicateTypeDefinition& e) const + { + return "Redefinition of type '" + e.name + "', previously defined at line " + std::to_string(e.previousLocation.begin.line + 1); + } + + std::string operator()(const Luau::CountMismatch& e) const + { + switch (e.context) + { + case CountMismatch::Return: + { + const std::string expectedS = e.expected == 1 ? "" : "s"; + const std::string actualS = e.actual == 1 ? "is" : "are"; + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + actualS + + " returned here"; + } + case CountMismatch::Result: + if (e.expected > e.actual) + return "Function returns " + std::to_string(e.expected) + " values but there are only " + std::to_string(e.expected) + + " values to unpack them into."; + else + return "Function only returns " + std::to_string(e.expected) + " values. " + std::to_string(e.actual) + " are required here"; + case CountMismatch::Arg: + return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + } + + LUAU_ASSERT(!"Unknown context"); + return ""; + } + + std::string operator()(const Luau::FunctionDoesNotTakeSelf&) const + { + return std::string("This function does not take self. Did you mean to use a dot instead of a colon?"); + } + + std::string operator()(const Luau::FunctionRequiresSelf& e) const + { + if (e.requiredExtraNils) + { + const char* plural = e.requiredExtraNils == 1 ? "" : "s"; + return format("This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a dot or " + "pass %i extra nil%s to suppress this warning", + e.requiredExtraNils, plural); + } + else + return "This function must be called with self. Did you mean to use a colon instead of a dot?"; + } + + std::string operator()(const Luau::OccursCheckFailed&) const + { + return "Type contains a self-recursive construct that cannot be resolved"; + } + + std::string operator()(const Luau::UnknownRequire& e) const + { + return "Unknown require: " + e.modulePath; + } + + std::string operator()(const Luau::IncorrectGenericParameterCount& e) const + { + std::string name = e.name; + if (!e.typeFun.typeParams.empty()) + { + name += "<"; + bool first = true; + for (TypeId t : e.typeFun.typeParams) + { + if (first) + first = false; + else + name += ", "; + + name += toString(t); + } + name += ">"; + } + + return "Generic type '" + name + "' " + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); + } + + std::string operator()(const Luau::SyntaxError& e) const + { + return "Syntax error: " + e.message; + } + + std::string operator()(const Luau::CodeTooComplex&) const + { + return "Code is too complex to typecheck! Consider simplifying the code around this area"; + } + + std::string operator()(const Luau::UnificationTooComplex&) const + { + return "Internal error: Code is too complex to typecheck! Consider adding type annotations around this area"; + } + + std::string operator()(const Luau::UnknownPropButFoundLikeProp& e) const + { + std::string candidatesSuggestion = "Did you mean "; + if (e.candidates.size() != 1) + candidatesSuggestion += "one of "; + + bool first = true; + for (Name name : e.candidates) + { + if (first) + first = false; + else + candidatesSuggestion += ", "; + + candidatesSuggestion += "'" + name + "'"; + } + + std::string s = "Key '" + e.key + "' not found in "; + + TypeId t = follow(e.table); + if (get(t)) + s += "class"; + else + s += "table"; + + s += " '" + toString(e.table) + "'. " + candidatesSuggestion + "?"; + return s; + } + + std::string operator()(const Luau::GenericError& e) const + { + return e.message; + } + + std::string operator()(const Luau::CannotCallNonFunction& e) const + { + return "Cannot call non-function " + toString(e.ty); + } + std::string operator()(const Luau::ExtraInformation& e) const + { + return e.message; + } + + std::string operator()(const Luau::DeprecatedApiUsed& e) const + { + return "The property ." + e.symbol + " is deprecated. Use ." + e.useInstead + " instead."; + } + + std::string operator()(const Luau::ModuleHasCyclicDependency& e) const + { + if (e.cycle.empty()) + return "Cyclic module dependency detected"; + + std::string s = "Cyclic module dependency: "; + + bool first = true; + for (const ModuleName& name : e.cycle) + { + if (first) + first = false; + else + s += " -> "; + + s += name; + } + + return s; + } + + std::string operator()(const Luau::FunctionExitsWithoutReturning& e) const + { + return "Not all codepaths in this function return '" + toString(e.expectedReturnType) + "'."; + } + + std::string operator()(const Luau::IllegalRequire& e) const + { + return "Cannot require module " + e.moduleName + ": " + e.reason; + } + + std::string operator()(const Luau::MissingProperties& e) const + { + std::string s = "Table type '" + toString(e.subType) + "' not compatible with type '" + toString(e.superType) + "' because the former"; + + switch (e.context) + { + case MissingProperties::Missing: + s += " is missing field"; + break; + case MissingProperties::Extra: + s += " has extra field"; + break; + } + + if (e.properties.size() > 1) + s += "s"; + + s += " "; + + for (size_t i = 0; i < e.properties.size(); ++i) + { + if (i > 0) + s += ", "; + + if (i > 0 && i == e.properties.size() - 1) + s += "and "; + + s += "'" + e.properties[i] + "'"; + } + + return s; + } + + std::string operator()(const Luau::DuplicateGenericParameter& e) const + { + return "Duplicate type parameter '" + e.parameterName + "'"; + } + + std::string operator()(const Luau::CannotInferBinaryOperation& e) const + { + std::string ss = "Unknown type used in " + toString(e.op); + + switch (e.kind) + { + case Luau::CannotInferBinaryOperation::Comparison: + ss += " comparison"; + break; + case Luau::CannotInferBinaryOperation::Operation: + ss += " operation"; + } + + if (e.suggestedToAnnotate) + ss += "; consider adding a type annotation to '" + *e.suggestedToAnnotate + "'"; + + return ss; + } + + std::string operator()(const Luau::SwappedGenericTypeParameter& e) const + { + switch (e.kind) + { + case Luau::SwappedGenericTypeParameter::Type: + return "Variadic type parameter '" + e.name + "...' is used as a regular generic type; consider changing '" + e.name + "...' to '" + + e.name + "' in the generic argument list"; + case Luau::SwappedGenericTypeParameter::Pack: + return "Generic type '" + e.name + "' is used as a variadic type parameter; consider changing '" + e.name + "' to '" + e.name + + "...' in the generic argument list"; + default: + LUAU_ASSERT(!"Unknown kind"); + return ""; + } + } + + std::string operator()(const Luau::OptionalValueAccess& e) const + { + return "Value of type '" + toString(e.optional) + "' could be nil"; + } + + std::string operator()(const Luau::MissingUnionProperty& e) const + { + std::string ss = "Key '" + e.key + "' is missing from "; + + bool first = true; + for (auto ty : e.missing) + { + if (first) + first = false; + else + ss += ", "; + + ss += "'" + toString(ty) + "'"; + } + + return ss + " in the type '" + toString(e.type) + "'"; + } +}; + +struct InvalidNameChecker +{ + std::string invalidName = "%error-id%"; + + bool operator()(const Luau::UnknownProperty& e) const + { + return e.key == invalidName; + } + bool operator()(const Luau::CannotExtendTable& e) const + { + return e.prop == invalidName; + } + bool operator()(const Luau::DuplicateTypeDefinition& e) const + { + return e.name == invalidName; + } + + template + bool operator()(const T& other) const + { + return false; + } +}; + +bool TypeMismatch::operator==(const TypeMismatch& rhs) const +{ + return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType; +} + +bool UnknownSymbol::operator==(const UnknownSymbol& rhs) const +{ + return name == rhs.name; +} + +bool UnknownProperty::operator==(const UnknownProperty& rhs) const +{ + return *table == *rhs.table && key == rhs.key; +} + +bool NotATable::operator==(const NotATable& rhs) const +{ + return ty == rhs.ty; +} + +bool CannotExtendTable::operator==(const CannotExtendTable& rhs) const +{ + return *tableType == *rhs.tableType && prop == rhs.prop && context == rhs.context; +} + +bool OnlyTablesCanHaveMethods::operator==(const OnlyTablesCanHaveMethods& rhs) const +{ + return *tableType == *rhs.tableType; +} + +bool DuplicateTypeDefinition::operator==(const DuplicateTypeDefinition& rhs) const +{ + return name == rhs.name && previousLocation == rhs.previousLocation; +} + +bool CountMismatch::operator==(const CountMismatch& rhs) const +{ + return expected == rhs.expected && actual == rhs.actual && context == rhs.context; +} + +bool FunctionDoesNotTakeSelf::operator==(const FunctionDoesNotTakeSelf&) const +{ + return true; +} + +bool FunctionRequiresSelf::operator==(const FunctionRequiresSelf& e) const +{ + return requiredExtraNils == e.requiredExtraNils; +} + +bool OccursCheckFailed::operator==(const OccursCheckFailed&) const +{ + return true; +} + +bool UnknownRequire::operator==(const UnknownRequire& rhs) const +{ + return modulePath == rhs.modulePath; +} + +bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterCount& rhs) const +{ + if (name != rhs.name) + return false; + + if (typeFun.type != rhs.typeFun.type) + return false; + + if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size()) + return false; + + for (size_t i = 0; i < typeFun.typeParams.size(); ++i) + if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) + return false; + + return true; +} + +bool SyntaxError::operator==(const SyntaxError& rhs) const +{ + return message == rhs.message; +} + +bool CodeTooComplex::operator==(const CodeTooComplex&) const +{ + return true; +} + +bool UnificationTooComplex::operator==(const UnificationTooComplex&) const +{ + return true; +} + +bool UnknownPropButFoundLikeProp::operator==(const UnknownPropButFoundLikeProp& rhs) const +{ + return *table == *rhs.table && key == rhs.key && candidates.size() == rhs.candidates.size() && + std::equal(candidates.begin(), candidates.end(), rhs.candidates.begin()); +} + +bool GenericError::operator==(const GenericError& rhs) const +{ + return message == rhs.message; +} + +bool CannotCallNonFunction::operator==(const CannotCallNonFunction& rhs) const +{ + return ty == rhs.ty; +} + +bool ExtraInformation::operator==(const ExtraInformation& rhs) const +{ + return message == rhs.message; +} + +bool DeprecatedApiUsed::operator==(const DeprecatedApiUsed& rhs) const +{ + return symbol == rhs.symbol && useInstead == rhs.useInstead; +} + +bool FunctionExitsWithoutReturning::operator==(const FunctionExitsWithoutReturning& rhs) const +{ + return expectedReturnType == rhs.expectedReturnType; +} + +int TypeError::code() const +{ + return 1000 + int(data.index()); +} + +bool TypeError::operator==(const TypeError& rhs) const +{ + return location == rhs.location && data == rhs.data; +} + +bool ModuleHasCyclicDependency::operator==(const ModuleHasCyclicDependency& rhs) const +{ + return cycle.size() == rhs.cycle.size() && std::equal(cycle.begin(), cycle.end(), rhs.cycle.begin()); +} + +bool IllegalRequire::operator==(const IllegalRequire& rhs) const +{ + return moduleName == rhs.moduleName && reason == rhs.reason; +} + +bool MissingProperties::operator==(const MissingProperties& rhs) const +{ + return *superType == *rhs.superType && *subType == *rhs.subType && properties.size() == rhs.properties.size() && + std::equal(properties.begin(), properties.end(), rhs.properties.begin()) && context == rhs.context; +} + +bool DuplicateGenericParameter::operator==(const DuplicateGenericParameter& rhs) const +{ + return parameterName == rhs.parameterName; +} + +bool CannotInferBinaryOperation::operator==(const CannotInferBinaryOperation& rhs) const +{ + return op == rhs.op && suggestedToAnnotate == rhs.suggestedToAnnotate && kind == rhs.kind; +} + +bool SwappedGenericTypeParameter::operator==(const SwappedGenericTypeParameter& rhs) const +{ + return name == rhs.name && kind == rhs.kind; +} + +bool OptionalValueAccess::operator==(const OptionalValueAccess& rhs) const +{ + return *optional == *rhs.optional; +} + +bool MissingUnionProperty::operator==(const MissingUnionProperty& rhs) const +{ + if (missing.size() != rhs.missing.size()) + return false; + + for (size_t i = 0; i < missing.size(); ++i) + { + if (*missing[i] != *rhs.missing[i]) + return false; + } + + return *type == *rhs.type && key == rhs.key; +} + +std::string toString(const TypeError& error) +{ + ErrorConverter converter; + return Luau::visit(converter, error.data); +} + +bool containsParseErrorName(const TypeError& error) +{ + return Luau::visit(InvalidNameChecker{}, error.data); +} + +void copyErrors(ErrorVec& errors, struct TypeArena& destArena) +{ + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + auto clone = [&](auto&& ty) { + return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks); + }; + + auto visitErrorData = [&](auto&& e) { + using T = std::decay_t; + + if constexpr (false) + { + } + else if constexpr (std::is_same_v) + { + e.wantedType = clone(e.wantedType); + e.givenType = clone(e.givenType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.table = clone(e.table); + } + else if constexpr (std::is_same_v) + { + e.ty = clone(e.ty); + } + else if constexpr (std::is_same_v) + { + e.tableType = clone(e.tableType); + } + else if constexpr (std::is_same_v) + { + e.tableType = clone(e.tableType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.typeFun = clone(e.typeFun); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.table = clone(e.table); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.ty = clone(e.ty); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.expectedReturnType = clone(e.expectedReturnType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.superType = clone(e.superType); + e.subType = clone(e.subType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.optional = clone(e.optional); + } + else if constexpr (std::is_same_v) + { + e.type = clone(e.type); + + for (auto& ty : e.missing) + ty = clone(ty); + } + else + static_assert(always_false_v, "Non-exhaustive type switch"); + }; + + LUAU_ASSERT(!destArena.typeVars.isFrozen()); + LUAU_ASSERT(!destArena.typePacks.isFrozen()); + + for (TypeError& error : errors) + visit(visitErrorData, error.data); +} + +void InternalErrorReporter::ice(const std::string& message, const Location& location) +{ + std::runtime_error error("Internal error in " + moduleName + " at " + toString(location) + ": " + message); + + if (onInternalError) + onInternalError(error.what()); + + throw error; +} + +void InternalErrorReporter::ice(const std::string& message) +{ + std::runtime_error error("Internal error in " + moduleName + ": " + message); + + if (onInternalError) + onInternalError(error.what()); + + throw error; +} + +} // namespace Luau diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp new file mode 100644 index 0000000..4d385ec --- /dev/null +++ b/Analysis/src/Frontend.cpp @@ -0,0 +1,967 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Frontend.h" + +#include "Luau/Config.h" +#include "Luau/FileResolver.h" +#include "Luau/StringUtils.h" +#include "Luau/TypeInfer.h" +#include "Luau/Variant.h" +#include "Luau/Common.h" + +#include +#include +#include + +LUAU_FASTFLAG(LuauInferInNoCheckMode) +LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) +LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) +LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false) +LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) +LUAU_FASTFLAG(LuauTraceRequireLookupChild) +LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) + +namespace Luau +{ + +std::optional parseMode(const std::vector& hotcomments) +{ + for (const std::string& hc : hotcomments) + { + if (hc == "nocheck") + return Mode::NoCheck; + + if (hc == "nonstrict") + return Mode::Nonstrict; + + if (hc == "strict") + return Mode::Strict; + } + + return std::nullopt; +} + +static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) +{ + // TODO: What do we do in this situation? This means that the definition + // file is exporting a type that is also a persistent type. + if (ty->persistent) + { + return; + } + + asMutable(ty)->documentationSymbol = rootName; + + if (TableTypeVar* ttv = getMutable(ty)) + { + for (auto& [name, prop] : ttv->props) + { + prop.documentationSymbol = rootName + "." + name; + } + } + else if (ClassTypeVar* ctv = getMutable(ty)) + { + for (auto& [name, prop] : ctv->props) + { + prop.documentationSymbol = rootName + "." + name; + } + } +} + +LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) +{ + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + + ParseOptions options; + options.allowDeclarationSyntax = true; + + Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); + + if (parseResult.errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, nullptr}; + + Luau::SourceModule module; + module.root = parseResult.root; + module.mode = Mode::Definition; + + ModulePtr checkedModule = typeChecker.check(module, Mode::Definition); + + if (checkedModule->errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, checkedModule}; + + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + if (FFlag::LuauPersistDefinitionFileTypes) + persist(globalTy); + } + + for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; + + if (FFlag::LuauPersistDefinitionFileTypes) + persist(globalTy.type); + } + + return LoadDefinitionFileResult{true, parseResult, checkedModule}; +} + +std::vector parsePathExpr(const AstExpr& pathExpr) +{ + const AstExprIndexName* indexName = pathExpr.as(); + if (!indexName) + return {}; + + std::vector segments{indexName->index.value}; + + while (true) + { + if (AstExprIndexName* in = indexName->expr->as()) + { + segments.push_back(in->index.value); + indexName = in; + continue; + } + else if (AstExprGlobal* indexNameAsGlobal = indexName->expr->as()) + { + segments.push_back(indexNameAsGlobal->name.value); + break; + } + else if (AstExprLocal* indexNameAsLocal = indexName->expr->as()) + { + segments.push_back(indexNameAsLocal->local->name.value); + break; + } + else + return {}; + } + + std::reverse(segments.begin(), segments.end()); + return segments; +} + +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const std::vector& segments) +{ + if (segments.empty()) + return std::nullopt; + + std::vector result; + + auto it = segments.begin(); + + if (*it == "script" && !currentModuleName.empty()) + { + result = split(currentModuleName, '/'); + ++it; + } + + for (; it != segments.end(); ++it) + { + if (result.size() > 1 && *it == "Parent") + result.pop_back(); + else + result.push_back(*it); + } + + return join(result, "/"); +} + +std::optional pathExprToModuleName(const ModuleName& currentModuleName, const AstExpr& pathExpr) +{ + std::vector segments = parsePathExpr(pathExpr); + return pathExprToModuleName(currentModuleName, segments); +} + +namespace +{ + +ErrorVec accumulateErrors( + const std::unordered_map& sourceNodes, const std::unordered_map& modules, const ModuleName& name) +{ + std::unordered_set seen; + std::vector queue{name}; + + ErrorVec result; + + while (!queue.empty()) + { + ModuleName next = std::move(queue.back()); + queue.pop_back(); + + if (seen.count(next)) + continue; + seen.insert(next); + + auto it = sourceNodes.find(next); + if (it == sourceNodes.end()) + continue; + + const SourceNode& sourceNode = it->second; + queue.insert(queue.end(), sourceNode.requires.begin(), sourceNode.requires.end()); + + // FIXME: If a module has a syntax error, we won't be able to re-report it here. + // The solution is probably to move errors from Module to SourceNode + + auto it2 = modules.find(next); + if (it2 == modules.end()) + continue; + + Module& module = *it2->second; + + std::sort(module.errors.begin(), module.errors.end(), [](const TypeError& e1, const TypeError& e2) -> bool { + return e1.location.begin > e2.location.begin; + }); + + result.insert(result.end(), module.errors.begin(), module.errors.end()); + } + + std::reverse(result.begin(), result.end()); + + return result; +} + +struct RequireCycle +{ + Location location; + std::vector path; // one of the paths for a require() to go all the way back to the originating module +}; + +// Given a source node (start), find all requires that start a transitive dependency path that ends back at start +// For each such path, record the full path and the location of the require in the starting module. +// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) +// However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true) +std::vector getRequireCycles( + const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) +{ + std::vector result; + + DenseHashSet seen(nullptr); + std::vector stack; + std::vector path; + + for (const auto& [depName, depLocation] : start->requireLocations) + { + std::vector cycle; + + auto dit = sourceNodes.find(depName); + if (dit == sourceNodes.end()) + continue; + + stack.push_back(&dit->second); + + while (!stack.empty()) + { + const SourceNode* top = stack.back(); + stack.pop_back(); + + if (top == nullptr) + { + // special marker for post-order processing + LUAU_ASSERT(!path.empty()); + top = path.back(); + path.pop_back(); + + // we reached the node! path must form a cycle now + if (top == start) + { + for (const SourceNode* node : path) + cycle.push_back(node->name); + + cycle.push_back(top->name); + break; + } + } + else if (!seen.contains(top)) + { + seen.insert(top); + + // push marker for post-order processing + path.push_back(top); + stack.push_back(nullptr); + + // note: we push require edges in the opposite order + // because it's a stack, the last edge to be pushed gets processed first + // this ensures that the cyclic path we report is the first one in DFS order + for (size_t i = top->requireLocations.size(); i > 0; --i) + { + const ModuleName& reqName = top->requireLocations[i - 1].first; + + auto rit = sourceNodes.find(reqName); + if (rit != sourceNodes.end()) + stack.push_back(&rit->second); + } + } + } + + path.clear(); + stack.clear(); + + if (!cycle.empty()) + { + result.push_back({depLocation, std::move(cycle)}); + + if (stopAtFirst) + return result; + + // note: if we didn't find a cycle, all nodes that we've seen don't depend [transitively] on start + // so it's safe to *only* clear seen vector when we find a cycle + // if we don't do it, we will not have correct reporting for some cycles + seen.clear(); + } + } + + return result; +} + +double getTimestamp() +{ + using namespace std::chrono; + return double(duration_cast(high_resolution_clock::now().time_since_epoch()).count()) / 1e9; +} + +} // namespace + +Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options) + : fileResolver(fileResolver) + , moduleResolver(this) + , moduleResolverForAutocomplete(this) + , typeChecker(&moduleResolver, &iceHandler) + , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, &iceHandler) + , configResolver(configResolver) + , options(options) +{ +} + +FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) + : frontend(frontend) +{ +} + +CheckResult Frontend::check(const ModuleName& name) +{ + CheckResult checkResult; + + auto it = sourceNodes.find(name); + if (it != sourceNodes.end() && !it->second.dirty) + { + // No recheck required. + auto it2 = moduleResolver.modules.find(name); + if (it2 == moduleResolver.modules.end() || it2->second == nullptr) + throw std::runtime_error("Frontend::modules does not have data for " + name); + + return CheckResult{accumulateErrors(sourceNodes, moduleResolver.modules, name)}; + } + + std::vector buildQueue; + bool cycleDetected = parseGraph(buildQueue, checkResult, name); + + // Keep track of which AST nodes we've reported cycles in + std::unordered_set reportedCycles; + + for (const ModuleName& moduleName : buildQueue) + { + LUAU_ASSERT(sourceNodes.count(moduleName)); + SourceNode& sourceNode = sourceNodes[moduleName]; + + if (!sourceNode.dirty) + continue; + + LUAU_ASSERT(sourceModules.count(moduleName)); + SourceModule& sourceModule = sourceModules[moduleName]; + + const Config& config = configResolver->getConfig(moduleName); + + Mode mode = sourceModule.mode.value_or(config.mode); + + ScopePtr environmentScope = getModuleEnvironment(sourceModule, config); + + double timestamp = getTimestamp(); + + std::vector requireCycles; + + // in NoCheck mode we only need to compute the value of .cyclic for typeck + // in the future we could replace toposort with an algorithm that can flag cyclic nodes by itself + // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term + // all correct programs must be acyclic so this code triggers rarely + if (cycleDetected) + requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck); + + // This is used by the type checker to replace the resulting type of cyclic modules with any + sourceModule.cyclic = !requireCycles.empty(); + + ModulePtr module = typeChecker.check(sourceModule, mode, environmentScope); + + // If we're typechecking twice, we do so. + // The second typecheck is always in strict mode with DM awareness + // to provide better typen information for IDE features. + if (options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel) + { + ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); + moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; + } + else if (options.retainFullTypeGraphs && options.typecheckTwice && mode != Mode::Strict) + { + ModulePtr strictModule = typeChecker.check(sourceModule, Mode::Strict, environmentScope); + module->astTypes.clear(); + module->astOriginalCallTypes.clear(); + module->astExpectedTypes.clear(); + + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + for (const auto& [expr, strictTy] : strictModule->astTypes) + module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + + for (const auto& [expr, strictTy] : strictModule->astOriginalCallTypes) + module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + + for (const auto& [expr, strictTy] : strictModule->astExpectedTypes) + module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + } + + stats.timeCheck += getTimestamp() - timestamp; + stats.filesStrict += mode == Mode::Strict; + stats.filesNonstrict += mode == Mode::Nonstrict; + + if (module == nullptr) + throw std::runtime_error("Frontend::check produced a nullptr module for " + moduleName); + + if (!options.retainFullTypeGraphs) + { + // copyErrors needs to allocate into interfaceTypes as it copies + // types out of internalTypes, so we unfreeze it here. + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes); + freeze(module->interfaceTypes); + + module->internalTypes.clear(); + module->astTypes.clear(); + module->astExpectedTypes.clear(); + module->astOriginalCallTypes.clear(); + } + + if (mode != Mode::NoCheck) + { + for (const RequireCycle& cyc : requireCycles) + { + TypeError te{cyc.location, moduleName, ModuleHasCyclicDependency{cyc.path}}; + + module->errors.push_back(te); + } + } + + ErrorVec parseErrors; + + for (const ParseError& pe : sourceModule.parseErrors) + parseErrors.push_back(TypeError{pe.getLocation(), moduleName, SyntaxError{pe.what()}}); + + module->errors.insert(module->errors.begin(), parseErrors.begin(), parseErrors.end()); + + checkResult.errors.insert(checkResult.errors.end(), module->errors.begin(), module->errors.end()); + + moduleResolver.modules[moduleName] = std::move(module); + sourceNode.dirty = false; + } + + return checkResult; +} + +bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root) +{ + // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search + enum Mark + { + None, + Temporary, + Permanent + }; + + DenseHashMap seen(nullptr); + std::vector stack; + std::vector path; + bool cyclic = false; + + { + auto [sourceNode, _] = getSourceNode(checkResult, root); + if (sourceNode) + stack.push_back(sourceNode); + } + + while (!stack.empty()) + { + SourceNode* top = stack.back(); + stack.pop_back(); + + if (top == nullptr) + { + // special marker for post-order processing + LUAU_ASSERT(!path.empty()); + + top = path.back(); + path.pop_back(); + + // note: topseen ref gets invalidated in any seen[] access, beware - only one seen[] access per iteration! + Mark& topseen = seen[top]; + LUAU_ASSERT(topseen == Temporary); + topseen = Permanent; + + buildQueue.push_back(top->name); + } + else + { + // note: topseen ref gets invalidated in any seen[] access, beware - only one seen[] access per iteration! + Mark& topseen = seen[top]; + + if (topseen != None) + { + cyclic |= topseen == Temporary; + continue; + } + + topseen = Temporary; + + // push marker for post-order processing + stack.push_back(nullptr); + path.push_back(top); + + // push children + for (const ModuleName& dep : top->requires) + { + auto it = sourceNodes.find(dep); + if (it != sourceNodes.end()) + { + // this is a critical optimization: we do *not* traverse non-dirty subtrees. + // this relies on the fact that markDirty marks reverse-dependencies dirty as well + // thus if a node is not dirty, all its transitive deps aren't dirty, which means that they won't ever need + // to be built, *and* can't form a cycle with any nodes we did process. + if (!it->second.dirty) + continue; + + // note: this check is technically redundant *except* that getSourceNode has somewhat broken memoization + // calling getSourceNode twice in succession will reparse the file, since getSourceNode leaves dirty flag set + if (seen.contains(&it->second)) + { + stack.push_back(&it->second); + continue; + } + } + + auto [sourceNode, _] = getSourceNode(checkResult, dep); + if (sourceNode) + { + stack.push_back(sourceNode); + + // note: this assignment is paired with .contains() check above and effectively deduplicates getSourceNode() + seen[sourceNode] = None; + } + } + } + } + + return cyclic; +} + +ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config) +{ + ScopePtr result = typeChecker.globalScope; + + if (module.environmentName) + result = getEnvironmentScope(*module.environmentName); + + if (!config.globals.empty()) + { + result = std::make_shared(result); + + for (const std::string& global : config.globals) + { + AstName name = module.names->get(global.c_str()); + + if (name.value) + result->bindings[name].typeId = typeChecker.anyType; + } + } + + return result; +} + +LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) +{ + CheckResult checkResult; + auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name); + + if (!sourceModule) + return LintResult{}; // FIXME: We really should do something a bit more obvious when a file is too broken to lint. + + return lint(*sourceModule, enabledLintWarnings); +} + +std::pair Frontend::lintFragment(std::string_view source, std::optional enabledLintWarnings) +{ + const Config& config = configResolver->getConfig(""); + + SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions); + + Luau::LintOptions lintOptions = enabledLintWarnings.value_or(config.enabledLint); + lintOptions.warningMask &= sourceModule.ignoreLints; + + double timestamp = getTimestamp(); + + std::vector warnings = + Luau::lint(sourceModule.root, *sourceModule.names.get(), typeChecker.globalScope, nullptr, enabledLintWarnings.value_or(config.enabledLint)); + + stats.timeLint += getTimestamp() - timestamp; + + return {std::move(sourceModule), classifyLints(warnings, config)}; +} + +CheckResult Frontend::check(const SourceModule& module) +{ + const Config& config = configResolver->getConfig(module.name); + + Mode mode = module.mode.value_or(config.mode); + + double timestamp = getTimestamp(); + + ModulePtr checkedModule = typeChecker.check(module, mode); + + stats.timeCheck += getTimestamp() - timestamp; + stats.filesStrict += mode == Mode::Strict; + stats.filesNonstrict += mode == Mode::Nonstrict; + + if (checkedModule == nullptr) + throw std::runtime_error("Frontend::check produced a nullptr module for module " + module.name); + moduleResolver.modules[module.name] = checkedModule; + + return CheckResult{checkedModule->errors}; +} + +LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) +{ + const Config& config = configResolver->getConfig(module.name); + + LintOptions options = enabledLintWarnings.value_or(config.enabledLint); + options.warningMask &= ~module.ignoreLints; + + Mode mode = module.mode.value_or(config.mode); + if (mode != Mode::NoCheck) + { + options.disableWarning(Luau::LintWarning::Code_UnknownGlobal); + } + + if (mode == Mode::Strict) + { + options.disableWarning(Luau::LintWarning::Code_ImplicitReturn); + } + + ScopePtr environmentScope = getModuleEnvironment(module, config); + + ModulePtr modulePtr = moduleResolver.getModule(module.name); + + double timestamp = getTimestamp(); + + std::vector warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), options); + + stats.timeLint += getTimestamp() - timestamp; + + return classifyLints(warnings, config); +} + +bool Frontend::isDirty(const ModuleName& name) const +{ + auto it = sourceNodes.find(name); + return it == sourceNodes.end() || it->second.dirty; +} + +/* + * Mark a file as requiring rechecking before its type information can be safely used again. + * + * I am not particularly pleased with the way each dirty() operation involves a BFS on reverse dependencies. + * It would be nice for this function to be O(1) + */ +void Frontend::markDirty(const ModuleName& name, std::vector* markedDirty) +{ + if (!moduleResolver.modules.count(name)) + return; + + std::unordered_map> reverseDeps; + for (const auto& module : sourceNodes) + { + for (const auto& dep : module.second.requires) + reverseDeps[dep].push_back(module.first); + } + + std::vector queue{name}; + + while (!queue.empty()) + { + ModuleName next = std::move(queue.back()); + queue.pop_back(); + + LUAU_ASSERT(sourceNodes.count(next) > 0); + SourceNode& sourceNode = sourceNodes[next]; + + if (markedDirty) + markedDirty->push_back(next); + + if (sourceNode.dirty) + continue; + + sourceNode.dirty = true; + + if (0 == reverseDeps.count(name)) + continue; + + sourceModules.erase(name); + + const std::vector& dependents = reverseDeps[name]; + queue.insert(queue.end(), dependents.begin(), dependents.end()); + } +} + +SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) +{ + auto it = sourceModules.find(moduleName); + if (it != sourceModules.end()) + return &it->second; + else + return nullptr; +} + +const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) const +{ + return const_cast(this)->getSourceModule(moduleName); +} + +// Read AST into sourceModules if necessary. Trace require()s. Report parse errors. +std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) +{ + auto it = sourceNodes.find(name); + if (it != sourceNodes.end() && !it->second.dirty) + { + auto moduleIt = sourceModules.find(name); + if (moduleIt != sourceModules.end()) + return {&it->second, &moduleIt->second}; + else + { + LUAU_ASSERT(!"Everything in sourceNodes should also be in sourceModules"); + return {&it->second, nullptr}; + } + } + + double timestamp = getTimestamp(); + + std::optional source = fileResolver->readSource(name); + std::optional environmentName = fileResolver->getEnvironmentForModule(name); + + stats.timeRead += getTimestamp() - timestamp; + + if (!source) + { + sourceModules.erase(name); + return {nullptr, nullptr}; + } + + const Config& config = configResolver->getConfig(name); + ParseOptions opts = config.parseOptions; + opts.captureComments = true; + SourceModule result = parse(name, source->source, opts); + result.type = source->type; + + RequireTraceResult& requireTrace = requires[name]; + requireTrace = traceRequires(fileResolver, result.root, name); + + SourceNode& sourceNode = sourceNodes[name]; + SourceModule& sourceModule = sourceModules[name]; + + sourceModule = std::move(result); + sourceModule.environmentName = environmentName; + + sourceNode.name = name; + sourceNode.requires.clear(); + sourceNode.requireLocations.clear(); + sourceNode.dirty = true; + + for (const auto& [moduleName, location] : requireTrace.requires) + sourceNode.requires.insert(moduleName); + + sourceNode.requireLocations = requireTrace.requires; + + return {&sourceNode, &sourceModule}; +} + +/** Try to parse a source file into a SourceModule. + * + * The logic here is a little bit more complicated than we'd like it to be. + * + * If a file does not exist, we return none to prevent the Frontend from creating knowledge that this module exists. + * If the Frontend thinks that the file exists, it will not produce an "Unknown require" error. + * + * If the file has syntax errors, we report them and synthesize an empty AST if it's not available. + * This suppresses the Unknown require error and allows us to make a best effort to typecheck code that require()s + * something that has broken syntax. + * We also translate Luau::ParseError into a Luau::TypeError so that we can use a vector to describe the + * result of the check() + */ +SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions) +{ + SourceModule sourceModule; + + double timestamp = getTimestamp(); + + auto parseResult = Luau::Parser::parse(src.data(), src.size(), *sourceModule.names, *sourceModule.allocator, parseOptions); + + stats.timeParse += getTimestamp() - timestamp; + stats.files++; + stats.lines += std::count(src.begin(), src.end(), '\n') + (src.size() && src.back() != '\n'); + + if (!parseResult.errors.empty()) + sourceModule.parseErrors.insert(sourceModule.parseErrors.end(), parseResult.errors.begin(), parseResult.errors.end()); + + if (parseResult.errors.empty() || parseResult.root) + { + sourceModule.root = parseResult.root; + sourceModule.mode = parseMode(parseResult.hotcomments); + sourceModule.ignoreLints = LintWarning::parseMask(parseResult.hotcomments); + } + else + { + sourceModule.root = sourceModule.allocator->alloc(Location{}, AstArray{nullptr, 0}); + sourceModule.mode = Mode::NoCheck; + } + + sourceModule.name = name; + if (parseOptions.captureComments) + sourceModule.commentLocations = std::move(parseResult.commentLocations); + return sourceModule; +} + +std::optional FrontendModuleResolver::resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) +{ + // FIXME I think this can be pushed into the FileResolver. + auto it = frontend->requires.find(currentModuleName); + if (it == frontend->requires.end()) + { + // CLI-43699 + // If we can't find the current module name, that's because we bypassed the frontend's initializer + // and called typeChecker.check directly. (This is done by autocompleteSource, for example). + // In that case, requires will always fail. + if (FFlag::LuauResolveModuleNameWithoutACurrentModule) + return std::nullopt; + else + throw std::runtime_error("Frontend::resolveModuleName: Unknown currentModuleName '" + currentModuleName + "'"); + } + + const auto& exprs = it->second.exprs; + + const ModuleName* relativeName = exprs.find(&pathExpr); + if (!relativeName || relativeName->empty()) + return std::nullopt; + + if (FFlag::LuauTraceRequireLookupChild) + { + const bool* optional = it->second.optional.find(&pathExpr); + + return {{*relativeName, optional ? *optional : false}}; + } + else + { + return {{*relativeName, false}}; + } +} + +const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const +{ + auto it = modules.find(moduleName); + if (it != modules.end()) + return it->second; + else + return nullptr; +} + +bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const +{ + return frontend->fileResolver->moduleExists(moduleName); +} + +std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const +{ + return frontend->fileResolver->getHumanReadableModuleName_(moduleName).value_or(moduleName); +} + +ScopePtr Frontend::addEnvironment(const std::string& environmentName) +{ + LUAU_ASSERT(environments.count(environmentName) == 0); + + if (environments.count(environmentName) == 0) + { + ScopePtr scope = std::make_shared(typeChecker.globalScope); + environments[environmentName] = scope; + return scope; + } + else + return environments[environmentName]; +} + +ScopePtr Frontend::getEnvironmentScope(const std::string& environmentName) +{ + LUAU_ASSERT(environments.count(environmentName) > 0); + + return environments[environmentName]; +} + +void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) +{ + LUAU_ASSERT(builtinDefinitions.count(name) == 0); + + if (builtinDefinitions.count(name) == 0) + builtinDefinitions[name] = applicator; +} + +void Frontend::applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName) +{ + LUAU_ASSERT(builtinDefinitions.count(definitionName) > 0); + + if (builtinDefinitions.count(definitionName) > 0) + builtinDefinitions[definitionName](typeChecker, getEnvironmentScope(environmentName)); +} + +LintResult Frontend::classifyLints(const std::vector& warnings, const Config& config) +{ + LintResult result; + for (const auto& w : warnings) + { + if (config.lintErrors || config.fatalLint.isEnabled(w.code)) + result.errors.push_back(w); + else + result.warnings.push_back(w); + } + + return result; +} + +void Frontend::clearStats() +{ + stats = {}; +} + +void Frontend::clear() +{ + sourceNodes.clear(); + sourceModules.clear(); + moduleResolver.modules.clear(); + moduleResolverForAutocomplete.modules.clear(); + requires.clear(); +} + +} // namespace Luau diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp new file mode 100644 index 0000000..84e9b77 --- /dev/null +++ b/Analysis/src/IostreamHelpers.cpp @@ -0,0 +1,280 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/IostreamHelpers.h" +#include "Luau/ToString.h" + +namespace Luau +{ + +std::ostream& operator<<(std::ostream& stream, const Position& position) +{ + return stream << "{ line = " << position.line << ", col = " << position.column << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const Location& location) +{ + return stream << "Location { " << location.begin << ", " << location.end << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const AstName& name) +{ + if (name.value) + return stream << name.value; + else + return stream << ""; +} + +std::ostream& operator<<(std::ostream& stream, const TypeMismatch& tm) +{ + return stream << "TypeMismatch { " << toString(tm.wantedType) << ", " << toString(tm.givenType) << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const TypeError& error) +{ + return stream << "TypeError { \"" << error.moduleName << "\", " << error.location << ", " << error.data << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const UnknownSymbol& error) +{ + return stream << "UnknownSymbol { " << error.name << " , context " << error.context << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const UnknownProperty& error) +{ + return stream << "UnknownProperty { " << toString(error.table) << ", key = " << error.key << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const NotATable& ge) +{ + return stream << "NotATable { " << toString(ge.ty) << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const CannotExtendTable& error) +{ + return stream << "CannotExtendTable { " << toString(error.tableType) << ", context " << error.context << ", prop \"" << error.prop << "\" }"; +} + +std::ostream& operator<<(std::ostream& stream, const OnlyTablesCanHaveMethods& error) +{ + return stream << "OnlyTablesCanHaveMethods { " << toString(error.tableType) << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const DuplicateTypeDefinition& error) +{ + return stream << "DuplicateTypeDefinition { " << error.name << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const CountMismatch& error) +{ + return stream << "CountMismatch { expected " << error.expected << ", got " << error.actual << ", context " << error.context << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const FunctionDoesNotTakeSelf&) +{ + return stream << "FunctionDoesNotTakeSelf { }"; +} + +std::ostream& operator<<(std::ostream& stream, const FunctionRequiresSelf& error) +{ + return stream << "FunctionRequiresSelf { extraNils " << error.requiredExtraNils << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const OccursCheckFailed&) +{ + return stream << "OccursCheckFailed { }"; +} + +std::ostream& operator<<(std::ostream& stream, const UnknownRequire& error) +{ + return stream << "UnknownRequire { " << error.modulePath << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCount& error) +{ + stream << "IncorrectGenericParameterCount { name = " << error.name; + + if (!error.typeFun.typeParams.empty()) + { + stream << "<"; + bool first = true; + for (TypeId t : error.typeFun.typeParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(t); + } + stream << ">"; + } + + stream << ", typeFun = " << toString(error.typeFun.type) << ", actualCount = " << error.actualParameters << " }"; + return stream; +} + +std::ostream& operator<<(std::ostream& stream, const SyntaxError& ge) +{ + return stream << "SyntaxError { " << ge.message << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const CodeTooComplex&) +{ + return stream << "CodeTooComplex {}"; +} + +std::ostream& operator<<(std::ostream& stream, const UnificationTooComplex&) +{ + return stream << "UnificationTooComplex {}"; +} + +std::ostream& operator<<(std::ostream& stream, const UnknownPropButFoundLikeProp& e) +{ + stream << "UnknownPropButFoundLikeProp { key = '" << e.key << "', suggested = { "; + + bool first = true; + for (Name name : e.candidates) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << name << "'"; + } + + return stream << " }, table = " << toString(e.table) << " } "; +} + +std::ostream& operator<<(std::ostream& stream, const GenericError& ge) +{ + return stream << "GenericError { " << ge.message << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const CannotCallNonFunction& e) +{ + return stream << "CannotCallNonFunction { " << toString(e.ty) << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const FunctionExitsWithoutReturning& error) +{ + return stream << "FunctionExitsWithoutReturning {" << toString(error.expectedReturnType) << "}"; +} + +std::ostream& operator<<(std::ostream& stream, const ExtraInformation& e) +{ + return stream << "ExtraInformation { " << e.message << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const DeprecatedApiUsed& e) +{ + return stream << "DeprecatedApiUsed { " << e.symbol << ", useInstead = " << e.useInstead << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const ModuleHasCyclicDependency& e) +{ + stream << "ModuleHasCyclicDependency {"; + + bool first = true; + for (const ModuleName& name : e.cycle) + { + if (first) + first = false; + else + stream << ", "; + + stream << name; + } + + return stream << "}"; +} + +std::ostream& operator<<(std::ostream& stream, const IllegalRequire& e) +{ + return stream << "IllegalRequire { " << e.moduleName << ", reason = " << e.reason << " }"; +} + +std::ostream& operator<<(std::ostream& stream, const MissingProperties& e) +{ + stream << "MissingProperties { superType = '" << toString(e.superType) << "', subType = '" << toString(e.subType) << "', properties = { "; + + bool first = true; + for (Name name : e.properties) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << name << "'"; + } + + return stream << " }, context " << e.context << " } "; +} + +std::ostream& operator<<(std::ostream& stream, const DuplicateGenericParameter& error) +{ + return stream << "DuplicateGenericParameter { " + error.parameterName + " }"; +} + +std::ostream& operator<<(std::ostream& stream, const CannotInferBinaryOperation& error) +{ + return stream << "CannotInferBinaryOperation { op = " + toString(error.op) + ", suggested = '" + + (error.suggestedToAnnotate ? *error.suggestedToAnnotate : "") + "', kind " + << error.kind << "}"; +} + +std::ostream& operator<<(std::ostream& stream, const SwappedGenericTypeParameter& error) +{ + return stream << "SwappedGenericTypeParameter { name = '" + error.name + "', kind = " + std::to_string(error.kind) + " }"; +} + +std::ostream& operator<<(std::ostream& stream, const OptionalValueAccess& error) +{ + return stream << "OptionalValueAccess { optional = '" + toString(error.optional) + "' }"; +} + +std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error) +{ + stream << "MissingUnionProperty { type = '" + toString(error.type) + "', missing = { "; + + bool first = true; + for (auto ty : error.missing) + { + if (first) + first = false; + else + stream << ", "; + + stream << "'" << toString(ty) << "'"; + } + + return stream << " }, key = '" + error.key + "' }"; +} + +std::ostream& operator<<(std::ostream& stream, const TableState& tv) +{ + return stream << static_cast::type>(tv); +} + +std::ostream& operator<<(std::ostream& stream, const TypeVar& tv) +{ + return stream << toString(tv); +} + +std::ostream& operator<<(std::ostream& stream, const TypePackVar& tv) +{ + return stream << toString(tv); +} + +std::ostream& operator<<(std::ostream& lhs, const TypeErrorData& ted) +{ + Luau::visit( + [&](const auto& a) { + lhs << a; + }, + ted); + + return lhs; +} + +} // namespace Luau diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp new file mode 100644 index 0000000..a101829 --- /dev/null +++ b/Analysis/src/JsonEncoder.cpp @@ -0,0 +1,1041 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/JsonEncoder.h" + +#include "Luau/Ast.h" +#include "Luau/StringUtils.h" + +namespace Luau +{ + +struct AstJsonEncoder : public AstVisitor +{ + static constexpr int CHUNK_SIZE = 1024; + std::vector chunks; + bool comma = false; + + AstJsonEncoder() + { + newChunk(); + } + + std::string str() + { + return join(chunks, ""); + } + + bool pushComma() + { + bool c = comma; + comma = false; + return c; + } + + void popComma(bool c) + { + comma = c; + } + + void newChunk() + { + chunks.emplace_back(); + chunks.back().reserve(CHUNK_SIZE); + } + + void appendChunk(std::string_view sv) + { + if (sv.size() > CHUNK_SIZE) + { + chunks.emplace_back(sv); + newChunk(); + return; + } + + auto& chunk = chunks.back(); + if (chunk.size() + sv.size() < CHUNK_SIZE) + { + chunk.append(sv.data(), sv.size()); + return; + } + + size_t prefix = CHUNK_SIZE - chunk.size(); + chunk.append(sv.data(), prefix); + newChunk(); + + chunks.back().append(sv.data() + prefix, sv.size() - prefix); + } + + void writeRaw(std::string_view sv) + { + appendChunk(sv); + } + + void writeRaw(char c) + { + writeRaw(std::string_view{&c, 1}); + } + + template + void write(std::string_view propName, const T& value) + { + if (comma) + writeRaw(","); + comma = true; + writeRaw("\""); + writeRaw(propName); + writeRaw("\":"); + write(value); + } + + void write(bool b) + { + if (b) + writeRaw("true"); + else + writeRaw("false"); + } + + void write(double d) + { + char b[256]; + sprintf(b, "%g", d); + writeRaw(b); + } + + void writeString(std::string_view sv) + { + // TODO escape more accurately? + writeRaw("\""); + + for (char c : sv) + { + if (c == '"') + writeRaw("\\\""); + else if (c == '\0') + writeRaw("\\\0"); + else + writeRaw(c); + } + + writeRaw("\""); + } + + void write(char c) + { + writeString(std::string_view(&c, 1)); + } + void write(int i) + { + writeRaw(std::to_string(i)); + } + void write(long i) + { + writeRaw(std::to_string(i)); + } + void write(long long i) + { + writeRaw(std::to_string(i)); + } + void write(unsigned int i) + { + writeRaw(std::to_string(i)); + } + void write(unsigned long i) + { + writeRaw(std::to_string(i)); + } + void write(unsigned long long i) + { + writeRaw(std::to_string(i)); + } + void write(std::string_view str) + { + writeString(str); + } + void write(AstName name) + { + writeString(name.value ? name.value : ""); + } + + void write(const Position& position) + { + write(position.line); + writeRaw(","); + write(position.column); + } + + void write(const Location& location) + { + writeRaw("\""); + write(location.begin); + writeRaw(" - "); + write(location.end); + writeRaw("\""); + } + + void write(AstLocal* local) + { + write(local->name); + } + + void writeNode(AstNode* node) + { + write("location", node->location); + } + + template + void writeNode(AstNode* node, std::string_view name, F&& f) + { + writeRaw("{"); + bool c = pushComma(); + write("type", name); + writeNode(node); + f(); + popComma(c); + writeRaw("}"); + } + + void write(AstNode* node) + { + node->visit(this); + } + + void write(class AstExprGroup* node) + { + writeNode(node, "AstExprGroup", [&]() { + write("expr", node->expr); + }); + } + + void write(class AstExprConstantNil* node) + { + writeNode(node, "AstExprConstantNil", []() {}); + } + + void write(class AstExprConstantBool* node) + { + writeNode(node, "AstExprConstantBool", [&]() { + write("value", node->value); + }); + } + + void write(class AstExprConstantNumber* node) + { + writeNode(node, "AstExprConstantNumber", [&]() { + write("value", node->value); + }); + } + + void write(class AstExprConstantString* node) + { + writeNode(node, "AstExprConstantString", [&]() { + write("value", node->value); + }); + } + + void write(class AstExprLocal* node) + { + writeNode(node, "AstExprLocal", [&]() { + write("local", node->local); + }); + } + + void write(class AstExprGlobal* node) + { + writeNode(node, "AstExprGlobal", [&]() { + write("global", node->name); + }); + } + + void write(class AstExprVarargs* node) + { + writeNode(node, "AstExprVarargs", []() {}); + } + + template + void write(AstArray arr) + { + writeRaw("["); + bool comma = false; + for (const auto& a : arr) + { + if (comma) + writeRaw(","); + else + comma = false; + + write(a); + } + writeRaw("]"); + } + + void write(AstArray arr) + { + write(std::string_view{arr.data, arr.size}); + } + +#define PROP(prop) write(#prop, node->prop) + + void write(class AstExprCall* node) + { + writeNode(node, "AstExprCall", [&]() { + PROP(func); + PROP(args); + PROP(self); + PROP(argLocation); + }); + } + + void write(class AstExprIndexName* node) + { + writeNode(node, "AstExprIndexName", [&]() { + PROP(expr); + PROP(index); + PROP(indexLocation); + PROP(op); + }); + } + + void write(class AstExprIndexExpr* node) + { + writeNode(node, "AstExprIndexExpr", [&]() { + PROP(expr); + PROP(index); + }); + } + + void write(class AstExprFunction* node) + { + writeNode(node, "AstExprFunction", [&]() { + PROP(generics); + PROP(genericPacks); + if (node->self) + PROP(self); + PROP(args); + if (node->hasReturnAnnotation) + PROP(returnAnnotation); + PROP(vararg); + PROP(varargLocation); + if (node->varargAnnotation) + PROP(varargAnnotation); + + PROP(body); + PROP(functionDepth); + PROP(debugname); + PROP(hasEnd); + }); + } + + void write(const AstTypeList& typeList) + { + writeRaw("{"); + bool c = pushComma(); + write("types", typeList.types); + if (typeList.tailType) + write("tailType", typeList.tailType); + popComma(c); + writeRaw("}"); + } + + void write(AstExprTable::Item::Kind kind) + { + switch (kind) + { + case AstExprTable::Item::List: + return writeString("item"); + case AstExprTable::Item::Record: + return writeString("record"); + case AstExprTable::Item::General: + return writeString("general"); + } + } + + void write(const AstExprTable::Item& item) + { + writeRaw("{"); + bool comma = pushComma(); + write("kind", item.kind); + switch (item.kind) + { + case AstExprTable::Item::List: + write(item.value); + break; + default: + write(item.key); + writeRaw(","); + write(item.value); + break; + } + popComma(comma); + writeRaw("}"); + } + + void write(class AstExprTable* node) + { + writeNode(node, "AstExprTable", [&]() { + bool comma = false; + for (const auto& prop : node->items) + { + if (comma) + writeRaw(","); + else + comma = false; + write(prop); + } + }); + } + + void write(AstExprUnary::Op op) + { + switch (op) + { + case AstExprUnary::Not: + return writeString("not"); + case AstExprUnary::Minus: + return writeString("minus"); + case AstExprUnary::Len: + return writeString("len"); + } + } + + void write(class AstExprUnary* node) + { + writeNode(node, "AstExprUnary", [&]() { + PROP(op); + PROP(expr); + }); + } + + void write(AstExprBinary::Op op) + { + switch (op) + { + case AstExprBinary::Add: + return writeString("Add"); + case AstExprBinary::Sub: + return writeString("Sub"); + case AstExprBinary::Mul: + return writeString("Mul"); + case AstExprBinary::Div: + return writeString("Div"); + case AstExprBinary::Mod: + return writeString("Mod"); + case AstExprBinary::Pow: + return writeString("Pow"); + case AstExprBinary::Concat: + return writeString("Concat"); + case AstExprBinary::CompareNe: + return writeString("CompareNe"); + case AstExprBinary::CompareEq: + return writeString("CompareEq"); + case AstExprBinary::CompareLt: + return writeString("CompareLt"); + case AstExprBinary::CompareLe: + return writeString("CompareLe"); + case AstExprBinary::CompareGt: + return writeString("CompareGt"); + case AstExprBinary::CompareGe: + return writeString("CompareGe"); + case AstExprBinary::And: + return writeString("And"); + case AstExprBinary::Or: + return writeString("Or"); + } + } + + void write(class AstExprBinary* node) + { + writeNode(node, "AstExprBinary", [&]() { + PROP(op); + PROP(left); + PROP(right); + }); + } + + void write(class AstExprTypeAssertion* node) + { + writeNode(node, "AstExprTypeAssertion", [&]() { + PROP(expr); + PROP(annotation); + }); + } + + void write(class AstExprError* node) + { + writeNode(node, "AstExprError", [&]() { + PROP(expressions); + PROP(messageIndex); + }); + } + + void write(class AstStatBlock* node) + { + writeNode(node, "AstStatBlock", [&]() { + writeRaw(",\"body\":["); + bool comma = false; + for (AstStat* stat : node->body) + { + if (comma) + writeRaw(","); + else + comma = true; + + write(stat); + } + writeRaw("]"); + }); + } + + void write(class AstStatIf* node) + { + writeNode(node, "AstStatIf", [&]() { + PROP(condition); + PROP(thenbody); + if (node->elsebody) + PROP(elsebody); + PROP(hasThen); + PROP(hasEnd); + }); + } + + void write(class AstStatWhile* node) + { + writeNode(node, "AtStatWhile", [&]() { + PROP(condition); + PROP(body); + PROP(hasDo); + PROP(hasEnd); + }); + } + + void write(class AstStatRepeat* node) + { + writeNode(node, "AstStatRepeat", [&]() { + PROP(condition); + PROP(body); + PROP(hasUntil); + }); + } + + void write(class AstStatBreak* node) + { + writeNode(node, "AstStatBreak", []() {}); + } + + void write(class AstStatContinue* node) + { + writeNode(node, "AstStatContinue", []() {}); + } + + void write(class AstStatReturn* node) + { + writeNode(node, "AstStatReturn", [&]() { + PROP(list); + }); + } + + void write(class AstStatExpr* node) + { + writeNode(node, "AstStatExpr", [&]() { + PROP(expr); + }); + } + + void write(class AstStatLocal* node) + { + writeNode(node, "AstStatLocal", [&]() { + PROP(vars); + PROP(values); + }); + } + + void write(class AstStatFor* node) + { + writeNode(node, "AstStatFor", [&]() { + PROP(var); + PROP(from); + PROP(to); + if (node->step) + PROP(step); + PROP(body); + PROP(hasDo); + PROP(hasEnd); + }); + } + + void write(class AstStatForIn* node) + { + writeNode(node, "AstStatForIn", [&]() { + PROP(vars); + PROP(values); + PROP(body); + PROP(hasIn); + PROP(hasDo); + PROP(hasEnd); + }); + } + + void write(class AstStatAssign* node) + { + writeNode(node, "AstStatAssign", [&]() { + PROP(vars); + PROP(values); + }); + } + + void write(class AstStatCompoundAssign* node) + { + writeNode(node, "AstStatCompoundAssign", [&]() { + PROP(op); + PROP(var); + PROP(value); + }); + } + + void write(class AstStatFunction* node) + { + writeNode(node, "AstStatFunction", [&]() { + PROP(name); + PROP(func); + }); + } + + void write(class AstStatLocalFunction* node) + { + writeNode(node, "AstStatLocalFunction", [&]() { + PROP(name); + PROP(func); + }); + } + + void write(class AstStatTypeAlias* node) + { + writeNode(node, "AstStatTypeAlias", [&]() { + PROP(name); + PROP(generics); + PROP(type); + PROP(exported); + }); + } + + void write(class AstStatDeclareFunction* node) + { + writeNode(node, "AstStatDeclareFunction", [&]() { + PROP(name); + PROP(params); + PROP(retTypes); + PROP(generics); + PROP(genericPacks); + }); + } + + void write(class AstStatDeclareGlobal* node) + { + writeNode(node, "AstStatDeclareGlobal", [&]() { + PROP(name); + PROP(type); + }); + } + + void write(const AstDeclaredClassProp& prop) + { + writeRaw("{"); + bool c = pushComma(); + write("name", prop.name); + write("type", prop.ty); + popComma(c); + writeRaw("}"); + } + + void write(class AstStatDeclareClass* node) + { + writeNode(node, "AstStatDeclareClass", [&]() { + PROP(name); + if (node->superName) + write("superName", *node->superName); + PROP(props); + }); + } + + void write(class AstStatError* node) + { + writeNode(node, "AstStatError", [&]() { + PROP(expressions); + PROP(statements); + }); + } + + void write(class AstTypeReference* node) + { + writeNode(node, "AstTypeReference", [&]() { + if (node->hasPrefix) + PROP(prefix); + PROP(name); + PROP(generics); + }); + } + + void write(const AstTableProp& prop) + { + writeRaw("{"); + bool c = pushComma(); + + write("name", prop.name); + write("location", prop.location); + write("type", prop.type); + + popComma(c); + writeRaw("}"); + } + + void write(class AstTypeTable* node) + { + writeNode(node, "AstTypeTable", [&]() { + PROP(props); + PROP(indexer); + }); + } + + void write(class AstTypeFunction* node) + { + writeNode(node, "AstTypeFunction", [&]() { + PROP(generics); + PROP(genericPacks); + PROP(argTypes); + PROP(returnTypes); + }); + } + + void write(class AstTypeTypeof* node) + { + writeNode(node, "AstTypeTypeof", [&]() { + PROP(expr); + }); + } + + void write(class AstTypeUnion* node) + { + writeNode(node, "AstTypeUnion", [&]() { + PROP(types); + }); + } + + void write(class AstTypeIntersection* node) + { + writeNode(node, "AstTypeIntersection", [&]() { + PROP(types); + }); + } + + void write(class AstTypeError* node) + { + writeNode(node, "AstTypeError", [&]() { + PROP(types); + PROP(messageIndex); + }); + } + + void write(class AstTypePackVariadic* node) + { + writeNode(node, "AstTypePackVariadic", [&]() { + PROP(variadicType); + }); + } + + void write(class AstTypePackGeneric* node) + { + writeNode(node, "AstTypePackGeneric", [&]() { + PROP(genericName); + }); + } + + bool visit(class AstExprGroup* node) override + { + write(node); + return false; + } + + bool visit(class AstExprConstantNil* node) override + { + write(node); + return false; + } + + bool visit(class AstExprConstantBool* node) override + { + write(node); + return false; + } + + bool visit(class AstExprConstantNumber* node) override + { + write(node); + return false; + } + + bool visit(class AstExprConstantString* node) override + { + write(node); + return false; + } + + bool visit(class AstExprLocal* node) override + { + write(node); + return false; + } + + bool visit(class AstExprGlobal* node) override + { + write(node); + return false; + } + + bool visit(class AstExprVarargs* node) override + { + write(node); + return false; + } + + bool visit(class AstExprCall* node) override + { + write(node); + return false; + } + + bool visit(class AstExprIndexName* node) override + { + write(node); + return false; + } + + bool visit(class AstExprIndexExpr* node) override + { + write(node); + return false; + } + + bool visit(class AstExprFunction* node) override + { + write(node); + return false; + } + + bool visit(class AstExprTable* node) override + { + write(node); + return false; + } + + bool visit(class AstExprUnary* node) override + { + write(node); + return false; + } + + bool visit(class AstExprBinary* node) override + { + write(node); + return false; + } + + bool visit(class AstExprTypeAssertion* node) override + { + write(node); + return false; + } + + bool visit(class AstExprError* node) override + { + write(node); + return false; + } + + bool visit(class AstStatBlock* node) override + { + write(node); + return false; + } + + bool visit(class AstStatIf* node) override + { + write(node); + return false; + } + + bool visit(class AstStatWhile* node) override + { + write(node); + return false; + } + + bool visit(class AstStatRepeat* node) override + { + write(node); + return false; + } + + bool visit(class AstStatBreak* node) override + { + write(node); + return false; + } + + bool visit(class AstStatContinue* node) override + { + write(node); + return false; + } + + bool visit(class AstStatReturn* node) override + { + write(node); + return false; + } + + bool visit(class AstStatExpr* node) override + { + write(node); + return false; + } + + bool visit(class AstStatLocal* node) override + { + write(node); + return false; + } + + bool visit(class AstStatFor* node) override + { + write(node); + return false; + } + + bool visit(class AstStatForIn* node) override + { + write(node); + return false; + } + + bool visit(class AstStatAssign* node) override + { + write(node); + return false; + } + + bool visit(class AstStatCompoundAssign* node) override + { + write(node); + return false; + } + + bool visit(class AstStatFunction* node) override + { + write(node); + return false; + } + + bool visit(class AstStatLocalFunction* node) override + { + write(node); + return false; + } + + bool visit(class AstStatTypeAlias* node) override + { + write(node); + return false; + } + + bool visit(class AstStatDeclareFunction* node) override + { + write(node); + return false; + } + + bool visit(class AstStatDeclareGlobal* node) override + { + write(node); + return false; + } + + bool visit(class AstStatDeclareClass* node) override + { + write(node); + return false; + } + + bool visit(class AstStatError* node) override + { + write(node); + return false; + } + + bool visit(class AstTypeReference* node) override + { + write(node); + return false; + } + + bool visit(class AstTypeTable* node) override + { + write(node); + return false; + } + + bool visit(class AstTypeFunction* node) override + { + write(node); + return false; + } + + bool visit(class AstTypeTypeof* node) override + { + write(node); + return false; + } + + bool visit(class AstTypeUnion* node) override + { + write(node); + return false; + } + + bool visit(class AstTypeIntersection* node) override + { + write(node); + return false; + } + + bool visit(class AstTypeError* node) override + { + write(node); + return false; + } + + bool visit(class AstTypePack* node) override + { + write(node); + return false; + } + + bool visit(class AstTypePackVariadic* node) override + { + write(node); + return false; + } + + bool visit(class AstTypePackGeneric* node) override + { + write(node); + return false; + } +}; + +std::string toJson(AstNode* node) +{ + AstJsonEncoder encoder; + node->visit(&encoder); + return encoder.str(); +} + +} // namespace Luau diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp new file mode 100644 index 0000000..f97f6a4 --- /dev/null +++ b/Analysis/src/Linter.cpp @@ -0,0 +1,2568 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Linter.h" + +#include "Luau/AstQuery.h" +#include "Luau/Module.h" +#include "Luau/TypeInfer.h" +#include "Luau/StringUtils.h" +#include "Luau/Common.h" + +#include +#include +#include + +LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false) + +namespace Luau +{ + +// clang-format off +static const char* kWarningNames[] = { + "Unknown", + + "UnknownGlobal", + "DeprecatedGlobal", + "GlobalUsedAsLocal", + "LocalShadow", + "SameLineStatement", + "MultiLineStatement", + "LocalUnused", + "FunctionUnused", + "ImportUnused", + "BuiltinGlobalWrite", + "PlaceholderRead", + "UnreachableCode", + "UnknownType", + "ForRange", + "UnbalancedAssignment", + "ImplicitReturn", + "DuplicateLocal", + "FormatString", + "TableLiteral", + "UninitializedLocal", + "DuplicateFunction", + "DeprecatedApi", + "TableOperations", + "DuplicateCondition", +}; +// clang-format on + +static_assert(std::size(kWarningNames) == unsigned(LintWarning::Code__Count), "did you forget to add warning to the list?"); + +struct LintContext +{ + struct Global + { + TypeId type = nullptr; + std::optional deprecated; + }; + + std::vector result; + LintOptions options; + + AstStat* root; + + AstName placeholder; + DenseHashMap builtinGlobals; + ScopePtr scope; + const Module* module; + + LintContext() + : root(nullptr) + , builtinGlobals(AstName()) + , module(nullptr) + { + } + + bool warningEnabled(LintWarning::Code code) + { + return (options.warningMask & (1ull << code)) != 0; + } + + std::optional getType(AstExpr* expr) + { + if (!module) + return std::nullopt; + + auto it = module->astTypes.find(expr); + if (it == module->astTypes.end()) + return std::nullopt; + + return it->second; + } +}; + +struct WarningComparator +{ + int compare(const Position& lhs, const Position& rhs) const + { + if (lhs.line != rhs.line) + return lhs.line < rhs.line ? -1 : 1; + if (lhs.column != rhs.column) + return lhs.column < rhs.column ? -1 : 1; + return 0; + } + + int compare(const Location& lhs, const Location& rhs) const + { + if (int c = compare(lhs.begin, rhs.begin)) + return c; + if (int c = compare(lhs.end, rhs.end)) + return c; + return 0; + } + + bool operator()(const LintWarning& lhs, const LintWarning& rhs) const + { + if (int c = compare(lhs.location, rhs.location)) + return c < 0; + + return lhs.code < rhs.code; + } +}; + +LUAU_PRINTF_ATTR(4, 5) +static void emitWarning(LintContext& context, LintWarning::Code code, const Location& location, const char* format, ...) +{ + if (!context.warningEnabled(code)) + return; + + va_list args; + va_start(args, format); + std::string message = vformat(format, args); + va_end(args); + + LintWarning warning = {code, location, message}; + context.result.push_back(warning); +} + +static bool similar(AstExpr* lhs, AstExpr* rhs) +{ + if (lhs->classIndex != rhs->classIndex) + return false; + +#define CASE(T) else if (T* le = lhs->as(), *re = rhs->as(); le && re) + + if (false) + return false; + CASE(AstExprGroup) return similar(le->expr, re->expr); + CASE(AstExprConstantNil) return true; + CASE(AstExprConstantBool) return le->value == re->value; + CASE(AstExprConstantNumber) return le->value == re->value; + CASE(AstExprConstantString) return le->value.size == re->value.size && memcmp(le->value.data, re->value.data, le->value.size) == 0; + CASE(AstExprLocal) return le->local == re->local; + CASE(AstExprGlobal) return le->name == re->name; + CASE(AstExprVarargs) return true; + CASE(AstExprIndexName) return le->index == re->index && similar(le->expr, re->expr); + CASE(AstExprIndexExpr) return similar(le->expr, re->expr) && similar(le->index, re->index); + CASE(AstExprFunction) return false; // rarely meaningful in context of this pass, avoids having to process statement nodes + CASE(AstExprUnary) return le->op == re->op && similar(le->expr, re->expr); + CASE(AstExprBinary) return le->op == re->op && similar(le->left, re->left) && similar(le->right, re->right); + CASE(AstExprTypeAssertion) return le->expr == re->expr; // the type doesn't affect execution semantics, avoids having to process type nodes + CASE(AstExprError) return false; + CASE(AstExprCall) + { + if (le->args.size != re->args.size || le->self != re->self) + return false; + + if (!similar(le->func, re->func)) + return false; + + for (size_t i = 0; i < le->args.size; ++i) + if (!similar(le->args.data[i], re->args.data[i])) + return false; + + return true; + } + CASE(AstExprTable) + { + if (le->items.size != re->items.size) + return false; + + for (size_t i = 0; i < le->items.size; ++i) + { + const AstExprTable::Item& li = le->items.data[i]; + const AstExprTable::Item& ri = re->items.data[i]; + + if (li.kind != ri.kind) + return false; + + if (bool(li.key) != bool(ri.key)) + return false; + else if (li.key && !similar(li.key, ri.key)) + return false; + + if (!similar(li.value, ri.value)) + return false; + } + + return true; + } + else + { + LUAU_ASSERT(!"Unknown expression type"); + return false; + } + +#undef CASE +} + +class LintGlobalLocal : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintGlobalLocal pass; + pass.context = &context; + + for (auto& global : context.builtinGlobals) + { + Global& g = pass.globals[global.first]; + + g.builtin = true; + g.deprecated = global.second.deprecated; + } + + context.root->visit(&pass); + + pass.report(); + } + +private: + struct Global + { + AstExprGlobal* firstRef = nullptr; + + std::vector functionRef; + + bool assigned = false; + bool builtin = false; + std::optional deprecated; + }; + + LintContext* context; + + DenseHashMap globals; + std::vector globalRefs; + std::vector functionStack; + + LintGlobalLocal() + : globals(AstName()) + { + } + + void report() + { + for (size_t i = 0; i < globalRefs.size(); ++i) + { + AstExprGlobal* gv = globalRefs[i]; + Global* g = globals.find(gv->name); + + if (!g || (!g->assigned && !g->builtin)) + emitWarning(*context, LintWarning::Code_UnknownGlobal, gv->location, "Unknown global '%s'", gv->name.value); + else if (g->deprecated) + { + if (*g->deprecated) + emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", + gv->name.value, *g->deprecated); + else + emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); + } + } + + for (auto& global : globals) + { + const Global& g = global.second; + + if (g.functionRef.size() && g.assigned && g.firstRef->name != context->placeholder) + { + AstExprFunction* top = g.functionRef.back(); + + if (top->debugname.value) + emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, + "Global '%s' is only used in the enclosing function '%s'; consider changing it to local", g.firstRef->name.value, + top->debugname.value); + else + emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, + "Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local", + g.firstRef->name.value, top->location.begin.line + 1); + } + } + } + + bool visit(AstExprFunction* node) override + { + functionStack.push_back(node); + + node->body->visit(this); + + functionStack.pop_back(); + + return false; + } + + bool visit(AstExprGlobal* node) override + { + trackGlobalRef(node); + + if (node->name == context->placeholder) + emitWarning( + *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"); + + return true; + } + + bool visit(AstExprLocal* node) override + { + if (node->local->name == context->placeholder) + emitWarning( + *context, LintWarning::Code_PlaceholderRead, node->location, "Placeholder value '_' is read here; consider using a named variable"); + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + { + AstExpr* var = node->vars.data[i]; + + if (AstExprGlobal* gv = var->as()) + { + Global& g = globals[gv->name]; + + if (g.builtin) + emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, + "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); + else + g.assigned = true; + + trackGlobalRef(gv); + } + else if (var->is()) + { + // We don't visit locals here because it's a local *write*, and visit(AstExprLocal*) assumes it's a local *read* + } + else + { + var->visit(this); + } + } + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatFunction* node) override + { + if (AstExprGlobal* gv = node->name->as()) + { + Global& g = globals[gv->name]; + + if (g.builtin) + emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, + "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); + else + g.assigned = true; + + trackGlobalRef(gv); + } + + return true; + } + + void trackGlobalRef(AstExprGlobal* node) + { + Global& g = globals[node->name]; + + globalRefs.push_back(node); + + if (!g.firstRef) + { + g.firstRef = node; + + // to reduce the cost of tracking we only track this for user globals + if (!g.builtin) + { + g.functionRef = functionStack; + } + } + else + { + // to reduce the cost of tracking we only track this for user globals + if (!g.builtin) + { + // we need to find a common prefix between all uses of a global + size_t prefix = 0; + + while (prefix < g.functionRef.size() && prefix < functionStack.size() && g.functionRef[prefix] == functionStack[prefix]) + prefix++; + + g.functionRef.resize(prefix); + } + } + } +}; + +class LintSameLineStatement : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintSameLineStatement pass; + + pass.context = &context; + pass.lastLine = ~0u; + + context.root->visit(&pass); + } + +private: + LintContext* context; + unsigned int lastLine; + + bool visit(AstStatBlock* node) override + { + for (size_t i = 1; i < node->body.size; ++i) + { + const Location& last = node->body.data[i - 1]->location; + const Location& location = node->body.data[i]->location; + + if (location.begin.line != last.end.line) + continue; + + // We warn once per line with multiple statements + if (location.begin.line == lastLine) + continue; + + // There's a common pattern where local variables are computed inside a do block that starts on the same line; we white-list this pattern + if (node->body.data[i - 1]->is() && node->body.data[i]->is()) + continue; + + // Another common pattern is using multiple statements on the same line with semi-colons on each of them. White-list this pattern too. + if (node->body.data[i - 1]->hasSemicolon) + continue; + + emitWarning(*context, LintWarning::Code_SameLineStatement, location, + "A new statement is on the same line; add semi-colon on previous statement to silence"); + + lastLine = location.begin.line; + } + + return true; + } +}; + +class LintMultiLineStatement : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintMultiLineStatement pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + struct Statement + { + Location start; + unsigned int lastLine; + bool flagged; + }; + + std::vector stack; + + bool visit(AstExpr* node) override + { + Statement& top = stack.back(); + + if (!top.flagged) + { + Location location = node->location; + + if (location.begin.line > top.lastLine) + { + top.lastLine = location.begin.line; + + if (location.begin.column <= top.start.begin.column) + { + emitWarning( + *context, LintWarning::Code_MultiLineStatement, location, "Statement spans multiple lines; use indentation to silence"); + + top.flagged = true; + } + } + } + + return true; + } + + bool visit(AstExprTable* node) override + { + (void)node; + + return false; + } + + bool visit(AstStatRepeat* node) override + { + node->body->visit(this); + + return false; + } + + bool visit(AstStatBlock* node) override + { + for (size_t i = 0; i < node->body.size; ++i) + { + AstStat* stmt = node->body.data[i]; + + Statement s = {stmt->location, stmt->location.begin.line, false}; + stack.push_back(s); + + stmt->visit(this); + + stack.pop_back(); + } + + return false; + } +}; + +class LintLocalHygiene : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintLocalHygiene pass; + pass.context = &context; + + for (auto& global : context.builtinGlobals) + pass.globals[global.first].builtin = true; + + context.root->visit(&pass); + + pass.report(); + } + +private: + LintContext* context; + + struct Local + { + AstNode* defined = nullptr; + bool function; + bool import; + bool used; + bool arg; + }; + + struct Global + { + bool used; + bool builtin; + AstExprGlobal* firstRef; + }; + + DenseHashMap locals; + DenseHashMap imports; + DenseHashMap globals; + + LintLocalHygiene() + : locals(NULL) + , imports(AstName()) + , globals(AstName()) + { + } + + void report() + { + for (auto& l : locals) + { + if (l.second.used) + reportUsedLocal(l.first, l.second); + else if (l.second.defined) + reportUnusedLocal(l.first, l.second); + } + } + + void reportUsedLocal(AstLocal* local, const Local& info) + { + if (AstLocal* shadow = local->shadow) + { + // LintDuplicateFunctions will catch this. + Local* shadowLocal = locals.find(shadow); + if (context->options.isEnabled(LintWarning::Code_DuplicateFunction) && info.function && shadowLocal && shadowLocal->function) + return; + + // LintDuplicateLocal will catch this. + if (context->options.isEnabled(LintWarning::Code_DuplicateLocal) && shadowLocal && shadowLocal->defined == info.defined) + return; + + // don't warn on inter-function shadowing since it is much more fragile wrt refactoring + if (shadow->functionDepth == local->functionDepth) + emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows previous declaration at line %d", + local->name.value, shadow->location.begin.line + 1); + } + else if (Global* global = globals.find(local->name)) + { + if (global->builtin) + ; // there are many builtins with common names like 'table'; some of them are deprecated as well + else if (global->firstRef) + { + emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows a global variable used at line %d", + local->name.value, global->firstRef->location.begin.line + 1); + } + else + { + emitWarning(*context, LintWarning::Code_LocalShadow, local->location, "Variable '%s' shadows a global variable", local->name.value); + } + } + } + + void reportUnusedLocal(AstLocal* local, const Local& info) + { + if (local->name.value[0] == '_') + return; + + if (info.function) + emitWarning(*context, LintWarning::Code_FunctionUnused, local->location, "Function '%s' is never used; prefix with '_' to silence", + local->name.value); + else if (info.import) + emitWarning(*context, LintWarning::Code_ImportUnused, local->location, "Import '%s' is never used; prefix with '_' to silence", + local->name.value); + else + emitWarning(*context, LintWarning::Code_LocalUnused, local->location, "Variable '%s' is never used; prefix with '_' to silence", + local->name.value); + } + + bool isRequireCall(AstExpr* expr) + { + AstExprCall* call = expr->as(); + if (!call) + return false; + + AstExprGlobal* glob = call->func->as(); + if (!glob) + return false; + + return glob->name == "require"; + } + + bool visit(AstStatAssign* node) override + { + for (AstExpr* var : node->vars) + { + // We don't visit locals here because it's a local *write*, and visit(AstExprLocal*) assumes it's a local *read* + if (!var->is()) + var->visit(this); + } + + for (AstExpr* value : node->values) + value->visit(this); + + return false; + } + + bool visit(AstStatLocal* node) override + { + if (node->vars.size == 1 && node->values.size == 1) + { + Local& l = locals[node->vars.data[0]]; + + l.defined = node; + l.import = isRequireCall(node->values.data[0]); + + if (l.import) + imports[node->vars.data[0]->name] = node->vars.data[0]; + } + else + { + for (size_t i = 0; i < node->vars.size; ++i) + { + Local& l = locals[node->vars.data[i]]; + + l.defined = node; + } + } + + return true; + } + + bool visit(AstStatLocalFunction* node) override + { + Local& l = locals[node->name]; + + l.defined = node; + l.function = true; + + return true; + } + + bool visit(AstExprLocal* node) override + { + Local& l = locals[node->local]; + + l.used = true; + + return true; + } + + bool visit(AstExprGlobal* node) override + { + Global& global = globals[node->name]; + + global.used = true; + if (!global.firstRef) + global.firstRef = node; + + return true; + } + + bool visit(AstType* node) override + { + return true; + } + + bool visit(AstTypeReference* node) override + { + if (!node->hasPrefix) + return true; + + if (!imports.contains(node->prefix)) + return true; + + AstLocal* astLocal = imports[node->prefix]; + Local& local = locals[astLocal]; + LUAU_ASSERT(local.import); + local.used = true; + + return true; + } + + bool visit(AstExprFunction* node) override + { + if (node->self) + locals[node->self].arg = true; + + for (size_t i = 0; i < node->args.size; ++i) + locals[node->args.data[i]].arg = true; + + return true; + } +}; + +class LintUnusedFunction : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintUnusedFunction pass; + pass.context = &context; + + context.root->visit(&pass); + + pass.report(); + } + +private: + LintContext* context; + + struct Global + { + Location location; + bool function; + bool used; + }; + + DenseHashMap globals; + + LintUnusedFunction() + : globals(AstName()) + { + } + + void report() + { + for (auto& g : globals) + { + if (g.second.function && !g.second.used && g.first.value[0] != '_') + emitWarning(*context, LintWarning::Code_FunctionUnused, g.second.location, "Function '%s' is never used; prefix with '_' to silence", + g.first.value); + } + } + + bool visit(AstStatFunction* node) override + { + if (AstExprGlobal* expr = node->name->as()) + { + Global& g = globals[expr->name]; + + g.function = true; + g.location = expr->location; + + node->func->visit(this); + + return false; + } + + return true; + } + + bool visit(AstExprGlobal* node) override + { + Global& g = globals[node->name]; + + g.used = true; + + return true; + } +}; + +class LintUnreachableCode : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintUnreachableCode pass; + pass.context = &context; + + pass.analyze(context.root); + context.root->visit(&pass); + } + +private: + LintContext* context; + + // Note: this enum is order-sensitive! + // The order is in the "severity" of the termination and affects merging of status codes from different branches + // For example, if one branch breaks and one returns, the merged result is "break" + enum Status + { + Unknown, + Continue, + Break, + Return, + Error, + }; + + const char* getReason(Status status) + { + switch (status) + { + case Continue: + return "continue"; + + case Break: + return "break"; + + case Return: + return "return"; + + case Error: + return "error"; + + default: + return "unknown"; + } + } + + Status analyze(AstStat* node) + { + if (AstStatBlock* stat = node->as()) + { + for (size_t i = 0; i < stat->body.size; ++i) + { + AstStat* si = stat->body.data[i]; + Status step = analyze(si); + + if (step != Unknown) + { + if (i + 1 == stat->body.size) + return step; + + AstStat* next = stat->body.data[i + 1]; + + // silence the warning for common pattern of Error (coming from error()) + Return + if (step == Error && si->is() && next->is() && i + 2 == stat->body.size) + return Error; + + emitWarning(*context, LintWarning::Code_UnreachableCode, next->location, "Unreachable code (previous statement always %ss)", + getReason(step)); + return step; + } + } + + return Unknown; + } + else if (AstStatIf* stat = node->as()) + { + Status ifs = analyze(stat->thenbody); + Status elses = stat->elsebody ? analyze(stat->elsebody) : Unknown; + + return std::min(ifs, elses); + } + else if (AstStatWhile* stat = node->as()) + { + analyze(stat->body); + + return Unknown; + } + else if (AstStatRepeat* stat = node->as()) + { + analyze(stat->body); + + return Unknown; + } + else if (node->is()) + { + return Break; + } + else if (node->is()) + { + return Continue; + } + else if (node->is()) + { + return Return; + } + else if (AstStatExpr* stat = node->as()) + { + if (AstExprCall* call = stat->expr->as()) + if (doesCallError(call)) + return Error; + + return Unknown; + } + else if (AstStatFor* stat = node->as()) + { + analyze(stat->body); + + return Unknown; + } + else if (AstStatForIn* stat = node->as()) + { + analyze(stat->body); + + return Unknown; + } + else + { + return Unknown; + } + } + + bool visit(AstExprFunction* node) override + { + analyze(node->body); + + return true; + } +}; + +class LintUnknownType : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintUnknownType pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + enum TypeKind + { + Kind_Invalid, + Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata. + Kind_Vector, // For 'vector' but only used when type is used + Kind_Userdata, // custom userdata type - Vector3/etc. + Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc. + Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc. + }; + + bool containsPropName(TypeId ty, const std::string& propName) + { + if (auto ctv = get(ty)) + return lookupClassProp(ctv, propName) != nullptr; + + if (auto ttv = get(ty)) + return ttv->props.find(propName) != ttv->props.end(); + + return false; + } + + TypeKind getTypeKind(const std::string& name) + { + if (name == "nil" || name == "boolean" || name == "userdata" || name == "number" || name == "string" || name == "table" || + name == "function" || name == "thread") + return Kind_Primitive; + + if (name == "vector") + return Kind_Vector; + + if (std::optional maybeTy = context->scope->lookupType(name)) + // Kind_Userdata is probably not 100% precise but is close enough + return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata; + else if (std::optional maybeTy = context->scope->lookupImportedType("Enum", name)) + return Kind_Enum; + + return Kind_Invalid; + } + + void validateType(AstExprConstantString* expr, std::initializer_list expected, const char* expectedString) + { + std::string name(expr->value.data, expr->value.size); + TypeKind kind = getTypeKind(name); + + if (kind == Kind_Invalid) + { + emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s'", name.c_str()); + return; + } + + for (TypeKind ek : expected) + { + if (kind == ek) + return; + + // as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type + if (ek == Kind_Userdata && (name == "Instance" || name == "EnumItem")) + return; + } + + emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s' (expected %s)", name.c_str(), expectedString); + } + + bool acceptsClassName(AstName method) + { + return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" || + method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA"); + } + + bool visit(AstExprCall* node) override + { + if (AstExprIndexName* index = node->func->as()) + { + AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as() : NULL; + + if (arg0) + { + if (node->self && index->index == "IsA" && node->args.size == 1) + { + validateType(arg0, {Kind_Class, Kind_Enum}, "class or enum type"); + } + else if (node->self && (index->index == "GetService" || index->index == "FindService") && node->args.size == 1) + { + AstExprGlobal* g = index->expr->as(); + + if (g && (g->name == "game" || g->name == "Game")) + { + validateType(arg0, {Kind_Class}, "class type"); + } + } + else if (node->self && acceptsClassName(index->index) && node->args.size == 1) + { + validateType(arg0, {Kind_Class}, "class type"); + } + else if (!node->self && index->index == "new" && node->args.size <= 2) + { + AstExprGlobal* g = index->expr->as(); + + if (g && g->name == "Instance") + { + validateType(arg0, {Kind_Class}, "class type"); + } + } + } + } + + return true; + } + + bool visit(AstExprBinary* node) override + { + if (node->op == AstExprBinary::CompareNe || node->op == AstExprBinary::CompareEq) + { + AstExpr* lhs = node->left; + AstExpr* rhs = node->right; + + if (!rhs->is()) + std::swap(lhs, rhs); + + AstExprCall* call = lhs->as(); + AstExprConstantString* arg = rhs->as(); + + if (call && arg) + { + AstExprGlobal* g = call->func->as(); + + if (g && g->name == "type") + { + if (FFlag::LuauLinterUnknownTypeVectorAware) + validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type"); + else + validateType(arg, {Kind_Primitive}, "primitive type"); + } + else if (g && g->name == "typeof") + { + validateType(arg, {Kind_Primitive, Kind_Userdata}, "primitive or userdata type"); + } + } + } + + return true; + } +}; + +class LintForRange : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintForRange pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + double getLoopEnd(double from, double to) + { + return from + floor(to - from); + } + + bool visit(AstStatFor* node) override + { + // note: we silence all warnings below if *any* step is specified, assuming that the user knows best + if (!node->step) + { + AstExprConstantNumber* fc = node->from->as(); + AstExprUnary* fu = node->from->as(); + AstExprConstantNumber* tc = node->to->as(); + AstExprUnary* tu = node->to->as(); + + Location rangeLocation(node->from->location, node->to->location); + + // for i=#t,1 do + if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 1.0) + emitWarning( + *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"); + // for i=8,1 do + else if (fc && tc && fc->value > tc->value) + emitWarning( + *context, LintWarning::Code_ForRange, rangeLocation, "For loop should iterate backwards; did you forget to specify -1 as step?"); + // for i=1,8.75 do + else if (fc && tc && getLoopEnd(fc->value, tc->value) != tc->value) + emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop ends at %g instead of %g; did you forget to specify step?", + getLoopEnd(fc->value, tc->value), tc->value); + // for i=0,#t do + else if (fc && tu && fc->value == 0.0 && tu->op == AstExprUnary::Len) + emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, "For loop starts at 0, but arrays start at 1"); + // for i=#t,0 do + else if (fu && fu->op == AstExprUnary::Len && tc && tc->value == 0.0) + emitWarning(*context, LintWarning::Code_ForRange, rangeLocation, + "For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1"); + } + + return true; + } +}; + +class LintUnbalancedAssignment : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintUnbalancedAssignment pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + void assign(size_t vars, const AstArray& values, const Location& location) + { + if (vars != values.size && values.size > 0) + { + AstExpr* last = values.data[values.size - 1]; + + if (vars < values.size) + emitWarning(*context, LintWarning::Code_UnbalancedAssignment, location, + "Assigning %d values to %d variables leaves some values unused", int(values.size), int(vars)); + else if (last->is() || last->is()) + ; // we don't know how many values the last expression returns + else if (last->is()) + ; // last expression is nil which explicitly silences the nil-init warning + else + emitWarning(*context, LintWarning::Code_UnbalancedAssignment, location, + "Assigning %d values to %d variables initializes extra variables with nil; add 'nil' to value list to silence", int(values.size), + int(vars)); + } + } + + bool visit(AstStatLocal* node) override + { + assign(node->vars.size, node->values, node->location); + + return true; + } + + bool visit(AstStatAssign* node) override + { + assign(node->vars.size, node->values, node->location); + + return true; + } +}; + +class LintImplicitReturn : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintImplicitReturn pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + Location getEndLocation(const AstStat* node) + { + Location loc = node->location; + + if (node->is() || node->is() || node->is()) + return loc; + + if (loc.begin.line == loc.end.line) + return loc; + + // assume that we're in context of a statement that has an "end" block + return Location(Position(loc.end.line, std::max(0, int(loc.end.column) - 3)), loc.end); + } + + AstStatReturn* getValueReturn(AstStat* node) + { + struct Visitor : AstVisitor + { + AstStatReturn* result = nullptr; + + bool visit(AstExpr* node) override + { + (void)node; + return false; + } + + bool visit(AstStatReturn* node) override + { + if (!result && node->list.size > 0) + result = node; + + return false; + } + }; + + Visitor visitor; + node->visit(&visitor); + return visitor.result; + } + + bool visit(AstExprFunction* node) override + { + const AstStat* bodyf = getFallthrough(node->body); + AstStat* vret = getValueReturn(node->body); + + if (bodyf && vret) + { + Location location = getEndLocation(bodyf); + + if (node->debugname.value) + emitWarning(*context, LintWarning::Code_ImplicitReturn, location, + "Function '%s' can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence", + node->debugname.value, vret->location.begin.line + 1); + else + emitWarning(*context, LintWarning::Code_ImplicitReturn, location, + "Function can implicitly return no values even though there's an explicit return at line %d; add explicit return to silence", + vret->location.begin.line + 1); + } + + return true; + } +}; + +class LintFormatString : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintFormatString pass; + pass.context = &context; + + context.root->visit(&pass); + } + + static void fuzz(const char* data, size_t size) + { + LintContext context; + + LintFormatString pass; + pass.context = &context; + + pass.checkStringFormat(data, size); + pass.checkStringPack(data, size, false); + pass.checkStringMatch(data, size); + pass.checkStringReplace(data, size, -1); + pass.checkDateFormat(data, size); + } + +private: + LintContext* context; + + static inline bool isAlpha(char ch) + { + // use or trick to convert to lower case and unsigned comparison to do range check + return unsigned((ch | ' ') - 'a') < 26; + } + + static inline bool isDigit(char ch) + { + // use unsigned comparison to do range check for performance + return unsigned(ch - '0') < 10; + } + + const char* checkStringFormat(const char* data, size_t size) + { + const char* flags = "-+ #0"; + const char* options = "cdiouxXeEfgGqs"; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + i++; + + // escaped % doesn't allow for flags/etc. + if (i < size && data[i] == '%') + continue; + + // skip flags + while (i < size && strchr(flags, data[i])) + i++; + + // skip width (up to two digits) + if (i < size && isDigit(data[i])) + i++; + if (i < size && isDigit(data[i])) + i++; + + // skip precision + if (i < size && data[i] == '.') + { + i++; + + // up to two digits + if (i < size && isDigit(data[i])) + i++; + if (i < size && isDigit(data[i])) + i++; + } + + if (i == size) + return "unfinished format specifier"; + + if (!strchr(options, data[i])) + return "invalid format specifier: must be a string format specifier or %"; + } + } + + return nullptr; + } + + const char* checkStringPack(const char* data, size_t size, bool fixed) + { + const char* options = "<>=!bBhHlLjJTiIfdnczsxX "; + const char* unsized = "<>=!zX "; + + for (size_t i = 0; i < size; ++i) + { + if (!strchr(options, data[i])) + return "unexpected character; must be a pack specifier or space"; + + if (data[i] == 'c' && (i + 1 == size || !isDigit(data[i + 1]))) + return "fixed-sized string format must specify the size"; + + if (data[i] == 'X' && (i + 1 == size || strchr(unsized, data[i + 1]))) + return "X must be followed by a size specifier"; + + if (fixed && (data[i] == 'z' || data[i] == 's')) + return "pack specifier must be fixed-size"; + + if ((data[i] == '!' || data[i] == 'i' || data[i] == 'I' || data[i] == 'c' || data[i] == 's') && i + 1 < size && isDigit(data[i + 1])) + { + bool isc = data[i] == 'c'; + + unsigned int v = 0; + while (i + 1 < size && isDigit(data[i + 1]) && v <= (INT_MAX - 9) / 10) + { + v = v * 10 + (data[i + 1] - '0'); + i++; + } + + if (i + 1 < size && isDigit(data[i + 1])) + return "size specifier is too large"; + + if (!isc && (v == 0 || v > 16)) + return "integer size must be in range [1,16]"; + } + } + + return nullptr; + } + + const char* checkStringMatchSet(const char* data, size_t size, const char* magic, const char* classes) + { + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + i++; + + if (i == size) + return "unfinished character class"; + + if (isDigit(data[i])) + { + return "sets can not contain capture references"; + } + else if (isAlpha(data[i])) + { + // lower case lookup - upper case for every character class is defined as its inverse + if (!strchr(classes, data[i] | ' ')) + return "invalid character class, must refer to a defined class or its inverse"; + } + else + { + // technically % can escape any non-alphanumeric character but this is error-prone + if (!strchr(magic, data[i])) + return "expected a magic character after %"; + } + + if (i + 1 < size && data[i + 1] == '-') + return "character range can't include character sets"; + } + else if (data[i] == '-') + { + if (i + 1 < size && data[i + 1] == '%') + return "character range can't include character sets"; + } + } + + return nullptr; + } + + const char* checkStringMatch(const char* data, size_t size, int* outCaptures = nullptr) + { + const char* magic = "^$()%.[]*+-?)"; + const char* classes = "acdglpsuwxz"; + + std::vector openCaptures; + int totalCaptures = 0; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + i++; + + if (i == size) + return "unfinished character class"; + + if (isDigit(data[i])) + { + if (data[i] == '0') + return "invalid capture reference, must be 1-9"; + + int captureIndex = data[i] - '0'; + + if (captureIndex > totalCaptures) + return "invalid capture reference, must refer to a valid capture"; + + for (int open : openCaptures) + if (open == captureIndex) + return "invalid capture reference, must refer to a closed capture"; + } + else if (isAlpha(data[i])) + { + if (data[i] == 'b') + { + if (i + 2 >= size) + return "missing brace characters for balanced match"; + + i += 2; + } + else if (data[i] == 'f') + { + if (i + 1 >= size || data[i + 1] != '[') + return "missing set after a frontier pattern"; + + // we can parse the set with the regular logic + } + else + { + // lower case lookup - upper case for every character class is defined as its inverse + if (!strchr(classes, data[i] | ' ')) + return "invalid character class, must refer to a defined class or its inverse"; + } + } + else + { + // technically % can escape any non-alphanumeric character but this is error-prone + if (!strchr(magic, data[i])) + return "expected a magic character after %"; + } + } + else if (data[i] == '[') + { + size_t j = i + 1; + + // empty patterns don't exist as per grammar rules, so we skip leading ^ and ] + if (j < size && data[j] == '^') + j++; + + if (j < size && data[j] == ']') + j++; + + // scan for the end of the pattern + while (j < size && data[j] != ']') + { + // % escapes the next character + if (j + 1 < size && data[j] == '%') + j++; + + j++; + } + + if (j == size) + return "expected ] at the end of the string to close a set"; + + if (const char* error = checkStringMatchSet(data + i + 1, j - i - 1, magic, classes)) + return error; + + LUAU_ASSERT(data[j] == ']'); + i = j; + } + else if (data[i] == '(') + { + totalCaptures++; + openCaptures.push_back(totalCaptures); + } + else if (data[i] == ')') + { + if (openCaptures.empty()) + return "unexpected ) without a matching ("; + openCaptures.pop_back(); + } + } + + if (!openCaptures.empty()) + return "expected ) at the end of the string to close a capture"; + + if (outCaptures) + *outCaptures = totalCaptures; + + return nullptr; + } + + const char* checkStringReplace(const char* data, size_t size, int captures) + { + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + i++; + + if (i == size) + return "unfinished replacement"; + + if (data[i] != '%' && !isDigit(data[i])) + return "unexpected replacement character; must be a digit or %"; + + if (isDigit(data[i]) && captures >= 0 && data[i] - '0' > captures) + return "invalid capture index, must refer to pattern capture"; + } + } + + return nullptr; + } + + const char* checkDateFormat(const char* data, size_t size) + { + const char* options = "aAbBcdHIjmMpSUwWxXyYzZ"; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + i++; + + if (i == size) + return "unfinished replacement"; + + if (data[i] != '%' && !strchr(options, data[i])) + return "unexpected replacement character; must be a date format specifier or %"; + } + + if (data[i] == 0) + return "date format can not contain null characters"; + } + + return nullptr; + } + + void matchStringCall(AstName name, AstExpr* self, AstArray args) + { + if (name == "format") + { + if (AstExprConstantString* fmt = self->as()) + if (const char* error = checkStringFormat(fmt->value.data, fmt->value.size)) + emitWarning(*context, LintWarning::Code_FormatString, fmt->location, "Invalid format string: %s", error); + } + else if (name == "pack" || name == "packsize" || name == "unpack") + { + if (AstExprConstantString* fmt = self->as()) + if (const char* error = checkStringPack(fmt->value.data, fmt->value.size, name == "packsize")) + emitWarning(*context, LintWarning::Code_FormatString, fmt->location, "Invalid pack format: %s", error); + } + else if ((name == "match" || name == "gmatch") && args.size > 0) + { + if (AstExprConstantString* pat = args.data[0]->as()) + if (const char* error = checkStringMatch(pat->value.data, pat->value.size)) + emitWarning(*context, LintWarning::Code_FormatString, pat->location, "Invalid match pattern: %s", error); + } + else if (name == "find" && args.size > 0 && args.size <= 2) + { + if (AstExprConstantString* pat = args.data[0]->as()) + if (const char* error = checkStringMatch(pat->value.data, pat->value.size)) + emitWarning(*context, LintWarning::Code_FormatString, pat->location, "Invalid match pattern: %s", error); + } + else if (name == "find" && args.size >= 3) + { + AstExprConstantBool* mode = args.data[2]->as(); + + // find(_, _, _, true) is a raw string find, not a pattern match + if (mode && !mode->value) + if (AstExprConstantString* pat = args.data[0]->as()) + if (const char* error = checkStringMatch(pat->value.data, pat->value.size)) + emitWarning(*context, LintWarning::Code_FormatString, pat->location, "Invalid match pattern: %s", error); + } + else if (name == "gsub" && args.size > 1) + { + int captures = -1; + + if (AstExprConstantString* pat = args.data[0]->as()) + if (const char* error = checkStringMatch(pat->value.data, pat->value.size, &captures)) + emitWarning(*context, LintWarning::Code_FormatString, pat->location, "Invalid match pattern: %s", error); + + if (AstExprConstantString* rep = args.data[1]->as()) + if (const char* error = checkStringReplace(rep->value.data, rep->value.size, captures)) + emitWarning(*context, LintWarning::Code_FormatString, rep->location, "Invalid match replacement: %s", error); + } + } + + void matchCall(AstExprCall* node) + { + AstExprIndexName* func = node->func->as(); + if (!func) + return; + + if (node->self) + { + AstExprGroup* group = func->expr->as(); + AstExpr* self = group ? group->expr : func->expr; + + if (self->is()) + matchStringCall(func->index, self, node->args); + else if (std::optional type = context->getType(self)) + if (isString(*type)) + matchStringCall(func->index, self, node->args); + return; + } + + AstExprGlobal* lib = func->expr->as(); + if (!lib) + return; + + if (lib->name == "string") + { + if (node->args.size > 0) + { + AstArray rest = {node->args.data + 1, node->args.size - 1}; + + matchStringCall(func->index, node->args.data[0], rest); + } + } + else if (lib->name == "os") + { + if (func->index == "date" && node->args.size > 0) + { + if (AstExprConstantString* fmt = node->args.data[0]->as()) + if (const char* error = checkDateFormat(fmt->value.data, fmt->value.size)) + emitWarning(*context, LintWarning::Code_FormatString, fmt->location, "Invalid date format: %s", error); + } + } + } + + bool visit(AstExprCall* node) override + { + matchCall(node); + return true; + } +}; + +class LintTableLiteral : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintTableLiteral pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + bool visit(AstExprTable* node) override + { + int count = 0; + + for (const AstExprTable::Item& item : node->items) + if (item.kind == AstExprTable::Item::List) + count++; + + DenseHashMap*, int, AstArrayPredicate, AstArrayPredicate> names(nullptr); + DenseHashMap indices(-1); + + for (const AstExprTable::Item& item : node->items) + { + if (!item.key) + continue; + + if (AstExprConstantString* expr = item.key->as()) + { + int& line = names[&expr->value]; + + if (line) + emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, + "Table field '%.*s' is a duplicate; previously defined at line %d", int(expr->value.size), expr->value.data, line); + else + line = expr->location.begin.line + 1; + } + else if (AstExprConstantNumber* expr = item.key->as()) + { + if (expr->value >= 1 && expr->value <= double(count) && double(int(expr->value)) == expr->value) + emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, + "Table index %d is a duplicate; previously defined as a list entry", int(expr->value)); + else if (expr->value >= 0 && expr->value <= double(INT_MAX) && double(int(expr->value)) == expr->value) + { + int& line = indices[int(expr->value)]; + + if (line) + emitWarning(*context, LintWarning::Code_TableLiteral, expr->location, + "Table index %d is a duplicate; previously defined at line %d", int(expr->value), line); + else + line = expr->location.begin.line + 1; + } + } + } + + return true; + } + + bool visit(AstType* node) override + { + return true; + } + + bool visit(AstTypeTable* node) override + { + DenseHashMap names(AstName{}); + + for (const AstTableProp& item : node->props) + { + int& line = names[item.name]; + + if (line) + emitWarning(*context, LintWarning::Code_TableLiteral, item.location, + "Table type field '%s' is a duplicate; previously defined at line %d", item.name.value, line); + else + line = item.location.begin.line + 1; + } + + return true; + } + + struct AstArrayPredicate + { + size_t operator()(const AstArray* value) const + { + return hashRange(value->data, value->size); + } + + bool operator()(const AstArray* lhs, const AstArray* rhs) const + { + return (lhs && rhs) ? lhs->size == rhs->size && memcmp(lhs->data, rhs->data, lhs->size) == 0 : lhs == rhs; + } + }; +}; + +class LintUninitializedLocal : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintUninitializedLocal pass; + pass.context = &context; + + context.root->visit(&pass); + + pass.report(); + } + +private: + struct Local + { + bool defined; + bool initialized; + bool assigned; + AstExprLocal* firstUse; + }; + + LintContext* context; + DenseHashMap locals; + + LintUninitializedLocal() + : locals(NULL) + { + } + + void report() + { + for (auto& lp : locals) + { + AstLocal* local = lp.first; + const Local& l = lp.second; + + if (l.defined && !l.initialized && !l.assigned && l.firstUse) + { + emitWarning(*context, LintWarning::Code_UninitializedLocal, l.firstUse->location, + "Variable '%s' defined at line %d is never initialized or assigned; initialize with 'nil' to silence", local->name.value, + local->location.begin.line + 1); + } + } + } + + bool visit(AstStatLocal* node) override + { + AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr; + bool vararg = last && (last->is() || last->is()); + + for (size_t i = 0; i < node->vars.size; ++i) + { + Local& l = locals[node->vars.data[i]]; + + l.defined = true; + l.initialized = vararg || i < node->values.size; + } + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + visitAssign(node->vars.data[i]); + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatFunction* node) override + { + visitAssign(node->name); + node->func->visit(this); + + return false; + } + + bool visit(AstExprLocal* node) override + { + Local& l = locals[node->local]; + + if (!l.firstUse) + l.firstUse = node; + + return false; + } + + void visitAssign(AstExpr* var) + { + if (AstExprLocal* lv = var->as()) + { + Local& l = locals[lv->local]; + + l.assigned = true; + } + else + { + var->visit(this); + } + } +}; + +class LintDuplicateFunction : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintDuplicateFunction pass{&context}; + context.root->visit(&pass); + } + +private: + LintContext* context; + DenseHashMap defns; + + LintDuplicateFunction(LintContext* context) + : context(context) + , defns("") + { + } + + bool visit(AstStatBlock* block) override + { + defns.clear(); + + for (AstStat* stat : block->body) + { + if (AstStatFunction* func = stat->as()) + trackFunction(func->name->location, buildName(func->name)); + else if (AstStatLocalFunction* func = stat->as()) + trackFunction(func->name->location, func->name->name.value); + } + + return true; + } + + void trackFunction(Location location, const std::string& name) + { + if (name.empty()) + return; + + Location& defn = defns[name]; + + if (defn.end.line == 0 && defn.end.column == 0) + defn = location; + else + report(name, location, defn); + } + + std::string buildName(AstExpr* expr) + { + if (AstExprLocal* local = expr->as()) + return local->local->name.value; + else if (AstExprGlobal* global = expr->as()) + return global->name.value; + else if (AstExprIndexName* indexName = expr->as()) + { + std::string lhs = buildName(indexName->expr); + if (lhs.empty()) + return lhs; + + lhs += '.'; + lhs += indexName->index.value; + return lhs; + } + else + return std::string(); + } + + void report(const std::string& name, Location location, Location otherLocation) + { + emitWarning(*context, LintWarning::Code_DuplicateFunction, location, "Duplicate function definition: '%s' also defined on line %d", + name.c_str(), otherLocation.begin.line + 1); + } +}; + +class LintDeprecatedApi : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + if (!context.module) + return; + + LintDeprecatedApi pass{&context}; + context.root->visit(&pass); + } + +private: + LintContext* context; + + LintDeprecatedApi(LintContext* context) + : context(context) + { + } + + bool visit(AstExprIndexName* node) override + { + std::optional ty = context->getType(node->expr); + if (!ty) + return true; + + if (const ClassTypeVar* cty = get(follow(*ty))) + { + const Property* prop = lookupClassProp(cty, node->index.value); + + if (prop && prop->deprecated) + { + if (!prop->deprecatedSuggestion.empty()) + emitWarning(*context, LintWarning::Code_DeprecatedApi, node->location, "Member '%s.%s' is deprecated, use '%s' instead", + cty->name.c_str(), node->index.value, prop->deprecatedSuggestion.c_str()); + else + emitWarning(*context, LintWarning::Code_DeprecatedApi, node->location, "Member '%s.%s' is deprecated", cty->name.c_str(), + node->index.value); + } + } + + return true; + } +}; + +class LintTableOperations : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + if (!context.module) + return; + + LintTableOperations pass{&context}; + context.root->visit(&pass); + } + +private: + LintContext* context; + + LintTableOperations(LintContext* context) + : context(context) + { + } + + bool visit(AstExprCall* node) override + { + AstExprIndexName* func = node->func->as(); + if (!func) + return true; + + AstExprGlobal* tablib = func->expr->as(); + if (!tablib || tablib->name != "table") + return true; + + AstExpr** args = node->args.data; + + if (func->index == "insert" && node->args.size == 2) + { + if (AstExprCall* tail = args[1]->as()) + { + if (std::optional funty = context->getType(tail->func)) + { + size_t ret = getReturnCount(follow(*funty)); + + if (ret > 1) + emitWarning(*context, LintWarning::Code_TableOperations, tail->location, + "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second " + "argument"); + } + } + } + + if (func->index == "insert" && node->args.size >= 3) + { + // table.insert(t, 0, ?) + if (isConstant(args[1], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?"); + + // table.insert(t, #t, ?) + if (isLength(args[1], args[0])) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.insert will insert the value before the last element, which is likely a bug; consider removing the second argument or " + "wrap it in parentheses to silence"); + + // table.insert(t, #t+1, ?) + if (AstExprBinary* add = args[1]->as(); + add && add->op == AstExprBinary::Add && isLength(add->left, args[0]) && isConstant(add->right, 1.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.insert will append the value to the table; consider removing the second argument for efficiency"); + } + + if (func->index == "remove" && node->args.size >= 2) + { + // table.remove(t, 0) + if (isConstant(args[1], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?"); + + // note: it's tempting to check for table.remove(t, #t), which is equivalent to table.remove(t), but it's correct, occurs frequently, + // and also reads better. + + // table.remove(t, #t-1) + if (AstExprBinary* sub = args[1]->as(); + sub && sub->op == AstExprBinary::Sub && isLength(sub->left, args[0]) && isConstant(sub->right, 1.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.remove will remove the value before the last element, which is likely a bug; consider removing the second argument or " + "wrap it in parentheses to silence"); + } + + return true; + } + + bool isConstant(AstExpr* expr, double value) + { + AstExprConstantNumber* n = expr->as(); + return n && n->value == value; + } + + bool isLength(AstExpr* expr, AstExpr* table) + { + AstExprUnary* n = expr->as(); + return n && n->op == AstExprUnary::Len && similar(n->expr, table); + } + + size_t getReturnCount(TypeId ty) + { + if (auto ftv = get(ty)) + return size(ftv->retType); + + if (auto itv = get(ty)) + { + // We don't process the type recursively to avoid having to deal with self-recursive intersection types + size_t result = 0; + + for (TypeId part : itv->parts) + if (auto ftv = get(follow(part))) + result = std::max(result, size(ftv->retType)); + + return result; + } + + return 0; + } +}; + +class LintDuplicateCondition : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintDuplicateCondition pass{&context}; + context.root->visit(&pass); + } + +private: + LintContext* context; + + LintDuplicateCondition(LintContext* context) + : context(context) + { + } + + bool visit(AstStatIf* stat) override + { + if (!stat->elsebody) + return true; + + if (!stat->elsebody->is()) + return true; + + // if..elseif chain detected, we need to unroll it + std::vector conditions; + conditions.reserve(2); + + AstStatIf* head = stat; + while (head) + { + head->condition->visit(this); + head->thenbody->visit(this); + + conditions.push_back(head->condition); + + if (head->elsebody && head->elsebody->is()) + { + head = head->elsebody->as(); + continue; + } + + if (head->elsebody) + head->elsebody->visit(this); + + break; + } + + detectDuplicates(conditions); + + // block recursive visits so that we only analyze each chain once + return false; + } + + bool visit(AstExprBinary* expr) override + { + if (expr->op != AstExprBinary::And && expr->op != AstExprBinary::Or) + return true; + + // for And expressions, it's idiomatic to use "a and a or b" as a ternary replacement, so we detect this pattern + if (expr->op == AstExprBinary::Or) + { + AstExprBinary* la = expr->left->as(); + + if (la && la->op == AstExprBinary::And) + { + AstExprBinary* lb = la->left->as(); + AstExprBinary* rb = la->right->as(); + + // check that the length of and-chain is exactly 2 + if (!(lb && lb->op == AstExprBinary::And) && !(rb && rb->op == AstExprBinary::And)) + { + la->left->visit(this); + la->right->visit(this); + expr->right->visit(this); + return false; + } + } + } + + // unroll condition chain + std::vector conditions; + conditions.reserve(2); + + extractOpChain(conditions, expr, expr->op); + + detectDuplicates(conditions); + + // block recursive visits so that we only analyze each chain once + return false; + } + + void extractOpChain(std::vector& conditions, AstExpr* expr, AstExprBinary::Op op) + { + if (AstExprBinary* bin = expr->as(); bin && bin->op == op) + { + extractOpChain(conditions, bin->left, op); + extractOpChain(conditions, bin->right, op); + } + else if (AstExprGroup* group = expr->as()) + { + extractOpChain(conditions, group->expr, op); + } + else + { + conditions.push_back(expr); + } + } + + void detectDuplicates(const std::vector& conditions) + { + // Limit the distance at which we consider duplicates to reduce N^2 complexity to KN + const size_t kMaxDistance = 5; + + for (size_t i = 0; i < conditions.size(); ++i) + { + for (size_t j = std::max(i, kMaxDistance) - kMaxDistance; j < i; ++j) + { + if (similar(conditions[j], conditions[i])) + { + if (conditions[i]->location.begin.line == conditions[j]->location.begin.line) + emitWarning(*context, LintWarning::Code_DuplicateCondition, conditions[i]->location, + "Condition has already been checked on column %d", conditions[j]->location.begin.column + 1); + else + emitWarning(*context, LintWarning::Code_DuplicateCondition, conditions[i]->location, + "Condition has already been checked on line %d", conditions[j]->location.begin.line + 1); + break; + } + } + } + } +}; + +class LintDuplicateLocal : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintDuplicateLocal pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + DenseHashMap locals; + + LintDuplicateLocal() + : locals(nullptr) + { + } + + bool visit(AstStatLocal* node) override + { + // early out for performance + if (node->vars.size == 1) + return true; + + for (size_t i = 0; i < node->vars.size; ++i) + locals[node->vars.data[i]] = node; + + for (size_t i = 0; i < node->vars.size; ++i) + { + AstLocal* local = node->vars.data[i]; + + if (local->shadow && locals[local->shadow] == node && !ignoreDuplicate(local)) + { + if (local->shadow->location.begin.line == local->location.begin.line) + emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Variable '%s' already defined on column %d", + local->name.value, local->shadow->location.begin.column + 1); + else + emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Variable '%s' already defined on line %d", + local->name.value, local->shadow->location.begin.line + 1); + } + } + + return true; + } + + bool visit(AstExprFunction* node) override + { + if (node->self) + locals[node->self] = node; + + for (size_t i = 0; i < node->args.size; ++i) + locals[node->args.data[i]] = node; + + for (size_t i = 0; i < node->args.size; ++i) + { + AstLocal* local = node->args.data[i]; + + if (local->shadow && locals[local->shadow] == node && !ignoreDuplicate(local)) + { + if (local->shadow == node->self) + emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter 'self' already defined implicitly"); + else if (local->shadow->location.begin.line == local->location.begin.line) + emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter '%s' already defined on column %d", + local->name.value, local->shadow->location.begin.column + 1); + else + emitWarning(*context, LintWarning::Code_DuplicateLocal, local->location, "Function parameter '%s' already defined on line %d", + local->name.value, local->shadow->location.begin.line + 1); + } + } + + return true; + } + + bool ignoreDuplicate(AstLocal* local) + { + return local->name == "_"; + } +}; + +static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, const ScopePtr& env) +{ + ScopePtr current = env; + while (true) + { + for (auto& [global, binding] : current->bindings) + { + AstName name = names.get(global.c_str()); + + if (name.value) + { + auto& g = context.builtinGlobals[name]; + g.type = binding.typeId; + if (binding.deprecated) + g.deprecated = binding.deprecatedSuggestion.c_str(); + } + } + + if (current->parent) + current = current->parent; + else + break; + } +} + +void LintOptions::setDefaults() +{ + // By default, we enable all warnings + warningMask = ~0ull; +} + +std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const LintOptions& options) +{ + LintContext context; + + context.options = options; + context.root = root; + context.placeholder = names.get("_"); + context.scope = env; + context.module = module; + + fillBuiltinGlobals(context, names, env); + + if (context.warningEnabled(LintWarning::Code_UnknownGlobal) || context.warningEnabled(LintWarning::Code_DeprecatedGlobal) || + context.warningEnabled(LintWarning::Code_GlobalUsedAsLocal) || context.warningEnabled(LintWarning::Code_PlaceholderRead) || + context.warningEnabled(LintWarning::Code_BuiltinGlobalWrite)) + { + LintGlobalLocal::process(context); + } + + if (context.warningEnabled(LintWarning::Code_MultiLineStatement)) + LintMultiLineStatement::process(context); + + if (context.warningEnabled(LintWarning::Code_SameLineStatement)) + LintSameLineStatement::process(context); + + if (context.warningEnabled(LintWarning::Code_LocalShadow) || context.warningEnabled(LintWarning::Code_FunctionUnused) || + context.warningEnabled(LintWarning::Code_ImportUnused) || context.warningEnabled(LintWarning::Code_LocalUnused)) + { + LintLocalHygiene::process(context); + } + + if (context.warningEnabled(LintWarning::Code_FunctionUnused)) + LintUnusedFunction::process(context); + + if (context.warningEnabled(LintWarning::Code_UnreachableCode)) + LintUnreachableCode::process(context); + + if (context.warningEnabled(LintWarning::Code_UnknownType)) + LintUnknownType::process(context); + + if (context.warningEnabled(LintWarning::Code_ForRange)) + LintForRange::process(context); + + if (context.warningEnabled(LintWarning::Code_UnbalancedAssignment)) + LintUnbalancedAssignment::process(context); + + if (context.warningEnabled(LintWarning::Code_ImplicitReturn)) + LintImplicitReturn::process(context); + + if (context.warningEnabled(LintWarning::Code_FormatString)) + LintFormatString::process(context); + + if (context.warningEnabled(LintWarning::Code_TableLiteral)) + LintTableLiteral::process(context); + + if (context.warningEnabled(LintWarning::Code_UninitializedLocal)) + LintUninitializedLocal::process(context); + + if (context.warningEnabled(LintWarning::Code_DuplicateFunction)) + LintDuplicateFunction::process(context); + + if (context.warningEnabled(LintWarning::Code_DeprecatedApi)) + LintDeprecatedApi::process(context); + + if (context.warningEnabled(LintWarning::Code_TableOperations)) + LintTableOperations::process(context); + + if (context.warningEnabled(LintWarning::Code_DuplicateCondition)) + LintDuplicateCondition::process(context); + + if (context.warningEnabled(LintWarning::Code_DuplicateLocal)) + LintDuplicateLocal::process(context); + + std::sort(context.result.begin(), context.result.end(), WarningComparator()); + + return context.result; +} + +const char* LintWarning::getName(Code code) +{ + LUAU_ASSERT(unsigned(code) < Code__Count); + + return kWarningNames[code]; +} + +LintWarning::Code LintWarning::parseName(const char* name) +{ + for (int code = Code_Unknown; code < Code__Count; ++code) + if (strcmp(name, getName(Code(code))) == 0) + return Code(code); + + return Code_Unknown; +} + +uint64_t LintWarning::parseMask(const std::vector& hotcomments) +{ + uint64_t result = 0; + + for (const std::string& hc : hotcomments) + { + if (hc.compare(0, 6, "nolint") != 0) + continue; + + std::string::size_type name = hc.find_first_not_of(" \t", 6); + + // --!nolint disables everything + if (name == std::string::npos) + return ~0ull; + + // --!nolint name disables the specific lint + LintWarning::Code code = LintWarning::parseName(hc.c_str() + name); + + if (code != LintWarning::Code_Unknown) + result |= 1ull << int(code); + } + + return result; +} + +std::vector getDeprecatedGlobals(const AstNameTable& names) +{ + LintContext context; + + std::vector result; + result.reserve(context.builtinGlobals.size()); + + for (auto& p : context.builtinGlobals) + if (p.second.deprecated) + result.push_back(p.first); + + return result; +} + +void fuzzFormatString(const char* data, size_t size) +{ + LintFormatString::fuzz(data, size); +} + +} // namespace Luau diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp new file mode 100644 index 0000000..f1d975f --- /dev/null +++ b/Analysis/src/Module.cpp @@ -0,0 +1,521 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Module.h" + +#include "Luau/TypeInfer.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" +#include "Luau/Common.h" + +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) +LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) +LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) +LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) + +namespace Luau +{ + +static bool contains(Position pos, Comment comment) +{ + if (comment.location.contains(pos)) + return true; + else if (FFlag::LuauCaptureBrokenCommentSpans && comment.type == Lexeme::BrokenComment && + comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end + return true; + else if (comment.type == Lexeme::Comment && comment.location.end == pos) + return true; + else + return false; +} + +bool isWithinComment(const SourceModule& sourceModule, Position pos) +{ + auto iter = std::lower_bound(sourceModule.commentLocations.begin(), sourceModule.commentLocations.end(), + Comment{Lexeme::Comment, Location{pos, pos}}, [](const Comment& a, const Comment& b) { + return a.location.end < b.location.end; + }); + + if (iter == sourceModule.commentLocations.end()) + return false; + + if (contains(pos, *iter)) + return true; + + // Due to the nature of std::lower_bound, it is possible that iter points at a comment that ends + // at pos. We'll try the next comment, if it exists. + ++iter; + if (iter == sourceModule.commentLocations.end()) + return false; + + return contains(pos, *iter); +} + +void TypeArena::clear() +{ + typeVars.clear(); + typePacks.clear(); +} + +TypeId TypeArena::addTV(TypeVar&& tv) +{ + TypeId allocated = typeVars.allocate(std::move(tv)); + + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypeId TypeArena::freshType(TypeLevel level) +{ + TypeId allocated = typeVars.allocate(FreeTypeVar{level}); + + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(std::initializer_list types) +{ + TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); + + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(std::vector types) +{ + TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); + + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(TypePack tp) +{ + TypePackId allocated = typePacks.allocate(std::move(tp)); + + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = this; + + return allocated; +} + +TypePackId TypeArena::addTypePack(TypePackVar tp) +{ + TypePackId allocated = typePacks.allocate(std::move(tp)); + + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = this; + + return allocated; +} + +using SeenTypes = std::unordered_map; +using SeenTypePacks = std::unordered_map; + +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType); +TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType); + +namespace +{ + +struct TypePackCloner; + +/* + * Both TypeCloner and TypePackCloner work by depositing the requested type variable into the appropriate 'seen' set. + * They do not return anything because their sole consumer (the deepClone function) already has a pointer into this storage. + */ + +struct TypeCloner +{ + TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) + : dest(dest) + , typeId(typeId) + , seenTypes(seenTypes) + , seenTypePacks(seenTypePacks) + { + } + + TypeArena& dest; + TypeId typeId; + SeenTypes& seenTypes; + SeenTypePacks& seenTypePacks; + + bool* encounteredFreeType = nullptr; + + template + void defaultClone(const T& t); + + void operator()(const Unifiable::Free& t); + void operator()(const Unifiable::Generic& t); + void operator()(const Unifiable::Bound& t); + void operator()(const Unifiable::Error& t); + void operator()(const PrimitiveTypeVar& t); + void operator()(const FunctionTypeVar& t); + void operator()(const TableTypeVar& t); + void operator()(const MetatableTypeVar& t); + void operator()(const ClassTypeVar& t); + void operator()(const AnyTypeVar& t); + void operator()(const UnionTypeVar& t); + void operator()(const IntersectionTypeVar& t); + void operator()(const LazyTypeVar& t); +}; + +struct TypePackCloner +{ + TypeArena& dest; + TypePackId typePackId; + SeenTypes& seenTypes; + SeenTypePacks& seenTypePacks; + bool* encounteredFreeType = nullptr; + + TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) + : dest(dest) + , typePackId(typePackId) + , seenTypes(seenTypes) + , seenTypePacks(seenTypePacks) + { + } + + template + void defaultClone(const T& t) + { + TypePackId cloned = dest.typePacks.allocate(t); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const Unifiable::Free& t) + { + if (encounteredFreeType) + *encounteredFreeType = true; + + seenTypePacks[typePackId] = dest.typePacks.allocate(TypePackVar{Unifiable::Error{}}); + } + + void operator()(const Unifiable::Generic& t) + { + defaultClone(t); + } + void operator()(const Unifiable::Error& t) + { + defaultClone(t); + } + + // While we are a-cloning, we can flatten out bound TypeVars and make things a bit tighter. + // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. + void operator()(const Unifiable::Bound& t) + { + TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const VariadicTypePack& t) + { + TypePackId cloned = dest.typePacks.allocate(VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}); + seenTypePacks[typePackId] = cloned; + } + + void operator()(const TypePack& t) + { + TypePackId cloned = dest.typePacks.allocate(TypePack{}); + TypePack* destTp = getMutable(cloned); + LUAU_ASSERT(destTp != nullptr); + seenTypePacks[typePackId] = cloned; + + for (TypeId ty : t.head) + destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + + if (t.tail) + destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, encounteredFreeType); + } +}; + +template +void TypeCloner::defaultClone(const T& t) +{ + TypeId cloned = dest.typeVars.allocate(t); + seenTypes[typeId] = cloned; +} + +void TypeCloner::operator()(const Unifiable::Free& t) +{ + if (encounteredFreeType) + *encounteredFreeType = true; + + seenTypes[typeId] = dest.typeVars.allocate(ErrorTypeVar{}); +} + +void TypeCloner::operator()(const Unifiable::Generic& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const Unifiable::Bound& t) +{ + TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + seenTypes[typeId] = boundTo; +} + +void TypeCloner::operator()(const Unifiable::Error& t) +{ + defaultClone(t); +} +void TypeCloner::operator()(const PrimitiveTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const FunctionTypeVar& t) +{ + TypeId result = dest.typeVars.allocate(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); + FunctionTypeVar* ftv = getMutable(result); + LUAU_ASSERT(ftv != nullptr); + + seenTypes[typeId] = result; + + for (TypeId generic : t.generics) + ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, encounteredFreeType)); + + for (TypePackId genericPack : t.genericPacks) + ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, encounteredFreeType)); + + if (FFlag::LuauSecondTypecheckKnowsTheDataModel) + ftv->tags = t.tags; + + ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, encounteredFreeType); + ftv->argNames = t.argNames; + ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, encounteredFreeType); +} + +void TypeCloner::operator()(const TableTypeVar& t) +{ + TypeId result = dest.typeVars.allocate(TableTypeVar{}); + TableTypeVar* ttv = getMutable(result); + LUAU_ASSERT(ttv != nullptr); + + *ttv = t; + + seenTypes[typeId] = result; + + ttv->level = TypeLevel{0, 0}; + + for (const auto& [name, prop] : t.props) + { + if (FFlag::LuauSecondTypecheckKnowsTheDataModel) + ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; + else + ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; + } + + if (t.indexer) + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), + clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)}; + + if (t.boundTo) + ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + + for (TypeId& arg : ttv->instantiatedTypeParams) + arg = (clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType)); + + if (ttv->state == TableState::Free) + { + if (!t.boundTo) + { + if (encounteredFreeType) + *encounteredFreeType = true; + } + + ttv->state = TableState::Sealed; + } + + ttv->definitionModuleName = t.definitionModuleName; + ttv->methodDefinitionLocations = t.methodDefinitionLocations; + ttv->tags = t.tags; +} + +void TypeCloner::operator()(const MetatableTypeVar& t) +{ + TypeId result = dest.typeVars.allocate(MetatableTypeVar{}); + MetatableTypeVar* mtv = getMutable(result); + seenTypes[typeId] = result; + + mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, encounteredFreeType); + mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType); +} + +void TypeCloner::operator()(const ClassTypeVar& t) +{ + TypeId result = dest.typeVars.allocate(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); + ClassTypeVar* ctv = getMutable(result); + + seenTypes[typeId] = result; + + for (const auto& [name, prop] : t.props) + if (FFlag::LuauSecondTypecheckKnowsTheDataModel) + ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; + else + ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; + + if (t.parent) + ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, encounteredFreeType); + + if (t.metatable) + ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType); +} + +void TypeCloner::operator()(const AnyTypeVar& t) +{ + defaultClone(t); +} + +void TypeCloner::operator()(const UnionTypeVar& t) +{ + TypeId result = dest.typeVars.allocate(UnionTypeVar{}); + seenTypes[typeId] = result; + + UnionTypeVar* option = getMutable(result); + LUAU_ASSERT(option != nullptr); + + for (TypeId ty : t.options) + option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); +} + +void TypeCloner::operator()(const IntersectionTypeVar& t) +{ + TypeId result = dest.typeVars.allocate(IntersectionTypeVar{}); + seenTypes[typeId] = result; + + IntersectionTypeVar* option = getMutable(result); + LUAU_ASSERT(option != nullptr); + + for (TypeId ty : t.parts) + option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); +} + +void TypeCloner::operator()(const LazyTypeVar& t) +{ + defaultClone(t); +} + +} // anonymous namespace + +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +{ + if (tp->persistent) + return tp; + + TypePackId& res = seenTypePacks[tp]; + + if (res == nullptr) + { + TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks}; + cloner.encounteredFreeType = encounteredFreeType; + Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. + } + + if (FFlag::DebugLuauTrackOwningArena) + asMutable(res)->owningArena = &dest; + + return res; +} + +TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +{ + if (typeId->persistent) + return typeId; + + TypeId& res = seenTypes[typeId]; + + if (res == nullptr) + { + TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks}; + cloner.encounteredFreeType = encounteredFreeType; + Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. + asMutable(res)->documentationSymbol = typeId->documentationSymbol; + } + + if (FFlag::DebugLuauTrackOwningArena) + asMutable(res)->owningArena = &dest; + + return res; +} + +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +{ + TypeFun result; + for (TypeId param : typeFun.typeParams) + result.typeParams.push_back(clone(param, dest, seenTypes, seenTypePacks, encounteredFreeType)); + + result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType); + + return result; +} + +ScopePtr Module::getModuleScope() const +{ + LUAU_ASSERT(!scopes.empty()); + return scopes.front().second; +} + +void freeze(TypeArena& arena) +{ + if (!FFlag::DebugLuauFreezeArena) + return; + + arena.typeVars.freeze(); + arena.typePacks.freeze(); +} + +void unfreeze(TypeArena& arena) +{ + if (!FFlag::DebugLuauFreezeArena) + return; + + arena.typeVars.unfreeze(); + arena.typePacks.unfreeze(); +} + +Module::~Module() +{ + unfreeze(interfaceTypes); + unfreeze(internalTypes); +} + +bool Module::clonePublicInterface() +{ + LUAU_ASSERT(interfaceTypes.typeVars.empty()); + LUAU_ASSERT(interfaceTypes.typePacks.empty()); + + bool encounteredFreeType = false; + + SeenTypePacks seenTypePacks; + SeenTypes seenTypes; + + ScopePtr moduleScope = getModuleScope(); + + moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + if (moduleScope->varargPack) + moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + + for (auto& pair : moduleScope->exportedTypeBindings) + pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + + for (TypeId ty : moduleScope->returnType) + if (get(follow(ty))) + *asMutable(ty) = AnyTypeVar{}; + + freeze(internalTypes); + freeze(interfaceTypes); + + return encounteredFreeType; +} + +} // namespace Luau diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/Predicate.cpp new file mode 100644 index 0000000..25e63bf --- /dev/null +++ b/Analysis/src/Predicate.cpp @@ -0,0 +1,93 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Predicate.h" + +#include "Luau/Ast.h" + +LUAU_FASTFLAG(LuauOrPredicate) + +namespace Luau +{ + +std::optional tryGetLValue(const AstExpr& node) +{ + const AstExpr* expr = &node; + while (auto e = expr->as()) + expr = e->expr; + + if (auto local = expr->as()) + return Symbol{local->local}; + else if (auto global = expr->as()) + return Symbol{global->name}; + else if (auto indexname = expr->as()) + { + if (auto lvalue = tryGetLValue(*indexname->expr)) + return Field{std::make_shared(*lvalue), indexname->index.value}; + } + else if (auto indexexpr = expr->as()) + { + if (auto lvalue = tryGetLValue(*indexexpr->expr)) + if (auto string = indexexpr->expr->as()) + return Field{std::make_shared(*lvalue), std::string(string->value.data, string->value.size)}; + } + + return std::nullopt; +} + +std::pair> getFullName(const LValue& lvalue) +{ + const LValue* current = &lvalue; + std::vector keys; + while (auto field = get(*current)) + { + keys.push_back(field->key); + current = field->parent.get(); + if (!current) + LUAU_ASSERT(!"LValue root is a Field?"); + } + + const Symbol* symbol = get(*current); + return {*symbol, std::vector(keys.rbegin(), keys.rend())}; +} + +std::string toString(const LValue& lvalue) +{ + auto [symbol, keys] = getFullName(lvalue); + std::string s = toString(symbol); + for (std::string key : keys) + s += "." + key; + return s; +} + +void merge(RefinementMap& l, const RefinementMap& r, std::function f) +{ + LUAU_ASSERT(FFlag::LuauOrPredicate); + + auto itL = l.begin(); + auto itR = r.begin(); + while (itL != l.end() && itR != r.end()) + { + const auto& [k, a] = *itR; + if (itL->first == k) + { + l[k] = f(itL->second, a); + ++itL; + ++itR; + } + else if (itL->first > k) + { + l[k] = a; + ++itR; + } + else + ++itL; + } + + l.insert(itR, r.end()); +} + +void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty) +{ + refis[toString(lvalue)] = ty; +} + +} // namespace Luau diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp new file mode 100644 index 0000000..5b3997e --- /dev/null +++ b/Analysis/src/RequireTracer.cpp @@ -0,0 +1,190 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/RequireTracer.h" + +#include "Luau/Ast.h" +#include "Luau/Module.h" + +LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) + +namespace Luau +{ + +namespace +{ + +struct RequireTracer : AstVisitor +{ + explicit RequireTracer(FileResolver* fileResolver, ModuleName currentModuleName) + : fileResolver(fileResolver) + , currentModuleName(std::move(currentModuleName)) + { + } + + FileResolver* const fileResolver; + ModuleName currentModuleName; + DenseHashMap locals{0}; + RequireTraceResult result; + + std::optional fromAstFragment(AstExpr* expr) + { + if (auto g = expr->as(); g && g->name == "script") + return currentModuleName; + + return fileResolver->fromAstFragment(expr); + } + + bool visit(AstStatLocal* stat) override + { + for (size_t i = 0; i < stat->vars.size; ++i) + { + AstLocal* local = stat->vars.data[i]; + + if (local->annotation) + { + if (AstTypeTypeof* ann = local->annotation->as()) + ann->expr->visit(this); + } + + if (i < stat->values.size) + { + AstExpr* expr = stat->values.data[i]; + expr->visit(this); + + const ModuleName* name = result.exprs.find(expr); + if (name) + locals[local] = *name; + } + } + + return false; + } + + bool visit(AstExprGlobal* global) override + { + std::optional name = fromAstFragment(global); + if (name) + result.exprs[global] = *name; + + return false; + } + + bool visit(AstExprLocal* local) override + { + const ModuleName* name = locals.find(local->local); + if (name) + result.exprs[local] = *name; + + return false; + } + + bool visit(AstExprIndexName* indexName) override + { + indexName->expr->visit(this); + + const ModuleName* name = result.exprs.find(indexName->expr); + if (name) + { + if (indexName->index == "parent" || indexName->index == "Parent") + { + if (auto parent = fileResolver->getParentModuleName(*name)) + result.exprs[indexName] = *parent; + } + else + result.exprs[indexName] = fileResolver->concat(*name, indexName->index.value); + } + + return false; + } + + bool visit(AstExprIndexExpr* indexExpr) override + { + indexExpr->expr->visit(this); + + const ModuleName* name = result.exprs.find(indexExpr->expr); + const AstExprConstantString* str = indexExpr->index->as(); + if (name && str) + { + result.exprs[indexExpr] = fileResolver->concat(*name, std::string_view(str->value.data, str->value.size)); + } + + indexExpr->index->visit(this); + + return false; + } + + bool visit(AstExprTypeAssertion* expr) override + { + return false; + } + + // If we see game:GetService("StringLiteral") or Game:GetService("StringLiteral"), then rewrite to game.StringLiteral. + // Else traverse arguments and trace requires to them. + bool visit(AstExprCall* call) override + { + for (AstExpr* arg : call->args) + arg->visit(this); + + call->func->visit(this); + + AstExprGlobal* globalName = call->func->as(); + if (globalName && globalName->name == "require" && call->args.size >= 1) + { + if (const ModuleName* moduleName = result.exprs.find(call->args.data[0])) + result.requires.push_back({*moduleName, call->location}); + + return false; + } + + AstExprIndexName* indexName = call->func->as(); + if (!indexName) + return false; + + std::optional rootName = fromAstFragment(indexName->expr); + + if (FFlag::LuauTraceRequireLookupChild && !rootName) + { + if (const ModuleName* moduleName = result.exprs.find(indexName->expr)) + rootName = *moduleName; + } + + if (!rootName) + return false; + + bool supportedLookup = indexName->index == "GetService" || + (FFlag::LuauTraceRequireLookupChild && (indexName->index == "FindFirstChild" || indexName->index == "WaitForChild")); + + if (!supportedLookup) + return false; + + if (call->args.size != 1) + return false; + + AstExprConstantString* name = call->args.data[0]->as(); + if (!name) + return false; + + std::string_view v{name->value.data, name->value.size}; + if (v.end() != std::find(v.begin(), v.end(), '/')) + return false; + + result.exprs[call] = fileResolver->concat(*rootName, v); + + // 'WaitForChild' can be used on modules that are not awailable at the typecheck time, but will be awailable at runtime + // If we fail to find such module, we will not report an UnknownRequire error + if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") + result.optional[call] = true; + + return false; + } +}; + +} // anonymous namespace + +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName) +{ + RequireTracer tracer{fileResolver, std::move(currentModuleName)}; + root->visit(&tracer); + return tracer.result; +} + +} // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp new file mode 100644 index 0000000..7223998 --- /dev/null +++ b/Analysis/src/Substitution.cpp @@ -0,0 +1,530 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Substitution.h" + +#include "Luau/Common.h" + +#include +#include + +LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 0) +LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) +LUAU_FASTFLAG(LuauRankNTypes) + +namespace Luau +{ + +void Tarjan::visitChildren(TypeId ty, int index) +{ + ty = follow(ty); + + if (FFlag::LuauRankNTypes && ignoreChildren(ty)) + return; + + if (const FunctionTypeVar* ftv = get(ty)) + { + visitChild(ftv->argTypes); + visitChild(ftv->retType); + } + else if (const TableTypeVar* ttv = get(ty)) + { + LUAU_ASSERT(!ttv->boundTo); + for (const auto& [name, prop] : ttv->props) + visitChild(prop.type); + if (ttv->indexer) + { + visitChild(ttv->indexer->indexType); + visitChild(ttv->indexer->indexResultType); + } + for (TypeId itp : ttv->instantiatedTypeParams) + visitChild(itp); + } + else if (const MetatableTypeVar* mtv = get(ty)) + { + visitChild(mtv->table); + visitChild(mtv->metatable); + } + else if (const UnionTypeVar* utv = get(ty)) + { + for (TypeId opt : utv->options) + visitChild(opt); + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + for (TypeId part : itv->parts) + visitChild(part); + } +} + +void Tarjan::visitChildren(TypePackId tp, int index) +{ + tp = follow(tp); + + if (FFlag::LuauRankNTypes && ignoreChildren(tp)) + return; + + if (const TypePack* tpp = get(tp)) + { + for (TypeId tv : tpp->head) + visitChild(tv); + if (tpp->tail) + visitChild(*tpp->tail); + } + else if (const VariadicTypePack* vtp = get(tp)) + { + visitChild(vtp->ty); + } +} + +std::pair Tarjan::indexify(TypeId ty) +{ + ty = follow(ty); + + bool fresh = !typeToIndex.contains(ty); + int& index = typeToIndex[ty]; + + if (fresh) + { + index = int(indexToType.size()); + indexToType.push_back(ty); + indexToPack.push_back(nullptr); + onStack.push_back(false); + lowlink.push_back(index); + } + return {index, fresh}; +} + +std::pair Tarjan::indexify(TypePackId tp) +{ + tp = follow(tp); + + bool fresh = !packToIndex.contains(tp); + int& index = packToIndex[tp]; + + if (fresh) + { + index = int(indexToPack.size()); + indexToType.push_back(nullptr); + indexToPack.push_back(tp); + onStack.push_back(false); + lowlink.push_back(index); + } + return {index, fresh}; +} + +void Tarjan::visitChild(TypeId ty) +{ + ty = follow(ty); + + edgesTy.push_back(ty); + edgesTp.push_back(nullptr); +} + +void Tarjan::visitChild(TypePackId tp) +{ + tp = follow(tp); + + edgesTy.push_back(nullptr); + edgesTp.push_back(tp); +} + +TarjanResult Tarjan::loop() +{ + // Normally Tarjan is presented recursively, but this is a hot loop, so worth optimizing + while (!worklist.empty()) + { + auto [index, currEdge, lastEdge] = worklist.back(); + + // First visit + if (currEdge == -1) + { + ++childCount; + if (FInt::LuauTarjanChildLimit > 0 && FInt::LuauTarjanChildLimit < childCount) + return TarjanResult::TooManyChildren; + + stack.push_back(index); + onStack[index] = true; + + currEdge = int(edgesTy.size()); + + // Fill in edge list of this vertex + if (TypeId ty = indexToType[index]) + visitChildren(ty, index); + else if (TypePackId tp = indexToPack[index]) + visitChildren(tp, index); + + lastEdge = int(edgesTy.size()); + } + + // Visit children + bool foundFresh = false; + + for (; currEdge < lastEdge; currEdge++) + { + int childIndex = -1; + bool fresh = false; + + if (auto ty = edgesTy[currEdge]) + std::tie(childIndex, fresh) = indexify(ty); + else if (auto tp = edgesTp[currEdge]) + std::tie(childIndex, fresh) = indexify(tp); + else + LUAU_ASSERT(false); + + if (fresh) + { + // Original recursion point, update the parent continuation point and start the new element + worklist.back() = {index, currEdge + 1, lastEdge}; + worklist.push_back({childIndex, -1, -1}); + + // We need to continue the top-level loop from the start with the new worklist element + foundFresh = true; + break; + } + else if (onStack[childIndex]) + { + lowlink[index] = std::min(lowlink[index], childIndex); + } + + visitEdge(childIndex, index); + } + + if (foundFresh) + continue; + + if (lowlink[index] == index) + { + visitSCC(index); + while (!stack.empty()) + { + int popped = stack.back(); + stack.pop_back(); + onStack[popped] = false; + if (popped == index) + break; + } + } + + worklist.pop_back(); + + // Original return from recursion into a child + if (!worklist.empty()) + { + auto [parentIndex, _, parentEndEdge] = worklist.back(); + + // No need to keep child edges around + edgesTy.resize(parentEndEdge); + edgesTp.resize(parentEndEdge); + + lowlink[parentIndex] = std::min(lowlink[parentIndex], lowlink[index]); + visitEdge(index, parentIndex); + } + } + + return TarjanResult::Ok; +} + +void Tarjan::clear() +{ + typeToIndex.clear(); + indexToType.clear(); + packToIndex.clear(); + indexToPack.clear(); + lowlink.clear(); + stack.clear(); + onStack.clear(); + + edgesTy.clear(); + edgesTp.clear(); + worklist.clear(); +} + +TarjanResult Tarjan::visitRoot(TypeId ty) +{ + childCount = 0; + ty = follow(ty); + + clear(); + auto [index, fresh] = indexify(ty); + worklist.push_back({index, -1, -1}); + return loop(); +} + +TarjanResult Tarjan::visitRoot(TypePackId tp) +{ + childCount = 0; + tp = follow(tp); + + clear(); + auto [index, fresh] = indexify(tp); + worklist.push_back({index, -1, -1}); + return loop(); +} + +bool FindDirty::getDirty(int index) +{ + if (dirty.size() <= size_t(index)) + dirty.resize(index + 1, false); + return dirty[index]; +} + +void FindDirty::setDirty(int index, bool d) +{ + if (dirty.size() <= size_t(index)) + dirty.resize(index + 1, false); + dirty[index] = d; +} + +void FindDirty::visitEdge(int index, int parentIndex) +{ + if (getDirty(index)) + setDirty(parentIndex, true); +} + +void FindDirty::visitSCC(int index) +{ + bool d = getDirty(index); + + for (auto it = stack.rbegin(); !d && it != stack.rend(); it++) + { + if (TypeId ty = indexToType[*it]) + d = isDirty(ty); + else if (TypePackId tp = indexToPack[*it]) + d = isDirty(tp); + if (*it == index) + break; + } + + if (!d) + return; + + for (auto it = stack.rbegin(); it != stack.rend(); it++) + { + setDirty(*it, true); + if (TypeId ty = indexToType[*it]) + foundDirty(ty); + else if (TypePackId tp = indexToPack[*it]) + foundDirty(tp); + if (*it == index) + return; + } +} + +TarjanResult FindDirty::findDirty(TypeId ty) +{ + dirty.clear(); + return visitRoot(ty); +} + +TarjanResult FindDirty::findDirty(TypePackId tp) +{ + dirty.clear(); + return visitRoot(tp); +} + +std::optional Substitution::substitute(TypeId ty) +{ + ty = follow(ty); + newTypes.clear(); + newPacks.clear(); + + auto result = findDirty(ty); + if (result != TarjanResult::Ok) + return std::nullopt; + + for (auto [oldTy, newTy] : newTypes) + replaceChildren(newTy); + for (auto [oldTp, newTp] : newPacks) + replaceChildren(newTp); + TypeId newTy = replace(ty); + return newTy; +} + +std::optional Substitution::substitute(TypePackId tp) +{ + tp = follow(tp); + newTypes.clear(); + newPacks.clear(); + + auto result = findDirty(tp); + if (result != TarjanResult::Ok) + return std::nullopt; + + for (auto [oldTy, newTy] : newTypes) + replaceChildren(newTy); + for (auto [oldTp, newTp] : newPacks) + replaceChildren(newTp); + TypePackId newTp = replace(tp); + return newTp; +} + +TypeId Substitution::clone(TypeId ty) +{ + ty = follow(ty); + + TypeId result = ty; + + if (const FunctionTypeVar* ftv = get(ty)) + { + FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + clone.generics = ftv->generics; + clone.genericPacks = ftv->genericPacks; + clone.magicFunction = ftv->magicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + result = addType(std::move(clone)); + } + else if (const TableTypeVar* ttv = get(ty)) + { + LUAU_ASSERT(!ttv->boundTo); + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + clone.definitionModuleName = ttv->definitionModuleName; + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.instantiatedTypeParams = ttv->instantiatedTypeParams; + if (FFlag::LuauSecondTypecheckKnowsTheDataModel) + clone.tags = ttv->tags; + result = addType(std::move(clone)); + } + else if (const MetatableTypeVar* mtv = get(ty)) + { + MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; + clone.syntheticName = mtv->syntheticName; + result = addType(std::move(clone)); + } + else if (const UnionTypeVar* utv = get(ty)) + { + UnionTypeVar clone; + clone.options = utv->options; + result = addType(std::move(clone)); + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + IntersectionTypeVar clone; + clone.parts = itv->parts; + result = addType(std::move(clone)); + } + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; +} + +TypePackId Substitution::clone(TypePackId tp) +{ + tp = follow(tp); + if (const TypePack* tpp = get(tp)) + { + TypePack clone; + clone.head = tpp->head; + clone.tail = tpp->tail; + return addTypePack(std::move(clone)); + } + else if (const VariadicTypePack* vtp = get(tp)) + { + VariadicTypePack clone; + clone.ty = vtp->ty; + return addTypePack(std::move(clone)); + } + else + return tp; +} + +void Substitution::foundDirty(TypeId ty) +{ + ty = follow(ty); + if (isDirty(ty)) + newTypes[ty] = clean(ty); + else + newTypes[ty] = clone(ty); +} + +void Substitution::foundDirty(TypePackId tp) +{ + tp = follow(tp); + if (isDirty(tp)) + newPacks[tp] = clean(tp); + else + newPacks[tp] = clone(tp); +} + +TypeId Substitution::replace(TypeId ty) +{ + ty = follow(ty); + if (TypeId* prevTy = newTypes.find(ty)) + return *prevTy; + else + return ty; +} + +TypePackId Substitution::replace(TypePackId tp) +{ + tp = follow(tp); + if (TypePackId* prevTp = newPacks.find(tp)) + return *prevTp; + else + return tp; +} + +void Substitution::replaceChildren(TypeId ty) +{ + ty = follow(ty); + + if (FFlag::LuauRankNTypes && ignoreChildren(ty)) + return; + + if (FunctionTypeVar* ftv = getMutable(ty)) + { + ftv->argTypes = replace(ftv->argTypes); + ftv->retType = replace(ftv->retType); + } + else if (TableTypeVar* ttv = getMutable(ty)) + { + LUAU_ASSERT(!ttv->boundTo); + for (auto& [name, prop] : ttv->props) + prop.type = replace(prop.type); + if (ttv->indexer) + { + ttv->indexer->indexType = replace(ttv->indexer->indexType); + ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType); + } + for (TypeId& itp : ttv->instantiatedTypeParams) + itp = replace(itp); + } + else if (MetatableTypeVar* mtv = getMutable(ty)) + { + mtv->table = replace(mtv->table); + mtv->metatable = replace(mtv->metatable); + } + else if (UnionTypeVar* utv = getMutable(ty)) + { + for (TypeId& opt : utv->options) + opt = replace(opt); + } + else if (IntersectionTypeVar* itv = getMutable(ty)) + { + for (TypeId& part : itv->parts) + part = replace(part); + } +} + +void Substitution::replaceChildren(TypePackId tp) +{ + tp = follow(tp); + + if (FFlag::LuauRankNTypes && ignoreChildren(tp)) + return; + + if (TypePack* tpp = getMutable(tp)) + { + for (TypeId& tv : tpp->head) + tv = replace(tv); + if (tpp->tail) + tpp->tail = replace(*tpp->tail); + } + else if (VariadicTypePack* vtp = getMutable(tp)) + { + vtp->ty = replace(vtp->ty); + } +} + +} // namespace Luau diff --git a/Analysis/src/Symbol.cpp b/Analysis/src/Symbol.cpp new file mode 100644 index 0000000..5922bb5 --- /dev/null +++ b/Analysis/src/Symbol.cpp @@ -0,0 +1,18 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Symbol.h" + +#include "Luau/Common.h" + +namespace Luau +{ + +std::string toString(const Symbol& name) +{ + if (name.local) + return name.local->name.value; + + LUAU_ASSERT(name.global.value); + return name.global.value; +} + +} // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp new file mode 100644 index 0000000..9d2f47b --- /dev/null +++ b/Analysis/src/ToString.cpp @@ -0,0 +1,1142 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ToString.h" + +#include "Luau/TypeInfer.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include +#include + +LUAU_FASTFLAG(LuauToStringFollowsBoundTo) +LUAU_FASTFLAG(LuauExtraNilRecovery) +LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) +LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) + +namespace Luau +{ + +namespace +{ + +struct FindCyclicTypes +{ + FindCyclicTypes() = default; + FindCyclicTypes(const FindCyclicTypes&) = delete; + FindCyclicTypes& operator=(const FindCyclicTypes&) = delete; + + bool exhaustive = false; + std::unordered_set visited; + std::unordered_set visitedPacks; + std::unordered_set cycles; + std::unordered_set cycleTPs; + + void cycle(TypeId ty) + { + cycles.insert(ty); + } + + void cycle(TypePackId tp) + { + cycleTPs.insert(tp); + } + + template + bool operator()(TypeId ty, const T&) + { + return visited.insert(ty).second; + } + + bool operator()(TypeId ty, const TableTypeVar& ttv) = delete; + + bool operator()(TypeId ty, const TableTypeVar& ttv, std::unordered_set& seen) + { + if (!visited.insert(ty).second) + return false; + + if (ttv.name || ttv.syntheticName) + { + for (TypeId itp : ttv.instantiatedTypeParams) + visitTypeVar(itp, *this, seen); + return exhaustive; + } + + return true; + } + + bool operator()(TypeId, const ClassTypeVar&) + { + return false; + } + + template + bool operator()(TypePackId tp, const T&) + { + return visitedPacks.insert(tp).second; + } +}; + +template +void findCyclicTypes(std::unordered_set& cycles, std::unordered_set& cycleTPs, TID ty, bool exhaustive) +{ + FindCyclicTypes fct; + fct.exhaustive = exhaustive; + visitTypeVar(ty, fct); + + cycles = std::move(fct.cycles); + cycleTPs = std::move(fct.cycleTPs); +} + +} // namespace + +static std::pair> canUseTypeNameInScope(ScopePtr scope, const std::string& name) +{ + for (ScopePtr curr = scope; curr; curr = curr->parent) + { + for (const auto& [importName, nameTable] : curr->importedTypeBindings) + { + if (nameTable.count(name)) + return {true, importName}; + } + + if (curr->exportedTypeBindings.count(name)) + return {true, std::nullopt}; + } + + return {false, std::nullopt}; +} + +struct StringifierState +{ + const ToStringOptions& opts; + ToStringResult& result; + + std::unordered_map cycleNames; + std::unordered_map cycleTpNames; + std::unordered_set seen; + std::unordered_set usedNames; + + bool exhaustive; + + StringifierState(const ToStringOptions& opts, ToStringResult& result, const std::optional& nameMap) + : opts(opts) + , result(result) + , exhaustive(opts.exhaustive) + { + if (nameMap) + result.nameMap = *nameMap; + + for (const auto& [_, v] : result.nameMap.typeVars) + usedNames.insert(v); + for (const auto& [_, v] : result.nameMap.typePacks) + usedNames.insert(v); + } + + bool hasSeen(const void* tv) + { + void* ttv = const_cast(tv); + if (seen.find(ttv) != seen.end()) + return true; + + seen.insert(ttv); + return false; + } + + void unsee(const void* tv) + { + void* ttv = const_cast(tv); + auto iter = seen.find(ttv); + if (iter != seen.end()) + seen.erase(iter); + } + + static std::string generateName(size_t i) + { + std::string n; + n = char('a' + i % 26); + if (i >= 26) + n += std::to_string(i / 26); + return n; + } + + std::string getName(TypeId ty) + { + const size_t s = result.nameMap.typeVars.size(); + std::string& n = result.nameMap.typeVars[ty]; + if (!n.empty()) + return n; + + for (int count = 0; count < 256; ++count) + { + std::string candidate = generateName(usedNames.size() + count); + if (!usedNames.count(candidate)) + { + usedNames.insert(candidate); + n = candidate; + return candidate; + } + } + + return generateName(s); + } + + std::string getName(TypePackId ty) + { + const size_t s = result.nameMap.typePacks.size(); + std::string& n = result.nameMap.typePacks[ty]; + if (!n.empty()) + return n; + + for (int count = 0; count < 256; ++count) + { + std::string candidate = generateName(usedNames.size() + count); + if (!usedNames.count(candidate)) + { + usedNames.insert(candidate); + n = candidate; + return candidate; + } + } + + return generateName(s); + } + + void emit(const std::string& s) + { + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + return; + + result.name += s; + } +}; + +struct TypeVarStringifier +{ + StringifierState& state; + + explicit TypeVarStringifier(StringifierState& state) + : state(state) + { + } + + void stringify(TypeId tv) + { + if (state.opts.maxTypeLength > 0 && state.result.name.length() > state.opts.maxTypeLength) + return; + + if (tv->ty.valueless_by_exception()) + { + state.result.error = true; + state.emit("< VALUELESS BY EXCEPTION >"); + return; + } + + auto it = state.cycleNames.find(tv); + if (it != state.cycleNames.end()) + { + state.emit(it->second); + return; + } + + if (!FFlag::LuauAddMissingFollow) + { + if (get(tv)) + { + state.emit(state.getName(tv)); + return; + } + } + + Luau::visit( + [this, tv](auto&& t) { + return (*this)(tv, t); + }, + tv->ty); + } + + void stringify(TypePackId tp); + void stringify(TypePackId tpid, const std::vector>& names); + + void stringify(const std::vector& types) + { + if (types.size() == 0) + return; + + if (types.size()) + state.emit("<"); + + for (size_t i = 0; i < types.size(); ++i) + { + if (i > 0) + state.emit(", "); + + stringify(types[i]); + } + + if (types.size()) + state.emit(">"); + } + + void operator()(TypeId ty, const Unifiable::Free& ftv) + { + state.result.invalid = true; + + if (FFlag::LuauAddMissingFollow) + state.emit(state.getName(ty)); + else + state.emit(""); + } + + void operator()(TypeId, const BoundTypeVar& btv) + { + stringify(btv.boundTo); + } + + void operator()(TypeId ty, const Unifiable::Generic& gtv) + { + if (gtv.explicitName) + { + state.result.nameMap.typeVars[ty] = gtv.name; + state.emit(gtv.name); + } + else + state.emit(state.getName(ty)); + } + + void operator()(TypeId, const PrimitiveTypeVar& ptv) + { + switch (ptv.type) + { + case PrimitiveTypeVar::NilType: + state.emit("nil"); + return; + case PrimitiveTypeVar::Boolean: + state.emit("boolean"); + return; + case PrimitiveTypeVar::Number: + state.emit("number"); + return; + case PrimitiveTypeVar::String: + state.emit("string"); + return; + case PrimitiveTypeVar::Thread: + state.emit("thread"); + return; + default: + LUAU_ASSERT(!"Unknown primitive type"); + throw std::runtime_error("Unknown primitive type " + std::to_string(ptv.type)); + } + } + + void operator()(TypeId, const FunctionTypeVar& ftv) + { + if (state.hasSeen(&ftv)) + { + state.result.cycle = true; + state.emit(""); + return; + } + + if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0) + { + state.emit("<"); + bool comma = false; + for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) + { + if (comma) + state.emit(", "); + comma = true; + stringify(*it); + } + for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) + { + if (comma) + state.emit(", "); + comma = true; + stringify(*it); + } + state.emit(">"); + } + + state.emit("("); + + if (state.opts.functionTypeArguments) + stringify(ftv.argTypes, ftv.argNames); + else + stringify(ftv.argTypes); + + state.emit(") -> "); + + bool plural = true; + if (auto retPack = get(follow(ftv.retType))) + { + if (retPack->head.size() == 1 && !retPack->tail) + plural = false; + } + + if (plural) + state.emit("("); + + stringify(ftv.retType); + + if (plural) + state.emit(")"); + + state.unsee(&ftv); + } + + void operator()(TypeId, const TableTypeVar& ttv) + { + if (FFlag::LuauToStringFollowsBoundTo && ttv.boundTo) + return stringify(*ttv.boundTo); + + if (!state.exhaustive) + { + if (ttv.name) + { + // If scope if provided, add module name and check visibility + if (state.opts.scope) + { + auto [success, moduleName] = canUseTypeNameInScope(state.opts.scope, *ttv.name); + + if (!success) + state.result.invalid = true; + + if (moduleName) + { + state.emit(*moduleName); + state.emit("."); + } + } + + state.emit(*ttv.name); + stringify(ttv.instantiatedTypeParams); + return; + } + if (ttv.syntheticName) + { + state.result.invalid = true; + state.emit(*ttv.syntheticName); + stringify(ttv.instantiatedTypeParams); + return; + } + } + + if (state.hasSeen(&ttv)) + { + state.result.cycle = true; + state.emit(""); + return; + } + + std::string openbrace = "@@@"; + std::string closedbrace = "@@@?!"; + switch (state.opts.hideTableKind ? TableState::Unsealed : ttv.state) + { + case TableState::Sealed: + state.result.invalid = true; + openbrace = "{| "; + closedbrace = " |}"; + break; + case TableState::Unsealed: + openbrace = "{ "; + closedbrace = " }"; + break; + case TableState::Free: + state.result.invalid = true; + openbrace = "{- "; + closedbrace = " -}"; + break; + case TableState::Generic: + state.result.invalid = true; + openbrace = "{+ "; + closedbrace = " +}"; + break; + } + + // If this appears to be an array, we want to stringify it using the {T} syntax. + if (ttv.indexer && ttv.props.empty() && isNumber(ttv.indexer->indexType)) + { + state.emit("{"); + stringify(ttv.indexer->indexResultType); + state.emit("}"); + return; + } + + state.emit(openbrace); + + bool comma = false; + if (ttv.indexer) + { + state.emit("["); + stringify(ttv.indexer->indexType); + state.emit("]: "); + stringify(ttv.indexer->indexResultType); + comma = true; + } + + size_t index = 0; + size_t oldLength = state.result.name.length(); + for (const auto& [name, prop] : ttv.props) + { + if (comma) + state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + + size_t length = state.result.name.length() - oldLength; + + if (state.opts.maxTableLength > 0 && (length - 2 * index) >= state.opts.maxTableLength) + { + state.emit("... "); + state.emit(std::to_string(ttv.props.size() - index)); + state.emit(" more ..."); + break; + } + + state.emit(name); + state.emit(": "); + stringify(prop.type); + comma = true; + ++index; + } + + state.emit(closedbrace); + + state.unsee(&ttv); + } + + void operator()(TypeId, const MetatableTypeVar& mtv) + { + state.result.invalid = true; + state.emit("{ @metatable "); + stringify(mtv.metatable); + state.emit(state.opts.useLineBreaks ? ",\n" : ", "); + stringify(mtv.table); + state.emit(" }"); + } + + void operator()(TypeId, const ClassTypeVar& ctv) + { + state.emit(ctv.name); + } + + void operator()(TypeId, const AnyTypeVar&) + { + state.emit("any"); + } + + void operator()(TypeId, const UnionTypeVar& uv) + { + if (state.hasSeen(&uv)) + { + state.result.cycle = true; + state.emit(""); + return; + } + + bool optional = false; + + std::vector results = {}; + for (auto el : &uv) + { + if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) + el = follow(el); + + if (isNil(el)) + { + optional = true; + continue; + } + + std::string saved = std::move(state.result.name); + + bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions + ? !state.cycleNames.count(el) && (get(el) || get(el)) + : get(el) || get(el); + + if (needParens) + state.emit("("); + + stringify(el); + + if (needParens) + state.emit(")"); + + results.push_back(std::move(state.result.name)); + state.result.name = std::move(saved); + } + + state.unsee(&uv); + + std::sort(results.begin(), results.end()); + + if (optional && results.size() > 1) + state.emit("("); + + bool first = true; + for (std::string& ss : results) + { + if (!first) + state.emit(" | "); + state.emit(ss); + first = false; + } + + if (optional) + { + const char* s = "?"; + if (results.size() > 1) + s = ")?"; + + state.emit(s); + } + } + + void operator()(TypeId, const IntersectionTypeVar& uv) + { + if (state.hasSeen(&uv)) + { + state.result.cycle = true; + state.emit(""); + return; + } + + std::vector results = {}; + for (auto el : uv.parts) + { + if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) + el = follow(el); + + std::string saved = std::move(state.result.name); + + bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions + ? !state.cycleNames.count(el) && (get(el) || get(el)) + : get(el) || get(el); + + if (needParens) + state.emit("("); + + stringify(el); + + if (needParens) + state.emit(")"); + + results.push_back(std::move(state.result.name)); + state.result.name = std::move(saved); + } + + state.unsee(&uv); + + std::sort(results.begin(), results.end()); + + bool first = true; + for (std::string& ss : results) + { + if (!first) + state.emit(" & "); + state.emit(ss); + first = false; + } + } + + void operator()(TypeId, const ErrorTypeVar& tv) + { + state.result.error = true; + state.emit("*unknown*"); + } + + void operator()(TypeId, const LazyTypeVar& ltv) + { + state.result.invalid = true; + state.emit("lazy?"); + } + +}; // namespace + +struct TypePackStringifier +{ + StringifierState& state; + + const std::vector> elemNames; + static inline const std::vector> dummyElemNames = {}; + unsigned elemIndex = 0; + + explicit TypePackStringifier(StringifierState& state, const std::vector>& elemNames) + : state(state) + , elemNames(elemNames) + { + } + + explicit TypePackStringifier(StringifierState& state) + : state(state) + , elemNames(dummyElemNames) + { + } + + void stringify(TypeId tv) + { + TypeVarStringifier tvs{state}; + tvs.stringify(tv); + } + + void stringify(TypePackId tp) + { + if (state.opts.maxTypeLength > 0 && state.result.name.length() > state.opts.maxTypeLength) + return; + + if (tp->ty.valueless_by_exception()) + { + state.result.error = true; + state.emit("< VALUELESS TP BY EXCEPTION >"); + return; + } + + if (!FFlag::LuauAddMissingFollow) + { + if (get(tp)) + { + state.emit(state.getName(tp)); + state.emit("..."); + return; + } + } + + auto it = state.cycleTpNames.find(tp); + if (it != state.cycleTpNames.end()) + { + state.emit(it->second); + return; + } + + Luau::visit( + [this, tp](auto&& t) { + return (*this)(tp, t); + }, + tp->ty); + } + + void operator()(TypePackId, const TypePack& tp) + { + if (state.hasSeen(&tp)) + { + state.result.cycle = true; + state.emit(""); + return; + } + + bool first = true; + + for (const auto& typeId : tp.head) + { + if (first) + first = false; + else + state.emit(", "); + + LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); + + if (!elemNames.empty() && elemNames[elemIndex]) + { + state.emit(elemNames[elemIndex]->name); + state.emit(": "); + } + elemIndex++; + + stringify(typeId); + } + + if (tp.tail && !isEmpty(*tp.tail)) + { + const auto& tail = *tp.tail; + if (first) + first = false; + else + state.emit(", "); + + stringify(tail); + } + + state.unsee(&tp); + } + + void operator()(TypePackId, const Unifiable::Error& error) + { + state.result.error = true; + state.emit("*unknown*"); + } + + void operator()(TypePackId, const VariadicTypePack& pack) + { + state.emit("..."); + stringify(pack.ty); + } + + void operator()(TypePackId tp, const GenericTypePack& pack) + { + if (pack.explicitName) + { + state.result.nameMap.typePacks[tp] = pack.name; + state.emit(pack.name); + } + else + { + state.emit(state.getName(tp)); + } + state.emit("..."); + } + + void operator()(TypePackId tp, const FreeTypePack& pack) + { + state.result.invalid = true; + + if (FFlag::LuauAddMissingFollow) + { + state.emit(state.getName(tp)); + state.emit("..."); + } + else + { + state.emit(""); + } + } + + void operator()(TypePackId, const BoundTypePack& btv) + { + stringify(btv.boundTo); + } +}; + +void TypeVarStringifier::stringify(TypePackId tp) +{ + TypePackStringifier tps(state); + tps.stringify(tp); +} + +void TypeVarStringifier::stringify(TypePackId tpid, const std::vector>& names) +{ + TypePackStringifier tps(state, names); + tps.stringify(tpid); +} + +static void assignCycleNames(const std::unordered_set& cycles, const std::unordered_set& cycleTPs, + std::unordered_map& cycleNames, std::unordered_map& cycleTpNames, bool exhaustive) +{ + int nextIndex = 1; + + std::vector sortedCycles{cycles.begin(), cycles.end()}; + std::sort(sortedCycles.begin(), sortedCycles.end(), std::less{}); + + for (TypeId cycleTy : sortedCycles) + { + std::string name; + + // TODO: use the stringified type list if there are no cycles + if (FFlag::LuauInstantiatedTypeParamRecursion) + { + if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) + { + // If we have a cycle type in type parameters, assign a cycle name for this named table + if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { + return cycles.count(follow(el)); + }) != ttv->instantiatedTypeParams.end()) + cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; + + continue; + } + } + else + { + if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) + continue; + } + + name = "t" + std::to_string(nextIndex); + ++nextIndex; + + cycleNames[cycleTy] = std::move(name); + } + + std::vector sortedCycleTps{cycleTPs.begin(), cycleTPs.end()}; + std::sort(sortedCycleTps.begin(), sortedCycleTps.end(), std::less()); + + for (TypePackId tp : sortedCycleTps) + { + std::string name = "tp" + std::to_string(nextIndex); + ++nextIndex; + cycleTpNames[tp] = std::move(name); + } +} + +ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) +{ + /* + * 1. Walk the TypeVar and track seen TypeIds. When you reencounter a TypeId, add it to a set of seen cycles. + * 2. Generate some names for each cycle. For a starting point, we can just call them t0, t1 and so on. + * 3. For each seen cycle, stringify it like we do now, but replace each known cycle with its name. + * 4. Print out the root of the type using the same algorithm as step 3. + */ + ty = follow(ty); + + ToStringResult result; + + if (!FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive) + { + if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) + { + if (ttv->syntheticName) + result.invalid = true; + + // If scope if provided, add module name and check visibility + if (ttv->name && opts.scope) + { + auto [success, moduleName] = canUseTypeNameInScope(opts.scope, *ttv->name); + + if (!success) + result.invalid = true; + + if (moduleName) + result.name = format("%s.", moduleName->c_str()); + } + + result.name += ttv->name ? *ttv->name : *ttv->syntheticName; + + if (ttv->instantiatedTypeParams.empty()) + return result; + + std::vector params; + for (TypeId tp : ttv->instantiatedTypeParams) + params.push_back(toString(tp)); + + result.name += "<" + join(params, ", ") + ">"; + return result; + } + else if (auto mtv = get(ty); mtv && mtv->syntheticName) + { + result.invalid = true; + result.name = *mtv->syntheticName; + return result; + } + } + + StringifierState state{opts, result, opts.nameMap}; + + std::unordered_set cycles; + std::unordered_set cycleTPs; + + findCyclicTypes(cycles, cycleTPs, ty, opts.exhaustive); + + assignCycleNames(cycles, cycleTPs, state.cycleNames, state.cycleTpNames, opts.exhaustive); + + TypeVarStringifier tvs{state}; + + if (FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive) + { + if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) + { + if (ttv->syntheticName) + result.invalid = true; + + // If scope if provided, add module name and check visibility + if (ttv->name && opts.scope) + { + auto [success, moduleName] = canUseTypeNameInScope(opts.scope, *ttv->name); + + if (!success) + result.invalid = true; + + if (moduleName) + result.name = format("%s.", moduleName->c_str()); + } + + result.name += ttv->name ? *ttv->name : *ttv->syntheticName; + + if (ttv->instantiatedTypeParams.empty()) + return result; + + result.name += "<"; + + bool first = true; + for (TypeId ty : ttv->instantiatedTypeParams) + { + if (!first) + result.name += ", "; + else + first = false; + + tvs.stringify(ty); + } + + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + { + result.truncated = true; + result.name += "... "; + } + else + { + result.name += ">"; + } + + return result; + } + else if (auto mtv = get(ty); mtv && mtv->syntheticName) + { + result.invalid = true; + result.name = *mtv->syntheticName; + return result; + } + } + + /* If the root itself is a cycle, we special case a little. + * We go out of our way to print the following: + * + * t1 where t1 = the_whole_root_type + */ + auto it = state.cycleNames.find(ty); + if (it != state.cycleNames.end()) + state.emit(it->second); + else + tvs.stringify(ty); + + if (!state.cycleNames.empty()) + { + result.cycle = true; + state.emit(" where "); + } + + state.exhaustive = true; + + std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; + std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { + return a.second < b.second; + }); + + bool semi = false; + for (const auto& [cycleTy, name] : sortedCycleNames) + { + if (semi) + state.emit(" ; "); + + state.emit(name); + state.emit(" = "); + Luau::visit( + [&tvs, cycleTy = cycleTy](auto&& t) { + return tvs(cycleTy, t); + }, + cycleTy->ty); + + semi = true; + } + + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + { + result.truncated = true; + result.name += "... "; + } + + return result; +} + +ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) +{ + /* + * 1. Walk the TypeVar and track seen TypeIds. When you reencounter a TypeId, add it to a set of seen cycles. + * 2. Generate some names for each cycle. For a starting point, we can just call them t0, t1 and so on. + * 3. For each seen cycle, stringify it like we do now, but replace each known cycle with its name. + * 4. Print out the root of the type using the same algorithm as step 3. + */ + ToStringResult result; + StringifierState state{opts, result, opts.nameMap}; + + std::unordered_set cycles; + std::unordered_set cycleTPs; + + findCyclicTypes(cycles, cycleTPs, tp, opts.exhaustive); + + assignCycleNames(cycles, cycleTPs, state.cycleNames, state.cycleTpNames, opts.exhaustive); + + TypeVarStringifier tvs{state}; + + /* If the root itself is a cycle, we special case a little. + * We go out of our way to print the following: + * + * t1 where t1 = the_whole_root_type + */ + auto it = state.cycleTpNames.find(tp); + if (it != state.cycleTpNames.end()) + state.emit(it->second); + else + tvs.stringify(tp); + + if (!cycles.empty()) + { + result.cycle = true; + state.emit(" where "); + } + + state.exhaustive = true; + + std::vector> sortedCycleNames{state.cycleNames.begin(), state.cycleNames.end()}; + std::sort(sortedCycleNames.begin(), sortedCycleNames.end(), [](const auto& a, const auto& b) { + return a.second < b.second; + }); + + bool semi = false; + for (const auto& [cycleTy, name] : sortedCycleNames) + { + if (semi) + state.emit(" ; "); + + state.emit(name); + state.emit(" = "); + Luau::visit( + [&tvs, cycleTy = cycleTy](auto&& t) { + return tvs(cycleTy, t); + }, + cycleTy->ty); + + semi = true; + } + + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + result.name += "... "; + + return result; +} + +std::string toString(TypeId ty, const ToStringOptions& opts) +{ + return toStringDetailed(ty, opts).name; +} + +std::string toString(TypePackId tp, const ToStringOptions& opts) +{ + return toStringDetailed(tp, opts).name; +} + +std::string toString(const TypeVar& tv, const ToStringOptions& opts) +{ + return toString(const_cast(&tv), std::move(opts)); +} + +std::string toString(const TypePackVar& tp, const ToStringOptions& opts) +{ + return toString(const_cast(&tp), std::move(opts)); +} + +void dump(TypeId ty) +{ + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + printf("%s\n", toString(ty, opts).c_str()); +} + +void dump(TypePackId ty) +{ + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + printf("%s\n", toString(ty, opts).c_str()); +} + +} // namespace Luau diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp new file mode 100644 index 0000000..2d35638 --- /dev/null +++ b/Analysis/src/TopoSortStatements.cpp @@ -0,0 +1,552 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TopoSortStatements.h" + +/* Decide the order in which we typecheck Lua statements in a block. + * + * Algorithm: + * + * 1. Build up a dependency graph. + * i. An AstStat is said to depend on another AstStat if it refers to it in any child node. + * A dependency is the relationship between the declaration of a symbol and its uses. + * ii. Additionally, statements that do not define functions have a dependency on the previous non-function statement. We do this + * to prevent the algorithm from checking imperative statements out-of-order. + * 2. Walk each node in the graph in lexical order. For each node: + * i. Select the next thing `t` + * ii. If `t` has no dependencies at all and is not a function definition, check it now + * iii. If `t` is a function definition or an expression that does not include a function call, add it to a queue `Q`. + * iv. Else, toposort `Q` and check things until it is possible to check `t` + * * If this fails, we expect the Lua runtime to also fail, as the code is trying to use a symbol before it has been defined. + * 3. Toposort whatever remains in `Q` and check it all. + * + * The end result that we want satisfies a few qualities: + * + * 1. Things are generally checked in lexical order. + * 2. If a function F calls another function G that is declared out-of-order, but in a way that will work when the code is actually run, we want + * to check G before F. + * 3. Cyclic dependencies can be resolved by picking an arbitrary statement to check first. + */ + +#include "Luau/Parser.h" +#include "Luau/DenseHash.h" +#include "Luau/Common.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Luau +{ + +// For some reason, natvis interacts really poorly with anonymous data types +namespace detail +{ + +struct Identifier +{ + std::string name; // A nice textual name + const AstLocal* ctx; // Only used to disambiguate potentially shadowed names +}; + +bool operator==(const Identifier& lhs, const Identifier& rhs) +{ + return lhs.name == rhs.name && lhs.ctx == rhs.ctx; +} + +struct IdentifierHash +{ + size_t operator()(const Identifier& ident) const + { + return std::hash()(ident.name) ^ std::hash()(ident.ctx); + } +}; + +struct Node; + +struct Arcs +{ + std::set provides; + std::set depends; +}; + +struct Node : Arcs +{ + std::optional name; + AstStat* element; + + Node(const std::optional& name, AstStat* el) + : name(name) + , element(el) + { + } +}; + +using NodeQueue = std::deque>; +using NodeList = std::list>; + +std::optional mkName(const AstExpr& expr); + +Identifier mkName(const AstLocal& local) +{ + return {local.name.value, &local}; +} + +Identifier mkName(const AstExprLocal& local) +{ + return mkName(*local.local); +} + +Identifier mkName(const AstExprGlobal& global) +{ + return {global.name.value, nullptr}; +} + +Identifier mkName(const AstName& name) +{ + return {name.value, nullptr}; +} + +std::optional mkName(const AstExprIndexName& expr) +{ + auto lhs = mkName(*expr.expr); + if (lhs) + { + std::string s = std::move(lhs->name); + s += "."; + s += expr.index.value; + return Identifier{std::move(s), lhs->ctx}; + } + else + return std::nullopt; +} + +Identifier mkName(const AstExprError& expr) +{ + return {format("error#%d", expr.messageIndex), nullptr}; +} + +std::optional mkName(const AstExpr& expr) +{ + if (auto l = expr.as()) + return mkName(*l); + else if (auto g = expr.as()) + return mkName(*g); + else if (auto i = expr.as()) + return mkName(*i); + else if (auto e = expr.as()) + return mkName(*e); + + return std::nullopt; +} + +Identifier mkName(const AstStatFunction& function) +{ + auto name = mkName(*function.name); + LUAU_ASSERT(bool(name)); + if (!name) + throw std::runtime_error("Internal error: Function declaration has a bad name"); + + return *name; +} + +Identifier mkName(const AstStatLocalFunction& function) +{ + return mkName(*function.name); +} + +std::optional mkName(const AstStatAssign& assign) +{ + if (assign.vars.size != 1) + return std::nullopt; + + return mkName(*assign.vars.data[0]); +} + +std::optional mkName(const AstStatLocal& local) +{ + if (local.vars.size != 1) + return std::nullopt; + + return mkName(*local.vars.data[0]); +} + +Identifier mkName(const AstStatTypeAlias& typealias) +{ + return mkName(typealias.name); +} + +std::optional mkName(AstStat* const el) +{ + if (auto function = el->as()) + return mkName(*function); + else if (auto function = el->as()) + return mkName(*function); + else if (auto assign = el->as()) + return mkName(*assign); + else if (auto local = el->as()) + return mkName(*local); + else if (auto typealias = el->as()) + return mkName(*typealias); + + return std::nullopt; +} + +struct ArcCollector : public AstVisitor +{ + NodeQueue& queue; + DenseHashMap map; + + Node* currentArc; + + ArcCollector(NodeQueue& queue) + : queue(queue) + , map(Identifier{std::string{}, 0}) + , currentArc(nullptr) + { + for (const auto& node : queue) + { + if (node->name && !map.contains(*node->name)) + map[*node->name] = node.get(); + } + } + + void add(const Identifier& name) + { + Node** it = map.find(name); + if (it == nullptr) + return; + + Node* n = *it; + + if (n == currentArc) + return; + + n->provides.insert(currentArc); + currentArc->depends.insert(n); + } + + bool visit(AstExprGlobal* node) override + { + add(mkName(*node)); + return true; + } + + bool visit(AstExprLocal* node) override + { + add(mkName(*node)); + return true; + } + + bool visit(AstExprIndexName* node) override + { + auto name = mkName(*node); + if (name) + add(*name); + return true; + } + + bool visit(AstStatFunction* node) override + { + auto name = mkName(*node->name); + if (!name) + throw std::runtime_error("Internal error: AstStatFunction has a bad name"); + + add(*name); + return true; + } + + bool visit(AstStatLocalFunction* node) override + { + add(mkName(*node->name)); + return true; + } + + bool visit(AstStatAssign* node) override + { + return true; + } + + bool visit(AstStatTypeAlias* node) override + { + add(mkName(*node)); + return true; + } + + bool visit(AstType* node) override + { + return true; + } + + bool visit(AstTypeReference* node) override + { + add(mkName(node->name)); + return true; + } + + bool visit(AstTypeTypeof* node) override + { + std::optional name = mkName(*node->expr); + if (name) + add(*name); + return true; + } +}; + +struct ContainsFunctionCall : public AstVisitor +{ + bool result = false; + + bool visit(AstExpr*) override + { + return !result; // short circuit if result is true + } + + bool visit(AstExprCall*) override + { + result = true; + return false; + } + + bool visit(AstStatForIn*) override + { + // for in loops perform an implicit function call as part of the iterator protocol + result = true; + return false; + } + + bool visit(AstExprFunction*) override + { + return false; + } + bool visit(AstStatFunction*) override + { + return false; + } + bool visit(AstStatLocalFunction*) override + { + return false; + } + + bool visit(AstType* ta) override + { + return true; + } +}; + +bool isToposortableNode(const AstStat& stat) +{ + return isFunction(stat) || stat.is(); +} + +bool containsToposortableNode(const std::vector& block) +{ + for (AstStat* stat : block) + if (isToposortableNode(*stat)) + return true; + + return false; +} + +bool isBlockTerminator(const AstStat& stat) +{ + return stat.is() || stat.is() || stat.is(); +} + +// Clip arcs to and from the node +void prune(Node* next) +{ + for (const auto& node : next->provides) + { + auto it = node->depends.find(next); + LUAU_ASSERT(it != node->depends.end()); + node->depends.erase(it); + } + + for (const auto& node : next->depends) + { + auto it = node->provides.find(next); + LUAU_ASSERT(it != node->provides.end()); + node->provides.erase(it); + } +} + +// Drain Q until the target's depends arcs are satisfied. target is always added to the result. +void drain(NodeList& Q, std::vector& result, Node* target) +{ + // Trying to toposort a subgraph is a pretty big hassle. :( + // Some of the nodes in .depends and .provides aren't present in our subgraph + + std::map allArcs; + + for (auto& node : Q) + { + // Copy the connectivity information but filter out any provides or depends arcs that are not in Q + Arcs& arcs = allArcs[node.get()]; + + DenseHashSet elements{nullptr}; + for (const auto& q : Q) + elements.insert(q.get()); + + for (Node* node : node->depends) + { + if (elements.contains(node)) + arcs.depends.insert(node); + } + for (Node* node : node->provides) + { + if (elements.contains(node)) + arcs.provides.insert(node); + } + } + + while (!Q.empty()) + { + if (target && target->depends.empty()) + { + prune(target); + result.push_back(target->element); + return; + } + + std::unique_ptr nextNode; + + for (auto iter = Q.begin(); iter != Q.end(); ++iter) + { + if (isBlockTerminator(*iter->get()->element)) + continue; + + LUAU_ASSERT(allArcs.end() != allArcs.find(iter->get())); + const Arcs& arcs = allArcs[iter->get()]; + + if (arcs.depends.empty()) + { + nextNode = std::move(*iter); + Q.erase(iter); + break; + } + } + + if (!nextNode) + { + // We've hit a cycle or a terminator. Pick an arbitrary node. + nextNode = std::move(Q.front()); + Q.pop_front(); + } + + for (const auto& node : nextNode->provides) + { + auto it = allArcs.find(node); + if (allArcs.end() != it) + { + auto i2 = it->second.depends.find(nextNode.get()); + LUAU_ASSERT(i2 != it->second.depends.end()); + it->second.depends.erase(i2); + } + } + + for (const auto& node : nextNode->depends) + { + auto it = allArcs.find(node); + if (allArcs.end() != it) + { + auto i2 = it->second.provides.find(nextNode.get()); + LUAU_ASSERT(i2 != it->second.provides.end()); + it->second.provides.erase(i2); + } + } + + prune(nextNode.get()); + result.push_back(nextNode->element); + } + + if (target) + { + prune(target); + result.push_back(target->element); + } +} + +} // namespace detail + +bool containsFunctionCall(const AstStat& stat) +{ + detail::ContainsFunctionCall cfc; + const_cast(stat).visit(&cfc); + return cfc.result; +} + +bool isFunction(const AstStat& stat) +{ + return stat.is() || stat.is(); +} + +void toposort(std::vector& stats) +{ + using namespace detail; + + if (stats.empty()) + return; + + if (!containsToposortableNode(stats)) + return; + + std::vector result; + result.reserve(stats.size()); + + NodeQueue nodes; + NodeList Q; + + for (AstStat* stat : stats) + nodes.push_back(std::unique_ptr(new Node(mkName(stat), stat))); + + ArcCollector collector{nodes}; + + for (const auto& node : nodes) + { + collector.currentArc = node.get(); + node->element->visit(&collector); + } + + { + auto it = nodes.begin(); + auto prev = it; + + while (it != nodes.end()) + { + if (it != prev && !isToposortableNode(*(*it)->element)) + { + (*it)->depends.insert(prev->get()); + (*prev)->provides.insert(it->get()); + prev = it; + } + ++it; + } + } + + while (!nodes.empty()) + { + Node* next = nodes.front().get(); + + if (next->depends.empty() && !isBlockTerminator(*next->element)) + { + prune(next); + result.push_back(next->element); + } + else if (!containsFunctionCall(*next->element)) + Q.push_back(std::move(nodes.front())); + else + drain(Q, result, next); + + nodes.pop_front(); + } + + drain(Q, result, nullptr); + + std::swap(stats, result); +} + +} // namespace Luau diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp new file mode 100644 index 0000000..462c70f --- /dev/null +++ b/Analysis/src/Transpiler.cpp @@ -0,0 +1,1156 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Transpiler.h" + +#include "Luau/Parser.h" +#include "Luau/StringUtils.h" +#include "Luau/Common.h" + +#include +#include +#include +#include + +LUAU_FASTFLAG(LuauGenericFunctions) + +namespace +{ + +std::string escape(std::string_view s) +{ + std::string r; + r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting + + for (uint8_t c : s) + { + if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') + r += c; + else + { + r += '\\'; + + switch (c) + { + case '\a': + r += 'a'; + break; + case '\b': + r += 'b'; + break; + case '\f': + r += 'f'; + break; + case '\n': + r += 'n'; + break; + case '\r': + r += 'r'; + break; + case '\t': + r += 't'; + break; + case '\v': + r += 'v'; + break; + case '\'': + r += '\''; + break; + case '\"': + r += '\"'; + break; + case '\\': + r += '\\'; + break; + default: + Luau::formatAppend(r, "%03u", c); + } + } + } + + return r; +} + +bool isIdentifierStartChar(char c) +{ + return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_'; +} + +bool isDigit(char c) +{ + return c >= '0' && c <= '9'; +} + +bool isIdentifierChar(char c) +{ + return isIdentifierStartChar(c) || isDigit(c); +} + +const std::vector keywords = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", + "not", "or", "repeat", "return", "then", "true", "until", "while"}; + +} // namespace + +namespace Luau +{ + +struct Writer +{ + virtual ~Writer() {} + + virtual void begin() {} + virtual void end() {} + + virtual void advance(const Position&) = 0; + virtual void newline() = 0; + virtual void space() = 0; + virtual void maybeSpace(const Position& newPos, int reserve) = 0; + virtual void write(std::string_view) = 0; + virtual void identifier(std::string_view name) = 0; + virtual void keyword(std::string_view) = 0; + virtual void symbol(std::string_view) = 0; + virtual void literal(std::string_view) = 0; + virtual void string(std::string_view) = 0; +}; + +struct StringWriter : Writer +{ + std::string ss; + Position pos{0, 0}; + char lastChar = '\0'; // used to determine whether we need to inject an extra space to preserve grammatical correctness. + + const std::string& str() const + { + return ss; + } + + void advance(const Position& newPos) override + { + while (pos.line < newPos.line) + newline(); + + if (pos.column < newPos.column) + write(std::string(newPos.column - pos.column, ' ')); + } + void maybeSpace(const Position& newPos, int reserve) override + { + if (pos.column + reserve < newPos.column) + space(); + } + + void newline() override + { + ss += '\n'; + pos.column = 0; + ++pos.line; + lastChar = '\n'; + } + + void space() override + { + ss += ' '; + ++pos.column; + lastChar = ' '; + } + + void write(std::string_view s) override + { + if (s.empty()) + return; + + ss.append(s.data(), s.size()); + pos.column += unsigned(s.size()); + lastChar = s[s.size() - 1]; + } + + void write(char c) + { + ss += c; + pos.column += 1; + lastChar = c; + } + + void identifier(std::string_view s) override + { + if (s.empty()) + return; + + if (isIdentifierChar(lastChar)) + space(); + + write(s); + } + + void keyword(std::string_view s) override + { + if (s.empty()) + return; + + if (isIdentifierChar(lastChar)) + space(); + + write(s); + } + + void symbol(std::string_view s) override + { + if (isDigit(lastChar) && s[0] == '.') + space(); + + write(s); + } + + void literal(std::string_view s) override + { + if (s.empty()) + return; + + else if (isIdentifierChar(lastChar) && isDigit(s[0])) + space(); + + write(s); + } + + void string(std::string_view s) override + { + char quote = '\''; + if (std::string::npos != s.find(quote)) + quote = '\"'; + + write(quote); + write(escape(s)); + write(quote); + } +}; + +class CommaSeparatorInserter +{ +public: + CommaSeparatorInserter(Writer& w) + : first(true) + , writer(w) + { + } + void operator()() + { + if (first) + first = !first; + else + writer.symbol(","); + } + +private: + bool first; + Writer& writer; +}; + +struct Printer +{ + explicit Printer(Writer& writer) + : writer(writer) + { + } + + bool writeTypes = false; + Writer& writer; + + void visualize(const AstLocal& local) + { + advance(local.location.begin); + + writer.identifier(local.name.value); + if (writeTypes && local.annotation) + { + writer.symbol(":"); + visualizeTypeAnnotation(*local.annotation); + } + } + + void visualizeWithSelf(AstExpr& expr, bool self) + { + if (!self) + return visualize(expr); + + AstExprIndexName* func = expr.as(); + LUAU_ASSERT(func); + + visualize(*func->expr); + writer.symbol(":"); + advance(func->indexLocation.begin); + writer.identifier(func->index.value); + } + + void visualizeTypePackAnnotation(const AstTypePack& annotation) + { + if (const AstTypePackVariadic* variadic = annotation.as()) + { + writer.symbol("..."); + visualizeTypeAnnotation(*variadic->variadicType); + } + else + { + LUAU_ASSERT(!"Unknown TypePackAnnotation kind"); + } + } + + void visualizeTypeList(const AstTypeList& list, bool unconditionallyParenthesize) + { + size_t typeCount = list.types.size + (list.tailType != nullptr ? 1 : 0); + if (typeCount == 0) + { + writer.symbol("("); + writer.symbol(")"); + } + else if (typeCount == 1) + { + if (unconditionallyParenthesize) + writer.symbol("("); + + // Only variadic tail + if (list.types.size == 0) + { + visualizeTypePackAnnotation(*list.tailType); + } + else + { + visualizeTypeAnnotation(*list.types.data[0]); + } + + if (unconditionallyParenthesize) + writer.symbol(")"); + } + else + { + writer.symbol("("); + + bool first = true; + for (const auto& el : list.types) + { + if (first) + first = false; + else + writer.symbol(","); + + visualizeTypeAnnotation(*el); + } + + if (list.tailType) + { + writer.symbol(","); + visualizeTypePackAnnotation(*list.tailType); + } + + writer.symbol(")"); + } + } + + bool isIntegerish(double d) + { + if (d <= std::numeric_limits::max() && d >= std::numeric_limits::min()) + return double(int(d)) == d && !(d == 0.0 && signbit(d)); + else + return false; + } + + void visualize(AstExpr& expr) + { + advance(expr.location.begin); + + if (const auto& a = expr.as()) + { + writer.symbol("("); + visualize(*a->expr); + writer.symbol(")"); + } + else if (expr.is()) + { + writer.keyword("nil"); + } + else if (const auto& a = expr.as()) + { + if (a->value) + writer.keyword("true"); + else + writer.keyword("false"); + } + else if (const auto& a = expr.as()) + { + if (isinf(a->value)) + { + if (a->value > 0) + writer.literal("1e500"); + else + writer.literal("-1e500"); + } + else if (isnan(a->value)) + writer.literal("0/0"); + else + { + if (isIntegerish(a->value)) + writer.literal(std::to_string(int(a->value))); + else + { + char buffer[100]; + size_t len = snprintf(buffer, sizeof(buffer), "%.17g", a->value); + writer.literal(std::string_view{buffer, len}); + } + } + } + else if (const auto& a = expr.as()) + { + writer.string(std::string_view(a->value.data, a->value.size)); + } + else if (const auto& a = expr.as()) + { + writer.identifier(a->local->name.value); + } + else if (const auto& a = expr.as()) + { + writer.identifier(a->name.value); + } + else if (expr.is()) + { + writer.symbol("..."); + } + else if (const auto& a = expr.as()) + { + visualizeWithSelf(*a->func, a->self); + writer.symbol("("); + + bool first = true; + for (const auto& arg : a->args) + { + if (first) + first = false; + else + writer.symbol(","); + + visualize(*arg); + } + + writer.symbol(")"); + } + else if (const auto& a = expr.as()) + { + visualize(*a->expr); + writer.symbol("."); + writer.write(a->index.value); + } + else if (const auto& a = expr.as()) + { + visualize(*a->expr); + writer.symbol("["); + visualize(*a->index); + writer.symbol("]"); + } + else if (const auto& a = expr.as()) + { + writer.keyword("function"); + visualizeFunctionBody(*a); + } + else if (const auto& a = expr.as()) + { + writer.symbol("{"); + + bool first = true; + + for (const auto& item : a->items) + { + if (first) + first = false; + else + writer.symbol(","); + + switch (item.kind) + { + case AstExprTable::Item::List: + break; + + case AstExprTable::Item::Record: + { + const auto& value = item.key->as()->value; + advance(item.key->location.begin); + writer.identifier(std::string_view(value.data, value.size)); + writer.maybeSpace(item.value->location.begin, 1); + writer.symbol("="); + } + break; + + case AstExprTable::Item::General: + { + writer.symbol("["); + visualize(*item.key); + writer.symbol("]"); + writer.maybeSpace(item.value->location.begin, 1); + writer.symbol("="); + } + break; + + default: + LUAU_ASSERT(!"Unknown table item kind"); + } + + advance(item.value->location.begin); + visualize(*item.value); + } + + Position endPos = expr.location.end; + if (endPos.column > 0) + --endPos.column; + + advance(endPos); + + writer.symbol("}"); + advance(expr.location.end); + } + else if (const auto& a = expr.as()) + { + switch (a->op) + { + case AstExprUnary::Not: + writer.keyword("not"); + break; + case AstExprUnary::Minus: + writer.symbol("-"); + break; + case AstExprUnary::Len: + writer.symbol("#"); + break; + } + visualize(*a->expr); + } + else if (const auto& a = expr.as()) + { + visualize(*a->left); + + switch (a->op) + { + case AstExprBinary::Add: + case AstExprBinary::Sub: + case AstExprBinary::Mul: + case AstExprBinary::Div: + case AstExprBinary::Mod: + case AstExprBinary::Pow: + case AstExprBinary::CompareLt: + case AstExprBinary::CompareGt: + writer.maybeSpace(a->right->location.begin, 2); + break; + case AstExprBinary::Concat: + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + case AstExprBinary::CompareLe: + case AstExprBinary::CompareGe: + case AstExprBinary::Or: + writer.maybeSpace(a->right->location.begin, 3); + break; + case AstExprBinary::And: + writer.maybeSpace(a->right->location.begin, 4); + break; + } + + writer.symbol(toString(a->op)); + + visualize(*a->right); + } + else if (const auto& a = expr.as()) + { + visualize(*a->expr); + } + else if (const auto& a = expr.as()) + { + writer.symbol("(error-expr"); + + for (size_t i = 0; i < a->expressions.size; i++) + { + writer.symbol(i == 0 ? ": " : ", "); + visualize(*a->expressions.data[i]); + } + + writer.symbol(")"); + } + else + { + LUAU_ASSERT(!"Unknown AstExpr"); + } + } + + void writeEnd(const Location& loc) + { + Position endPos = loc.end; + if (endPos.column >= 3) + endPos.column -= 3; + advance(endPos); + writer.keyword("end"); + } + + void advance(const Position& newPos) + { + writer.advance(newPos); + } + + void visualize(AstStat& program) + { + advance(program.location.begin); + + if (const auto& block = program.as()) + { + writer.keyword("do"); + for (const auto& s : block->body) + visualize(*s); + writer.advance(block->location.end); + writeEnd(program.location); + } + else if (const auto& a = program.as()) + { + writer.keyword("if"); + visualizeElseIf(*a); + } + else if (const auto& a = program.as()) + { + writer.keyword("while"); + visualize(*a->condition); + writer.keyword("do"); + visualizeBlock(*a->body); + writeEnd(program.location); + } + else if (const auto& a = program.as()) + { + writer.keyword("repeat"); + visualizeBlock(*a->body); + if (a->condition->location.begin.column > 5) + writer.advance(Position{a->condition->location.begin.line, a->condition->location.begin.column - 6}); + writer.keyword("until"); + visualize(*a->condition); + } + else if (program.is()) + writer.keyword("break"); + else if (program.is()) + writer.keyword("continue"); + else if (const auto& a = program.as()) + { + writer.keyword("return"); + + bool first = true; + for (const auto& expr : a->list) + { + if (first) + first = false; + else + writer.symbol(","); + visualize(*expr); + } + } + else if (const auto& a = program.as()) + { + visualize(*a->expr); + } + else if (const auto& a = program.as()) + { + writer.keyword("local"); + + bool first = true; + for (const auto& local : a->vars) + { + if (first) + first = false; + else + writer.write(","); + + visualize(*local); + } + + first = true; + for (const auto& value : a->values) + { + if (first) + { + first = false; + writer.maybeSpace(value->location.begin, 2); + writer.symbol("="); + } + else + writer.symbol(","); + + visualize(*value); + } + } + else if (const auto& a = program.as()) + { + writer.keyword("for"); + + visualize(*a->var); + writer.symbol("="); + visualize(*a->from); + writer.symbol(","); + visualize(*a->to); + if (a->step) + { + writer.symbol(","); + visualize(*a->step); + } + writer.keyword("do"); + visualizeBlock(*a->body); + + writeEnd(program.location); + } + else if (const auto& a = program.as()) + { + writer.keyword("for"); + + bool first = true; + for (const auto& var : a->vars) + { + if (first) + first = false; + else + writer.symbol(","); + + visualize(*var); + } + + writer.keyword("in"); + + first = true; + for (const auto& val : a->values) + { + if (first) + first = false; + else + writer.symbol(","); + + visualize(*val); + } + + writer.keyword("do"); + + visualizeBlock(*a->body); + + writeEnd(program.location); + } + else if (const auto& a = program.as()) + { + bool first = true; + for (const auto& var : a->vars) + { + if (first) + first = false; + else + writer.symbol(","); + visualize(*var); + } + + first = true; + for (const auto& value : a->values) + { + if (first) + { + writer.maybeSpace(value->location.begin, 1); + writer.symbol("="); + first = false; + } + else + writer.symbol(","); + + visualize(*value); + } + } + else if (const auto& a = program.as()) + { + visualize(*a->var); + + switch (a->op) + { + case AstExprBinary::Add: + writer.symbol("+="); + break; + case AstExprBinary::Sub: + writer.symbol("-="); + break; + case AstExprBinary::Mul: + writer.symbol("*="); + break; + case AstExprBinary::Div: + writer.symbol("/="); + break; + case AstExprBinary::Mod: + writer.symbol("%="); + break; + case AstExprBinary::Pow: + writer.symbol("^="); + break; + case AstExprBinary::Concat: + writer.symbol("..="); + break; + default: + LUAU_ASSERT(!"Unexpected compound assignment op"); + } + + visualize(*a->value); + } + else if (const auto& a = program.as()) + { + writer.keyword("function"); + visualizeWithSelf(*a->name, a->func->self != nullptr); + visualizeFunctionBody(*a->func); + } + else if (const auto& a = program.as()) + { + writer.keyword("local function"); + advance(a->name->location.begin); + writer.identifier(a->name->name.value); + visualizeFunctionBody(*a->func); + } + else if (const auto& a = program.as()) + { + if (writeTypes) + { + if (a->exported) + writer.keyword("export"); + + writer.keyword("type"); + writer.identifier(a->name.value); + if (a->generics.size > 0) + { + writer.symbol("<"); + CommaSeparatorInserter comma(writer); + + for (auto o : a->generics) + { + comma(); + writer.identifier(o.value); + } + writer.symbol(">"); + } + writer.maybeSpace(a->type->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*a->type); + } + } + else if (const auto& a = program.as()) + { + writer.symbol("(error-stat"); + + for (size_t i = 0; i < a->expressions.size; i++) + { + writer.symbol(i == 0 ? ": " : ", "); + visualize(*a->expressions.data[i]); + } + + for (size_t i = 0; i < a->statements.size; i++) + { + writer.symbol(i == 0 && a->expressions.size == 0 ? ": " : ", "); + visualize(*a->statements.data[i]); + } + + writer.symbol(")"); + } + else + { + LUAU_ASSERT(!"Unknown AstStat"); + } + + if (program.hasSemicolon) + writer.symbol(";"); + } + + void visualizeFunctionBody(AstExprFunction& func) + { + if (FFlag::LuauGenericFunctions && (func.generics.size > 0 || func.genericPacks.size > 0)) + { + CommaSeparatorInserter comma(writer); + writer.symbol("<"); + for (const auto& o : func.generics) + { + comma(); + writer.identifier(o.value); + } + for (const auto& o : func.genericPacks) + { + comma(); + writer.identifier(o.value); + writer.symbol("..."); + } + writer.symbol(">"); + } + + writer.symbol("("); + CommaSeparatorInserter comma(writer); + + for (size_t i = 0; i < func.args.size; ++i) + { + AstLocal* local = func.args.data[i]; + + comma(); + + advance(local->location.begin); + writer.identifier(local->name.value); + if (writeTypes && local->annotation) + { + writer.symbol(":"); + visualizeTypeAnnotation(*local->annotation); + } + } + + if (func.vararg) + { + comma(); + writer.symbol("..."); + + if (func.varargAnnotation) + { + writer.symbol(":"); + visualizeTypePackAnnotation(*func.varargAnnotation); + } + } + + writer.symbol(")"); + + if (writeTypes && func.hasReturnAnnotation) + { + writer.symbol(":"); + writer.space(); + + visualizeTypeList(func.returnAnnotation, false); + } + + visualizeBlock(*func.body); + writeEnd(func.location); + } + + void visualizeBlock(AstStatBlock& block) + { + for (const auto& s : block.body) + visualize(*s); + writer.advance(block.location.end); + } + + void visualizeBlock(AstStat& stat) + { + if (AstStatBlock* block = stat.as()) + visualizeBlock(*block); + else + LUAU_ASSERT(!"visualizeBlock was expecting an AstStatBlock"); + } + + void visualizeElseIf(AstStatIf& elseif) + { + visualize(*elseif.condition); + writer.keyword("then"); + visualizeBlock(*elseif.thenbody); + + if (elseif.elsebody == nullptr) + { + writeEnd(elseif.location); + } + else if (auto elseifelseif = elseif.elsebody->as()) + { + writer.keyword("elseif"); + visualizeElseIf(*elseifelseif); + } + else + { + writer.keyword("else"); + + visualizeBlock(*elseif.elsebody); + writeEnd(elseif.location); + } + } + + void visualizeTypeAnnotation(const AstType& typeAnnotation) + { + advance(typeAnnotation.location.begin); + if (const auto& a = typeAnnotation.as()) + { + writer.write(a->name.value); + if (a->generics.size > 0) + { + CommaSeparatorInserter comma(writer); + writer.symbol("<"); + for (auto o : a->generics) + { + comma(); + visualizeTypeAnnotation(*o); + } + writer.symbol(">"); + } + } + else if (const auto& a = typeAnnotation.as()) + { + if (FFlag::LuauGenericFunctions && (a->generics.size > 0 || a->genericPacks.size > 0)) + { + CommaSeparatorInserter comma(writer); + writer.symbol("<"); + for (const auto& o : a->generics) + { + comma(); + writer.identifier(o.value); + } + for (const auto& o : a->genericPacks) + { + comma(); + writer.identifier(o.value); + writer.symbol("..."); + } + writer.symbol(">"); + } + + { + visualizeTypeList(a->argTypes, true); + } + + writer.symbol("->"); + visualizeTypeList(a->returnTypes, true); + } + else if (const auto& a = typeAnnotation.as()) + { + CommaSeparatorInserter comma(writer); + + writer.symbol("{"); + + for (std::size_t i = 0; i < a->props.size; ++i) + { + comma(); + advance(a->props.data[i].location.begin); + writer.identifier(a->props.data[i].name.value); + if (a->props.data[i].type) + { + writer.symbol(":"); + visualizeTypeAnnotation(*a->props.data[i].type); + } + } + if (a->indexer) + { + comma(); + writer.symbol("["); + visualizeTypeAnnotation(*a->indexer->indexType); + writer.symbol("]"); + writer.symbol(":"); + visualizeTypeAnnotation(*a->indexer->resultType); + } + writer.symbol("}"); + } + else if (auto a = typeAnnotation.as()) + { + writer.keyword("typeof"); + writer.symbol("("); + visualize(*a->expr); + writer.symbol(")"); + } + else if (const auto& a = typeAnnotation.as()) + { + if (a->types.size == 2) + { + AstType* l = a->types.data[0]; + AstType* r = a->types.data[1]; + + auto lta = l->as(); + if (lta && lta->name == "nil") + std::swap(l, r); + + // it's still possible that we had a (T | U) or (T | nil) and not (nil | T) + auto rta = r->as(); + if (rta && rta->name == "nil") + { + visualizeTypeAnnotation(*l); + writer.symbol("?"); + return; + } + } + + for (size_t i = 0; i < a->types.size; ++i) + { + if (i > 0) + { + writer.maybeSpace(a->types.data[i]->location.begin, 2); + writer.symbol("|"); + } + + visualizeTypeAnnotation(*a->types.data[i]); + } + } + else if (const auto& a = typeAnnotation.as()) + { + for (size_t i = 0; i < a->types.size; ++i) + { + if (i > 0) + { + writer.maybeSpace(a->types.data[i]->location.begin, 2); + writer.symbol("&"); + } + + visualizeTypeAnnotation(*a->types.data[i]); + } + } + else if (typeAnnotation.is()) + { + writer.symbol("%error-type%"); + } + else + { + LUAU_ASSERT(!"Unknown AstType"); + } + } +}; + +void dump(AstNode* node) +{ + StringWriter writer; + Printer printer(writer); + printer.writeTypes = true; + + if (auto statNode = dynamic_cast(node)) + { + printer.visualize(*statNode); + printf("%s\n", writer.str().c_str()); + } + else if (auto exprNode = dynamic_cast(node)) + { + printer.visualize(*exprNode); + printf("%s\n", writer.str().c_str()); + } + else if (auto typeNode = dynamic_cast(node)) + { + printer.visualizeTypeAnnotation(*typeNode); + printf("%s\n", writer.str().c_str()); + } + else + { + printf("Can't dump this node\n"); + } +} + +std::string transpile(AstStatBlock& block) +{ + StringWriter writer; + Printer(writer).visualizeBlock(block); + return writer.str(); +} +std::string transpileWithTypes(AstStatBlock& block) +{ + StringWriter writer; + Printer printer(writer); + printer.writeTypes = true; + printer.visualizeBlock(block); + return writer.str(); +} + +TranspileResult transpile(std::string_view source, ParseOptions options) +{ + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(source.data(), source.size(), names, allocator, options); + + if (!parseResult.errors.empty()) + { + // TranspileResult keeps track of only a single error + const ParseError& error = parseResult.errors.front(); + + return TranspileResult{"", error.getLocation(), error.what()}; + } + + LUAU_ASSERT(parseResult.root); + if (!parseResult.root) + return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"}; + + return TranspileResult{transpile(*parseResult.root)}; +} + +} // namespace Luau diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp new file mode 100644 index 0000000..702d0ca --- /dev/null +++ b/Analysis/src/TxnLog.cpp @@ -0,0 +1,72 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TxnLog.h" + +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ + +void TxnLog::operator()(TypeId a) +{ + typeVarChanges.emplace_back(a, *a); +} + +void TxnLog::operator()(TypePackId a) +{ + typePackChanges.emplace_back(a, *a); +} + +void TxnLog::operator()(TableTypeVar* a) +{ + tableChanges.emplace_back(a, a->boundTo); +} + +void TxnLog::rollback() +{ + for (auto it = typeVarChanges.rbegin(); it != typeVarChanges.rend(); ++it) + std::swap(*asMutable(it->first), it->second); + + for (auto it = typePackChanges.rbegin(); it != typePackChanges.rend(); ++it) + std::swap(*asMutable(it->first), it->second); + + for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) + std::swap(it->first->boundTo, it->second); +} + +void TxnLog::concat(TxnLog rhs) +{ + typeVarChanges.insert(typeVarChanges.end(), rhs.typeVarChanges.begin(), rhs.typeVarChanges.end()); + rhs.typeVarChanges.clear(); + + typePackChanges.insert(typePackChanges.end(), rhs.typePackChanges.begin(), rhs.typePackChanges.end()); + rhs.typePackChanges.clear(); + + tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); + rhs.tableChanges.clear(); + + seen.swap(rhs.seen); + rhs.seen.clear(); +} + +bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) +{ + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair)); +} + +void TxnLog::pushSeen(TypeId lhs, TypeId rhs) +{ + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + seen.push_back(sortedPair); +} + +void TxnLog::popSeen(TypeId lhs, TypeId rhs) +{ + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + LUAU_ASSERT(sortedPair == seen.back()); + seen.pop_back(); +} + +} // namespace Luau diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp new file mode 100644 index 0000000..17c57c8 --- /dev/null +++ b/Analysis/src/TypeAttach.cpp @@ -0,0 +1,437 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeAttach.h" + +#include "Luau/Error.h" +#include "Luau/Module.h" +#include "Luau/Parser.h" +#include "Luau/RecursionCounter.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" + +#include + +LUAU_FASTFLAG(LuauGenericFunctions) + +static char* allocateString(Luau::Allocator& allocator, std::string_view contents) +{ + char* result = (char*)allocator.allocate(contents.size() + 1); + memcpy(result, contents.data(), contents.size()); + result[contents.size()] = '\0'; + + return result; +} + +template +static char* allocateString(Luau::Allocator& allocator, const char* format, Data... data) +{ + int len = snprintf(nullptr, 0, format, data...); + char* result = (char*)allocator.allocate(len + 1); + snprintf(result, len + 1, format, data...); + return result; +} + +namespace Luau +{ + +class TypeRehydrationVisitor +{ + mutable std::map seen; + mutable int count = 0; + + bool hasSeen(const void* tv) const + { + void* ttv = const_cast(tv); + auto it = seen.find(ttv); + if (it != seen.end() && it->second < count) + return true; + + seen[ttv] = count; + return false; + } + +public: + TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions()) + : allocator(alloc) + , options(options) + { + } + + AstType* operator()(const PrimitiveTypeVar& ptv) const + { + switch (ptv.type) + { + case PrimitiveTypeVar::NilType: + return allocator->alloc(Location(), std::nullopt, AstName("nil")); + case PrimitiveTypeVar::Boolean: + return allocator->alloc(Location(), std::nullopt, AstName("boolean")); + case PrimitiveTypeVar::Number: + return allocator->alloc(Location(), std::nullopt, AstName("number")); + case PrimitiveTypeVar::String: + return allocator->alloc(Location(), std::nullopt, AstName("string")); + case PrimitiveTypeVar::Thread: + return allocator->alloc(Location(), std::nullopt, AstName("thread")); + default: + return nullptr; + } + } + AstType* operator()(const AnyTypeVar&) const + { + return allocator->alloc(Location(), std::nullopt, AstName("any")); + } + AstType* operator()(const TableTypeVar& ttv) const + { + RecursionCounter counter(&count); + + if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end()) + { + AstArray generics; + generics.size = ttv.instantiatedTypeParams.size(); + generics.data = static_cast(allocator->allocate(sizeof(AstType*) * generics.size)); + + for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i) + { + generics.data[i] = Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty); + } + + return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), generics); + } + + if (hasSeen(&ttv)) + { + if (ttv.name) + return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str())); + else + return allocator->alloc(Location(), std::nullopt, AstName("")); + } + + AstArray props; + props.size = ttv.props.size(); + props.data = static_cast(allocator->allocate(sizeof(AstTableProp) * props.size)); + int idx = 0; + for (const auto& [propName, prop] : ttv.props) + { + RecursionCounter counter(&count); + + char* name = allocateString(*allocator, propName); + + props.data[idx].name = AstName(name); + props.data[idx].type = Luau::visit(*this, prop.type->ty); + props.data[idx].location = Location(); + idx++; + } + + AstTableIndexer* indexer = nullptr; + if (ttv.indexer) + { + RecursionCounter counter(&count); + + indexer = allocator->alloc(); + indexer->indexType = Luau::visit(*this, ttv.indexer->indexType->ty); + indexer->resultType = Luau::visit(*this, ttv.indexer->indexResultType->ty); + } + return allocator->alloc(Location(), props, indexer); + } + + AstType* operator()(const MetatableTypeVar& mtv) const + { + return Luau::visit(*this, mtv.table->ty); + } + + AstType* operator()(const ClassTypeVar& ctv) const + { + RecursionCounter counter(&count); + + char* name = allocateString(*allocator, ctv.name); + + if (!options.expandClassProps || hasSeen(&ctv) || count > 1) + return allocator->alloc(Location(), std::nullopt, AstName{name}); + + AstArray props; + props.size = ctv.props.size(); + props.data = static_cast(allocator->allocate(sizeof(AstTableProp) * props.size)); + + int idx = 0; + for (const auto& [propName, prop] : ctv.props) + { + char* name = allocateString(*allocator, propName); + + props.data[idx].name = AstName{name}; + props.data[idx].type = Luau::visit(*this, prop.type->ty); + props.data[idx].location = Location(); + idx++; + } + + return allocator->alloc(Location(), props); + } + + AstType* operator()(const FunctionTypeVar& ftv) const + { + RecursionCounter counter(&count); + + if (hasSeen(&ftv)) + return allocator->alloc(Location(), std::nullopt, AstName("")); + + AstArray generics; + if (FFlag::LuauGenericFunctions) + { + generics.size = ftv.generics.size(); + generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); + size_t i = 0; + for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) + { + if (auto gtv = get(*it)) + generics.data[i++] = AstName(gtv->name.c_str()); + } + } + else + { + generics.size = 0; + generics.data = nullptr; + } + + AstArray genericPacks; + if (FFlag::LuauGenericFunctions) + { + genericPacks.size = ftv.genericPacks.size(); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); + size_t i = 0; + for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) + { + if (auto gtv = get(*it)) + genericPacks.data[i++] = AstName(gtv->name.c_str()); + } + } + else + { + generics.size = 0; + generics.data = nullptr; + } + + AstArray argTypes; + const auto& [argVector, argTail] = flatten(ftv.argTypes); + argTypes.size = argVector.size(); + argTypes.data = static_cast(allocator->allocate(sizeof(AstType*) * argTypes.size)); + for (size_t i = 0; i < argTypes.size; ++i) + { + RecursionCounter counter(&count); + + argTypes.data[i] = Luau::visit(*this, (argVector[i])->ty); + } + + AstTypePack* argTailAnnotation = nullptr; + if (argTail) + { + TypePackId tail = *argTail; + if (const VariadicTypePack* vtp = get(tail)) + { + argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + } + } + + AstArray> argNames; + argNames.size = ftv.argNames.size(); + argNames.data = static_cast*>(allocator->allocate(sizeof(std::optional) * argNames.size)); + size_t i = 0; + for (const auto& el : ftv.argNames) + { + if (el) + argNames.data[i++] = {AstName(el->name.c_str()), el->location}; + else + argNames.data[i++] = {}; + } + + AstArray returnTypes; + const auto& [retVector, retTail] = flatten(ftv.retType); + returnTypes.size = retVector.size(); + returnTypes.data = static_cast(allocator->allocate(sizeof(AstType*) * returnTypes.size)); + for (size_t i = 0; i < returnTypes.size; ++i) + { + RecursionCounter counter(&count); + + returnTypes.data[i] = Luau::visit(*this, (retVector[i])->ty); + } + + AstTypePack* retTailAnnotation = nullptr; + if (retTail) + { + TypePackId tail = *retTail; + if (const VariadicTypePack* vtp = get(tail)) + { + retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + } + } + + return allocator->alloc( + Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); + } + AstType* operator()(const Unifiable::Error&) const + { + return allocator->alloc(Location(), std::nullopt, AstName("Unifiable")); + } + AstType* operator()(const GenericTypeVar& gtv) const + { + return allocator->alloc(Location(), std::nullopt, AstName(gtv.name.c_str())); + } + AstType* operator()(const Unifiable::Bound& bound) const + { + return Luau::visit(*this, bound.boundTo->ty); + } + AstType* operator()(Unifiable::Free ftv) const + { + return allocator->alloc(Location(), std::nullopt, AstName("free")); + } + AstType* operator()(const UnionTypeVar& uv) const + { + AstArray unionTypes; + unionTypes.size = uv.options.size(); + unionTypes.data = static_cast(allocator->allocate(sizeof(AstType*) * unionTypes.size)); + for (size_t i = 0; i < unionTypes.size; ++i) + { + unionTypes.data[i] = Luau::visit(*this, uv.options[i]->ty); + } + return allocator->alloc(Location(), unionTypes); + } + AstType* operator()(const IntersectionTypeVar& uv) const + { + AstArray intersectionTypes; + intersectionTypes.size = uv.parts.size(); + intersectionTypes.data = static_cast(allocator->allocate(sizeof(AstType*) * intersectionTypes.size)); + for (size_t i = 0; i < intersectionTypes.size; ++i) + { + intersectionTypes.data[i] = Luau::visit(*this, uv.parts[i]->ty); + } + return allocator->alloc(Location(), intersectionTypes); + } + AstType* operator()(const LazyTypeVar& ltv) const + { + return allocator->alloc(Location(), std::nullopt, AstName("")); + } + +private: + Allocator* allocator; + const TypeRehydrationOptions& options; +}; + +class TypeAttacher : public AstVisitor +{ +public: + TypeAttacher(Module& checker, Luau::Allocator* alloc) + : module(checker) + , allocator(alloc) + { + } + ScopePtr getScope(const Location& loc) + { + Location scopeLocation; + ScopePtr scope = nullptr; + for (const auto& s : module.scopes) + { + if (s.first.encloses(loc)) + { + if (!scope || scopeLocation.encloses(s.first)) + { + scopeLocation = s.first; + scope = s.second; + } + } + } + + return scope; + } + + AstType* typeAst(std::optional type) + { + if (!type) + return nullptr; + return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty); + } + + AstArray typeAstPack(TypePackId type) + { + const auto& [v, tail] = flatten(type); + + AstArray result; + result.size = v.size(); + result.data = static_cast(allocator->allocate(sizeof(AstType*) * v.size())); + for (size_t i = 0; i < v.size(); ++i) + { + result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty); + } + return result; + } + + virtual bool visit(AstStatLocal* al) override + { + for (size_t i = 0; i < al->vars.size; ++i) + { + visitLocal(al->vars.data[i]); + } + return true; + } + + virtual bool visitLocal(AstLocal* local) + { + AstType* annotation = local->annotation; + if (!annotation) + { + if (auto result = getScope(local->location)->lookup(local)) + local->annotation = typeAst(*result); + } + return true; + } + + virtual bool visit(AstExprLocal* al) override + { + return visitLocal(al->local); + } + virtual bool visit(AstExprFunction* fn) override + { + // TODO: add generics if the inferred type of the function is generic CLI-39908 + for (size_t i = 0; i < fn->args.size; ++i) + { + AstLocal* arg = fn->args.data[i]; + visitLocal(arg); + } + + if (!fn->hasReturnAnnotation) + { + if (auto result = getScope(fn->body->location)) + { + TypePackId ret = result->returnType; + fn->hasReturnAnnotation = true; + + AstTypePack* variadicAnnotation = nullptr; + const auto& [v, tail] = flatten(ret); + + if (tail) + { + TypePackId tailPack = *tail; + if (const VariadicTypePack* vtp = get(tailPack)) + variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); + } + + fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation}; + } + } + + return true; + } + +private: + Module& module; + Allocator* allocator; +}; + +void attachTypeData(SourceModule& source, Module& result) +{ + TypeAttacher ta(result, source.allocator.get()); + source.root->visit(&ta); +} + +AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options) +{ + return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty); +} + +} // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp new file mode 100644 index 0000000..2216881 --- /dev/null +++ b/Analysis/src/TypeInfer.cpp @@ -0,0 +1,5497 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeInfer.h" + +#include "Luau/Common.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Parser.h" +#include "Luau/RecursionCounter.h" +#include "Luau/Substitution.h" +#include "Luau/TopoSortStatements.h" +#include "Luau/ToString.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" +#include "Luau/TypeVar.h" + +#include +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) +LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 0) +LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 0) +LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) +LUAU_FASTFLAGVARIABLE(LuauIndexTablesWithIndexers, false) +LUAU_FASTFLAGVARIABLE(LuauGenericFunctions, false) +LUAU_FASTFLAGVARIABLE(LuauGenericVariadicsUnification, false) +LUAU_FASTFLAG(LuauKnowsTheDataModel3) +LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) +LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) +LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) +LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. +LUAU_FASTFLAGVARIABLE(LuauImprovedTypeGuardPredicate2, false) +LUAU_FASTFLAG(LuauTraceRequireLookupChild) +LUAU_FASTFLAG(DebugLuauTrackOwningArena) +LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) +LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) +LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) +LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) +LUAU_FASTFLAGVARIABLE(LuauFixTableTypeAliasClone, false) +LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false) +LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false) +LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) +LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) +LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false) +LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false) +LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) +LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false) +LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) +LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) + +namespace Luau +{ + +static bool typeCouldHaveMetatable(TypeId ty) +{ + return get(follow(ty)) || get(follow(ty)) || get(follow(ty)); +} + +static void defaultLuauPrintLine(const std::string& s) +{ + printf("%s\n", s.c_str()); +} + +using PrintLineProc = decltype(&defaultLuauPrintLine); + +static PrintLineProc luauPrintLine = &defaultLuauPrintLine; + +void setPrintLine(PrintLineProc pl) +{ + luauPrintLine = pl; +} + +void resetPrintLine() +{ + luauPrintLine = &defaultLuauPrintLine; +} + +bool doesCallError(const AstExprCall* call) +{ + const AstExprGlobal* global = call->func->as(); + if (!global) + return false; + + if (global->name == "error") + return true; + else if (global->name == "assert") + { + // assert() will error because it is missing the first argument + if (call->args.size == 0) + return true; + + if (AstExprConstantBool* expr = call->args.data[0]->as()) + if (!expr->value) + return true; + } + + return false; +} + +bool hasBreak(AstStat* node) +{ + if (AstStatBlock* stat = node->as()) + { + for (size_t i = 0; i < stat->body.size; ++i) + { + if (hasBreak(stat->body.data[i])) + return true; + } + + return false; + } + else if (node->is()) + { + return true; + } + else if (AstStatIf* stat = node->as()) + { + if (hasBreak(stat->thenbody)) + return true; + + if (stat->elsebody && hasBreak(stat->elsebody)) + return true; + + return false; + } + else + { + return false; + } +} + +// returns the last statement before the block exits, or nullptr if the block never exits +const AstStat* getFallthrough(const AstStat* node) +{ + if (const AstStatBlock* stat = node->as()) + { + if (stat->body.size == 0) + return stat; + + for (size_t i = 0; i < stat->body.size - 1; ++i) + { + if (getFallthrough(stat->body.data[i]) == nullptr) + return nullptr; + } + + return getFallthrough(stat->body.data[stat->body.size - 1]); + } + else if (const AstStatIf* stat = node->as()) + { + if (const AstStat* thenf = getFallthrough(stat->thenbody)) + return thenf; + + if (stat->elsebody) + { + if (const AstStat* elsef = getFallthrough(stat->elsebody)) + return elsef; + + return nullptr; + } + else + { + return stat; + } + } + else if (node->is()) + { + return nullptr; + } + else if (const AstStatExpr* stat = node->as()) + { + if (AstExprCall* call = stat->expr->as()) + { + if (doesCallError(call)) + return nullptr; + } + + return stat; + } + else if (const AstStatWhile* stat = node->as()) + { + if (AstExprConstantBool* expr = stat->condition->as()) + { + if (expr->value && !hasBreak(stat->body)) + return nullptr; + } + + return node; + } + else if (const AstStatRepeat* stat = node->as()) + { + if (AstExprConstantBool* expr = stat->condition->as()) + { + if (!expr->value && !hasBreak(stat->body)) + return nullptr; + } + + if (getFallthrough(stat->body) == nullptr) + return nullptr; + + return node; + } + else + { + return node; + } +} + +static bool isMetamethod(const Name& name) +{ + return name == "__index" || name == "__newindex" || name == "__call" || name == "__concat" || name == "__unm" || name == "__add" || + name == "__sub" || name == "__mul" || name == "__div" || name == "__mod" || name == "__pow" || name == "__tostring" || + name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode"; +} + +TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler) + : resolver(resolver) + , iceHandler(iceHandler) + , nilType(singletonTypes.nilType) + , numberType(singletonTypes.numberType) + , stringType(singletonTypes.stringType) + , booleanType( + FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.booleanType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Boolean))) + , threadType(FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.threadType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Thread))) + , anyType(singletonTypes.anyType) + , errorType(singletonTypes.errorType) + , optionalNumberType(globalTypes.addType(UnionTypeVar{{numberType, nilType}})) + , anyTypePack(globalTypes.addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}, true})) + , errorTypePack(globalTypes.addTypePack(TypePackVar{Unifiable::Error{}})) +{ + globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); + + globalScope->exportedTypeBindings["any"] = TypeFun{{}, anyType}; + globalScope->exportedTypeBindings["nil"] = TypeFun{{}, nilType}; + globalScope->exportedTypeBindings["number"] = TypeFun{{}, numberType}; + globalScope->exportedTypeBindings["string"] = TypeFun{{}, stringType}; + globalScope->exportedTypeBindings["boolean"] = TypeFun{{}, booleanType}; + globalScope->exportedTypeBindings["thread"] = TypeFun{{}, threadType}; +} + +ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) +{ + currentModule.reset(new Module()); + currentModule->type = module.type; + + iceHandler->moduleName = module.name; + + ScopePtr parentScope = environmentScope.value_or(globalScope); + ScopePtr moduleScope = std::make_shared(parentScope); + + if (module.cyclic) + moduleScope->returnType = addTypePack(TypePack{{anyType}, std::nullopt}); + else if (FFlag::LuauRankNTypes) + moduleScope->returnType = freshTypePack(moduleScope); + else + moduleScope->returnType = DEPRECATED_freshTypePack(moduleScope, true); + + moduleScope->varargPack = anyTypePack; + + currentModule->scopes.push_back(std::make_pair(module.root->location, moduleScope)); + currentModule->mode = mode; + + currentModuleName = module.name; + + if (prepareModuleScope) + prepareModuleScope(module.name, currentModule->getModuleScope()); + + checkBlock(moduleScope, *module.root); + + if (get(FFlag::LuauAddMissingFollow ? follow(moduleScope->returnType) : moduleScope->returnType)) + moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); + else + moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); + + for (auto& [_, typeFun] : moduleScope->exportedTypeBindings) + typeFun.type = anyify(moduleScope, typeFun.type, Location{}); + + prepareErrorsForDisplay(currentModule->errors); + + bool encounteredFreeType = currentModule->clonePublicInterface(); + if (encounteredFreeType) + { + reportError(TypeError{module.root->location, + GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); + } + + return std::move(currentModule); +} + +void TypeChecker::check(const ScopePtr& scope, const AstStat& program) +{ + if (auto block = program.as()) + check(scope, *block); + else if (auto if_ = program.as()) + check(scope, *if_); + else if (auto while_ = program.as()) + check(scope, *while_); + else if (auto repeat = program.as()) + check(scope, *repeat); + else if (program.is()) + { + } // Nothing to do + else if (program.is()) + { + } // Nothing to do + else if (auto return_ = program.as()) + check(scope, *return_); + else if (auto expr = program.as()) + checkExprPack(scope, *expr->expr); + else if (auto local = program.as()) + check(scope, *local); + else if (auto for_ = program.as()) + check(scope, *for_); + else if (auto forIn = program.as()) + check(scope, *forIn); + else if (auto assign = program.as()) + check(scope, *assign); + else if (auto assign = program.as()) + check(scope, *assign); + else if (program.is()) + ice("Should not be calling two-argument check() on a function statement", program.location); + else if (program.is()) + ice("Should not be calling two-argument check() on a function statement", program.location); + else if (auto typealias = program.as()) + check(scope, *typealias); + else if (auto global = program.as()) + { + TypeId globalType = (FFlag::LuauRankNTypes ? resolveType(scope, *global->type) : resolveType(scope, *global->type, true)); + Name globalName(global->name.value); + + currentModule->declaredGlobals[globalName] = globalType; + currentModule->getModuleScope()->bindings[global->name] = Binding{globalType, global->location}; + } + else if (auto global = program.as()) + check(scope, *global); + else if (auto global = program.as()) + check(scope, *global); + else if (auto errorStatement = program.as()) + { + const size_t oldSize = currentModule->errors.size(); + + for (AstStat* s : errorStatement->statements) + check(scope, *s); + + for (AstExpr* expr : errorStatement->expressions) + checkExpr(scope, *expr); + + // HACK: We want to run typechecking on the contents of the AstStatError, but + // we don't think the type errors will be useful most of the time. + currentModule->errors.resize(oldSize); + } + else + ice("Unknown AstStat"); +} + +// This particular overload is for do...end. If you need to not increase the scope level, use checkBlock directly. +void TypeChecker::check(const ScopePtr& scope, const AstStatBlock& block) +{ + ScopePtr child = childScope(scope, block.location); + checkBlock(child, block); +} + +void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) +{ + RecursionCounter _rc(&checkRecursionCount); + if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) + { + reportErrorCodeTooComplex(block.location); + return; + } + + std::vector sorted(block.body.data, block.body.data + block.body.size); + toposort(sorted); + + for (const auto& stat : sorted) + { + if (const auto& typealias = stat->as()) + check(scope, *typealias, true); + } + + auto protoIter = sorted.begin(); + auto checkIter = sorted.begin(); + + std::unordered_map> functionDecls; + + auto checkBody = [&](AstStat* stat) { + if (auto fun = stat->as()) + { + LUAU_ASSERT(functionDecls.count(stat)); + auto [funTy, funScope] = functionDecls[stat]; + check(scope, funTy, funScope, *fun); + } + else if (auto fun = stat->as()) + { + LUAU_ASSERT(functionDecls.count(stat)); + auto [funTy, funScope] = functionDecls[stat]; + check(scope, funTy, funScope, *fun); + } + }; + + int subLevel = 0; + + while (protoIter != sorted.end()) + { + // protoIter walks forward + // If it contains a function call (function bodies don't count), walk checkIter forward until it catches up with protoIter + // For each element checkIter sees, check function bodies and unify the computed type with the prototype + // If it is a function definition, add its prototype to the environment + // If it is anything else, check it. + + // A subtlety is caused by mutually recursive functions, e.g. + // ``` + // function f(x) return g(x) end + // function g(x) return f(x) end + // ``` + // These both call each other, so `f` will be ordered before `g`, so the call to `g` + // is typechecked before `g` has had its body checked. For this reason, there's three + // types for each functuion: before its body is checked, during checking its body, + // and after its body is checked. + // + // We currently treat the before-type and the during-type as the same, + // which can result in some oddness, as the before-type is usually a monotype, + // and the after-type is often a polytype. For example: + // + // ``` + // function f(x) local x: number = g(37) return x end + // function g(x) return f(x) end + // ``` + // The before-type of g is `(X)->Y...` but during type-checking of `f` we will + // unify that with `(number)->number`. The types end up being + // ``` + // function f(x:a):a local x: number = g(37) return x end + // function g(x:number):number return f(x) end + // ``` + if (containsFunctionCall(**protoIter)) + { + while (checkIter != protoIter) + { + checkBody(*checkIter); + ++checkIter; + } + + // We do check the current element, so advance checkIter beyond it. + ++checkIter; + check(scope, **protoIter); + } + else if (auto fun = (*protoIter)->as()) + { + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt); + auto [funTy, funScope] = pair; + + functionDecls[*protoIter] = pair; + ++subLevel; + + TypeId leftType = checkFunctionName(scope, *fun->name); + unify(leftType, funTy, fun->location); + } + else if (auto fun = (*protoIter)->as()) + { + auto pair = checkFunctionSignature(scope, subLevel, *fun->func, fun->name->location, std::nullopt); + auto [funTy, funScope] = pair; + + functionDecls[*protoIter] = pair; + ++subLevel; + + scope->bindings[fun->name] = {funTy, fun->name->location}; + } + else + check(scope, **protoIter); + + ++protoIter; + } + + while (checkIter != sorted.end()) + { + checkBody(*checkIter); + ++checkIter; + } + + checkBlockTypeAliases(scope, sorted); +} + +LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted) +{ + for (const auto& stat : sorted) + { + if (const auto& typealias = stat->as()) + { + auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; + + Name name = typealias->name.value; + TypeId type = bindings[name].type; + if (get(FFlag::LuauAddMissingFollow ? follow(type) : type)) + { + *asMutable(type) = ErrorTypeVar{}; + reportError(TypeError{typealias->location, OccursCheckFailed{}}); + } + } + } +} + +static std::optional tryGetTypeGuardPredicate(const AstExprBinary& expr) +{ + if (expr.op != AstExprBinary::Op::CompareEq && expr.op != AstExprBinary::Op::CompareNe) + return std::nullopt; + + AstExpr* left = expr.left; + AstExpr* right = expr.right; + + if (left->as()) + std::swap(left, right); + + AstExprConstantString* str = right->as(); + if (!str) + return std::nullopt; + + AstExprCall* call = left->as(); + if (!call) + return std::nullopt; + + AstExprGlobal* callee = call->func->as(); + if (!callee) + return std::nullopt; + + if (callee->name != "type" && callee->name != "typeof") + return std::nullopt; + + if (call->args.size != 1) + return std::nullopt; + + // If ssval is not a valid constant string, we'll find out later when resolving predicate. + Name ssval(str->value.data, str->value.size); + bool isTypeof = callee->name == "typeof"; + + std::optional lvalue = tryGetLValue(*call->args.data[0]); + if (!lvalue) + return std::nullopt; + + Predicate predicate{TypeGuardPredicate{std::move(*lvalue), expr.location, ssval, isTypeof}}; + if (expr.op == AstExprBinary::Op::CompareNe) + return NotPredicate{{std::move(predicate)}}; + + return predicate; +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) +{ + ExprResult result = checkExpr(scope, *statement.condition); + + ScopePtr ifScope = childScope(scope, statement.thenbody->location); + reportErrors(resolve(result.predicates, ifScope, true)); + check(ifScope, *statement.thenbody); + + if (statement.elsebody) + { + ScopePtr elseScope = childScope(scope, statement.elsebody->location); + resolve(result.predicates, elseScope, false); + check(elseScope, *statement.elsebody); + } +} + +ErrorVec TypeChecker::canUnify(TypeId left, TypeId right, const Location& location) +{ + return canUnify_(left, right, location); +} + +ErrorVec TypeChecker::canUnify(TypePackId left, TypePackId right, const Location& location) +{ + return canUnify_(left, right, location); +} + +ErrorVec TypeChecker::canUnify(const std::vector>& seen, TypeId superTy, TypeId subTy, const Location& location) +{ + Unifier state = mkUnifier(seen, location); + return state.canUnify(superTy, subTy); +} + +template +ErrorVec TypeChecker::canUnify_(Id superTy, Id subTy, const Location& location) +{ + Unifier state = mkUnifier(location); + return state.canUnify(superTy, subTy); +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) +{ + ExprResult result = checkExpr(scope, *statement.condition); + + ScopePtr whileScope = childScope(scope, statement.body->location); + reportErrors(resolve(result.predicates, whileScope, true)); + check(whileScope, *statement.body); +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) +{ + ScopePtr repScope = childScope(scope, statement.location); + + checkBlock(repScope, *statement.body); + + checkExpr(repScope, *statement.condition); +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) +{ + std::vector> expectedTypes; + + if (FFlag::LuauInferReturnAssertAssign) + { + expectedTypes.reserve(return_.list.size); + + TypePackIterator expectedRetCurr = begin(scope->returnType); + TypePackIterator expectedRetEnd = end(scope->returnType); + + for (size_t i = 0; i < return_.list.size; ++i) + { + if (expectedRetCurr != expectedRetEnd) + { + expectedTypes.push_back(*expectedRetCurr); + ++expectedRetCurr; + } + else if (auto expectedArgsTail = expectedRetCurr.tail()) + { + if (const VariadicTypePack* vtp = get(follow(*expectedArgsTail))) + expectedTypes.push_back(vtp->ty); + } + } + } + + TypePackId retPack = checkExprList(scope, return_.location, return_.list, false, {}, expectedTypes).type; + + // HACK: Nonstrict mode gets a bit too smart and strict for us when we + // start typechecking everything across module boundaries. + if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) + { + ErrorVec errors = tryUnify(scope->returnType, retPack, return_.location); + + if (!errors.empty()) + currentModule->getModuleScope()->returnType = addTypePack({anyType}); + + return; + } + + unify(scope->returnType, retPack, return_.location, CountMismatch::Context::Return); +} + +ErrorVec TypeChecker::tryUnify(TypeId left, TypeId right, const Location& location) +{ + return tryUnify_(left, right, location); +} + +ErrorVec TypeChecker::tryUnify(TypePackId left, TypePackId right, const Location& location) +{ + return tryUnify_(left, right, location); +} + +template +ErrorVec TypeChecker::tryUnify_(Id left, Id right, const Location& location) +{ + Unifier state = mkUnifier(location); + state.tryUnify(left, right); + + if (!state.errors.empty()) + state.log.rollback(); + + return state.errors; +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) +{ + std::vector> expectedTypes; + + if (FFlag::LuauInferReturnAssertAssign) + { + expectedTypes.reserve(assign.vars.size); + + ScopePtr moduleScope = currentModule->getModuleScope(); + + for (size_t i = 0; i < assign.vars.size; ++i) + { + AstExpr* dest = assign.vars.data[i]; + + if (auto a = dest->as()) + { + // AstExprLocal l-values will have to be checked again because their type might have been mutated during checkExprList later + expectedTypes.push_back(scope->lookup(a->local)); + } + else if (auto a = dest->as()) + { + // AstExprGlobal l-values lookup is inlined here to avoid creating a global binding before checkExprList + if (auto it = moduleScope->bindings.find(a->name); it != moduleScope->bindings.end()) + expectedTypes.push_back(it->second.typeId); + else + expectedTypes.push_back(std::nullopt); + } + else + { + expectedTypes.push_back(checkLValue(scope, *dest)); + } + } + } + + TypePackId valuePack = checkExprList(scope, assign.location, assign.values, false, {}, expectedTypes).type; + + auto valueIter = begin(valuePack); + auto valueEnd = end(valuePack); + + TypePack* growingPack = nullptr; + + for (size_t i = 0; i < assign.vars.size; ++i) + { + AstExpr* dest = assign.vars.data[i]; + TypeId left = nullptr; + + if (!FFlag::LuauInferReturnAssertAssign || dest->is() || dest->is()) + left = checkLValue(scope, *dest); + else + left = *expectedTypes[i]; + + TypeId right = nullptr; + + Location loc = 0 == assign.values.size + ? assign.location + : i < assign.values.size ? assign.values.data[i]->location : assign.values.data[assign.values.size - 1]->location; + + if (valueIter != valueEnd) + { + right = follow(*valueIter); + ++valueIter; + } + else if (growingPack) + { + growingPack->head.push_back(left); + continue; + } + else if (auto tail = valueIter.tail()) + { + if (get(*tail)) + right = errorType; + else if (auto vtp = get(*tail)) + right = vtp->ty; + else if (get(*tail)) + { + *asMutable(*tail) = TypePack{{left}}; + growingPack = getMutable(*tail); + } + } + + if (right) + { + if (FFlag::LuauGenericFunctions && !maybeGeneric(left) && isGeneric(right)) + right = instantiate(scope, right, loc); + + if (!FFlag::LuauGenericFunctions && get(FFlag::LuauAddMissingFollow ? follow(left) : left) && + get(FFlag::LuauAddMissingFollow ? follow(right) : right)) + right = instantiate(scope, right, loc); + + // Setting a table entry to nil doesn't mean nil is the type of the indexer, it is just deleting the entry + const TableTypeVar* destTableTypeReceivingNil = nullptr; + if (auto indexExpr = dest->as(); isNil(right) && indexExpr) + destTableTypeReceivingNil = getTableType(checkExpr(scope, *indexExpr->expr).type); + + if (!destTableTypeReceivingNil || !destTableTypeReceivingNil->indexer) + { + // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. + if (isNonstrictMode() && get(FFlag::LuauAddMissingFollow ? follow(left) : left) && !get(follow(right))) + unify(left, anyType, loc); + else + unify(left, right, loc); + } + } + } +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assign) +{ + AstExprBinary expr(assign.location, assign.op, assign.var, assign.value); + + TypeId left = checkExpr(scope, *expr.left).type; + TypeId right = checkExpr(scope, *expr.right).type; + + TypeId result = checkBinaryOperation(scope, expr, left, right); + + unify(left, result, assign.location); +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) +{ + // Important subtlety: A local variable is not in scope while its initializer is being evaluated. + // For instance, you cannot do this: + // local a = function() return a end + + AstLocal** vars = local.vars.data; + + std::vector> varBindings; + varBindings.reserve(local.vars.size); + + std::vector variableTypes; + variableTypes.reserve(local.vars.size); + + std::vector> expectedTypes; + expectedTypes.reserve(local.vars.size); + + std::vector instantiateGenerics; + + for (size_t i = 0; i < local.vars.size; ++i) + { + const AstType* annotation = vars[i]->annotation; + const bool rhsIsTable = local.values.size > i && local.values.data[i]->as(); + + TypeId ty = nullptr; + + if (annotation) + { + ty = (FFlag::LuauRankNTypes ? resolveType(scope, *annotation) : resolveType(scope, *annotation, true)); + + // If the annotation type has an error, treat it as if there was no annotation + if (get(follow(ty))) + ty = nullptr; + } + + if (!ty) + ty = rhsIsTable ? (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)) + : isNonstrictMode() ? anyType : (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + + varBindings.emplace_back(vars[i], Binding{ty, vars[i]->location}); + + variableTypes.push_back(ty); + expectedTypes.push_back(ty); + + if (FFlag::LuauGenericFunctions) + instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); + else + instantiateGenerics.push_back(annotation != nullptr && get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)); + } + + if (local.values.size > 0) + { + TypePackId variablePack = addTypePack(variableTypes, FFlag::LuauRankNTypes ? freshTypePack(scope) : DEPRECATED_freshTypePack(scope, true)); + TypePackId valuePack = + checkExprList(scope, local.location, local.values, /* substituteFreeForNil= */ true, instantiateGenerics, expectedTypes).type; + + Unifier state = mkUnifier(local.location); + state.ctx = CountMismatch::Result; + state.tryUnify(variablePack, valuePack); + reportErrors(state.errors); + + // In the code 'local T = {}', we wish to ascribe the name 'T' to the type of the table for error-reporting purposes. + // We also want to do this for 'local T = setmetatable(...)'. + if (local.vars.size == 1 && local.values.size == 1) + { + const AstExpr* rhs = local.values.data[0]; + std::optional ty = first(valuePack); + + if (ty) + { + if (rhs->is()) + { + TableTypeVar* ttv = getMutable(follow(*ty)); + if (ttv && !ttv->name && scope == currentModule->getModuleScope()) + ttv->syntheticName = vars[0]->name.value; + } + else if (const AstExprCall* call = rhs->as()) + { + if (const AstExprGlobal* global = call->func->as(); global && global->name == "setmetatable") + { + MetatableTypeVar* mtv = getMutable(follow(*ty)); + if (mtv) + mtv->syntheticName = vars[0]->name.value; + } + } + } + } + + // Handle 'require' calls, we need to import exported type bindings into the variable 'namespace' and to update binding type in non-strict + // mode + for (size_t i = 0; i < local.values.size && i < local.vars.size; ++i) + { + const AstExprCall* call = local.values.data[i]->as(); + if (!call) + continue; + + if (auto maybeRequire = matchRequire(*call)) + { + AstExpr* require = *maybeRequire; + + if (auto moduleInfo = resolver->resolveModuleInfo(currentModuleName, *require)) + { + const Name name{local.vars.data[i]->name.value}; + + if (ModulePtr module = resolver->getModule(moduleInfo->name)) + scope->importedTypeBindings[name] = module->getModuleScope()->exportedTypeBindings; + + // In non-strict mode we force the module type on the variable, in strict mode it is already unified + if (isNonstrictMode()) + { + auto [types, tail] = flatten(valuePack); + + if (i < types.size()) + varBindings[i].second.typeId = types[i]; + } + } + } + } + } + + for (const auto& [local, binding] : varBindings) + scope->bindings[local] = binding; +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) +{ + ScopePtr loopScope = childScope(scope, expr.location); + + TypeId loopVarType = numberType; + if (expr.var->annotation) + unify(resolveType(scope, *expr.var->annotation), loopVarType, expr.location); + + loopScope->bindings[expr.var] = {loopVarType, expr.var->location}; + + if (!expr.from) + ice("Bad AstStatFor has no from expr"); + + if (!expr.to) + ice("Bad AstStatFor has no to expr"); + + unify(loopVarType, checkExpr(loopScope, *expr.from).type, expr.from->location); + unify(loopVarType, checkExpr(loopScope, *expr.to).type, expr.to->location); + + if (expr.step) + unify(loopVarType, checkExpr(loopScope, *expr.step).type, expr.step->location); + + check(loopScope, *expr.body); +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) +{ + ScopePtr loopScope = childScope(scope, forin.location); + + AstLocal** vars = forin.vars.data; + + std::vector varTypes; + varTypes.reserve(forin.vars.size); + + for (size_t i = 0; i < forin.vars.size; ++i) + { + AstType* ann = vars[i]->annotation; + TypeId ty = ann ? resolveType(scope, *ann) : anyIfNonstrict(freshType(loopScope)); + + loopScope->bindings[vars[i]] = {ty, vars[i]->location}; + varTypes.push_back(ty); + } + + AstExpr** values = forin.values.data; + AstExpr* firstValue = forin.values.data[0]; + + // next is a function that takes Table and an optional index of type K + // next(t: Table, index: K | nil) -> (K, V) + // however, pairs and ipairs are quite messy, but they both share the same types + // pairs returns 'next, t, nil', thus the type would be + // pairs(t: Table) -> ((Table, K | nil) -> (K, V), Table, K | nil) + // ipairs returns 'next, t, 0', thus ipairs will also share the same type as pairs, except K = number + // + // we can also define our own custom iterators by by returning a wrapped coroutine that calls coroutine.yield + // and most custom iterators does not return a table state, or returns a function that takes no additional arguments, making it optional + // so we up with this catch-all type constraint that works for all use cases + // (free) -> ((free) -> R, Table | nil, K | nil) + + if (!firstValue) + ice("expected at least an iterator function value, but we parsed nothing"); + + TypeId iterTy = nullptr; + TypePackId callRetPack = nullptr; + + if (forin.values.size == 1 && firstValue->is()) + { + AstExprCall* exprCall = firstValue->as(); + callRetPack = checkExprPack(scope, *exprCall).type; + if (!FFlag::LuauRankNTypes) + callRetPack = DEPRECATED_instantiate(scope, callRetPack, exprCall->location); + callRetPack = follow(callRetPack); + + if (get(callRetPack)) + { + iterTy = freshType(scope); + unify(addTypePack({{iterTy}, freshTypePack(scope)}), callRetPack, forin.location); + } + else if (get(callRetPack) || !first(callRetPack)) + { + for (TypeId var : varTypes) + unify(var, errorType, forin.location); + + return check(loopScope, *forin.body); + } + else + { + iterTy = *first(callRetPack); + if (FFlag::LuauRankNTypes) + iterTy = instantiate(scope, iterTy, exprCall->location); + } + } + else + { + iterTy = follow(instantiate(scope, checkExpr(scope, *firstValue).type, firstValue->location)); + } + + const FunctionTypeVar* iterFunc = get(iterTy); + if (!iterFunc) + { + TypeId varTy = get(iterTy) ? anyType : errorType; + + for (TypeId var : varTypes) + unify(var, varTy, forin.location); + + if (!get(iterTy) && !get(iterTy) && !get(iterTy)) + reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); + + return check(loopScope, *forin.body); + } + + if (forin.values.size == 1) + { + TypePackId argPack = nullptr; + if (firstValue->is()) + { + // Extract the remaining return values of the call + // and check them against the parameter types of the iterator function. + auto [types, tail] = flatten(callRetPack); + std::vector argTypes = std::vector(types.begin() + 1, types.end()); + argPack = addTypePack(TypePackVar{TypePack{std::move(argTypes), tail}}); + } + else + { + // Check if iterator function accepts 0 arguments + argPack = addTypePack(TypePack{}); + } + + Unifier state = mkUnifier(firstValue->location); + checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); + + reportErrors(state.errors); + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{varTypes, freshTypePack(scope)}}); + + if (forin.values.size >= 2) + { + AstArray arguments{forin.values.data + 1, forin.values.size - 1}; + + Position start = firstValue->location.begin; + Position end = values[forin.values.size - 1]->location.end; + AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; + + TypePackId retPack = checkExprPack(scope, exprCall).type; + unify(varPack, retPack, forin.location); + } + else + unify(varPack, iterFunc->retType, forin.location); + + check(loopScope, *forin.body); +} + +void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function) +{ + if (auto exprName = function.name->as()) + { + auto& globalBindings = currentModule->getModuleScope()->bindings; + Symbol name = exprName->name; + Name globalName = exprName->name.value; + + Binding oldBinding; + bool previouslyDefined = isNonstrictMode() && globalBindings.count(name); + + if (previouslyDefined) + { + oldBinding = globalBindings[name]; + } + + globalBindings[name] = {ty, exprName->location}; + checkFunctionBody(funScope, ty, *function.func); + + // If in nonstrict mode and allowing redefinition of global function, restore the previous definition type + // in case this function has a differing signature. The signature discrepency will be caught in checkBlock. + if (previouslyDefined) + globalBindings[name] = oldBinding; + else + globalBindings[name] = {quantify(funScope, ty, exprName->location), exprName->location}; + + return; + } + else if (auto name = function.name->as()) + { + scope->bindings[name->local] = {ty, name->local->location}; + + checkFunctionBody(funScope, ty, *function.func); + + scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; + return; + } + else if (function.func->self) + { + AstExprIndexName* indexName = function.name->as(); + if (!indexName) + ice("member function declaration has malformed name expression"); + + TypeId selfTy = checkExpr(scope, *indexName->expr).type; + TableTypeVar* tableSelf = getMutableTableType(selfTy); + if (!tableSelf) + { + if (isTableIntersection(selfTy)) + reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); + else if (!get(selfTy) && !get(selfTy)) + reportError(TypeError{function.location, OnlyTablesCanHaveMethods{selfTy}}); + } + else if (tableSelf->state == TableState::Sealed) + reportError(TypeError{function.location, CannotExtendTable{selfTy, CannotExtendTable::Property, indexName->index.value}}); + + ty = follow(ty); + + if (tableSelf && !selfTy->persistent) + tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; + + const FunctionTypeVar* funTy = get(ty); + if (!funTy) + ice("Methods should be functions"); + + std::optional arg0 = first(funTy->argTypes); + if (!arg0) + ice("Methods should always have at least 1 argument (self)"); + + checkFunctionBody(funScope, ty, *function.func); + + if (tableSelf && !selfTy->persistent) + tableSelf->props[indexName->index.value] = { + follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; + } + else + { + auto [leftType, leftTypeBinding] = checkLValueBinding(scope, *function.name); + + checkFunctionBody(funScope, ty, *function.func); + + unify(leftType, ty, function.location); + + if (leftTypeBinding) + *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); + } +} + +void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function) +{ + Name name = function.name->name.value; + + scope->bindings[function.name] = {ty, function.location}; + + checkFunctionBody(funScope, ty, *function.func); + + if (FFlag::LuauGenericFunctions) + scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; + else + scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location}; +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare) +{ + // This function should be called at most twice for each type alias. + // Once with forwardDeclare, and once without. + Name name = typealias.name.value; + + std::optional binding; + if (auto it = scope->exportedTypeBindings.find(name); it != scope->exportedTypeBindings.end()) + binding = it->second; + else if (auto it = scope->privateTypeBindings.find(name); it != scope->privateTypeBindings.end()) + binding = it->second; + + auto& bindingsMap = typealias.exported ? scope->exportedTypeBindings : scope->privateTypeBindings; + + if (forwardDeclare) + { + if (binding) + { + Location location = scope->typeAliasLocations[name]; + reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); + bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + } + else + { + ScopePtr aliasScope = childScope(scope, typealias.location); + + std::vector generics; + for (AstName generic : typealias.generics) + { + Name n = generic.value; + + // These generics are the only thing that will ever be added to aliasScope, so we can be certain that + // a collision can only occur when two generic typevars have the same name. + if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) + { + // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. + reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); + } + + TypeId g; + if (FFlag::LuauRecursiveTypeParameterRestriction) + { + TypeId& cached = scope->typeAliasParameters[n]; + if (!cached) + cached = addType(GenericTypeVar{aliasScope->level, n}); + g = cached; + } + else + g = addType(GenericTypeVar{aliasScope->level, n}); + generics.push_back(g); + aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; + } + + TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), ty}; + } + } + else + { + if (!binding) + ice("Not predeclared"); + + ScopePtr aliasScope = childScope(scope, typealias.location); + + for (TypeId ty : binding->typeParams) + { + auto generic = get(ty); + LUAU_ASSERT(generic); + aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; + } + + TypeId ty = (FFlag::LuauRankNTypes ? resolveType(aliasScope, *typealias.type) : resolveType(aliasScope, *typealias.type, true)); + if (auto ttv = getMutable(follow(ty))) + { + // If the table is already named and we want to rename the type function, we have to bind new alias to a copy + if (ttv->name) + { + // Copy can be skipped if this is an identical alias + if (!FFlag::LuauFixTableTypeAliasClone || ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams) + { + // This is a shallow clone, original recursive links to self are not updated + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; + + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + clone.definitionModuleName = ttv->definitionModuleName; + + clone.name = name; + clone.instantiatedTypeParams = binding->typeParams; + + ty = addType(std::move(clone)); + } + } + else + { + ttv->name = name; + ttv->instantiatedTypeParams = binding->typeParams; + } + } + else if (auto mtv = getMutable(follow(ty))) + mtv->syntheticName = name; + + unify(bindingsMap[name].type, ty, typealias.location); + } +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass) +{ + std::optional superTy = std::nullopt; + if (declaredClass.superName) + { + Name superName = Name(declaredClass.superName->value); + std::optional lookupType = scope->lookupType(superName); + + if (!lookupType) + { + reportError(declaredClass.location, UnknownSymbol{superName, UnknownSymbol::Type}); + return; + } + + // We don't have generic classes, so this assertion _should_ never be hit. + LUAU_ASSERT(lookupType->typeParams.size() == 0); + superTy = lookupType->type; + + if (FFlag::LuauAddMissingFollow) + { + if (!get(follow(*superTy))) + { + reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", + superName.c_str(), declaredClass.name.value)}); + + return; + } + } + else + { + if (const ClassTypeVar* superCtv = get(*superTy); !superCtv) + { + reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", + superName.c_str(), declaredClass.name.value)}); + + return; + } + } + } + + Name className(declaredClass.name.value); + + TypeId classTy = addType(ClassTypeVar(className, {}, superTy, std::nullopt, {}, {})); + ClassTypeVar* ctv = getMutable(classTy); + + TypeId metaTy = addType(TableTypeVar{TableState::Sealed, scope->level}); + TableTypeVar* metatable = getMutable(metaTy); + + ctv->metatable = metaTy; + + scope->exportedTypeBindings[className] = TypeFun{{}, classTy}; + + for (const AstDeclaredClassProp& prop : declaredClass.props) + { + Name propName(prop.name.value); + TypeId propTy = resolveType(scope, *prop.ty); + + bool assignToMetatable = isMetamethod(propName); + + // Function types always take 'self', but this isn't reflected in the + // parsed annotation. Add it here. + if (prop.isMethod) + { + if (FunctionTypeVar* ftv = getMutable(propTy)) + { + ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); + ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + } + } + + if (ctv->props.count(propName) == 0) + { + if (assignToMetatable) + metatable->props[propName] = {propTy}; + else + ctv->props[propName] = {propTy}; + } + else + { + TypeId currentTy = assignToMetatable ? metatable->props[propName].type : ctv->props[propName].type; + + // We special-case this logic to keep the intersection flat; otherwise we + // would create a ton of nested intersection types. + if (const IntersectionTypeVar* itv = get(currentTy)) + { + std::vector options = itv->parts; + options.push_back(propTy); + TypeId newItv = addType(IntersectionTypeVar{std::move(options)}); + + if (assignToMetatable) + metatable->props[propName] = {newItv}; + else + ctv->props[propName] = {newItv}; + } + else if (get(currentTy)) + { + TypeId intersection = addType(IntersectionTypeVar{{currentTy, propTy}}); + + if (assignToMetatable) + metatable->props[propName] = {intersection}; + else + ctv->props[propName] = {intersection}; + } + else + { + reportError(declaredClass.location, GenericError{format("Cannot overload non-function class member '%s'", propName.c_str())}); + } + } + } +} + +void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& global) +{ + ScopePtr funScope = childFunctionScope(scope, global.location); + + auto [generics, genericPacks] = createGenericTypes(funScope, global, global.generics, global.genericPacks); + + TypePackId argPack = resolveTypePack(funScope, global.params); + TypePackId retPack = resolveTypePack(funScope, global.retTypes); + TypeId fnType = addType(FunctionTypeVar{funScope->level, generics, genericPacks, argPack, retPack}); + FunctionTypeVar* ftv = getMutable(fnType); + + ftv->argNames.reserve(global.paramNames.size); + for (const auto& el : global.paramNames) + ftv->argNames.push_back(FunctionArgument{el.first.value, el.second}); + + Name fnName(global.name.value); + + currentModule->declaredGlobals[fnName] = fnType; + currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType) +{ + RecursionCounter _rc(&checkRecursionCount); + if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) + { + reportErrorCodeTooComplex(expr.location); + return {errorType}; + } + + ExprResult result; + + if (auto a = expr.as()) + result = checkExpr(scope, *a->expr); + else if (expr.is()) + result = {nilType}; + else if (expr.is()) + result = {booleanType}; + else if (expr.is()) + result = {numberType}; + else if (expr.is()) + result = {stringType}; + else if (auto a = expr.as()) + result = checkExpr(scope, *a); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); + else if (auto a = expr.as()) + result = checkExpr(scope, *a, expectedType); + else if (auto a = expr.as()) + result = checkExpr(scope, *a, expectedType); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); + else if (auto a = expr.as()) + { + if (FFlag::LuauIfElseExpressionAnalysisSupport) + { + result = checkExpr(scope, *a); + } + else + { + // Note: When the fast flag is disabled we can't skip the handling of AstExprIfElse + // because we would generate an ICE. We also can't use the default value + // of result, because it will lead to a compiler crash. + // Note: LuauIfElseExpressionBaseSupport can be used to disable parser support + // for if-else expressions which will mean this node type is never created. + result = {anyType}; + } + } + else + ice("Unhandled AstExpr?"); + + result.type = follow(result.type); + + if (FFlag::LuauStoreMatchingOverloadFnType) + { + currentModule->astTypes.try_emplace(&expr, result.type); + } + else + { + currentModule->astTypes[&expr] = result.type; + } + + if (expectedType) + currentModule->astExpectedTypes[&expr] = *expectedType; + + return result; +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLocal& expr) +{ + std::optional lvalue = tryGetLValue(expr); + LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprLocal is an LValue. + + if (std::optional ty = resolveLValue(scope, *lvalue)) + return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; + + // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint + // ice("AstExprLocal exists but no binding definition for it?", expr.location); + reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); + return {errorType}; +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) +{ + std::optional lvalue = tryGetLValue(expr); + LUAU_ASSERT(lvalue); // Guaranteed to not be nullopt - AstExprGlobal is an LValue. + + if (std::optional ty = resolveLValue(scope, *lvalue)) + return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; + + reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); + return {errorType}; +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) +{ + TypePackId varargPack = checkExprPack(scope, expr).type; + + if (get(varargPack)) + { + std::vector types = flatten(varargPack).first; + return {!types.empty() ? types[0] : nilType}; + } + else if (auto ftp = get(varargPack)) + { + TypeId head = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, ftp->DEPRECATED_canBeGeneric)); + TypePackId tail = (FFlag::LuauRankNTypes ? freshTypePack(scope) : DEPRECATED_freshTypePack(scope, ftp->DEPRECATED_canBeGeneric)); + *asMutable(varargPack) = TypePack{{head}, tail}; + return {head}; + } + if (get(varargPack)) + return {errorType}; + else if (auto vtp = get(varargPack)) + return {vtp->ty}; + else if (FFlag::LuauGenericVariadicsUnification && get(varargPack)) + { + // TODO: Better error? + reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); + return {errorType}; + } + else + ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCall& expr) +{ + ExprResult result = checkExprPack(scope, expr); + TypePackId retPack = follow(result.type); + + if (auto pack = get(retPack)) + { + return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; + } + else if (auto ftp = get(retPack)) + { + TypeId head = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, ftp->DEPRECATED_canBeGeneric)); + TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); + unify(retPack, pack, expr.location); + return {head, std::move(result.predicates)}; + } + if (get(retPack)) + return {errorType, std::move(result.predicates)}; + else if (auto vtp = get(retPack)) + return {vtp->ty, std::move(result.predicates)}; + else if (get(retPack)) + ice("Unexpected abstract type pack!"); + else + ice("Unknown TypePack type!"); +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) +{ + Name name = expr.index.value; + + // Redundant call if we find a refined lvalue, but this function must be called in order to recursively populate astTypes. + TypeId lhsType = checkExpr(scope, *expr.expr).type; + + if (std::optional lvalue = tryGetLValue(expr)) + if (std::optional ty = resolveLValue(scope, *lvalue)) + return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; + + if (FFlag::LuauExtraNilRecovery) + lhsType = stripFromNilAndReport(lhsType, expr.expr->location); + + if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) + return {*ty}; + + if (!FFlag::LuauMissingUnionPropertyError) + reportError(expr.indexLocation, UnknownProperty{lhsType, expr.index.value}); + + if (!FFlag::LuauExtraNilRecovery) + { + // Try to recover using a union without 'nil' options + if (std::optional strippedUnion = tryStripUnionFromNil(lhsType)) + { + if (std::optional ty = getIndexTypeFromType(scope, *strippedUnion, name, expr.location, false)) + return {*ty}; + } + } + + return {errorType}; +} + +std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) +{ + ErrorVec errors; + auto result = Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); + reportErrors(errors); + return result; +} + +std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location) +{ + ErrorVec errors; + auto result = Luau::findMetatableEntry(errors, globalScope, type, entry, location); + reportErrors(errors); + return result; +} + +std::optional TypeChecker::getIndexTypeFromType( + const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) +{ + type = follow(type); + + if (get(type) || get(type)) + return type; + + tablify(type); + + const PrimitiveTypeVar* primitiveType = get(type); + if (primitiveType && primitiveType->type == PrimitiveTypeVar::String) + { + if (std::optional mtIndex = findMetatableEntry(type, "__index", location)) + type = *mtIndex; + } + + if (TableTypeVar* tableType = getMutableTableType(type)) + { + const auto& it = tableType->props.find(name); + if (it != tableType->props.end()) + return it->second.type; + else if (auto indexer = tableType->indexer) + { + tryUnify(indexer->indexType, stringType, location); + return indexer->indexResultType; + } + else if (tableType->state == TableState::Free) + { + TypeId result = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + tableType->props[name] = {result}; + return result; + } + + auto found = findTablePropertyRespectingMeta(type, name, location); + if (found) + return *found; + } + else if (const ClassTypeVar* cls = get(type)) + { + const Property* prop = lookupClassProp(cls, name); + if (prop) + return prop->type; + } + else if (const UnionTypeVar* utv = get(type)) + { + if (FFlag::LuauMissingUnionPropertyError) + { + std::vector goodOptions; + std::vector badOptions; + + for (TypeId t : utv) + { + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + goodOptions.push_back(*ty); + else + badOptions.push_back(t); + } + + if (!badOptions.empty()) + { + if (addErrors) + { + if (goodOptions.empty()) + reportError(location, UnknownProperty{type, name}); + else + reportError(location, MissingUnionProperty{type, badOptions, name}); + } + return std::nullopt; + } + + std::vector result = reduceUnion(goodOptions); + + if (result.size() == 1) + return result[0]; + + return addType(UnionTypeVar{std::move(result)}); + } + else + { + std::vector options; + + for (TypeId t : utv->options) + { + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + options.push_back(*ty); + else + return std::nullopt; + } + + std::vector result = reduceUnion(options); + + if (result.size() == 1) + return result[0]; + + return addType(UnionTypeVar{std::move(result)}); + } + } + else if (const IntersectionTypeVar* itv = get(type)) + { + std::vector parts; + + for (TypeId t : itv->parts) + { + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + parts.push_back(*ty); + } + + // If no parts of the intersection had the property we looked up for, it never existed at all. + if (parts.empty()) + { + if (FFlag::LuauMissingUnionPropertyError && addErrors) + reportError(location, UnknownProperty{type, name}); + return std::nullopt; + } + + // TODO(amccord): Write some logic to correctly handle intersections. CLI-34659 + std::vector result = reduceUnion(parts); + + if (result.size() == 1) + return result[0]; + + return addType(IntersectionTypeVar{result}); + } + + if (FFlag::LuauMissingUnionPropertyError && addErrors) + reportError(location, UnknownProperty{type, name}); + + return std::nullopt; +} + +std::vector TypeChecker::reduceUnion(const std::vector& types) +{ + std::set s; + + for (TypeId t : types) + { + if (const UnionTypeVar* utv = get(follow(t))) + { + std::vector r = reduceUnion(utv->options); + for (TypeId ty : r) + s.insert(ty); + } + else + s.insert(t); + } + + // If any of them are ErrorTypeVars/AnyTypeVars, decay into them. + for (TypeId t : s) + { + t = follow(t); + if (get(t) || get(t)) + return {t}; + } + + std::vector r(s.begin(), s.end()); + std::sort(r.begin(), r.end()); + return r; +} + +std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) +{ + if (const UnionTypeVar* utv = get(ty)) + { + bool hasNil = false; + + for (TypeId option : utv) + { + if (isNil(option)) + { + hasNil = true; + break; + } + } + + if (!hasNil) + return ty; + + std::vector result; + + for (TypeId option : utv) + { + if (!isNil(option)) + result.push_back(option); + } + + if (result.empty()) + return std::nullopt; + + return result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); + } + + return std::nullopt; +} + +TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) +{ + if (isOptional(ty)) + { + if (std::optional strippedUnion = tryStripUnionFromNil(follow(ty))) + { + reportError(location, OptionalValueAccess{ty}); + return follow(*strippedUnion); + } + } + + return ty; +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) +{ + return {checkLValue(scope, expr)}; +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) +{ + auto [funTy, funScope] = checkFunctionSignature(scope, 0, expr, std::nullopt, expectedType); + + checkFunctionBody(funScope, funTy, expr); + + return {quantify(funScope, funTy, expr.location)}; +} + +TypeId TypeChecker::checkExprTable( + const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType) +{ + TableTypeVar::Props props; + std::optional indexer; + + const TableTypeVar* expectedTable = nullptr; + + if (expectedType) + { + if (auto ttv = get(follow(*expectedType))) + { + if (ttv->state == TableState::Sealed) + expectedTable = ttv; + } + } + + for (size_t i = 0; i < expr.items.size; ++i) + { + const AstExprTable::Item& item = expr.items.data[i]; + + AstExpr* k = item.key; + AstExpr* value = item.value; + + auto [keyType, valueType] = fieldTypes[i]; + + if (item.kind == AstExprTable::Item::List) + { + if (expectedTable && !indexer) + indexer = expectedTable->indexer; + + if (indexer) + unify(indexer->indexResultType, valueType, value->location); + else + indexer = TableIndexer{numberType, anyIfNonstrict(valueType)}; + } + else if (item.kind == AstExprTable::Item::Record || item.kind == AstExprTable::Item::General) + { + if (auto key = k->as()) + { + TypeId exprType = follow(valueType); + if (isNonstrictMode() && !getTableType(exprType) && !get(exprType)) + exprType = anyType; + + props[key->value.data] = {exprType, /* deprecated */ false, {}, k->location}; + } + else + { + if (expectedTable && !indexer) + indexer = expectedTable->indexer; + + if (indexer) + { + unify(indexer->indexType, keyType, k->location); + unify(indexer->indexResultType, valueType, value->location); + } + else if (isNonstrictMode()) + { + indexer = TableIndexer{anyType, anyType}; + } + else + { + indexer = TableIndexer{keyType, valueType}; + } + } + } + } + + TableState state = (expr.items.size == 0 || isNonstrictMode()) ? TableState::Unsealed : TableState::Sealed; + TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; + table.definitionModuleName = currentModuleName; + return addType(table); +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTable& expr, std::optional expectedType) +{ + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + std::vector> fieldTypes(expr.items.size); + + const TableTypeVar* expectedTable = nullptr; + std::optional expectedIndexType; + std::optional expectedIndexResultType; + + if (expectedType) + { + if (auto ttv = get(follow(*expectedType))) + { + if (ttv->state == TableState::Sealed) + { + expectedTable = ttv; + + if (ttv->indexer) + { + expectedIndexType = ttv->indexer->indexType; + expectedIndexResultType = ttv->indexer->indexResultType; + } + } + } + } + + for (size_t i = 0; i < expr.items.size; ++i) + { + AstExprTable::Item& item = expr.items.data[i]; + std::optional expectedResultType; + bool isIndexedItem = false; + + if (item.kind == AstExprTable::Item::List) + { + expectedResultType = expectedIndexResultType; + isIndexedItem = true; + } + else if (item.kind == AstExprTable::Item::Record || item.kind == AstExprTable::Item::General) + { + if (auto key = item.key->as()) + { + if (expectedTable) + { + if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) + expectedResultType = prop->second.type; + } + } + else + { + expectedResultType = expectedIndexResultType; + isIndexedItem = true; + } + } + + fieldTypes[i].first = item.key ? checkExpr(scope, *item.key, expectedIndexType).type : nullptr; + fieldTypes[i].second = checkExpr(scope, *item.value, expectedResultType).type; + + // Indexer keys after the first are unified with the first one + // If we don't have an expected indexer type yet, take this first item type + if (isIndexedItem && !expectedIndexResultType) + expectedIndexResultType = fieldTypes[i].second; + } + + return {checkExprTable(scope, expr, fieldTypes, expectedType)}; +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUnary& expr) +{ + ExprResult result = checkExpr(scope, *expr.expr); + TypeId operandType = follow(result.type); + + switch (expr.op) + { + case AstExprUnary::Not: + return {booleanType, {NotPredicate{std::move(result.predicates)}}}; + case AstExprUnary::Minus: + { + const bool operandIsAny = get(operandType) || get(operandType); + + if (operandIsAny) + return {operandType}; + + if (typeCouldHaveMetatable(operandType)) + { + if (auto fnt = findMetatableEntry(operandType, "__unm", expr.location)) + { + TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); + TypePackId arguments = addTypePack({operandType}); + TypePackId retType = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + + Unifier state = mkUnifier(expr.location); + state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + + if (!state.errors.empty()) + return {errorType}; + + return {first(retType).value_or(nilType)}; + } + + reportError(expr.location, + GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); + return {errorType}; + } + + reportErrors(tryUnify(numberType, operandType, expr.location)); + return {numberType}; + } + case AstExprUnary::Len: + tablify(operandType); + + if (FFlag::LuauExtraNilRecovery) + operandType = stripFromNilAndReport(operandType, expr.location); + + if (get(operandType)) + return {errorType}; + + if (get(operandType)) + return {numberType}; // Not strictly correct: metatables permit overriding this + + if (auto p = get(operandType)) + { + if (p->type == PrimitiveTypeVar::String) + return {numberType}; + } + + if (!getTableType(operandType)) + reportError(TypeError{expr.location, NotATable{operandType}}); + + return {numberType}; + + default: + ice("Unknown AstExprUnary " + std::to_string(int(expr.op))); + } +} + +std::string opToMetaTableEntry(const AstExprBinary::Op& op) +{ + switch (op) + { + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + return "__eq"; + case AstExprBinary::CompareLt: + case AstExprBinary::CompareGe: + return "__lt"; + case AstExprBinary::CompareLe: + case AstExprBinary::CompareGt: + return "__le"; + case AstExprBinary::Add: + return "__add"; + case AstExprBinary::Sub: + return "__sub"; + case AstExprBinary::Mul: + return "__mul"; + case AstExprBinary::Div: + return "__div"; + case AstExprBinary::Mod: + return "__mod"; + case AstExprBinary::Pow: + return "__pow"; + case AstExprBinary::Concat: + return "__concat"; + default: + return ""; + } +} + +TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes) +{ + if (unifyFreeTypes && (get(a) || get(b))) + { + if (unify(a, b, location)) + return a; + + return errorType; + } + + if (*a == *b) + return a; + + std::vector types = reduceUnion({a, b}); + if (types.size() == 1) + return types[0]; + + return addType(UnionTypeVar{types}); +} + +static std::optional getIdentifierOfBaseVar(AstExpr* node) +{ + if (AstExprGlobal* expr = node->as()) + return expr->name.value; + + if (AstExprLocal* expr = node->as()) + return expr->local->name.value; + + if (AstExprIndexExpr* expr = node->as()) + return getIdentifierOfBaseVar(expr->expr); + + if (AstExprIndexName* expr = node->as()) + return getIdentifierOfBaseVar(expr->expr); + + return std::nullopt; +} + +TypeId TypeChecker::checkRelationalOperation( + const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates) +{ + auto stripNil = [this](TypeId ty, bool isOrOp = false) { + ty = follow(ty); + if (!isNonstrictMode() && !isOrOp) + return ty; + + if (auto i = get(ty)) + { + std::optional cleaned = tryStripUnionFromNil(ty); + + // If there is no union option without 'nil' + if (!cleaned) + return nilType; + + return follow(*cleaned); + } + + return follow(ty); + }; + + bool isEquality = expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe; + + lhsType = stripNil(lhsType, expr.op == AstExprBinary::Or); + rhsType = stripNil(rhsType); + + // If we know nothing at all about the lhs type, we can usually say nothing about the result. + // The notable exception to this is the equality and inequality operators, which always produce a boolean. + const bool lhsIsAny = get(lhsType) || get(lhsType); + + // Peephole check for `cond and a or b -> type(a)|type(b)` + // TODO: Kill this when singleton types arrive. :( + if (AstExprBinary* subexp = expr.left->as()) + { + if (expr.op == AstExprBinary::Or && subexp->op == AstExprBinary::And) + { + if (FFlag::LuauSlightlyMoreFlexibleBinaryPredicates) + { + ScopePtr subScope = childScope(scope, subexp->location); + reportErrors(resolve(predicates, subScope, true)); + return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); + } + else + { + return unionOfTypes(rhsType, checkExpr(scope, *subexp->right).type, expr.location); + } + } + } + + // Lua casts the results of these to boolean + switch (expr.op) + { + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + { + if (isNonstrictMode() && (isNil(lhsType) || isNil(rhsType))) + return booleanType; + + const bool rhsIsAny = get(rhsType) || get(rhsType); + if (lhsIsAny || rhsIsAny) + return booleanType; + + // Fallthrough here is intentional + } + case AstExprBinary::CompareLt: + case AstExprBinary::CompareGt: + case AstExprBinary::CompareGe: + case AstExprBinary::CompareLe: + { + /* Subtlety here: + * We need to do this unification first, but there are situations where we don't actually want to + * report any problems that might have been surfaced as a result of this step because we might already + * have a better, more descriptive error teed up. + */ + Unifier state = mkUnifier(expr.location); + if (!FFlag::LuauEqConstraint || !isEquality) + state.tryUnify(lhsType, rhsType); + + bool needsMetamethod = !isEquality; + + TypeId leftType = follow(lhsType); + if (get(leftType) || get(leftType) || get(leftType) || get(leftType)) + { + reportErrors(state.errors); + + const PrimitiveTypeVar* ptv = get(leftType); + if (!isEquality && state.errors.empty() && (get(leftType) || (ptv && ptv->type == PrimitiveTypeVar::Boolean))) + reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), + toString(expr.op).c_str())}); + + return booleanType; + } + + std::string metamethodName = opToMetaTableEntry(expr.op); + + std::optional leftMetatable = + isString(lhsType) ? std::nullopt : getMetatable(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType); + std::optional rightMetatable = + isString(rhsType) ? std::nullopt : getMetatable(FFlag::LuauAddMissingFollow ? follow(rhsType) : rhsType); + + if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) + { + reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", + toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); + return errorType; + } + + if (leftMetatable) + { + std::optional metamethod = findMetatableEntry(lhsType, metamethodName, expr.location); + if (metamethod) + { + if (const FunctionTypeVar* ftv = get(*metamethod)) + { + if (isEquality) + { + Unifier state = mkUnifier(expr.location); + state.tryUnify(ftv->retType, addTypePack({booleanType})); + + if (!state.errors.empty()) + { + reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); + return errorType; + } + } + } + + reportErrors(state.errors); + + TypeId actualFunctionType = addType(FunctionTypeVar(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType}))); + state.tryUnify( + instantiate(scope, *metamethod, expr.location), instantiate(scope, actualFunctionType, expr.location), /*isFunctionCall*/ true); + + reportErrors(state.errors); + return booleanType; + } + else if (needsMetamethod) + { + reportError( + expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); + return errorType; + } + } + + if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && (!FFlag::LuauEqConstraint || !isEquality)) + { + auto name = getIdentifierOfBaseVar(expr.left); + reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); + return errorType; + } + + if (needsMetamethod) + { + reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", + toString(lhsType).c_str(), toString(expr.op).c_str())}); + return errorType; + } + + if (!FFlag::LuauEqConstraint) + { + if (isEquality) + { + ErrorVec errVec = tryUnify(rhsType, lhsType, expr.location); + if (!state.errors.empty() && !errVec.empty()) + reportError(expr.location, TypeMismatch{lhsType, rhsType}); + } + else + reportErrors(state.errors); + } + + return booleanType; + } + + case AstExprBinary::And: + if (lhsIsAny) + return lhsType; + return unionOfTypes(rhsType, booleanType, expr.location, false); + case AstExprBinary::Or: + if (lhsIsAny) + return lhsType; + return unionOfTypes(lhsType, rhsType, expr.location); + default: + LUAU_ASSERT(0); + ice(format("checkRelationalOperation called with incorrect binary expression '%s'", toString(expr.op).c_str()), expr.location); + } +} + +TypeId TypeChecker::checkBinaryOperation( + const ScopePtr& scope, const AstExprBinary& expr, TypeId lhsType, TypeId rhsType, const PredicateVec& predicates) +{ + switch (expr.op) + { + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + case AstExprBinary::CompareLt: + case AstExprBinary::CompareGt: + case AstExprBinary::CompareGe: + case AstExprBinary::CompareLe: + case AstExprBinary::And: + case AstExprBinary::Or: + return checkRelationalOperation(scope, expr, lhsType, rhsType, predicates); + default: + break; + } + + lhsType = follow(lhsType); + rhsType = follow(rhsType); + + if (!isNonstrictMode() && get(lhsType)) + { + auto name = getIdentifierOfBaseVar(expr.left); + reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); + return errorType; + } + + // If we know nothing at all about the lhs type, we can usually say nothing about the result. + // The notable exception to this is the equality and inequality operators, which always produce a boolean. + const bool lhsIsAny = get(lhsType) || get(lhsType); + const bool rhsIsAny = get(rhsType) || get(rhsType); + + if (lhsIsAny) + return lhsType; + if (rhsIsAny) + return rhsType; + + if (get(lhsType)) + { + // Inferring this accurately will get a bit weird. + // If the lhs type is not known, it could be assumed that it is a table or class that has a metatable + // that defines the required method, but we don't know which. + // For now, we'll give up and hope for the best. + return anyType; + } + + if (get(rhsType)) + unify(lhsType, rhsType, expr.location); + + if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) + { + auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { + TypeId actualFunctionType = instantiate(scope, fnt, expr.location); + TypePackId arguments = addTypePack({lhst, rhst}); + TypePackId retType = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + + Unifier state = mkUnifier(expr.location); + state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + + reportErrors(state.errors); + + if (!state.errors.empty()) + return errorType; + + return first(retType).value_or(nilType); + }; + + std::string op = opToMetaTableEntry(expr.op); + if (auto fnt = findMetatableEntry(lhsType, op, expr.location)) + return checkMetatableCall(*fnt, lhsType, rhsType); + if (auto fnt = findMetatableEntry(rhsType, op, expr.location)) + { + // Note the intentionally reversed arguments here. + return checkMetatableCall(*fnt, rhsType, lhsType); + } + + reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), + toString(lhsType).c_str(), toString(rhsType).c_str())}); + return errorType; + } + + switch (expr.op) + { + case AstExprBinary::Concat: + reportErrors(tryUnify(addType(UnionTypeVar{{stringType, numberType}}), lhsType, expr.left->location)); + reportErrors(tryUnify(addType(UnionTypeVar{{stringType, numberType}}), rhsType, expr.right->location)); + return stringType; + case AstExprBinary::Add: + case AstExprBinary::Sub: + case AstExprBinary::Mul: + case AstExprBinary::Div: + case AstExprBinary::Mod: + case AstExprBinary::Pow: + reportErrors(tryUnify(numberType, lhsType, expr.left->location)); + reportErrors(tryUnify(numberType, rhsType, expr.right->location)); + return numberType; + default: + // These should have been handled with checkRelationalOperation + LUAU_ASSERT(0); + return anyType; + } +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBinary& expr) +{ + if (expr.op == AstExprBinary::And) + { + ExprResult lhs = checkExpr(scope, *expr.left); + + // We can't just report errors here. + // This function can be called from AstStatLocal or from AstStatIf, or even from AstExprBinary (and others). + // For now, ignore the errors returned by the predicate resolver. + // We may need an extra property for each predicate set that indicates it has been resolved. + // Requires a slight modification to the data structure. + ScopePtr innerScope = childScope(scope, expr.location); + resolve(lhs.predicates, innerScope, true); + + ExprResult rhs = checkExpr(innerScope, *expr.right); + if (!FFlag::LuauSlightlyMoreFlexibleBinaryPredicates) + resolve(rhs.predicates, innerScope, true); + + return {checkBinaryOperation(innerScope, expr, lhs.type, rhs.type), {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + } + else if (FFlag::LuauOrPredicate && expr.op == AstExprBinary::Or) + { + ExprResult lhs = checkExpr(scope, *expr.left); + + ScopePtr innerScope = childScope(scope, expr.location); + resolve(lhs.predicates, innerScope, false); + + ExprResult rhs = checkExpr(innerScope, *expr.right); + + // Because of C++, I'm not sure if lhs.predicates was not moved out by the time we call checkBinaryOperation. + TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates); + return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + } + else if (FFlag::LuauEqConstraint && (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe)) + { + if (auto predicate = tryGetTypeGuardPredicate(expr)) + return {booleanType, {std::move(*predicate)}}; + + ExprResult lhs = checkExpr(scope, *expr.left); + ExprResult rhs = checkExpr(scope, *expr.right); + + PredicateVec predicates; + + if (auto lvalue = tryGetLValue(*expr.left)) + predicates.push_back(EqPredicate{std::move(*lvalue), rhs.type, expr.location}); + + if (auto lvalue = tryGetLValue(*expr.right)) + predicates.push_back(EqPredicate{std::move(*lvalue), lhs.type, expr.location}); + + if (!predicates.empty() && expr.op == AstExprBinary::CompareNe) + predicates = {NotPredicate{std::move(predicates)}}; + + return {checkBinaryOperation(scope, expr, lhs.type, rhs.type), std::move(predicates)}; + } + else + { + // Once we have EqPredicate, we should break this else branch into its' own branch. + // For now, fall through is intentional. + if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) + { + if (auto predicate = tryGetTypeGuardPredicate(expr)) + return {booleanType, {std::move(*predicate)}}; + } + + ExprResult lhs = checkExpr(scope, *expr.left); + ExprResult rhs = checkExpr(scope, *expr.right); + + // Intentionally discarding predicates with other operators. + return {checkBinaryOperation(scope, expr, lhs.type, rhs.type, lhs.predicates)}; + } +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) +{ + ExprResult result; + TypeId annotationType; + + if (FFlag::LuauInferReturnAssertAssign) + { + annotationType = (FFlag::LuauRankNTypes ? resolveType(scope, *expr.annotation) : resolveType(scope, *expr.annotation, true)); + result = checkExpr(scope, *expr.expr, annotationType); + } + else + { + result = checkExpr(scope, *expr.expr); + annotationType = (FFlag::LuauRankNTypes ? resolveType(scope, *expr.annotation) : resolveType(scope, *expr.annotation, true)); + } + + ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + if (!errorVec.empty()) + { + reportErrors(errorVec); + return {errorType, std::move(result.predicates)}; + } + + return {annotationType, std::move(result.predicates)}; +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) +{ + const size_t oldSize = currentModule->errors.size(); + + for (AstExpr* expr : expr.expressions) + checkExpr(scope, *expr); + + // HACK: We want to check the contents of the AstExprError, but + // any type errors that may arise from it are going to be useless. + currentModule->errors.resize(oldSize); + + return {errorType}; +} + +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) +{ + ExprResult result = checkExpr(scope, *expr.condition); + ScopePtr trueScope = childScope(scope, expr.trueExpr->location); + reportErrors(resolve(result.predicates, trueScope, true)); + ExprResult trueType = checkExpr(trueScope, *expr.trueExpr); + + ScopePtr falseScope = childScope(scope, expr.falseExpr->location); + // Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope. + resolve(result.predicates, falseScope, false); + ExprResult falseType = checkExpr(falseScope, *expr.falseExpr); + + unify(trueType.type, falseType.type, expr.location); + + // TODO: normalize(UnionTypeVar{{trueType, falseType}}) + // For now both trueType and falseType must be the same type. + return {trueType.type}; +} + +TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) +{ + auto [ty, binding] = checkLValueBinding(scope, expr); + return ty; +} + +std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) +{ + if (auto a = expr.as()) + return checkLValueBinding(scope, *a); + else if (auto a = expr.as()) + return checkLValueBinding(scope, *a); + else if (auto a = expr.as()) + return checkLValueBinding(scope, *a); + else if (auto a = expr.as()) + return checkLValueBinding(scope, *a); + else if (auto a = expr.as()) + { + for (AstExpr* expr : a->expressions) + checkExpr(scope, *expr); + return std::pair(errorType, nullptr); + } + else + ice("Unexpected AST node in checkLValue", expr.location); +} + +std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) +{ + if (std::optional ty = scope->lookup(expr.local)) + return {*ty, nullptr}; + + reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); + return {errorType, nullptr}; +} + +std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) +{ + Name name = expr.name.value; + ScopePtr moduleScope = currentModule->getModuleScope(); + + const auto it = moduleScope->bindings.find(expr.name); + + if (it != moduleScope->bindings.end()) + return std::pair(it->second.typeId, &it->second.typeId); + + if (isNonstrictMode() || FFlag::LuauSecondTypecheckKnowsTheDataModel) + { + TypeId result = (FFlag::LuauGenericFunctions && FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(moduleScope, true)); + + Binding& binding = moduleScope->bindings[expr.name]; + binding = {result, expr.location}; + + // If we're in strict mode, we want to report defining a global as an error, + // but still add it to the bindings, so that autocomplete includes it in completions. + if (!isNonstrictMode()) + reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); + + return std::pair(result, &binding.typeId); + } + + reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); + return std::pair(errorType, nullptr); +} + +std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) +{ + TypeId lhs = checkExpr(scope, *expr.expr).type; + + if (get(lhs) || get(lhs)) + return std::pair(lhs, nullptr); + + tablify(lhs); + + Name name = expr.index.value; + + if (FFlag::LuauExtraNilRecovery) + lhs = stripFromNilAndReport(lhs, expr.expr->location); + + if (TableTypeVar* lhsTable = getMutableTableType(lhs)) + { + const auto& it = lhsTable->props.find(name); + if (it != lhsTable->props.end()) + { + return std::pair(it->second.type, &it->second.type); + } + else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) + { + TypeId theType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + Property& property = lhsTable->props[name]; + property.type = theType; + property.location = expr.indexLocation; + return std::pair(theType, &property.type); + } + else if (auto indexer = lhsTable->indexer) + { + Unifier state = mkUnifier(expr.location); + state.tryUnify(indexer->indexType, stringType); + if (!state.errors.empty()) + { + state.log.rollback(); + reportError(expr.location, UnknownProperty{lhs, name}); + return std::pair(errorType, nullptr); + } + + return std::pair(indexer->indexResultType, nullptr); + } + else if (lhsTable->state == TableState::Sealed) + { + reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); + return std::pair(errorType, nullptr); + } + else + { + reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); + return std::pair(errorType, nullptr); + } + } + else if (const ClassTypeVar* lhsClass = get(lhs)) + { + const Property* prop = lookupClassProp(lhsClass, name); + if (!prop) + { + reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); + return std::pair(errorType, nullptr); + } + + return std::pair(prop->type, nullptr); + } + else if (get(lhs)) + { + if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, false)) + return std::pair(*ty, nullptr); + + // If intersection has a table part, report that it cannot be extended just as a sealed table + if (isTableIntersection(lhs)) + { + reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); + return std::pair(errorType, nullptr); + } + } + + reportError(TypeError{expr.location, NotATable{lhs}}); + return std::pair(errorType, nullptr); +} + +std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) +{ + TypeId exprType = checkExpr(scope, *expr.expr).type; + tablify(exprType); + + if (FFlag::LuauExtraNilRecovery) + exprType = stripFromNilAndReport(exprType, expr.expr->location); + + TypeId indexType = checkExpr(scope, *expr.index).type; + + if (get(exprType) || get(exprType)) + return std::pair(exprType, nullptr); + + AstExprConstantString* value = expr.index->as(); + + if (value && FFlag::LuauClassPropertyAccessAsString) + { + if (const ClassTypeVar* exprClass = get(exprType)) + { + const Property* prop = lookupClassProp(exprClass, value->value.data); + if (!prop) + { + reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); + return std::pair(errorType, nullptr); + } + return std::pair(prop->type, nullptr); + } + } + + TableTypeVar* exprTable = getMutableTableType(exprType); + + if (!exprTable) + { + if (FFlag::LuauExtraNilRecovery) + reportError(TypeError{expr.expr->location, NotATable{exprType}}); + else + reportError(TypeError{expr.location, NotATable{exprType}}); + return std::pair(errorType, nullptr); + } + + if (value) + { + const auto& it = exprTable->props.find(value->value.data); + if (it != exprTable->props.end()) + { + return std::pair(it->second.type, &it->second.type); + } + else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) + { + TypeId resultType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + Property& property = exprTable->props[value->value.data]; + property.type = resultType; + property.location = expr.index->location; + return std::pair(resultType, &property.type); + } + } + + if (exprTable->indexer) + { + const TableIndexer& indexer = *exprTable->indexer; + unify(indexer.indexType, indexType, expr.index->location); + return std::pair(indexer.indexResultType, nullptr); + } + else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) + { + TypeId resultType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; + return std::pair(resultType, nullptr); + } + else if (FFlag::LuauIndexTablesWithIndexers) + { + // We allow t[x] where x:string for tables without an indexer + unify(indexType, stringType, expr.location); + return std::pair(anyType, nullptr); + } + else + { + TypeId resultType = freshType(scope); + return std::pair(resultType, nullptr); + } +} + +// Answers the question: "Can I define another function with this name?" +// Primarily about detecting duplicates. +TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) +{ + if (auto globalName = funName.as()) + { + const ScopePtr& globalScope = currentModule->getModuleScope(); + Symbol name = globalName->name; + if (globalScope->bindings.count(name)) + { + if (isNonstrictMode()) + return globalScope->bindings[name].typeId; + + return errorType; + } + else + { + TypeId ty = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + globalScope->bindings[name] = {ty, funName.location}; + return ty; + } + } + else if (auto localName = funName.as()) + { + Symbol name = localName->local; + Binding& binding = scope->bindings[name]; + if (binding.typeId == nullptr) + binding = {(FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)), funName.location}; + + return binding.typeId; + } + else if (auto indexName = funName.as()) + { + TypeId lhsType = checkExpr(scope, *indexName->expr).type; + if (get(lhsType) || get(lhsType)) + return lhsType; + + TableTypeVar* ttv = getMutableTableType(lhsType); + if (!ttv) + { + if (!isTableIntersection(lhsType)) + reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); + + return errorType; + } + + // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check + if (lhsType->persistent || ttv->state == TableState::Sealed) + return errorType; + + Name name = indexName->index.value; + + if (ttv->props.count(name)) + return errorType; + + Property& property = ttv->props[name]; + + property.type = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + property.location = indexName->indexLocation; + ttv->methodDefinitionLocations[name] = funName.location; + return property.type; + } + else if (funName.is()) + { + return errorType; + } + else + { + ice("Unexpected AST node type", funName.location); + } +} + +// This returns a pair `[funType, funScope]` where +// - funType is the prototype type of the function +// - funScope is the scope for the function, which is a child scope with bindings added for +// parameters (and generic types if there were explicit generic annotations). +// +// The function type is a prototype, in that it may be missing some generic types which +// can only be inferred from type inference after typechecking the function body. +// For example the function `function id(x) return x end` has prototype +// `(X) -> Y...`, but after typechecking the body, we cam unify `Y...` with `X` +// to get type `(X) -> X`, then we quantify the free types to get the final +// generic type `(a) -> a`. +std::pair TypeChecker::checkFunctionSignature( + const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::optional originalName, std::optional expectedType) +{ + ScopePtr funScope = childFunctionScope(scope, expr.location, subLevel); + + const FunctionTypeVar* expectedFunctionType = nullptr; + + if (expectedType) + { + LUAU_ASSERT(!expr.self); + + if (auto ftv = get(follow(*expectedType))) + { + expectedFunctionType = ftv; + } + else if (auto utv = get(follow(*expectedType))) + { + // Look for function type in a union. Other types can be ignored since current expression is a function + for (auto option : utv) + { + if (auto ftv = get(follow(option))) + { + if (!expectedFunctionType) + { + expectedFunctionType = ftv; + } + else + { + // Do not infer argument types when multiple overloads are expected + expectedFunctionType = nullptr; + break; + } + } + } + } + + // We do not infer type binders, so if a generic function is required we do not propagate + if (expectedFunctionType && !(expectedFunctionType->generics.empty() && expectedFunctionType->genericPacks.empty())) + expectedFunctionType = nullptr; + } + + std::vector generics; + std::vector genericPacks; + + if (FFlag::LuauGenericFunctions) + { + std::tie(generics, genericPacks) = createGenericTypes(funScope, expr, expr.generics, expr.genericPacks); + } + + TypePackId retPack; + if (expr.hasReturnAnnotation) + { + if (FFlag::LuauGenericFunctions) + retPack = resolveTypePack(funScope, expr.returnAnnotation); + else + retPack = resolveTypePack(scope, expr.returnAnnotation); + } + else if (isNonstrictMode()) + retPack = anyTypePack; + else if (expectedFunctionType) + { + auto [head, tail] = flatten(expectedFunctionType->retType); + + // Do not infer 'nil' as function return type + if (!tail && head.size() == 1 && isNil(head[0])) + retPack = FFlag::LuauGenericFunctions ? freshTypePack(funScope) : freshTypePack(scope); + else + retPack = addTypePack(head, tail); + } + else if (FFlag::LuauGenericFunctions) + retPack = freshTypePack(funScope); + else + retPack = freshTypePack(scope); + + if (expr.vararg) + { + if (expr.varargAnnotation) + { + if (FFlag::LuauGenericFunctions) + funScope->varargPack = resolveTypePack(funScope, *expr.varargAnnotation); + else + funScope->varargPack = resolveTypePack(scope, *expr.varargAnnotation); + } + else + { + if (expectedFunctionType && !isNonstrictMode()) + { + auto [head, tail] = flatten(expectedFunctionType->argTypes); + + if (expr.args.size <= head.size()) + { + head.erase(head.begin(), head.begin() + expr.args.size); + + funScope->varargPack = addTypePack(head, tail); + } + else if (tail) + { + if (get(follow(*tail))) + funScope->varargPack = addTypePack({}, tail); + } + else + { + funScope->varargPack = addTypePack({}); + } + } + + // TODO: should this be a free type pack? CLI-39910 + if (!funScope->varargPack) + funScope->varargPack = anyTypePack; + } + } + + std::vector argTypes; + + funScope->returnType = retPack; + + if (expr.self) + { + // TODO: generic self types: CLI-39906 + TypeId selfType = anyIfNonstrict(freshType(funScope)); + funScope->bindings[expr.self] = {selfType, expr.self->location}; + argTypes.push_back(selfType); + } + + // Prepare expected argument type iterators if we have an expected function type + TypePackIterator expectedArgsCurr, expectedArgsEnd; + + if (expectedFunctionType && !isNonstrictMode()) + { + expectedArgsCurr = begin(expectedFunctionType->argTypes); + expectedArgsEnd = end(expectedFunctionType->argTypes); + } + + for (AstLocal* local : expr.args) + { + TypeId argType = nullptr; + + if (local->annotation) + { + argType = resolveType((FFlag::LuauGenericFunctions ? funScope : scope), *local->annotation); + + // If the annotation type has an error, treat it as if there was no annotation + if (get(follow(argType))) + argType = anyIfNonstrict(freshType(funScope)); + } + else + { + if (expectedFunctionType && !isNonstrictMode()) + { + if (expectedArgsCurr != expectedArgsEnd) + { + argType = *expectedArgsCurr; + + if (!FFlag::LuauInferFunctionArgsFix) + ++expectedArgsCurr; + } + else if (auto expectedArgsTail = expectedArgsCurr.tail()) + { + if (const VariadicTypePack* vtp = get(follow(*expectedArgsTail))) + argType = vtp->ty; + } + } + + if (!argType) + argType = anyIfNonstrict(freshType(funScope)); + } + + funScope->bindings[local] = {argType, local->location}; + argTypes.push_back(argType); + + if (FFlag::LuauInferFunctionArgsFix && expectedArgsCurr != expectedArgsEnd) + ++expectedArgsCurr; + } + + TypePackId argPack = addTypePack(TypePackVar(TypePack{argTypes, funScope->varargPack})); + + FunctionDefinition defn; + defn.definitionModuleName = currentModuleName; + defn.definitionLocation = expr.location; + defn.varargLocation = expr.vararg ? std::make_optional(expr.varargLocation) : std::nullopt; + defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); + + TypeId funTy = addType(FunctionTypeVar(funScope->level, generics, genericPacks, argPack, retPack, std::move(defn), bool(expr.self))); + + FunctionTypeVar* ftv = getMutable(funTy); + + ftv->argNames.reserve(expr.args.size + (expr.self ? 1 : 0)); + + if (expr.self) + ftv->argNames.push_back(FunctionArgument{"self", {}}); + + for (AstLocal* local : expr.args) + ftv->argNames.push_back(FunctionArgument{local->name.value, local->location}); + + return std::make_pair(funTy, funScope); +} + +static bool allowsNoReturnValues(const TypePackId tp) +{ + for (TypeId ty : tp) + { + if (!get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) + { + return false; + } + } + + return true; +} + +static Location getEndLocation(const AstExprFunction& function) +{ + Location loc = function.location; + if (loc.begin.line != loc.end.line) + { + Position begin = loc.end; + begin.column = std::max(0u, begin.column - 3); + loc = Location(begin, 3); + } + + return loc; +} + +void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstExprFunction& function) +{ + if (FunctionTypeVar* funTy = getMutable(ty)) + { + check(scope, *function.body); + + // We explicitly don't follow here to check if we have a 'true' free type instead of bound one + if (FFlag::LuauAddMissingFollow ? get_if(&funTy->retType->ty) : get(funTy->retType)) + *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; + + bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; + + if (reachesImplicitReturn && !allowsNoReturnValues(follow(funTy->retType))) + { + // If we're in nonstrict mode we want to only report this missing return + // statement if there are type annotations on the function. In strict mode + // we report it regardless. + if (!isNonstrictMode() || function.hasReturnAnnotation) + { + reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retType}); + } + } + } + else + ice("Checking non functional type"); +} + +ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExpr& expr) +{ + if (auto a = expr.as()) + return checkExprPack(scope, *a); + else if (expr.is()) + { + if (!scope->varargPack) + return {addTypePack({addType(ErrorTypeVar())})}; + + return {*scope->varargPack}; + } + else + { + TypeId type = checkExpr(scope, expr).type; + return {addTypePack({type})}; + } +} + +// Returns the minimum number of arguments the argument list can accept. +static size_t getMinParameterCount(TypePackId tp) +{ + size_t minCount = 0; + size_t optionalCount = 0; + + auto it = begin(tp); + auto endIter = end(tp); + + while (it != endIter) + { + TypeId ty = *it; + if (isOptional(ty)) + ++optionalCount; + else + { + minCount += optionalCount; + optionalCount = 0; + minCount++; + } + + ++it; + } + + return minCount; +} + +void TypeChecker::checkArgumentList( + const ScopePtr& scope, Unifier& state, TypePackId argPack, TypePackId paramPack, const std::vector& argLocations) +{ + /* Important terminology refresher: + * A function requires paramaters. + * To call a function, you supply arguments. + */ + TypePackIterator argIter = begin(argPack); + TypePackIterator paramIter = begin(paramPack); + TypePackIterator endIter = end(argPack); // Important subtlety: All end TypePackIterators are equivalent + + size_t paramIndex = 0; + + size_t minParams = getMinParameterCount(paramPack); + + while (true) + { + state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; + + if (argIter == endIter && paramIter == endIter) + { + std::optional argTail = argIter.tail(); + std::optional paramTail = paramIter.tail(); + + // If we hit the end of both type packs simultaneously, then there are definitely no further type + // errors to report. All we need to do is tie up any free tails. + // + // If one side has a free tail and the other has none at all, we create an empty pack and bind the + // free tail to that. + + if (argTail) + { + if (get(*argTail)) + { + if (paramTail) + state.tryUnify(*argTail, *paramTail); + else + { + state.log(*argTail); + *asMutable(*argTail) = TypePack{{}}; + } + } + } + else if (paramTail) + { + // argTail is definitely empty + if (get(*paramTail)) + { + state.log(*paramTail); + *asMutable(*paramTail) = TypePack{{}}; + } + } + + return; + } + else if (argIter == endIter) + { + // Not enough arguments. + + // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. + if (argIter.tail()) + { + TypePackId tail = *argIter.tail(); + if (get(tail)) + { + // Unify remaining parameters so we don't leave any free-types hanging around. + TypeId argTy = errorType; + while (paramIter != endIter) + { + state.tryUnify(*paramIter, argTy); + ++paramIter; + } + return; + } + else if (auto vtp = get(tail)) + { + while (paramIter != endIter) + { + state.tryUnify(*paramIter, vtp->ty); + ++paramIter; + } + + return; + } + else if (get(tail)) + { + std::vector rest; + rest.reserve(std::distance(paramIter, endIter)); + while (paramIter != endIter) + { + rest.push_back(*paramIter); + ++paramIter; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); + state.tryUnify(tail, varPack); + return; + } + } + + // If any remaining unfulfilled parameters are nonoptional, this is a problem. + while (paramIter != endIter) + { + TypeId t = follow(*paramIter); + if (isOptional(t)) + { + } // ok + else if (get(t)) + { + } // ok + else if (isNonstrictMode() && get(t)) + { + } // ok + else + { + state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + return; + } + ++paramIter; + } + } + else if (paramIter == endIter) + { + // too many parameters passed + if (!paramIter.tail()) + { + while (argIter != endIter) + { + unify(*argIter, errorType, state.location); + ++argIter; + } + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; + } + TypePackId tail = *paramIter.tail(); + + if (get(tail)) + { + // Function is variadic. Ok. + return; + } + else if (auto vtp = get(tail)) + { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. + size_t argIndex = paramIndex; + while (argIter != endIter) + { + Location location = state.location; + + if (argIndex < argLocations.size()) + location = argLocations[argIndex]; + + unify(vtp->ty, *argIter, location); + ++argIter; + ++argIndex; + } + + return; + } + else if (FFlag::LuauGenericVariadicsUnification && get(tail)) + { + // Create a type pack out of the remaining argument types + // and unify it with the tail. + std::vector rest; + rest.reserve(std::distance(argIter, endIter)); + while (argIter != endIter) + { + rest.push_back(*argIter); + ++argIter; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); + state.tryUnify(tail, varPack); + return; + } + else if (get(tail)) + { + state.log(tail); + *asMutable(tail) = TypePack{}; + + return; + } + else if (FFlag::LuauRankNTypes && get(tail)) + { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + // TODO: Better error message? + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; + } + } + else + { + if (FFlag::LuauRankNTypes) + unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); + else + state.tryUnify(*paramIter, *argIter, /*isFunctionCall*/ false); + ++argIter; + ++paramIter; + } + + ++paramIndex; + } +} + +ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const AstExprCall& expr) +{ + // evaluate type of function + // decompose an intersection into its component overloads + // Compute types of parameters + // For each overload + // Compare parameter and argument types + // Report any errors (also speculate dot vs colon warnings!) + // If there are no errors, return the resulting return type + + TypeId selfType = nullptr; + TypeId functionType = nullptr; + TypeId actualFunctionType = nullptr; + + if (expr.self) + { + AstExprIndexName* indexExpr = expr.func->as(); + if (!indexExpr) + ice("method call expression has no 'self'"); + + selfType = checkExpr(scope, *indexExpr->expr).type; + if (!FFlag::LuauRankNTypes) + instantiate(scope, selfType, expr.func->location); + + if (FFlag::LuauExtraNilRecovery) + selfType = stripFromNilAndReport(selfType, expr.func->location); + + if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) + { + functionType = *propTy; + actualFunctionType = instantiate(scope, functionType, expr.func->location); + } + else + { + if (!FFlag::LuauMissingUnionPropertyError) + reportError(indexExpr->indexLocation, UnknownProperty{selfType, indexExpr->index.value}); + + if (!FFlag::LuauExtraNilRecovery) + { + // Try to recover using a union without 'nil' options + if (std::optional strippedUnion = tryStripUnionFromNil(selfType)) + { + if (std::optional propTy = getIndexTypeFromType(scope, *strippedUnion, indexExpr->index.value, expr.location, false)) + { + selfType = *strippedUnion; + + functionType = *propTy; + actualFunctionType = instantiate(scope, functionType, expr.func->location); + } + } + + if (!actualFunctionType) + { + functionType = errorType; + actualFunctionType = errorType; + } + } + else + { + functionType = errorType; + actualFunctionType = errorType; + } + } + } + else + { + functionType = checkExpr(scope, *expr.func).type; + actualFunctionType = instantiate(scope, functionType, expr.func->location); + } + + actualFunctionType = follow(actualFunctionType); + + // checkExpr will log the pre-instantiated type of the function. + // That's not nearly as interesting as the instantiated type, which will include details about how + // generic functions are being instantiated for this particular callsite. + currentModule->astOriginalCallTypes[expr.func] = follow(functionType); + currentModule->astTypes[expr.func] = actualFunctionType; + + std::vector overloads = flattenIntersection(actualFunctionType); + + TypePackId retPack = freshTypePack(scope->level); + + std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); + + ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); + TypePackId argList = argListResult.type; + TypePackId argPack = (FFlag::LuauRankNTypes ? argList : DEPRECATED_instantiate(scope, argList, expr.location)); + + if (get(argPack)) + return ExprResult{errorTypePack}; + + TypePack* args = getMutable(argPack); + LUAU_ASSERT(args != nullptr); + + if (expr.self) + args->head.insert(args->head.begin(), selfType); + + std::vector argLocations; + argLocations.reserve(expr.args.size + 1); + if (expr.self) + argLocations.push_back(expr.func->as()->expr->location); + for (AstExpr* arg : expr.args) + argLocations.push_back(arg->location); + + std::vector errors; // errors encountered for each overload + + std::vector overloadsThatMatchArgCount; + + for (TypeId fn : overloads) + { + fn = follow(fn); + + if (auto ret = checkCallOverload(scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, errors)) + return *ret; + } + + if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) + return {retPack}; + + return reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); +} + +std::vector> TypeChecker::getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall) +{ + std::vector> expectedTypes; + + auto assignOption = [this, &expectedTypes](size_t index, std::optional ty) { + if (index == expectedTypes.size()) + { + expectedTypes.push_back(ty); + } + else if (ty) + { + auto& el = expectedTypes[index]; + + if (!el) + { + el = ty; + } + else + { + std::vector result = reduceUnion({*el, *ty}); + el = result.size() == 1 ? result[0] : addType(UnionTypeVar{std::move(result)}); + } + } + }; + + for (const TypeId overload : overloads) + { + if (const FunctionTypeVar* ftv = get(overload)) + { + auto [argsHead, argsTail] = flatten(ftv->argTypes); + + size_t start = selfCall ? 1 : 0; + size_t index = 0; + + for (size_t i = start; i < argsHead.size(); ++i) + assignOption(index++, argsHead[i]); + + if (argsTail) + { + if (const VariadicTypePack* vtp = get(follow(*argsTail))) + { + while (index < argumentCount) + assignOption(index++, vtp->ty); + } + } + } + } + + return expectedTypes; +} + +std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, + TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, + std::vector& overloadsThatMatchArgCount, std::vector& errors) +{ + if (FFlag::LuauExtraNilRecovery) + fn = stripFromNilAndReport(fn, expr.func->location); + + if (get(fn)) + { + unify(argPack, anyTypePack, expr.location); + return {{anyTypePack}}; + } + + if (get(fn)) + { + return {{addTypePack(TypePackVar{Unifiable::Error{}})}}; + } + + if (get(fn)) + { + // fn is one of the overloads of actualFunctionType, which + // has been instantiated, so is a monotype. We can therefore + // unify it with a monomorphic function. + TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); + unify(r, fn, expr.location); + return {{retPack}}; + } + + const FunctionTypeVar* ftv = get(fn); + if (!ftv) + { + // Might be a callable table + if (const MetatableTypeVar* mttv = get(fn)) + { + if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) + { + // Construct arguments with 'self' added in front + TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); + + TypePack* metaCallArgs = getMutable(metaCallArgPack); + metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); + + std::vector metaArgLocations = argLocations; + metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + + TypeId fn = *ty; + if (FFlag::LuauRankNTypes) + fn = instantiate(scope, fn, expr.func->location); + + return checkCallOverload( + scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); + } + } + + reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); + unify(retPack, errorTypePack, expr.func->location); + return {{errorTypePack}}; + } + + // When this function type has magic functions and did return something, we select that overload instead. + // TODO: pass in a Unifier object to the magic functions? This will allow the magic functions to cooperate with overload resolution. + if (ftv->magicFunction) + { + // TODO: We're passing in the wrong TypePackId. Should be argPack, but a unit test fails otherwise. CLI-40458 + if (std::optional> ret = ftv->magicFunction(*this, scope, expr, argListResult)) + return *ret; + } + + Unifier state = mkUnifier(expr.location); + + // Unify return types + checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); + if (!state.errors.empty()) + { + state.log.rollback(); + return {}; + } + + checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); + + if (!state.errors.empty()) + { + bool argMismatch = false; + for (auto error : state.errors) + { + CountMismatch* cm = get(error); + if (!cm) + continue; + + if (cm->context == CountMismatch::Arg) + { + argMismatch = true; + break; + } + } + + if (!argMismatch) + overloadsThatMatchArgCount.push_back(fn); + + errors.emplace_back(std::move(state.errors), args->head, ftv); + state.log.rollback(); + } + else + { + if (isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) + { + // If we are running in nonstrict mode, passing fewer arguments than the function is declared to take AND + // the function is declared with colon notation AND we use dot notation, warn. + auto [providedArgs, providedTail] = flatten(argPack); + + // If we have a variadic tail, we can't say how many arguments were actually provided + if (!providedTail) + { + std::vector actualArgs = flatten(ftv->argTypes).first; + + size_t providedCount = providedArgs.size(); + size_t requiredCount = actualArgs.size(); + + // Ignore optional arguments + while (providedCount < requiredCount && requiredCount != 0 && isOptional(actualArgs[requiredCount - 1])) + requiredCount--; + + if (providedCount < requiredCount) + { + int requiredExtraNils = int(requiredCount - providedCount); + reportError(TypeError{expr.func->location, FunctionRequiresSelf{requiredExtraNils}}); + } + } + } + + if (FFlag::LuauStoreMatchingOverloadFnType) + currentModule->astOverloadResolvedTypes[&expr] = fn; + + // We select this overload + return {{retPack}}; + } + + return {}; +} + +bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, + const std::vector& errors) +{ + // No overloads succeeded: Scan for one that would have worked had the user + // used a.b() rather than a:b() or vice versa. + for (const auto& [_, argVec, ftv] : errors) + { + // Did you write foo:bar() when you should have written foo.bar()? + if (expr.self) + { + std::vector editedArgLocations(argLocations.begin() + 1, argLocations.end()); + + std::vector editedParamList(args->head.begin() + 1, args->head.end()); + TypePackId editedArgPack = addTypePack(TypePack{editedParamList}); + + Unifier editedState = mkUnifier(expr.location); + checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + + if (editedState.errors.empty()) + { + reportError(TypeError{expr.location, FunctionDoesNotTakeSelf{}}); + // This is a little bit suspect: If this overload would work with a . replaced by a : + // we eagerly assume that that's what you actually meant and we commit to it. + // This could be incorrect if the function has an additional overload that + // actually works. + // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + return true; + } + else + editedState.log.rollback(); + } + else if (ftv->hasSelf) + { + // Did you write foo.bar() when you should have written foo:bar()? + if (AstExprIndexName* indexName = expr.func->as()) + { + std::vector editedArgLocations; + editedArgLocations.reserve(argLocations.size() + 1); + editedArgLocations.push_back(indexName->expr->location); + editedArgLocations.insert(editedArgLocations.end(), argLocations.begin(), argLocations.end()); + + std::vector editedArgList(args->head); + editedArgList.insert(editedArgList.begin(), checkExpr(scope, *indexName->expr).type); + TypePackId editedArgPack = addTypePack(TypePack{editedArgList}); + + Unifier editedState = mkUnifier(expr.location); + + checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); + + if (editedState.errors.empty()) + { + reportError(TypeError{expr.location, FunctionRequiresSelf{}}); + // This is a little bit suspect: If this overload would work with a : replaced by a . + // we eagerly assume that that's what you actually meant and we commit to it. + // This could be incorrect if the function has an additional overload that + // actually works. + // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); + return true; + } + else + editedState.log.rollback(); + } + } + } + + return false; +} + +ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, + TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, + const std::vector& overloadsThatMatchArgCount, const std::vector& errors) +{ + if (overloads.size() == 1) + { + reportErrors(std::get<0>(errors.front())); + return {errorTypePack}; + } + + std::vector overloadTypes = overloadsThatMatchArgCount; + if (overloadsThatMatchArgCount.size() == 0) + { + reportError(TypeError{expr.location, GenericError{"No overload for function accepts " + std::to_string(size(argPack)) + " arguments."}}); + // If no overloads match argument count, just list all overloads. + overloadTypes = overloads; + } + else + { + // Report errors of the first argument-count-matching, but failing overload + TypeId overload = overloadsThatMatchArgCount[0]; + + // Remove the overload we are reporting errors about, from the list of alternative + overloadTypes.erase(std::remove(overloadTypes.begin(), overloadTypes.end(), overload), overloadTypes.end()); + + const FunctionTypeVar* ftv = get(overload); + + auto error = std::find_if(errors.begin(), errors.end(), [ftv](const OverloadErrorEntry& e) { + return ftv == std::get<2>(e); + }); + + LUAU_ASSERT(error != errors.end()); + reportErrors(std::get<0>(*error)); + + // If only one overload matched, we don't need this error because we provided the previous errors. + if (overloadsThatMatchArgCount.size() == 1) + return {errorTypePack}; + } + + std::string s; + for (size_t i = 0; i < overloadTypes.size(); ++i) + { + TypeId overload = overloadTypes[i]; + Unifier state = mkUnifier(expr.location); + + // Unify return types + if (const FunctionTypeVar* ftv = get(overload)) + { + checkArgumentList(scope, state, retPack, ftv->retType, {}); + checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); + } + + if (i > 0) + s += "; "; + + if (i > 0 && i == overloadTypes.size() - 1) + s += "and "; + + s += toString(overload); + + state.log.rollback(); + } + + if (overloadsThatMatchArgCount.size() == 0) + reportError(expr.func->location, ExtraInformation{"Available overloads: " + s}); + else + reportError(expr.func->location, ExtraInformation{"Other overloads are also not viable: " + s}); + + // No viable overload + return {errorTypePack}; +} + +ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, + bool substituteFreeForNil, const std::vector& instantiateGenerics, const std::vector>& expectedTypes) +{ + TypePackId pack = addTypePack(TypePack{}); + PredicateVec predicates; // At the moment we will be pushing all predicate sets into this. Do we need some way to split them up? + + auto insert = [&predicates](PredicateVec& vec) { + for (Predicate& c : vec) + predicates.push_back(std::move(c)); + }; + + if (exprs.size == 0) + return {pack}; + + TypePack* tp = getMutable(pack); + + size_t lastIndex = exprs.size - 1; + tp->head.reserve(lastIndex); + + Unifier state = mkUnifier(location); + + for (size_t i = 0; i < exprs.size; ++i) + { + AstExpr* expr = exprs.data[i]; + + if (i == lastIndex && (expr->is() || expr->is())) + { + auto [typePack, exprPredicates] = checkExprPack(scope, *expr); + insert(exprPredicates); + + tp->tail = typePack; + } + else + { + std::optional expectedType = i < expectedTypes.size() ? expectedTypes[i] : std::nullopt; + auto [type, exprPredicates] = checkExpr(scope, *expr, expectedType); + insert(exprPredicates); + + TypeId actualType = substituteFreeForNil && expr->is() ? freshType(scope) : type; + + if (instantiateGenerics.size() > i && instantiateGenerics[i] && (FFlag::LuauGenericFunctions || get(actualType))) + actualType = instantiate(scope, actualType, expr->location); + + if (expectedType) + state.tryUnify(*expectedType, actualType); + + tp->head.push_back(actualType); + } + } + + state.log.rollback(); + + return {pack, predicates}; +} + +std::optional TypeChecker::matchRequire(const AstExprCall& call) +{ + const char* require = "require"; + + if (call.args.size != 1) + return std::nullopt; + + const AstExprGlobal* funcAsGlobal = call.func->as(); + if (!funcAsGlobal || funcAsGlobal->name != require) + return std::nullopt; + + if (call.args.size != 1) + return std::nullopt; + + return call.args.data[0]; +} + +TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location) +{ + ModulePtr module = resolver->getModule(moduleInfo.name); + if (!module) + { + // There are two reasons why we might fail to find the module: + // either the file does not exist or there's a cycle. If there's a cycle + // we will already have reported the error. + if (!resolver->moduleExists(moduleInfo.name) && (FFlag::LuauTraceRequireLookupChild ? !moduleInfo.optional : true)) + { + std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); + reportError(TypeError{location, UnknownRequire{reportedModulePath}}); + } + + return errorType; + } + + if (module->type != SourceCode::Module) + { + std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); + reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); + return errorType; + } + + std::optional moduleType = first(module->getModuleScope()->returnType); + if (!moduleType) + { + std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); + reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); + return errorType; + } + + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + return clone(*moduleType, currentModule->internalTypes, seenTypes, seenTypePacks); +} + +void TypeChecker::tablify(TypeId type) +{ + type = follow(type); + + if (auto f = get(type)) + *asMutable(type) = TableTypeVar{TableState::Free, f->level}; +} + +TypeId TypeChecker::anyIfNonstrict(TypeId ty) const +{ + if (isNonstrictMode()) + return anyType; + else + return ty; +} + +bool TypeChecker::unify(TypeId left, TypeId right, const Location& location) +{ + Unifier state = mkUnifier(location); + state.tryUnify(left, right); + + reportErrors(state.errors); + + return state.errors.empty(); +} + +bool TypeChecker::unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx) +{ + Unifier state = mkUnifier(location); + state.ctx = ctx; + state.tryUnify(left, right); + + reportErrors(state.errors); + + return state.errors.empty(); +} + +bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location) +{ + LUAU_ASSERT(FFlag::LuauRankNTypes); + Unifier state = mkUnifier(location); + unifyWithInstantiationIfNeeded(scope, left, right, state); + + reportErrors(state.errors); + + return state.errors.empty(); +} + +void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state) +{ + LUAU_ASSERT(FFlag::LuauRankNTypes); + if (!maybeGeneric(right)) + // Quick check to see if we definitely can't instantiate + state.tryUnify(left, right, /*isFunctionCall*/ false); + else if (!maybeGeneric(left) && isGeneric(right)) + { + // Quick check to see if we definitely have to instantiate + TypeId instantiated = instantiate(scope, right, state.location); + state.tryUnify(left, instantiated, /*isFunctionCall*/ false); + } + else + { + // First try unifying with the original uninstantiated type + // but if that fails, try the instantiated one. + Unifier child = state.makeChildUnifier(); + child.tryUnify(left, right, /*isFunctionCall*/ false); + if (!child.errors.empty()) + { + TypeId instantiated = instantiate(scope, right, state.location); + if (right == instantiated) + { + // Instantiating the argument made no difference, so just report any child errors + state.log.concat(std::move(child.log)); + state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); + } + else + { + child.log.rollback(); + state.tryUnify(left, instantiated, /*isFunctionCall*/ false); + } + } + else + { + state.log.concat(std::move(child.log)); + } + } +} + +bool Instantiation::isDirty(TypeId ty) +{ + if (FFlag::LuauRankNTypes) + { + if (get(ty)) + return true; + else + return false; + } + + if (const FunctionTypeVar* ftv = get(ty)) + return !ftv->generics.empty() || !ftv->genericPacks.empty(); + else if (const TableTypeVar* ttv = get(ty)) + return ttv->state == TableState::Generic; + else if (get(ty)) + return true; + else + return false; +} + +bool Instantiation::isDirty(TypePackId tp) +{ + if (FFlag::LuauRankNTypes) + return false; + + if (get(tp)) + return true; + else + return false; +} + +bool Instantiation::ignoreChildren(TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauRankNTypes); + if (get(ty)) + return true; + else + return false; +} + +TypeId Instantiation::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + + if (const FunctionTypeVar* ftv = get(ty)) + { + FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + clone.magicFunction = ftv->magicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + TypeId result = addType(std::move(clone)); + + if (FFlag::LuauRankNTypes) + { + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + replaceGenerics.level = level; + replaceGenerics.currentModule = currentModule; + replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); + replaceGenerics.genericPacks.assign(ftv->genericPacks.begin(), ftv->genericPacks.end()); + + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); + } + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; + } + else if (const TableTypeVar* ttv = get(ty)) + { + LUAU_ASSERT(!FFlag::LuauRankNTypes); + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + clone.definitionModuleName = ttv->definitionModuleName; + TypeId result = addType(std::move(clone)); + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; + } + else + { + LUAU_ASSERT(!FFlag::LuauRankNTypes); + TypeId result = addType(FreeTypeVar{level}); + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; + } +} + +TypePackId Instantiation::clean(TypePackId tp) +{ + LUAU_ASSERT(!FFlag::LuauRankNTypes); + return addTypePack(TypePackVar(FreeTypePack{level})); +} + +bool ReplaceGenerics::ignoreChildren(TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauRankNTypes); + if (const FunctionTypeVar* ftv = get(ty)) + // We aren't recursing in the case of a generic function which + // binds the same generics. This can happen if, for example, there's recursive types. + // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. + // It's OK to use vector equality here, since we always generate fresh generics + // whenever we quantify, so the vectors overlap if and only if they are equal. + return (!generics.empty() || !genericPacks.empty()) && (ftv->generics == generics) && (ftv->genericPacks == genericPacks); + else + return false; +} + +bool ReplaceGenerics::isDirty(TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauRankNTypes); + if (const TableTypeVar* ttv = get(ty)) + return ttv->state == TableState::Generic; + else if (get(ty)) + return std::find(generics.begin(), generics.end(), ty) != generics.end(); + else + return false; +} + +bool ReplaceGenerics::isDirty(TypePackId tp) +{ + LUAU_ASSERT(FFlag::LuauRankNTypes); + if (get(tp)) + return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); + else + return false; +} + +TypeId ReplaceGenerics::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + if (const TableTypeVar* ttv = get(ty)) + { + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + clone.definitionModuleName = ttv->definitionModuleName; + return addType(std::move(clone)); + } + else + return addType(FreeTypeVar{level}); +} + +TypePackId ReplaceGenerics::clean(TypePackId tp) +{ + LUAU_ASSERT(isDirty(tp)); + return addTypePack(TypePackVar(FreeTypePack{level})); +} + +bool Quantification::isDirty(TypeId ty) +{ + if (const TableTypeVar* ttv = get(ty)) + return level.subsumes(ttv->level) && ((ttv->state == TableState::Free) || (ttv->state == TableState::Unsealed)); + else if (const FreeTypeVar* ftv = get(ty)) + return level.subsumes(ftv->level); + else + return false; +} + +bool Quantification::isDirty(TypePackId tp) +{ + if (const FreeTypePack* ftv = get(tp)) + return level.subsumes(ftv->level); + else + return false; +} + +TypeId Quantification::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + if (const TableTypeVar* ttv = get(ty)) + { + TableState state = (ttv->state == TableState::Unsealed ? TableState::Sealed : TableState::Generic); + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, state}; + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + clone.definitionModuleName = ttv->definitionModuleName; + return addType(std::move(clone)); + } + else + { + TypeId generic = addType(GenericTypeVar{level}); + generics.push_back(generic); + return generic; + } +} + +TypePackId Quantification::clean(TypePackId tp) +{ + LUAU_ASSERT(isDirty(tp)); + TypePackId genericPack = addTypePack(TypePackVar(GenericTypePack{level})); + genericPacks.push_back(genericPack); + return genericPack; +} + +bool Anyification::isDirty(TypeId ty) +{ + if (const TableTypeVar* ttv = get(ty)) + return (ttv->state == TableState::Free); + else if (get(ty)) + return true; + else + return false; +} + +bool Anyification::isDirty(TypePackId tp) +{ + if (get(tp)) + return true; + else + return false; +} + +TypeId Anyification::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + if (const TableTypeVar* ttv = get(ty)) + { + TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; + clone.methodDefinitionLocations = ttv->methodDefinitionLocations; + clone.definitionModuleName = ttv->definitionModuleName; + return addType(std::move(clone)); + } + else + return anyType; +} + +TypePackId Anyification::clean(TypePackId tp) +{ + LUAU_ASSERT(isDirty(tp)); + return anyTypePack; +} + +TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location) +{ + ty = follow(ty); + + const FunctionTypeVar* ftv = get(ty); + if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) + return ty; + + quantification.level = scope->level; + quantification.generics.clear(); + quantification.genericPacks.clear(); + quantification.currentModule = currentModule; + + std::optional qty = quantification.substitute(ty); + + if (!qty.has_value()) + { + reportError(location, UnificationTooComplex{}); + return errorType; + } + + if (ty == *qty) + return ty; + + FunctionTypeVar* qftv = getMutable(*qty); + LUAU_ASSERT(qftv); + qftv->generics = quantification.generics; + qftv->genericPacks = quantification.genericPacks; + return *qty; +} + +TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location) +{ + instantiation.level = scope->level; + instantiation.currentModule = currentModule; + std::optional instantiated = instantiation.substitute(ty); + if (instantiated.has_value()) + return *instantiated; + else + { + reportError(location, UnificationTooComplex{}); + return errorType; + } +} + +TypePackId TypeChecker::DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location) +{ + LUAU_ASSERT(!FFlag::LuauRankNTypes); + instantiation.level = scope->level; + instantiation.currentModule = currentModule; + std::optional instantiated = instantiation.substitute(ty); + if (instantiated.has_value()) + return *instantiated; + else + { + reportError(location, UnificationTooComplex{}); + return errorTypePack; + } +} + +TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) +{ + anyification.anyType = anyType; + anyification.anyTypePack = anyTypePack; + anyification.currentModule = currentModule; + std::optional any = anyification.substitute(ty); + if (any.has_value()) + return *any; + else + { + reportError(location, UnificationTooComplex{}); + return errorType; + } +} + +TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) +{ + anyification.anyType = anyType; + anyification.anyTypePack = anyTypePack; + anyification.currentModule = currentModule; + std::optional any = anyification.substitute(ty); + if (any.has_value()) + return *any; + else + { + reportError(location, UnificationTooComplex{}); + return errorTypePack; + } +} + +void TypeChecker::reportError(const TypeError& error) +{ + if (currentModule->mode == Mode::NoCheck) + return; + currentModule->errors.push_back(error); + currentModule->errors.back().moduleName = currentModuleName; +} + +void TypeChecker::reportError(const Location& location, TypeErrorData errorData) +{ + return reportError(TypeError{location, std::move(errorData)}); +} + +void TypeChecker::reportErrors(const ErrorVec& errors) +{ + for (const auto& err : errors) + reportError(err); +} + +void TypeChecker::ice(const std::string& message, const Location& location) +{ + iceHandler->ice(message, location); +} + +void TypeChecker::ice(const std::string& message) +{ + iceHandler->ice(message); +} + +void TypeChecker::prepareErrorsForDisplay(ErrorVec& errVec) +{ + // Remove errors with names that were generated by recovery from a parse error + errVec.erase(std::remove_if(errVec.begin(), errVec.end(), + [](auto& err) { + return containsParseErrorName(err); + }), + errVec.end()); + + for (auto& err : errVec) + { + if (auto utk = get(err)) + diagnoseMissingTableKey(utk, err.data); + } +} + +void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data) +{ + std::string_view sv(utk->key); + std::set candidates; + + auto accumulate = [&](const TableTypeVar::Props& props) { + for (const auto& [name, ty] : props) + { + if (sv != name && equalsLower(sv, name)) + candidates.insert(name); + } + }; + + if (auto ttv = getTableType(follow(utk->table))) + accumulate(ttv->props); + else if (auto ctv = get(follow(utk->table))) + { + while (ctv) + { + accumulate(ctv->props); + + if (!ctv->parent) + break; + + ctv = get(*ctv->parent); + LUAU_ASSERT(ctv); + } + } + + if (!candidates.empty()) + data = TypeErrorData(UnknownPropButFoundLikeProp{utk->table, utk->key, candidates}); +} + +LUAU_NOINLINE void TypeChecker::reportErrorCodeTooComplex(const Location& location) +{ + reportError(TypeError{location, CodeTooComplex{}}); +} + +// Creates a new Scope but without carrying forward the varargs from the parent. +ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel) +{ + ScopePtr scope = std::make_shared(parent, subLevel); + currentModule->scopes.push_back(std::make_pair(location, scope)); + return scope; +} + +// Creates a new Scope and carries forward the varargs from the parent. +ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location, int subLevel) +{ + ScopePtr scope = std::make_shared(parent, subLevel); + scope->varargPack = parent->varargPack; + + currentModule->scopes.push_back(std::make_pair(location, scope)); + return scope; +} + +void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) +{ + Luau::merge(l, r, [this](TypeId a, TypeId b) { + // TODO: normalize(UnionTypeVar{{a, b}}) + std::unordered_set set; + + if (auto utv = get(follow(a))) + set.insert(begin(utv), end(utv)); + else + set.insert(a); + + if (auto utv = get(follow(b))) + set.insert(begin(utv), end(utv)); + else + set.insert(b); + + std::vector options(set.begin(), set.end()); + if (set.size() == 1) + return options[0]; + return addType(UnionTypeVar{std::move(options)}); + }); +} + +Unifier TypeChecker::mkUnifier(const Location& location) +{ + return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, iceHandler}; +} + +Unifier TypeChecker::mkUnifier(const std::vector>& seen, const Location& location) +{ + return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, iceHandler}; +} + +TypeId TypeChecker::freshType(const ScopePtr& scope) +{ + return freshType(scope->level); +} + +TypeId TypeChecker::freshType(TypeLevel level) +{ + return currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level))); +} + +TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric) +{ + return DEPRECATED_freshType(scope->level, canBeGeneric); +} + +TypeId TypeChecker::DEPRECATED_freshType(TypeLevel level, bool canBeGeneric) +{ + TypeId allocated = currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level, canBeGeneric))); + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = ¤tModule->internalTypes; + + return allocated; +} + +std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) +{ + std::vector types = Luau::filterMap(type, predicate); + if (!types.empty()) + return types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)}); + return std::nullopt; +} + +TypeId TypeChecker::addType(const UnionTypeVar& utv) +{ + LUAU_ASSERT(utv.options.size() > 1); + + return addTV(TypeVar(utv)); +} + +TypeId TypeChecker::addTV(TypeVar&& tv) +{ + TypeId allocated = currentModule->internalTypes.typeVars.allocate(std::move(tv)); + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = ¤tModule->internalTypes; + + return allocated; +} + +TypePackId TypeChecker::addTypePack(TypePackVar&& tv) +{ + TypePackId allocated = currentModule->internalTypes.typePacks.allocate(std::move(tv)); + if (FFlag::DebugLuauTrackOwningArena) + asMutable(allocated)->owningArena = ¤tModule->internalTypes; + + return allocated; +} + +TypePackId TypeChecker::addTypePack(TypePack&& tp) +{ + return addTypePack(TypePackVar(std::move(tp))); +} + +TypePackId TypeChecker::addTypePack(const std::vector& ty) +{ + return addTypePack(ty, std::nullopt); +} + +TypePackId TypeChecker::addTypePack(const std::vector& ty, std::optional tail) +{ + return addTypePack(TypePackVar(TypePack{ty, tail})); +} + +TypePackId TypeChecker::addTypePack(std::initializer_list&& ty) +{ + return addTypePack(TypePackVar(TypePack{std::vector(begin(ty), end(ty)), std::nullopt})); +} + +TypePackId TypeChecker::freshTypePack(const ScopePtr& scope) +{ + return freshTypePack(scope->level); +} + +TypePackId TypeChecker::freshTypePack(TypeLevel level) +{ + return addTypePack(TypePackVar(FreeTypePack(level))); +} + +TypePackId TypeChecker::DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric) +{ + return DEPRECATED_freshTypePack(scope->level, canBeGeneric); +} + +TypePackId TypeChecker::DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric) +{ + return addTypePack(TypePackVar(FreeTypePack(level, canBeGeneric))); +} + +TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation, bool DEPRECATED_canBeGeneric) +{ + if (DEPRECATED_canBeGeneric) + LUAU_ASSERT(!FFlag::LuauRankNTypes); + + if (const auto& lit = annotation.as()) + { + std::optional tf; + if (lit->hasPrefix) + tf = scope->lookupImportedType(lit->prefix.value, lit->name.value); + + else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_ice") + ice("_luau_ice encountered", lit->location); + + else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_print") + { + if (lit->generics.size != 1) + { + reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); + return addType(ErrorTypeVar{}); + } + + ToStringOptions opts; + opts.exhaustive = true; + opts.maxTableLength = 0; + + TypeId param = resolveType(scope, *lit->generics.data[0]); + luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); + return param; + } + + else + tf = scope->lookupType(lit->name.value); + + if (!tf) + { + if (lit->name == Parser::errorName) + return addType(ErrorTypeVar{}); + + std::string typeName; + if (lit->hasPrefix) + typeName = std::string(lit->prefix.value) + "."; + typeName += lit->name.value; + + if (scope->lookupPack(typeName)) + reportError(TypeError{annotation.location, SwappedGenericTypeParameter{typeName, SwappedGenericTypeParameter::Type}}); + else + reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}}); + + return addType(ErrorTypeVar{}); + } + + if (lit->generics.size == 0 && tf->typeParams.empty()) + return tf->type; + else if (lit->generics.size != tf->typeParams.size()) + { + reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->generics.size}}); + return addType(ErrorTypeVar{}); + } + else + { + std::vector typeParams; + for (AstType* paramAnnot : lit->generics) + typeParams.push_back(resolveType(scope, *paramAnnot)); + + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) + { + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + return tf->type; + } + + return instantiateTypeFun(scope, *tf, typeParams, annotation.location); + } + } + else if (const auto& table = annotation.as()) + { + TableTypeVar::Props props; + std::optional tableIndexer; + + for (const auto& prop : table->props) + props[prop.name.value] = {resolveType(scope, *prop.type, DEPRECATED_canBeGeneric)}; + + if (const auto& indexer = table->indexer) + tableIndexer = TableIndexer( + resolveType(scope, *indexer->indexType, DEPRECATED_canBeGeneric), resolveType(scope, *indexer->resultType, DEPRECATED_canBeGeneric)); + + return addType(TableTypeVar{ + props, tableIndexer, scope->level, + TableState::Sealed // FIXME: probably want a way to annotate other kinds of tables maybe + }); + } + else if (const auto& func = annotation.as()) + { + ScopePtr funcScope = childScope(scope, func->location); + + std::vector generics; + std::vector genericPacks; + + if (FFlag::LuauGenericFunctions) + { + std::tie(generics, genericPacks) = createGenericTypes(funcScope, annotation, func->generics, func->genericPacks); + } + + // TODO: better error message CLI-39912 + if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !DEPRECATED_canBeGeneric && (generics.size() > 0 || genericPacks.size() > 0)) + reportError(TypeError{annotation.location, GenericError{"generic function where only monotypes are allowed"}}); + + TypePackId argTypes = resolveTypePack(funcScope, func->argTypes); + TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes); + + TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(generics), std::move(genericPacks), argTypes, retTypes}); + + FunctionTypeVar* ftv = getMutable(fnType); + + ftv->argNames.reserve(func->argNames.size); + for (const auto& el : func->argNames) + { + if (el) + ftv->argNames.push_back(FunctionArgument{el->first.value, el->second}); + else + ftv->argNames.push_back(std::nullopt); + } + + return fnType; + } + else if (auto typeOf = annotation.as()) + { + TypeId ty = checkExpr(scope, *typeOf->expr).type; + // TODO: better error message CLI-39912 + if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !DEPRECATED_canBeGeneric && isGeneric(ty)) + reportError(TypeError{annotation.location, GenericError{"typeof produced a polytype where only monotypes are allowed"}}); + return ty; + } + else if (const auto& un = annotation.as()) + { + std::vector types; + for (AstType* ann : un->types) + types.push_back(resolveType(scope, *ann, DEPRECATED_canBeGeneric)); + + return addType(UnionTypeVar{types}); + } + else if (const auto& un = annotation.as()) + { + std::vector types; + for (AstType* ann : un->types) + types.push_back(resolveType(scope, *ann, DEPRECATED_canBeGeneric)); + + return addType(IntersectionTypeVar{types}); + } + else if (annotation.is()) + { + return addType(ErrorTypeVar{}); + } + else + { + reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}}); + return addType(ErrorTypeVar{}); + } +} + +TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypeList& types) +{ + if (types.types.size == 0 && types.tailType) + { + return resolveTypePack(scope, *types.tailType); + } + else if (types.types.size > 0) + { + std::vector head; + for (AstType* ann : types.types) + head.push_back(resolveType(scope, *ann)); + + std::optional tail = types.tailType ? std::optional(resolveTypePack(scope, *types.tailType)) : std::nullopt; + return addTypePack(TypePack{head, tail}); + } + + return addTypePack(TypePack{}); +} + +TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation) +{ + if (const AstTypePackVariadic* variadic = annotation.as()) + { + return addTypePack(TypePackVar{VariadicTypePack{resolveType(scope, *variadic->variadicType)}}); + } + else if (const AstTypePackGeneric* generic = annotation.as()) + { + Name genericName = Name(generic->genericName.value); + std::optional genericTy = scope->lookupPack(genericName); + + if (!genericTy) + { + if (scope->lookupType(genericName)) + reportError(TypeError{generic->location, SwappedGenericTypeParameter{genericName, SwappedGenericTypeParameter::Pack}}); + else + reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); + + return addTypePack(TypePackVar{Unifiable::Error{}}); + } + + return *genericTy; + } + else + { + ice("Unknown AstTypePack kind"); + } +} + +bool ApplyTypeFunction::isDirty(TypeId ty) +{ + // Really this should just replace the arguments, + // but for bug-compatibility with existing code, we replace + // all generics. + if (get(ty)) + return true; + else if (const FreeTypeVar* ftv = get(ty)) + { + if (FFlag::LuauRecursiveTypeParameterRestriction && ftv->forwardedTypeAlias) + encounteredForwardedType = true; + return false; + } + else + return false; +} + +bool ApplyTypeFunction::isDirty(TypePackId tp) +{ + // Really this should just replace the arguments, + // but for bug-compatibility with existing code, we replace + // all generics. + if (get(tp)) + return true; + else + return false; +} + +TypeId ApplyTypeFunction::clean(TypeId ty) +{ + // Really this should just replace the arguments, + // but for bug-compatibility with existing code, we replace + // all generics by free type variables. + TypeId& arg = arguments[ty]; + if (arg) + return arg; + else + return addType(FreeTypeVar{level}); +} + +TypePackId ApplyTypeFunction::clean(TypePackId tp) +{ + // Really this should just replace the arguments, + // but for bug-compatibility with existing code, we replace + // all generics by free type variables. + return addTypePack(FreeTypePack{level}); +} + +TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location) +{ + if (tf.typeParams.empty()) + return tf.type; + + applyTypeFunction.arguments.clear(); + for (size_t i = 0; i < tf.typeParams.size(); ++i) + applyTypeFunction.arguments[tf.typeParams[i]] = typeParams[i]; + applyTypeFunction.currentModule = currentModule; + applyTypeFunction.level = scope->level; + applyTypeFunction.encounteredForwardedType = false; + std::optional maybeInstantiated = applyTypeFunction.substitute(tf.type); + if (!maybeInstantiated.has_value()) + { + reportError(location, UnificationTooComplex{}); + return errorType; + } + if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) + { + reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); + return errorType; + } + + TypeId instantiated = *maybeInstantiated; + + if (FFlag::LuauCloneCorrectlyBeforeMutatingTableType) + { + // TODO: CLI-46926 it's a bad idea to rename the type whether we follow through the BoundTypeVar or not + TypeId target = FFlag::LuauFollowInTypeFunApply ? follow(instantiated) : instantiated; + + bool needsClone = follow(tf.type) == target; + TableTypeVar* ttv = getMutableTableType(target); + + if (ttv && needsClone) + { + // Substitution::clone is a shallow clone. If this is a metatable type, we + // want to mutate its table, so we need to explicitly clone that table as + // well. If we don't, we will mutate another module's type surface and cause + // a use-after-free. + if (get(target)) + { + instantiated = applyTypeFunction.clone(tf.type); + MetatableTypeVar* mtv = getMutable(instantiated); + mtv->table = applyTypeFunction.clone(mtv->table); + ttv = getMutable(mtv->table); + } + if (get(target)) + { + instantiated = applyTypeFunction.clone(tf.type); + ttv = getMutable(instantiated); + } + } + + if (ttv) + { + ttv->instantiatedTypeParams = typeParams; + } + } + else + { + if (TableTypeVar* ttv = getMutableTableType(instantiated)) + { + if (follow(tf.type) == instantiated) + { + // This can happen if a type alias has generics that it does not use at all. + // ex type FooBar = { a: number } + instantiated = applyTypeFunction.clone(tf.type); + ttv = getMutableTableType(instantiated); + } + + ttv->instantiatedTypeParams = typeParams; + } + } + + return instantiated; +} + +std::pair, std::vector> TypeChecker::createGenericTypes( + const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +{ + std::vector generics; + for (const AstName& generic : genericNames) + { + Name n = generic.value; + + // These generics are the only thing that will ever be added to scope, so we can be certain that + // a collision can only occur when two generic typevars have the same name. + if (scope->privateTypeBindings.count(n) || scope->privateTypePackBindings.count(n)) + { + // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. + reportError(TypeError{node.location, DuplicateGenericParameter{n}}); + } + + TypeId g = addType(Unifiable::Generic{scope->level, n}); + generics.push_back(g); + scope->privateTypeBindings[n] = TypeFun{{}, g}; + } + + std::vector genericPacks; + for (const AstName& genericPack : genericPackNames) + { + Name n = genericPack.value; + + // These generics are the only thing that will ever be added to scope, so we can be certain that + // a collision can only occur when two generic typevars have the same name. + if (scope->privateTypePackBindings.count(n) || scope->privateTypeBindings.count(n)) + { + // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. + reportError(TypeError{node.location, DuplicateGenericParameter{n}}); + } + + TypePackId g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + genericPacks.push_back(g); + scope->privateTypePackBindings[n] = g; + } + + return {generics, genericPacks}; +} + +std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue) +{ + std::string path = toString(lvalue); + auto [symbol, keys] = getFullName(lvalue); + + ScopePtr currentScope = scope; + while (currentScope) + { + if (auto it = currentScope->refinements.find(path); it != currentScope->refinements.end()) + return it->second; + + // Should not be using scope->lookup. This is already recursive. + if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) + { + std::optional currentTy = it->second.typeId; + + for (std::string key : keys) + { + // TODO: This function probably doesn't need Location at all, or at least should hide the argument. + currentTy = getIndexTypeFromType(scope, *currentTy, key, Location(), false); + if (!currentTy) + break; + } + + return currentTy; + } + + currentScope = currentScope->parent; + } + + return std::nullopt; +} + +std::optional TypeChecker::resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue) +{ + if (auto it = refis.find(toString(lvalue)); it != refis.end()) + return it->second; + else + return resolveLValue(scope, lvalue); +} + +// Only should be used for refinements! +// This can probably go away once we have something that can limit a free type's type domain. +static bool isUndecidable(TypeId ty) +{ + ty = follow(ty); + return get(ty) || get(ty) || get(ty); +} + +ErrorVec TypeChecker::resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense) +{ + ErrorVec errVec; + resolve(predicates, errVec, scope->refinements, scope, sense); + return errVec; +} + +void TypeChecker::resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +{ + for (const Predicate& c : predicates) + resolve(c, errVec, refis, scope, sense, fromOr); +} + +void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +{ + if (auto truthyP = get(predicate)) + resolve(*truthyP, errVec, refis, scope, sense, fromOr); + else if (auto andP = get(predicate)) + resolve(*andP, errVec, refis, scope, sense); + else if (auto orP = get(predicate)) + resolve(*orP, errVec, refis, scope, sense); + else if (auto notP = get(predicate)) + resolve(notP->predicates, errVec, refis, scope, !sense, fromOr); + else if (auto isaP = get(predicate)) + resolve(*isaP, errVec, refis, scope, sense); + else if (auto typeguardP = get(predicate)) + { + if (FFlag::LuauImprovedTypeGuardPredicate2) + resolve(*typeguardP, errVec, refis, scope, sense); + else + DEPRECATED_resolve(*typeguardP, errVec, refis, scope, sense); + } + else if (auto eqP = get(predicate); eqP && FFlag::LuauEqConstraint) + resolve(*eqP, errVec, refis, scope, sense); + else + ice("Unhandled predicate kind"); +} + +void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) +{ + auto predicate = [sense](TypeId option) -> std::optional { + if (isUndecidable(option) || isBoolean(option) || isNil(option) != sense) + return option; + + return std::nullopt; + }; + + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (!ty) + return; + + // This is a hack. :( + // Without this, the expression 'a or b' might refine 'b' to be falsy. + // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. + if (fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); + + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, truthyP.lvalue, *result); +} + +void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +{ + if (FFlag::LuauOrPredicate) + { + if (!sense) + { + OrPredicate orP{ + {NotPredicate{std::move(andP.lhs)}}, + {NotPredicate{std::move(andP.rhs)}}, + }; + + return resolve(orP, errVec, refis, scope, !sense); + } + + resolve(andP.lhs, errVec, refis, scope, sense); + resolve(andP.rhs, errVec, refis, scope, sense); + } + else + { + // And predicate is currently not resolvable when sense is false. 'not (a and b)' is synonymous with '(not a) or (not b)'. + // TODO: implement environment merging to permit this case. + if (!sense) + return; + + resolve(andP.lhs, errVec, refis, scope, sense); + resolve(andP.rhs, errVec, refis, scope, sense); + } +} + +void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +{ + if (!sense) + { + AndPredicate andP{ + {NotPredicate{std::move(orP.lhs)}}, + {NotPredicate{std::move(orP.rhs)}}, + }; + + return resolve(andP, errVec, refis, scope, !sense); + } + + ErrorVec discarded; + + RefinementMap leftRefis; + resolve(orP.lhs, errVec, leftRefis, scope, sense); + + RefinementMap rightRefis; + resolve(orP.lhs, discarded, rightRefis, scope, !sense); + resolve(orP.rhs, errVec, rightRefis, scope, sense, true); // :( + + merge(refis, leftRefis); + merge(refis, rightRefis); +} + +void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +{ + auto predicate = [&](TypeId option) -> std::optional { + if (FFlag::LuauTypeGuardPeelsAwaySubclasses) + { + // This by itself is not truly enough to determine that A is stronger than B or vice versa. + // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. + // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) + bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); + bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + + // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. + if (!optionIsSubtype && targetIsSubtype) + return sense ? isaP.ty : option; + + // If A is a subset of B, then if sense is true we pick A, otherwise we eliminate A. + if (optionIsSubtype && !targetIsSubtype) + return sense ? std::optional(option) : std::nullopt; + + // If neither has any relationship, we only return A if sense is false. + if (!optionIsSubtype && !targetIsSubtype) + return sense ? std::nullopt : std::optional(option); + + // If both are subtypes, then we're in one of the two situations: + // 1. Instance₁ <: Instance₂ ∧ Instance₂ <: Instance₁ + // 2. any <: Instance ∧ Instance <: any + // Right now, we have to look at the types to see if they were undecidables. + // By this point, we also know free tables are also subtypes and supertypes. + if (optionIsSubtype && targetIsSubtype) + { + // We can only have (any, Instance) because the rhs is never undecidable right now. + // So we can just return the right hand side immediately. + + // typeof(x) == "Instance" where x : any + auto ttv = get(option); + if (isUndecidable(option) || (ttv && ttv->state == TableState::Free)) + return sense ? isaP.ty : option; + + // typeof(x) == "Instance" where x : Instance + if (sense) + return isaP.ty; + } + } + else if (FFlag::LuauImprovedTypeGuardPredicate2) + { + auto lctv = get(option); + auto rctv = get(isaP.ty); + + if (isSubclass(lctv, rctv) == sense) + return option; + + if (isSubclass(rctv, lctv) == sense) + return isaP.ty; + + if (canUnify(option, isaP.ty, isaP.location).empty() == sense) + return isaP.ty; + } + else + { + auto lctv = get(option); + auto rctv = get(isaP.ty); + + if (lctv && rctv) + { + if (isSubclass(lctv, rctv) == sense) + return option; + else if (isSubclass(rctv, lctv) == sense) + return isaP.ty; + } + } + + return std::nullopt; + }; + + std::optional ty = resolveLValue(refis, scope, isaP.lvalue); + if (!ty) + return; + + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, isaP.lvalue, *result); + else + { + addRefinement(refis, isaP.lvalue, errorType); + errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); + } +} + +void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +{ + // Rewrite the predicate 'type(foo) == "vector"' to be 'typeof(foo) == "Vector3"'. They're exactly identical. + // This allows us to avoid writing in edge cases. + if (!typeguardP.isTypeof && typeguardP.kind == "vector") + return resolve(TypeGuardPredicate{std::move(typeguardP.lvalue), typeguardP.location, "Vector3", true}, errVec, refis, scope, sense); + + std::optional ty = resolveLValue(refis, scope, typeguardP.lvalue); + if (!ty) + return; + + // In certain cases, the value may actually be nil, but Luau doesn't know about it. So we whitelist this. + if (sense && typeguardP.kind == "nil") + { + addRefinement(refis, typeguardP.lvalue, nilType); + return; + } + + using ConditionFunc = bool(TypeId); + using SenseToTypeIdPredicate = std::function; + auto mkFilter = [](ConditionFunc f, std::optional other = std::nullopt) -> SenseToTypeIdPredicate { + return [f, other](bool sense) -> TypeIdPredicate { + return [f, other, sense](TypeId ty) -> std::optional { + if (f(ty) == sense) + return ty; + + if (isUndecidable(ty)) + return other.value_or(ty); + + return std::nullopt; + }; + }; + }; + + // Note: "vector" never happens here at this point, so we don't have to write something for it. + // clang-format off + static const std::unordered_map primitives{ + // Trivial primitives. + {"nil", mkFilter(isNil, nilType)}, // This can still happen when sense is false! + {"string", mkFilter(isString, stringType)}, + {"number", mkFilter(isNumber, numberType)}, + {"boolean", mkFilter(isBoolean, booleanType)}, + {"thread", mkFilter(isThread, threadType)}, + + // Non-trivial primitives. + {"table", mkFilter([](TypeId ty) -> bool { return isTableIntersection(ty) || get(ty) || get(ty); })}, + {"function", mkFilter([](TypeId ty) -> bool { return isOverloadedFunction(ty) || get(ty); })}, + + // For now, we don't really care about being accurate with userdata if the typeguard was using typeof. + {"userdata", mkFilter([](TypeId ty) -> bool { return get(ty); })}, + }; + // clang-format on + + if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) + { + if (std::optional result = filterMap(*ty, it->second(sense))) + addRefinement(refis, typeguardP.lvalue, *result); + else + { + addRefinement(refis, typeguardP.lvalue, errorType); + if (sense) + errVec.push_back( + TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); + } + + return; + } + + auto fail = [&](const TypeErrorData& err) { + errVec.push_back(TypeError{typeguardP.location, err}); + addRefinement(refis, typeguardP.lvalue, errorType); + }; + + if (!typeguardP.isTypeof) + return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + + auto typeFun = globalScope->lookupType(typeguardP.kind); + if (!typeFun || !typeFun->typeParams.empty()) + return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + + TypeId type = follow(typeFun->type); + + // We're only interested in the root class of any classes. + if (auto ctv = get(type); !ctv || ctv->parent) + return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); + + // This probably hints at breaking out type filtering functions from the predicate solver so that typeof is not tightly coupled with IsA. + // Until then, we rewrite this to be the same as using IsA. + return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); +} + +void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +{ + if (!sense) + return; + + static std::vector primitives{ + "string", "number", "boolean", "nil", "thread", + "table", // no op. Requires special handling. + "function", // no op. Requires special handling. + "userdata", // no op. Requires special handling. + }; + + if (auto typeFun = globalScope->lookupType(typeguardP.kind); typeFun && typeFun->typeParams.empty()) + { + if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) + addRefinement(refis, typeguardP.lvalue, typeFun->type); + else if (typeguardP.isTypeof) + addRefinement(refis, typeguardP.lvalue, typeFun->type); + } +} + +void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) +{ + // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. + + auto options = [](TypeId ty) -> std::vector { + if (auto utv = get(follow(ty))) + return std::vector(begin(utv), end(utv)); + return {ty}; + }; + + if (FFlag::LuauWeakEqConstraint) + { + if (!sense && isNil(eqP.type)) + resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); + + return; + } + + std::optional ty = resolveLValue(refis, scope, eqP.lvalue); + if (!ty) + return; + + std::vector lhs = options(*ty); + std::vector rhs = options(eqP.type); + + if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) + { + addRefinement(refis, eqP.lvalue, eqP.type); + return; + } + else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + + std::unordered_set set; + for (TypeId left : lhs) + { + for (TypeId right : rhs) + { + // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. + if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) + set.insert(left); + } + } + + if (set.empty()) + return; + + std::vector viable(set.begin(), set.end()); + TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); + addRefinement(refis, eqP.lvalue, result); +} + +bool TypeChecker::isNonstrictMode() const +{ + return (currentModule->mode == Mode::Nonstrict) || (currentModule->mode == Mode::NoCheck); +} + +std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp, size_t expectedLength, const Location& location) +{ + TypePackId expectedTypePack = addTypePack({}); + TypePack* expectedPack = getMutable(expectedTypePack); + LUAU_ASSERT(expectedPack); + for (size_t i = 0; i < expectedLength; ++i) + expectedPack->head.push_back(FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + + unify(expectedTypePack, tp, location); + + for (TypeId& tp : expectedPack->head) + tp = follow(tp); + + return expectedPack->head; +} + +std::vector> TypeChecker::getScopes() const +{ + return currentModule->scopes; +} + +Scope::Scope(TypePackId returnType) + : parent(nullptr) + , returnType(returnType) + , level(TypeLevel()) +{ +} + +Scope::Scope(const ScopePtr& parent, int subLevel) + : parent(parent) + , returnType(parent->returnType) + , level(parent->level.incr()) +{ + level.subLevel = subLevel; +} + +std::optional Scope::lookup(const Symbol& name) +{ + Scope* scope = this; + + while (scope) + { + auto it = scope->bindings.find(name); + if (it != scope->bindings.end()) + return it->second.typeId; + + scope = scope->parent.get(); + } + + return std::nullopt; +} + +std::optional Scope::lookupType(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->exportedTypeBindings.find(name); + if (it != scope->exportedTypeBindings.end()) + return it->second; + + it = scope->privateTypeBindings.find(name); + if (it != scope->privateTypeBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) +{ + const Scope* scope = this; + while (scope) + { + auto it = scope->importedTypeBindings.find(moduleAlias); + if (it == scope->importedTypeBindings.end()) + { + scope = scope->parent.get(); + continue; + } + + auto it2 = it->second.find(name); + if (it2 == it->second.end()) + { + scope = scope->parent.get(); + continue; + } + + return it2->second; + } + + return std::nullopt; +} + +std::optional Scope::lookupPack(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->privateTypePackBindings.find(name); + if (it != scope->privateTypePackBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) +{ + Scope* scope = this; + + while (scope) + { + for (const auto& [n, binding] : scope->bindings) + { + if (n.local && n.local->name == name.c_str()) + return binding; + else if (n.global.value && n.global == name.c_str()) + return binding; + } + + scope = scope->parent.get(); + + if (!traverseScopeChain) + break; + } + + return std::nullopt; +} + +} // namespace Luau diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp new file mode 100644 index 0000000..5970f30 --- /dev/null +++ b/Analysis/src/TypePack.cpp @@ -0,0 +1,277 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ + +TypePackVar::TypePackVar(const TypePackVariant& tp) + : ty(tp) +{ +} + +TypePackVar::TypePackVar(TypePackVariant&& tp) + : ty(std::move(tp)) +{ +} + +TypePackVar::TypePackVar(TypePackVariant&& tp, bool persistent) + : ty(std::move(tp)) + , persistent(persistent) +{ +} + +bool TypePackVar::operator==(const TypePackVar& rhs) const +{ + SeenSet seen; + return areEqual(seen, *this, rhs); +} + +TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) +{ + ty = std::move(tp); + return *this; +} + +TypePackIterator::TypePackIterator(TypePackId typePack) + : currentTypePack(follow(typePack)) + , tp(get(currentTypePack)) + , currentIndex(0) +{ + while (tp && tp->head.empty()) + { + currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; + tp = currentTypePack ? get(currentTypePack) : nullptr; + } +} + +TypePackIterator& TypePackIterator::operator++() +{ + LUAU_ASSERT(tp); + + ++currentIndex; + while (tp && currentIndex >= tp->head.size()) + { + currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; + tp = currentTypePack ? get(currentTypePack) : nullptr; + currentIndex = 0; + } + + return *this; +} + +TypePackIterator TypePackIterator::operator++(int) +{ + TypePackIterator copy = *this; + ++*this; + return copy; +} + +bool TypePackIterator::operator!=(const TypePackIterator& rhs) +{ + return !(*this == rhs); +} + +bool TypePackIterator::operator==(const TypePackIterator& rhs) +{ + return tp == rhs.tp && currentIndex == rhs.currentIndex; +} + +const TypeId& TypePackIterator::operator*() +{ + LUAU_ASSERT(tp); + return tp->head[currentIndex]; +} + +std::optional TypePackIterator::tail() +{ + LUAU_ASSERT(!tp); + return currentTypePack ? std::optional{currentTypePack} : std::nullopt; +} + +TypePackIterator begin(TypePackId tp) +{ + return TypePackIterator{tp}; +} + +TypePackIterator end(TypePackId tp) +{ + return FFlag::LuauAddMissingFollow ? TypePackIterator{} : TypePackIterator{nullptr}; +} + +bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) +{ + TypePackId lhsId = const_cast(&lhs); + TypePackId rhsId = const_cast(&rhs); + TypePackIterator lhsIter = begin(lhsId); + TypePackIterator rhsIter = begin(rhsId); + TypePackIterator lhsEnd = end(lhsId); + TypePackIterator rhsEnd = end(rhsId); + while (lhsIter != lhsEnd && rhsIter != rhsEnd) + { + if (!areEqual(seen, **lhsIter, **rhsIter)) + return false; + ++lhsIter; + ++rhsIter; + } + + if (lhsIter != lhsEnd || rhsIter != rhsEnd) + return false; + + if (!lhsIter.tail() && !rhsIter.tail()) + return true; + if (!lhsIter.tail() || !rhsIter.tail()) + return false; + + TypePackId lhsTail = *lhsIter.tail(); + TypePackId rhsTail = *rhsIter.tail(); + + { + const Unifiable::Free* lf = get_if(&lhsTail->ty); + const Unifiable::Free* rf = get_if(&rhsTail->ty); + if (lf && rf) + return lf->index == rf->index; + } + + { + const Unifiable::Bound* lb = get_if>(&lhsTail->ty); + const Unifiable::Bound* rb = get_if>(&rhsTail->ty); + if (lb && rb) + return areEqual(seen, *lb->boundTo, *rb->boundTo); + } + + { + const Unifiable::Generic* lg = get_if(&lhsTail->ty); + const Unifiable::Generic* rg = get_if(&rhsTail->ty); + if (lg && rg) + return lg->index == rg->index; + } + + { + const VariadicTypePack* lv = get_if(&lhsTail->ty); + const VariadicTypePack* rv = get_if(&rhsTail->ty); + if (lv && rv) + return areEqual(seen, *lv->ty, *rv->ty); + } + + return false; +} + +TypePackId follow(TypePackId tp) +{ + auto advance = [](TypePackId ty) -> std::optional { + if (const Unifiable::Bound* btv = get>(ty)) + return btv->boundTo; + else + return std::nullopt; + }; + + TypePackId cycleTester = tp; // Null once we've determined that there is no cycle + if (auto a = advance(cycleTester)) + cycleTester = *a; + else + return tp; + + while (true) + { + auto a1 = advance(tp); + if (a1) + tp = *a1; + else + return tp; + + if (nullptr != cycleTester) + { + auto a2 = advance(cycleTester); + if (a2) + { + auto a3 = advance(*a2); + if (a3) + cycleTester = *a3; + else + cycleTester = nullptr; + } + else + cycleTester = nullptr; + + if (tp == cycleTester) + throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); + } + } +} + +size_t size(TypePackId tp) +{ + if (auto pack = get(FFlag::LuauAddMissingFollow ? follow(tp) : tp)) + return size(*pack); + else + return 0; +} + +size_t size(const TypePack& tp) +{ + size_t result = tp.head.size(); + if (tp.tail) + { + const TypePack* tail = get(FFlag::LuauAddMissingFollow ? follow(*tp.tail) : *tp.tail); + if (tail) + result += size(*tail); + } + return result; +} + +std::optional first(TypePackId tp) +{ + auto it = begin(tp); + auto endIter = end(tp); + + if (it != endIter) + return *it; + + if (auto tail = it.tail()) + { + if (auto vtp = get(*tail)) + return vtp->ty; + } + + return std::nullopt; +} + +bool isEmpty(TypePackId tp) +{ + tp = follow(tp); + if (auto tpp = get(tp)) + { + return tpp->head.empty() && (!tpp->tail || isEmpty(*tpp->tail)); + } + + return false; +} + +std::pair, std::optional> flatten(TypePackId tp) +{ + std::vector res; + + auto iter = begin(tp); + auto endIter = end(tp); + while (iter != endIter) + { + res.push_back(*iter); + ++iter; + } + + return {res, iter.tail()}; +} + +TypePackVar* asMutable(TypePackId tp) +{ + return const_cast(tp); +} + +TypePack* asMutable(const TypePack* tp) +{ + return const_cast(tp); +} + +} // namespace Luau diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp new file mode 100644 index 0000000..b9f5097 --- /dev/null +++ b/Analysis/src/TypeUtils.cpp @@ -0,0 +1,95 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeUtils.h" + +#include "Luau/ToString.h" +#include "Luau/TypeInfer.h" + +LUAU_FASTFLAG(LuauStringMetatable) + +namespace Luau +{ + +std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globalScope, TypeId type, std::string entry, Location location) +{ + type = follow(type); + + if (!FFlag::LuauStringMetatable) + { + if (const PrimitiveTypeVar* primType = get(type)) + { + if (primType->type != PrimitiveTypeVar::String || "__index" != entry) + return std::nullopt; + + auto it = globalScope->bindings.find(AstName{"string"}); + if (it != globalScope->bindings.end()) + return it->second.typeId; + else + return std::nullopt; + } + } + + std::optional metatable = getMetatable(type); + if (!metatable) + return std::nullopt; + + TypeId unwrapped = follow(*metatable); + + if (get(unwrapped)) + return singletonTypes.anyType; + + const TableTypeVar* mtt = getTableType(unwrapped); + if (!mtt) + { + errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}}); + return std::nullopt; + } + + auto it = mtt->props.find(entry); + if (it != mtt->props.end()) + return it->second.type; + else + return std::nullopt; +} + +std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const ScopePtr& globalScope, TypeId ty, Name name, Location location) +{ + if (get(ty)) + return ty; + + if (const TableTypeVar* tableType = getTableType(ty)) + { + const auto& it = tableType->props.find(name); + if (it != tableType->props.end()) + return it->second.type; + } + + std::optional mtIndex = findMetatableEntry(errors, globalScope, ty, "__index", location); + while (mtIndex) + { + TypeId index = follow(*mtIndex); + if (const auto& itt = getTableType(index)) + { + const auto& fit = itt->props.find(name); + if (fit != itt->props.end()) + return fit->second.type; + } + else if (const auto& itf = get(index)) + { + std::optional r = first(follow(itf->retType)); + if (!r) + return singletonTypes.nilType; + else + return *r; + } + else if (get(index)) + return singletonTypes.anyType; + else + errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); + + mtIndex = findMetatableEntry(errors, globalScope, *mtIndex, "__index", location); + } + + return std::nullopt; +} + +} // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp new file mode 100644 index 0000000..111f4f5 --- /dev/null +++ b/Analysis/src/TypeVar.cpp @@ -0,0 +1,1505 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeVar.h" + +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Common.h" +#include "Luau/DenseHash.h" +#include "Luau/Error.h" +#include "Luau/StringUtils.h" +#include "Luau/ToString.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypePack.h" +#include "Luau/VisitTypeVar.h" + +#include +#include +#include +#include +#include + +LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) +LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) +LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) +LUAU_FASTFLAGVARIABLE(LuauToStringFollowsBoundTo, false) +LUAU_FASTFLAG(LuauRankNTypes) +LUAU_FASTFLAGVARIABLE(LuauStringMetatable, false) +LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) + +namespace Luau +{ + +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult); + +TypeId follow(TypeId t) +{ + auto advance = [](TypeId ty) -> std::optional { + if (auto btv = get>(ty)) + return btv->boundTo; + else if (auto ttv = get(ty)) + return ttv->boundTo; + else + return std::nullopt; + }; + + auto force = [](TypeId ty) { + if (auto ltv = FFlag::LuauAddMissingFollow ? get_if(&ty->ty) : get(ty)) + { + TypeId res = ltv->thunk(); + if (get(res)) + throw std::runtime_error("Lazy TypeVar cannot resolve to another Lazy TypeVar"); + + *asMutable(ty) = BoundTypeVar(res); + } + }; + + force(t); + + TypeId cycleTester = t; // Null once we've determined that there is no cycle + if (auto a = advance(cycleTester)) + cycleTester = *a; + else + return t; + + while (true) + { + force(t); + auto a1 = advance(t); + if (a1) + t = *a1; + else + return t; + + if (nullptr != cycleTester) + { + auto a2 = advance(cycleTester); + if (a2) + { + auto a3 = advance(*a2); + if (a3) + cycleTester = *a3; + else + cycleTester = nullptr; + } + else + cycleTester = nullptr; + + if (t == cycleTester) + throw std::runtime_error("Luau::follow detected a TypeVar cycle!!"); + } + } +} + +std::vector flattenIntersection(TypeId ty) +{ + if (!get(follow(ty))) + return {ty}; + + std::unordered_set seen; + std::deque queue{ty}; + + std::vector result; + + while (!queue.empty()) + { + TypeId current = follow(queue.front()); + queue.pop_front(); + + if (seen.find(current) != seen.end()) + continue; + + seen.insert(current); + + if (auto itv = get(current)) + { + for (TypeId ty : itv->parts) + queue.push_back(ty); + } + else + result.push_back(current); + } + + return result; +} + +bool isPrim(TypeId ty, PrimitiveTypeVar::Type primType) +{ + auto p = get(follow(ty)); + return p && p->type == primType; +} + +bool isNil(TypeId ty) +{ + return isPrim(ty, PrimitiveTypeVar::NilType); +} + +bool isBoolean(TypeId ty) +{ + return isPrim(ty, PrimitiveTypeVar::Boolean); +} + +bool isNumber(TypeId ty) +{ + return isPrim(ty, PrimitiveTypeVar::Number); +} + +bool isString(TypeId ty) +{ + return isPrim(ty, PrimitiveTypeVar::String); +} + +bool isThread(TypeId ty) +{ + return isPrim(ty, PrimitiveTypeVar::Thread); +} + +bool isOptional(TypeId ty) +{ + if (isNil(ty)) + return true; + + if (!get(follow(ty))) + return false; + + std::unordered_set seen; + std::deque queue{ty}; + while (!queue.empty()) + { + TypeId current = follow(queue.front()); + queue.pop_front(); + + if (seen.count(current)) + continue; + + seen.insert(current); + + if (isNil(current)) + return true; + + if (auto u = get(current)) + { + for (TypeId option : u->options) + { + if (isNil(option)) + return true; + + queue.push_back(option); + } + } + } + + return false; +} + +bool isTableIntersection(TypeId ty) +{ + if (FFlag::LuauImprovedTypeGuardPredicate2) + { + if (!get(follow(ty))) + return false; + + std::vector parts = flattenIntersection(ty); + return std::all_of(parts.begin(), parts.end(), getTableType); + } + else + { + if (const IntersectionTypeVar* itv = get(ty)) + { + for (TypeId part : itv->parts) + { + if (getTableType(follow(part))) + return true; + } + } + + return false; + } +} + +bool isOverloadedFunction(TypeId ty) +{ + if (!get(follow(ty))) + return false; + + auto isFunction = [](TypeId part) -> bool { + return get(part); + }; + + std::vector parts = flattenIntersection(ty); + return std::all_of(parts.begin(), parts.end(), isFunction); +} + +std::optional getMetatable(TypeId type) +{ + if (const MetatableTypeVar* mtType = get(type)) + return mtType->metatable; + else if (const ClassTypeVar* classType = get(type)) + return classType->metatable; + else if (const PrimitiveTypeVar* primitiveType = get(type); + FFlag::LuauStringMetatable && primitiveType && primitiveType->metatable) + { + LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); + return primitiveType->metatable; + } + else + return std::nullopt; +} + +const TableTypeVar* getTableType(TypeId type) +{ + if (const TableTypeVar* ttv = get(type)) + return ttv; + else if (const MetatableTypeVar* mtv = get(type)) + return get(mtv->table); + else + return nullptr; +} + +TableTypeVar* getMutableTableType(TypeId type) +{ + return const_cast(getTableType(type)); +} + +const std::string* getName(TypeId type) +{ + type = follow(type); + if (auto mtv = get(type)) + { + if (mtv->syntheticName) + return &*mtv->syntheticName; + type = mtv->table; + } + + if (auto ttv = get(type)) + { + if (ttv->name) + return &*ttv->name; + if (ttv->syntheticName) + return &*ttv->syntheticName; + } + + return nullptr; +} + +bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub) +{ + std::unordered_set superTypes; + + for (TypeId id : super.options) + superTypes.insert(id); + + for (TypeId id : sub.options) + { + if (superTypes.find(id) == superTypes.end()) + return false; + } + + return true; +} + +// When typechecking an assignment `x = e`, we typecheck `x:T` and `e:U`, +// then instantiate U if `isGeneric(U)` is true, and `maybeGeneric(T)` is false. +bool isGeneric(TypeId ty) +{ + ty = follow(ty); + if (auto ftv = get(ty)) + return ftv->generics.size() > 0 || ftv->genericPacks.size() > 0; + else + // TODO: recurse on type synonyms CLI-39914 + // TODO: recurse on table types CLI-39914 + return false; +} + +bool maybeGeneric(TypeId ty) +{ + ty = follow(ty); + if (auto ftv = get(ty)) + return FFlag::LuauRankNTypes || ftv->DEPRECATED_canBeGeneric; + else if (auto ttv = get(ty)) + { + // TODO: recurse on table types CLI-39914 + (void)ttv; + return true; + } + else + return isGeneric(ty); +} + +FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) + : argTypes(argTypes) + , retType(retType) + , definition(std::move(defn)) + , hasSelf(hasSelf) +{ +} + +FunctionTypeVar::FunctionTypeVar(TypeLevel level, TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) + : level(level) + , argTypes(argTypes) + , retType(retType) + , definition(std::move(defn)) + , hasSelf(hasSelf) +{ +} + +FunctionTypeVar::FunctionTypeVar(std::vector generics, std::vector genericPacks, TypePackId argTypes, TypePackId retType, + std::optional defn, bool hasSelf) + : generics(generics) + , genericPacks(genericPacks) + , argTypes(argTypes) + , retType(retType) + , definition(std::move(defn)) + , hasSelf(hasSelf) +{ +} + +FunctionTypeVar::FunctionTypeVar(TypeLevel level, std::vector generics, std::vector genericPacks, TypePackId argTypes, + TypePackId retType, std::optional defn, bool hasSelf) + : level(level) + , generics(generics) + , genericPacks(genericPacks) + , argTypes(argTypes) + , retType(retType) + , definition(std::move(defn)) + , hasSelf(hasSelf) +{ +} + +TableTypeVar::TableTypeVar(TableState state, TypeLevel level) + : state(state) + , level(level) +{ +} + +TableTypeVar::TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, TableState state) + : props(props) + , indexer(indexer) + , state(state) + , level(level) +{ +} + +// Test TypeVars for equivalence +// More complex than we'd like because TypeVars can self-reference. + +bool areSeen(SeenSet& seen, const void* lhs, const void* rhs) +{ + if (lhs == rhs) + return true; + + auto p = std::make_pair(const_cast(lhs), const_cast(rhs)); + if (seen.find(p) != seen.end()) + return true; + + seen.insert(p); + return false; +} + +bool areEqual(SeenSet& seen, const FunctionTypeVar& lhs, const FunctionTypeVar& rhs) +{ + if (areSeen(seen, &lhs, &rhs)) + return true; + + // TODO: check generics CLI-39915 + + if (!areEqual(seen, *lhs.argTypes, *rhs.argTypes)) + return false; + + if (!areEqual(seen, *lhs.retType, *rhs.retType)) + return false; + + return true; +} + +bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs) +{ + if (areSeen(seen, &lhs, &rhs)) + return true; + + if (lhs.state != rhs.state) + return false; + + if (lhs.props.size() != rhs.props.size()) + return false; + + if (bool(lhs.indexer) != bool(rhs.indexer)) + return false; + + if (lhs.indexer && rhs.indexer) + { + if (!areEqual(seen, *lhs.indexer->indexType, *rhs.indexer->indexType)) + return false; + + if (!areEqual(seen, *lhs.indexer->indexResultType, *rhs.indexer->indexResultType)) + return false; + } + + auto l = lhs.props.begin(); + auto r = rhs.props.begin(); + + while (l != lhs.props.end()) + { + if (l->first != r->first) + return false; + + if (!areEqual(seen, *l->second.type, *r->second.type)) + return false; + ++l; + ++r; + } + + return true; +} + +static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs) +{ + return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable); +} + +bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs) +{ + if (auto bound = get_if(&lhs.ty)) + return areEqual(seen, *bound->boundTo, rhs); + + if (auto bound = get_if(&rhs.ty)) + return areEqual(seen, lhs, *bound->boundTo); + + if (lhs.ty.index() != rhs.ty.index()) + return false; + + { + const FreeTypeVar* lf = get_if(&lhs.ty); + const FreeTypeVar* rf = get_if(&rhs.ty); + if (lf && rf) + return lf->index == rf->index; + } + + { + const GenericTypeVar* lg = get_if(&lhs.ty); + const GenericTypeVar* rg = get_if(&rhs.ty); + if (lg && rg) + return lg->index == rg->index; + } + + { + const PrimitiveTypeVar* lp = get_if(&lhs.ty); + const PrimitiveTypeVar* rp = get_if(&rhs.ty); + if (lp && rp) + return lp->type == rp->type; + } + + { + const GenericTypeVar* lg = get_if(&lhs.ty); + const GenericTypeVar* rg = get_if(&rhs.ty); + if (lg && rg) + return lg->index == rg->index; + } + + { + const ErrorTypeVar* le = get_if(&lhs.ty); + const ErrorTypeVar* re = get_if(&rhs.ty); + if (le && re) + return le->index == re->index; + } + + { + const FunctionTypeVar* lf = get_if(&lhs.ty); + const FunctionTypeVar* rf = get_if(&rhs.ty); + if (lf && rf) + return areEqual(seen, *lf, *rf); + } + + { + const TableTypeVar* lt = get_if(&lhs.ty); + const TableTypeVar* rt = get_if(&rhs.ty); + if (lt && rt) + return areEqual(seen, *lt, *rt); + } + + { + const MetatableTypeVar* lmt = get_if(&lhs.ty); + const MetatableTypeVar* rmt = get_if(&rhs.ty); + + if (lmt && rmt) + return areEqual(seen, *lmt, *rmt); + } + + if (get_if(&lhs.ty) && get_if(&rhs.ty)) + return true; + + return false; +} + +TypeVar* asMutable(TypeId ty) +{ + return const_cast(ty); +} + +bool TypeVar::operator==(const TypeVar& rhs) const +{ + SeenSet seen; + return areEqual(seen, *this, rhs); +} + +bool TypeVar::operator!=(const TypeVar& rhs) const +{ + SeenSet seen; + return !areEqual(seen, *this, rhs); +} + +TypeVar& TypeVar::operator=(const TypeVariant& rhs) +{ + ty = rhs; + return *this; +} + +TypeVar& TypeVar::operator=(TypeVariant&& rhs) +{ + ty = std::move(rhs); + return *this; +} + +TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initializer_list generics, + std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, + std::initializer_list retTypes); + +SingletonTypes::SingletonTypes() + : arena(new TypeArena) + , nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true} + , numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true} + , stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true} + , booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true} + , threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true} + , anyType_{AnyTypeVar{}} + , errorType_{ErrorTypeVar{}} +{ + TypeId stringMetatable = makeStringMetatable(); + stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, makeStringMetatable()}; + persist(stringMetatable); + freeze(*arena); +} + +TypeId SingletonTypes::makeStringMetatable() +{ + const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); + const TypeId optionalString = arena->addType(UnionTypeVar{{nilType, stringType}}); + const TypeId optionalBoolean = arena->addType(UnionTypeVar{{nilType, &booleanType_}}); + + const TypePackId oneStringPack = arena->addTypePack({stringType}); + const TypePackId anyTypePack = arena->addTypePack(TypePackVar{VariadicTypePack{anyType}, true}); + + FunctionTypeVar formatFTV{arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}; + formatFTV.magicFunction = &magicFunctionFormat; + const TypeId formatFn = arena->addType(formatFTV); + + const TypePackId emptyPack = arena->addTypePack({}); + const TypePackId stringVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{stringType}}); + const TypePackId numberVariadicList = arena->addTypePack(TypePackVar{VariadicTypePack{numberType}}); + + const TypeId stringToStringType = makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType}); + + const TypeId replArgType = arena->addType( + UnionTypeVar{{stringType, arena->addType(TableTypeVar({}, TableIndexer(stringType, stringType), TypeLevel{}, TableState::Generic)), + makeFunction(*arena, std::nullopt, {}, {}, {stringType}, {}, {stringType})}}); + const TypeId gsubFunc = makeFunction(*arena, stringType, {}, {}, {stringType, replArgType, optionalNumber}, {}, {stringType, numberType}); + const TypeId gmatchFunc = + makeFunction(*arena, stringType, {}, {}, {stringType}, {}, {arena->addType(FunctionTypeVar{emptyPack, stringVariadicList})}); + + TableTypeVar::Props stringLib = { + {"byte", {arena->addType(FunctionTypeVar{arena->addTypePack({stringType, optionalNumber, optionalNumber}), numberVariadicList})}}, + {"char", {arena->addType(FunctionTypeVar{arena->addTypePack(TypePack{{numberType}, numberVariadicList}), arena->addTypePack({stringType})})}}, + {"find", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber, optionalBoolean}, {}, {optionalNumber, optionalNumber})}}, + {"format", {formatFn}}, // FIXME + {"gmatch", {gmatchFunc}}, + {"gsub", {gsubFunc}}, + {"len", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, + {"lower", {stringToStringType}}, + {"match", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalNumber}, {}, {optionalString})}}, + {"rep", {makeFunction(*arena, stringType, {}, {}, {numberType}, {}, {stringType})}}, + {"reverse", {stringToStringType}}, + {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, + {"upper", {stringToStringType}}, + {"split", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalString}, {}, + {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}})})}}, + {"pack", {arena->addType(FunctionTypeVar{ + arena->addTypePack(TypePack{{stringType}, anyTypePack}), + oneStringPack, + })}}, + {"packsize", {makeFunction(*arena, stringType, {}, {}, {}, {}, {numberType})}}, + {"unpack", {arena->addType(FunctionTypeVar{ + arena->addTypePack(TypePack{{stringType, stringType, optionalNumber}}), + anyTypePack, + })}}, + }; + + assignPropDocumentationSymbols(stringLib, "@luau/global/string"); + + TypeId tableType = arena->addType(TableTypeVar{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + + return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); +} + +SingletonTypes singletonTypes; + +void persist(TypeId ty) +{ + std::deque queue{ty}; + + while (!queue.empty()) + { + TypeId t = queue.front(); + queue.pop_front(); + + if (t->persistent) + continue; + + asMutable(t)->persistent = true; + + if (auto btv = get(t)) + queue.push_back(btv->boundTo); + else if (auto ftv = get(t)) + { + persist(ftv->argTypes); + persist(ftv->retType); + } + else if (auto ttv = get(t)) + { + for (const auto& [_name, prop] : ttv->props) + queue.push_back(prop.type); + + if (ttv->indexer) + { + queue.push_back(ttv->indexer->indexType); + queue.push_back(ttv->indexer->indexResultType); + } + } + else if (auto ctv = get(t)) + { + for (const auto& [_name, prop] : ctv->props) + queue.push_back(prop.type); + } + else if (auto utv = get(t)) + { + for (TypeId opt : utv->options) + queue.push_back(opt); + } + else if (auto itv = get(t)) + { + for (TypeId opt : itv->parts) + queue.push_back(opt); + } + } +} + +void persist(TypePackId tp) +{ + if (tp->persistent) + return; + + asMutable(tp)->persistent = true; + + if (auto p = get(tp)) + { + for (TypeId ty : p->head) + persist(ty); + if (p->tail) + persist(*p->tail); + } +} + +namespace +{ + +struct StateDot +{ + StateDot(ToDotOptions opts) + : opts(opts) + { + } + + ToDotOptions opts; + + std::unordered_set seenTy; + std::unordered_set seenTp; + std::unordered_map tyToIndex; + std::unordered_map tpToIndex; + int nextIndex = 1; + std::string result; + + bool canDuplicatePrimitive(TypeId ty); + + void visitChildren(TypeId ty, int index); + void visitChildren(TypePackId ty, int index); + + void visitChild(TypeId ty, int parentIndex, const char* linkName = nullptr); + void visitChild(TypePackId tp, int parentIndex, const char* linkName = nullptr); + + void startNode(int index); + void finishNode(); + + void startNodeLabel(); + void finishNodeLabel(TypeId ty); + void finishNodeLabel(TypePackId tp); +}; + +bool StateDot::canDuplicatePrimitive(TypeId ty) +{ + if (get(ty)) + return false; + + return get(ty) || get(ty); +} + +void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) +{ + if (!tyToIndex.count(ty) || (opts.duplicatePrimitives && canDuplicatePrimitive(ty))) + tyToIndex[ty] = nextIndex++; + + int index = tyToIndex[ty]; + + if (parentIndex != 0) + { + if (linkName) + formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, index, linkName); + else + formatAppend(result, "n%d -> n%d;\n", parentIndex, index); + } + + if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) + { + if (const PrimitiveTypeVar* ptv = get(ty)) + formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); + else if (const AnyTypeVar* atv = get(ty)) + formatAppend(result, "n%d [label=\"any\"];\n", index); + } + else + { + visitChildren(ty, index); + } +} + +void StateDot::visitChild(TypePackId tp, int parentIndex, const char* linkName) +{ + if (!tpToIndex.count(tp)) + tpToIndex[tp] = nextIndex++; + + if (linkName) + formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, tpToIndex[tp], linkName); + else + formatAppend(result, "n%d -> n%d;\n", parentIndex, tpToIndex[tp]); + + visitChildren(tp, tpToIndex[tp]); +} + +void StateDot::startNode(int index) +{ + formatAppend(result, "n%d [", index); +} + +void StateDot::finishNode() +{ + formatAppend(result, "];\n"); +} + +void StateDot::startNodeLabel() +{ + formatAppend(result, "label=\""); +} + +void StateDot::finishNodeLabel(TypeId ty) +{ + if (opts.showPointers) + formatAppend(result, "\n0x%p", ty); + // additional common attributes can be added here as well + result += "\""; +} + +void StateDot::finishNodeLabel(TypePackId tp) +{ + if (opts.showPointers) + formatAppend(result, "\n0x%p", tp); + // additional common attributes can be added here as well + result += "\""; +} + +void StateDot::visitChildren(TypeId ty, int index) +{ + if (seenTy.count(ty)) + return; + seenTy.insert(ty); + + startNode(index); + startNodeLabel(); + + if (const BoundTypeVar* btv = get(ty)) + { + formatAppend(result, "BoundTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(btv->boundTo, index); + } + else if (const FunctionTypeVar* ftv = get(ty)) + { + formatAppend(result, "FunctionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(ftv->argTypes, index, "arg"); + visitChild(ftv->retType, index, "ret"); + } + else if (const TableTypeVar* ttv = get(ty)) + { + if (ttv->name) + formatAppend(result, "TableTypeVar %s", ttv->name->c_str()); + else if (ttv->syntheticName) + formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str()); + else + formatAppend(result, "TableTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + if (ttv->boundTo) + return visitChild(*ttv->boundTo, index, "boundTo"); + + for (const auto& [name, prop] : ttv->props) + visitChild(prop.type, index, name.c_str()); + if (ttv->indexer) + { + visitChild(ttv->indexer->indexType, index, "[index]"); + visitChild(ttv->indexer->indexResultType, index, "[value]"); + } + for (TypeId itp : ttv->instantiatedTypeParams) + visitChild(itp, index, "typeParam"); + } + else if (const MetatableTypeVar* mtv = get(ty)) + { + formatAppend(result, "MetatableTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(mtv->table, index, "table"); + visitChild(mtv->metatable, index, "metatable"); + } + else if (const UnionTypeVar* utv = get(ty)) + { + formatAppend(result, "UnionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId opt : utv->options) + visitChild(opt, index); + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + formatAppend(result, "IntersectionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : itv->parts) + visitChild(part, index); + } + else if (const GenericTypeVar* gtv = get(ty)) + { + if (gtv->explicitName) + formatAppend(result, "GenericTypeVar %s", gtv->name.c_str()); + else + formatAppend(result, "GenericTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const FreeTypeVar* ftv = get(ty)) + { + formatAppend(result, "FreeTypeVar %d", ftv->index); + finishNodeLabel(ty); + finishNode(); + } + else if (const AnyTypeVar* atv = get(ty)) + { + formatAppend(result, "AnyTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const PrimitiveTypeVar* ptv = get(ty)) + { + formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if (const ErrorTypeVar* etv = get(ty)) + { + formatAppend(result, "ErrorTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const ClassTypeVar* ctv = get(ty)) + { + formatAppend(result, "ClassTypeVar %s", ctv->name.c_str()); + finishNodeLabel(ty); + finishNode(); + + for (const auto& [name, prop] : ctv->props) + visitChild(prop.type, index, name.c_str()); + + if (ctv->parent) + visitChild(*ctv->parent, index, "[parent]"); + + if (ctv->metatable) + visitChild(*ctv->metatable, index, "[metatable]"); + } + else + { + LUAU_ASSERT(!"unknown type kind"); + finishNodeLabel(ty); + finishNode(); + } +} + +void StateDot::visitChildren(TypePackId tp, int index) +{ + if (seenTp.count(tp)) + return; + seenTp.insert(tp); + + startNode(index); + startNodeLabel(); + + if (const BoundTypePack* btp = get(tp)) + { + formatAppend(result, "BoundTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + visitChild(btp->boundTo, index); + } + else if (const TypePack* tpp = get(tp)) + { + formatAppend(result, "TypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + for (TypeId tv : tpp->head) + visitChild(tv, index); + if (tpp->tail) + visitChild(*tpp->tail, index, "tail"); + } + else if (const VariadicTypePack* vtp = get(tp)) + { + formatAppend(result, "VariadicTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + visitChild(vtp->ty, index); + } + else if (const FreeTypePack* ftp = get(tp)) + { + formatAppend(result, "FreeTypePack %d", ftp->index); + finishNodeLabel(tp); + finishNode(); + } + else if (const GenericTypePack* gtp = get(tp)) + { + if (gtp->explicitName) + formatAppend(result, "GenericTypePack %s", gtp->name.c_str()); + else + formatAppend(result, "GenericTypePack %d", gtp->index); + finishNodeLabel(tp); + finishNode(); + } + else if (const Unifiable::Error* etp = get(tp)) + { + formatAppend(result, "ErrorTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else + { + LUAU_ASSERT(!"unknown type pack kind"); + finishNodeLabel(tp); + finishNode(); + } +} + +} // namespace + +std::string toDot(TypeId ty, const ToDotOptions& opts) +{ + StateDot state{opts}; + + state.result = "digraph graphname {\n"; + state.visitChild(ty, 0); + state.result += "}"; + + return state.result; +} + +std::string toDot(TypePackId tp, const ToDotOptions& opts) +{ + StateDot state{opts}; + + state.result = "digraph graphname {\n"; + state.visitChild(tp, 0); + state.result += "}"; + + return state.result; +} + +std::string toDot(TypeId ty) +{ + return toDot(ty, {}); +} + +std::string toDot(TypePackId tp) +{ + return toDot(tp, {}); +} + +void dumpDot(TypeId ty) +{ + printf("%s\n", toDot(ty).c_str()); +} + +void dumpDot(TypePackId tp) +{ + printf("%s\n", toDot(tp).c_str()); +} + +const TypeLevel* getLevel(TypeId ty) +{ + ty = follow(ty); + + if (auto ftv = get(ty)) + return &ftv->level; + else if (auto ttv = get(ty)) + return &ttv->level; + else if (auto ftv = get(ty)) + return &ftv->level; + else + return nullptr; +} + +TypeLevel* getMutableLevel(TypeId ty) +{ + return const_cast(getLevel(ty)); +} + +struct QVarFinder +{ + mutable DenseHashSet seen; + + QVarFinder() + : seen(nullptr) + { + } + + bool hasSeen(const void* tv) const + { + if (seen.contains(tv)) + return true; + + seen.insert(tv); + return false; + } + + bool hasGeneric(TypeId tid) const + { + if (hasSeen(&tid->ty)) + return false; + + return Luau::visit(*this, tid->ty); + } + + bool hasGeneric(TypePackId tp) const + { + if (hasSeen(&tp->ty)) + return false; + + return Luau::visit(*this, tp->ty); + } + + bool operator()(const Unifiable::Free&) const + { + return false; + } + + bool operator()(const Unifiable::Bound& bound) const + { + return hasGeneric(bound.boundTo); + } + + bool operator()(const Unifiable::Generic&) const + { + return true; + } + bool operator()(const Unifiable::Error&) const + { + return false; + } + bool operator()(const PrimitiveTypeVar&) const + { + return false; + } + + bool operator()(const FunctionTypeVar& ftv) const + { + if (hasGeneric(ftv.argTypes)) + return true; + return hasGeneric(ftv.retType); + } + + bool operator()(const TableTypeVar& ttv) const + { + if (ttv.state == TableState::Generic) + return true; + + if (ttv.indexer) + { + if (hasGeneric(ttv.indexer->indexType)) + return true; + if (hasGeneric(ttv.indexer->indexResultType)) + return true; + } + + for (const auto& [_name, prop] : ttv.props) + { + if (hasGeneric(prop.type)) + return true; + } + + return false; + } + + bool operator()(const MetatableTypeVar& mtv) const + { + return hasGeneric(mtv.table) || hasGeneric(mtv.metatable); + } + + bool operator()(const ClassTypeVar& ctv) const + { + for (const auto& [name, prop] : ctv.props) + { + if (hasGeneric(prop.type)) + return true; + } + + if (ctv.parent) + return hasGeneric(*ctv.parent); + + return false; + } + + bool operator()(const AnyTypeVar&) const + { + return false; + } + + bool operator()(const UnionTypeVar& utv) const + { + for (TypeId tid : utv.options) + if (hasGeneric(tid)) + return true; + + return false; + } + + bool operator()(const IntersectionTypeVar& utv) const + { + for (TypeId tid : utv.parts) + if (hasGeneric(tid)) + return true; + + return false; + } + + bool operator()(const LazyTypeVar&) const + { + return false; + } + + bool operator()(const Unifiable::Bound& bound) const + { + return hasGeneric(bound.boundTo); + } + + bool operator()(const TypePack& pack) const + { + for (TypeId ty : pack.head) + if (hasGeneric(ty)) + return true; + + if (pack.tail) + return hasGeneric(*pack.tail); + + return false; + } + + bool operator()(const VariadicTypePack& pack) const + { + return hasGeneric(pack.ty); + } +}; + +const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) +{ + while (cls) + { + auto it = cls->props.find(name); + if (it != cls->props.end()) + return &it->second; + + if (cls->parent) + cls = get(*cls->parent); + else + return nullptr; + + LUAU_ASSERT(cls); + } + + return nullptr; +} + +bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) +{ + while (cls) + { + if (cls == parent) + return true; + else if (!cls->parent) + return false; + + cls = get(*cls->parent); + LUAU_ASSERT(cls); + } + + return false; +} + +bool hasGeneric(TypeId ty) +{ + return Luau::visit(QVarFinder{}, ty->ty); +} + +bool hasGeneric(TypePackId tp) +{ + return Luau::visit(QVarFinder{}, tp->ty); +} + +UnionTypeVarIterator::UnionTypeVarIterator(const UnionTypeVar* utv) +{ + LUAU_ASSERT(utv); + + if (!utv->options.empty()) + stack.push_front({utv, 0}); + + seen.insert(utv); +} + +UnionTypeVarIterator& UnionTypeVarIterator::operator++() +{ + advance(); + descend(); + return *this; +} + +UnionTypeVarIterator UnionTypeVarIterator::operator++(int) +{ + UnionTypeVarIterator copy = *this; + ++copy; + return copy; +} + +bool UnionTypeVarIterator::operator!=(const UnionTypeVarIterator& rhs) +{ + return !(*this == rhs); +} + +bool UnionTypeVarIterator::operator==(const UnionTypeVarIterator& rhs) +{ + if (!stack.empty() && !rhs.stack.empty()) + return stack.front() == rhs.stack.front(); + + return stack.empty() && rhs.stack.empty(); +} + +const TypeId& UnionTypeVarIterator::operator*() +{ + LUAU_ASSERT(!stack.empty()); + + descend(); + + auto [utv, currentIndex] = stack.front(); + LUAU_ASSERT(utv); + LUAU_ASSERT(currentIndex < utv->options.size()); + + const TypeId& ty = utv->options[currentIndex]; + LUAU_ASSERT(!get(follow(ty))); + return ty; +} + +void UnionTypeVarIterator::advance() +{ + while (!stack.empty()) + { + auto& [utv, currentIndex] = stack.front(); + ++currentIndex; + + if (currentIndex >= utv->options.size()) + stack.pop_front(); + else + break; + } +} + +void UnionTypeVarIterator::descend() +{ + while (!stack.empty()) + { + auto [utv, currentIndex] = stack.front(); + if (auto innerUnion = get(follow(utv->options[currentIndex]))) + { + // If we're about to descend into a cyclic UnionTypeVar, we should skip over this. + // Ideally this should never happen, but alas it does from time to time. :( + if (seen.find(innerUnion) != seen.end()) + advance(); + else + { + seen.insert(innerUnion); + stack.push_front({innerUnion, 0}); + } + + continue; + } + + break; + } +} + +UnionTypeVarIterator begin(const UnionTypeVar* utv) +{ + return UnionTypeVarIterator{utv}; +} + +UnionTypeVarIterator end(const UnionTypeVar* utv) +{ + return UnionTypeVarIterator{}; +} + +static std::vector DEPRECATED_filterMap(TypeId type, TypeIdPredicate predicate) +{ + std::vector result; + + if (auto utv = get(follow(type))) + { + for (TypeId option : utv) + { + if (auto out = predicate(follow(option))) + result.push_back(*out); + } + } + else if (auto out = predicate(follow(type))) + return {*out}; + + return result; +} + +static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) +{ + const char* options = "cdiouxXeEfgGqs"; + + std::vector result; + + for (size_t i = 0; i < size; ++i) + { + if (data[i] == '%') + { + i++; + + if (i < size && data[i] == '%') + continue; + + // we just ignore all characters (including flags/precision) up until first alphabetic character + while (i < size && !(data[i] > 0 && isalpha(data[i]))) + i++; + + if (i == size) + break; + + if (data[i] == 'q' || data[i] == 's') + result.push_back(typechecker.stringType); + else if (strchr(options, data[i])) + result.push_back(typechecker.numberType); + else + result.push_back(typechecker.errorType); + } + } + + return result; +} + +std::optional> magicFunctionFormat( + TypeChecker& typechecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +{ + auto [paramPack, _predicates] = exprResult; + + TypeArena& arena = typechecker.currentModule->internalTypes; + + AstExprConstantString* fmt = nullptr; + if (auto index = expr.func->as(); index && expr.self) + { + if (auto group = index->expr->as()) + fmt = group->expr->as(); + else + fmt = index->expr->as(); + } + + if (!expr.self && expr.args.size > 0) + fmt = expr.args.data[0]->as(); + + if (!fmt) + return std::nullopt; + + std::vector expected = parseFormatString(typechecker, fmt->value.data, fmt->value.size); + const auto& [params, tail] = flatten(paramPack); + + const size_t dataOffset = 1; + + // unify the prefix one argument at a time + for (size_t i = 0; i < expected.size() && i + dataOffset < params.size(); ++i) + { + Location location = expr.args.data[std::min(i, expr.args.size - 1)]->location; + + typechecker.unify(expected[i], params[i + dataOffset], location); + } + + // if we know the argument count or if we have too many arguments for sure, we can issue an error + const size_t actualParamSize = params.size() - dataOffset; + + if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) + typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); + + return ExprResult{arena.addTypePack({typechecker.stringType})}; +} + +std::vector filterMap(TypeId type, TypeIdPredicate predicate) +{ + if (!FFlag::LuauTypeGuardPeelsAwaySubclasses) + return DEPRECATED_filterMap(type, predicate); + + type = follow(type); + + if (auto utv = get(type)) + { + std::set options; + for (TypeId option : utv) + if (auto out = predicate(follow(option))) + options.insert(*out); + + return std::vector(options.begin(), options.end()); + } + else if (auto out = predicate(type)) + return {*out}; + + return {}; +} + +} // namespace Luau diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp new file mode 100644 index 0000000..f037351 --- /dev/null +++ b/Analysis/src/TypedAllocator.cpp @@ -0,0 +1,99 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypedAllocator.h" + +#include "Luau/Common.h" + +#ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include + +const size_t kPageSize = 4096; +#else +#include +#include + +const size_t kPageSize = sysconf(_SC_PAGESIZE); +#endif + +#include + +LUAU_FASTFLAG(DebugLuauFreezeArena) + +namespace Luau +{ + +static void* systemAllocateAligned(size_t size, size_t align) +{ +#ifdef _WIN32 + return _aligned_malloc(size, align); +#elif defined(__ANDROID__) // for Android 4.1 + return memalign(align, size); +#else + void* ptr; + return posix_memalign(&ptr, align, size) == 0 ? ptr : 0; +#endif +} + +static void systemDeallocateAligned(void* ptr) +{ +#ifdef _WIN32 + _aligned_free(ptr); +#else + free(ptr); +#endif +} + +static size_t pageAlign(size_t size) +{ + return (size + kPageSize - 1) & ~(kPageSize - 1); +} + +void* pagedAllocate(size_t size) +{ + if (FFlag::DebugLuauFreezeArena) + return systemAllocateAligned(pageAlign(size), kPageSize); + else + return ::operator new(size, std::nothrow); +} + +void pagedDeallocate(void* ptr) +{ + if (FFlag::DebugLuauFreezeArena) + systemDeallocateAligned(ptr); + else + ::operator delete(ptr); +} + +void pagedFreeze(void* ptr, size_t size) +{ + LUAU_ASSERT(FFlag::DebugLuauFreezeArena); + LUAU_ASSERT(uintptr_t(ptr) % kPageSize == 0); + +#ifdef _WIN32 + DWORD oldProtect; + BOOL rc = VirtualProtect(ptr, pageAlign(size), PAGE_READONLY, &oldProtect); + LUAU_ASSERT(rc); +#else + int rc = mprotect(ptr, pageAlign(size), PROT_READ); + LUAU_ASSERT(rc == 0); +#endif +} + +void pagedUnfreeze(void* ptr, size_t size) +{ + LUAU_ASSERT(FFlag::DebugLuauFreezeArena); + LUAU_ASSERT(uintptr_t(ptr) % kPageSize == 0); + +#ifdef _WIN32 + DWORD oldProtect; + BOOL rc = VirtualProtect(ptr, pageAlign(size), PAGE_READWRITE, &oldProtect); + LUAU_ASSERT(rc); +#else + int rc = mprotect(ptr, pageAlign(size), PROT_READ | PROT_WRITE); + LUAU_ASSERT(rc == 0); +#endif +} + +} // namespace Luau diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp new file mode 100644 index 0000000..cef0783 --- /dev/null +++ b/Analysis/src/Unifiable.cpp @@ -0,0 +1,67 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Unifiable.h" + +LUAU_FASTFLAG(LuauRankNTypes) + +namespace Luau +{ +namespace Unifiable +{ + +Free::Free(TypeLevel level) + : index(++nextIndex) + , level(level) +{ +} + +Free::Free(TypeLevel level, bool DEPRECATED_canBeGeneric) + : index(++nextIndex) + , level(level) + , DEPRECATED_canBeGeneric(DEPRECATED_canBeGeneric) +{ + LUAU_ASSERT(!FFlag::LuauRankNTypes); +} + +int Free::nextIndex = 0; + +Generic::Generic() + : index(++nextIndex) + , name("g" + std::to_string(index)) + , explicitName(false) +{ +} + +Generic::Generic(TypeLevel level) + : index(++nextIndex) + , level(level) + , name("g" + std::to_string(index)) + , explicitName(false) +{ +} + +Generic::Generic(const Name& name) + : index(++nextIndex) + , name(name) + , explicitName(true) +{ +} + +Generic::Generic(TypeLevel level, const Name& name) + : index(++nextIndex) + , level(level) + , name(name) + , explicitName(true) +{ +} + +int Generic::nextIndex = 0; + +Error::Error() + : index(++nextIndex) +{ +} + +int Error::nextIndex = 0; + +} // namespace Unifiable +} // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp new file mode 100644 index 0000000..89c3f80 --- /dev/null +++ b/Analysis/src/Unifier.cpp @@ -0,0 +1,1575 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Unifier.h" + +#include "Luau/Common.h" +#include "Luau/RecursionCounter.h" +#include "Luau/TypePack.h" +#include "Luau/TypeUtils.h" + +#include + +LUAU_FASTINT(LuauTypeInferRecursionLimit); +LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 0); +LUAU_FASTFLAGVARIABLE(LuauLogTableTypeVarBoundTo, false) +LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAGVARIABLE(LuauDontMutatePersistentFunctions, false) +LUAU_FASTFLAG(LuauRankNTypes) +LUAU_FASTFLAG(LuauStringMetatable) +LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) +LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) +LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) +LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) + +namespace Luau +{ + +static std::optional hasUnificationTooComplex(const ErrorVec& errors) +{ + auto isUnificationTooComplex = [](const TypeError& te) { + return nullptr != get(te); + }; + + auto it = std::find_if(errors.begin(), errors.end(), isUnificationTooComplex); + if (it == errors.end()) + return std::nullopt; + else + return *it; +} + +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler) + : types(types) + , mode(mode) + , globalScope(std::move(globalScope)) + , location(location) + , variance(variance) + , counters(std::make_shared()) + , iceHandler(iceHandler) +{ + LUAU_ASSERT(iceHandler); +} + +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, + Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters) + : types(types) + , mode(mode) + , globalScope(std::move(globalScope)) + , log(seen) + , location(location) + , variance(variance) + , counters(counters ? counters : std::make_shared()) + , iceHandler(iceHandler) +{ + LUAU_ASSERT(iceHandler); +} + +void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) +{ + counters->iterationCount = 0; + return tryUnify_(superTy, subTy, isFunctionCall, isIntersection); +} + +void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) +{ + RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + + ++counters->iterationCount; + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + { + errors.push_back(TypeError{location, UnificationTooComplex{}}); + return; + } + + superTy = follow(superTy); + subTy = follow(subTy); + + if (superTy == subTy) + return; + + auto l = getMutable(superTy); + auto r = getMutable(subTy); + + if (l && r && l->level.subsumes(r->level)) + { + occursCheck(subTy, superTy); + + if (!get(subTy)) + { + log(subTy); + *asMutable(subTy) = BoundTypeVar(superTy); + } + + if (!FFlag::LuauRankNTypes) + l->DEPRECATED_canBeGeneric &= r->DEPRECATED_canBeGeneric; + + return; + } + else if (l && r && FFlag::LuauGenericFunctions) + { + log(superTy); + occursCheck(superTy, subTy); + if (!FFlag::LuauRankNTypes) + r->DEPRECATED_canBeGeneric &= l->DEPRECATED_canBeGeneric; + r->level = min(r->level, l->level); + *asMutable(superTy) = BoundTypeVar(subTy); + return; + } + else if (l) + { + occursCheck(superTy, subTy); + + // Unification can't change the level of a generic. + auto rightGeneric = get(subTy); + if (FFlag::LuauRankNTypes && rightGeneric && !rightGeneric->level.subsumes(l->level)) + { + // TODO: a more informative error message? CLI-39912 + errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); + return; + } + + if (!get(superTy)) + { + if (auto rightLevel = getMutableLevel(subTy)) + { + if (!rightLevel->subsumes(l->level)) + *rightLevel = l->level; + } + + log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + } + return; + } + else if (r) + { + occursCheck(subTy, superTy); + + // Unification can't change the level of a generic. + auto leftGeneric = get(superTy); + if (FFlag::LuauRankNTypes && leftGeneric && !leftGeneric->level.subsumes(r->level)) + { + // TODO: a more informative error message? CLI-39912 + errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); + return; + } + + // This is the old code which is just wrong + auto wrongGeneric = get(subTy); // Guaranteed to be null + if (!FFlag::LuauRankNTypes && FFlag::LuauGenericFunctions && wrongGeneric && r->level.subsumes(wrongGeneric->level)) + { + // This code is unreachable! Should we just remove it? + // TODO: a more informative error message? CLI-39912 + errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); + return; + } + + // Check if we're unifying a monotype with a polytype + if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !r->DEPRECATED_canBeGeneric && isGeneric(superTy)) + { + // TODO: a more informative error message? CLI-39912 + errors.push_back(TypeError{location, GenericError{"Failed to unify a polytype with a monotype"}}); + return; + } + + if (!get(subTy)) + { + if (auto leftLevel = getMutableLevel(superTy)) + { + if (!leftLevel->subsumes(r->level)) + *leftLevel = r->level; + } + + log(subTy); + *asMutable(subTy) = BoundTypeVar(superTy); + } + + return; + } + + if (get(superTy) || get(superTy)) + return tryUnifyWithAny(superTy, subTy); + + if (get(subTy) || get(subTy)) + return tryUnifyWithAny(subTy, superTy); + + // If we have seen this pair of types before, we are currently recursing into cyclic types. + // Here, we assume that the types unify. If they do not, we will find out as we roll back + // the stack. + + if (log.haveSeen(superTy, subTy)) + return; + + log.pushSeen(superTy, subTy); + + if (const UnionTypeVar* uv = get(subTy)) + { + // A | B <: T if A <: T and B <: T + bool failed = false; + std::optional unificationTooComplex; + + size_t count = uv->options.size(); + size_t i = 0; + + for (TypeId type : uv->options) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(superTy, type); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) + failed = true; + + if (i != count - 1) + innerState.log.rollback(); + else + log.concat(std::move(innerState.log)); + + ++i; + } + + if (unificationTooComplex) + errors.push_back(*unificationTooComplex); + else if (failed) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } + else if (const UnionTypeVar* uv = get(superTy)) + { + // T <: A | B if T <: A or T <: B + bool found = false; + std::optional unificationTooComplex; + + size_t startIndex = 0; + + if (FFlag::LuauUnionHeuristic) + { + const std::string* subName = getName(subTy); + if (subName) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + const std::string* optionName = getName(uv->options[i]); + if (optionName && *optionName == *subName) + { + startIndex = i; + break; + } + } + } + } + + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[(i + startIndex) % uv->options.size()]; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, subTy, isFunctionCall); + + if (innerState.errors.empty()) + { + found = true; + log.concat(std::move(innerState.log)); + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + { + unificationTooComplex = e; + } + + innerState.log.rollback(); + } + + if (unificationTooComplex) + errors.push_back(*unificationTooComplex); + else if (!found) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } + else if (const IntersectionTypeVar* uv = get(superTy)) + { + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) + { + tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + } + } + else if (const IntersectionTypeVar* uv = get(subTy)) + { + // A & B <: T if T <: A or T <: B + bool found = false; + std::optional unificationTooComplex; + + for (TypeId type : uv->parts) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(superTy, type, isFunctionCall); + + if (innerState.errors.empty()) + { + found = true; + log.concat(std::move(innerState.log)); + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + { + unificationTooComplex = e; + } + + innerState.log.rollback(); + } + + if (unificationTooComplex) + errors.push_back(*unificationTooComplex); + else if (!found) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } + else if (get(superTy) && get(subTy)) + tryUnifyPrimitives(superTy, subTy); + + else if (get(superTy) && get(subTy)) + tryUnifyFunctions(superTy, subTy, isFunctionCall); + + else if (get(superTy) && get(subTy)) + tryUnifyTables(superTy, subTy, isIntersection); + + // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. + else if (get(superTy)) + tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false); + else if (get(subTy)) + tryUnifyWithMetatable(subTy, superTy, /*reversed*/ true); + + else if (get(superTy)) + tryUnifyWithClass(superTy, subTy, /*reversed*/ false); + + // Unification of nonclasses with classes is almost, but not quite symmetrical. + // The order in which we perform this test is significant in the case that both types are classes. + else if (get(subTy)) + tryUnifyWithClass(superTy, subTy, /*reversed*/ true); + + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + + log.popSeen(superTy, subTy); +} + +struct WeirdIter +{ + TypePackId packId; + const TypePack* pack; + size_t index; + bool growing; + TypeLevel level; + + WeirdIter(TypePackId packId) + : packId(packId) + , pack(get(packId)) + , index(0) + , growing(false) + { + while (pack && pack->head.empty() && pack->tail) + { + packId = *pack->tail; + pack = get(packId); + } + } + + WeirdIter(const WeirdIter&) = default; + + const TypeId& operator*() + { + LUAU_ASSERT(good()); + return pack->head[index]; + } + + bool good() const + { + return pack != nullptr && index < pack->head.size(); + } + + bool advance() + { + if (!pack) + return good(); + + if (index < pack->head.size()) + ++index; + + if (growing || index < pack->head.size()) + return good(); + + if (pack->tail) + { + packId = follow(*pack->tail); + pack = get(packId); + index = 0; + } + + return good(); + } + + bool canGrow() const + { + return nullptr != get(packId); + } + + void grow(TypePackId newTail) + { + LUAU_ASSERT(canGrow()); + level = get(packId)->level; + *asMutable(packId) = Unifiable::Bound(newTail); + packId = newTail; + pack = get(newTail); + index = 0; + growing = true; + } +}; + +ErrorVec Unifier::canUnify(TypeId superTy, TypeId subTy) +{ + Unifier s = makeChildUnifier(); + s.tryUnify_(superTy, subTy); + s.log.rollback(); + return s.errors; +} + +ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall) +{ + Unifier s = makeChildUnifier(); + s.tryUnify_(superTy, subTy, isFunctionCall); + s.log.rollback(); + return s.errors; +} + +void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) +{ + counters->iterationCount = 0; + return tryUnify_(superTp, subTp, isFunctionCall); +} + +/* + * This is quite tricky: we are walking two rope-like structures and unifying corresponding elements. + * If one is longer than the other, but the short end is free, we grow it to the required length. + */ +void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) +{ + RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + + ++counters->iterationCount; + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + { + errors.push_back(TypeError{location, UnificationTooComplex{}}); + return; + } + + superTp = follow(superTp); + subTp = follow(subTp); + + while (auto r = get(subTp)) + { + if (r->head.empty() && r->tail) + subTp = follow(*r->tail); + else + break; + } + + while (auto l = get(superTp)) + { + if (l->head.empty() && l->tail) + superTp = follow(*l->tail); + else + break; + } + + if (superTp == subTp) + return; + + if (get(superTp)) + { + occursCheck(superTp, subTp); + + if (!get(superTp)) + { + log(superTp); + *asMutable(superTp) = Unifiable::Bound(subTp); + } + } + + else if (get(subTp)) + { + occursCheck(subTp, superTp); + + if (!get(subTp)) + { + log(subTp); + *asMutable(subTp) = Unifiable::Bound(superTp); + } + } + + else if (get(superTp)) + tryUnifyWithAny(superTp, subTp); + + else if (get(subTp)) + tryUnifyWithAny(subTp, superTp); + + else if (get(superTp)) + tryUnifyVariadics(superTp, subTp, false); + else if (get(subTp)) + tryUnifyVariadics(subTp, superTp, true); + + else if (get(superTp) && get(subTp)) + { + auto l = get(superTp); + auto r = get(subTp); + + // If the size of two heads does not match, but both packs have free tail + // We set the sentinel variable to say so to avoid growing it forever. + auto [superTypes, superTail] = flatten(superTp); + auto [subTypes, subTail] = flatten(subTp); + + bool noInfiniteGrowth = + (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); + + auto superIter = WeirdIter{superTp}; + auto subIter = WeirdIter{subTp}; + + auto mkFreshType = [this](TypeLevel level) { + return types->freshType(level); + }; + + const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); + + int loopCount = 0; + + do + { + if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) + ice("Detected possibly infinite TypePack growth"); + + ++loopCount; + + if (superIter.good() && subIter.growing) + asMutable(subIter.pack)->head.push_back(mkFreshType(subIter.level)); + + if (subIter.good() && superIter.growing) + asMutable(superIter.pack)->head.push_back(mkFreshType(superIter.level)); + + if (superIter.good() && subIter.good()) + { + tryUnify_(*superIter, *subIter); + superIter.advance(); + subIter.advance(); + continue; + } + + // If both are at the end, we're done + if (!superIter.good() && !subIter.good()) + { + const bool lFreeTail = l->tail && get(FFlag::LuauAddMissingFollow ? follow(*l->tail) : *l->tail) != nullptr; + const bool rFreeTail = r->tail && get(FFlag::LuauAddMissingFollow ? follow(*r->tail) : *r->tail) != nullptr; + if (lFreeTail && rFreeTail) + tryUnify_(*l->tail, *r->tail); + else if (lFreeTail) + tryUnify_(*l->tail, emptyTp); + else if (rFreeTail) + tryUnify_(*r->tail, emptyTp); + + break; + } + + // If both tails are free, bind one to the other and call it a day + if (superIter.canGrow() && subIter.canGrow()) + return tryUnify_(*superIter.pack->tail, *subIter.pack->tail); + + // If just one side is free on its tail, grow it to fit the other side. + // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. + if (superIter.canGrow()) + superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + + else if (subIter.canGrow()) + subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + + else + { + // A union type including nil marks an optional argument + if (superIter.good() && isOptional(*superIter)) + { + superIter.advance(); + continue; + } + else if (subIter.good() && isOptional(*subIter)) + { + subIter.advance(); + continue; + } + + // In nonstrict mode, any also marks an optional argument. + else if (superIter.good() && isNonstrictMode() && get(FFlag::LuauAddMissingFollow ? follow(*superIter) : *superIter)) + { + superIter.advance(); + continue; + } + + if (get(superIter.packId)) + { + tryUnifyVariadics(superIter.packId, subIter.packId, false, int(subIter.index)); + return; + } + + if (get(subIter.packId)) + { + tryUnifyVariadics(subIter.packId, superIter.packId, true, int(superIter.index)); + return; + } + + if (!isFunctionCall && subIter.good()) + { + // Sometimes it is ok to pass too many arguments + return; + } + + // This is a bit weird because we don't actually know expected vs actual. We just know + // subtype vs supertype. If we are checking a return value, we swap these to produce + // the expected error message. + size_t expectedSize = size(superTp); + size_t actualSize = size(subTp); + if (ctx == CountMismatch::Result || ctx == CountMismatch::Return) + std::swap(expectedSize, actualSize); + errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + + while (superIter.good()) + { + tryUnify_(singletonTypes.errorType, *superIter); + superIter.advance(); + } + + while (subIter.good()) + { + tryUnify_(singletonTypes.errorType, *subIter); + subIter.advance(); + } + + return; + } + + } while (!noInfiniteGrowth); + } + else + { + errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + } +} + +void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy) +{ + const PrimitiveTypeVar* lp = get(superTy); + const PrimitiveTypeVar* rp = get(subTy); + if (!lp || !rp) + ice("passed non primitive types to unifyPrimitives"); + + if (lp->type != rp->type) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); +} + +void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) +{ + FunctionTypeVar* lf = getMutable(superTy); + FunctionTypeVar* rf = getMutable(subTy); + if (!lf || !rf) + ice("passed non-function types to unifyFunction"); + + size_t numGenerics = lf->generics.size(); + if (FFlag::LuauGenericFunctions && numGenerics != rf->generics.size()) + { + numGenerics = std::min(lf->generics.size(), rf->generics.size()); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } + + size_t numGenericPacks = lf->genericPacks.size(); + if (FFlag::LuauGenericFunctions && numGenericPacks != rf->genericPacks.size()) + { + numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } + + if (FFlag::LuauGenericFunctions) + { + for (size_t i = 0; i < numGenerics; i++) + log.pushSeen(lf->generics[i], rf->generics[i]); + } + + CountMismatch::Context context = ctx; + + if (!isFunctionCall) + { + Unifier innerState = makeChildUnifier(); + + ctx = CountMismatch::Arg; + innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + + ctx = CountMismatch::Result; + innerState.tryUnify_(lf->retType, rf->retType); + + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + + log.concat(std::move(innerState.log)); + } + else + { + ctx = CountMismatch::Arg; + tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + + ctx = CountMismatch::Result; + tryUnify_(lf->retType, rf->retType); + } + + if (lf->definition && !rf->definition && (!FFlag::LuauDontMutatePersistentFunctions || !subTy->persistent)) + { + rf->definition = lf->definition; + } + else if (!lf->definition && rf->definition && (!FFlag::LuauDontMutatePersistentFunctions || !superTy->persistent)) + { + lf->definition = rf->definition; + } + + ctx = context; + + if (FFlag::LuauGenericFunctions) + { + for (int i = int(numGenerics) - 1; 0 <= i; i--) + log.popSeen(lf->generics[i], rf->generics[i]); + } +} + +namespace +{ + +struct Resetter +{ + explicit Resetter(Variance* variance) + : oldValue(*variance) + , variance(variance) + { + } + + Variance oldValue; + Variance* variance; + + ~Resetter() + { + *variance = oldValue; + } +}; + +} // namespace + +void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +{ + std::unique_ptr resetter; + + resetter.reset(new Resetter{&variance}); + variance = Invariant; + + TableTypeVar* lt = getMutable(left); + TableTypeVar* rt = getMutable(right); + if (!lt || !rt) + ice("passed non-table types to unifyTables"); + + if (lt->state == TableState::Sealed && rt->state == TableState::Sealed) + return tryUnifySealedTables(left, right, isIntersection); + else if ((lt->state == TableState::Sealed && rt->state == TableState::Unsealed) || + (lt->state == TableState::Unsealed && rt->state == TableState::Sealed)) + return tryUnifySealedTables(left, right, isIntersection); + else if ((lt->state == TableState::Sealed && rt->state == TableState::Generic) || + (lt->state == TableState::Generic && rt->state == TableState::Sealed)) + errors.push_back(TypeError{location, TypeMismatch{left, right}}); + else if ((lt->state == TableState::Free) != (rt->state == TableState::Free)) // one table is free and the other is not + { + TypeId freeTypeId = rt->state == TableState::Free ? right : left; + TypeId otherTypeId = rt->state == TableState::Free ? left : right; + + return tryUnifyFreeTable(freeTypeId, otherTypeId); + } + else if (lt->state == TableState::Free && rt->state == TableState::Free) + { + tryUnifyFreeTable(left, right); + + // avoid creating a cycle when the types are already pointing at each other + if (follow(left) != follow(right)) + { + log(lt); + lt->boundTo = right; + } + return; + } + else if (lt->state != TableState::Sealed && rt->state != TableState::Sealed) + { + // All free tables are checked in one of the branches above + LUAU_ASSERT(lt->state != TableState::Free); + LUAU_ASSERT(rt->state != TableState::Free); + + // Tables must have exactly the same props and their types must all unify + // I honestly have no idea if this is remotely close to reasonable. + for (const auto& [name, prop] : lt->props) + { + const auto& r = rt->props.find(name); + if (r == rt->props.end()) + errors.push_back(TypeError{location, UnknownProperty{right, name}}); + else + tryUnify_(prop.type, r->second.type); + } + + if (lt->indexer && rt->indexer) + tryUnify(*lt->indexer, *rt->indexer); + else if (lt->indexer) + { + // passing/assigning a table without an indexer to something that has one + // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. + if (rt->state == TableState::Unsealed) + rt->indexer = lt->indexer; + else + errors.push_back(TypeError{location, CannotExtendTable{right, CannotExtendTable::Indexer}}); + } + } + else if (lt->state == TableState::Sealed) + { + // lt is sealed and so it must be possible for rt to have precisely the same shape + // Verify that this is the case, then bind rt to lt. + ice("unsealed tables are not working yet", location); + } + else if (rt->state == TableState::Sealed) + return tryUnifyTables(right, left, isIntersection); + else + ice("tryUnifyTables"); +} + +void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) +{ + TableTypeVar* freeTable = getMutable(freeTypeId); + TableTypeVar* otherTable = getMutable(otherTypeId); + if (!freeTable || !otherTable) + ice("passed non-table types to tryUnifyFreeTable"); + + // Any properties in freeTable must unify with those in otherTable. + // Then bind freeTable to otherTable. + for (const auto& [freeName, freeProp] : freeTable->props) + { + if (auto otherProp = findTablePropertyRespectingMeta(otherTypeId, freeName)) + { + tryUnify_(*otherProp, freeProp.type); + + /* + * TypeVars are commonly cyclic, so it is entirely possible + * for unifying a property of a table to change the table itself! + * We need to check for this and start over if we notice this occurring. + * + * I believe this is guaranteed to terminate eventually because this will + * only happen when a free table is bound to another table. + */ + if (!get(freeTypeId) || !get(otherTypeId)) + return tryUnify_(freeTypeId, otherTypeId); + + if (freeTable->boundTo) + return tryUnify_(freeTypeId, otherTypeId); + } + else + { + // If the other table is also free, then we are learning that it has more + // properties than we previously thought. Else, it is an error. + if (otherTable->state == TableState::Free) + otherTable->props.insert({freeName, freeProp}); + else + errors.push_back(TypeError{location, UnknownProperty{otherTypeId, freeName}}); + } + } + + if (freeTable->indexer && otherTable->indexer) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify(*freeTable->indexer, *otherTable->indexer); + + checkChildUnifierTypeMismatch(innerState.errors, freeTypeId, otherTypeId); + + log.concat(std::move(innerState.log)); + } + else if (otherTable->state == TableState::Free && freeTable->indexer) + freeTable->indexer = otherTable->indexer; + + if (!freeTable->boundTo && otherTable->state != TableState::Free) + { + if (FFlag::LuauLogTableTypeVarBoundTo) + log(freeTable); + else + log(freeTypeId); + freeTable->boundTo = otherTypeId; + } +} + +void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection) +{ + TableTypeVar* lt = getMutable(left); + TableTypeVar* rt = getMutable(right); + if (!lt || !rt) + ice("passed non-table types to unifySealedTables"); + + Unifier innerState = makeChildUnifier(); + + std::vector missingPropertiesInSuper; + bool isUnnamedTable = rt->name == std::nullopt && rt->syntheticName == std::nullopt; + bool errorReported = false; + + // Optimization: First test that the property sets are compatible without doing any recursive unification + if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer) + { + for (const auto& [propName, superProp] : lt->props) + { + auto subIter = rt->props.find(propName); + if (subIter == rt->props.end() && !isOptional(superProp.type)) + missingPropertiesInSuper.push_back(propName); + } + + if (!missingPropertiesInSuper.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingPropertiesInSuper)}}); + return; + } + } + + // Tables must have exactly the same props and their types must all unify + for (const auto& it : lt->props) + { + const auto& r = rt->props.find(it.first); + if (r == rt->props.end()) + { + if (FFlag::LuauSealedTableUnifyOptionalFix) + { + if (isOptional(it.second.type)) + continue; + } + else + { + if (get(it.second.type)) + { + const UnionTypeVar* possiblyOptional = get(it.second.type); + const std::vector& options = possiblyOptional->options; + if (options.end() != std::find_if(options.begin(), options.end(), isNil)) + continue; + } + } + + missingPropertiesInSuper.push_back(it.first); + + innerState.errors.push_back(TypeError{location, TypeMismatch{left, right}}); + } + else + { + if (isUnnamedTable && r->second.location) + { + size_t oldErrorSize = innerState.errors.size(); + Location old = innerState.location; + innerState.location = *r->second.location; + innerState.tryUnify_(it.second.type, r->second.type); + innerState.location = old; + + if (oldErrorSize != innerState.errors.size() && !errorReported) + { + errorReported = true; + errors.push_back(innerState.errors.back()); + } + } + else + { + innerState.tryUnify_(it.second.type, r->second.type); + } + } + } + + if (lt->indexer || rt->indexer) + { + if (lt->indexer && rt->indexer) + innerState.tryUnify(*lt->indexer, *rt->indexer); + else if (rt->state == TableState::Unsealed) + { + if (lt->indexer && !rt->indexer) + rt->indexer = lt->indexer; + } + else if (lt->state == TableState::Unsealed) + { + if (rt->indexer && !lt->indexer) + lt->indexer = rt->indexer; + } + else if (lt->indexer) + { + innerState.tryUnify_(lt->indexer->indexType, singletonTypes.stringType); + // We already try to unify properties in both tables. + // Skip those and just look for the ones remaining and see if they fit into the indexer. + for (const auto& [name, type] : rt->props) + { + const auto& it = lt->props.find(name); + if (it == lt->props.end()) + innerState.tryUnify_(lt->indexer->indexResultType, type.type); + } + } + else + innerState.errors.push_back(TypeError{location, TypeMismatch{left, right}}); + } + + log.concat(std::move(innerState.log)); + + if (errorReported) + return; + + if (!missingPropertiesInSuper.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingPropertiesInSuper)}}); + return; + } + + // If the superTy/left is an immediate part of an intersection type, do not do extra-property check. + // Otherwise, we would falsely generate an extra-property-error for 's' in this code: + // local a: {n: number} & {s: string} = {n=1, s=""} + // When checking agaist the table '{n: number}'. + if (!isIntersection && lt->state != TableState::Unsealed && !lt->indexer) + { + // Check for extra properties in the subTy + std::vector extraPropertiesInSub; + + for (const auto& it : rt->props) + { + const auto& r = lt->props.find(it.first); + if (r == lt->props.end()) + { + if (FFlag::LuauSealedTableUnifyOptionalFix) + { + if (isOptional(it.second.type)) + continue; + } + else + { + if (get(it.second.type)) + { + const UnionTypeVar* possiblyOptional = get(it.second.type); + const std::vector& options = possiblyOptional->options; + if (options.end() != std::find_if(options.begin(), options.end(), isNil)) + continue; + } + } + + extraPropertiesInSub.push_back(it.first); + } + } + + if (!extraPropertiesInSub.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraPropertiesInSub), MissingProperties::Extra}}); + return; + } + } + + checkChildUnifierTypeMismatch(innerState.errors, left, right); +} + +void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed) +{ + const MetatableTypeVar* lhs = get(metatable); + if (!lhs) + ice("tryUnifyMetatable invoked with non-metatable TypeVar"); + + TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other}}; + + if (const MetatableTypeVar* rhs = get(other)) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(lhs->table, rhs->table); + innerState.tryUnify_(lhs->metatable, rhs->metatable); + + checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); + + log.concat(std::move(innerState.log)); + } + else if (TableTypeVar* rhs = getMutable(other)) + { + switch (rhs->state) + { + case TableState::Free: + { + tryUnify_(lhs->table, other); + rhs->boundTo = metatable; + + break; + } + // We know the shape of sealed, unsealed, and generic tables; you can't add a metatable on to any of these. + case TableState::Sealed: + case TableState::Unsealed: + case TableState::Generic: + errors.push_back(mismatchError); + } + } + else if (get(other) || get(other)) + { + } + else + { + errors.push_back(mismatchError); + } +} + +// Class unification is almost, but not quite symmetrical. We use the 'reversed' boolean to indicate which scenario we are evaluating. +void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) +{ + if (reversed) + std::swap(superTy, subTy); + + auto fail = [&]() { + if (!reversed) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + else + errors.push_back(TypeError{location, TypeMismatch{subTy, superTy}}); + }; + + const ClassTypeVar* superClass = get(superTy); + if (!superClass) + ice("tryUnifyClass invoked with non-class TypeVar"); + + if (const ClassTypeVar* subClass = get(subTy)) + { + switch (variance) + { + case Covariant: + if (!isSubclass(subClass, superClass)) + return fail(); + return; + case Invariant: + if (subClass != superClass) + return fail(); + return; + } + ice("Illegal variance setting!"); + } + else if (TableTypeVar* table = getMutable(subTy)) + { + /** + * A free table is something whose shape we do not exactly know yet. + * Thus, it is entirely reasonable that we might discover that it is being used as some class type. + * In this case, the free table must indeed be that exact class. + * For this to hold, the table must not have any properties that the class does not. + * Further, all properties of the table should unify cleanly with the matching class properties. + * TODO: What does it mean for the table to have an indexer? (probably failure?) + * + * Tables that are not free are known to be actual tables. + */ + if (table->state != TableState::Free) + return fail(); + + bool ok = true; + + for (const auto& [propName, prop] : table->props) + { + const Property* classProp = lookupClassProp(superClass, propName); + if (!classProp) + { + ok = false; + errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); + tryUnify_(prop.type, singletonTypes.errorType); + } + else + tryUnify_(prop.type, classProp->type); + } + + if (table->indexer) + { + ok = false; + std::string msg = "Class " + superClass->name + " does not have an indexer"; + errors.push_back(TypeError{location, GenericError{msg}}); + } + + if (!ok) + return; + + log(table); + table->boundTo = superTy; + } + else + return fail(); +} + +void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer) +{ + tryUnify_(superIndexer.indexType, subIndexer.indexType); + tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); +} + +static void queueTypePack( + std::vector& queue, std::unordered_set& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) +{ + while (true) + { + if (FFlag::LuauAddMissingFollow) + a = follow(a); + + if (seenTypePacks.count(a)) + break; + seenTypePacks.insert(a); + + if (FFlag::LuauAddMissingFollow) + { + if (get(a)) + { + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + else if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } + } + else + { + if (get(a)) + { + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + + if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } + } + } +} + +void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool reversed, int subOffset) +{ + const VariadicTypePack* lv = get(superTp); + if (!lv) + ice("passed non-variadic pack to tryUnifyVariadics"); + + if (const VariadicTypePack* rv = get(subTp)) + tryUnify_(reversed ? rv->ty : lv->ty, reversed ? lv->ty : rv->ty); + else if (get(subTp)) + { + TypePackIterator rIter = begin(subTp); + TypePackIterator rEnd = end(subTp); + + std::advance(rIter, subOffset); + + while (rIter != rEnd) + { + tryUnify_(reversed ? *rIter : lv->ty, reversed ? lv->ty : *rIter); + ++rIter; + } + + if (std::optional maybeTail = rIter.tail()) + { + TypePackId tail = follow(*maybeTail); + if (get(tail)) + { + log(tail); + *asMutable(tail) = BoundTypePack{superTp}; + } + else if (const VariadicTypePack* vtp = get(tail)) + { + tryUnify_(lv->ty, vtp->ty); + } + else if (get(tail)) + { + errors.push_back(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); + } + else if (get(tail)) + { + // Nothing to do here. + } + else + { + ice("Unknown TypePack kind"); + } + } + } + else + { + errors.push_back(TypeError{location, GenericError{"Failed to unify variadic packs"}}); + } +} + +static void tryUnifyWithAny( + std::vector& queue, Unifier& state, std::unordered_set& seenTypePacks, TypeId anyType, TypePackId anyTypePack) +{ + std::unordered_set seen; + + while (!queue.empty()) + { + TypeId ty = follow(queue.back()); + queue.pop_back(); + if (seen.count(ty)) + continue; + seen.insert(ty); + + if (get(ty)) + { + state.log(ty); + *asMutable(ty) = BoundTypeVar{anyType}; + } + else if (auto fun = get(ty)) + { + queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = get(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = get(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (get(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = get(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = get(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. + } +} + +void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) +{ + LUAU_ASSERT(get(any) || get(any)); + + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); + + const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + + std::unordered_set seenTypePacks; + std::vector queue = {ty}; + + Luau::tryUnifyWithAny(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); +} + +void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) +{ + LUAU_ASSERT(get(any)); + + const TypeId anyTy = singletonTypes.errorType; + + std::unordered_set seenTypePacks; + std::vector queue; + + queueTypePack(queue, seenTypePacks, *this, ty, any); + + Luau::tryUnifyWithAny(queue, *this, seenTypePacks, anyTy, any); +} + +std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) +{ + return Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); +} + +std::optional Unifier::findMetatableEntry(TypeId type, std::string entry) +{ + type = follow(type); + + if (!FFlag::LuauStringMetatable) + { + if (const PrimitiveTypeVar* primType = get(type)) + { + if (primType->type != PrimitiveTypeVar::String || "__index" != entry) + return std::nullopt; + + auto found = globalScope->bindings.find(AstName{"string"}); + if (found == globalScope->bindings.end()) + return std::nullopt; + else + return found->second.typeId; + } + } + + std::optional metatable = getMetatable(type); + if (!metatable) + return std::nullopt; + + TypeId unwrapped = follow(*metatable); + + if (get(unwrapped)) + return singletonTypes.anyType; + + const TableTypeVar* mtt = getTableType(unwrapped); + if (!mtt) + { + errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}}); + return std::nullopt; + } + + auto it = mtt->props.find(entry); + if (it != mtt->props.end()) + return it->second.type; + else + return std::nullopt; +} + +void Unifier::occursCheck(TypeId needle, TypeId haystack) +{ + std::unordered_set seen; + return occursCheck(seen, needle, haystack); +} + +void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack) +{ + RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + + needle = follow(needle); + haystack = follow(haystack); + + if (seen.end() != seen.find(haystack)) + return; + + seen.insert(haystack); + + if (get(needle)) + return; + + if (!get(needle)) + ice("Expected needle to be free"); + + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + log(needle); + *asMutable(needle) = ErrorTypeVar{}; + return; + } + + auto check = [&](TypeId tv) { + occursCheck(seen, needle, tv); + }; + + if (get(haystack)) + return; + else if (auto a = get(haystack)) + { + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (TypeId ty : a->argTypes) + check(ty); + + for (TypeId ty : a->retType) + check(ty); + } + } + else if (auto a = get(haystack)) + { + for (TypeId ty : a->options) + check(ty); + } + else if (auto a = get(haystack)) + { + for (TypeId ty : a->parts) + check(ty); + } +} + +void Unifier::occursCheck(TypePackId needle, TypePackId haystack) +{ + std::unordered_set seen; + return occursCheck(seen, needle, haystack); +} + +void Unifier::occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack) +{ + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack) != seen.end()) + return; + + seen.insert(haystack); + + if (get(needle)) + return; + + if (!get(needle)) + ice("Expected needle pack to be free"); + + RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + + while (!get(haystack)) + { + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + log(needle); + *asMutable(needle) = ErrorTypeVar{}; + return; + } + + if (auto a = get(haystack)) + { + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (const auto& ty : a->head) + { + if (auto f = get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) + { + occursCheck(seen, needle, f->argTypes); + occursCheck(seen, needle, f->retType); + } + } + } + + if (a->tail) + { + haystack = follow(*a->tail); + continue; + } + } + break; + } +} + +Unifier Unifier::makeChildUnifier() +{ + return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters}; +} + +bool Unifier::isNonstrictMode() const +{ + return (mode == Mode::Nonstrict) || (mode == Mode::NoCheck); +} + +void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType) +{ + if (auto e = hasUnificationTooComplex(innerErrors)) + errors.push_back(*e); + else if (!innerErrors.empty()) + errors.push_back(TypeError{location, TypeMismatch{wantedType, givenType}}); +} + +void Unifier::ice(const std::string& message, const Location& location) +{ + iceHandler->ice(message, location); +} + +void Unifier::ice(const std::string& message) +{ + iceHandler->ice(message); +} + +} // namespace Luau diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h new file mode 100644 index 0000000..df38cfe --- /dev/null +++ b/Ast/include/Luau/Ast.h @@ -0,0 +1,1198 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" + +#include +#include + +#include + +namespace Luau +{ + +struct AstName +{ + const char* value; + + AstName() + : value(nullptr) + { + } + + explicit AstName(const char* value) + : value(value) + { + } + + bool operator==(const AstName& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const AstName& rhs) const + { + return value != rhs.value; + } + + bool operator==(const char* rhs) const + { + return value && strcmp(value, rhs) == 0; + } + + bool operator!=(const char* rhs) const + { + return !value || strcmp(value, rhs) != 0; + } + + bool operator<(const AstName& rhs) const + { + return (value && rhs.value) ? strcmp(value, rhs.value) < 0 : value < rhs.value; + } +}; + +class AstVisitor +{ +public: + virtual ~AstVisitor() {} + + virtual bool visit(class AstNode*) + { + return true; + } + + virtual bool visit(class AstExpr* node) + { + return visit((class AstNode*)node); + } + + virtual bool visit(class AstExprGroup* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprConstantNil* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprConstantBool* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprConstantNumber* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprConstantString* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprLocal* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprGlobal* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprVarargs* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprCall* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprIndexName* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprIndexExpr* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprFunction* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprTable* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprUnary* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprBinary* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprTypeAssertion* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprIfElse* node) + { + return visit((class AstExpr*)node); + } + virtual bool visit(class AstExprError* node) + { + return visit((class AstExpr*)node); + } + + virtual bool visit(class AstStat* node) + { + return visit((class AstNode*)node); + } + + virtual bool visit(class AstStatBlock* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatIf* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatWhile* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatRepeat* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatBreak* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatContinue* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatReturn* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatExpr* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatLocal* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatFor* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatForIn* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatAssign* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatCompoundAssign* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatFunction* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatLocalFunction* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatTypeAlias* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatDeclareFunction* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatDeclareGlobal* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatDeclareClass* node) + { + return visit((class AstStat*)node); + } + virtual bool visit(class AstStatError* node) + { + return visit((class AstStat*)node); + } + + // By default visiting type annotations is disabled; override this in your visitor if you need to! + virtual bool visit(class AstType* node) + { + return false; + } + + virtual bool visit(class AstTypeReference* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeTable* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeFunction* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeTypeof* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeUnion* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeIntersection* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeError* node) + { + return visit((class AstType*)node); + } + + virtual bool visit(class AstTypePack* node) + { + return false; + } + virtual bool visit(class AstTypePackVariadic* node) + { + return visit((class AstTypePack*)node); + } + virtual bool visit(class AstTypePackGeneric* node) + { + return visit((class AstTypePack*)node); + } +}; + +class AstType; + +struct AstLocal +{ + AstName name; + Location location; + AstLocal* shadow; + size_t functionDepth; + size_t loopDepth; + + AstType* annotation; + + AstLocal(const AstName& name, const Location& location, AstLocal* shadow, size_t functionDepth, size_t loopDepth, AstType* annotation) + : name(name) + , location(location) + , shadow(shadow) + , functionDepth(functionDepth) + , loopDepth(loopDepth) + , annotation(annotation) + { + } +}; + +template +struct AstArray +{ + T* data; + std::size_t size; + + const T* begin() const + { + return data; + } + const T* end() const + { + return data + size; + } +}; + +struct AstTypeList +{ + AstArray types; + // Null indicates no tail, not an untyped tail. + AstTypePack* tailType = nullptr; +}; + +using AstArgumentName = std::pair; // TODO: remove and replace when we get a common struct for this pair instead of AstName + +extern int gAstRttiIndex; + +template +struct AstRtti +{ + static const int value; +}; + +template +const int AstRtti::value = ++gAstRttiIndex; + +#define LUAU_RTTI(Class) \ + static int ClassIndex() \ + { \ + return AstRtti::value; \ + } + +class AstNode +{ +public: + explicit AstNode(int classIndex, const Location& location) + : classIndex(classIndex) + , location(location) + { + } + + virtual void visit(AstVisitor* visitor) = 0; + + virtual AstExpr* asExpr() + { + return nullptr; + } + virtual AstStat* asStat() + { + return nullptr; + } + virtual AstType* asType() + { + return nullptr; + } + + template + bool is() const + { + return classIndex == T::ClassIndex(); + } + template + T* as() + { + return classIndex == T::ClassIndex() ? static_cast(this) : nullptr; + } + template + const T* as() const + { + return classIndex == T::ClassIndex() ? static_cast(this) : nullptr; + } + + const int classIndex; + Location location; +}; + +class AstExpr : public AstNode +{ +public: + explicit AstExpr(int classIndex, const Location& location) + : AstNode(classIndex, location) + { + } + + AstExpr* asExpr() override + { + return this; + } +}; + +class AstStat : public AstNode +{ +public: + explicit AstStat(int classIndex, const Location& location) + : AstNode(classIndex, location) + , hasSemicolon(false) + { + } + + AstStat* asStat() override + { + return this; + } + + bool hasSemicolon; +}; + +class AstExprGroup : public AstExpr +{ +public: + LUAU_RTTI(AstExprGroup) + + explicit AstExprGroup(const Location& location, AstExpr* expr); + + void visit(AstVisitor* visitor) override; + + AstExpr* expr; +}; + +class AstExprConstantNil : public AstExpr +{ +public: + LUAU_RTTI(AstExprConstantNil) + + explicit AstExprConstantNil(const Location& location); + + void visit(AstVisitor* visitor) override; +}; + +class AstExprConstantBool : public AstExpr +{ +public: + LUAU_RTTI(AstExprConstantBool) + + AstExprConstantBool(const Location& location, bool value); + + void visit(AstVisitor* visitor) override; + + bool value; +}; + +class AstExprConstantNumber : public AstExpr +{ +public: + LUAU_RTTI(AstExprConstantNumber) + + AstExprConstantNumber(const Location& location, double value); + + void visit(AstVisitor* visitor) override; + + double value; +}; + +class AstExprConstantString : public AstExpr +{ +public: + LUAU_RTTI(AstExprConstantString) + + AstExprConstantString(const Location& location, const AstArray& value); + + void visit(AstVisitor* visitor) override; + + AstArray value; +}; + +class AstExprLocal : public AstExpr +{ +public: + LUAU_RTTI(AstExprLocal) + + AstExprLocal(const Location& location, AstLocal* local, bool upvalue); + + void visit(AstVisitor* visitor) override; + + AstLocal* local; + bool upvalue; +}; + +class AstExprGlobal : public AstExpr +{ +public: + LUAU_RTTI(AstExprGlobal) + + AstExprGlobal(const Location& location, const AstName& name); + + void visit(AstVisitor* visitor) override; + + AstName name; +}; + +class AstExprVarargs : public AstExpr +{ +public: + LUAU_RTTI(AstExprVarargs) + + AstExprVarargs(const Location& location); + + void visit(AstVisitor* visitor) override; +}; + +class AstExprCall : public AstExpr +{ +public: + LUAU_RTTI(AstExprCall) + + AstExprCall(const Location& location, AstExpr* func, const AstArray& args, bool self, const Location& argLocation); + + void visit(AstVisitor* visitor) override; + + AstExpr* func; + AstArray args; + bool self; + Location argLocation; +}; + +class AstExprIndexName : public AstExpr +{ +public: + LUAU_RTTI(AstExprIndexName) + + AstExprIndexName( + const Location& location, AstExpr* expr, const AstName& index, const Location& indexLocation, const Position& opPosition, char op); + + void visit(AstVisitor* visitor) override; + + AstExpr* expr; + AstName index; + Location indexLocation; + Position opPosition; + char op = '.'; +}; + +class AstExprIndexExpr : public AstExpr +{ +public: + LUAU_RTTI(AstExprIndexExpr) + + AstExprIndexExpr(const Location& location, AstExpr* expr, AstExpr* index); + + void visit(AstVisitor* visitor) override; + + AstExpr* expr; + AstExpr* index; +}; + +class AstExprFunction : public AstExpr +{ +public: + LUAU_RTTI(AstExprFunction) + + AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, + const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, + std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, + std::optional argLocation = std::nullopt); + + void visit(AstVisitor* visitor) override; + + AstArray generics; + AstArray genericPacks; + AstLocal* self; + AstArray args; + bool hasReturnAnnotation; + AstTypeList returnAnnotation; + bool vararg = false; + Location varargLocation; + AstTypePack* varargAnnotation; + + AstStatBlock* body; + + size_t functionDepth; + + AstName debugname; + + bool hasEnd = false; + std::optional argLocation; +}; + +class AstExprTable : public AstExpr +{ +public: + LUAU_RTTI(AstExprTable) + + struct Item + { + enum Kind + { + List, // foo, in which case key is a nullptr + Record, // foo=bar, in which case key is a AstExprConstantString + General, // [foo]=bar + }; + + Kind kind; + + AstExpr* key; // can be nullptr! + AstExpr* value; + }; + + AstExprTable(const Location& location, const AstArray& items); + + void visit(AstVisitor* visitor) override; + + AstArray items; +}; + +class AstExprUnary : public AstExpr +{ +public: + LUAU_RTTI(AstExprUnary) + + enum Op + { + Not, + Minus, + Len + }; + + AstExprUnary(const Location& location, Op op, AstExpr* expr); + + void visit(AstVisitor* visitor) override; + + Op op; + AstExpr* expr; +}; + +std::string toString(AstExprUnary::Op op); + +class AstExprBinary : public AstExpr +{ +public: + LUAU_RTTI(AstExprBinary) + + enum Op + { + Add, + Sub, + Mul, + Div, + Mod, + Pow, + Concat, + CompareNe, + CompareEq, + CompareLt, + CompareLe, + CompareGt, + CompareGe, + And, + Or + }; + + AstExprBinary(const Location& location, Op op, AstExpr* left, AstExpr* right); + + void visit(AstVisitor* visitor) override; + + Op op; + AstExpr* left; + AstExpr* right; +}; + +std::string toString(AstExprBinary::Op op); + +class AstExprTypeAssertion : public AstExpr +{ +public: + LUAU_RTTI(AstExprTypeAssertion) + + AstExprTypeAssertion(const Location& location, AstExpr* expr, AstType* annotation); + + void visit(AstVisitor* visitor) override; + + AstExpr* expr; + AstType* annotation; +}; + +class AstExprIfElse : public AstExpr +{ +public: + LUAU_RTTI(AstExprIfElse) + + AstExprIfElse(const Location& location, AstExpr* condition, bool hasThen, AstExpr* trueExpr, bool hasElse, AstExpr* falseExpr); + + void visit(AstVisitor* visitor) override; + + AstExpr* condition; + bool hasThen; + AstExpr* trueExpr; + bool hasElse; + AstExpr* falseExpr; +}; + +class AstStatBlock : public AstStat +{ +public: + LUAU_RTTI(AstStatBlock) + + AstStatBlock(const Location& location, const AstArray& body); + + void visit(AstVisitor* visitor) override; + + AstArray body; +}; + +class AstStatIf : public AstStat +{ +public: + LUAU_RTTI(AstStatIf) + + AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, bool hasThen, const Location& thenLocation, + const std::optional& elseLocation, bool hasEnd); + + void visit(AstVisitor* visitor) override; + + AstExpr* condition; + AstStatBlock* thenbody; + AstStat* elsebody; + + bool hasThen = false; + Location thenLocation; + + // Active for 'elseif' as well + bool hasElse = false; + Location elseLocation; + + bool hasEnd = false; +}; + +class AstStatWhile : public AstStat +{ +public: + LUAU_RTTI(AstStatWhile) + + AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation, bool hasEnd); + + void visit(AstVisitor* visitor) override; + + AstExpr* condition; + AstStatBlock* body; + + bool hasDo = false; + Location doLocation; + + bool hasEnd = false; +}; + +class AstStatRepeat : public AstStat +{ +public: + LUAU_RTTI(AstStatRepeat) + + AstStatRepeat(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasUntil); + + void visit(AstVisitor* visitor) override; + + AstExpr* condition; + AstStatBlock* body; + + bool hasUntil = false; +}; + +class AstStatBreak : public AstStat +{ +public: + LUAU_RTTI(AstStatBreak) + + AstStatBreak(const Location& location); + + void visit(AstVisitor* visitor) override; +}; + +class AstStatContinue : public AstStat +{ +public: + LUAU_RTTI(AstStatContinue) + + AstStatContinue(const Location& location); + + void visit(AstVisitor* visitor) override; +}; + +class AstStatReturn : public AstStat +{ +public: + LUAU_RTTI(AstStatReturn) + + AstStatReturn(const Location& location, const AstArray& list); + + void visit(AstVisitor* visitor) override; + + AstArray list; +}; + +class AstStatExpr : public AstStat +{ +public: + LUAU_RTTI(AstStatExpr) + + AstStatExpr(const Location& location, AstExpr* expr); + + void visit(AstVisitor* visitor) override; + + AstExpr* expr; +}; + +class AstStatLocal : public AstStat +{ +public: + LUAU_RTTI(AstStatLocal) + + AstStatLocal(const Location& location, const AstArray& vars, const AstArray& values, + const std::optional& equalsSignLocation); + + void visit(AstVisitor* visitor) override; + + AstArray vars; + AstArray values; + + bool hasEqualsSign = false; + Location equalsSignLocation; +}; + +class AstStatFor : public AstStat +{ +public: + LUAU_RTTI(AstStatFor) + + AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, + const Location& doLocation, bool hasEnd); + + void visit(AstVisitor* visitor) override; + + AstLocal* var; + AstExpr* from; + AstExpr* to; + AstExpr* step; + AstStatBlock* body; + + bool hasDo = false; + Location doLocation; + + bool hasEnd = false; +}; + +class AstStatForIn : public AstStat +{ +public: + LUAU_RTTI(AstStatForIn) + + AstStatForIn(const Location& location, const AstArray& vars, const AstArray& values, AstStatBlock* body, bool hasIn, + const Location& inLocation, bool hasDo, const Location& doLocation, bool hasEnd); + + void visit(AstVisitor* visitor) override; + + AstArray vars; + AstArray values; + AstStatBlock* body; + + bool hasIn = false; + Location inLocation; + + bool hasDo = false; + Location doLocation; + + bool hasEnd = false; +}; + +class AstStatAssign : public AstStat +{ +public: + LUAU_RTTI(AstStatAssign) + + AstStatAssign(const Location& location, const AstArray& vars, const AstArray& values); + + void visit(AstVisitor* visitor) override; + + AstArray vars; + AstArray values; +}; + +class AstStatCompoundAssign : public AstStat +{ +public: + LUAU_RTTI(AstStatCompoundAssign) + + AstStatCompoundAssign(const Location& location, AstExprBinary::Op op, AstExpr* var, AstExpr* value); + + void visit(AstVisitor* visitor) override; + + AstExprBinary::Op op; + AstExpr* var; + AstExpr* value; +}; + +class AstStatFunction : public AstStat +{ +public: + LUAU_RTTI(AstStatFunction) + + AstStatFunction(const Location& location, AstExpr* name, AstExprFunction* func); + + void visit(AstVisitor* visitor) override; + + AstExpr* name; + AstExprFunction* func; +}; + +class AstStatLocalFunction : public AstStat +{ +public: + LUAU_RTTI(AstStatLocalFunction) + + AstStatLocalFunction(const Location& location, AstLocal* name, AstExprFunction* func); + + void visit(AstVisitor* visitor) override; + + AstLocal* name; + AstExprFunction* func; +}; + +class AstStatTypeAlias : public AstStat +{ +public: + LUAU_RTTI(AstStatTypeAlias) + + AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported); + + void visit(AstVisitor* visitor) override; + + AstName name; + AstArray generics; + AstType* type; + bool exported; +}; + +class AstStatDeclareGlobal : public AstStat +{ +public: + LUAU_RTTI(AstStatDeclareGlobal) + + AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type); + + void visit(AstVisitor* visitor) override; + + AstName name; + AstType* type; +}; + +class AstStatDeclareFunction : public AstStat +{ +public: + LUAU_RTTI(AstStatDeclareFunction) + + AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, + const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes); + + void visit(AstVisitor* visitor) override; + + AstName name; + AstArray generics; + AstArray genericPacks; + AstTypeList params; + AstArray paramNames; + AstTypeList retTypes; +}; + +struct AstDeclaredClassProp +{ + AstName name; + AstType* ty = nullptr; + bool isMethod = false; +}; + +class AstStatDeclareClass : public AstStat +{ +public: + LUAU_RTTI(AstStatDeclareClass) + + AstStatDeclareClass(const Location& location, const AstName& name, std::optional superName, const AstArray& props); + + void visit(AstVisitor* visitor) override; + + AstName name; + std::optional superName; + + AstArray props; +}; + +class AstType : public AstNode +{ +public: + AstType(int classIndex, const Location& location) + : AstNode(classIndex, location) + { + } + + AstType* asType() override + { + return this; + } +}; + +class AstTypeReference : public AstType +{ +public: + LUAU_RTTI(AstTypeReference) + + AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics = {}); + + void visit(AstVisitor* visitor) override; + + bool hasPrefix; + AstName prefix; + AstName name; + AstArray generics; +}; + +struct AstTableProp +{ + AstName name; + Location location; + AstType* type; +}; + +struct AstTableIndexer +{ + AstType* indexType; + AstType* resultType; + Location location; +}; + +class AstTypeTable : public AstType +{ +public: + LUAU_RTTI(AstTypeTable) + + AstTypeTable(const Location& location, const AstArray& props, AstTableIndexer* indexer = nullptr); + + void visit(AstVisitor* visitor) override; + + AstArray props; + AstTableIndexer* indexer; +}; + +class AstTypeFunction : public AstType +{ +public: + LUAU_RTTI(AstTypeFunction) + + AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, + const AstArray>& argNames, const AstTypeList& returnTypes); + + void visit(AstVisitor* visitor) override; + + AstArray generics; + AstArray genericPacks; + AstTypeList argTypes; + AstArray> argNames; + AstTypeList returnTypes; +}; + +class AstTypeTypeof : public AstType +{ +public: + LUAU_RTTI(AstTypeTypeof) + + AstTypeTypeof(const Location& location, AstExpr* expr); + + void visit(AstVisitor* visitor) override; + + AstExpr* expr; +}; + +class AstTypeUnion : public AstType +{ +public: + LUAU_RTTI(AstTypeUnion) + + AstTypeUnion(const Location& location, const AstArray& types); + + void visit(AstVisitor* visitor) override; + + AstArray types; +}; + +class AstTypeIntersection : public AstType +{ +public: + LUAU_RTTI(AstTypeIntersection) + + AstTypeIntersection(const Location& location, const AstArray& types); + + void visit(AstVisitor* visitor) override; + + AstArray types; +}; + +class AstExprError : public AstExpr +{ +public: + LUAU_RTTI(AstExprError) + + AstExprError(const Location& location, const AstArray& expressions, unsigned messageIndex); + + void visit(AstVisitor* visitor) override; + + AstArray expressions; + unsigned messageIndex; +}; + +class AstStatError : public AstStat +{ +public: + LUAU_RTTI(AstStatError) + + AstStatError(const Location& location, const AstArray& expressions, const AstArray& statements, unsigned messageIndex); + + void visit(AstVisitor* visitor) override; + + AstArray expressions; + AstArray statements; + unsigned messageIndex; +}; + +class AstTypeError : public AstType +{ +public: + LUAU_RTTI(AstTypeError) + + AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex); + + void visit(AstVisitor* visitor) override; + + AstArray types; + bool isMissing; + unsigned messageIndex; +}; + +class AstTypePack : public AstNode +{ +public: + AstTypePack(int classIndex, const Location& location) + : AstNode(classIndex, location) + { + } +}; + +class AstTypePackVariadic : public AstTypePack +{ +public: + LUAU_RTTI(AstTypePackVariadic) + + AstTypePackVariadic(const Location& location, AstType* variadicType); + + void visit(AstVisitor* visitor) override; + + AstType* variadicType; +}; + +class AstTypePackGeneric : public AstTypePack +{ +public: + LUAU_RTTI(AstTypePackGeneric) + + AstTypePackGeneric(const Location& location, AstName name); + + void visit(AstVisitor* visitor) override; + + AstName genericName; +}; + +AstName getIdentifier(AstExpr*); + +#undef LUAU_RTTI + +} // namespace Luau + +namespace std +{ + +template<> +struct hash +{ + size_t operator()(const Luau::AstName& value) const + { + // note: since operator== uses pointer identity, hashing function uses it as well + return value.value ? std::hash()(value.value) : 0; + } +}; + +} // namespace std diff --git a/Ast/include/Luau/Common.h b/Ast/include/Luau/Common.h new file mode 100644 index 0000000..63cd3df --- /dev/null +++ b/Ast/include/Luau/Common.h @@ -0,0 +1,133 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +// Compiler codegen control macros +#ifdef _MSC_VER +#define LUAU_NORETURN __declspec(noreturn) +#define LUAU_NOINLINE __declspec(noinline) +#define LUAU_FORCEINLINE __forceinline +#define LUAU_LIKELY(x) x +#define LUAU_UNLIKELY(x) x +#define LUAU_UNREACHABLE() __assume(false) +#define LUAU_DEBUGBREAK() __debugbreak() +#else +#define LUAU_NORETURN __attribute__((__noreturn__)) +#define LUAU_NOINLINE __attribute__((noinline)) +#define LUAU_FORCEINLINE inline __attribute__((always_inline)) +#define LUAU_LIKELY(x) __builtin_expect(x, 1) +#define LUAU_UNLIKELY(x) __builtin_expect(x, 0) +#define LUAU_UNREACHABLE() __builtin_unreachable() +#define LUAU_DEBUGBREAK() __builtin_trap() +#endif + + + + + + + +namespace Luau +{ + +using AssertHandler = int (*)(const char* expression, const char* file, int line); + +inline AssertHandler& assertHandler() +{ + static AssertHandler handler = nullptr; + return handler; +} + +inline int assertCallHandler(const char* expression, const char* file, int line) +{ + if (AssertHandler handler = assertHandler()) + return handler(expression, file, line); + + return 1; +} + +} // namespace Luau + +#if !defined(NDEBUG) || defined(LUAU_ENABLE_ASSERT) +#define LUAU_ASSERT(expr) ((void)(!!(expr) || (Luau::assertCallHandler(#expr, __FILE__, __LINE__) && (LUAU_DEBUGBREAK(), 0)))) +#define LUAU_ASSERTENABLED +#else +#define LUAU_ASSERT(expr) (void)sizeof(!!(expr)) +#endif + +namespace Luau +{ + +template +struct FValue +{ + static FValue* list; + + T value; + bool dynamic; + const char* name; + FValue* next; + + FValue(const char* name, T def, bool dynamic, void (*reg)(const char*, T*, bool) = nullptr) + : value(def) + , dynamic(dynamic) + , name(name) + , next(list) + { + list = this; + + if (reg) + reg(name, &value, dynamic); + } + + operator T() const + { + return value; + } +}; + +template +FValue* FValue::list = nullptr; + +} // namespace Luau + +#define LUAU_FASTFLAG(flag) \ + namespace FFlag \ + { \ + extern Luau::FValue flag; \ + } +#define LUAU_FASTFLAGVARIABLE(flag, def) \ + namespace FFlag \ + { \ + Luau::FValue flag(#flag, def, false, nullptr); \ + } +#define LUAU_FASTINT(flag) \ + namespace FInt \ + { \ + extern Luau::FValue flag; \ + } +#define LUAU_FASTINTVARIABLE(flag, def) \ + namespace FInt \ + { \ + Luau::FValue flag(#flag, def, false, nullptr); \ + } + +#define LUAU_DYNAMIC_FASTFLAG(flag) \ + namespace DFFlag \ + { \ + extern Luau::FValue flag; \ + } +#define LUAU_DYNAMIC_FASTFLAGVARIABLE(flag, def) \ + namespace DFFlag \ + { \ + Luau::FValue flag(#flag, def, true, nullptr); \ + } +#define LUAU_DYNAMIC_FASTINT(flag) \ + namespace DFInt \ + { \ + extern Luau::FValue flag; \ + } +#define LUAU_DYNAMIC_FASTINTVARIABLE(flag, def) \ + namespace DFInt \ + { \ + Luau::FValue flag(#flag, def, true, nullptr); \ + } diff --git a/Ast/include/Luau/Confusables.h b/Ast/include/Luau/Confusables.h new file mode 100644 index 0000000..13f3497 --- /dev/null +++ b/Ast/include/Luau/Confusables.h @@ -0,0 +1,9 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ +const char* findConfusable(uint32_t codepoint); +} diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h new file mode 100644 index 0000000..02924e8 --- /dev/null +++ b/Ast/include/Luau/DenseHash.h @@ -0,0 +1,407 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include +#include +#include +#include +#include + +namespace Luau +{ + +// Internal implementation of DenseHashSet and DenseHashMap +namespace detail +{ + +struct DenseHashPointer +{ + size_t operator()(const void* key) const + { + return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9); + } +}; + +template +using DenseHashDefault = std::conditional_t, DenseHashPointer, std::hash>; + +template +class DenseHashTable +{ +public: + class const_iterator; + + DenseHashTable(const Key& empty_key, size_t buckets = 0) + : count(0) + , empty_key(empty_key) + { + // buckets has to be power-of-two or zero + LUAU_ASSERT((buckets & (buckets - 1)) == 0); + + // don't move this to initializer list! this works around an MSVC codegen issue on AMD CPUs: + // https://developercommunity.visualstudio.com/t/stdvector-constructor-from-size-t-is-25-times-slow/1546547 + if (buckets) + data.resize(buckets, ItemInterface::create(empty_key)); + } + + void clear() + { + data.clear(); + count = 0; + } + + Item* insert_unsafe(const Key& key) + { + // It is invalid to insert empty_key into the table since it acts as a "entry does not exist" marker + LUAU_ASSERT(!eq(key, empty_key)); + + size_t hashmod = data.size() - 1; + size_t bucket = hasher(key) & hashmod; + + for (size_t probe = 0; probe <= hashmod; ++probe) + { + Item& probe_item = data[bucket]; + + // Element does not exist, insert here + if (eq(ItemInterface::getKey(probe_item), empty_key)) + { + ItemInterface::setKey(probe_item, key); + count++; + return &probe_item; + } + + // Element already exists + if (eq(ItemInterface::getKey(probe_item), key)) + { + return &probe_item; + } + + // Hash collision, quadratic probing + bucket = (bucket + probe + 1) & hashmod; + } + + // Hash table is full - this should not happen + LUAU_ASSERT(false); + return NULL; + } + + const Item* find(const Key& key) const + { + if (data.empty()) + return 0; + if (eq(key, empty_key)) + return 0; + + size_t hashmod = data.size() - 1; + size_t bucket = hasher(key) & hashmod; + + for (size_t probe = 0; probe <= hashmod; ++probe) + { + const Item& probe_item = data[bucket]; + + // Element exists + if (eq(ItemInterface::getKey(probe_item), key)) + return &probe_item; + + // Element does not exist + if (eq(ItemInterface::getKey(probe_item), empty_key)) + return NULL; + + // Hash collision, quadratic probing + bucket = (bucket + probe + 1) & hashmod; + } + + // Hash table is full - this should not happen + LUAU_ASSERT(false); + return NULL; + } + + void rehash() + { + size_t newsize = data.empty() ? 16 : data.size() * 2; + + if (data.empty() && data.capacity() >= newsize) + { + LUAU_ASSERT(count == 0); + data.resize(newsize, ItemInterface::create(empty_key)); + return; + } + + DenseHashTable newtable(empty_key, newsize); + + for (size_t i = 0; i < data.size(); ++i) + { + const Key& key = ItemInterface::getKey(data[i]); + + if (!eq(key, empty_key)) + *newtable.insert_unsafe(key) = data[i]; + } + + LUAU_ASSERT(count == newtable.count); + data.swap(newtable.data); + } + + void rehash_if_full() + { + if (count >= data.size() * 3 / 4) + { + rehash(); + } + } + + const_iterator begin() const + { + size_t start = 0; + + while (start < data.size() && eq(ItemInterface::getKey(data[start]), empty_key)) + start++; + + return const_iterator(this, start); + } + + const_iterator end() const + { + return const_iterator(this, data.size()); + } + + size_t size() const + { + return count; + } + + class const_iterator + { + public: + const_iterator() + : set(0) + , index(0) + { + } + + const_iterator(const DenseHashTable* set, size_t index) + : set(set) + , index(index) + { + } + + const Item& operator*() const + { + return set->data[index]; + } + + const Item* operator->() const + { + return &set->data[index]; + } + + bool operator==(const const_iterator& other) const + { + return set == other.set && index == other.index; + } + + bool operator!=(const const_iterator& other) const + { + return set != other.set || index != other.index; + } + + const_iterator& operator++() + { + size_t size = set->data.size(); + + do + { + index++; + } while (index < size && set->eq(ItemInterface::getKey(set->data[index]), set->empty_key)); + + return *this; + } + + const_iterator operator++(int) + { + const_iterator res = *this; + ++*this; + return res; + } + + private: + const DenseHashTable* set; + size_t index; + }; + +private: + std::vector data; + size_t count; + Key empty_key; + Hash hasher; + Eq eq; +}; + +template +struct ItemInterfaceSet +{ + static const Key& getKey(const Key& item) + { + return item; + } + + static void setKey(Key& item, const Key& key) + { + item = key; + } + + static Key create(const Key& key) + { + return key; + } +}; + +template +struct ItemInterfaceMap +{ + static const Key& getKey(const std::pair& item) + { + return item.first; + } + + static void setKey(std::pair& item, const Key& key) + { + item.first = key; + } + + static std::pair create(const Key& key) + { + return std::pair(key, Value()); + } +}; + +} // namespace detail + +// This is a faster alternative of unordered_set, but it does not implement the same interface (i.e. it does not support erasing) +template, typename Eq = std::equal_to> +class DenseHashSet +{ + typedef detail::DenseHashTable, Hash, Eq> Impl; + Impl impl; + +public: + typedef typename Impl::const_iterator const_iterator; + + DenseHashSet(const Key& empty_key, size_t buckets = 0) + : impl(empty_key, buckets) + { + } + + void clear() + { + impl.clear(); + } + + const Key& insert(const Key& key) + { + impl.rehash_if_full(); + return *impl.insert_unsafe(key); + } + + const Key* find(const Key& key) const + { + return impl.find(key); + } + + bool contains(const Key& key) const + { + return impl.find(key) != 0; + } + + size_t size() const + { + return impl.size(); + } + + bool empty() const + { + return impl.size() == 0; + } + + const_iterator begin() const + { + return impl.begin(); + } + + const_iterator end() const + { + return impl.end(); + } +}; + +// This is a faster alternative of unordered_map, but it does not implement the same interface (i.e. it does not support erasing and has +// contains() instead of find()) +template, typename Eq = std::equal_to> +class DenseHashMap +{ + typedef detail::DenseHashTable, std::pair, detail::ItemInterfaceMap, Hash, Eq> Impl; + Impl impl; + +public: + typedef typename Impl::const_iterator const_iterator; + + DenseHashMap(const Key& empty_key, size_t buckets = 0) + : impl(empty_key, buckets) + { + } + + void clear() + { + impl.clear(); + } + + // Note: this reference is invalidated by any insert operation (i.e. operator[]) + Value& operator[](const Key& key) + { + impl.rehash_if_full(); + return impl.insert_unsafe(key)->second; + } + + // Note: this pointer is invalidated by any insert operation (i.e. operator[]) + const Value* find(const Key& key) const + { + const std::pair* result = impl.find(key); + + return result ? &result->second : NULL; + } + + // Note: this pointer is invalidated by any insert operation (i.e. operator[]) + Value* find(const Key& key) + { + const std::pair* result = impl.find(key); + + return result ? const_cast(&result->second) : NULL; + } + + bool contains(const Key& key) const + { + return impl.find(key) != 0; + } + + size_t size() const + { + return impl.size(); + } + + bool empty() const + { + return impl.size() == 0; + } + + const_iterator begin() const + { + return impl.begin(); + } + const_iterator end() const + { + return impl.end(); + } +}; + +} // namespace Luau diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h new file mode 100644 index 0000000..460ef05 --- /dev/null +++ b/Ast/include/Luau/Lexer.h @@ -0,0 +1,236 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Location.h" +#include "Luau/DenseHash.h" +#include "Luau/Common.h" + +namespace Luau +{ + +class Allocator +{ +public: + Allocator(); + Allocator(Allocator&&); + + Allocator& operator=(Allocator&&) = delete; + + ~Allocator(); + + void* allocate(size_t size); + + template + T* alloc(Args&&... args) + { + static_assert(std::is_trivially_destructible::value, "Objects allocated with this allocator will never have their destructors run!"); + + T* t = static_cast(allocate(sizeof(T))); + new (t) T(std::forward(args)...); + return t; + } + +private: + struct Page + { + Page* next; + + char data[8192]; + }; + + Page* root; + size_t offset; +}; + +struct Lexeme +{ + enum Type + { + Eof = 0, + + // 1..255 means actual character values + Char_END = 256, + + Equal, + LessEqual, + GreaterEqual, + NotEqual, + Dot2, + Dot3, + SkinnyArrow, + DoubleColon, + + AddAssign, + SubAssign, + MulAssign, + DivAssign, + ModAssign, + PowAssign, + ConcatAssign, + + RawString, + QuotedString, + Number, + Name, + + Comment, + BlockComment, + + BrokenString, + BrokenComment, + BrokenUnicode, + Error, + + Reserved_BEGIN, + ReservedAnd = Reserved_BEGIN, + ReservedBreak, + ReservedDo, + ReservedElse, + ReservedElseif, + ReservedEnd, + ReservedFalse, + ReservedFor, + ReservedFunction, + ReservedIf, + ReservedIn, + ReservedLocal, + ReservedNil, + ReservedNot, + ReservedOr, + ReservedRepeat, + ReservedReturn, + ReservedThen, + ReservedTrue, + ReservedUntil, + ReservedWhile, + Reserved_END + }; + + Type type; + Location location; + unsigned int length; + + union + { + const char* data; // String, Number, Comment + const char* name; // Name + unsigned int codepoint; // BrokenUnicode + }; + + Lexeme(const Location& location, Type type); + Lexeme(const Location& location, char character); + Lexeme(const Location& location, Type type, const char* data, size_t size); + Lexeme(const Location& location, Type type, const char* name); + + std::string toString() const; +}; + +class AstNameTable +{ +public: + AstNameTable(Allocator& allocator); + + AstName addStatic(const char* name, Lexeme::Type type = Lexeme::Name); + + std::pair getOrAddWithType(const char* name, size_t length); + std::pair getWithType(const char* name, size_t length) const; + + AstName getOrAdd(const char* name); + AstName get(const char* name) const; + +private: + struct Entry + { + AstName value; + uint32_t length; + Lexeme::Type type; + + bool operator==(const Entry& other) const; + }; + + struct EntryHash + { + size_t operator()(const Entry& e) const; + }; + + DenseHashSet data; + + Allocator& allocator; +}; + +class Lexer +{ +public: + Lexer(const char* buffer, std::size_t bufferSize, AstNameTable& names); + + void setSkipComments(bool skip); + void setReadNames(bool read); + + const Location& previousLocation() const + { + return prevLocation; + } + + const Lexeme& next(); + const Lexeme& next(bool skipComments); + void nextline(); + + Lexeme lookahead(); + + const Lexeme& current() const + { + return lexeme; + } + + static bool isReserved(const std::string& word); + + static bool fixupQuotedString(std::string& data); + static void fixupMultilineString(std::string& data); + +private: + char peekch() const; + char peekch(unsigned int lookahead) const; + + Position position() const; + + void consume(); + + Lexeme readCommentBody(); + + // Given a sequence [===[ or ]===], returns: + // 1. number of equal signs (or 0 if none present) between the brackets + // 2. -1 if this is not a long comment/string separator + // 3. -N if this is a malformed separator + // Does *not* consume the closing brace. + int skipLongSeparator(); + + Lexeme readLongString(const Position& start, int sep, Lexeme::Type ok, Lexeme::Type broken); + Lexeme readQuotedString(); + + std::pair readName(); + + Lexeme readNumber(const Position& start, unsigned int startOffset); + + Lexeme readUtf8Error(); + Lexeme readNext(); + + const char* buffer; + std::size_t bufferSize; + + unsigned int offset; + + unsigned int line; + unsigned int lineOffset; + + Lexeme lexeme; + + Location prevLocation; + + AstNameTable& names; + + bool skipComments; + bool readNames; +}; + +} // namespace Luau diff --git a/Ast/include/Luau/Location.h b/Ast/include/Luau/Location.h new file mode 100644 index 0000000..d3c0a46 --- /dev/null +++ b/Ast/include/Luau/Location.h @@ -0,0 +1,109 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +namespace Luau +{ + +struct Position +{ + unsigned int line, column; + + Position(unsigned int line, unsigned int column) + : line(line) + , column(column) + { + } + + bool operator==(const Position& rhs) const + { + return this->column == rhs.column && this->line == rhs.line; + } + bool operator!=(const Position& rhs) const + { + return !(*this == rhs); + } + + bool operator<(const Position& rhs) const + { + if (line == rhs.line) + return column < rhs.column; + else + return line < rhs.line; + } + + bool operator>(const Position& rhs) const + { + if (line == rhs.line) + return column > rhs.column; + else + return line > rhs.line; + } + + bool operator<=(const Position& rhs) const + { + return *this == rhs || *this < rhs; + } + + bool operator>=(const Position& rhs) const + { + return *this == rhs || *this > rhs; + } +}; + +struct Location +{ + Position begin, end; + + Location() + : begin(0, 0) + , end(0, 0) + { + } + + Location(const Position& begin, const Position& end) + : begin(begin) + , end(end) + { + } + + Location(const Position& begin, unsigned int length) + : begin(begin) + , end(begin.line, begin.column + length) + { + } + + Location(const Location& begin, const Location& end) + : begin(begin.begin) + , end(end.end) + { + } + + bool operator==(const Location& rhs) const + { + return this->begin == rhs.begin && this->end == rhs.end; + } + bool operator!=(const Location& rhs) const + { + return !(*this == rhs); + } + + bool encloses(const Location& l) const + { + return begin <= l.begin && end >= l.end; + } + bool contains(const Position& p) const + { + return begin <= p && p < end; + } + bool containsClosed(const Position& p) const + { + return begin <= p && p <= end; + } +}; + +std::string toString(const Position& position); +std::string toString(const Location& location); + +} // namespace Luau diff --git a/Ast/include/Luau/ParseOptions.h b/Ast/include/Luau/ParseOptions.h new file mode 100644 index 0000000..89e7952 --- /dev/null +++ b/Ast/include/Luau/ParseOptions.h @@ -0,0 +1,23 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +namespace Luau +{ + +enum class Mode +{ + NoCheck, // Do not perform any inference + Nonstrict, // Unannotated symbols are any + Strict, // Unannotated symbols are inferred + Definition, // Type definition module, has special parsing rules +}; + +struct ParseOptions +{ + bool allowTypeAnnotations = true; + bool supportContinueStatement = true; + bool allowDeclarationSyntax = false; + bool captureComments = false; +}; + +} // namespace Luau diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h new file mode 100644 index 0000000..e6ebd50 --- /dev/null +++ b/Ast/include/Luau/Parser.h @@ -0,0 +1,423 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/Lexer.h" +#include "Luau/ParseOptions.h" +#include "Luau/StringUtils.h" +#include "Luau/DenseHash.h" +#include "Luau/Common.h" + +#include +#include + +namespace Luau +{ + +class ParseError : public std::exception +{ +public: + ParseError(const Location& location, const std::string& message); + + virtual const char* what() const throw(); + + const Location& getLocation() const; + const std::string& getMessage() const; + + static LUAU_NORETURN void raise(const Location& location, const char* format, ...) LUAU_PRINTF_ATTR(2, 3); + +private: + Location location; + std::string message; +}; + +class ParseErrors : public std::exception +{ +public: + ParseErrors(std::vector errors); + + virtual const char* what() const throw(); + + const std::vector& getErrors() const; + +private: + std::vector errors; + std::string message; +}; + +template +class TempVector +{ +public: + explicit TempVector(std::vector& storage); + + ~TempVector(); + + const T& operator[](std::size_t index) const; + + const T& front() const; + + const T& back() const; + + bool empty() const; + + std::size_t size() const; + + void push_back(const T& item); + + typename std::vector::const_iterator begin() const + { + return storage.begin() + offset; + } + typename std::vector::const_iterator end() const + { + return storage.begin() + offset + size_; + } + +private: + std::vector& storage; + size_t offset; + size_t size_; +}; + +struct Comment +{ + Lexeme::Type type; // Comment, BlockComment, or BrokenComment + Location location; +}; + +struct ParseResult +{ + AstStatBlock* root; + std::vector hotcomments; + std::vector errors; + + std::vector commentLocations; +}; + +class Parser +{ +public: + static ParseResult parse( + const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options = ParseOptions()); + + static constexpr const char* errorName = "%error-id%"; + +private: + struct Name; + struct Binding; + + Parser(const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator); + + bool blockFollow(const Lexeme& l); + + AstStatBlock* parseChunk(); + + // chunk ::= {stat [`;']} [laststat [`;']] + // block ::= chunk + AstStatBlock* parseBlock(); + + AstStatBlock* parseBlockNoScope(); + + // stat ::= + // varlist `=' explist | + // functioncall | + // do block end | + // while exp do block end | + // repeat block until exp | + // if exp then block {elseif exp then block} [else block] end | + // for Name `=' exp `,' exp [`,' exp] do block end | + // for namelist in explist do block end | + // function funcname funcbody | + // local function Name funcbody | + // local namelist [`=' explist] + // laststat ::= return [explist] | break + AstStat* parseStat(); + + // if exp then block {elseif exp then block} [else block] end + AstStat* parseIf(); + + // while exp do block end + AstStat* parseWhile(); + + // repeat block until exp + AstStat* parseRepeat(); + + // do block end + AstStat* parseDo(); + + // break + AstStat* parseBreak(); + + // continue + AstStat* parseContinue(const Location& start); + + // for Name `=' exp `,' exp [`,' exp] do block end | + // for namelist in explist do block end | + AstStat* parseFor(); + + // function funcname funcbody | + // funcname ::= Name {`.' Name} [`:' Name] + AstStat* parseFunctionStat(); + + // local function Name funcbody | + // local namelist [`=' explist] + AstStat* parseLocal(); + + // return [explist] + AstStat* parseReturn(); + + // type Name `=' typeannotation + AstStat* parseTypeAlias(const Location& start, bool exported); + + AstDeclaredClassProp parseDeclaredClassMethod(); + + // `declare global' Name: typeannotation | + // `declare function' Name`(' [parlist] `)' [`:` TypeAnnotation] + AstStat* parseDeclaration(const Location& start); + + // varlist `=' explist + AstStat* parseAssignment(AstExpr* initial); + + // var [`+=' | `-=' | `*=' | `/=' | `%=' | `^=' | `..='] exp + AstStat* parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op); + + // funcbody ::= `(' [parlist] `)' block end + // parlist ::= namelist [`,' `...'] | `...' + std::pair parseFunctionBody( + bool hasself, const Lexeme& matchFunction, const AstName& debugname, std::optional localName); + + // explist ::= {exp `,'} exp + void parseExprList(TempVector& result); + + // binding ::= Name [`:` TypeAnnotation] + Binding parseBinding(); + + // bindinglist ::= (binding | `...') {`,' bindinglist} + // Returns the location of the vararg ..., or std::nullopt if the function is not vararg. + std::pair, AstTypePack*> parseBindingList(TempVector& result, bool allowDot3 = false); + + AstType* parseOptionalTypeAnnotation(); + + // TypeList ::= TypeAnnotation [`,' TypeList] + // ReturnType ::= TypeAnnotation | `(' TypeList `)' + // TableProp ::= Name `:' TypeAnnotation + // TableIndexer ::= `[' TypeAnnotation `]' `:' TypeAnnotation + // PropList ::= (TableProp | TableIndexer) [`,' PropList] + // TypeAnnotation + // ::= Name + // | `nil` + // | `{' [PropList] `}' + // | `(' [TypeList] `)' `->` ReturnType + + // Returns the variadic annotation, if it exists. + AstTypePack* parseTypeList(TempVector& result, TempVector>& resultNames); + + std::optional parseOptionalReturnTypeAnnotation(); + std::pair parseReturnTypeAnnotation(); + + AstTableIndexer* parseTableIndexerAnnotation(); + + AstType* parseFunctionTypeAnnotation(); + AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); + + AstType* parseTableTypeAnnotation(); + AstType* parseSimpleTypeAnnotation(); + + AstType* parseTypeAnnotation(TempVector& parts, const Location& begin); + AstType* parseTypeAnnotation(); + + AstTypePack* parseTypePackAnnotation(); + AstTypePack* parseVariadicArgumentAnnotation(); + + static std::optional parseUnaryOp(const Lexeme& l); + static std::optional parseBinaryOp(const Lexeme& l); + static std::optional parseCompoundOp(const Lexeme& l); + + struct BinaryOpPriority + { + unsigned char left, right; + }; + + std::optional checkUnaryConfusables(); + std::optional checkBinaryConfusables(const BinaryOpPriority binaryPriority[], unsigned int limit); + + // subexpr -> (asexp | unop subexpr) { binop subexpr } + // where `binop' is any binary operator with a priority higher than `limit' + AstExpr* parseExpr(unsigned int limit = 0); + + // NAME + AstExpr* parseNameExpr(const char* context = nullptr); + + // prefixexp -> NAME | '(' expr ')' + AstExpr* parsePrefixExpr(); + + // primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } + AstExpr* parsePrimaryExpr(bool asStatement); + + // asexp -> simpleexp [`::' typeAnnotation] + AstExpr* parseAssertionExpr(); + + // simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp + AstExpr* parseSimpleExpr(); + + // args ::= `(' [explist] `)' | tableconstructor | String + AstExpr* parseFunctionArgs(AstExpr* func, bool self, const Location& selfLocation); + + // tableconstructor ::= `{' [fieldlist] `}' + // fieldlist ::= field {fieldsep field} [fieldsep] + // field ::= `[' exp `]' `=' exp | Name `=' exp | exp + // fieldsep ::= `,' | `;' + AstExpr* parseTableConstructor(); + + // TODO: Add grammar rules here? + AstExpr* parseIfElseExpr(); + + // Name + std::optional parseNameOpt(const char* context = nullptr); + Name parseName(const char* context = nullptr); + Name parseIndexName(const char* context, const Position& previous); + + // `<' namelist `>' + std::pair, AstArray> parseGenericTypeList(); + std::pair, AstArray> parseGenericTypeListIfFFlagParseGenericFunctions(); + + // `<' typeAnnotation[, ...] `>' + AstArray parseTypeParams(); + + AstExpr* parseString(); + + AstLocal* pushLocal(const Binding& binding); + + unsigned int saveLocals(); + + void restoreLocals(unsigned int offset); + + // check that parser is at lexeme/symbol, move to next lexeme/symbol on success, report failure and continue on failure + bool expectAndConsume(char value, const char* context = nullptr); + bool expectAndConsume(Lexeme::Type type, const char* context = nullptr); + void expectAndConsumeFail(Lexeme::Type type, const char* context); + + bool expectMatchAndConsume(char value, const Lexeme& begin, bool searchForMissing = false); + void expectMatchAndConsumeFail(Lexeme::Type type, const Lexeme& begin, const char* extra = nullptr); + + bool expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin); + void expectMatchEndAndConsumeFail(Lexeme::Type type, const Lexeme& begin); + + template + AstArray copy(const T* data, std::size_t size); + + template + AstArray copy(const TempVector& data); + + template + AstArray copy(std::initializer_list data); + + AstArray copy(const std::string& data); + + void incrementRecursionCounter(const char* context); + + void report(const Location& location, const char* format, va_list args); + void report(const Location& location, const char* format, ...) LUAU_PRINTF_ATTR(3, 4); + + void reportNameError(const char* context); + + AstStatError* reportStatError(const Location& location, const AstArray& expressions, const AstArray& statements, + const char* format, ...) LUAU_PRINTF_ATTR(5, 6); + AstExprError* reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) LUAU_PRINTF_ATTR(4, 5); + AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) + LUAU_PRINTF_ATTR(5, 6); + + const Lexeme& nextLexeme(); + + struct Function + { + bool vararg; + unsigned int loopDepth; + + Function() + : vararg(false) + , loopDepth(0) + { + } + }; + + struct Local + { + AstLocal* local; + unsigned int offset; + + Local() + : local(nullptr) + , offset(0) + { + } + }; + + struct Name + { + AstName name; + Location location; + + Name(const AstName& name, const Location& location) + : name(name) + , location(location) + { + } + }; + + struct Binding + { + Name name; + AstType* annotation; + + explicit Binding(const Name& name, AstType* annotation = nullptr) + : name(name) + , annotation(annotation) + { + } + }; + + ParseOptions options; + + Lexer lexer; + Allocator& allocator; + + std::vector commentLocations; + + unsigned int recursionCounter; + + AstName nameSelf; + AstName nameNumber; + AstName nameError; + AstName nameNil; + + Lexeme endMismatchSuspect; + + std::vector functionStack; + + DenseHashMap localMap; + std::vector localStack; + + std::vector parseErrors; + + std::vector matchRecoveryStopOnToken; + + std::vector scratchStat; + std::vector scratchExpr; + std::vector scratchExprAux; + std::vector scratchName; + std::vector scratchPackName; + std::vector scratchBinding; + std::vector scratchLocal; + std::vector scratchTableTypeProps; + std::vector scratchAnnotation; + std::vector scratchDeclaredClassProps; + std::vector scratchItem; + std::vector scratchArgName; + std::vector> scratchOptArgName; + std::string scratchData; +}; + +} // namespace Luau diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h new file mode 100644 index 0000000..4f7673f --- /dev/null +++ b/Ast/include/Luau/StringUtils.h @@ -0,0 +1,37 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include + +#include + +#if defined(__GNUC__) +#define LUAU_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg))) +#else +#define LUAU_PRINTF_ATTR(fmt, arg) +#endif + +namespace Luau +{ + +std::string format(const char* fmt, ...) LUAU_PRINTF_ATTR(1, 2); +std::string vformat(const char* fmt, va_list args); + +void formatAppend(std::string& str, const char* fmt, ...) LUAU_PRINTF_ATTR(2, 3); + +std::string join(const std::vector& segments, std::string_view delimiter); +std::string join(const std::vector& segments, std::string_view delimiter); + +std::vector split(std::string_view s, char delimiter); + +// Computes the Damerau-Levenshtein distance of A and B. +// https://en.wikipedia.org/wiki/Damerau-Levenshtein_distance#Distance_with_adjacent_transpositions +size_t editDistance(std::string_view a, std::string_view b); + +bool startsWith(std::string_view lhs, std::string_view rhs); +bool equalsLower(std::string_view lhs, std::string_view rhs); + +size_t hashRange(const char* data, size_t size); + +} // namespace Luau diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp new file mode 100644 index 0000000..fff1537 --- /dev/null +++ b/Ast/src/Ast.cpp @@ -0,0 +1,886 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Ast.h" + +#include "Luau/Common.h" + +namespace Luau +{ + +static void visitTypeList(AstVisitor* visitor, const AstTypeList& list) +{ + for (AstType* ty : list.types) + ty->visit(visitor); + + if (list.tailType) + list.tailType->visit(visitor); +} + +int gAstRttiIndex = 0; + +AstExprGroup::AstExprGroup(const Location& location, AstExpr* expr) + : AstExpr(ClassIndex(), location) + , expr(expr) +{ +} + +void AstExprGroup::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + expr->visit(visitor); +} + +AstExprConstantNil::AstExprConstantNil(const Location& location) + : AstExpr(ClassIndex(), location) +{ +} + +void AstExprConstantNil::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstExprConstantBool::AstExprConstantBool(const Location& location, bool value) + : AstExpr(ClassIndex(), location) + , value(value) +{ +} + +void AstExprConstantBool::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstExprConstantNumber::AstExprConstantNumber(const Location& location, double value) + : AstExpr(ClassIndex(), location) + , value(value) +{ +} + +void AstExprConstantNumber::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstExprConstantString::AstExprConstantString(const Location& location, const AstArray& value) + : AstExpr(ClassIndex(), location) + , value(value) +{ +} + +void AstExprConstantString::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstExprLocal::AstExprLocal(const Location& location, AstLocal* local, bool upvalue) + : AstExpr(ClassIndex(), location) + , local(local) + , upvalue(upvalue) +{ +} + +void AstExprLocal::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstExprGlobal::AstExprGlobal(const Location& location, const AstName& name) + : AstExpr(ClassIndex(), location) + , name(name) +{ +} + +void AstExprGlobal::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstExprVarargs::AstExprVarargs(const Location& location) + : AstExpr(ClassIndex(), location) +{ +} + +void AstExprVarargs::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstExprCall::AstExprCall(const Location& location, AstExpr* func, const AstArray& args, bool self, const Location& argLocation) + : AstExpr(ClassIndex(), location) + , func(func) + , args(args) + , self(self) + , argLocation(argLocation) +{ +} + +void AstExprCall::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + func->visit(visitor); + + for (AstExpr* arg : args) + arg->visit(visitor); + } +} + +AstExprIndexName::AstExprIndexName( + const Location& location, AstExpr* expr, const AstName& index, const Location& indexLocation, const Position& opPosition, char op) + : AstExpr(ClassIndex(), location) + , expr(expr) + , index(index) + , indexLocation(indexLocation) + , opPosition(opPosition) + , op(op) +{ +} + +void AstExprIndexName::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + expr->visit(visitor); +} + +AstExprIndexExpr::AstExprIndexExpr(const Location& location, AstExpr* expr, AstExpr* index) + : AstExpr(ClassIndex(), location) + , expr(expr) + , index(index) +{ +} + +void AstExprIndexExpr::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + expr->visit(visitor); + index->visit(visitor); + } +} + +AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, + const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, + std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, std::optional argLocation) + : AstExpr(ClassIndex(), location) + , generics(generics) + , genericPacks(genericPacks) + , self(self) + , args(args) + , hasReturnAnnotation(returnAnnotation.has_value()) + , returnAnnotation() + , vararg(vararg.has_value()) + , varargLocation(vararg.value_or(Location())) + , varargAnnotation(varargAnnotation) + , body(body) + , functionDepth(functionDepth) + , debugname(debugname) + , hasEnd(hasEnd) + , argLocation(argLocation) +{ + if (returnAnnotation.has_value()) + this->returnAnnotation = *returnAnnotation; +} + +void AstExprFunction::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstLocal* arg : args) + { + if (arg->annotation) + arg->annotation->visit(visitor); + } + + if (varargAnnotation) + varargAnnotation->visit(visitor); + + if (hasReturnAnnotation) + visitTypeList(visitor, returnAnnotation); + + body->visit(visitor); + } +} + +AstExprTable::AstExprTable(const Location& location, const AstArray& items) + : AstExpr(ClassIndex(), location) + , items(items) +{ +} + +void AstExprTable::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (const Item& item : items) + { + if (item.key) + item.key->visit(visitor); + + item.value->visit(visitor); + } + } +} + +AstExprUnary::AstExprUnary(const Location& location, Op op, AstExpr* expr) + : AstExpr(ClassIndex(), location) + , op(op) + , expr(expr) +{ +} + +void AstExprUnary::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + expr->visit(visitor); +} + +std::string toString(AstExprUnary::Op op) +{ + switch (op) + { + case AstExprUnary::Minus: + return "-"; + case AstExprUnary::Not: + return "not"; + case AstExprUnary::Len: + return "#"; + default: + LUAU_ASSERT(false); + return ""; // MSVC requires this even though the switch/case is exhaustive + } +} + +AstExprBinary::AstExprBinary(const Location& location, Op op, AstExpr* left, AstExpr* right) + : AstExpr(ClassIndex(), location) + , op(op) + , left(left) + , right(right) +{ +} + +void AstExprBinary::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + left->visit(visitor); + right->visit(visitor); + } +} + +std::string toString(AstExprBinary::Op op) +{ + switch (op) + { + case AstExprBinary::Add: + return "+"; + case AstExprBinary::Sub: + return "-"; + case AstExprBinary::Mul: + return "*"; + case AstExprBinary::Div: + return "/"; + case AstExprBinary::Mod: + return "%"; + case AstExprBinary::Pow: + return "^"; + case AstExprBinary::Concat: + return ".."; + case AstExprBinary::CompareNe: + return "~="; + case AstExprBinary::CompareEq: + return "=="; + case AstExprBinary::CompareLt: + return "<"; + case AstExprBinary::CompareLe: + return "<="; + case AstExprBinary::CompareGt: + return ">"; + case AstExprBinary::CompareGe: + return ">="; + case AstExprBinary::And: + return "and"; + case AstExprBinary::Or: + return "or"; + default: + LUAU_ASSERT(false); + return ""; // MSVC requires this even though the switch/case is exhaustive + } +} + +AstExprTypeAssertion::AstExprTypeAssertion(const Location& location, AstExpr* expr, AstType* annotation) + : AstExpr(ClassIndex(), location) + , expr(expr) + , annotation(annotation) +{ +} + +void AstExprTypeAssertion::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + expr->visit(visitor); + annotation->visit(visitor); + } +} + +AstExprIfElse::AstExprIfElse(const Location& location, AstExpr* condition, bool hasThen, AstExpr* trueExpr, bool hasElse, AstExpr* falseExpr) + : AstExpr(ClassIndex(), location) + , condition(condition) + , hasThen(hasThen) + , trueExpr(trueExpr) + , hasElse(hasElse) + , falseExpr(falseExpr) +{ +} + +void AstExprIfElse::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + condition->visit(visitor); + trueExpr->visit(visitor); + falseExpr->visit(visitor); + } +} + +AstExprError::AstExprError(const Location& location, const AstArray& expressions, unsigned messageIndex) + : AstExpr(ClassIndex(), location) + , expressions(expressions) + , messageIndex(messageIndex) +{ +} + +void AstExprError::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstExpr* expression : expressions) + expression->visit(visitor); + } +} + +AstStatBlock::AstStatBlock(const Location& location, const AstArray& body) + : AstStat(ClassIndex(), location) + , body(body) +{ +} + +void AstStatBlock::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstStat* stat : body) + stat->visit(visitor); + } +} + +AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, bool hasThen, + const Location& thenLocation, const std::optional& elseLocation, bool hasEnd) + : AstStat(ClassIndex(), location) + , condition(condition) + , thenbody(thenbody) + , elsebody(elsebody) + , hasThen(hasThen) + , thenLocation(thenLocation) + , hasEnd(hasEnd) +{ + if (bool(elseLocation)) + { + hasElse = true; + this->elseLocation = *elseLocation; + } +} + +void AstStatIf::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + condition->visit(visitor); + thenbody->visit(visitor); + + if (elsebody) + elsebody->visit(visitor); + } +} + +AstStatWhile::AstStatWhile(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasDo, const Location& doLocation, bool hasEnd) + : AstStat(ClassIndex(), location) + , condition(condition) + , body(body) + , hasDo(hasDo) + , doLocation(doLocation) + , hasEnd(hasEnd) +{ +} + +void AstStatWhile::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + condition->visit(visitor); + body->visit(visitor); + } +} + +AstStatRepeat::AstStatRepeat(const Location& location, AstExpr* condition, AstStatBlock* body, bool hasUntil) + : AstStat(ClassIndex(), location) + , condition(condition) + , body(body) + , hasUntil(hasUntil) +{ +} + +void AstStatRepeat::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + body->visit(visitor); + condition->visit(visitor); + } +} + +AstStatBreak::AstStatBreak(const Location& location) + : AstStat(ClassIndex(), location) +{ +} + +void AstStatBreak::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstStatContinue::AstStatContinue(const Location& location) + : AstStat(ClassIndex(), location) +{ +} + +void AstStatContinue::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstStatReturn::AstStatReturn(const Location& location, const AstArray& list) + : AstStat(ClassIndex(), location) + , list(list) +{ +} + +void AstStatReturn::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstExpr* expr : list) + expr->visit(visitor); + } +} + +AstStatExpr::AstStatExpr(const Location& location, AstExpr* expr) + : AstStat(ClassIndex(), location) + , expr(expr) +{ +} + +void AstStatExpr::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + expr->visit(visitor); +} + +AstStatLocal::AstStatLocal( + const Location& location, const AstArray& vars, const AstArray& values, const std::optional& equalsSignLocation) + : AstStat(ClassIndex(), location) + , vars(vars) + , values(values) +{ + if (bool(equalsSignLocation)) + { + hasEqualsSign = true; + this->equalsSignLocation = *equalsSignLocation; + } +} + +void AstStatLocal::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstLocal* var : vars) + { + if (var->annotation) + var->annotation->visit(visitor); + } + + for (AstExpr* expr : values) + expr->visit(visitor); + } +} + +AstStatFor::AstStatFor(const Location& location, AstLocal* var, AstExpr* from, AstExpr* to, AstExpr* step, AstStatBlock* body, bool hasDo, + const Location& doLocation, bool hasEnd) + : AstStat(ClassIndex(), location) + , var(var) + , from(from) + , to(to) + , step(step) + , body(body) + , hasDo(hasDo) + , doLocation(doLocation) + , hasEnd(hasEnd) +{ +} + +void AstStatFor::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + if (var->annotation) + var->annotation->visit(visitor); + + from->visit(visitor); + to->visit(visitor); + + if (step) + step->visit(visitor); + + body->visit(visitor); + } +} + +AstStatForIn::AstStatForIn(const Location& location, const AstArray& vars, const AstArray& values, AstStatBlock* body, + bool hasIn, const Location& inLocation, bool hasDo, const Location& doLocation, bool hasEnd) + : AstStat(ClassIndex(), location) + , vars(vars) + , values(values) + , body(body) + , hasIn(hasIn) + , inLocation(inLocation) + , hasDo(hasDo) + , doLocation(doLocation) + , hasEnd(hasEnd) +{ +} + +void AstStatForIn::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstLocal* var : vars) + { + if (var->annotation) + var->annotation->visit(visitor); + } + + for (AstExpr* expr : values) + expr->visit(visitor); + + body->visit(visitor); + } +} + +AstStatAssign::AstStatAssign(const Location& location, const AstArray& vars, const AstArray& values) + : AstStat(ClassIndex(), location) + , vars(vars) + , values(values) +{ +} + +void AstStatAssign::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstExpr* lvalue : vars) + lvalue->visit(visitor); + + for (AstExpr* expr : values) + expr->visit(visitor); + } +} + +AstStatCompoundAssign::AstStatCompoundAssign(const Location& location, AstExprBinary::Op op, AstExpr* var, AstExpr* value) + : AstStat(ClassIndex(), location) + , op(op) + , var(var) + , value(value) +{ +} + +void AstStatCompoundAssign::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + var->visit(visitor); + value->visit(visitor); + } +} + +AstStatFunction::AstStatFunction(const Location& location, AstExpr* name, AstExprFunction* func) + : AstStat(ClassIndex(), location) + , name(name) + , func(func) +{ +} + +void AstStatFunction::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + name->visit(visitor); + func->visit(visitor); + } +} + +AstStatLocalFunction::AstStatLocalFunction(const Location& location, AstLocal* name, AstExprFunction* func) + : AstStat(ClassIndex(), location) + , name(name) + , func(func) +{ +} + +void AstStatLocalFunction::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + func->visit(visitor); +} + +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported) + : AstStat(ClassIndex(), location) + , name(name) + , generics(generics) + , type(type) + , exported(exported) +{ +} + +void AstStatTypeAlias::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + type->visit(visitor); +} + +AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) + : AstStat(ClassIndex(), location) + , name(name) + , type(type) +{ +} + +void AstStatDeclareGlobal::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + type->visit(visitor); +} + +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes) + : AstStat(ClassIndex(), location) + , name(name) + , generics(generics) + , genericPacks(genericPacks) + , params(params) + , paramNames(paramNames) + , retTypes(retTypes) +{ +} + +void AstStatDeclareFunction::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + visitTypeList(visitor, params); + visitTypeList(visitor, retTypes); + } +} + +AstStatDeclareClass::AstStatDeclareClass( + const Location& location, const AstName& name, std::optional superName, const AstArray& props) + : AstStat(ClassIndex(), location) + , name(name) + , superName(superName) + , props(props) +{ +} + +void AstStatDeclareClass::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (const AstDeclaredClassProp& prop : props) + prop.ty->visit(visitor); + } +} + +AstStatError::AstStatError( + const Location& location, const AstArray& expressions, const AstArray& statements, unsigned messageIndex) + : AstStat(ClassIndex(), location) + , expressions(expressions) + , statements(statements) + , messageIndex(messageIndex) +{ +} + +void AstStatError::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstNode* expression : expressions) + expression->visit(visitor); + + for (AstNode* statement : statements) + statement->visit(visitor); + } +} + +AstTypeReference::AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics) + : AstType(ClassIndex(), location) + , hasPrefix(bool(prefix)) + , prefix(prefix ? *prefix : AstName()) + , name(name) + , generics(generics) +{ +} + +void AstTypeReference::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstType* generic : generics) + generic->visit(visitor); + } +} + +AstTypeTable::AstTypeTable(const Location& location, const AstArray& props, AstTableIndexer* indexer) + : AstType(ClassIndex(), location) + , props(props) + , indexer(indexer) +{ +} + +void AstTypeTable::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (const AstTableProp& prop : props) + prop.type->visit(visitor); + + if (indexer) + { + indexer->indexType->visit(visitor); + indexer->resultType->visit(visitor); + } + } +} + +AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) + : AstType(ClassIndex(), location) + , generics(generics) + , genericPacks(genericPacks) + , argTypes(argTypes) + , argNames(argNames) + , returnTypes(returnTypes) +{ + LUAU_ASSERT(argNames.size == 0 || argNames.size == argTypes.types.size); +} + +void AstTypeFunction::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + visitTypeList(visitor, argTypes); + visitTypeList(visitor, returnTypes); + } +} + +AstTypeTypeof::AstTypeTypeof(const Location& location, AstExpr* expr) + : AstType(ClassIndex(), location) + , expr(expr) +{ +} + +void AstTypeTypeof::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + expr->visit(visitor); +} + +AstTypeUnion::AstTypeUnion(const Location& location, const AstArray& types) + : AstType(ClassIndex(), location) + , types(types) +{ +} + +void AstTypeUnion::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstType* type : types) + type->visit(visitor); + } +} + +AstTypeIntersection::AstTypeIntersection(const Location& location, const AstArray& types) + : AstType(ClassIndex(), location) + , types(types) +{ +} + +void AstTypeIntersection::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstType* type : types) + type->visit(visitor); + } +} + +AstTypeError::AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex) + : AstType(ClassIndex(), location) + , types(types) + , isMissing(isMissing) + , messageIndex(messageIndex) +{ +} + +void AstTypeError::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstType* type : types) + type->visit(visitor); + } +} + +AstTypePackVariadic::AstTypePackVariadic(const Location& location, AstType* variadicType) + : AstTypePack(ClassIndex(), location) + , variadicType(variadicType) +{ +} + +void AstTypePackVariadic::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + variadicType->visit(visitor); +} + +AstTypePackGeneric::AstTypePackGeneric(const Location& location, AstName name) + : AstTypePack(ClassIndex(), location) + , genericName(name) +{ +} + +void AstTypePackGeneric::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstName getIdentifier(AstExpr* node) +{ + if (AstExprGlobal* expr = node->as()) + return expr->name; + + if (AstExprLocal* expr = node->as()) + return expr->local->name; + + return AstName(); +} + +} // namespace Luau diff --git a/Ast/src/Confusables.cpp b/Ast/src/Confusables.cpp new file mode 100644 index 0000000..1c79215 --- /dev/null +++ b/Ast/src/Confusables.cpp @@ -0,0 +1,1818 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Confusables.h" + +#include +#include + +namespace Luau +{ + +struct Confusable +{ + unsigned codepoint : 24; + char text[5]; +}; + +// Derived from http://www.unicode.org/Public/security/10.0.0/confusables.txt; sorted by codepoint +// clang-format off +static const Confusable kConfusables[] = +{ + {34, "\""}, // MA#* ( " → '' ) QUOTATION MARK → APOSTROPHE, APOSTROPHE# # Converted to a quote. + {48, "O"}, // MA# ( 0 → O ) DIGIT ZERO → LATIN CAPITAL LETTER O# + {49, "l"}, // MA# ( 1 → l ) DIGIT ONE → LATIN SMALL LETTER L# + {73, "l"}, // MA# ( I → l ) LATIN CAPITAL LETTER I → LATIN SMALL LETTER L# + {96, "'"}, // MA#* ( ` → ' ) GRAVE ACCENT → APOSTROPHE# →ˋ→→`→→‘→ + {109, "rn"}, // MA# ( m → rn ) LATIN SMALL LETTER M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# + {124, "l"}, // MA#* ( | → l ) VERTICAL LINE → LATIN SMALL LETTER L# + {160, " "}, // MA#* (   → ) NO-BREAK SPACE → SPACE# + {180, "'"}, // MA#* ( ´ → ' ) ACUTE ACCENT → APOSTROPHE# →΄→→ʹ→ + {184, ","}, // MA#* ( ¸ → , ) CEDILLA → COMMA# + {198, "AE"}, // MA# ( Æ → AE ) LATIN CAPITAL LETTER AE → LATIN CAPITAL LETTER A, LATIN CAPITAL LETTER E# + {215, "x"}, // MA#* ( × → x ) MULTIPLICATION SIGN → LATIN SMALL LETTER X# + {230, "ae"}, // MA# ( æ → ae ) LATIN SMALL LETTER AE → LATIN SMALL LETTER A, LATIN SMALL LETTER E# + {305, "i"}, // MA# ( ı → i ) LATIN SMALL LETTER DOTLESS I → LATIN SMALL LETTER I# + {306, "lJ"}, // MA# ( IJ → lJ ) LATIN CAPITAL LIGATURE IJ → LATIN SMALL LETTER L, LATIN CAPITAL LETTER J# →IJ→ + {307, "ij"}, // MA# ( ij → ij ) LATIN SMALL LIGATURE IJ → LATIN SMALL LETTER I, LATIN SMALL LETTER J# + {329, "'n"}, // MA# ( ʼn → 'n ) LATIN SMALL LETTER N PRECEDED BY APOSTROPHE → APOSTROPHE, LATIN SMALL LETTER N# →ʼn→ + {338, "OE"}, // MA# ( Œ → OE ) LATIN CAPITAL LIGATURE OE → LATIN CAPITAL LETTER O, LATIN CAPITAL LETTER E# + {339, "oe"}, // MA# ( œ → oe ) LATIN SMALL LIGATURE OE → LATIN SMALL LETTER O, LATIN SMALL LETTER E# + {383, "f"}, // MA# ( ſ → f ) LATIN SMALL LETTER LONG S → LATIN SMALL LETTER F# + {385, "'B"}, // MA# ( Ɓ → 'B ) LATIN CAPITAL LETTER B WITH HOOK → APOSTROPHE, LATIN CAPITAL LETTER B# →ʽB→ + {388, "b"}, // MA# ( Ƅ → b ) LATIN CAPITAL LETTER TONE SIX → LATIN SMALL LETTER B# + {391, "C'"}, // MA# ( Ƈ → C' ) LATIN CAPITAL LETTER C WITH HOOK → LATIN CAPITAL LETTER C, APOSTROPHE# →Cʽ→ + {394, "'D"}, // MA# ( Ɗ → 'D ) LATIN CAPITAL LETTER D WITH HOOK → APOSTROPHE, LATIN CAPITAL LETTER D# →ʽD→ + {397, "g"}, // MA# ( ƍ → g ) LATIN SMALL LETTER TURNED DELTA → LATIN SMALL LETTER G# + {403, "G'"}, // MA# ( Ɠ → G' ) LATIN CAPITAL LETTER G WITH HOOK → LATIN CAPITAL LETTER G, APOSTROPHE# →Gʽ→ + {406, "l"}, // MA# ( Ɩ → l ) LATIN CAPITAL LETTER IOTA → LATIN SMALL LETTER L# + {408, "K'"}, // MA# ( Ƙ → K' ) LATIN CAPITAL LETTER K WITH HOOK → LATIN CAPITAL LETTER K, APOSTROPHE# →Kʽ→ + {416, "O'"}, // MA# ( Ơ → O' ) LATIN CAPITAL LETTER O WITH HORN → LATIN CAPITAL LETTER O, APOSTROPHE# →Oʼ→ + {417, "o'"}, // MA# ( ơ → o' ) LATIN SMALL LETTER O WITH HORN → LATIN SMALL LETTER O, APOSTROPHE# →oʼ→ + {420, "'P"}, // MA# ( Ƥ → 'P ) LATIN CAPITAL LETTER P WITH HOOK → APOSTROPHE, LATIN CAPITAL LETTER P# →ʽP→ + {422, "R"}, // MA# ( Ʀ → R ) LATIN LETTER YR → LATIN CAPITAL LETTER R# + {423, "2"}, // MA# ( Ƨ → 2 ) LATIN CAPITAL LETTER TONE TWO → DIGIT TWO# + {428, "'T"}, // MA# ( Ƭ → 'T ) LATIN CAPITAL LETTER T WITH HOOK → APOSTROPHE, LATIN CAPITAL LETTER T# →ʽT→ + {435, "'Y"}, // MA# ( Ƴ → 'Y ) LATIN CAPITAL LETTER Y WITH HOOK → APOSTROPHE, LATIN CAPITAL LETTER Y# →ʽY→ + {439, "3"}, // MA# ( Ʒ → 3 ) LATIN CAPITAL LETTER EZH → DIGIT THREE# + {444, "5"}, // MA# ( Ƽ → 5 ) LATIN CAPITAL LETTER TONE FIVE → DIGIT FIVE# + {445, "s"}, // MA# ( ƽ → s ) LATIN SMALL LETTER TONE FIVE → LATIN SMALL LETTER S# + {448, "l"}, // MA# ( ǀ → l ) LATIN LETTER DENTAL CLICK → LATIN SMALL LETTER L# + {449, "ll"}, // MA# ( ǁ → ll ) LATIN LETTER LATERAL CLICK → LATIN SMALL LETTER L, LATIN SMALL LETTER L# →‖→→∥→→||→ + {451, "!"}, // MA# ( ǃ → ! ) LATIN LETTER RETROFLEX CLICK → EXCLAMATION MARK# + {455, "LJ"}, // MA# ( LJ → LJ ) LATIN CAPITAL LETTER LJ → LATIN CAPITAL LETTER L, LATIN CAPITAL LETTER J# + {456, "Lj"}, // MA# ( Lj → Lj ) LATIN CAPITAL LETTER L WITH SMALL LETTER J → LATIN CAPITAL LETTER L, LATIN SMALL LETTER J# + {457, "lj"}, // MA# ( lj → lj ) LATIN SMALL LETTER LJ → LATIN SMALL LETTER L, LATIN SMALL LETTER J# + {458, "NJ"}, // MA# ( NJ → NJ ) LATIN CAPITAL LETTER NJ → LATIN CAPITAL LETTER N, LATIN CAPITAL LETTER J# + {459, "Nj"}, // MA# ( Nj → Nj ) LATIN CAPITAL LETTER N WITH SMALL LETTER J → LATIN CAPITAL LETTER N, LATIN SMALL LETTER J# + {460, "nj"}, // MA# ( nj → nj ) LATIN SMALL LETTER NJ → LATIN SMALL LETTER N, LATIN SMALL LETTER J# + {497, "DZ"}, // MA# ( DZ → DZ ) LATIN CAPITAL LETTER DZ → LATIN CAPITAL LETTER D, LATIN CAPITAL LETTER Z# + {498, "Dz"}, // MA# ( Dz → Dz ) LATIN CAPITAL LETTER D WITH SMALL LETTER Z → LATIN CAPITAL LETTER D, LATIN SMALL LETTER Z# + {499, "dz"}, // MA# ( dz → dz ) LATIN SMALL LETTER DZ → LATIN SMALL LETTER D, LATIN SMALL LETTER Z# + {540, "3"}, // MA# ( Ȝ → 3 ) LATIN CAPITAL LETTER YOGH → DIGIT THREE# →Ʒ→ + {546, "8"}, // MA# ( Ȣ → 8 ) LATIN CAPITAL LETTER OU → DIGIT EIGHT# + {547, "8"}, // MA# ( ȣ → 8 ) LATIN SMALL LETTER OU → DIGIT EIGHT# + {577, "?"}, // MA# ( Ɂ → ? ) LATIN CAPITAL LETTER GLOTTAL STOP → QUESTION MARK# →ʔ→ + {593, "a"}, // MA# ( ɑ → a ) LATIN SMALL LETTER ALPHA → LATIN SMALL LETTER A# + {609, "g"}, // MA# ( ɡ → g ) LATIN SMALL LETTER SCRIPT G → LATIN SMALL LETTER G# + {611, "y"}, // MA# ( ɣ → y ) LATIN SMALL LETTER GAMMA → LATIN SMALL LETTER Y# →γ→ + {617, "i"}, // MA# ( ɩ → i ) LATIN SMALL LETTER IOTA → LATIN SMALL LETTER I# + {618, "i"}, // MA# ( ɪ → i ) LATIN LETTER SMALL CAPITAL I → LATIN SMALL LETTER I# →ı→ + {623, "w"}, // MA# ( ɯ → w ) LATIN SMALL LETTER TURNED M → LATIN SMALL LETTER W# + {651, "u"}, // MA# ( ʋ → u ) LATIN SMALL LETTER V WITH HOOK → LATIN SMALL LETTER U# + {655, "y"}, // MA# ( ʏ → y ) LATIN LETTER SMALL CAPITAL Y → LATIN SMALL LETTER Y# →ү→→γ→ + {660, "?"}, // MA# ( ʔ → ? ) LATIN LETTER GLOTTAL STOP → QUESTION MARK# + {675, "dz"}, // MA# ( ʣ → dz ) LATIN SMALL LETTER DZ DIGRAPH → LATIN SMALL LETTER D, LATIN SMALL LETTER Z# + {678, "ts"}, // MA# ( ʦ → ts ) LATIN SMALL LETTER TS DIGRAPH → LATIN SMALL LETTER T, LATIN SMALL LETTER S# + {682, "ls"}, // MA# ( ʪ → ls ) LATIN SMALL LETTER LS DIGRAPH → LATIN SMALL LETTER L, LATIN SMALL LETTER S# + {683, "lz"}, // MA# ( ʫ → lz ) LATIN SMALL LETTER LZ DIGRAPH → LATIN SMALL LETTER L, LATIN SMALL LETTER Z# + {697, "'"}, // MA# ( ʹ → ' ) MODIFIER LETTER PRIME → APOSTROPHE# + {698, "\""}, // MA# ( ʺ → '' ) MODIFIER LETTER DOUBLE PRIME → APOSTROPHE, APOSTROPHE# →"→# Converted to a quote. + {699, "'"}, // MA# ( ʻ → ' ) MODIFIER LETTER TURNED COMMA → APOSTROPHE# →‘→ + {700, "'"}, // MA# ( ʼ → ' ) MODIFIER LETTER APOSTROPHE → APOSTROPHE# →′→ + {701, "'"}, // MA# ( ʽ → ' ) MODIFIER LETTER REVERSED COMMA → APOSTROPHE# →‘→ + {702, "'"}, // MA# ( ʾ → ' ) MODIFIER LETTER RIGHT HALF RING → APOSTROPHE# →ʼ→→′→ + {706, "<"}, // MA#* ( ˂ → < ) MODIFIER LETTER LEFT ARROWHEAD → LESS-THAN SIGN# + {707, ">"}, // MA#* ( ˃ → > ) MODIFIER LETTER RIGHT ARROWHEAD → GREATER-THAN SIGN# + {708, "^"}, // MA#* ( ˄ → ^ ) MODIFIER LETTER UP ARROWHEAD → CIRCUMFLEX ACCENT# + {710, "^"}, // MA# ( ˆ → ^ ) MODIFIER LETTER CIRCUMFLEX ACCENT → CIRCUMFLEX ACCENT# + {712, "'"}, // MA# ( ˈ → ' ) MODIFIER LETTER VERTICAL LINE → APOSTROPHE# + {714, "'"}, // MA# ( ˊ → ' ) MODIFIER LETTER ACUTE ACCENT → APOSTROPHE# →ʹ→→′→ + {715, "'"}, // MA# ( ˋ → ' ) MODIFIER LETTER GRAVE ACCENT → APOSTROPHE# →`→→‘→ + {720, ":"}, // MA# ( ː → : ) MODIFIER LETTER TRIANGULAR COLON → COLON# + {727, "-"}, // MA#* ( ˗ → - ) MODIFIER LETTER MINUS SIGN → HYPHEN-MINUS# + {731, "i"}, // MA#* ( ˛ → i ) OGONEK → LATIN SMALL LETTER I# →ͺ→→ι→→ι→ + {732, "~"}, // MA#* ( ˜ → ~ ) SMALL TILDE → TILDE# + {733, "\""}, // MA#* ( ˝ → '' ) DOUBLE ACUTE ACCENT → APOSTROPHE, APOSTROPHE# →"→# Converted to a quote. + {750, "\""}, // MA# ( ˮ → '' ) MODIFIER LETTER DOUBLE APOSTROPHE → APOSTROPHE, APOSTROPHE# →″→→"→# Converted to a quote. + {756, "'"}, // MA#* ( ˴ → ' ) MODIFIER LETTER MIDDLE GRAVE ACCENT → APOSTROPHE# →ˋ→→`→→‘→ + {758, "\""}, // MA#* ( ˶ → '' ) MODIFIER LETTER MIDDLE DOUBLE ACUTE ACCENT → APOSTROPHE, APOSTROPHE# →˝→→"→# Converted to a quote. + {760, ":"}, // MA#* ( ˸ → : ) MODIFIER LETTER RAISED COLON → COLON# + {884, "'"}, // MA# ( ʹ → ' ) GREEK NUMERAL SIGN → APOSTROPHE# →′→ + {890, "i"}, // MA#* ( ͺ → i ) GREEK YPOGEGRAMMENI → LATIN SMALL LETTER I# →ι→→ι→ + {894, ";"}, // MA#* ( ; → ; ) GREEK QUESTION MARK → SEMICOLON# + {895, "J"}, // MA# ( Ϳ → J ) GREEK CAPITAL LETTER YOT → LATIN CAPITAL LETTER J# + {900, "'"}, // MA#* ( ΄ → ' ) GREEK TONOS → APOSTROPHE# →ʹ→ + {913, "A"}, // MA# ( Α → A ) GREEK CAPITAL LETTER ALPHA → LATIN CAPITAL LETTER A# + {914, "B"}, // MA# ( Β → B ) GREEK CAPITAL LETTER BETA → LATIN CAPITAL LETTER B# + {917, "E"}, // MA# ( Ε → E ) GREEK CAPITAL LETTER EPSILON → LATIN CAPITAL LETTER E# + {918, "Z"}, // MA# ( Ζ → Z ) GREEK CAPITAL LETTER ZETA → LATIN CAPITAL LETTER Z# + {919, "H"}, // MA# ( Η → H ) GREEK CAPITAL LETTER ETA → LATIN CAPITAL LETTER H# + {921, "l"}, // MA# ( Ι → l ) GREEK CAPITAL LETTER IOTA → LATIN SMALL LETTER L# + {922, "K"}, // MA# ( Κ → K ) GREEK CAPITAL LETTER KAPPA → LATIN CAPITAL LETTER K# + {924, "M"}, // MA# ( Μ → M ) GREEK CAPITAL LETTER MU → LATIN CAPITAL LETTER M# + {925, "N"}, // MA# ( Ν → N ) GREEK CAPITAL LETTER NU → LATIN CAPITAL LETTER N# + {927, "O"}, // MA# ( Ο → O ) GREEK CAPITAL LETTER OMICRON → LATIN CAPITAL LETTER O# + {929, "P"}, // MA# ( Ρ → P ) GREEK CAPITAL LETTER RHO → LATIN CAPITAL LETTER P# + {932, "T"}, // MA# ( Τ → T ) GREEK CAPITAL LETTER TAU → LATIN CAPITAL LETTER T# + {933, "Y"}, // MA# ( Υ → Y ) GREEK CAPITAL LETTER UPSILON → LATIN CAPITAL LETTER Y# + {935, "X"}, // MA# ( Χ → X ) GREEK CAPITAL LETTER CHI → LATIN CAPITAL LETTER X# + {945, "a"}, // MA# ( α → a ) GREEK SMALL LETTER ALPHA → LATIN SMALL LETTER A# + {947, "y"}, // MA# ( γ → y ) GREEK SMALL LETTER GAMMA → LATIN SMALL LETTER Y# + {953, "i"}, // MA# ( ι → i ) GREEK SMALL LETTER IOTA → LATIN SMALL LETTER I# + {957, "v"}, // MA# ( ν → v ) GREEK SMALL LETTER NU → LATIN SMALL LETTER V# + {959, "o"}, // MA# ( ο → o ) GREEK SMALL LETTER OMICRON → LATIN SMALL LETTER O# + {961, "p"}, // MA# ( ρ → p ) GREEK SMALL LETTER RHO → LATIN SMALL LETTER P# + {963, "o"}, // MA# ( σ → o ) GREEK SMALL LETTER SIGMA → LATIN SMALL LETTER O# + {965, "u"}, // MA# ( υ → u ) GREEK SMALL LETTER UPSILON → LATIN SMALL LETTER U# →ʋ→ + {978, "Y"}, // MA# ( ϒ → Y ) GREEK UPSILON WITH HOOK SYMBOL → LATIN CAPITAL LETTER Y# + {988, "F"}, // MA# ( Ϝ → F ) GREEK LETTER DIGAMMA → LATIN CAPITAL LETTER F# + {1000, "2"}, // MA# ( Ϩ → 2 ) COPTIC CAPITAL LETTER HORI → DIGIT TWO# →Ƨ→ + {1009, "p"}, // MA# ( ϱ → p ) GREEK RHO SYMBOL → LATIN SMALL LETTER P# →ρ→ + {1010, "c"}, // MA# ( ϲ → c ) GREEK LUNATE SIGMA SYMBOL → LATIN SMALL LETTER C# + {1011, "j"}, // MA# ( ϳ → j ) GREEK LETTER YOT → LATIN SMALL LETTER J# + {1017, "C"}, // MA# ( Ϲ → C ) GREEK CAPITAL LUNATE SIGMA SYMBOL → LATIN CAPITAL LETTER C# + {1018, "M"}, // MA# ( Ϻ → M ) GREEK CAPITAL LETTER SAN → LATIN CAPITAL LETTER M# + {1029, "S"}, // MA# ( Ѕ → S ) CYRILLIC CAPITAL LETTER DZE → LATIN CAPITAL LETTER S# + {1030, "l"}, // MA# ( І → l ) CYRILLIC CAPITAL LETTER BYELORUSSIAN-UKRAINIAN I → LATIN SMALL LETTER L# + {1032, "J"}, // MA# ( Ј → J ) CYRILLIC CAPITAL LETTER JE → LATIN CAPITAL LETTER J# + {1040, "A"}, // MA# ( А → A ) CYRILLIC CAPITAL LETTER A → LATIN CAPITAL LETTER A# + {1042, "B"}, // MA# ( В → B ) CYRILLIC CAPITAL LETTER VE → LATIN CAPITAL LETTER B# + {1045, "E"}, // MA# ( Е → E ) CYRILLIC CAPITAL LETTER IE → LATIN CAPITAL LETTER E# + {1047, "3"}, // MA# ( З → 3 ) CYRILLIC CAPITAL LETTER ZE → DIGIT THREE# + {1050, "K"}, // MA# ( К → K ) CYRILLIC CAPITAL LETTER KA → LATIN CAPITAL LETTER K# + {1052, "M"}, // MA# ( М → M ) CYRILLIC CAPITAL LETTER EM → LATIN CAPITAL LETTER M# + {1053, "H"}, // MA# ( Н → H ) CYRILLIC CAPITAL LETTER EN → LATIN CAPITAL LETTER H# + {1054, "O"}, // MA# ( О → O ) CYRILLIC CAPITAL LETTER O → LATIN CAPITAL LETTER O# + {1056, "P"}, // MA# ( Р → P ) CYRILLIC CAPITAL LETTER ER → LATIN CAPITAL LETTER P# + {1057, "C"}, // MA# ( С → C ) CYRILLIC CAPITAL LETTER ES → LATIN CAPITAL LETTER C# + {1058, "T"}, // MA# ( Т → T ) CYRILLIC CAPITAL LETTER TE → LATIN CAPITAL LETTER T# + {1059, "Y"}, // MA# ( У → Y ) CYRILLIC CAPITAL LETTER U → LATIN CAPITAL LETTER Y# + {1061, "X"}, // MA# ( Х → X ) CYRILLIC CAPITAL LETTER HA → LATIN CAPITAL LETTER X# + {1067, "bl"}, // MA# ( Ы → bl ) CYRILLIC CAPITAL LETTER YERU → LATIN SMALL LETTER B, LATIN SMALL LETTER L# →ЬІ→→Ь1→ + {1068, "b"}, // MA# ( Ь → b ) CYRILLIC CAPITAL LETTER SOFT SIGN → LATIN SMALL LETTER B# →Ƅ→ + {1070, "lO"}, // MA# ( Ю → lO ) CYRILLIC CAPITAL LETTER YU → LATIN SMALL LETTER L, LATIN CAPITAL LETTER O# →IO→ + {1072, "a"}, // MA# ( а → a ) CYRILLIC SMALL LETTER A → LATIN SMALL LETTER A# + {1073, "6"}, // MA# ( б → 6 ) CYRILLIC SMALL LETTER BE → DIGIT SIX# + {1075, "r"}, // MA# ( г → r ) CYRILLIC SMALL LETTER GHE → LATIN SMALL LETTER R# + {1077, "e"}, // MA# ( е → e ) CYRILLIC SMALL LETTER IE → LATIN SMALL LETTER E# + {1086, "o"}, // MA# ( о → o ) CYRILLIC SMALL LETTER O → LATIN SMALL LETTER O# + {1088, "p"}, // MA# ( р → p ) CYRILLIC SMALL LETTER ER → LATIN SMALL LETTER P# + {1089, "c"}, // MA# ( с → c ) CYRILLIC SMALL LETTER ES → LATIN SMALL LETTER C# + {1091, "y"}, // MA# ( у → y ) CYRILLIC SMALL LETTER U → LATIN SMALL LETTER Y# + {1093, "x"}, // MA# ( х → x ) CYRILLIC SMALL LETTER HA → LATIN SMALL LETTER X# + {1109, "s"}, // MA# ( ѕ → s ) CYRILLIC SMALL LETTER DZE → LATIN SMALL LETTER S# + {1110, "i"}, // MA# ( і → i ) CYRILLIC SMALL LETTER BYELORUSSIAN-UKRAINIAN I → LATIN SMALL LETTER I# + {1112, "j"}, // MA# ( ј → j ) CYRILLIC SMALL LETTER JE → LATIN SMALL LETTER J# + {1121, "w"}, // MA# ( ѡ → w ) CYRILLIC SMALL LETTER OMEGA → LATIN SMALL LETTER W# + {1140, "V"}, // MA# ( Ѵ → V ) CYRILLIC CAPITAL LETTER IZHITSA → LATIN CAPITAL LETTER V# + {1141, "v"}, // MA# ( ѵ → v ) CYRILLIC SMALL LETTER IZHITSA → LATIN SMALL LETTER V# + {1169, "r'"}, // MA# ( ґ → r' ) CYRILLIC SMALL LETTER GHE WITH UPTURN → LATIN SMALL LETTER R, APOSTROPHE# →гˈ→ + {1198, "Y"}, // MA# ( Ү → Y ) CYRILLIC CAPITAL LETTER STRAIGHT U → LATIN CAPITAL LETTER Y# + {1199, "y"}, // MA# ( ү → y ) CYRILLIC SMALL LETTER STRAIGHT U → LATIN SMALL LETTER Y# →γ→ + {1211, "h"}, // MA# ( һ → h ) CYRILLIC SMALL LETTER SHHA → LATIN SMALL LETTER H# + {1213, "e"}, // MA# ( ҽ → e ) CYRILLIC SMALL LETTER ABKHASIAN CHE → LATIN SMALL LETTER E# + {1216, "l"}, // MA# ( Ӏ → l ) CYRILLIC LETTER PALOCHKA → LATIN SMALL LETTER L# + {1231, "i"}, // MA# ( ӏ → i ) CYRILLIC SMALL LETTER PALOCHKA → LATIN SMALL LETTER I# →ı→ + {1236, "AE"}, // MA# ( Ӕ → AE ) CYRILLIC CAPITAL LIGATURE A IE → LATIN CAPITAL LETTER A, LATIN CAPITAL LETTER E# →Æ→ + {1237, "ae"}, // MA# ( ӕ → ae ) CYRILLIC SMALL LIGATURE A IE → LATIN SMALL LETTER A, LATIN SMALL LETTER E# →ае→ + {1248, "3"}, // MA# ( Ӡ → 3 ) CYRILLIC CAPITAL LETTER ABKHASIAN DZE → DIGIT THREE# →Ʒ→ + {1281, "d"}, // MA# ( ԁ → d ) CYRILLIC SMALL LETTER KOMI DE → LATIN SMALL LETTER D# + {1292, "G"}, // MA# ( Ԍ → G ) CYRILLIC CAPITAL LETTER KOMI SJE → LATIN CAPITAL LETTER G# + {1307, "q"}, // MA# ( ԛ → q ) CYRILLIC SMALL LETTER QA → LATIN SMALL LETTER Q# + {1308, "W"}, // MA# ( Ԝ → W ) CYRILLIC CAPITAL LETTER WE → LATIN CAPITAL LETTER W# + {1309, "w"}, // MA# ( ԝ → w ) CYRILLIC SMALL LETTER WE → LATIN SMALL LETTER W# + {1357, "U"}, // MA# ( Ս → U ) ARMENIAN CAPITAL LETTER SEH → LATIN CAPITAL LETTER U# + {1359, "S"}, // MA# ( Տ → S ) ARMENIAN CAPITAL LETTER TIWN → LATIN CAPITAL LETTER S# + {1365, "O"}, // MA# ( Օ → O ) ARMENIAN CAPITAL LETTER OH → LATIN CAPITAL LETTER O# + {1370, "'"}, // MA#* ( ՚ → ' ) ARMENIAN APOSTROPHE → APOSTROPHE# →’→ + {1373, "'"}, // MA#* ( ՝ → ' ) ARMENIAN COMMA → APOSTROPHE# →ˋ→→`→→‘→ + {1377, "w"}, // MA# ( ա → w ) ARMENIAN SMALL LETTER AYB → LATIN SMALL LETTER W# →ɯ→ + {1379, "q"}, // MA# ( գ → q ) ARMENIAN SMALL LETTER GIM → LATIN SMALL LETTER Q# + {1382, "q"}, // MA# ( զ → q ) ARMENIAN SMALL LETTER ZA → LATIN SMALL LETTER Q# + {1392, "h"}, // MA# ( հ → h ) ARMENIAN SMALL LETTER HO → LATIN SMALL LETTER H# + {1400, "n"}, // MA# ( ո → n ) ARMENIAN SMALL LETTER VO → LATIN SMALL LETTER N# + {1404, "n"}, // MA# ( ռ → n ) ARMENIAN SMALL LETTER RA → LATIN SMALL LETTER N# + {1405, "u"}, // MA# ( ս → u ) ARMENIAN SMALL LETTER SEH → LATIN SMALL LETTER U# + {1409, "g"}, // MA# ( ց → g ) ARMENIAN SMALL LETTER CO → LATIN SMALL LETTER G# + {1412, "f"}, // MA# ( ք → f ) ARMENIAN SMALL LETTER KEH → LATIN SMALL LETTER F# + {1413, "o"}, // MA# ( օ → o ) ARMENIAN SMALL LETTER OH → LATIN SMALL LETTER O# + {1417, ":"}, // MA#* ( ։ → : ) ARMENIAN FULL STOP → COLON# + {1472, "l"}, // MA#* ( ‎׀‎ → l ) HEBREW PUNCTUATION PASEQ → LATIN SMALL LETTER L# →|→ + {1475, ":"}, // MA#* ( ‎׃‎ → : ) HEBREW PUNCTUATION SOF PASUQ → COLON# + {1493, "l"}, // MA# ( ‎ו‎ → l ) HEBREW LETTER VAV → LATIN SMALL LETTER L# + {1496, "v"}, // MA# ( ‎ט‎ → v ) HEBREW LETTER TET → LATIN SMALL LETTER V# + {1497, "'"}, // MA# ( ‎י‎ → ' ) HEBREW LETTER YOD → APOSTROPHE# + {1503, "l"}, // MA# ( ‎ן‎ → l ) HEBREW LETTER FINAL NUN → LATIN SMALL LETTER L# + {1505, "o"}, // MA# ( ‎ס‎ → o ) HEBREW LETTER SAMEKH → LATIN SMALL LETTER O# + {1520, "ll"}, // MA# ( ‎װ‎ → ll ) HEBREW LIGATURE YIDDISH DOUBLE VAV → LATIN SMALL LETTER L, LATIN SMALL LETTER L# →‎וו‎→ + {1521, "l'"}, // MA# ( ‎ױ‎ → l' ) HEBREW LIGATURE YIDDISH VAV YOD → LATIN SMALL LETTER L, APOSTROPHE# →‎וי‎→ + {1522, "\""}, // MA# ( ‎ײ‎ → '' ) HEBREW LIGATURE YIDDISH DOUBLE YOD → APOSTROPHE, APOSTROPHE# →‎יי‎→# Converted to a quote. + {1523, "'"}, // MA#* ( ‎׳‎ → ' ) HEBREW PUNCTUATION GERESH → APOSTROPHE# + {1524, "\""}, // MA#* ( ‎״‎ → '' ) HEBREW PUNCTUATION GERSHAYIM → APOSTROPHE, APOSTROPHE# →"→# Converted to a quote. + {1549, ","}, // MA#* ( ‎؍‎ → , ) ARABIC DATE SEPARATOR → COMMA# →‎٫‎→ + {1575, "l"}, // MA# ( ‎ا‎ → l ) ARABIC LETTER ALEF → LATIN SMALL LETTER L# →1→ + {1607, "o"}, // MA# ( ‎ه‎ → o ) ARABIC LETTER HEH → LATIN SMALL LETTER O# + {1632, "."}, // MA# ( ‎٠‎ → . ) ARABIC-INDIC DIGIT ZERO → FULL STOP# + {1633, "l"}, // MA# ( ‎١‎ → l ) ARABIC-INDIC DIGIT ONE → LATIN SMALL LETTER L# →1→ + {1637, "o"}, // MA# ( ‎٥‎ → o ) ARABIC-INDIC DIGIT FIVE → LATIN SMALL LETTER O# + {1639, "V"}, // MA# ( ‎٧‎ → V ) ARABIC-INDIC DIGIT SEVEN → LATIN CAPITAL LETTER V# + {1643, ","}, // MA#* ( ‎٫‎ → , ) ARABIC DECIMAL SEPARATOR → COMMA# + {1645, "*"}, // MA#* ( ‎٭‎ → * ) ARABIC FIVE POINTED STAR → ASTERISK# + {1726, "o"}, // MA# ( ‎ھ‎ → o ) ARABIC LETTER HEH DOACHASHMEE → LATIN SMALL LETTER O# →‎ه‎→ + {1729, "o"}, // MA# ( ‎ہ‎ → o ) ARABIC LETTER HEH GOAL → LATIN SMALL LETTER O# →‎ه‎→ + {1748, "-"}, // MA#* ( ‎۔‎ → - ) ARABIC FULL STOP → HYPHEN-MINUS# →‐→ + {1749, "o"}, // MA# ( ‎ە‎ → o ) ARABIC LETTER AE → LATIN SMALL LETTER O# →‎ه‎→ + {1776, "."}, // MA# ( ۰ → . ) EXTENDED ARABIC-INDIC DIGIT ZERO → FULL STOP# →‎٠‎→ + {1777, "l"}, // MA# ( ۱ → l ) EXTENDED ARABIC-INDIC DIGIT ONE → LATIN SMALL LETTER L# →1→ + {1781, "o"}, // MA# ( ۵ → o ) EXTENDED ARABIC-INDIC DIGIT FIVE → LATIN SMALL LETTER O# →‎٥‎→ + {1783, "V"}, // MA# ( ۷ → V ) EXTENDED ARABIC-INDIC DIGIT SEVEN → LATIN CAPITAL LETTER V# →‎٧‎→ + {1793, "."}, // MA#* ( ‎܁‎ → . ) SYRIAC SUPRALINEAR FULL STOP → FULL STOP# + {1794, "."}, // MA#* ( ‎܂‎ → . ) SYRIAC SUBLINEAR FULL STOP → FULL STOP# + {1795, ":"}, // MA#* ( ‎܃‎ → : ) SYRIAC SUPRALINEAR COLON → COLON# + {1796, ":"}, // MA#* ( ‎܄‎ → : ) SYRIAC SUBLINEAR COLON → COLON# + {1984, "O"}, // MA# ( ‎߀‎ → O ) NKO DIGIT ZERO → LATIN CAPITAL LETTER O# →0→ + {1994, "l"}, // MA# ( ‎ߊ‎ → l ) NKO LETTER A → LATIN SMALL LETTER L# →∣→→ǀ→ + {2036, "'"}, // MA# ( ‎ߴ‎ → ' ) NKO HIGH TONE APOSTROPHE → APOSTROPHE# →’→ + {2037, "'"}, // MA# ( ‎ߵ‎ → ' ) NKO LOW TONE APOSTROPHE → APOSTROPHE# →‘→ + {2042, "_"}, // MA# ( ‎ߺ‎ → _ ) NKO LAJANYALAN → LOW LINE# + {2307, ":"}, // MA# ( ः → : ) DEVANAGARI SIGN VISARGA → COLON# + {2406, "o"}, // MA# ( ० → o ) DEVANAGARI DIGIT ZERO → LATIN SMALL LETTER O# + {2429, "?"}, // MA# ( ॽ → ? ) DEVANAGARI LETTER GLOTTAL STOP → QUESTION MARK# + {2534, "O"}, // MA# ( ০ → O ) BENGALI DIGIT ZERO → LATIN CAPITAL LETTER O# →0→ + {2538, "8"}, // MA# ( ৪ → 8 ) BENGALI DIGIT FOUR → DIGIT EIGHT# + {2541, "9"}, // MA# ( ৭ → 9 ) BENGALI DIGIT SEVEN → DIGIT NINE# + {2662, "o"}, // MA# ( ੦ → o ) GURMUKHI DIGIT ZERO → LATIN SMALL LETTER O# + {2663, "9"}, // MA# ( ੧ → 9 ) GURMUKHI DIGIT ONE → DIGIT NINE# + {2666, "8"}, // MA# ( ੪ → 8 ) GURMUKHI DIGIT FOUR → DIGIT EIGHT# + {2691, ":"}, // MA# ( ઃ → : ) GUJARATI SIGN VISARGA → COLON# + {2790, "o"}, // MA# ( ૦ → o ) GUJARATI DIGIT ZERO → LATIN SMALL LETTER O# + {2819, "8"}, // MA# ( ଃ → 8 ) ORIYA SIGN VISARGA → DIGIT EIGHT# + {2848, "O"}, // MA# ( ଠ → O ) ORIYA LETTER TTHA → LATIN CAPITAL LETTER O# →୦→→0→ + {2918, "O"}, // MA# ( ୦ → O ) ORIYA DIGIT ZERO → LATIN CAPITAL LETTER O# →0→ + {2920, "9"}, // MA# ( ୨ → 9 ) ORIYA DIGIT TWO → DIGIT NINE# + {3046, "o"}, // MA# ( ௦ → o ) TAMIL DIGIT ZERO → LATIN SMALL LETTER O# + {3074, "o"}, // MA# ( ం → o ) TELUGU SIGN ANUSVARA → LATIN SMALL LETTER O# + {3174, "o"}, // MA# ( ౦ → o ) TELUGU DIGIT ZERO → LATIN SMALL LETTER O# + {3202, "o"}, // MA# ( ಂ → o ) KANNADA SIGN ANUSVARA → LATIN SMALL LETTER O# + {3302, "o"}, // MA# ( ೦ → o ) KANNADA DIGIT ZERO → LATIN SMALL LETTER O# →౦→ + {3330, "o"}, // MA# ( ം → o ) MALAYALAM SIGN ANUSVARA → LATIN SMALL LETTER O# + {3360, "o"}, // MA# ( ഠ → o ) MALAYALAM LETTER TTHA → LATIN SMALL LETTER O# + {3430, "o"}, // MA# ( ൦ → o ) MALAYALAM DIGIT ZERO → LATIN SMALL LETTER O# + {3437, "9"}, // MA# ( ൭ → 9 ) MALAYALAM DIGIT SEVEN → DIGIT NINE# + {3458, "o"}, // MA# ( ං → o ) SINHALA SIGN ANUSVARAYA → LATIN SMALL LETTER O# + {3664, "o"}, // MA# ( ๐ → o ) THAI DIGIT ZERO → LATIN SMALL LETTER O# + {3792, "o"}, // MA# ( ໐ → o ) LAO DIGIT ZERO → LATIN SMALL LETTER O# + {4125, "o"}, // MA# ( ဝ → o ) MYANMAR LETTER WA → LATIN SMALL LETTER O# + {4160, "o"}, // MA# ( ၀ → o ) MYANMAR DIGIT ZERO → LATIN SMALL LETTER O# + {4327, "y"}, // MA# ( ყ → y ) GEORGIAN LETTER QAR → LATIN SMALL LETTER Y# + {4351, "o"}, // MA# ( ჿ → o ) GEORGIAN LETTER LABIAL SIGN → LATIN SMALL LETTER O# + {4608, "U"}, // MA# ( ሀ → U ) ETHIOPIC SYLLABLE HA → LATIN CAPITAL LETTER U# →Ս→ + {4816, "O"}, // MA# ( ዐ → O ) ETHIOPIC SYLLABLE PHARYNGEAL A → LATIN CAPITAL LETTER O# →Օ→ + {5024, "D"}, // MA# ( Ꭰ → D ) CHEROKEE LETTER A → LATIN CAPITAL LETTER D# + {5025, "R"}, // MA# ( Ꭱ → R ) CHEROKEE LETTER E → LATIN CAPITAL LETTER R# + {5026, "T"}, // MA# ( Ꭲ → T ) CHEROKEE LETTER I → LATIN CAPITAL LETTER T# + {5028, "O'"}, // MA# ( Ꭴ → O' ) CHEROKEE LETTER U → LATIN CAPITAL LETTER O, APOSTROPHE# →Ơ→→Oʼ→ + {5029, "i"}, // MA# ( Ꭵ → i ) CHEROKEE LETTER V → LATIN SMALL LETTER I# + {5033, "Y"}, // MA# ( Ꭹ → Y ) CHEROKEE LETTER GI → LATIN CAPITAL LETTER Y# + {5034, "A"}, // MA# ( Ꭺ → A ) CHEROKEE LETTER GO → LATIN CAPITAL LETTER A# + {5035, "J"}, // MA# ( Ꭻ → J ) CHEROKEE LETTER GU → LATIN CAPITAL LETTER J# + {5036, "E"}, // MA# ( Ꭼ → E ) CHEROKEE LETTER GV → LATIN CAPITAL LETTER E# + {5038, "?"}, // MA# ( Ꭾ → ? ) CHEROKEE LETTER HE → QUESTION MARK# →Ɂ→→ʔ→ + {5043, "W"}, // MA# ( Ꮃ → W ) CHEROKEE LETTER LA → LATIN CAPITAL LETTER W# + {5047, "M"}, // MA# ( Ꮇ → M ) CHEROKEE LETTER LU → LATIN CAPITAL LETTER M# + {5051, "H"}, // MA# ( Ꮋ → H ) CHEROKEE LETTER MI → LATIN CAPITAL LETTER H# + {5053, "Y"}, // MA# ( Ꮍ → Y ) CHEROKEE LETTER MU → LATIN CAPITAL LETTER Y# →Ꭹ→ + {5056, "G"}, // MA# ( Ꮐ → G ) CHEROKEE LETTER NAH → LATIN CAPITAL LETTER G# + {5058, "h"}, // MA# ( Ꮒ → h ) CHEROKEE LETTER NI → LATIN SMALL LETTER H# + {5059, "Z"}, // MA# ( Ꮓ → Z ) CHEROKEE LETTER NO → LATIN CAPITAL LETTER Z# + {5070, "4"}, // MA# ( Ꮞ → 4 ) CHEROKEE LETTER SE → DIGIT FOUR# + {5071, "b"}, // MA# ( Ꮟ → b ) CHEROKEE LETTER SI → LATIN SMALL LETTER B# + {5074, "R"}, // MA# ( Ꮢ → R ) CHEROKEE LETTER SV → LATIN CAPITAL LETTER R# + {5076, "W"}, // MA# ( Ꮤ → W ) CHEROKEE LETTER TA → LATIN CAPITAL LETTER W# + {5077, "S"}, // MA# ( Ꮥ → S ) CHEROKEE LETTER DE → LATIN CAPITAL LETTER S# + {5081, "V"}, // MA# ( Ꮩ → V ) CHEROKEE LETTER DO → LATIN CAPITAL LETTER V# + {5082, "S"}, // MA# ( Ꮪ → S ) CHEROKEE LETTER DU → LATIN CAPITAL LETTER S# + {5086, "L"}, // MA# ( Ꮮ → L ) CHEROKEE LETTER TLE → LATIN CAPITAL LETTER L# + {5087, "C"}, // MA# ( Ꮯ → C ) CHEROKEE LETTER TLI → LATIN CAPITAL LETTER C# + {5090, "P"}, // MA# ( Ꮲ → P ) CHEROKEE LETTER TLV → LATIN CAPITAL LETTER P# + {5094, "K"}, // MA# ( Ꮶ → K ) CHEROKEE LETTER TSO → LATIN CAPITAL LETTER K# + {5095, "d"}, // MA# ( Ꮷ → d ) CHEROKEE LETTER TSU → LATIN SMALL LETTER D# + {5102, "6"}, // MA# ( Ꮾ → 6 ) CHEROKEE LETTER WV → DIGIT SIX# + {5107, "G"}, // MA# ( Ᏻ → G ) CHEROKEE LETTER YU → LATIN CAPITAL LETTER G# + {5108, "B"}, // MA# ( Ᏼ → B ) CHEROKEE LETTER YV → LATIN CAPITAL LETTER B# + {5120, "="}, // MA#* ( ᐀ → = ) CANADIAN SYLLABICS HYPHEN → EQUALS SIGN# + {5167, "V"}, // MA# ( ᐯ → V ) CANADIAN SYLLABICS PE → LATIN CAPITAL LETTER V# + {5171, ">"}, // MA# ( ᐳ → > ) CANADIAN SYLLABICS PO → GREATER-THAN SIGN# + {5176, "<"}, // MA# ( ᐸ → < ) CANADIAN SYLLABICS PA → LESS-THAN SIGN# + {5194, "'"}, // MA# ( ᑊ → ' ) CANADIAN SYLLABICS WEST-CREE P → APOSTROPHE# →ˈ→ + {5196, "U"}, // MA# ( ᑌ → U ) CANADIAN SYLLABICS TE → LATIN CAPITAL LETTER U# + {5223, "U'"}, // MA# ( ᑧ → U' ) CANADIAN SYLLABICS TTE → LATIN CAPITAL LETTER U, APOSTROPHE# →ᑌᑊ→→ᑌ'→ + {5229, "P"}, // MA# ( ᑭ → P ) CANADIAN SYLLABICS KI → LATIN CAPITAL LETTER P# + {5231, "d"}, // MA# ( ᑯ → d ) CANADIAN SYLLABICS KO → LATIN SMALL LETTER D# + {5254, "P'"}, // MA# ( ᒆ → P' ) CANADIAN SYLLABICS SOUTH-SLAVEY KIH → LATIN CAPITAL LETTER P, APOSTROPHE# →ᑭᑊ→ + {5255, "d'"}, // MA# ( ᒇ → d' ) CANADIAN SYLLABICS SOUTH-SLAVEY KOH → LATIN SMALL LETTER D, APOSTROPHE# →ᑯᑊ→ + {5261, "J"}, // MA# ( ᒍ → J ) CANADIAN SYLLABICS CO → LATIN CAPITAL LETTER J# + {5290, "L"}, // MA# ( ᒪ → L ) CANADIAN SYLLABICS MA → LATIN CAPITAL LETTER L# + {5311, "2"}, // MA# ( ᒿ → 2 ) CANADIAN SYLLABICS SAYISI M → DIGIT TWO# + {5441, "x"}, // MA# ( ᕁ → x ) CANADIAN SYLLABICS SAYISI YI → LATIN SMALL LETTER X# →᙮→ + {5500, "H"}, // MA# ( ᕼ → H ) CANADIAN SYLLABICS NUNAVUT H → LATIN CAPITAL LETTER H# + {5501, "x"}, // MA# ( ᕽ → x ) CANADIAN SYLLABICS HK → LATIN SMALL LETTER X# →ᕁ→→᙮→ + {5511, "R"}, // MA# ( ᖇ → R ) CANADIAN SYLLABICS TLHI → LATIN CAPITAL LETTER R# + {5551, "b"}, // MA# ( ᖯ → b ) CANADIAN SYLLABICS AIVILIK B → LATIN SMALL LETTER B# + {5556, "F"}, // MA# ( ᖴ → F ) CANADIAN SYLLABICS BLACKFOOT WE → LATIN CAPITAL LETTER F# + {5573, "A"}, // MA# ( ᗅ → A ) CANADIAN SYLLABICS CARRIER GHO → LATIN CAPITAL LETTER A# + {5598, "D"}, // MA# ( ᗞ → D ) CANADIAN SYLLABICS CARRIER THE → LATIN CAPITAL LETTER D# + {5610, "D"}, // MA# ( ᗪ → D ) CANADIAN SYLLABICS CARRIER PE → LATIN CAPITAL LETTER D# →ᗞ→ + {5616, "M"}, // MA# ( ᗰ → M ) CANADIAN SYLLABICS CARRIER GO → LATIN CAPITAL LETTER M# + {5623, "B"}, // MA# ( ᗷ → B ) CANADIAN SYLLABICS CARRIER KHE → LATIN CAPITAL LETTER B# + {5741, "X"}, // MA#* ( ᙭ → X ) CANADIAN SYLLABICS CHI SIGN → LATIN CAPITAL LETTER X# + {5742, "x"}, // MA#* ( ᙮ → x ) CANADIAN SYLLABICS FULL STOP → LATIN SMALL LETTER X# + {5760, " "}, // MA#* (   → ) OGHAM SPACE MARK → SPACE# + {5810, "<"}, // MA# ( ᚲ → < ) RUNIC LETTER KAUNA → LESS-THAN SIGN# + {5815, "X"}, // MA# ( ᚷ → X ) RUNIC LETTER GEBO GYFU G → LATIN CAPITAL LETTER X# + {5825, "l"}, // MA# ( ᛁ → l ) RUNIC LETTER ISAZ IS ISS I → LATIN SMALL LETTER L# →I→ + {5836, "'"}, // MA# ( ᛌ → ' ) RUNIC LETTER SHORT-TWIG-SOL S → APOSTROPHE# + {5845, "K"}, // MA# ( ᛕ → K ) RUNIC LETTER OPEN-P → LATIN CAPITAL LETTER K# + {5846, "M"}, // MA# ( ᛖ → M ) RUNIC LETTER EHWAZ EH E → LATIN CAPITAL LETTER M# + {5868, ":"}, // MA#* ( ᛬ → : ) RUNIC MULTIPLE PUNCTUATION → COLON# + {5869, "+"}, // MA#* ( ᛭ → + ) RUNIC CROSS PUNCTUATION → PLUS SIGN# + {5941, "/"}, // MA#* ( ᜵ → / ) PHILIPPINE SINGLE PUNCTUATION → SOLIDUS# + {6147, ":"}, // MA#* ( ᠃ → : ) MONGOLIAN FULL STOP → COLON# + {6153, ":"}, // MA#* ( ᠉ → : ) MONGOLIAN MANCHU FULL STOP → COLON# + {7379, "\""}, // MA#* ( ᳓ → '' ) VEDIC SIGN NIHSHVASA → APOSTROPHE, APOSTROPHE# →″→→"→# Converted to a quote. + {7428, "c"}, // MA# ( ᴄ → c ) LATIN LETTER SMALL CAPITAL C → LATIN SMALL LETTER C# + {7439, "o"}, // MA# ( ᴏ → o ) LATIN LETTER SMALL CAPITAL O → LATIN SMALL LETTER O# + {7441, "o"}, // MA# ( ᴑ → o ) LATIN SMALL LETTER SIDEWAYS O → LATIN SMALL LETTER O# + {7452, "u"}, // MA# ( ᴜ → u ) LATIN LETTER SMALL CAPITAL U → LATIN SMALL LETTER U# + {7456, "v"}, // MA# ( ᴠ → v ) LATIN LETTER SMALL CAPITAL V → LATIN SMALL LETTER V# + {7457, "w"}, // MA# ( ᴡ → w ) LATIN LETTER SMALL CAPITAL W → LATIN SMALL LETTER W# + {7458, "z"}, // MA# ( ᴢ → z ) LATIN LETTER SMALL CAPITAL Z → LATIN SMALL LETTER Z# + {7462, "r"}, // MA# ( ᴦ → r ) GREEK LETTER SMALL CAPITAL GAMMA → LATIN SMALL LETTER R# →г→ + {7531, "ue"}, // MA# ( ᵫ → ue ) LATIN SMALL LETTER UE → LATIN SMALL LETTER U, LATIN SMALL LETTER E# + {7555, "g"}, // MA# ( ᶃ → g ) LATIN SMALL LETTER G WITH PALATAL HOOK → LATIN SMALL LETTER G# + {7564, "y"}, // MA# ( ᶌ → y ) LATIN SMALL LETTER V WITH PALATAL HOOK → LATIN SMALL LETTER Y# + {7837, "f"}, // MA# ( ẝ → f ) LATIN SMALL LETTER LONG S WITH HIGH STROKE → LATIN SMALL LETTER F# + {7935, "y"}, // MA# ( ỿ → y ) LATIN SMALL LETTER Y WITH LOOP → LATIN SMALL LETTER Y# + {8125, "'"}, // MA#* ( ᾽ → ' ) GREEK KORONIS → APOSTROPHE# →’→ + {8126, "i"}, // MA# ( ι → i ) GREEK PROSGEGRAMMENI → LATIN SMALL LETTER I# →ι→ + {8127, "'"}, // MA#* ( ᾿ → ' ) GREEK PSILI → APOSTROPHE# →’→ + {8128, "~"}, // MA#* ( ῀ → ~ ) GREEK PERISPOMENI → TILDE# →˜→ + {8175, "'"}, // MA#* ( ` → ' ) GREEK VARIA → APOSTROPHE# →ˋ→→`→→‘→ + {8189, "'"}, // MA#* ( ´ → ' ) GREEK OXIA → APOSTROPHE# →´→→΄→→ʹ→ + {8190, "'"}, // MA#* ( ῾ → ' ) GREEK DASIA → APOSTROPHE# →‛→→′→ + {8192, " "}, // MA#* (   → ) EN QUAD → SPACE# + {8193, " "}, // MA#* (   → ) EM QUAD → SPACE# + {8194, " "}, // MA#* (   → ) EN SPACE → SPACE# + {8195, " "}, // MA#* (   → ) EM SPACE → SPACE# + {8196, " "}, // MA#* (   → ) THREE-PER-EM SPACE → SPACE# + {8197, " "}, // MA#* (   → ) FOUR-PER-EM SPACE → SPACE# + {8198, " "}, // MA#* (   → ) SIX-PER-EM SPACE → SPACE# + {8199, " "}, // MA#* (   → ) FIGURE SPACE → SPACE# + {8200, " "}, // MA#* (   → ) PUNCTUATION SPACE → SPACE# + {8201, " "}, // MA#* (   → ) THIN SPACE → SPACE# + {8202, " "}, // MA#* (   → ) HAIR SPACE → SPACE# + {8208, "-"}, // MA#* ( ‐ → - ) HYPHEN → HYPHEN-MINUS# + {8209, "-"}, // MA#* ( ‑ → - ) NON-BREAKING HYPHEN → HYPHEN-MINUS# + {8210, "-"}, // MA#* ( ‒ → - ) FIGURE DASH → HYPHEN-MINUS# + {8211, "-"}, // MA#* ( – → - ) EN DASH → HYPHEN-MINUS# + {8214, "ll"}, // MA#* ( ‖ → ll ) DOUBLE VERTICAL LINE → LATIN SMALL LETTER L, LATIN SMALL LETTER L# →∥→→||→ + {8216, "'"}, // MA#* ( ‘ → ' ) LEFT SINGLE QUOTATION MARK → APOSTROPHE# + {8217, "'"}, // MA#* ( ’ → ' ) RIGHT SINGLE QUOTATION MARK → APOSTROPHE# + {8218, ","}, // MA#* ( ‚ → , ) SINGLE LOW-9 QUOTATION MARK → COMMA# + {8219, "'"}, // MA#* ( ‛ → ' ) SINGLE HIGH-REVERSED-9 QUOTATION MARK → APOSTROPHE# →′→ + {8220, "\""}, // MA#* ( “ → '' ) LEFT DOUBLE QUOTATION MARK → APOSTROPHE, APOSTROPHE# →"→# Converted to a quote. + {8221, "\""}, // MA#* ( ” → '' ) RIGHT DOUBLE QUOTATION MARK → APOSTROPHE, APOSTROPHE# →"→# Converted to a quote. + {8223, "\""}, // MA#* ( ‟ → '' ) DOUBLE HIGH-REVERSED-9 QUOTATION MARK → APOSTROPHE, APOSTROPHE# →“→→"→# Converted to a quote. + {8228, "."}, // MA#* ( ․ → . ) ONE DOT LEADER → FULL STOP# + {8229, ".."}, // MA#* ( ‥ → .. ) TWO DOT LEADER → FULL STOP, FULL STOP# + {8230, "..."}, // MA#* ( … → ... ) HORIZONTAL ELLIPSIS → FULL STOP, FULL STOP, FULL STOP# + {8232, " "}, // MA#* ( → ) LINE SEPARATOR → SPACE# + {8233, " "}, // MA#* ( → ) PARAGRAPH SEPARATOR → SPACE# + {8239, " "}, // MA#* (   → ) NARROW NO-BREAK SPACE → SPACE# + {8242, "'"}, // MA#* ( ′ → ' ) PRIME → APOSTROPHE# + {8243, "\""}, // MA#* ( ″ → '' ) DOUBLE PRIME → APOSTROPHE, APOSTROPHE# →"→# Converted to a quote. + {8244, "'''"}, // MA#* ( ‴ → ''' ) TRIPLE PRIME → APOSTROPHE, APOSTROPHE, APOSTROPHE# →′′′→ + {8245, "'"}, // MA#* ( ‵ → ' ) REVERSED PRIME → APOSTROPHE# →ʽ→→‘→ + {8246, "\""}, // MA#* ( ‶ → '' ) REVERSED DOUBLE PRIME → APOSTROPHE, APOSTROPHE# →‵‵→# Converted to a quote. + {8247, "'''"}, // MA#* ( ‷ → ''' ) REVERSED TRIPLE PRIME → APOSTROPHE, APOSTROPHE, APOSTROPHE# →‵‵‵→ + {8249, "<"}, // MA#* ( ‹ → < ) SINGLE LEFT-POINTING ANGLE QUOTATION MARK → LESS-THAN SIGN# + {8250, ">"}, // MA#* ( › → > ) SINGLE RIGHT-POINTING ANGLE QUOTATION MARK → GREATER-THAN SIGN# + {8252, "!!"}, // MA#* ( ‼ → !! ) DOUBLE EXCLAMATION MARK → EXCLAMATION MARK, EXCLAMATION MARK# + {8257, "/"}, // MA#* ( ⁁ → / ) CARET INSERTION POINT → SOLIDUS# + {8259, "-"}, // MA#* ( ⁃ → - ) HYPHEN BULLET → HYPHEN-MINUS# →‐→ + {8260, "/"}, // MA#* ( ⁄ → / ) FRACTION SLASH → SOLIDUS# + {8263, "??"}, // MA#* ( ⁇ → ?? ) DOUBLE QUESTION MARK → QUESTION MARK, QUESTION MARK# + {8264, "?!"}, // MA#* ( ⁈ → ?! ) QUESTION EXCLAMATION MARK → QUESTION MARK, EXCLAMATION MARK# + {8265, "!?"}, // MA#* ( ⁉ → !? ) EXCLAMATION QUESTION MARK → EXCLAMATION MARK, QUESTION MARK# + {8270, "*"}, // MA#* ( ⁎ → * ) LOW ASTERISK → ASTERISK# + {8275, "~"}, // MA#* ( ⁓ → ~ ) SWUNG DASH → TILDE# + {8279, "''''"}, // MA#* ( ⁗ → '''' ) QUADRUPLE PRIME → APOSTROPHE, APOSTROPHE, APOSTROPHE, APOSTROPHE# →′′′′→ + {8282, ":"}, // MA#* ( ⁚ → : ) TWO DOT PUNCTUATION → COLON# + {8287, " "}, // MA#* (   → ) MEDIUM MATHEMATICAL SPACE → SPACE# + {8360, "Rs"}, // MA#* ( ₨ → Rs ) RUPEE SIGN → LATIN CAPITAL LETTER R, LATIN SMALL LETTER S# + {8374, "lt"}, // MA#* ( ₶ → lt ) LIVRE TOURNOIS SIGN → LATIN SMALL LETTER L, LATIN SMALL LETTER T# + {8448, "a/c"}, // MA#* ( ℀ → a/c ) ACCOUNT OF → LATIN SMALL LETTER A, SOLIDUS, LATIN SMALL LETTER C# + {8449, "a/s"}, // MA#* ( ℁ → a/s ) ADDRESSED TO THE SUBJECT → LATIN SMALL LETTER A, SOLIDUS, LATIN SMALL LETTER S# + {8450, "C"}, // MA# ( ℂ → C ) DOUBLE-STRUCK CAPITAL C → LATIN CAPITAL LETTER C# + {8453, "c/o"}, // MA#* ( ℅ → c/o ) CARE OF → LATIN SMALL LETTER C, SOLIDUS, LATIN SMALL LETTER O# + {8454, "c/u"}, // MA#* ( ℆ → c/u ) CADA UNA → LATIN SMALL LETTER C, SOLIDUS, LATIN SMALL LETTER U# + {8458, "g"}, // MA# ( ℊ → g ) SCRIPT SMALL G → LATIN SMALL LETTER G# + {8459, "H"}, // MA# ( ℋ → H ) SCRIPT CAPITAL H → LATIN CAPITAL LETTER H# + {8460, "H"}, // MA# ( ℌ → H ) BLACK-LETTER CAPITAL H → LATIN CAPITAL LETTER H# + {8461, "H"}, // MA# ( ℍ → H ) DOUBLE-STRUCK CAPITAL H → LATIN CAPITAL LETTER H# + {8462, "h"}, // MA# ( ℎ → h ) PLANCK CONSTANT → LATIN SMALL LETTER H# + {8464, "l"}, // MA# ( ℐ → l ) SCRIPT CAPITAL I → LATIN SMALL LETTER L# →I→ + {8465, "l"}, // MA# ( ℑ → l ) BLACK-LETTER CAPITAL I → LATIN SMALL LETTER L# →I→ + {8466, "L"}, // MA# ( ℒ → L ) SCRIPT CAPITAL L → LATIN CAPITAL LETTER L# + {8467, "l"}, // MA# ( ℓ → l ) SCRIPT SMALL L → LATIN SMALL LETTER L# + {8469, "N"}, // MA# ( ℕ → N ) DOUBLE-STRUCK CAPITAL N → LATIN CAPITAL LETTER N# + {8470, "No"}, // MA#* ( № → No ) NUMERO SIGN → LATIN CAPITAL LETTER N, LATIN SMALL LETTER O# + {8473, "P"}, // MA# ( ℙ → P ) DOUBLE-STRUCK CAPITAL P → LATIN CAPITAL LETTER P# + {8474, "Q"}, // MA# ( ℚ → Q ) DOUBLE-STRUCK CAPITAL Q → LATIN CAPITAL LETTER Q# + {8475, "R"}, // MA# ( ℛ → R ) SCRIPT CAPITAL R → LATIN CAPITAL LETTER R# + {8476, "R"}, // MA# ( ℜ → R ) BLACK-LETTER CAPITAL R → LATIN CAPITAL LETTER R# + {8477, "R"}, // MA# ( ℝ → R ) DOUBLE-STRUCK CAPITAL R → LATIN CAPITAL LETTER R# + {8481, "TEL"}, // MA#* ( ℡ → TEL ) TELEPHONE SIGN → LATIN CAPITAL LETTER T, LATIN CAPITAL LETTER E, LATIN CAPITAL LETTER L# + {8484, "Z"}, // MA# ( ℤ → Z ) DOUBLE-STRUCK CAPITAL Z → LATIN CAPITAL LETTER Z# + {8488, "Z"}, // MA# ( ℨ → Z ) BLACK-LETTER CAPITAL Z → LATIN CAPITAL LETTER Z# + {8490, "K"}, // MA# ( K → K ) KELVIN SIGN → LATIN CAPITAL LETTER K# + {8492, "B"}, // MA# ( ℬ → B ) SCRIPT CAPITAL B → LATIN CAPITAL LETTER B# + {8493, "C"}, // MA# ( ℭ → C ) BLACK-LETTER CAPITAL C → LATIN CAPITAL LETTER C# + {8494, "e"}, // MA# ( ℮ → e ) ESTIMATED SYMBOL → LATIN SMALL LETTER E# + {8495, "e"}, // MA# ( ℯ → e ) SCRIPT SMALL E → LATIN SMALL LETTER E# + {8496, "E"}, // MA# ( ℰ → E ) SCRIPT CAPITAL E → LATIN CAPITAL LETTER E# + {8497, "F"}, // MA# ( ℱ → F ) SCRIPT CAPITAL F → LATIN CAPITAL LETTER F# + {8499, "M"}, // MA# ( ℳ → M ) SCRIPT CAPITAL M → LATIN CAPITAL LETTER M# + {8500, "o"}, // MA# ( ℴ → o ) SCRIPT SMALL O → LATIN SMALL LETTER O# + {8505, "i"}, // MA# ( ℹ → i ) INFORMATION SOURCE → LATIN SMALL LETTER I# + {8507, "FAX"}, // MA#* ( ℻ → FAX ) FACSIMILE SIGN → LATIN CAPITAL LETTER F, LATIN CAPITAL LETTER A, LATIN CAPITAL LETTER X# + {8509, "y"}, // MA# ( ℽ → y ) DOUBLE-STRUCK SMALL GAMMA → LATIN SMALL LETTER Y# →γ→ + {8517, "D"}, // MA# ( ⅅ → D ) DOUBLE-STRUCK ITALIC CAPITAL D → LATIN CAPITAL LETTER D# + {8518, "d"}, // MA# ( ⅆ → d ) DOUBLE-STRUCK ITALIC SMALL D → LATIN SMALL LETTER D# + {8519, "e"}, // MA# ( ⅇ → e ) DOUBLE-STRUCK ITALIC SMALL E → LATIN SMALL LETTER E# + {8520, "i"}, // MA# ( ⅈ → i ) DOUBLE-STRUCK ITALIC SMALL I → LATIN SMALL LETTER I# + {8521, "j"}, // MA# ( ⅉ → j ) DOUBLE-STRUCK ITALIC SMALL J → LATIN SMALL LETTER J# + {8544, "l"}, // MA# ( Ⅰ → l ) ROMAN NUMERAL ONE → LATIN SMALL LETTER L# →Ӏ→ + {8545, "ll"}, // MA# ( Ⅱ → ll ) ROMAN NUMERAL TWO → LATIN SMALL LETTER L, LATIN SMALL LETTER L# →II→ + {8546, "lll"}, // MA# ( Ⅲ → lll ) ROMAN NUMERAL THREE → LATIN SMALL LETTER L, LATIN SMALL LETTER L, LATIN SMALL LETTER L# →III→ + {8547, "lV"}, // MA# ( Ⅳ → lV ) ROMAN NUMERAL FOUR → LATIN SMALL LETTER L, LATIN CAPITAL LETTER V# →IV→ + {8548, "V"}, // MA# ( Ⅴ → V ) ROMAN NUMERAL FIVE → LATIN CAPITAL LETTER V# + {8549, "Vl"}, // MA# ( Ⅵ → Vl ) ROMAN NUMERAL SIX → LATIN CAPITAL LETTER V, LATIN SMALL LETTER L# →VI→ + {8550, "Vll"}, // MA# ( Ⅶ → Vll ) ROMAN NUMERAL SEVEN → LATIN CAPITAL LETTER V, LATIN SMALL LETTER L, LATIN SMALL LETTER L# →VII→ + {8551, "Vlll"}, // MA# ( Ⅷ → Vlll ) ROMAN NUMERAL EIGHT → LATIN CAPITAL LETTER V, LATIN SMALL LETTER L, LATIN SMALL LETTER L, LATIN SMALL LETTER L# →VIII→ + {8552, "lX"}, // MA# ( Ⅸ → lX ) ROMAN NUMERAL NINE → LATIN SMALL LETTER L, LATIN CAPITAL LETTER X# →IX→ + {8553, "X"}, // MA# ( Ⅹ → X ) ROMAN NUMERAL TEN → LATIN CAPITAL LETTER X# + {8554, "Xl"}, // MA# ( Ⅺ → Xl ) ROMAN NUMERAL ELEVEN → LATIN CAPITAL LETTER X, LATIN SMALL LETTER L# →XI→ + {8555, "Xll"}, // MA# ( Ⅻ → Xll ) ROMAN NUMERAL TWELVE → LATIN CAPITAL LETTER X, LATIN SMALL LETTER L, LATIN SMALL LETTER L# →XII→ + {8556, "L"}, // MA# ( Ⅼ → L ) ROMAN NUMERAL FIFTY → LATIN CAPITAL LETTER L# + {8557, "C"}, // MA# ( Ⅽ → C ) ROMAN NUMERAL ONE HUNDRED → LATIN CAPITAL LETTER C# + {8558, "D"}, // MA# ( Ⅾ → D ) ROMAN NUMERAL FIVE HUNDRED → LATIN CAPITAL LETTER D# + {8559, "M"}, // MA# ( Ⅿ → M ) ROMAN NUMERAL ONE THOUSAND → LATIN CAPITAL LETTER M# + {8560, "i"}, // MA# ( ⅰ → i ) SMALL ROMAN NUMERAL ONE → LATIN SMALL LETTER I# + {8561, "ii"}, // MA# ( ⅱ → ii ) SMALL ROMAN NUMERAL TWO → LATIN SMALL LETTER I, LATIN SMALL LETTER I# + {8562, "iii"}, // MA# ( ⅲ → iii ) SMALL ROMAN NUMERAL THREE → LATIN SMALL LETTER I, LATIN SMALL LETTER I, LATIN SMALL LETTER I# + {8563, "iv"}, // MA# ( ⅳ → iv ) SMALL ROMAN NUMERAL FOUR → LATIN SMALL LETTER I, LATIN SMALL LETTER V# + {8564, "v"}, // MA# ( ⅴ → v ) SMALL ROMAN NUMERAL FIVE → LATIN SMALL LETTER V# + {8565, "vi"}, // MA# ( ⅵ → vi ) SMALL ROMAN NUMERAL SIX → LATIN SMALL LETTER V, LATIN SMALL LETTER I# + {8566, "vii"}, // MA# ( ⅶ → vii ) SMALL ROMAN NUMERAL SEVEN → LATIN SMALL LETTER V, LATIN SMALL LETTER I, LATIN SMALL LETTER I# + {8567, "viii"}, // MA# ( ⅷ → viii ) SMALL ROMAN NUMERAL EIGHT → LATIN SMALL LETTER V, LATIN SMALL LETTER I, LATIN SMALL LETTER I, LATIN SMALL LETTER I# + {8568, "ix"}, // MA# ( ⅸ → ix ) SMALL ROMAN NUMERAL NINE → LATIN SMALL LETTER I, LATIN SMALL LETTER X# + {8569, "x"}, // MA# ( ⅹ → x ) SMALL ROMAN NUMERAL TEN → LATIN SMALL LETTER X# + {8570, "xi"}, // MA# ( ⅺ → xi ) SMALL ROMAN NUMERAL ELEVEN → LATIN SMALL LETTER X, LATIN SMALL LETTER I# + {8571, "xii"}, // MA# ( ⅻ → xii ) SMALL ROMAN NUMERAL TWELVE → LATIN SMALL LETTER X, LATIN SMALL LETTER I, LATIN SMALL LETTER I# + {8572, "l"}, // MA# ( ⅼ → l ) SMALL ROMAN NUMERAL FIFTY → LATIN SMALL LETTER L# + {8573, "c"}, // MA# ( ⅽ → c ) SMALL ROMAN NUMERAL ONE HUNDRED → LATIN SMALL LETTER C# + {8574, "d"}, // MA# ( ⅾ → d ) SMALL ROMAN NUMERAL FIVE HUNDRED → LATIN SMALL LETTER D# + {8575, "rn"}, // MA# ( ⅿ → rn ) SMALL ROMAN NUMERAL ONE THOUSAND → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {8722, "-"}, // MA#* ( − → - ) MINUS SIGN → HYPHEN-MINUS# + {8725, "/"}, // MA#* ( ∕ → / ) DIVISION SLASH → SOLIDUS# + {8726, "\\"}, // MA#* ( ∖ → \ ) SET MINUS → REVERSE SOLIDUS# + {8727, "*"}, // MA#* ( ∗ → * ) ASTERISK OPERATOR → ASTERISK# + {8734, "oo"}, // MA#* ( ∞ → oo ) INFINITY → LATIN SMALL LETTER O, LATIN SMALL LETTER O# →ꝏ→ + {8739, "l"}, // MA#* ( ∣ → l ) DIVIDES → LATIN SMALL LETTER L# →ǀ→ + {8741, "ll"}, // MA#* ( ∥ → ll ) PARALLEL TO → LATIN SMALL LETTER L, LATIN SMALL LETTER L# →||→ + {8744, "v"}, // MA#* ( ∨ → v ) LOGICAL OR → LATIN SMALL LETTER V# + {8746, "U"}, // MA#* ( ∪ → U ) UNION → LATIN CAPITAL LETTER U# →ᑌ→ + {8758, ":"}, // MA#* ( ∶ → : ) RATIO → COLON# + {8764, "~"}, // MA#* ( ∼ → ~ ) TILDE OPERATOR → TILDE# + {8810, "<<"}, // MA#* ( ≪ → << ) MUCH LESS-THAN → LESS-THAN SIGN, LESS-THAN SIGN# + {8811, ">>"}, // MA#* ( ≫ → >> ) MUCH GREATER-THAN → GREATER-THAN SIGN, GREATER-THAN SIGN# + {8868, "T"}, // MA#* ( ⊤ → T ) DOWN TACK → LATIN CAPITAL LETTER T# + {8897, "v"}, // MA#* ( ⋁ → v ) N-ARY LOGICAL OR → LATIN SMALL LETTER V# →∨→ + {8899, "U"}, // MA#* ( ⋃ → U ) N-ARY UNION → LATIN CAPITAL LETTER U# →∪→→ᑌ→ + {8920, "<<<"}, // MA#* ( ⋘ → <<< ) VERY MUCH LESS-THAN → LESS-THAN SIGN, LESS-THAN SIGN, LESS-THAN SIGN# + {8921, ">>>"}, // MA#* ( ⋙ → >>> ) VERY MUCH GREATER-THAN → GREATER-THAN SIGN, GREATER-THAN SIGN, GREATER-THAN SIGN# + {8959, "E"}, // MA#* ( ⋿ → E ) Z NOTATION BAG MEMBERSHIP → LATIN CAPITAL LETTER E# + {9075, "i"}, // MA#* ( ⍳ → i ) APL FUNCTIONAL SYMBOL IOTA → LATIN SMALL LETTER I# →ι→ + {9076, "p"}, // MA#* ( ⍴ → p ) APL FUNCTIONAL SYMBOL RHO → LATIN SMALL LETTER P# →ρ→ + {9082, "a"}, // MA#* ( ⍺ → a ) APL FUNCTIONAL SYMBOL ALPHA → LATIN SMALL LETTER A# →α→ + {9213, "l"}, // MA#* ( ⏽ → l ) POWER ON SYMBOL → LATIN SMALL LETTER L# →I→ + {9290, "\\\\"}, // MA#* ( ⑊ → \\ ) OCR DOUBLE BACKSLASH → REVERSE SOLIDUS, REVERSE SOLIDUS# + {9332, "(l)"}, // MA#* ( ⑴ → (l) ) PARENTHESIZED DIGIT ONE → LEFT PARENTHESIS, LATIN SMALL LETTER L, RIGHT PARENTHESIS# →(1)→ + {9333, "(2)"}, // MA#* ( ⑵ → (2) ) PARENTHESIZED DIGIT TWO → LEFT PARENTHESIS, DIGIT TWO, RIGHT PARENTHESIS# + {9334, "(3)"}, // MA#* ( ⑶ → (3) ) PARENTHESIZED DIGIT THREE → LEFT PARENTHESIS, DIGIT THREE, RIGHT PARENTHESIS# + {9335, "(4)"}, // MA#* ( ⑷ → (4) ) PARENTHESIZED DIGIT FOUR → LEFT PARENTHESIS, DIGIT FOUR, RIGHT PARENTHESIS# + {9336, "(5)"}, // MA#* ( ⑸ → (5) ) PARENTHESIZED DIGIT FIVE → LEFT PARENTHESIS, DIGIT FIVE, RIGHT PARENTHESIS# + {9337, "(6)"}, // MA#* ( ⑹ → (6) ) PARENTHESIZED DIGIT SIX → LEFT PARENTHESIS, DIGIT SIX, RIGHT PARENTHESIS# + {9338, "(7)"}, // MA#* ( ⑺ → (7) ) PARENTHESIZED DIGIT SEVEN → LEFT PARENTHESIS, DIGIT SEVEN, RIGHT PARENTHESIS# + {9339, "(8)"}, // MA#* ( ⑻ → (8) ) PARENTHESIZED DIGIT EIGHT → LEFT PARENTHESIS, DIGIT EIGHT, RIGHT PARENTHESIS# + {9340, "(9)"}, // MA#* ( ⑼ → (9) ) PARENTHESIZED DIGIT NINE → LEFT PARENTHESIS, DIGIT NINE, RIGHT PARENTHESIS# + {9341, "(lO)"}, // MA#* ( ⑽ → (lO) ) PARENTHESIZED NUMBER TEN → LEFT PARENTHESIS, LATIN SMALL LETTER L, LATIN CAPITAL LETTER O, RIGHT PARENTHESIS# →(10)→ + {9342, "(ll)"}, // MA#* ( ⑾ → (ll) ) PARENTHESIZED NUMBER ELEVEN → LEFT PARENTHESIS, LATIN SMALL LETTER L, LATIN SMALL LETTER L, RIGHT PARENTHESIS# →(11)→ + {9343, "(l2)"}, // MA#* ( ⑿ → (l2) ) PARENTHESIZED NUMBER TWELVE → LEFT PARENTHESIS, LATIN SMALL LETTER L, DIGIT TWO, RIGHT PARENTHESIS# →(12)→ + {9344, "(l3)"}, // MA#* ( ⒀ → (l3) ) PARENTHESIZED NUMBER THIRTEEN → LEFT PARENTHESIS, LATIN SMALL LETTER L, DIGIT THREE, RIGHT PARENTHESIS# →(13)→ + {9345, "(l4)"}, // MA#* ( ⒁ → (l4) ) PARENTHESIZED NUMBER FOURTEEN → LEFT PARENTHESIS, LATIN SMALL LETTER L, DIGIT FOUR, RIGHT PARENTHESIS# →(14)→ + {9346, "(l5)"}, // MA#* ( ⒂ → (l5) ) PARENTHESIZED NUMBER FIFTEEN → LEFT PARENTHESIS, LATIN SMALL LETTER L, DIGIT FIVE, RIGHT PARENTHESIS# →(15)→ + {9347, "(l6)"}, // MA#* ( ⒃ → (l6) ) PARENTHESIZED NUMBER SIXTEEN → LEFT PARENTHESIS, LATIN SMALL LETTER L, DIGIT SIX, RIGHT PARENTHESIS# →(16)→ + {9348, "(l7)"}, // MA#* ( ⒄ → (l7) ) PARENTHESIZED NUMBER SEVENTEEN → LEFT PARENTHESIS, LATIN SMALL LETTER L, DIGIT SEVEN, RIGHT PARENTHESIS# →(17)→ + {9349, "(l8)"}, // MA#* ( ⒅ → (l8) ) PARENTHESIZED NUMBER EIGHTEEN → LEFT PARENTHESIS, LATIN SMALL LETTER L, DIGIT EIGHT, RIGHT PARENTHESIS# →(18)→ + {9350, "(l9)"}, // MA#* ( ⒆ → (l9) ) PARENTHESIZED NUMBER NINETEEN → LEFT PARENTHESIS, LATIN SMALL LETTER L, DIGIT NINE, RIGHT PARENTHESIS# →(19)→ + {9351, "(2O)"}, // MA#* ( ⒇ → (2O) ) PARENTHESIZED NUMBER TWENTY → LEFT PARENTHESIS, DIGIT TWO, LATIN CAPITAL LETTER O, RIGHT PARENTHESIS# →(20)→ + {9352, "l."}, // MA#* ( ⒈ → l. ) DIGIT ONE FULL STOP → LATIN SMALL LETTER L, FULL STOP# →1.→ + {9353, "2."}, // MA#* ( ⒉ → 2. ) DIGIT TWO FULL STOP → DIGIT TWO, FULL STOP# + {9354, "3."}, // MA#* ( ⒊ → 3. ) DIGIT THREE FULL STOP → DIGIT THREE, FULL STOP# + {9355, "4."}, // MA#* ( ⒋ → 4. ) DIGIT FOUR FULL STOP → DIGIT FOUR, FULL STOP# + {9356, "5."}, // MA#* ( ⒌ → 5. ) DIGIT FIVE FULL STOP → DIGIT FIVE, FULL STOP# + {9357, "6."}, // MA#* ( ⒍ → 6. ) DIGIT SIX FULL STOP → DIGIT SIX, FULL STOP# + {9358, "7."}, // MA#* ( ⒎ → 7. ) DIGIT SEVEN FULL STOP → DIGIT SEVEN, FULL STOP# + {9359, "8."}, // MA#* ( ⒏ → 8. ) DIGIT EIGHT FULL STOP → DIGIT EIGHT, FULL STOP# + {9360, "9."}, // MA#* ( ⒐ → 9. ) DIGIT NINE FULL STOP → DIGIT NINE, FULL STOP# + {9361, "lO."}, // MA#* ( ⒑ → lO. ) NUMBER TEN FULL STOP → LATIN SMALL LETTER L, LATIN CAPITAL LETTER O, FULL STOP# →10.→ + {9362, "ll."}, // MA#* ( ⒒ → ll. ) NUMBER ELEVEN FULL STOP → LATIN SMALL LETTER L, LATIN SMALL LETTER L, FULL STOP# →11.→ + {9363, "l2."}, // MA#* ( ⒓ → l2. ) NUMBER TWELVE FULL STOP → LATIN SMALL LETTER L, DIGIT TWO, FULL STOP# →12.→ + {9364, "l3."}, // MA#* ( ⒔ → l3. ) NUMBER THIRTEEN FULL STOP → LATIN SMALL LETTER L, DIGIT THREE, FULL STOP# →13.→ + {9365, "l4."}, // MA#* ( ⒕ → l4. ) NUMBER FOURTEEN FULL STOP → LATIN SMALL LETTER L, DIGIT FOUR, FULL STOP# →14.→ + {9366, "l5."}, // MA#* ( ⒖ → l5. ) NUMBER FIFTEEN FULL STOP → LATIN SMALL LETTER L, DIGIT FIVE, FULL STOP# →15.→ + {9367, "l6."}, // MA#* ( ⒗ → l6. ) NUMBER SIXTEEN FULL STOP → LATIN SMALL LETTER L, DIGIT SIX, FULL STOP# →16.→ + {9368, "l7."}, // MA#* ( ⒘ → l7. ) NUMBER SEVENTEEN FULL STOP → LATIN SMALL LETTER L, DIGIT SEVEN, FULL STOP# →17.→ + {9369, "l8."}, // MA#* ( ⒙ → l8. ) NUMBER EIGHTEEN FULL STOP → LATIN SMALL LETTER L, DIGIT EIGHT, FULL STOP# →18.→ + {9370, "l9."}, // MA#* ( ⒚ → l9. ) NUMBER NINETEEN FULL STOP → LATIN SMALL LETTER L, DIGIT NINE, FULL STOP# →19.→ + {9371, "2O."}, // MA#* ( ⒛ → 2O. ) NUMBER TWENTY FULL STOP → DIGIT TWO, LATIN CAPITAL LETTER O, FULL STOP# →20.→ + {9372, "(a)"}, // MA#* ( ⒜ → (a) ) PARENTHESIZED LATIN SMALL LETTER A → LEFT PARENTHESIS, LATIN SMALL LETTER A, RIGHT PARENTHESIS# + {9373, "(b)"}, // MA#* ( ⒝ → (b) ) PARENTHESIZED LATIN SMALL LETTER B → LEFT PARENTHESIS, LATIN SMALL LETTER B, RIGHT PARENTHESIS# + {9374, "(c)"}, // MA#* ( ⒞ → (c) ) PARENTHESIZED LATIN SMALL LETTER C → LEFT PARENTHESIS, LATIN SMALL LETTER C, RIGHT PARENTHESIS# + {9375, "(d)"}, // MA#* ( ⒟ → (d) ) PARENTHESIZED LATIN SMALL LETTER D → LEFT PARENTHESIS, LATIN SMALL LETTER D, RIGHT PARENTHESIS# + {9376, "(e)"}, // MA#* ( ⒠ → (e) ) PARENTHESIZED LATIN SMALL LETTER E → LEFT PARENTHESIS, LATIN SMALL LETTER E, RIGHT PARENTHESIS# + {9377, "(f)"}, // MA#* ( ⒡ → (f) ) PARENTHESIZED LATIN SMALL LETTER F → LEFT PARENTHESIS, LATIN SMALL LETTER F, RIGHT PARENTHESIS# + {9378, "(g)"}, // MA#* ( ⒢ → (g) ) PARENTHESIZED LATIN SMALL LETTER G → LEFT PARENTHESIS, LATIN SMALL LETTER G, RIGHT PARENTHESIS# + {9379, "(h)"}, // MA#* ( ⒣ → (h) ) PARENTHESIZED LATIN SMALL LETTER H → LEFT PARENTHESIS, LATIN SMALL LETTER H, RIGHT PARENTHESIS# + {9380, "(i)"}, // MA#* ( ⒤ → (i) ) PARENTHESIZED LATIN SMALL LETTER I → LEFT PARENTHESIS, LATIN SMALL LETTER I, RIGHT PARENTHESIS# + {9381, "(j)"}, // MA#* ( ⒥ → (j) ) PARENTHESIZED LATIN SMALL LETTER J → LEFT PARENTHESIS, LATIN SMALL LETTER J, RIGHT PARENTHESIS# + {9382, "(k)"}, // MA#* ( ⒦ → (k) ) PARENTHESIZED LATIN SMALL LETTER K → LEFT PARENTHESIS, LATIN SMALL LETTER K, RIGHT PARENTHESIS# + {9383, "(l)"}, // MA#* ( ⒧ → (l) ) PARENTHESIZED LATIN SMALL LETTER L → LEFT PARENTHESIS, LATIN SMALL LETTER L, RIGHT PARENTHESIS# + {9384, "(rn)"}, // MA#* ( ⒨ → (rn) ) PARENTHESIZED LATIN SMALL LETTER M → LEFT PARENTHESIS, LATIN SMALL LETTER R, LATIN SMALL LETTER N, RIGHT PARENTHESIS# →(m)→ + {9385, "(n)"}, // MA#* ( ⒩ → (n) ) PARENTHESIZED LATIN SMALL LETTER N → LEFT PARENTHESIS, LATIN SMALL LETTER N, RIGHT PARENTHESIS# + {9386, "(o)"}, // MA#* ( ⒪ → (o) ) PARENTHESIZED LATIN SMALL LETTER O → LEFT PARENTHESIS, LATIN SMALL LETTER O, RIGHT PARENTHESIS# + {9387, "(p)"}, // MA#* ( ⒫ → (p) ) PARENTHESIZED LATIN SMALL LETTER P → LEFT PARENTHESIS, LATIN SMALL LETTER P, RIGHT PARENTHESIS# + {9388, "(q)"}, // MA#* ( ⒬ → (q) ) PARENTHESIZED LATIN SMALL LETTER Q → LEFT PARENTHESIS, LATIN SMALL LETTER Q, RIGHT PARENTHESIS# + {9389, "(r)"}, // MA#* ( ⒭ → (r) ) PARENTHESIZED LATIN SMALL LETTER R → LEFT PARENTHESIS, LATIN SMALL LETTER R, RIGHT PARENTHESIS# + {9390, "(s)"}, // MA#* ( ⒮ → (s) ) PARENTHESIZED LATIN SMALL LETTER S → LEFT PARENTHESIS, LATIN SMALL LETTER S, RIGHT PARENTHESIS# + {9391, "(t)"}, // MA#* ( ⒯ → (t) ) PARENTHESIZED LATIN SMALL LETTER T → LEFT PARENTHESIS, LATIN SMALL LETTER T, RIGHT PARENTHESIS# + {9392, "(u)"}, // MA#* ( ⒰ → (u) ) PARENTHESIZED LATIN SMALL LETTER U → LEFT PARENTHESIS, LATIN SMALL LETTER U, RIGHT PARENTHESIS# + {9393, "(v)"}, // MA#* ( ⒱ → (v) ) PARENTHESIZED LATIN SMALL LETTER V → LEFT PARENTHESIS, LATIN SMALL LETTER V, RIGHT PARENTHESIS# + {9394, "(w)"}, // MA#* ( ⒲ → (w) ) PARENTHESIZED LATIN SMALL LETTER W → LEFT PARENTHESIS, LATIN SMALL LETTER W, RIGHT PARENTHESIS# + {9395, "(x)"}, // MA#* ( ⒳ → (x) ) PARENTHESIZED LATIN SMALL LETTER X → LEFT PARENTHESIS, LATIN SMALL LETTER X, RIGHT PARENTHESIS# + {9396, "(y)"}, // MA#* ( ⒴ → (y) ) PARENTHESIZED LATIN SMALL LETTER Y → LEFT PARENTHESIS, LATIN SMALL LETTER Y, RIGHT PARENTHESIS# + {9397, "(z)"}, // MA#* ( ⒵ → (z) ) PARENTHESIZED LATIN SMALL LETTER Z → LEFT PARENTHESIS, LATIN SMALL LETTER Z, RIGHT PARENTHESIS# + {9585, "/"}, // MA#* ( ╱ → / ) BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT → SOLIDUS# + {9587, "X"}, // MA#* ( ╳ → X ) BOX DRAWINGS LIGHT DIAGONAL CROSS → LATIN CAPITAL LETTER X# + {10088, "("}, // MA#* ( ❨ → ( ) MEDIUM LEFT PARENTHESIS ORNAMENT → LEFT PARENTHESIS# + {10089, ")"}, // MA#* ( ❩ → ) ) MEDIUM RIGHT PARENTHESIS ORNAMENT → RIGHT PARENTHESIS# + {10094, "<"}, // MA#* ( ❮ → < ) HEAVY LEFT-POINTING ANGLE QUOTATION MARK ORNAMENT → LESS-THAN SIGN# →‹→ + {10095, ">"}, // MA#* ( ❯ → > ) HEAVY RIGHT-POINTING ANGLE QUOTATION MARK ORNAMENT → GREATER-THAN SIGN# →›→ + {10098, "("}, // MA#* ( ❲ → ( ) LIGHT LEFT TORTOISE SHELL BRACKET ORNAMENT → LEFT PARENTHESIS# →〔→ + {10099, ")"}, // MA#* ( ❳ → ) ) LIGHT RIGHT TORTOISE SHELL BRACKET ORNAMENT → RIGHT PARENTHESIS# →〕→ + {10100, "{"}, // MA#* ( ❴ → { ) MEDIUM LEFT CURLY BRACKET ORNAMENT → LEFT CURLY BRACKET# + {10101, "}"}, // MA#* ( ❵ → } ) MEDIUM RIGHT CURLY BRACKET ORNAMENT → RIGHT CURLY BRACKET# + {10133, "+"}, // MA#* ( ➕ → + ) HEAVY PLUS SIGN → PLUS SIGN# + {10134, "-"}, // MA#* ( ➖ → - ) HEAVY MINUS SIGN → HYPHEN-MINUS# →−→ + {10187, "/"}, // MA#* ( ⟋ → / ) MATHEMATICAL RISING DIAGONAL → SOLIDUS# + {10189, "\\"}, // MA#* ( ⟍ → \ ) MATHEMATICAL FALLING DIAGONAL → REVERSE SOLIDUS# + {10201, "T"}, // MA#* ( ⟙ → T ) LARGE DOWN TACK → LATIN CAPITAL LETTER T# + {10539, "x"}, // MA#* ( ⤫ → x ) RISING DIAGONAL CROSSING FALLING DIAGONAL → LATIN SMALL LETTER X# + {10540, "x"}, // MA#* ( ⤬ → x ) FALLING DIAGONAL CROSSING RISING DIAGONAL → LATIN SMALL LETTER X# + {10741, "\\"}, // MA#* ( ⧵ → \ ) REVERSE SOLIDUS OPERATOR → REVERSE SOLIDUS# + {10744, "/"}, // MA#* ( ⧸ → / ) BIG SOLIDUS → SOLIDUS# + {10745, "\\"}, // MA#* ( ⧹ → \ ) BIG REVERSE SOLIDUS → REVERSE SOLIDUS# + {10784, ">>"}, // MA#* ( ⨠ → >> ) Z NOTATION SCHEMA PIPING → GREATER-THAN SIGN, GREATER-THAN SIGN# →≫→ + {10799, "x"}, // MA#* ( ⨯ → x ) VECTOR OR CROSS PRODUCT → LATIN SMALL LETTER X# →×→ + {10868, "::="}, // MA#* ( ⩴ → ::= ) DOUBLE COLON EQUAL → COLON, COLON, EQUALS SIGN# + {10869, "=="}, // MA#* ( ⩵ → == ) TWO CONSECUTIVE EQUALS SIGNS → EQUALS SIGN, EQUALS SIGN# + {10870, "==="}, // MA#* ( ⩶ → === ) THREE CONSECUTIVE EQUALS SIGNS → EQUALS SIGN, EQUALS SIGN, EQUALS SIGN# + {10917, "><"}, // MA#* ( ⪥ → >< ) GREATER-THAN BESIDE LESS-THAN → GREATER-THAN SIGN, LESS-THAN SIGN# + {11003, "///"}, // MA#* ( ⫻ → /// ) TRIPLE SOLIDUS BINARY RELATION → SOLIDUS, SOLIDUS, SOLIDUS# + {11005, "//"}, // MA#* ( ⫽ → // ) DOUBLE SOLIDUS OPERATOR → SOLIDUS, SOLIDUS# + {11397, "r"}, // MA# ( ⲅ → r ) COPTIC SMALL LETTER GAMMA → LATIN SMALL LETTER R# →г→ + {11406, "H"}, // MA# ( Ⲏ → H ) COPTIC CAPITAL LETTER HATE → LATIN CAPITAL LETTER H# →Η→ + {11410, "l"}, // MA# ( Ⲓ → l ) COPTIC CAPITAL LETTER IAUDA → LATIN SMALL LETTER L# →Ӏ→ + {11412, "K"}, // MA# ( Ⲕ → K ) COPTIC CAPITAL LETTER KAPA → LATIN CAPITAL LETTER K# →Κ→ + {11416, "M"}, // MA# ( Ⲙ → M ) COPTIC CAPITAL LETTER MI → LATIN CAPITAL LETTER M# + {11418, "N"}, // MA# ( Ⲛ → N ) COPTIC CAPITAL LETTER NI → LATIN CAPITAL LETTER N# + {11422, "O"}, // MA# ( Ⲟ → O ) COPTIC CAPITAL LETTER O → LATIN CAPITAL LETTER O# + {11423, "o"}, // MA# ( ⲟ → o ) COPTIC SMALL LETTER O → LATIN SMALL LETTER O# + {11426, "P"}, // MA# ( Ⲣ → P ) COPTIC CAPITAL LETTER RO → LATIN CAPITAL LETTER P# + {11427, "p"}, // MA# ( ⲣ → p ) COPTIC SMALL LETTER RO → LATIN SMALL LETTER P# →ρ→ + {11428, "C"}, // MA# ( Ⲥ → C ) COPTIC CAPITAL LETTER SIMA → LATIN CAPITAL LETTER C# →Ϲ→ + {11429, "c"}, // MA# ( ⲥ → c ) COPTIC SMALL LETTER SIMA → LATIN SMALL LETTER C# →ϲ→ + {11430, "T"}, // MA# ( Ⲧ → T ) COPTIC CAPITAL LETTER TAU → LATIN CAPITAL LETTER T# + {11432, "Y"}, // MA# ( Ⲩ → Y ) COPTIC CAPITAL LETTER UA → LATIN CAPITAL LETTER Y# + {11436, "X"}, // MA# ( Ⲭ → X ) COPTIC CAPITAL LETTER KHI → LATIN CAPITAL LETTER X# →Х→ + {11450, "-"}, // MA# ( Ⲻ → - ) COPTIC CAPITAL LETTER DIALECT-P NI → HYPHEN-MINUS# →‒→ + {11462, "/"}, // MA# ( Ⳇ → / ) COPTIC CAPITAL LETTER OLD COPTIC ESH → SOLIDUS# + {11466, "9"}, // MA# ( Ⳋ → 9 ) COPTIC CAPITAL LETTER DIALECT-P HORI → DIGIT NINE# + {11468, "3"}, // MA# ( Ⳍ → 3 ) COPTIC CAPITAL LETTER OLD COPTIC HORI → DIGIT THREE# →Ȝ→→Ʒ→ + {11472, "L"}, // MA# ( Ⳑ → L ) COPTIC CAPITAL LETTER L-SHAPED HA → LATIN CAPITAL LETTER L# + {11474, "6"}, // MA# ( Ⳓ → 6 ) COPTIC CAPITAL LETTER OLD COPTIC HEI → DIGIT SIX# + {11513, "\\\\"}, // MA#* ( ⳹ → \\ ) COPTIC OLD NUBIAN FULL STOP → REVERSE SOLIDUS, REVERSE SOLIDUS# + {11576, "V"}, // MA# ( ⴸ → V ) TIFINAGH LETTER YADH → LATIN CAPITAL LETTER V# + {11577, "E"}, // MA# ( ⴹ → E ) TIFINAGH LETTER YADD → LATIN CAPITAL LETTER E# + {11599, "l"}, // MA# ( ⵏ → l ) TIFINAGH LETTER YAN → LATIN SMALL LETTER L# →Ӏ→ + {11601, "!"}, // MA# ( ⵑ → ! ) TIFINAGH LETTER TUAREG YANG → EXCLAMATION MARK# + {11604, "O"}, // MA# ( ⵔ → O ) TIFINAGH LETTER YAR → LATIN CAPITAL LETTER O# + {11605, "Q"}, // MA# ( ⵕ → Q ) TIFINAGH LETTER YARR → LATIN CAPITAL LETTER Q# + {11613, "X"}, // MA# ( ⵝ → X ) TIFINAGH LETTER YATH → LATIN CAPITAL LETTER X# + {11816, "(("}, // MA#* ( ⸨ → (( ) LEFT DOUBLE PARENTHESIS → LEFT PARENTHESIS, LEFT PARENTHESIS# + {11817, "))"}, // MA#* ( ⸩ → )) ) RIGHT DOUBLE PARENTHESIS → RIGHT PARENTHESIS, RIGHT PARENTHESIS# + {11840, "="}, // MA#* ( ⹀ → = ) DOUBLE HYPHEN → EQUALS SIGN# + {12034, "\\"}, // MA#* ( ⼂ → \ ) KANGXI RADICAL DOT → REVERSE SOLIDUS# + {12035, "/"}, // MA#* ( ⼃ → / ) KANGXI RADICAL SLASH → SOLIDUS# + {12291, "\""}, // MA#* ( 〃 → '' ) DITTO MARK → APOSTROPHE, APOSTROPHE# →″→→"→# Converted to a quote. + {12295, "O"}, // MA# ( 〇 → O ) IDEOGRAPHIC NUMBER ZERO → LATIN CAPITAL LETTER O# + {12308, "("}, // MA#* ( 〔 → ( ) LEFT TORTOISE SHELL BRACKET → LEFT PARENTHESIS# + {12309, ")"}, // MA#* ( 〕 → ) ) RIGHT TORTOISE SHELL BRACKET → RIGHT PARENTHESIS# + {12339, "/"}, // MA# ( 〳 → / ) VERTICAL KANA REPEAT MARK UPPER HALF → SOLIDUS# + {12448, "="}, // MA#* ( ゠ → = ) KATAKANA-HIRAGANA DOUBLE HYPHEN → EQUALS SIGN# + {12494, "/"}, // MA# ( ノ → / ) KATAKANA LETTER NO → SOLIDUS# →⼃→ + {12755, "/"}, // MA#* ( ㇓ → / ) CJK STROKE SP → SOLIDUS# →⼃→ + {12756, "\\"}, // MA#* ( ㇔ → \ ) CJK STROKE D → REVERSE SOLIDUS# →⼂→ + {20022, "\\"}, // MA# ( 丶 → \ ) CJK UNIFIED IDEOGRAPH-4E36 → REVERSE SOLIDUS# →⼂→ + {20031, "/"}, // MA# ( 丿 → / ) CJK UNIFIED IDEOGRAPH-4E3F → SOLIDUS# →⼃→ + {42192, "B"}, // MA# ( ꓐ → B ) LISU LETTER BA → LATIN CAPITAL LETTER B# + {42193, "P"}, // MA# ( ꓑ → P ) LISU LETTER PA → LATIN CAPITAL LETTER P# + {42194, "d"}, // MA# ( ꓒ → d ) LISU LETTER PHA → LATIN SMALL LETTER D# + {42195, "D"}, // MA# ( ꓓ → D ) LISU LETTER DA → LATIN CAPITAL LETTER D# + {42196, "T"}, // MA# ( ꓔ → T ) LISU LETTER TA → LATIN CAPITAL LETTER T# + {42198, "G"}, // MA# ( ꓖ → G ) LISU LETTER GA → LATIN CAPITAL LETTER G# + {42199, "K"}, // MA# ( ꓗ → K ) LISU LETTER KA → LATIN CAPITAL LETTER K# + {42201, "J"}, // MA# ( ꓙ → J ) LISU LETTER JA → LATIN CAPITAL LETTER J# + {42202, "C"}, // MA# ( ꓚ → C ) LISU LETTER CA → LATIN CAPITAL LETTER C# + {42204, "Z"}, // MA# ( ꓜ → Z ) LISU LETTER DZA → LATIN CAPITAL LETTER Z# + {42205, "F"}, // MA# ( ꓝ → F ) LISU LETTER TSA → LATIN CAPITAL LETTER F# + {42207, "M"}, // MA# ( ꓟ → M ) LISU LETTER MA → LATIN CAPITAL LETTER M# + {42208, "N"}, // MA# ( ꓠ → N ) LISU LETTER NA → LATIN CAPITAL LETTER N# + {42209, "L"}, // MA# ( ꓡ → L ) LISU LETTER LA → LATIN CAPITAL LETTER L# + {42210, "S"}, // MA# ( ꓢ → S ) LISU LETTER SA → LATIN CAPITAL LETTER S# + {42211, "R"}, // MA# ( ꓣ → R ) LISU LETTER ZHA → LATIN CAPITAL LETTER R# + {42214, "V"}, // MA# ( ꓦ → V ) LISU LETTER HA → LATIN CAPITAL LETTER V# + {42215, "H"}, // MA# ( ꓧ → H ) LISU LETTER XA → LATIN CAPITAL LETTER H# + {42218, "W"}, // MA# ( ꓪ → W ) LISU LETTER WA → LATIN CAPITAL LETTER W# + {42219, "X"}, // MA# ( ꓫ → X ) LISU LETTER SHA → LATIN CAPITAL LETTER X# + {42220, "Y"}, // MA# ( ꓬ → Y ) LISU LETTER YA → LATIN CAPITAL LETTER Y# + {42222, "A"}, // MA# ( ꓮ → A ) LISU LETTER A → LATIN CAPITAL LETTER A# + {42224, "E"}, // MA# ( ꓰ → E ) LISU LETTER E → LATIN CAPITAL LETTER E# + {42226, "l"}, // MA# ( ꓲ → l ) LISU LETTER I → LATIN SMALL LETTER L# →I→ + {42227, "O"}, // MA# ( ꓳ → O ) LISU LETTER O → LATIN CAPITAL LETTER O# + {42228, "U"}, // MA# ( ꓴ → U ) LISU LETTER U → LATIN CAPITAL LETTER U# + {42232, "."}, // MA# ( ꓸ → . ) LISU LETTER TONE MYA TI → FULL STOP# + {42233, ","}, // MA# ( ꓹ → , ) LISU LETTER TONE NA PO → COMMA# + {42234, ".."}, // MA# ( ꓺ → .. ) LISU LETTER TONE MYA CYA → FULL STOP, FULL STOP# + {42235, ".,"}, // MA# ( ꓻ → ., ) LISU LETTER TONE MYA BO → FULL STOP, COMMA# + {42237, ":"}, // MA# ( ꓽ → : ) LISU LETTER TONE MYA JEU → COLON# + {42238, "-."}, // MA#* ( ꓾ → -. ) LISU PUNCTUATION COMMA → HYPHEN-MINUS, FULL STOP# + {42239, "="}, // MA#* ( ꓿ → = ) LISU PUNCTUATION FULL STOP → EQUALS SIGN# + {42510, "."}, // MA#* ( ꘎ → . ) VAI FULL STOP → FULL STOP# + {42564, "2"}, // MA# ( Ꙅ → 2 ) CYRILLIC CAPITAL LETTER REVERSED DZE → DIGIT TWO# →Ƨ→ + {42567, "i"}, // MA# ( ꙇ → i ) CYRILLIC SMALL LETTER IOTA → LATIN SMALL LETTER I# →ι→ + {42648, "OO"}, // MA# ( Ꚙ → OO ) CYRILLIC CAPITAL LETTER DOUBLE O → LATIN CAPITAL LETTER O, LATIN CAPITAL LETTER O# + {42649, "oo"}, // MA# ( ꚙ → oo ) CYRILLIC SMALL LETTER DOUBLE O → LATIN SMALL LETTER O, LATIN SMALL LETTER O# + {42719, "V"}, // MA# ( ꛟ → V ) BAMUM LETTER KO → LATIN CAPITAL LETTER V# + {42731, "?"}, // MA# ( ꛫ → ? ) BAMUM LETTER NTUU → QUESTION MARK# →ʔ→ + {42735, "2"}, // MA# ( ꛯ → 2 ) BAMUM LETTER KOGHOM → DIGIT TWO# →Ƨ→ + {42792, "T3"}, // MA# ( Ꜩ → T3 ) LATIN CAPITAL LETTER TZ → LATIN CAPITAL LETTER T, DIGIT THREE# →TƷ→ + {42801, "s"}, // MA# ( ꜱ → s ) LATIN LETTER SMALL CAPITAL S → LATIN SMALL LETTER S# + {42802, "AA"}, // MA# ( Ꜳ → AA ) LATIN CAPITAL LETTER AA → LATIN CAPITAL LETTER A, LATIN CAPITAL LETTER A# + {42803, "aa"}, // MA# ( ꜳ → aa ) LATIN SMALL LETTER AA → LATIN SMALL LETTER A, LATIN SMALL LETTER A# + {42804, "AO"}, // MA# ( Ꜵ → AO ) LATIN CAPITAL LETTER AO → LATIN CAPITAL LETTER A, LATIN CAPITAL LETTER O# + {42805, "ao"}, // MA# ( ꜵ → ao ) LATIN SMALL LETTER AO → LATIN SMALL LETTER A, LATIN SMALL LETTER O# + {42806, "AU"}, // MA# ( Ꜷ → AU ) LATIN CAPITAL LETTER AU → LATIN CAPITAL LETTER A, LATIN CAPITAL LETTER U# + {42807, "au"}, // MA# ( ꜷ → au ) LATIN SMALL LETTER AU → LATIN SMALL LETTER A, LATIN SMALL LETTER U# + {42808, "AV"}, // MA# ( Ꜹ → AV ) LATIN CAPITAL LETTER AV → LATIN CAPITAL LETTER A, LATIN CAPITAL LETTER V# + {42809, "av"}, // MA# ( ꜹ → av ) LATIN SMALL LETTER AV → LATIN SMALL LETTER A, LATIN SMALL LETTER V# + {42810, "AV"}, // MA# ( Ꜻ → AV ) LATIN CAPITAL LETTER AV WITH HORIZONTAL BAR → LATIN CAPITAL LETTER A, LATIN CAPITAL LETTER V# + {42811, "av"}, // MA# ( ꜻ → av ) LATIN SMALL LETTER AV WITH HORIZONTAL BAR → LATIN SMALL LETTER A, LATIN SMALL LETTER V# + {42812, "AY"}, // MA# ( Ꜽ → AY ) LATIN CAPITAL LETTER AY → LATIN CAPITAL LETTER A, LATIN CAPITAL LETTER Y# + {42813, "ay"}, // MA# ( ꜽ → ay ) LATIN SMALL LETTER AY → LATIN SMALL LETTER A, LATIN SMALL LETTER Y# + {42830, "OO"}, // MA# ( Ꝏ → OO ) LATIN CAPITAL LETTER OO → LATIN CAPITAL LETTER O, LATIN CAPITAL LETTER O# + {42831, "oo"}, // MA# ( ꝏ → oo ) LATIN SMALL LETTER OO → LATIN SMALL LETTER O, LATIN SMALL LETTER O# + {42842, "2"}, // MA# ( Ꝛ → 2 ) LATIN CAPITAL LETTER R ROTUNDA → DIGIT TWO# + {42858, "3"}, // MA# ( Ꝫ → 3 ) LATIN CAPITAL LETTER ET → DIGIT THREE# + {42862, "9"}, // MA# ( Ꝯ → 9 ) LATIN CAPITAL LETTER CON → DIGIT NINE# + {42871, "tf"}, // MA# ( ꝷ → tf ) LATIN SMALL LETTER TUM → LATIN SMALL LETTER T, LATIN SMALL LETTER F# + {42872, "&"}, // MA# ( ꝸ → & ) LATIN SMALL LETTER UM → AMPERSAND# + {42889, ":"}, // MA#* ( ꞉ → : ) MODIFIER LETTER COLON → COLON# + {42892, "'"}, // MA# ( ꞌ → ' ) LATIN SMALL LETTER SALTILLO → APOSTROPHE# + {42904, "F"}, // MA# ( Ꞙ → F ) LATIN CAPITAL LETTER F WITH STROKE → LATIN CAPITAL LETTER F# + {42905, "f"}, // MA# ( ꞙ → f ) LATIN SMALL LETTER F WITH STROKE → LATIN SMALL LETTER F# + {42911, "u"}, // MA# ( ꞟ → u ) LATIN SMALL LETTER VOLAPUK UE → LATIN SMALL LETTER U# + {42923, "3"}, // MA# ( Ɜ → 3 ) LATIN CAPITAL LETTER REVERSED OPEN E → DIGIT THREE# + {42930, "J"}, // MA# ( Ʝ → J ) LATIN CAPITAL LETTER J WITH CROSSED-TAIL → LATIN CAPITAL LETTER J# + {42931, "X"}, // MA# ( Ꭓ → X ) LATIN CAPITAL LETTER CHI → LATIN CAPITAL LETTER X# + {42932, "B"}, // MA# ( Ꞵ → B ) LATIN CAPITAL LETTER BETA → LATIN CAPITAL LETTER B# + {43826, "e"}, // MA# ( ꬲ → e ) LATIN SMALL LETTER BLACKLETTER E → LATIN SMALL LETTER E# + {43829, "f"}, // MA# ( ꬵ → f ) LATIN SMALL LETTER LENIS F → LATIN SMALL LETTER F# + {43837, "o"}, // MA# ( ꬽ → o ) LATIN SMALL LETTER BLACKLETTER O → LATIN SMALL LETTER O# + {43847, "r"}, // MA# ( ꭇ → r ) LATIN SMALL LETTER R WITHOUT HANDLE → LATIN SMALL LETTER R# + {43848, "r"}, // MA# ( ꭈ → r ) LATIN SMALL LETTER DOUBLE R → LATIN SMALL LETTER R# + {43854, "u"}, // MA# ( ꭎ → u ) LATIN SMALL LETTER U WITH SHORT RIGHT LEG → LATIN SMALL LETTER U# + {43858, "u"}, // MA# ( ꭒ → u ) LATIN SMALL LETTER U WITH LEFT HOOK → LATIN SMALL LETTER U# + {43866, "y"}, // MA# ( ꭚ → y ) LATIN SMALL LETTER Y WITH SHORT RIGHT LEG → LATIN SMALL LETTER Y# + {43875, "uo"}, // MA# ( ꭣ → uo ) LATIN SMALL LETTER UO → LATIN SMALL LETTER U, LATIN SMALL LETTER O# + {43893, "i"}, // MA# ( ꭵ → i ) CHEROKEE SMALL LETTER V → LATIN SMALL LETTER I# + {43905, "r"}, // MA# ( ꮁ → r ) CHEROKEE SMALL LETTER HU → LATIN SMALL LETTER R# →ᴦ→→г→ + {43907, "w"}, // MA# ( ꮃ → w ) CHEROKEE SMALL LETTER LA → LATIN SMALL LETTER W# →ᴡ→ + {43923, "z"}, // MA# ( ꮓ → z ) CHEROKEE SMALL LETTER NO → LATIN SMALL LETTER Z# →ᴢ→ + {43945, "v"}, // MA# ( ꮩ → v ) CHEROKEE SMALL LETTER DO → LATIN SMALL LETTER V# →ᴠ→ + {43946, "s"}, // MA# ( ꮪ → s ) CHEROKEE SMALL LETTER DU → LATIN SMALL LETTER S# →ꜱ→ + {43951, "c"}, // MA# ( ꮯ → c ) CHEROKEE SMALL LETTER TLI → LATIN SMALL LETTER C# →ᴄ→ + {64256, "ff"}, // MA# ( ff → ff ) LATIN SMALL LIGATURE FF → LATIN SMALL LETTER F, LATIN SMALL LETTER F# + {64257, "fi"}, // MA# ( fi → fi ) LATIN SMALL LIGATURE FI → LATIN SMALL LETTER F, LATIN SMALL LETTER I# + {64258, "fl"}, // MA# ( fl → fl ) LATIN SMALL LIGATURE FL → LATIN SMALL LETTER F, LATIN SMALL LETTER L# + {64259, "ffi"}, // MA# ( ffi → ffi ) LATIN SMALL LIGATURE FFI → LATIN SMALL LETTER F, LATIN SMALL LETTER F, LATIN SMALL LETTER I# + {64260, "ffl"}, // MA# ( ffl → ffl ) LATIN SMALL LIGATURE FFL → LATIN SMALL LETTER F, LATIN SMALL LETTER F, LATIN SMALL LETTER L# + {64262, "st"}, // MA# ( st → st ) LATIN SMALL LIGATURE ST → LATIN SMALL LETTER S, LATIN SMALL LETTER T# + {64422, "o"}, // MA# ( ‎ﮦ‎ → o ) ARABIC LETTER HEH GOAL ISOLATED FORM → LATIN SMALL LETTER O# →‎ه‎→ + {64423, "o"}, // MA# ( ‎ﮧ‎ → o ) ARABIC LETTER HEH GOAL FINAL FORM → LATIN SMALL LETTER O# →‎ہ‎→→‎ه‎→ + {64424, "o"}, // MA# ( ‎ﮨ‎ → o ) ARABIC LETTER HEH GOAL INITIAL FORM → LATIN SMALL LETTER O# →‎ہ‎→→‎ه‎→ + {64425, "o"}, // MA# ( ‎ﮩ‎ → o ) ARABIC LETTER HEH GOAL MEDIAL FORM → LATIN SMALL LETTER O# →‎ہ‎→→‎ه‎→ + {64426, "o"}, // MA# ( ‎ﮪ‎ → o ) ARABIC LETTER HEH DOACHASHMEE ISOLATED FORM → LATIN SMALL LETTER O# →‎ه‎→ + {64427, "o"}, // MA# ( ‎ﮫ‎ → o ) ARABIC LETTER HEH DOACHASHMEE FINAL FORM → LATIN SMALL LETTER O# →‎ﻪ‎→→‎ه‎→ + {64428, "o"}, // MA# ( ‎ﮬ‎ → o ) ARABIC LETTER HEH DOACHASHMEE INITIAL FORM → LATIN SMALL LETTER O# →‎ﻫ‎→→‎ه‎→ + {64429, "o"}, // MA# ( ‎ﮭ‎ → o ) ARABIC LETTER HEH DOACHASHMEE MEDIAL FORM → LATIN SMALL LETTER O# →‎ﻬ‎→→‎ه‎→ + {64830, "("}, // MA#* ( ﴾ → ( ) ORNATE LEFT PARENTHESIS → LEFT PARENTHESIS# + {64831, ")"}, // MA#* ( ﴿ → ) ) ORNATE RIGHT PARENTHESIS → RIGHT PARENTHESIS# + {65072, ":"}, // MA#* ( ︰ → : ) PRESENTATION FORM FOR VERTICAL TWO DOT LEADER → COLON# + {65101, "_"}, // MA# ( ﹍ → _ ) DASHED LOW LINE → LOW LINE# + {65102, "_"}, // MA# ( ﹎ → _ ) CENTRELINE LOW LINE → LOW LINE# + {65103, "_"}, // MA# ( ﹏ → _ ) WAVY LOW LINE → LOW LINE# + {65112, "-"}, // MA#* ( ﹘ → - ) SMALL EM DASH → HYPHEN-MINUS# + {65128, "\\"}, // MA#* ( ﹨ → \ ) SMALL REVERSE SOLIDUS → REVERSE SOLIDUS# →∖→ + {65165, "l"}, // MA# ( ‎ﺍ‎ → l ) ARABIC LETTER ALEF ISOLATED FORM → LATIN SMALL LETTER L# →‎ا‎→→1→ + {65166, "l"}, // MA# ( ‎ﺎ‎ → l ) ARABIC LETTER ALEF FINAL FORM → LATIN SMALL LETTER L# →‎ا‎→→1→ + {65257, "o"}, // MA# ( ‎ﻩ‎ → o ) ARABIC LETTER HEH ISOLATED FORM → LATIN SMALL LETTER O# →‎ه‎→ + {65258, "o"}, // MA# ( ‎ﻪ‎ → o ) ARABIC LETTER HEH FINAL FORM → LATIN SMALL LETTER O# →‎ه‎→ + {65259, "o"}, // MA# ( ‎ﻫ‎ → o ) ARABIC LETTER HEH INITIAL FORM → LATIN SMALL LETTER O# →‎ه‎→ + {65260, "o"}, // MA# ( ‎ﻬ‎ → o ) ARABIC LETTER HEH MEDIAL FORM → LATIN SMALL LETTER O# →‎ه‎→ + {65281, "!"}, // MA#* ( ! → ! ) FULLWIDTH EXCLAMATION MARK → EXCLAMATION MARK# →ǃ→ + {65282, "\""}, // MA#* ( " → '' ) FULLWIDTH QUOTATION MARK → APOSTROPHE, APOSTROPHE# →”→→"→# Converted to a quote. + {65287, "'"}, // MA#* ( ' → ' ) FULLWIDTH APOSTROPHE → APOSTROPHE# →’→ + {65306, ":"}, // MA#* ( : → : ) FULLWIDTH COLON → COLON# →︰→ + {65313, "A"}, // MA# ( A → A ) FULLWIDTH LATIN CAPITAL LETTER A → LATIN CAPITAL LETTER A# →А→ + {65314, "B"}, // MA# ( B → B ) FULLWIDTH LATIN CAPITAL LETTER B → LATIN CAPITAL LETTER B# →Β→ + {65315, "C"}, // MA# ( C → C ) FULLWIDTH LATIN CAPITAL LETTER C → LATIN CAPITAL LETTER C# →С→ + {65317, "E"}, // MA# ( E → E ) FULLWIDTH LATIN CAPITAL LETTER E → LATIN CAPITAL LETTER E# →Ε→ + {65320, "H"}, // MA# ( H → H ) FULLWIDTH LATIN CAPITAL LETTER H → LATIN CAPITAL LETTER H# →Η→ + {65321, "l"}, // MA# ( I → l ) FULLWIDTH LATIN CAPITAL LETTER I → LATIN SMALL LETTER L# →Ӏ→ + {65322, "J"}, // MA# ( J → J ) FULLWIDTH LATIN CAPITAL LETTER J → LATIN CAPITAL LETTER J# →Ј→ + {65323, "K"}, // MA# ( K → K ) FULLWIDTH LATIN CAPITAL LETTER K → LATIN CAPITAL LETTER K# →Κ→ + {65325, "M"}, // MA# ( M → M ) FULLWIDTH LATIN CAPITAL LETTER M → LATIN CAPITAL LETTER M# →Μ→ + {65326, "N"}, // MA# ( N → N ) FULLWIDTH LATIN CAPITAL LETTER N → LATIN CAPITAL LETTER N# →Ν→ + {65327, "O"}, // MA# ( O → O ) FULLWIDTH LATIN CAPITAL LETTER O → LATIN CAPITAL LETTER O# →О→ + {65328, "P"}, // MA# ( P → P ) FULLWIDTH LATIN CAPITAL LETTER P → LATIN CAPITAL LETTER P# →Р→ + {65331, "S"}, // MA# ( S → S ) FULLWIDTH LATIN CAPITAL LETTER S → LATIN CAPITAL LETTER S# →Ѕ→ + {65332, "T"}, // MA# ( T → T ) FULLWIDTH LATIN CAPITAL LETTER T → LATIN CAPITAL LETTER T# →Т→ + {65336, "X"}, // MA# ( X → X ) FULLWIDTH LATIN CAPITAL LETTER X → LATIN CAPITAL LETTER X# →Х→ + {65337, "Y"}, // MA# ( Y → Y ) FULLWIDTH LATIN CAPITAL LETTER Y → LATIN CAPITAL LETTER Y# →Υ→ + {65338, "Z"}, // MA# ( Z → Z ) FULLWIDTH LATIN CAPITAL LETTER Z → LATIN CAPITAL LETTER Z# →Ζ→ + {65339, "("}, // MA#* ( [ → ( ) FULLWIDTH LEFT SQUARE BRACKET → LEFT PARENTHESIS# →〔→ + {65340, "\\"}, // MA#* ( \ → \ ) FULLWIDTH REVERSE SOLIDUS → REVERSE SOLIDUS# →∖→ + {65341, ")"}, // MA#* ( ] → ) ) FULLWIDTH RIGHT SQUARE BRACKET → RIGHT PARENTHESIS# →〕→ + {65344, "'"}, // MA#* ( ` → ' ) FULLWIDTH GRAVE ACCENT → APOSTROPHE# →‘→ + {65345, "a"}, // MA# ( a → a ) FULLWIDTH LATIN SMALL LETTER A → LATIN SMALL LETTER A# →а→ + {65347, "c"}, // MA# ( c → c ) FULLWIDTH LATIN SMALL LETTER C → LATIN SMALL LETTER C# →с→ + {65349, "e"}, // MA# ( e → e ) FULLWIDTH LATIN SMALL LETTER E → LATIN SMALL LETTER E# →е→ + {65351, "g"}, // MA# ( g → g ) FULLWIDTH LATIN SMALL LETTER G → LATIN SMALL LETTER G# →ɡ→ + {65352, "h"}, // MA# ( h → h ) FULLWIDTH LATIN SMALL LETTER H → LATIN SMALL LETTER H# →һ→ + {65353, "i"}, // MA# ( i → i ) FULLWIDTH LATIN SMALL LETTER I → LATIN SMALL LETTER I# →і→ + {65354, "j"}, // MA# ( j → j ) FULLWIDTH LATIN SMALL LETTER J → LATIN SMALL LETTER J# →ϳ→ + {65356, "l"}, // MA# ( l → l ) FULLWIDTH LATIN SMALL LETTER L → LATIN SMALL LETTER L# →Ⅰ→→Ӏ→ + {65359, "o"}, // MA# ( o → o ) FULLWIDTH LATIN SMALL LETTER O → LATIN SMALL LETTER O# →о→ + {65360, "p"}, // MA# ( p → p ) FULLWIDTH LATIN SMALL LETTER P → LATIN SMALL LETTER P# →р→ + {65363, "s"}, // MA# ( s → s ) FULLWIDTH LATIN SMALL LETTER S → LATIN SMALL LETTER S# →ѕ→ + {65366, "v"}, // MA# ( v → v ) FULLWIDTH LATIN SMALL LETTER V → LATIN SMALL LETTER V# →ν→ + {65368, "x"}, // MA# ( x → x ) FULLWIDTH LATIN SMALL LETTER X → LATIN SMALL LETTER X# →х→ + {65369, "y"}, // MA# ( y → y ) FULLWIDTH LATIN SMALL LETTER Y → LATIN SMALL LETTER Y# →у→ + {65512, "l"}, // MA#* ( │ → l ) HALFWIDTH FORMS LIGHT VERTICAL → LATIN SMALL LETTER L# →|→ + {66178, "B"}, // MA# ( 𐊂 → B ) LYCIAN LETTER B → LATIN CAPITAL LETTER B# + {66182, "E"}, // MA# ( 𐊆 → E ) LYCIAN LETTER I → LATIN CAPITAL LETTER E# + {66183, "F"}, // MA# ( 𐊇 → F ) LYCIAN LETTER W → LATIN CAPITAL LETTER F# + {66186, "l"}, // MA# ( 𐊊 → l ) LYCIAN LETTER J → LATIN SMALL LETTER L# →I→ + {66192, "X"}, // MA# ( 𐊐 → X ) LYCIAN LETTER MM → LATIN CAPITAL LETTER X# + {66194, "O"}, // MA# ( 𐊒 → O ) LYCIAN LETTER U → LATIN CAPITAL LETTER O# + {66197, "P"}, // MA# ( 𐊕 → P ) LYCIAN LETTER R → LATIN CAPITAL LETTER P# + {66198, "S"}, // MA# ( 𐊖 → S ) LYCIAN LETTER S → LATIN CAPITAL LETTER S# + {66199, "T"}, // MA# ( 𐊗 → T ) LYCIAN LETTER T → LATIN CAPITAL LETTER T# + {66203, "+"}, // MA# ( 𐊛 → + ) LYCIAN LETTER H → PLUS SIGN# + {66208, "A"}, // MA# ( 𐊠 → A ) CARIAN LETTER A → LATIN CAPITAL LETTER A# + {66209, "B"}, // MA# ( 𐊡 → B ) CARIAN LETTER P2 → LATIN CAPITAL LETTER B# + {66210, "C"}, // MA# ( 𐊢 → C ) CARIAN LETTER D → LATIN CAPITAL LETTER C# + {66213, "F"}, // MA# ( 𐊥 → F ) CARIAN LETTER R → LATIN CAPITAL LETTER F# + {66219, "O"}, // MA# ( 𐊫 → O ) CARIAN LETTER O → LATIN CAPITAL LETTER O# + {66224, "M"}, // MA# ( 𐊰 → M ) CARIAN LETTER S → LATIN CAPITAL LETTER M# + {66225, "T"}, // MA# ( 𐊱 → T ) CARIAN LETTER C-18 → LATIN CAPITAL LETTER T# + {66226, "Y"}, // MA# ( 𐊲 → Y ) CARIAN LETTER U → LATIN CAPITAL LETTER Y# + {66228, "X"}, // MA# ( 𐊴 → X ) CARIAN LETTER X → LATIN CAPITAL LETTER X# + {66255, "H"}, // MA# ( 𐋏 → H ) CARIAN LETTER E2 → LATIN CAPITAL LETTER H# + {66293, "Z"}, // MA#* ( 𐋵 → Z ) COPTIC EPACT NUMBER THREE HUNDRED → LATIN CAPITAL LETTER Z# + {66305, "B"}, // MA# ( 𐌁 → B ) OLD ITALIC LETTER BE → LATIN CAPITAL LETTER B# + {66306, "C"}, // MA# ( 𐌂 → C ) OLD ITALIC LETTER KE → LATIN CAPITAL LETTER C# + {66313, "l"}, // MA# ( 𐌉 → l ) OLD ITALIC LETTER I → LATIN SMALL LETTER L# →I→ + {66321, "M"}, // MA# ( 𐌑 → M ) OLD ITALIC LETTER SHE → LATIN CAPITAL LETTER M# + {66325, "T"}, // MA# ( 𐌕 → T ) OLD ITALIC LETTER TE → LATIN CAPITAL LETTER T# + {66327, "X"}, // MA# ( 𐌗 → X ) OLD ITALIC LETTER EKS → LATIN CAPITAL LETTER X# + {66330, "8"}, // MA# ( 𐌚 → 8 ) OLD ITALIC LETTER EF → DIGIT EIGHT# + {66335, "*"}, // MA# ( 𐌟 → * ) OLD ITALIC LETTER ESS → ASTERISK# + {66336, "l"}, // MA#* ( 𐌠 → l ) OLD ITALIC NUMERAL ONE → LATIN SMALL LETTER L# →𐌉→→I→ + {66338, "X"}, // MA#* ( 𐌢 → X ) OLD ITALIC NUMERAL TEN → LATIN CAPITAL LETTER X# →𐌗→ + {66564, "O"}, // MA# ( 𐐄 → O ) DESERET CAPITAL LETTER LONG O → LATIN CAPITAL LETTER O# + {66581, "C"}, // MA# ( 𐐕 → C ) DESERET CAPITAL LETTER CHEE → LATIN CAPITAL LETTER C# + {66587, "L"}, // MA# ( 𐐛 → L ) DESERET CAPITAL LETTER ETH → LATIN CAPITAL LETTER L# + {66592, "S"}, // MA# ( 𐐠 → S ) DESERET CAPITAL LETTER ZHEE → LATIN CAPITAL LETTER S# + {66604, "o"}, // MA# ( 𐐬 → o ) DESERET SMALL LETTER LONG O → LATIN SMALL LETTER O# + {66621, "c"}, // MA# ( 𐐽 → c ) DESERET SMALL LETTER CHEE → LATIN SMALL LETTER C# + {66632, "s"}, // MA# ( 𐑈 → s ) DESERET SMALL LETTER ZHEE → LATIN SMALL LETTER S# + {66740, "R"}, // MA# ( 𐒴 → R ) OSAGE CAPITAL LETTER BRA → LATIN CAPITAL LETTER R# →Ʀ→ + {66754, "O"}, // MA# ( 𐓂 → O ) OSAGE CAPITAL LETTER O → LATIN CAPITAL LETTER O# + {66766, "U"}, // MA# ( 𐓎 → U ) OSAGE CAPITAL LETTER U → LATIN CAPITAL LETTER U# + {66770, "7"}, // MA# ( 𐓒 → 7 ) OSAGE CAPITAL LETTER ZA → DIGIT SEVEN# + {66794, "o"}, // MA# ( 𐓪 → o ) OSAGE SMALL LETTER O → LATIN SMALL LETTER O# + {66806, "u"}, // MA# ( 𐓶 → u ) OSAGE SMALL LETTER U → LATIN SMALL LETTER U# →ᴜ→ + {66835, "N"}, // MA# ( 𐔓 → N ) ELBASAN LETTER NE → LATIN CAPITAL LETTER N# + {66838, "O"}, // MA# ( 𐔖 → O ) ELBASAN LETTER O → LATIN CAPITAL LETTER O# + {66840, "K"}, // MA# ( 𐔘 → K ) ELBASAN LETTER QE → LATIN CAPITAL LETTER K# + {66844, "C"}, // MA# ( 𐔜 → C ) ELBASAN LETTER SHE → LATIN CAPITAL LETTER C# + {66845, "V"}, // MA# ( 𐔝 → V ) ELBASAN LETTER TE → LATIN CAPITAL LETTER V# + {66853, "F"}, // MA# ( 𐔥 → F ) ELBASAN LETTER GHE → LATIN CAPITAL LETTER F# + {66854, "L"}, // MA# ( 𐔦 → L ) ELBASAN LETTER GHAMMA → LATIN CAPITAL LETTER L# + {66855, "X"}, // MA# ( 𐔧 → X ) ELBASAN LETTER KHE → LATIN CAPITAL LETTER X# + {68176, "."}, // MA#* ( ‎𐩐‎ → . ) KHAROSHTHI PUNCTUATION DOT → FULL STOP# + {70864, "O"}, // MA# ( 𑓐 → O ) TIRHUTA DIGIT ZERO → LATIN CAPITAL LETTER O# →০→→0→ + {71424, "rn"}, // MA# ( 𑜀 → rn ) AHOM LETTER KA → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {71430, "v"}, // MA# ( 𑜆 → v ) AHOM LETTER PA → LATIN SMALL LETTER V# + {71434, "w"}, // MA# ( 𑜊 → w ) AHOM LETTER JA → LATIN SMALL LETTER W# + {71438, "w"}, // MA# ( 𑜎 → w ) AHOM LETTER LA → LATIN SMALL LETTER W# + {71439, "w"}, // MA# ( 𑜏 → w ) AHOM LETTER SA → LATIN SMALL LETTER W# + {71840, "V"}, // MA# ( 𑢠 → V ) WARANG CITI CAPITAL LETTER NGAA → LATIN CAPITAL LETTER V# + {71842, "F"}, // MA# ( 𑢢 → F ) WARANG CITI CAPITAL LETTER WI → LATIN CAPITAL LETTER F# + {71843, "L"}, // MA# ( 𑢣 → L ) WARANG CITI CAPITAL LETTER YU → LATIN CAPITAL LETTER L# + {71844, "Y"}, // MA# ( 𑢤 → Y ) WARANG CITI CAPITAL LETTER YA → LATIN CAPITAL LETTER Y# + {71846, "E"}, // MA# ( 𑢦 → E ) WARANG CITI CAPITAL LETTER II → LATIN CAPITAL LETTER E# + {71849, "Z"}, // MA# ( 𑢩 → Z ) WARANG CITI CAPITAL LETTER O → LATIN CAPITAL LETTER Z# + {71852, "9"}, // MA# ( 𑢬 → 9 ) WARANG CITI CAPITAL LETTER KO → DIGIT NINE# + {71854, "E"}, // MA# ( 𑢮 → E ) WARANG CITI CAPITAL LETTER YUJ → LATIN CAPITAL LETTER E# + {71855, "4"}, // MA# ( 𑢯 → 4 ) WARANG CITI CAPITAL LETTER UC → DIGIT FOUR# + {71858, "L"}, // MA# ( 𑢲 → L ) WARANG CITI CAPITAL LETTER TTE → LATIN CAPITAL LETTER L# + {71861, "O"}, // MA# ( 𑢵 → O ) WARANG CITI CAPITAL LETTER AT → LATIN CAPITAL LETTER O# + {71864, "U"}, // MA# ( 𑢸 → U ) WARANG CITI CAPITAL LETTER PU → LATIN CAPITAL LETTER U# + {71867, "5"}, // MA# ( 𑢻 → 5 ) WARANG CITI CAPITAL LETTER HORR → DIGIT FIVE# + {71868, "T"}, // MA# ( 𑢼 → T ) WARANG CITI CAPITAL LETTER HAR → LATIN CAPITAL LETTER T# + {71872, "v"}, // MA# ( 𑣀 → v ) WARANG CITI SMALL LETTER NGAA → LATIN SMALL LETTER V# + {71873, "s"}, // MA# ( 𑣁 → s ) WARANG CITI SMALL LETTER A → LATIN SMALL LETTER S# + {71874, "F"}, // MA# ( 𑣂 → F ) WARANG CITI SMALL LETTER WI → LATIN CAPITAL LETTER F# + {71875, "i"}, // MA# ( 𑣃 → i ) WARANG CITI SMALL LETTER YU → LATIN SMALL LETTER I# →ι→ + {71876, "z"}, // MA# ( 𑣄 → z ) WARANG CITI SMALL LETTER YA → LATIN SMALL LETTER Z# + {71878, "7"}, // MA# ( 𑣆 → 7 ) WARANG CITI SMALL LETTER II → DIGIT SEVEN# + {71880, "o"}, // MA# ( 𑣈 → o ) WARANG CITI SMALL LETTER E → LATIN SMALL LETTER O# + {71882, "3"}, // MA# ( 𑣊 → 3 ) WARANG CITI SMALL LETTER ANG → DIGIT THREE# + {71884, "9"}, // MA# ( 𑣌 → 9 ) WARANG CITI SMALL LETTER KO → DIGIT NINE# + {71893, "6"}, // MA# ( 𑣕 → 6 ) WARANG CITI SMALL LETTER AT → DIGIT SIX# + {71894, "9"}, // MA# ( 𑣖 → 9 ) WARANG CITI SMALL LETTER AM → DIGIT NINE# + {71895, "o"}, // MA# ( 𑣗 → o ) WARANG CITI SMALL LETTER BU → LATIN SMALL LETTER O# + {71896, "u"}, // MA# ( 𑣘 → u ) WARANG CITI SMALL LETTER PU → LATIN SMALL LETTER U# →υ→→ʋ→ + {71900, "y"}, // MA# ( 𑣜 → y ) WARANG CITI SMALL LETTER HAR → LATIN SMALL LETTER Y# →ɣ→→γ→ + {71904, "O"}, // MA# ( 𑣠 → O ) WARANG CITI DIGIT ZERO → LATIN CAPITAL LETTER O# →0→ + {71907, "rn"}, // MA# ( 𑣣 → rn ) WARANG CITI DIGIT THREE → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {71909, "Z"}, // MA# ( 𑣥 → Z ) WARANG CITI DIGIT FIVE → LATIN CAPITAL LETTER Z# + {71910, "W"}, // MA# ( 𑣦 → W ) WARANG CITI DIGIT SIX → LATIN CAPITAL LETTER W# + {71913, "C"}, // MA# ( 𑣩 → C ) WARANG CITI DIGIT NINE → LATIN CAPITAL LETTER C# + {71916, "X"}, // MA#* ( 𑣬 → X ) WARANG CITI NUMBER THIRTY → LATIN CAPITAL LETTER X# + {71919, "W"}, // MA#* ( 𑣯 → W ) WARANG CITI NUMBER SIXTY → LATIN CAPITAL LETTER W# + {71922, "C"}, // MA#* ( 𑣲 → C ) WARANG CITI NUMBER NINETY → LATIN CAPITAL LETTER C# + {93960, "V"}, // MA# ( 𖼈 → V ) MIAO LETTER VA → LATIN CAPITAL LETTER V# + {93962, "T"}, // MA# ( 𖼊 → T ) MIAO LETTER TA → LATIN CAPITAL LETTER T# + {93974, "L"}, // MA# ( 𖼖 → L ) MIAO LETTER LA → LATIN CAPITAL LETTER L# + {93992, "l"}, // MA# ( 𖼨 → l ) MIAO LETTER GHA → LATIN SMALL LETTER L# →I→ + {94005, "R"}, // MA# ( 𖼵 → R ) MIAO LETTER ZHA → LATIN CAPITAL LETTER R# + {94010, "S"}, // MA# ( 𖼺 → S ) MIAO LETTER SA → LATIN CAPITAL LETTER S# + {94011, "3"}, // MA# ( 𖼻 → 3 ) MIAO LETTER ZA → DIGIT THREE# →Ʒ→ + {94015, ">"}, // MA# ( 𖼿 → > ) MIAO LETTER ARCHAIC ZZA → GREATER-THAN SIGN# + {94016, "A"}, // MA# ( 𖽀 → A ) MIAO LETTER ZZYA → LATIN CAPITAL LETTER A# + {94018, "U"}, // MA# ( 𖽂 → U ) MIAO LETTER WA → LATIN CAPITAL LETTER U# + {94019, "Y"}, // MA# ( 𖽃 → Y ) MIAO LETTER AH → LATIN CAPITAL LETTER Y# + {94033, "'"}, // MA# ( 𖽑 → ' ) MIAO SIGN ASPIRATION → APOSTROPHE# →ʼ→→′→ + {94034, "'"}, // MA# ( 𖽒 → ' ) MIAO SIGN REFORMED VOICING → APOSTROPHE# →ʻ→→‘→ + {119060, "{"}, // MA#* ( 𝄔 → { ) MUSICAL SYMBOL BRACE → LEFT CURLY BRACKET# + {119149, "."}, // MA# ( 𝅭 → . ) MUSICAL SYMBOL COMBINING AUGMENTATION DOT → FULL STOP# + {119302, "3"}, // MA#* ( 𝈆 → 3 ) GREEK VOCAL NOTATION SYMBOL-7 → DIGIT THREE# + {119309, "V"}, // MA#* ( 𝈍 → V ) GREEK VOCAL NOTATION SYMBOL-14 → LATIN CAPITAL LETTER V# + {119311, "\\"}, // MA#* ( 𝈏 → \ ) GREEK VOCAL NOTATION SYMBOL-16 → REVERSE SOLIDUS# + {119314, "7"}, // MA#* ( 𝈒 → 7 ) GREEK VOCAL NOTATION SYMBOL-19 → DIGIT SEVEN# + {119315, "F"}, // MA#* ( 𝈓 → F ) GREEK VOCAL NOTATION SYMBOL-20 → LATIN CAPITAL LETTER F# →Ϝ→ + {119318, "R"}, // MA#* ( 𝈖 → R ) GREEK VOCAL NOTATION SYMBOL-23 → LATIN CAPITAL LETTER R# + {119338, "L"}, // MA#* ( 𝈪 → L ) GREEK INSTRUMENTAL NOTATION SYMBOL-23 → LATIN CAPITAL LETTER L# + {119350, "<"}, // MA#* ( 𝈶 → < ) GREEK INSTRUMENTAL NOTATION SYMBOL-40 → LESS-THAN SIGN# + {119351, ">"}, // MA#* ( 𝈷 → > ) GREEK INSTRUMENTAL NOTATION SYMBOL-42 → GREATER-THAN SIGN# + {119354, "/"}, // MA#* ( 𝈺 → / ) GREEK INSTRUMENTAL NOTATION SYMBOL-47 → SOLIDUS# + {119355, "\\"}, // MA#* ( 𝈻 → \ ) GREEK INSTRUMENTAL NOTATION SYMBOL-48 → REVERSE SOLIDUS# →𝈏→ + {119808, "A"}, // MA# ( 𝐀 → A ) MATHEMATICAL BOLD CAPITAL A → LATIN CAPITAL LETTER A# + {119809, "B"}, // MA# ( 𝐁 → B ) MATHEMATICAL BOLD CAPITAL B → LATIN CAPITAL LETTER B# + {119810, "C"}, // MA# ( 𝐂 → C ) MATHEMATICAL BOLD CAPITAL C → LATIN CAPITAL LETTER C# + {119811, "D"}, // MA# ( 𝐃 → D ) MATHEMATICAL BOLD CAPITAL D → LATIN CAPITAL LETTER D# + {119812, "E"}, // MA# ( 𝐄 → E ) MATHEMATICAL BOLD CAPITAL E → LATIN CAPITAL LETTER E# + {119813, "F"}, // MA# ( 𝐅 → F ) MATHEMATICAL BOLD CAPITAL F → LATIN CAPITAL LETTER F# + {119814, "G"}, // MA# ( 𝐆 → G ) MATHEMATICAL BOLD CAPITAL G → LATIN CAPITAL LETTER G# + {119815, "H"}, // MA# ( 𝐇 → H ) MATHEMATICAL BOLD CAPITAL H → LATIN CAPITAL LETTER H# + {119816, "l"}, // MA# ( 𝐈 → l ) MATHEMATICAL BOLD CAPITAL I → LATIN SMALL LETTER L# →I→ + {119817, "J"}, // MA# ( 𝐉 → J ) MATHEMATICAL BOLD CAPITAL J → LATIN CAPITAL LETTER J# + {119818, "K"}, // MA# ( 𝐊 → K ) MATHEMATICAL BOLD CAPITAL K → LATIN CAPITAL LETTER K# + {119819, "L"}, // MA# ( 𝐋 → L ) MATHEMATICAL BOLD CAPITAL L → LATIN CAPITAL LETTER L# + {119820, "M"}, // MA# ( 𝐌 → M ) MATHEMATICAL BOLD CAPITAL M → LATIN CAPITAL LETTER M# + {119821, "N"}, // MA# ( 𝐍 → N ) MATHEMATICAL BOLD CAPITAL N → LATIN CAPITAL LETTER N# + {119822, "O"}, // MA# ( 𝐎 → O ) MATHEMATICAL BOLD CAPITAL O → LATIN CAPITAL LETTER O# + {119823, "P"}, // MA# ( 𝐏 → P ) MATHEMATICAL BOLD CAPITAL P → LATIN CAPITAL LETTER P# + {119824, "Q"}, // MA# ( 𝐐 → Q ) MATHEMATICAL BOLD CAPITAL Q → LATIN CAPITAL LETTER Q# + {119825, "R"}, // MA# ( 𝐑 → R ) MATHEMATICAL BOLD CAPITAL R → LATIN CAPITAL LETTER R# + {119826, "S"}, // MA# ( 𝐒 → S ) MATHEMATICAL BOLD CAPITAL S → LATIN CAPITAL LETTER S# + {119827, "T"}, // MA# ( 𝐓 → T ) MATHEMATICAL BOLD CAPITAL T → LATIN CAPITAL LETTER T# + {119828, "U"}, // MA# ( 𝐔 → U ) MATHEMATICAL BOLD CAPITAL U → LATIN CAPITAL LETTER U# + {119829, "V"}, // MA# ( 𝐕 → V ) MATHEMATICAL BOLD CAPITAL V → LATIN CAPITAL LETTER V# + {119830, "W"}, // MA# ( 𝐖 → W ) MATHEMATICAL BOLD CAPITAL W → LATIN CAPITAL LETTER W# + {119831, "X"}, // MA# ( 𝐗 → X ) MATHEMATICAL BOLD CAPITAL X → LATIN CAPITAL LETTER X# + {119832, "Y"}, // MA# ( 𝐘 → Y ) MATHEMATICAL BOLD CAPITAL Y → LATIN CAPITAL LETTER Y# + {119833, "Z"}, // MA# ( 𝐙 → Z ) MATHEMATICAL BOLD CAPITAL Z → LATIN CAPITAL LETTER Z# + {119834, "a"}, // MA# ( 𝐚 → a ) MATHEMATICAL BOLD SMALL A → LATIN SMALL LETTER A# + {119835, "b"}, // MA# ( 𝐛 → b ) MATHEMATICAL BOLD SMALL B → LATIN SMALL LETTER B# + {119836, "c"}, // MA# ( 𝐜 → c ) MATHEMATICAL BOLD SMALL C → LATIN SMALL LETTER C# + {119837, "d"}, // MA# ( 𝐝 → d ) MATHEMATICAL BOLD SMALL D → LATIN SMALL LETTER D# + {119838, "e"}, // MA# ( 𝐞 → e ) MATHEMATICAL BOLD SMALL E → LATIN SMALL LETTER E# + {119839, "f"}, // MA# ( 𝐟 → f ) MATHEMATICAL BOLD SMALL F → LATIN SMALL LETTER F# + {119840, "g"}, // MA# ( 𝐠 → g ) MATHEMATICAL BOLD SMALL G → LATIN SMALL LETTER G# + {119841, "h"}, // MA# ( 𝐡 → h ) MATHEMATICAL BOLD SMALL H → LATIN SMALL LETTER H# + {119842, "i"}, // MA# ( 𝐢 → i ) MATHEMATICAL BOLD SMALL I → LATIN SMALL LETTER I# + {119843, "j"}, // MA# ( 𝐣 → j ) MATHEMATICAL BOLD SMALL J → LATIN SMALL LETTER J# + {119844, "k"}, // MA# ( 𝐤 → k ) MATHEMATICAL BOLD SMALL K → LATIN SMALL LETTER K# + {119845, "l"}, // MA# ( 𝐥 → l ) MATHEMATICAL BOLD SMALL L → LATIN SMALL LETTER L# + {119846, "rn"}, // MA# ( 𝐦 → rn ) MATHEMATICAL BOLD SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {119847, "n"}, // MA# ( 𝐧 → n ) MATHEMATICAL BOLD SMALL N → LATIN SMALL LETTER N# + {119848, "o"}, // MA# ( 𝐨 → o ) MATHEMATICAL BOLD SMALL O → LATIN SMALL LETTER O# + {119849, "p"}, // MA# ( 𝐩 → p ) MATHEMATICAL BOLD SMALL P → LATIN SMALL LETTER P# + {119850, "q"}, // MA# ( 𝐪 → q ) MATHEMATICAL BOLD SMALL Q → LATIN SMALL LETTER Q# + {119851, "r"}, // MA# ( 𝐫 → r ) MATHEMATICAL BOLD SMALL R → LATIN SMALL LETTER R# + {119852, "s"}, // MA# ( 𝐬 → s ) MATHEMATICAL BOLD SMALL S → LATIN SMALL LETTER S# + {119853, "t"}, // MA# ( 𝐭 → t ) MATHEMATICAL BOLD SMALL T → LATIN SMALL LETTER T# + {119854, "u"}, // MA# ( 𝐮 → u ) MATHEMATICAL BOLD SMALL U → LATIN SMALL LETTER U# + {119855, "v"}, // MA# ( 𝐯 → v ) MATHEMATICAL BOLD SMALL V → LATIN SMALL LETTER V# + {119856, "w"}, // MA# ( 𝐰 → w ) MATHEMATICAL BOLD SMALL W → LATIN SMALL LETTER W# + {119857, "x"}, // MA# ( 𝐱 → x ) MATHEMATICAL BOLD SMALL X → LATIN SMALL LETTER X# + {119858, "y"}, // MA# ( 𝐲 → y ) MATHEMATICAL BOLD SMALL Y → LATIN SMALL LETTER Y# + {119859, "z"}, // MA# ( 𝐳 → z ) MATHEMATICAL BOLD SMALL Z → LATIN SMALL LETTER Z# + {119860, "A"}, // MA# ( 𝐴 → A ) MATHEMATICAL ITALIC CAPITAL A → LATIN CAPITAL LETTER A# + {119861, "B"}, // MA# ( 𝐵 → B ) MATHEMATICAL ITALIC CAPITAL B → LATIN CAPITAL LETTER B# + {119862, "C"}, // MA# ( 𝐶 → C ) MATHEMATICAL ITALIC CAPITAL C → LATIN CAPITAL LETTER C# + {119863, "D"}, // MA# ( 𝐷 → D ) MATHEMATICAL ITALIC CAPITAL D → LATIN CAPITAL LETTER D# + {119864, "E"}, // MA# ( 𝐸 → E ) MATHEMATICAL ITALIC CAPITAL E → LATIN CAPITAL LETTER E# + {119865, "F"}, // MA# ( 𝐹 → F ) MATHEMATICAL ITALIC CAPITAL F → LATIN CAPITAL LETTER F# + {119866, "G"}, // MA# ( 𝐺 → G ) MATHEMATICAL ITALIC CAPITAL G → LATIN CAPITAL LETTER G# + {119867, "H"}, // MA# ( 𝐻 → H ) MATHEMATICAL ITALIC CAPITAL H → LATIN CAPITAL LETTER H# + {119868, "l"}, // MA# ( 𝐼 → l ) MATHEMATICAL ITALIC CAPITAL I → LATIN SMALL LETTER L# →I→ + {119869, "J"}, // MA# ( 𝐽 → J ) MATHEMATICAL ITALIC CAPITAL J → LATIN CAPITAL LETTER J# + {119870, "K"}, // MA# ( 𝐾 → K ) MATHEMATICAL ITALIC CAPITAL K → LATIN CAPITAL LETTER K# + {119871, "L"}, // MA# ( 𝐿 → L ) MATHEMATICAL ITALIC CAPITAL L → LATIN CAPITAL LETTER L# + {119872, "M"}, // MA# ( 𝑀 → M ) MATHEMATICAL ITALIC CAPITAL M → LATIN CAPITAL LETTER M# + {119873, "N"}, // MA# ( 𝑁 → N ) MATHEMATICAL ITALIC CAPITAL N → LATIN CAPITAL LETTER N# + {119874, "O"}, // MA# ( 𝑂 → O ) MATHEMATICAL ITALIC CAPITAL O → LATIN CAPITAL LETTER O# + {119875, "P"}, // MA# ( 𝑃 → P ) MATHEMATICAL ITALIC CAPITAL P → LATIN CAPITAL LETTER P# + {119876, "Q"}, // MA# ( 𝑄 → Q ) MATHEMATICAL ITALIC CAPITAL Q → LATIN CAPITAL LETTER Q# + {119877, "R"}, // MA# ( 𝑅 → R ) MATHEMATICAL ITALIC CAPITAL R → LATIN CAPITAL LETTER R# + {119878, "S"}, // MA# ( 𝑆 → S ) MATHEMATICAL ITALIC CAPITAL S → LATIN CAPITAL LETTER S# + {119879, "T"}, // MA# ( 𝑇 → T ) MATHEMATICAL ITALIC CAPITAL T → LATIN CAPITAL LETTER T# + {119880, "U"}, // MA# ( 𝑈 → U ) MATHEMATICAL ITALIC CAPITAL U → LATIN CAPITAL LETTER U# + {119881, "V"}, // MA# ( 𝑉 → V ) MATHEMATICAL ITALIC CAPITAL V → LATIN CAPITAL LETTER V# + {119882, "W"}, // MA# ( 𝑊 → W ) MATHEMATICAL ITALIC CAPITAL W → LATIN CAPITAL LETTER W# + {119883, "X"}, // MA# ( 𝑋 → X ) MATHEMATICAL ITALIC CAPITAL X → LATIN CAPITAL LETTER X# + {119884, "Y"}, // MA# ( 𝑌 → Y ) MATHEMATICAL ITALIC CAPITAL Y → LATIN CAPITAL LETTER Y# + {119885, "Z"}, // MA# ( 𝑍 → Z ) MATHEMATICAL ITALIC CAPITAL Z → LATIN CAPITAL LETTER Z# + {119886, "a"}, // MA# ( 𝑎 → a ) MATHEMATICAL ITALIC SMALL A → LATIN SMALL LETTER A# + {119887, "b"}, // MA# ( 𝑏 → b ) MATHEMATICAL ITALIC SMALL B → LATIN SMALL LETTER B# + {119888, "c"}, // MA# ( 𝑐 → c ) MATHEMATICAL ITALIC SMALL C → LATIN SMALL LETTER C# + {119889, "d"}, // MA# ( 𝑑 → d ) MATHEMATICAL ITALIC SMALL D → LATIN SMALL LETTER D# + {119890, "e"}, // MA# ( 𝑒 → e ) MATHEMATICAL ITALIC SMALL E → LATIN SMALL LETTER E# + {119891, "f"}, // MA# ( 𝑓 → f ) MATHEMATICAL ITALIC SMALL F → LATIN SMALL LETTER F# + {119892, "g"}, // MA# ( 𝑔 → g ) MATHEMATICAL ITALIC SMALL G → LATIN SMALL LETTER G# + {119894, "i"}, // MA# ( 𝑖 → i ) MATHEMATICAL ITALIC SMALL I → LATIN SMALL LETTER I# + {119895, "j"}, // MA# ( 𝑗 → j ) MATHEMATICAL ITALIC SMALL J → LATIN SMALL LETTER J# + {119896, "k"}, // MA# ( 𝑘 → k ) MATHEMATICAL ITALIC SMALL K → LATIN SMALL LETTER K# + {119897, "l"}, // MA# ( 𝑙 → l ) MATHEMATICAL ITALIC SMALL L → LATIN SMALL LETTER L# + {119898, "rn"}, // MA# ( 𝑚 → rn ) MATHEMATICAL ITALIC SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {119899, "n"}, // MA# ( 𝑛 → n ) MATHEMATICAL ITALIC SMALL N → LATIN SMALL LETTER N# + {119900, "o"}, // MA# ( 𝑜 → o ) MATHEMATICAL ITALIC SMALL O → LATIN SMALL LETTER O# + {119901, "p"}, // MA# ( 𝑝 → p ) MATHEMATICAL ITALIC SMALL P → LATIN SMALL LETTER P# + {119902, "q"}, // MA# ( 𝑞 → q ) MATHEMATICAL ITALIC SMALL Q → LATIN SMALL LETTER Q# + {119903, "r"}, // MA# ( 𝑟 → r ) MATHEMATICAL ITALIC SMALL R → LATIN SMALL LETTER R# + {119904, "s"}, // MA# ( 𝑠 → s ) MATHEMATICAL ITALIC SMALL S → LATIN SMALL LETTER S# + {119905, "t"}, // MA# ( 𝑡 → t ) MATHEMATICAL ITALIC SMALL T → LATIN SMALL LETTER T# + {119906, "u"}, // MA# ( 𝑢 → u ) MATHEMATICAL ITALIC SMALL U → LATIN SMALL LETTER U# + {119907, "v"}, // MA# ( 𝑣 → v ) MATHEMATICAL ITALIC SMALL V → LATIN SMALL LETTER V# + {119908, "w"}, // MA# ( 𝑤 → w ) MATHEMATICAL ITALIC SMALL W → LATIN SMALL LETTER W# + {119909, "x"}, // MA# ( 𝑥 → x ) MATHEMATICAL ITALIC SMALL X → LATIN SMALL LETTER X# + {119910, "y"}, // MA# ( 𝑦 → y ) MATHEMATICAL ITALIC SMALL Y → LATIN SMALL LETTER Y# + {119911, "z"}, // MA# ( 𝑧 → z ) MATHEMATICAL ITALIC SMALL Z → LATIN SMALL LETTER Z# + {119912, "A"}, // MA# ( 𝑨 → A ) MATHEMATICAL BOLD ITALIC CAPITAL A → LATIN CAPITAL LETTER A# + {119913, "B"}, // MA# ( 𝑩 → B ) MATHEMATICAL BOLD ITALIC CAPITAL B → LATIN CAPITAL LETTER B# + {119914, "C"}, // MA# ( 𝑪 → C ) MATHEMATICAL BOLD ITALIC CAPITAL C → LATIN CAPITAL LETTER C# + {119915, "D"}, // MA# ( 𝑫 → D ) MATHEMATICAL BOLD ITALIC CAPITAL D → LATIN CAPITAL LETTER D# + {119916, "E"}, // MA# ( 𝑬 → E ) MATHEMATICAL BOLD ITALIC CAPITAL E → LATIN CAPITAL LETTER E# + {119917, "F"}, // MA# ( 𝑭 → F ) MATHEMATICAL BOLD ITALIC CAPITAL F → LATIN CAPITAL LETTER F# + {119918, "G"}, // MA# ( 𝑮 → G ) MATHEMATICAL BOLD ITALIC CAPITAL G → LATIN CAPITAL LETTER G# + {119919, "H"}, // MA# ( 𝑯 → H ) MATHEMATICAL BOLD ITALIC CAPITAL H → LATIN CAPITAL LETTER H# + {119920, "l"}, // MA# ( 𝑰 → l ) MATHEMATICAL BOLD ITALIC CAPITAL I → LATIN SMALL LETTER L# →I→ + {119921, "J"}, // MA# ( 𝑱 → J ) MATHEMATICAL BOLD ITALIC CAPITAL J → LATIN CAPITAL LETTER J# + {119922, "K"}, // MA# ( 𝑲 → K ) MATHEMATICAL BOLD ITALIC CAPITAL K → LATIN CAPITAL LETTER K# + {119923, "L"}, // MA# ( 𝑳 → L ) MATHEMATICAL BOLD ITALIC CAPITAL L → LATIN CAPITAL LETTER L# + {119924, "M"}, // MA# ( 𝑴 → M ) MATHEMATICAL BOLD ITALIC CAPITAL M → LATIN CAPITAL LETTER M# + {119925, "N"}, // MA# ( 𝑵 → N ) MATHEMATICAL BOLD ITALIC CAPITAL N → LATIN CAPITAL LETTER N# + {119926, "O"}, // MA# ( 𝑶 → O ) MATHEMATICAL BOLD ITALIC CAPITAL O → LATIN CAPITAL LETTER O# + {119927, "P"}, // MA# ( 𝑷 → P ) MATHEMATICAL BOLD ITALIC CAPITAL P → LATIN CAPITAL LETTER P# + {119928, "Q"}, // MA# ( 𝑸 → Q ) MATHEMATICAL BOLD ITALIC CAPITAL Q → LATIN CAPITAL LETTER Q# + {119929, "R"}, // MA# ( 𝑹 → R ) MATHEMATICAL BOLD ITALIC CAPITAL R → LATIN CAPITAL LETTER R# + {119930, "S"}, // MA# ( 𝑺 → S ) MATHEMATICAL BOLD ITALIC CAPITAL S → LATIN CAPITAL LETTER S# + {119931, "T"}, // MA# ( 𝑻 → T ) MATHEMATICAL BOLD ITALIC CAPITAL T → LATIN CAPITAL LETTER T# + {119932, "U"}, // MA# ( 𝑼 → U ) MATHEMATICAL BOLD ITALIC CAPITAL U → LATIN CAPITAL LETTER U# + {119933, "V"}, // MA# ( 𝑽 → V ) MATHEMATICAL BOLD ITALIC CAPITAL V → LATIN CAPITAL LETTER V# + {119934, "W"}, // MA# ( 𝑾 → W ) MATHEMATICAL BOLD ITALIC CAPITAL W → LATIN CAPITAL LETTER W# + {119935, "X"}, // MA# ( 𝑿 → X ) MATHEMATICAL BOLD ITALIC CAPITAL X → LATIN CAPITAL LETTER X# + {119936, "Y"}, // MA# ( 𝒀 → Y ) MATHEMATICAL BOLD ITALIC CAPITAL Y → LATIN CAPITAL LETTER Y# + {119937, "Z"}, // MA# ( 𝒁 → Z ) MATHEMATICAL BOLD ITALIC CAPITAL Z → LATIN CAPITAL LETTER Z# + {119938, "a"}, // MA# ( 𝒂 → a ) MATHEMATICAL BOLD ITALIC SMALL A → LATIN SMALL LETTER A# + {119939, "b"}, // MA# ( 𝒃 → b ) MATHEMATICAL BOLD ITALIC SMALL B → LATIN SMALL LETTER B# + {119940, "c"}, // MA# ( 𝒄 → c ) MATHEMATICAL BOLD ITALIC SMALL C → LATIN SMALL LETTER C# + {119941, "d"}, // MA# ( 𝒅 → d ) MATHEMATICAL BOLD ITALIC SMALL D → LATIN SMALL LETTER D# + {119942, "e"}, // MA# ( 𝒆 → e ) MATHEMATICAL BOLD ITALIC SMALL E → LATIN SMALL LETTER E# + {119943, "f"}, // MA# ( 𝒇 → f ) MATHEMATICAL BOLD ITALIC SMALL F → LATIN SMALL LETTER F# + {119944, "g"}, // MA# ( 𝒈 → g ) MATHEMATICAL BOLD ITALIC SMALL G → LATIN SMALL LETTER G# + {119945, "h"}, // MA# ( 𝒉 → h ) MATHEMATICAL BOLD ITALIC SMALL H → LATIN SMALL LETTER H# + {119946, "i"}, // MA# ( 𝒊 → i ) MATHEMATICAL BOLD ITALIC SMALL I → LATIN SMALL LETTER I# + {119947, "j"}, // MA# ( 𝒋 → j ) MATHEMATICAL BOLD ITALIC SMALL J → LATIN SMALL LETTER J# + {119948, "k"}, // MA# ( 𝒌 → k ) MATHEMATICAL BOLD ITALIC SMALL K → LATIN SMALL LETTER K# + {119949, "l"}, // MA# ( 𝒍 → l ) MATHEMATICAL BOLD ITALIC SMALL L → LATIN SMALL LETTER L# + {119950, "rn"}, // MA# ( 𝒎 → rn ) MATHEMATICAL BOLD ITALIC SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {119951, "n"}, // MA# ( 𝒏 → n ) MATHEMATICAL BOLD ITALIC SMALL N → LATIN SMALL LETTER N# + {119952, "o"}, // MA# ( 𝒐 → o ) MATHEMATICAL BOLD ITALIC SMALL O → LATIN SMALL LETTER O# + {119953, "p"}, // MA# ( 𝒑 → p ) MATHEMATICAL BOLD ITALIC SMALL P → LATIN SMALL LETTER P# + {119954, "q"}, // MA# ( 𝒒 → q ) MATHEMATICAL BOLD ITALIC SMALL Q → LATIN SMALL LETTER Q# + {119955, "r"}, // MA# ( 𝒓 → r ) MATHEMATICAL BOLD ITALIC SMALL R → LATIN SMALL LETTER R# + {119956, "s"}, // MA# ( 𝒔 → s ) MATHEMATICAL BOLD ITALIC SMALL S → LATIN SMALL LETTER S# + {119957, "t"}, // MA# ( 𝒕 → t ) MATHEMATICAL BOLD ITALIC SMALL T → LATIN SMALL LETTER T# + {119958, "u"}, // MA# ( 𝒖 → u ) MATHEMATICAL BOLD ITALIC SMALL U → LATIN SMALL LETTER U# + {119959, "v"}, // MA# ( 𝒗 → v ) MATHEMATICAL BOLD ITALIC SMALL V → LATIN SMALL LETTER V# + {119960, "w"}, // MA# ( 𝒘 → w ) MATHEMATICAL BOLD ITALIC SMALL W → LATIN SMALL LETTER W# + {119961, "x"}, // MA# ( 𝒙 → x ) MATHEMATICAL BOLD ITALIC SMALL X → LATIN SMALL LETTER X# + {119962, "y"}, // MA# ( 𝒚 → y ) MATHEMATICAL BOLD ITALIC SMALL Y → LATIN SMALL LETTER Y# + {119963, "z"}, // MA# ( 𝒛 → z ) MATHEMATICAL BOLD ITALIC SMALL Z → LATIN SMALL LETTER Z# + {119964, "A"}, // MA# ( 𝒜 → A ) MATHEMATICAL SCRIPT CAPITAL A → LATIN CAPITAL LETTER A# + {119966, "C"}, // MA# ( 𝒞 → C ) MATHEMATICAL SCRIPT CAPITAL C → LATIN CAPITAL LETTER C# + {119967, "D"}, // MA# ( 𝒟 → D ) MATHEMATICAL SCRIPT CAPITAL D → LATIN CAPITAL LETTER D# + {119970, "G"}, // MA# ( 𝒢 → G ) MATHEMATICAL SCRIPT CAPITAL G → LATIN CAPITAL LETTER G# + {119973, "J"}, // MA# ( 𝒥 → J ) MATHEMATICAL SCRIPT CAPITAL J → LATIN CAPITAL LETTER J# + {119974, "K"}, // MA# ( 𝒦 → K ) MATHEMATICAL SCRIPT CAPITAL K → LATIN CAPITAL LETTER K# + {119977, "N"}, // MA# ( 𝒩 → N ) MATHEMATICAL SCRIPT CAPITAL N → LATIN CAPITAL LETTER N# + {119978, "O"}, // MA# ( 𝒪 → O ) MATHEMATICAL SCRIPT CAPITAL O → LATIN CAPITAL LETTER O# + {119979, "P"}, // MA# ( 𝒫 → P ) MATHEMATICAL SCRIPT CAPITAL P → LATIN CAPITAL LETTER P# + {119980, "Q"}, // MA# ( 𝒬 → Q ) MATHEMATICAL SCRIPT CAPITAL Q → LATIN CAPITAL LETTER Q# + {119982, "S"}, // MA# ( 𝒮 → S ) MATHEMATICAL SCRIPT CAPITAL S → LATIN CAPITAL LETTER S# + {119983, "T"}, // MA# ( 𝒯 → T ) MATHEMATICAL SCRIPT CAPITAL T → LATIN CAPITAL LETTER T# + {119984, "U"}, // MA# ( 𝒰 → U ) MATHEMATICAL SCRIPT CAPITAL U → LATIN CAPITAL LETTER U# + {119985, "V"}, // MA# ( 𝒱 → V ) MATHEMATICAL SCRIPT CAPITAL V → LATIN CAPITAL LETTER V# + {119986, "W"}, // MA# ( 𝒲 → W ) MATHEMATICAL SCRIPT CAPITAL W → LATIN CAPITAL LETTER W# + {119987, "X"}, // MA# ( 𝒳 → X ) MATHEMATICAL SCRIPT CAPITAL X → LATIN CAPITAL LETTER X# + {119988, "Y"}, // MA# ( 𝒴 → Y ) MATHEMATICAL SCRIPT CAPITAL Y → LATIN CAPITAL LETTER Y# + {119989, "Z"}, // MA# ( 𝒵 → Z ) MATHEMATICAL SCRIPT CAPITAL Z → LATIN CAPITAL LETTER Z# + {119990, "a"}, // MA# ( 𝒶 → a ) MATHEMATICAL SCRIPT SMALL A → LATIN SMALL LETTER A# + {119991, "b"}, // MA# ( 𝒷 → b ) MATHEMATICAL SCRIPT SMALL B → LATIN SMALL LETTER B# + {119992, "c"}, // MA# ( 𝒸 → c ) MATHEMATICAL SCRIPT SMALL C → LATIN SMALL LETTER C# + {119993, "d"}, // MA# ( 𝒹 → d ) MATHEMATICAL SCRIPT SMALL D → LATIN SMALL LETTER D# + {119995, "f"}, // MA# ( 𝒻 → f ) MATHEMATICAL SCRIPT SMALL F → LATIN SMALL LETTER F# + {119997, "h"}, // MA# ( 𝒽 → h ) MATHEMATICAL SCRIPT SMALL H → LATIN SMALL LETTER H# + {119998, "i"}, // MA# ( 𝒾 → i ) MATHEMATICAL SCRIPT SMALL I → LATIN SMALL LETTER I# + {119999, "j"}, // MA# ( 𝒿 → j ) MATHEMATICAL SCRIPT SMALL J → LATIN SMALL LETTER J# + {120000, "k"}, // MA# ( 𝓀 → k ) MATHEMATICAL SCRIPT SMALL K → LATIN SMALL LETTER K# + {120001, "l"}, // MA# ( 𝓁 → l ) MATHEMATICAL SCRIPT SMALL L → LATIN SMALL LETTER L# + {120002, "rn"}, // MA# ( 𝓂 → rn ) MATHEMATICAL SCRIPT SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {120003, "n"}, // MA# ( 𝓃 → n ) MATHEMATICAL SCRIPT SMALL N → LATIN SMALL LETTER N# + {120005, "p"}, // MA# ( 𝓅 → p ) MATHEMATICAL SCRIPT SMALL P → LATIN SMALL LETTER P# + {120006, "q"}, // MA# ( 𝓆 → q ) MATHEMATICAL SCRIPT SMALL Q → LATIN SMALL LETTER Q# + {120007, "r"}, // MA# ( 𝓇 → r ) MATHEMATICAL SCRIPT SMALL R → LATIN SMALL LETTER R# + {120008, "s"}, // MA# ( 𝓈 → s ) MATHEMATICAL SCRIPT SMALL S → LATIN SMALL LETTER S# + {120009, "t"}, // MA# ( 𝓉 → t ) MATHEMATICAL SCRIPT SMALL T → LATIN SMALL LETTER T# + {120010, "u"}, // MA# ( 𝓊 → u ) MATHEMATICAL SCRIPT SMALL U → LATIN SMALL LETTER U# + {120011, "v"}, // MA# ( 𝓋 → v ) MATHEMATICAL SCRIPT SMALL V → LATIN SMALL LETTER V# + {120012, "w"}, // MA# ( 𝓌 → w ) MATHEMATICAL SCRIPT SMALL W → LATIN SMALL LETTER W# + {120013, "x"}, // MA# ( 𝓍 → x ) MATHEMATICAL SCRIPT SMALL X → LATIN SMALL LETTER X# + {120014, "y"}, // MA# ( 𝓎 → y ) MATHEMATICAL SCRIPT SMALL Y → LATIN SMALL LETTER Y# + {120015, "z"}, // MA# ( 𝓏 → z ) MATHEMATICAL SCRIPT SMALL Z → LATIN SMALL LETTER Z# + {120016, "A"}, // MA# ( 𝓐 → A ) MATHEMATICAL BOLD SCRIPT CAPITAL A → LATIN CAPITAL LETTER A# + {120017, "B"}, // MA# ( 𝓑 → B ) MATHEMATICAL BOLD SCRIPT CAPITAL B → LATIN CAPITAL LETTER B# + {120018, "C"}, // MA# ( 𝓒 → C ) MATHEMATICAL BOLD SCRIPT CAPITAL C → LATIN CAPITAL LETTER C# + {120019, "D"}, // MA# ( 𝓓 → D ) MATHEMATICAL BOLD SCRIPT CAPITAL D → LATIN CAPITAL LETTER D# + {120020, "E"}, // MA# ( 𝓔 → E ) MATHEMATICAL BOLD SCRIPT CAPITAL E → LATIN CAPITAL LETTER E# + {120021, "F"}, // MA# ( 𝓕 → F ) MATHEMATICAL BOLD SCRIPT CAPITAL F → LATIN CAPITAL LETTER F# + {120022, "G"}, // MA# ( 𝓖 → G ) MATHEMATICAL BOLD SCRIPT CAPITAL G → LATIN CAPITAL LETTER G# + {120023, "H"}, // MA# ( 𝓗 → H ) MATHEMATICAL BOLD SCRIPT CAPITAL H → LATIN CAPITAL LETTER H# + {120024, "l"}, // MA# ( 𝓘 → l ) MATHEMATICAL BOLD SCRIPT CAPITAL I → LATIN SMALL LETTER L# →I→ + {120025, "J"}, // MA# ( 𝓙 → J ) MATHEMATICAL BOLD SCRIPT CAPITAL J → LATIN CAPITAL LETTER J# + {120026, "K"}, // MA# ( 𝓚 → K ) MATHEMATICAL BOLD SCRIPT CAPITAL K → LATIN CAPITAL LETTER K# + {120027, "L"}, // MA# ( 𝓛 → L ) MATHEMATICAL BOLD SCRIPT CAPITAL L → LATIN CAPITAL LETTER L# + {120028, "M"}, // MA# ( 𝓜 → M ) MATHEMATICAL BOLD SCRIPT CAPITAL M → LATIN CAPITAL LETTER M# + {120029, "N"}, // MA# ( 𝓝 → N ) MATHEMATICAL BOLD SCRIPT CAPITAL N → LATIN CAPITAL LETTER N# + {120030, "O"}, // MA# ( 𝓞 → O ) MATHEMATICAL BOLD SCRIPT CAPITAL O → LATIN CAPITAL LETTER O# + {120031, "P"}, // MA# ( 𝓟 → P ) MATHEMATICAL BOLD SCRIPT CAPITAL P → LATIN CAPITAL LETTER P# + {120032, "Q"}, // MA# ( 𝓠 → Q ) MATHEMATICAL BOLD SCRIPT CAPITAL Q → LATIN CAPITAL LETTER Q# + {120033, "R"}, // MA# ( 𝓡 → R ) MATHEMATICAL BOLD SCRIPT CAPITAL R → LATIN CAPITAL LETTER R# + {120034, "S"}, // MA# ( 𝓢 → S ) MATHEMATICAL BOLD SCRIPT CAPITAL S → LATIN CAPITAL LETTER S# + {120035, "T"}, // MA# ( 𝓣 → T ) MATHEMATICAL BOLD SCRIPT CAPITAL T → LATIN CAPITAL LETTER T# + {120036, "U"}, // MA# ( 𝓤 → U ) MATHEMATICAL BOLD SCRIPT CAPITAL U → LATIN CAPITAL LETTER U# + {120037, "V"}, // MA# ( 𝓥 → V ) MATHEMATICAL BOLD SCRIPT CAPITAL V → LATIN CAPITAL LETTER V# + {120038, "W"}, // MA# ( 𝓦 → W ) MATHEMATICAL BOLD SCRIPT CAPITAL W → LATIN CAPITAL LETTER W# + {120039, "X"}, // MA# ( 𝓧 → X ) MATHEMATICAL BOLD SCRIPT CAPITAL X → LATIN CAPITAL LETTER X# + {120040, "Y"}, // MA# ( 𝓨 → Y ) MATHEMATICAL BOLD SCRIPT CAPITAL Y → LATIN CAPITAL LETTER Y# + {120041, "Z"}, // MA# ( 𝓩 → Z ) MATHEMATICAL BOLD SCRIPT CAPITAL Z → LATIN CAPITAL LETTER Z# + {120042, "a"}, // MA# ( 𝓪 → a ) MATHEMATICAL BOLD SCRIPT SMALL A → LATIN SMALL LETTER A# + {120043, "b"}, // MA# ( 𝓫 → b ) MATHEMATICAL BOLD SCRIPT SMALL B → LATIN SMALL LETTER B# + {120044, "c"}, // MA# ( 𝓬 → c ) MATHEMATICAL BOLD SCRIPT SMALL C → LATIN SMALL LETTER C# + {120045, "d"}, // MA# ( 𝓭 → d ) MATHEMATICAL BOLD SCRIPT SMALL D → LATIN SMALL LETTER D# + {120046, "e"}, // MA# ( 𝓮 → e ) MATHEMATICAL BOLD SCRIPT SMALL E → LATIN SMALL LETTER E# + {120047, "f"}, // MA# ( 𝓯 → f ) MATHEMATICAL BOLD SCRIPT SMALL F → LATIN SMALL LETTER F# + {120048, "g"}, // MA# ( 𝓰 → g ) MATHEMATICAL BOLD SCRIPT SMALL G → LATIN SMALL LETTER G# + {120049, "h"}, // MA# ( 𝓱 → h ) MATHEMATICAL BOLD SCRIPT SMALL H → LATIN SMALL LETTER H# + {120050, "i"}, // MA# ( 𝓲 → i ) MATHEMATICAL BOLD SCRIPT SMALL I → LATIN SMALL LETTER I# + {120051, "j"}, // MA# ( 𝓳 → j ) MATHEMATICAL BOLD SCRIPT SMALL J → LATIN SMALL LETTER J# + {120052, "k"}, // MA# ( 𝓴 → k ) MATHEMATICAL BOLD SCRIPT SMALL K → LATIN SMALL LETTER K# + {120053, "l"}, // MA# ( 𝓵 → l ) MATHEMATICAL BOLD SCRIPT SMALL L → LATIN SMALL LETTER L# + {120054, "rn"}, // MA# ( 𝓶 → rn ) MATHEMATICAL BOLD SCRIPT SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {120055, "n"}, // MA# ( 𝓷 → n ) MATHEMATICAL BOLD SCRIPT SMALL N → LATIN SMALL LETTER N# + {120056, "o"}, // MA# ( 𝓸 → o ) MATHEMATICAL BOLD SCRIPT SMALL O → LATIN SMALL LETTER O# + {120057, "p"}, // MA# ( 𝓹 → p ) MATHEMATICAL BOLD SCRIPT SMALL P → LATIN SMALL LETTER P# + {120058, "q"}, // MA# ( 𝓺 → q ) MATHEMATICAL BOLD SCRIPT SMALL Q → LATIN SMALL LETTER Q# + {120059, "r"}, // MA# ( 𝓻 → r ) MATHEMATICAL BOLD SCRIPT SMALL R → LATIN SMALL LETTER R# + {120060, "s"}, // MA# ( 𝓼 → s ) MATHEMATICAL BOLD SCRIPT SMALL S → LATIN SMALL LETTER S# + {120061, "t"}, // MA# ( 𝓽 → t ) MATHEMATICAL BOLD SCRIPT SMALL T → LATIN SMALL LETTER T# + {120062, "u"}, // MA# ( 𝓾 → u ) MATHEMATICAL BOLD SCRIPT SMALL U → LATIN SMALL LETTER U# + {120063, "v"}, // MA# ( 𝓿 → v ) MATHEMATICAL BOLD SCRIPT SMALL V → LATIN SMALL LETTER V# + {120064, "w"}, // MA# ( 𝔀 → w ) MATHEMATICAL BOLD SCRIPT SMALL W → LATIN SMALL LETTER W# + {120065, "x"}, // MA# ( 𝔁 → x ) MATHEMATICAL BOLD SCRIPT SMALL X → LATIN SMALL LETTER X# + {120066, "y"}, // MA# ( 𝔂 → y ) MATHEMATICAL BOLD SCRIPT SMALL Y → LATIN SMALL LETTER Y# + {120067, "z"}, // MA# ( 𝔃 → z ) MATHEMATICAL BOLD SCRIPT SMALL Z → LATIN SMALL LETTER Z# + {120068, "A"}, // MA# ( 𝔄 → A ) MATHEMATICAL FRAKTUR CAPITAL A → LATIN CAPITAL LETTER A# + {120069, "B"}, // MA# ( 𝔅 → B ) MATHEMATICAL FRAKTUR CAPITAL B → LATIN CAPITAL LETTER B# + {120071, "D"}, // MA# ( 𝔇 → D ) MATHEMATICAL FRAKTUR CAPITAL D → LATIN CAPITAL LETTER D# + {120072, "E"}, // MA# ( 𝔈 → E ) MATHEMATICAL FRAKTUR CAPITAL E → LATIN CAPITAL LETTER E# + {120073, "F"}, // MA# ( 𝔉 → F ) MATHEMATICAL FRAKTUR CAPITAL F → LATIN CAPITAL LETTER F# + {120074, "G"}, // MA# ( 𝔊 → G ) MATHEMATICAL FRAKTUR CAPITAL G → LATIN CAPITAL LETTER G# + {120077, "J"}, // MA# ( 𝔍 → J ) MATHEMATICAL FRAKTUR CAPITAL J → LATIN CAPITAL LETTER J# + {120078, "K"}, // MA# ( 𝔎 → K ) MATHEMATICAL FRAKTUR CAPITAL K → LATIN CAPITAL LETTER K# + {120079, "L"}, // MA# ( 𝔏 → L ) MATHEMATICAL FRAKTUR CAPITAL L → LATIN CAPITAL LETTER L# + {120080, "M"}, // MA# ( 𝔐 → M ) MATHEMATICAL FRAKTUR CAPITAL M → LATIN CAPITAL LETTER M# + {120081, "N"}, // MA# ( 𝔑 → N ) MATHEMATICAL FRAKTUR CAPITAL N → LATIN CAPITAL LETTER N# + {120082, "O"}, // MA# ( 𝔒 → O ) MATHEMATICAL FRAKTUR CAPITAL O → LATIN CAPITAL LETTER O# + {120083, "P"}, // MA# ( 𝔓 → P ) MATHEMATICAL FRAKTUR CAPITAL P → LATIN CAPITAL LETTER P# + {120084, "Q"}, // MA# ( 𝔔 → Q ) MATHEMATICAL FRAKTUR CAPITAL Q → LATIN CAPITAL LETTER Q# + {120086, "S"}, // MA# ( 𝔖 → S ) MATHEMATICAL FRAKTUR CAPITAL S → LATIN CAPITAL LETTER S# + {120087, "T"}, // MA# ( 𝔗 → T ) MATHEMATICAL FRAKTUR CAPITAL T → LATIN CAPITAL LETTER T# + {120088, "U"}, // MA# ( 𝔘 → U ) MATHEMATICAL FRAKTUR CAPITAL U → LATIN CAPITAL LETTER U# + {120089, "V"}, // MA# ( 𝔙 → V ) MATHEMATICAL FRAKTUR CAPITAL V → LATIN CAPITAL LETTER V# + {120090, "W"}, // MA# ( 𝔚 → W ) MATHEMATICAL FRAKTUR CAPITAL W → LATIN CAPITAL LETTER W# + {120091, "X"}, // MA# ( 𝔛 → X ) MATHEMATICAL FRAKTUR CAPITAL X → LATIN CAPITAL LETTER X# + {120092, "Y"}, // MA# ( 𝔜 → Y ) MATHEMATICAL FRAKTUR CAPITAL Y → LATIN CAPITAL LETTER Y# + {120094, "a"}, // MA# ( 𝔞 → a ) MATHEMATICAL FRAKTUR SMALL A → LATIN SMALL LETTER A# + {120095, "b"}, // MA# ( 𝔟 → b ) MATHEMATICAL FRAKTUR SMALL B → LATIN SMALL LETTER B# + {120096, "c"}, // MA# ( 𝔠 → c ) MATHEMATICAL FRAKTUR SMALL C → LATIN SMALL LETTER C# + {120097, "d"}, // MA# ( 𝔡 → d ) MATHEMATICAL FRAKTUR SMALL D → LATIN SMALL LETTER D# + {120098, "e"}, // MA# ( 𝔢 → e ) MATHEMATICAL FRAKTUR SMALL E → LATIN SMALL LETTER E# + {120099, "f"}, // MA# ( 𝔣 → f ) MATHEMATICAL FRAKTUR SMALL F → LATIN SMALL LETTER F# + {120100, "g"}, // MA# ( 𝔤 → g ) MATHEMATICAL FRAKTUR SMALL G → LATIN SMALL LETTER G# + {120101, "h"}, // MA# ( 𝔥 → h ) MATHEMATICAL FRAKTUR SMALL H → LATIN SMALL LETTER H# + {120102, "i"}, // MA# ( 𝔦 → i ) MATHEMATICAL FRAKTUR SMALL I → LATIN SMALL LETTER I# + {120103, "j"}, // MA# ( 𝔧 → j ) MATHEMATICAL FRAKTUR SMALL J → LATIN SMALL LETTER J# + {120104, "k"}, // MA# ( 𝔨 → k ) MATHEMATICAL FRAKTUR SMALL K → LATIN SMALL LETTER K# + {120105, "l"}, // MA# ( 𝔩 → l ) MATHEMATICAL FRAKTUR SMALL L → LATIN SMALL LETTER L# + {120106, "rn"}, // MA# ( 𝔪 → rn ) MATHEMATICAL FRAKTUR SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {120107, "n"}, // MA# ( 𝔫 → n ) MATHEMATICAL FRAKTUR SMALL N → LATIN SMALL LETTER N# + {120108, "o"}, // MA# ( 𝔬 → o ) MATHEMATICAL FRAKTUR SMALL O → LATIN SMALL LETTER O# + {120109, "p"}, // MA# ( 𝔭 → p ) MATHEMATICAL FRAKTUR SMALL P → LATIN SMALL LETTER P# + {120110, "q"}, // MA# ( 𝔮 → q ) MATHEMATICAL FRAKTUR SMALL Q → LATIN SMALL LETTER Q# + {120111, "r"}, // MA# ( 𝔯 → r ) MATHEMATICAL FRAKTUR SMALL R → LATIN SMALL LETTER R# + {120112, "s"}, // MA# ( 𝔰 → s ) MATHEMATICAL FRAKTUR SMALL S → LATIN SMALL LETTER S# + {120113, "t"}, // MA# ( 𝔱 → t ) MATHEMATICAL FRAKTUR SMALL T → LATIN SMALL LETTER T# + {120114, "u"}, // MA# ( 𝔲 → u ) MATHEMATICAL FRAKTUR SMALL U → LATIN SMALL LETTER U# + {120115, "v"}, // MA# ( 𝔳 → v ) MATHEMATICAL FRAKTUR SMALL V → LATIN SMALL LETTER V# + {120116, "w"}, // MA# ( 𝔴 → w ) MATHEMATICAL FRAKTUR SMALL W → LATIN SMALL LETTER W# + {120117, "x"}, // MA# ( 𝔵 → x ) MATHEMATICAL FRAKTUR SMALL X → LATIN SMALL LETTER X# + {120118, "y"}, // MA# ( 𝔶 → y ) MATHEMATICAL FRAKTUR SMALL Y → LATIN SMALL LETTER Y# + {120119, "z"}, // MA# ( 𝔷 → z ) MATHEMATICAL FRAKTUR SMALL Z → LATIN SMALL LETTER Z# + {120120, "A"}, // MA# ( 𝔸 → A ) MATHEMATICAL DOUBLE-STRUCK CAPITAL A → LATIN CAPITAL LETTER A# + {120121, "B"}, // MA# ( 𝔹 → B ) MATHEMATICAL DOUBLE-STRUCK CAPITAL B → LATIN CAPITAL LETTER B# + {120123, "D"}, // MA# ( 𝔻 → D ) MATHEMATICAL DOUBLE-STRUCK CAPITAL D → LATIN CAPITAL LETTER D# + {120124, "E"}, // MA# ( 𝔼 → E ) MATHEMATICAL DOUBLE-STRUCK CAPITAL E → LATIN CAPITAL LETTER E# + {120125, "F"}, // MA# ( 𝔽 → F ) MATHEMATICAL DOUBLE-STRUCK CAPITAL F → LATIN CAPITAL LETTER F# + {120126, "G"}, // MA# ( 𝔾 → G ) MATHEMATICAL DOUBLE-STRUCK CAPITAL G → LATIN CAPITAL LETTER G# + {120128, "l"}, // MA# ( 𝕀 → l ) MATHEMATICAL DOUBLE-STRUCK CAPITAL I → LATIN SMALL LETTER L# →I→ + {120129, "J"}, // MA# ( 𝕁 → J ) MATHEMATICAL DOUBLE-STRUCK CAPITAL J → LATIN CAPITAL LETTER J# + {120130, "K"}, // MA# ( 𝕂 → K ) MATHEMATICAL DOUBLE-STRUCK CAPITAL K → LATIN CAPITAL LETTER K# + {120131, "L"}, // MA# ( 𝕃 → L ) MATHEMATICAL DOUBLE-STRUCK CAPITAL L → LATIN CAPITAL LETTER L# + {120132, "M"}, // MA# ( 𝕄 → M ) MATHEMATICAL DOUBLE-STRUCK CAPITAL M → LATIN CAPITAL LETTER M# + {120134, "O"}, // MA# ( 𝕆 → O ) MATHEMATICAL DOUBLE-STRUCK CAPITAL O → LATIN CAPITAL LETTER O# + {120138, "S"}, // MA# ( 𝕊 → S ) MATHEMATICAL DOUBLE-STRUCK CAPITAL S → LATIN CAPITAL LETTER S# + {120139, "T"}, // MA# ( 𝕋 → T ) MATHEMATICAL DOUBLE-STRUCK CAPITAL T → LATIN CAPITAL LETTER T# + {120140, "U"}, // MA# ( 𝕌 → U ) MATHEMATICAL DOUBLE-STRUCK CAPITAL U → LATIN CAPITAL LETTER U# + {120141, "V"}, // MA# ( 𝕍 → V ) MATHEMATICAL DOUBLE-STRUCK CAPITAL V → LATIN CAPITAL LETTER V# + {120142, "W"}, // MA# ( 𝕎 → W ) MATHEMATICAL DOUBLE-STRUCK CAPITAL W → LATIN CAPITAL LETTER W# + {120143, "X"}, // MA# ( 𝕏 → X ) MATHEMATICAL DOUBLE-STRUCK CAPITAL X → LATIN CAPITAL LETTER X# + {120144, "Y"}, // MA# ( 𝕐 → Y ) MATHEMATICAL DOUBLE-STRUCK CAPITAL Y → LATIN CAPITAL LETTER Y# + {120146, "a"}, // MA# ( 𝕒 → a ) MATHEMATICAL DOUBLE-STRUCK SMALL A → LATIN SMALL LETTER A# + {120147, "b"}, // MA# ( 𝕓 → b ) MATHEMATICAL DOUBLE-STRUCK SMALL B → LATIN SMALL LETTER B# + {120148, "c"}, // MA# ( 𝕔 → c ) MATHEMATICAL DOUBLE-STRUCK SMALL C → LATIN SMALL LETTER C# + {120149, "d"}, // MA# ( 𝕕 → d ) MATHEMATICAL DOUBLE-STRUCK SMALL D → LATIN SMALL LETTER D# + {120150, "e"}, // MA# ( 𝕖 → e ) MATHEMATICAL DOUBLE-STRUCK SMALL E → LATIN SMALL LETTER E# + {120151, "f"}, // MA# ( 𝕗 → f ) MATHEMATICAL DOUBLE-STRUCK SMALL F → LATIN SMALL LETTER F# + {120152, "g"}, // MA# ( 𝕘 → g ) MATHEMATICAL DOUBLE-STRUCK SMALL G → LATIN SMALL LETTER G# + {120153, "h"}, // MA# ( 𝕙 → h ) MATHEMATICAL DOUBLE-STRUCK SMALL H → LATIN SMALL LETTER H# + {120154, "i"}, // MA# ( 𝕚 → i ) MATHEMATICAL DOUBLE-STRUCK SMALL I → LATIN SMALL LETTER I# + {120155, "j"}, // MA# ( 𝕛 → j ) MATHEMATICAL DOUBLE-STRUCK SMALL J → LATIN SMALL LETTER J# + {120156, "k"}, // MA# ( 𝕜 → k ) MATHEMATICAL DOUBLE-STRUCK SMALL K → LATIN SMALL LETTER K# + {120157, "l"}, // MA# ( 𝕝 → l ) MATHEMATICAL DOUBLE-STRUCK SMALL L → LATIN SMALL LETTER L# + {120158, "rn"}, // MA# ( 𝕞 → rn ) MATHEMATICAL DOUBLE-STRUCK SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {120159, "n"}, // MA# ( 𝕟 → n ) MATHEMATICAL DOUBLE-STRUCK SMALL N → LATIN SMALL LETTER N# + {120160, "o"}, // MA# ( 𝕠 → o ) MATHEMATICAL DOUBLE-STRUCK SMALL O → LATIN SMALL LETTER O# + {120161, "p"}, // MA# ( 𝕡 → p ) MATHEMATICAL DOUBLE-STRUCK SMALL P → LATIN SMALL LETTER P# + {120162, "q"}, // MA# ( 𝕢 → q ) MATHEMATICAL DOUBLE-STRUCK SMALL Q → LATIN SMALL LETTER Q# + {120163, "r"}, // MA# ( 𝕣 → r ) MATHEMATICAL DOUBLE-STRUCK SMALL R → LATIN SMALL LETTER R# + {120164, "s"}, // MA# ( 𝕤 → s ) MATHEMATICAL DOUBLE-STRUCK SMALL S → LATIN SMALL LETTER S# + {120165, "t"}, // MA# ( 𝕥 → t ) MATHEMATICAL DOUBLE-STRUCK SMALL T → LATIN SMALL LETTER T# + {120166, "u"}, // MA# ( 𝕦 → u ) MATHEMATICAL DOUBLE-STRUCK SMALL U → LATIN SMALL LETTER U# + {120167, "v"}, // MA# ( 𝕧 → v ) MATHEMATICAL DOUBLE-STRUCK SMALL V → LATIN SMALL LETTER V# + {120168, "w"}, // MA# ( 𝕨 → w ) MATHEMATICAL DOUBLE-STRUCK SMALL W → LATIN SMALL LETTER W# + {120169, "x"}, // MA# ( 𝕩 → x ) MATHEMATICAL DOUBLE-STRUCK SMALL X → LATIN SMALL LETTER X# + {120170, "y"}, // MA# ( 𝕪 → y ) MATHEMATICAL DOUBLE-STRUCK SMALL Y → LATIN SMALL LETTER Y# + {120171, "z"}, // MA# ( 𝕫 → z ) MATHEMATICAL DOUBLE-STRUCK SMALL Z → LATIN SMALL LETTER Z# + {120172, "A"}, // MA# ( 𝕬 → A ) MATHEMATICAL BOLD FRAKTUR CAPITAL A → LATIN CAPITAL LETTER A# + {120173, "B"}, // MA# ( 𝕭 → B ) MATHEMATICAL BOLD FRAKTUR CAPITAL B → LATIN CAPITAL LETTER B# + {120174, "C"}, // MA# ( 𝕮 → C ) MATHEMATICAL BOLD FRAKTUR CAPITAL C → LATIN CAPITAL LETTER C# + {120175, "D"}, // MA# ( 𝕯 → D ) MATHEMATICAL BOLD FRAKTUR CAPITAL D → LATIN CAPITAL LETTER D# + {120176, "E"}, // MA# ( 𝕰 → E ) MATHEMATICAL BOLD FRAKTUR CAPITAL E → LATIN CAPITAL LETTER E# + {120177, "F"}, // MA# ( 𝕱 → F ) MATHEMATICAL BOLD FRAKTUR CAPITAL F → LATIN CAPITAL LETTER F# + {120178, "G"}, // MA# ( 𝕲 → G ) MATHEMATICAL BOLD FRAKTUR CAPITAL G → LATIN CAPITAL LETTER G# + {120179, "H"}, // MA# ( 𝕳 → H ) MATHEMATICAL BOLD FRAKTUR CAPITAL H → LATIN CAPITAL LETTER H# + {120180, "l"}, // MA# ( 𝕴 → l ) MATHEMATICAL BOLD FRAKTUR CAPITAL I → LATIN SMALL LETTER L# →I→ + {120181, "J"}, // MA# ( 𝕵 → J ) MATHEMATICAL BOLD FRAKTUR CAPITAL J → LATIN CAPITAL LETTER J# + {120182, "K"}, // MA# ( 𝕶 → K ) MATHEMATICAL BOLD FRAKTUR CAPITAL K → LATIN CAPITAL LETTER K# + {120183, "L"}, // MA# ( 𝕷 → L ) MATHEMATICAL BOLD FRAKTUR CAPITAL L → LATIN CAPITAL LETTER L# + {120184, "M"}, // MA# ( 𝕸 → M ) MATHEMATICAL BOLD FRAKTUR CAPITAL M → LATIN CAPITAL LETTER M# + {120185, "N"}, // MA# ( 𝕹 → N ) MATHEMATICAL BOLD FRAKTUR CAPITAL N → LATIN CAPITAL LETTER N# + {120186, "O"}, // MA# ( 𝕺 → O ) MATHEMATICAL BOLD FRAKTUR CAPITAL O → LATIN CAPITAL LETTER O# + {120187, "P"}, // MA# ( 𝕻 → P ) MATHEMATICAL BOLD FRAKTUR CAPITAL P → LATIN CAPITAL LETTER P# + {120188, "Q"}, // MA# ( 𝕼 → Q ) MATHEMATICAL BOLD FRAKTUR CAPITAL Q → LATIN CAPITAL LETTER Q# + {120189, "R"}, // MA# ( 𝕽 → R ) MATHEMATICAL BOLD FRAKTUR CAPITAL R → LATIN CAPITAL LETTER R# + {120190, "S"}, // MA# ( 𝕾 → S ) MATHEMATICAL BOLD FRAKTUR CAPITAL S → LATIN CAPITAL LETTER S# + {120191, "T"}, // MA# ( 𝕿 → T ) MATHEMATICAL BOLD FRAKTUR CAPITAL T → LATIN CAPITAL LETTER T# + {120192, "U"}, // MA# ( 𝖀 → U ) MATHEMATICAL BOLD FRAKTUR CAPITAL U → LATIN CAPITAL LETTER U# + {120193, "V"}, // MA# ( 𝖁 → V ) MATHEMATICAL BOLD FRAKTUR CAPITAL V → LATIN CAPITAL LETTER V# + {120194, "W"}, // MA# ( 𝖂 → W ) MATHEMATICAL BOLD FRAKTUR CAPITAL W → LATIN CAPITAL LETTER W# + {120195, "X"}, // MA# ( 𝖃 → X ) MATHEMATICAL BOLD FRAKTUR CAPITAL X → LATIN CAPITAL LETTER X# + {120196, "Y"}, // MA# ( 𝖄 → Y ) MATHEMATICAL BOLD FRAKTUR CAPITAL Y → LATIN CAPITAL LETTER Y# + {120197, "Z"}, // MA# ( 𝖅 → Z ) MATHEMATICAL BOLD FRAKTUR CAPITAL Z → LATIN CAPITAL LETTER Z# + {120198, "a"}, // MA# ( 𝖆 → a ) MATHEMATICAL BOLD FRAKTUR SMALL A → LATIN SMALL LETTER A# + {120199, "b"}, // MA# ( 𝖇 → b ) MATHEMATICAL BOLD FRAKTUR SMALL B → LATIN SMALL LETTER B# + {120200, "c"}, // MA# ( 𝖈 → c ) MATHEMATICAL BOLD FRAKTUR SMALL C → LATIN SMALL LETTER C# + {120201, "d"}, // MA# ( 𝖉 → d ) MATHEMATICAL BOLD FRAKTUR SMALL D → LATIN SMALL LETTER D# + {120202, "e"}, // MA# ( 𝖊 → e ) MATHEMATICAL BOLD FRAKTUR SMALL E → LATIN SMALL LETTER E# + {120203, "f"}, // MA# ( 𝖋 → f ) MATHEMATICAL BOLD FRAKTUR SMALL F → LATIN SMALL LETTER F# + {120204, "g"}, // MA# ( 𝖌 → g ) MATHEMATICAL BOLD FRAKTUR SMALL G → LATIN SMALL LETTER G# + {120205, "h"}, // MA# ( 𝖍 → h ) MATHEMATICAL BOLD FRAKTUR SMALL H → LATIN SMALL LETTER H# + {120206, "i"}, // MA# ( 𝖎 → i ) MATHEMATICAL BOLD FRAKTUR SMALL I → LATIN SMALL LETTER I# + {120207, "j"}, // MA# ( 𝖏 → j ) MATHEMATICAL BOLD FRAKTUR SMALL J → LATIN SMALL LETTER J# + {120208, "k"}, // MA# ( 𝖐 → k ) MATHEMATICAL BOLD FRAKTUR SMALL K → LATIN SMALL LETTER K# + {120209, "l"}, // MA# ( 𝖑 → l ) MATHEMATICAL BOLD FRAKTUR SMALL L → LATIN SMALL LETTER L# + {120210, "rn"}, // MA# ( 𝖒 → rn ) MATHEMATICAL BOLD FRAKTUR SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {120211, "n"}, // MA# ( 𝖓 → n ) MATHEMATICAL BOLD FRAKTUR SMALL N → LATIN SMALL LETTER N# + {120212, "o"}, // MA# ( 𝖔 → o ) MATHEMATICAL BOLD FRAKTUR SMALL O → LATIN SMALL LETTER O# + {120213, "p"}, // MA# ( 𝖕 → p ) MATHEMATICAL BOLD FRAKTUR SMALL P → LATIN SMALL LETTER P# + {120214, "q"}, // MA# ( 𝖖 → q ) MATHEMATICAL BOLD FRAKTUR SMALL Q → LATIN SMALL LETTER Q# + {120215, "r"}, // MA# ( 𝖗 → r ) MATHEMATICAL BOLD FRAKTUR SMALL R → LATIN SMALL LETTER R# + {120216, "s"}, // MA# ( 𝖘 → s ) MATHEMATICAL BOLD FRAKTUR SMALL S → LATIN SMALL LETTER S# + {120217, "t"}, // MA# ( 𝖙 → t ) MATHEMATICAL BOLD FRAKTUR SMALL T → LATIN SMALL LETTER T# + {120218, "u"}, // MA# ( 𝖚 → u ) MATHEMATICAL BOLD FRAKTUR SMALL U → LATIN SMALL LETTER U# + {120219, "v"}, // MA# ( 𝖛 → v ) MATHEMATICAL BOLD FRAKTUR SMALL V → LATIN SMALL LETTER V# + {120220, "w"}, // MA# ( 𝖜 → w ) MATHEMATICAL BOLD FRAKTUR SMALL W → LATIN SMALL LETTER W# + {120221, "x"}, // MA# ( 𝖝 → x ) MATHEMATICAL BOLD FRAKTUR SMALL X → LATIN SMALL LETTER X# + {120222, "y"}, // MA# ( 𝖞 → y ) MATHEMATICAL BOLD FRAKTUR SMALL Y → LATIN SMALL LETTER Y# + {120223, "z"}, // MA# ( 𝖟 → z ) MATHEMATICAL BOLD FRAKTUR SMALL Z → LATIN SMALL LETTER Z# + {120224, "A"}, // MA# ( 𝖠 → A ) MATHEMATICAL SANS-SERIF CAPITAL A → LATIN CAPITAL LETTER A# + {120225, "B"}, // MA# ( 𝖡 → B ) MATHEMATICAL SANS-SERIF CAPITAL B → LATIN CAPITAL LETTER B# + {120226, "C"}, // MA# ( 𝖢 → C ) MATHEMATICAL SANS-SERIF CAPITAL C → LATIN CAPITAL LETTER C# + {120227, "D"}, // MA# ( 𝖣 → D ) MATHEMATICAL SANS-SERIF CAPITAL D → LATIN CAPITAL LETTER D# + {120228, "E"}, // MA# ( 𝖤 → E ) MATHEMATICAL SANS-SERIF CAPITAL E → LATIN CAPITAL LETTER E# + {120229, "F"}, // MA# ( 𝖥 → F ) MATHEMATICAL SANS-SERIF CAPITAL F → LATIN CAPITAL LETTER F# + {120230, "G"}, // MA# ( 𝖦 → G ) MATHEMATICAL SANS-SERIF CAPITAL G → LATIN CAPITAL LETTER G# + {120231, "H"}, // MA# ( 𝖧 → H ) MATHEMATICAL SANS-SERIF CAPITAL H → LATIN CAPITAL LETTER H# + {120232, "l"}, // MA# ( 𝖨 → l ) MATHEMATICAL SANS-SERIF CAPITAL I → LATIN SMALL LETTER L# →I→ + {120233, "J"}, // MA# ( 𝖩 → J ) MATHEMATICAL SANS-SERIF CAPITAL J → LATIN CAPITAL LETTER J# + {120234, "K"}, // MA# ( 𝖪 → K ) MATHEMATICAL SANS-SERIF CAPITAL K → LATIN CAPITAL LETTER K# + {120235, "L"}, // MA# ( 𝖫 → L ) MATHEMATICAL SANS-SERIF CAPITAL L → LATIN CAPITAL LETTER L# + {120236, "M"}, // MA# ( 𝖬 → M ) MATHEMATICAL SANS-SERIF CAPITAL M → LATIN CAPITAL LETTER M# + {120237, "N"}, // MA# ( 𝖭 → N ) MATHEMATICAL SANS-SERIF CAPITAL N → LATIN CAPITAL LETTER N# + {120238, "O"}, // MA# ( 𝖮 → O ) MATHEMATICAL SANS-SERIF CAPITAL O → LATIN CAPITAL LETTER O# + {120239, "P"}, // MA# ( 𝖯 → P ) MATHEMATICAL SANS-SERIF CAPITAL P → LATIN CAPITAL LETTER P# + {120240, "Q"}, // MA# ( 𝖰 → Q ) MATHEMATICAL SANS-SERIF CAPITAL Q → LATIN CAPITAL LETTER Q# + {120241, "R"}, // MA# ( 𝖱 → R ) MATHEMATICAL SANS-SERIF CAPITAL R → LATIN CAPITAL LETTER R# + {120242, "S"}, // MA# ( 𝖲 → S ) MATHEMATICAL SANS-SERIF CAPITAL S → LATIN CAPITAL LETTER S# + {120243, "T"}, // MA# ( 𝖳 → T ) MATHEMATICAL SANS-SERIF CAPITAL T → LATIN CAPITAL LETTER T# + {120244, "U"}, // MA# ( 𝖴 → U ) MATHEMATICAL SANS-SERIF CAPITAL U → LATIN CAPITAL LETTER U# + {120245, "V"}, // MA# ( 𝖵 → V ) MATHEMATICAL SANS-SERIF CAPITAL V → LATIN CAPITAL LETTER V# + {120246, "W"}, // MA# ( 𝖶 → W ) MATHEMATICAL SANS-SERIF CAPITAL W → LATIN CAPITAL LETTER W# + {120247, "X"}, // MA# ( 𝖷 → X ) MATHEMATICAL SANS-SERIF CAPITAL X → LATIN CAPITAL LETTER X# + {120248, "Y"}, // MA# ( 𝖸 → Y ) MATHEMATICAL SANS-SERIF CAPITAL Y → LATIN CAPITAL LETTER Y# + {120249, "Z"}, // MA# ( 𝖹 → Z ) MATHEMATICAL SANS-SERIF CAPITAL Z → LATIN CAPITAL LETTER Z# + {120250, "a"}, // MA# ( 𝖺 → a ) MATHEMATICAL SANS-SERIF SMALL A → LATIN SMALL LETTER A# + {120251, "b"}, // MA# ( 𝖻 → b ) MATHEMATICAL SANS-SERIF SMALL B → LATIN SMALL LETTER B# + {120252, "c"}, // MA# ( 𝖼 → c ) MATHEMATICAL SANS-SERIF SMALL C → LATIN SMALL LETTER C# + {120253, "d"}, // MA# ( 𝖽 → d ) MATHEMATICAL SANS-SERIF SMALL D → LATIN SMALL LETTER D# + {120254, "e"}, // MA# ( 𝖾 → e ) MATHEMATICAL SANS-SERIF SMALL E → LATIN SMALL LETTER E# + {120255, "f"}, // MA# ( 𝖿 → f ) MATHEMATICAL SANS-SERIF SMALL F → LATIN SMALL LETTER F# + {120256, "g"}, // MA# ( 𝗀 → g ) MATHEMATICAL SANS-SERIF SMALL G → LATIN SMALL LETTER G# + {120257, "h"}, // MA# ( 𝗁 → h ) MATHEMATICAL SANS-SERIF SMALL H → LATIN SMALL LETTER H# + {120258, "i"}, // MA# ( 𝗂 → i ) MATHEMATICAL SANS-SERIF SMALL I → LATIN SMALL LETTER I# + {120259, "j"}, // MA# ( 𝗃 → j ) MATHEMATICAL SANS-SERIF SMALL J → LATIN SMALL LETTER J# + {120260, "k"}, // MA# ( 𝗄 → k ) MATHEMATICAL SANS-SERIF SMALL K → LATIN SMALL LETTER K# + {120261, "l"}, // MA# ( 𝗅 → l ) MATHEMATICAL SANS-SERIF SMALL L → LATIN SMALL LETTER L# + {120262, "rn"}, // MA# ( 𝗆 → rn ) MATHEMATICAL SANS-SERIF SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {120263, "n"}, // MA# ( 𝗇 → n ) MATHEMATICAL SANS-SERIF SMALL N → LATIN SMALL LETTER N# + {120264, "o"}, // MA# ( 𝗈 → o ) MATHEMATICAL SANS-SERIF SMALL O → LATIN SMALL LETTER O# + {120265, "p"}, // MA# ( 𝗉 → p ) MATHEMATICAL SANS-SERIF SMALL P → LATIN SMALL LETTER P# + {120266, "q"}, // MA# ( 𝗊 → q ) MATHEMATICAL SANS-SERIF SMALL Q → LATIN SMALL LETTER Q# + {120267, "r"}, // MA# ( 𝗋 → r ) MATHEMATICAL SANS-SERIF SMALL R → LATIN SMALL LETTER R# + {120268, "s"}, // MA# ( 𝗌 → s ) MATHEMATICAL SANS-SERIF SMALL S → LATIN SMALL LETTER S# + {120269, "t"}, // MA# ( 𝗍 → t ) MATHEMATICAL SANS-SERIF SMALL T → LATIN SMALL LETTER T# + {120270, "u"}, // MA# ( 𝗎 → u ) MATHEMATICAL SANS-SERIF SMALL U → LATIN SMALL LETTER U# + {120271, "v"}, // MA# ( 𝗏 → v ) MATHEMATICAL SANS-SERIF SMALL V → LATIN SMALL LETTER V# + {120272, "w"}, // MA# ( 𝗐 → w ) MATHEMATICAL SANS-SERIF SMALL W → LATIN SMALL LETTER W# + {120273, "x"}, // MA# ( 𝗑 → x ) MATHEMATICAL SANS-SERIF SMALL X → LATIN SMALL LETTER X# + {120274, "y"}, // MA# ( 𝗒 → y ) MATHEMATICAL SANS-SERIF SMALL Y → LATIN SMALL LETTER Y# + {120275, "z"}, // MA# ( 𝗓 → z ) MATHEMATICAL SANS-SERIF SMALL Z → LATIN SMALL LETTER Z# + {120276, "A"}, // MA# ( 𝗔 → A ) MATHEMATICAL SANS-SERIF BOLD CAPITAL A → LATIN CAPITAL LETTER A# + {120277, "B"}, // MA# ( 𝗕 → B ) MATHEMATICAL SANS-SERIF BOLD CAPITAL B → LATIN CAPITAL LETTER B# + {120278, "C"}, // MA# ( 𝗖 → C ) MATHEMATICAL SANS-SERIF BOLD CAPITAL C → LATIN CAPITAL LETTER C# + {120279, "D"}, // MA# ( 𝗗 → D ) MATHEMATICAL SANS-SERIF BOLD CAPITAL D → LATIN CAPITAL LETTER D# + {120280, "E"}, // MA# ( 𝗘 → E ) MATHEMATICAL SANS-SERIF BOLD CAPITAL E → LATIN CAPITAL LETTER E# + {120281, "F"}, // MA# ( 𝗙 → F ) MATHEMATICAL SANS-SERIF BOLD CAPITAL F → LATIN CAPITAL LETTER F# + {120282, "G"}, // MA# ( 𝗚 → G ) MATHEMATICAL SANS-SERIF BOLD CAPITAL G → LATIN CAPITAL LETTER G# + {120283, "H"}, // MA# ( 𝗛 → H ) MATHEMATICAL SANS-SERIF BOLD CAPITAL H → LATIN CAPITAL LETTER H# + {120284, "l"}, // MA# ( 𝗜 → l ) MATHEMATICAL SANS-SERIF BOLD CAPITAL I → LATIN SMALL LETTER L# →I→ + {120285, "J"}, // MA# ( 𝗝 → J ) MATHEMATICAL SANS-SERIF BOLD CAPITAL J → LATIN CAPITAL LETTER J# + {120286, "K"}, // MA# ( 𝗞 → K ) MATHEMATICAL SANS-SERIF BOLD CAPITAL K → LATIN CAPITAL LETTER K# + {120287, "L"}, // MA# ( 𝗟 → L ) MATHEMATICAL SANS-SERIF BOLD CAPITAL L → LATIN CAPITAL LETTER L# + {120288, "M"}, // MA# ( 𝗠 → M ) MATHEMATICAL SANS-SERIF BOLD CAPITAL M → LATIN CAPITAL LETTER M# + {120289, "N"}, // MA# ( 𝗡 → N ) MATHEMATICAL SANS-SERIF BOLD CAPITAL N → LATIN CAPITAL LETTER N# + {120290, "O"}, // MA# ( 𝗢 → O ) MATHEMATICAL SANS-SERIF BOLD CAPITAL O → LATIN CAPITAL LETTER O# + {120291, "P"}, // MA# ( 𝗣 → P ) MATHEMATICAL SANS-SERIF BOLD CAPITAL P → LATIN CAPITAL LETTER P# + {120292, "Q"}, // MA# ( 𝗤 → Q ) MATHEMATICAL SANS-SERIF BOLD CAPITAL Q → LATIN CAPITAL LETTER Q# + {120293, "R"}, // MA# ( 𝗥 → R ) MATHEMATICAL SANS-SERIF BOLD CAPITAL R → LATIN CAPITAL LETTER R# + {120294, "S"}, // MA# ( 𝗦 → S ) MATHEMATICAL SANS-SERIF BOLD CAPITAL S → LATIN CAPITAL LETTER S# + {120295, "T"}, // MA# ( 𝗧 → T ) MATHEMATICAL SANS-SERIF BOLD CAPITAL T → LATIN CAPITAL LETTER T# + {120296, "U"}, // MA# ( 𝗨 → U ) MATHEMATICAL SANS-SERIF BOLD CAPITAL U → LATIN CAPITAL LETTER U# + {120297, "V"}, // MA# ( 𝗩 → V ) MATHEMATICAL SANS-SERIF BOLD CAPITAL V → LATIN CAPITAL LETTER V# + {120298, "W"}, // MA# ( 𝗪 → W ) MATHEMATICAL SANS-SERIF BOLD CAPITAL W → LATIN CAPITAL LETTER W# + {120299, "X"}, // MA# ( 𝗫 → X ) MATHEMATICAL SANS-SERIF BOLD CAPITAL X → LATIN CAPITAL LETTER X# + {120300, "Y"}, // MA# ( 𝗬 → Y ) MATHEMATICAL SANS-SERIF BOLD CAPITAL Y → LATIN CAPITAL LETTER Y# + {120301, "Z"}, // MA# ( 𝗭 → Z ) MATHEMATICAL SANS-SERIF BOLD CAPITAL Z → LATIN CAPITAL LETTER Z# + {120302, "a"}, // MA# ( 𝗮 → a ) MATHEMATICAL SANS-SERIF BOLD SMALL A → LATIN SMALL LETTER A# + {120303, "b"}, // MA# ( 𝗯 → b ) MATHEMATICAL SANS-SERIF BOLD SMALL B → LATIN SMALL LETTER B# + {120304, "c"}, // MA# ( 𝗰 → c ) MATHEMATICAL SANS-SERIF BOLD SMALL C → LATIN SMALL LETTER C# + {120305, "d"}, // MA# ( 𝗱 → d ) MATHEMATICAL SANS-SERIF BOLD SMALL D → LATIN SMALL LETTER D# + {120306, "e"}, // MA# ( 𝗲 → e ) MATHEMATICAL SANS-SERIF BOLD SMALL E → LATIN SMALL LETTER E# + {120307, "f"}, // MA# ( 𝗳 → f ) MATHEMATICAL SANS-SERIF BOLD SMALL F → LATIN SMALL LETTER F# + {120308, "g"}, // MA# ( 𝗴 → g ) MATHEMATICAL SANS-SERIF BOLD SMALL G → LATIN SMALL LETTER G# + {120309, "h"}, // MA# ( 𝗵 → h ) MATHEMATICAL SANS-SERIF BOLD SMALL H → LATIN SMALL LETTER H# + {120310, "i"}, // MA# ( 𝗶 → i ) MATHEMATICAL SANS-SERIF BOLD SMALL I → LATIN SMALL LETTER I# + {120311, "j"}, // MA# ( 𝗷 → j ) MATHEMATICAL SANS-SERIF BOLD SMALL J → LATIN SMALL LETTER J# + {120312, "k"}, // MA# ( 𝗸 → k ) MATHEMATICAL SANS-SERIF BOLD SMALL K → LATIN SMALL LETTER K# + {120313, "l"}, // MA# ( 𝗹 → l ) MATHEMATICAL SANS-SERIF BOLD SMALL L → LATIN SMALL LETTER L# + {120314, "rn"}, // MA# ( 𝗺 → rn ) MATHEMATICAL SANS-SERIF BOLD SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {120315, "n"}, // MA# ( 𝗻 → n ) MATHEMATICAL SANS-SERIF BOLD SMALL N → LATIN SMALL LETTER N# + {120316, "o"}, // MA# ( 𝗼 → o ) MATHEMATICAL SANS-SERIF BOLD SMALL O → LATIN SMALL LETTER O# + {120317, "p"}, // MA# ( 𝗽 → p ) MATHEMATICAL SANS-SERIF BOLD SMALL P → LATIN SMALL LETTER P# + {120318, "q"}, // MA# ( 𝗾 → q ) MATHEMATICAL SANS-SERIF BOLD SMALL Q → LATIN SMALL LETTER Q# + {120319, "r"}, // MA# ( 𝗿 → r ) MATHEMATICAL SANS-SERIF BOLD SMALL R → LATIN SMALL LETTER R# + {120320, "s"}, // MA# ( 𝘀 → s ) MATHEMATICAL SANS-SERIF BOLD SMALL S → LATIN SMALL LETTER S# + {120321, "t"}, // MA# ( 𝘁 → t ) MATHEMATICAL SANS-SERIF BOLD SMALL T → LATIN SMALL LETTER T# + {120322, "u"}, // MA# ( 𝘂 → u ) MATHEMATICAL SANS-SERIF BOLD SMALL U → LATIN SMALL LETTER U# + {120323, "v"}, // MA# ( 𝘃 → v ) MATHEMATICAL SANS-SERIF BOLD SMALL V → LATIN SMALL LETTER V# + {120324, "w"}, // MA# ( 𝘄 → w ) MATHEMATICAL SANS-SERIF BOLD SMALL W → LATIN SMALL LETTER W# + {120325, "x"}, // MA# ( 𝘅 → x ) MATHEMATICAL SANS-SERIF BOLD SMALL X → LATIN SMALL LETTER X# + {120326, "y"}, // MA# ( 𝘆 → y ) MATHEMATICAL SANS-SERIF BOLD SMALL Y → LATIN SMALL LETTER Y# + {120327, "z"}, // MA# ( 𝘇 → z ) MATHEMATICAL SANS-SERIF BOLD SMALL Z → LATIN SMALL LETTER Z# + {120328, "A"}, // MA# ( 𝘈 → A ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL A → LATIN CAPITAL LETTER A# + {120329, "B"}, // MA# ( 𝘉 → B ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL B → LATIN CAPITAL LETTER B# + {120330, "C"}, // MA# ( 𝘊 → C ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL C → LATIN CAPITAL LETTER C# + {120331, "D"}, // MA# ( 𝘋 → D ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL D → LATIN CAPITAL LETTER D# + {120332, "E"}, // MA# ( 𝘌 → E ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL E → LATIN CAPITAL LETTER E# + {120333, "F"}, // MA# ( 𝘍 → F ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL F → LATIN CAPITAL LETTER F# + {120334, "G"}, // MA# ( 𝘎 → G ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL G → LATIN CAPITAL LETTER G# + {120335, "H"}, // MA# ( 𝘏 → H ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL H → LATIN CAPITAL LETTER H# + {120336, "l"}, // MA# ( 𝘐 → l ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL I → LATIN SMALL LETTER L# →I→ + {120337, "J"}, // MA# ( 𝘑 → J ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL J → LATIN CAPITAL LETTER J# + {120338, "K"}, // MA# ( 𝘒 → K ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL K → LATIN CAPITAL LETTER K# + {120339, "L"}, // MA# ( 𝘓 → L ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL L → LATIN CAPITAL LETTER L# + {120340, "M"}, // MA# ( 𝘔 → M ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL M → LATIN CAPITAL LETTER M# + {120341, "N"}, // MA# ( 𝘕 → N ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL N → LATIN CAPITAL LETTER N# + {120342, "O"}, // MA# ( 𝘖 → O ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL O → LATIN CAPITAL LETTER O# + {120343, "P"}, // MA# ( 𝘗 → P ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL P → LATIN CAPITAL LETTER P# + {120344, "Q"}, // MA# ( 𝘘 → Q ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL Q → LATIN CAPITAL LETTER Q# + {120345, "R"}, // MA# ( 𝘙 → R ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL R → LATIN CAPITAL LETTER R# + {120346, "S"}, // MA# ( 𝘚 → S ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL S → LATIN CAPITAL LETTER S# + {120347, "T"}, // MA# ( 𝘛 → T ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL T → LATIN CAPITAL LETTER T# + {120348, "U"}, // MA# ( 𝘜 → U ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL U → LATIN CAPITAL LETTER U# + {120349, "V"}, // MA# ( 𝘝 → V ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL V → LATIN CAPITAL LETTER V# + {120350, "W"}, // MA# ( 𝘞 → W ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL W → LATIN CAPITAL LETTER W# + {120351, "X"}, // MA# ( 𝘟 → X ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL X → LATIN CAPITAL LETTER X# + {120352, "Y"}, // MA# ( 𝘠 → Y ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL Y → LATIN CAPITAL LETTER Y# + {120353, "Z"}, // MA# ( 𝘡 → Z ) MATHEMATICAL SANS-SERIF ITALIC CAPITAL Z → LATIN CAPITAL LETTER Z# + {120354, "a"}, // MA# ( 𝘢 → a ) MATHEMATICAL SANS-SERIF ITALIC SMALL A → LATIN SMALL LETTER A# + {120355, "b"}, // MA# ( 𝘣 → b ) MATHEMATICAL SANS-SERIF ITALIC SMALL B → LATIN SMALL LETTER B# + {120356, "c"}, // MA# ( 𝘤 → c ) MATHEMATICAL SANS-SERIF ITALIC SMALL C → LATIN SMALL LETTER C# + {120357, "d"}, // MA# ( 𝘥 → d ) MATHEMATICAL SANS-SERIF ITALIC SMALL D → LATIN SMALL LETTER D# + {120358, "e"}, // MA# ( 𝘦 → e ) MATHEMATICAL SANS-SERIF ITALIC SMALL E → LATIN SMALL LETTER E# + {120359, "f"}, // MA# ( 𝘧 → f ) MATHEMATICAL SANS-SERIF ITALIC SMALL F → LATIN SMALL LETTER F# + {120360, "g"}, // MA# ( 𝘨 → g ) MATHEMATICAL SANS-SERIF ITALIC SMALL G → LATIN SMALL LETTER G# + {120361, "h"}, // MA# ( 𝘩 → h ) MATHEMATICAL SANS-SERIF ITALIC SMALL H → LATIN SMALL LETTER H# + {120362, "i"}, // MA# ( 𝘪 → i ) MATHEMATICAL SANS-SERIF ITALIC SMALL I → LATIN SMALL LETTER I# + {120363, "j"}, // MA# ( 𝘫 → j ) MATHEMATICAL SANS-SERIF ITALIC SMALL J → LATIN SMALL LETTER J# + {120364, "k"}, // MA# ( 𝘬 → k ) MATHEMATICAL SANS-SERIF ITALIC SMALL K → LATIN SMALL LETTER K# + {120365, "l"}, // MA# ( 𝘭 → l ) MATHEMATICAL SANS-SERIF ITALIC SMALL L → LATIN SMALL LETTER L# + {120366, "rn"}, // MA# ( 𝘮 → rn ) MATHEMATICAL SANS-SERIF ITALIC SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {120367, "n"}, // MA# ( 𝘯 → n ) MATHEMATICAL SANS-SERIF ITALIC SMALL N → LATIN SMALL LETTER N# + {120368, "o"}, // MA# ( 𝘰 → o ) MATHEMATICAL SANS-SERIF ITALIC SMALL O → LATIN SMALL LETTER O# + {120369, "p"}, // MA# ( 𝘱 → p ) MATHEMATICAL SANS-SERIF ITALIC SMALL P → LATIN SMALL LETTER P# + {120370, "q"}, // MA# ( 𝘲 → q ) MATHEMATICAL SANS-SERIF ITALIC SMALL Q → LATIN SMALL LETTER Q# + {120371, "r"}, // MA# ( 𝘳 → r ) MATHEMATICAL SANS-SERIF ITALIC SMALL R → LATIN SMALL LETTER R# + {120372, "s"}, // MA# ( 𝘴 → s ) MATHEMATICAL SANS-SERIF ITALIC SMALL S → LATIN SMALL LETTER S# + {120373, "t"}, // MA# ( 𝘵 → t ) MATHEMATICAL SANS-SERIF ITALIC SMALL T → LATIN SMALL LETTER T# + {120374, "u"}, // MA# ( 𝘶 → u ) MATHEMATICAL SANS-SERIF ITALIC SMALL U → LATIN SMALL LETTER U# + {120375, "v"}, // MA# ( 𝘷 → v ) MATHEMATICAL SANS-SERIF ITALIC SMALL V → LATIN SMALL LETTER V# + {120376, "w"}, // MA# ( 𝘸 → w ) MATHEMATICAL SANS-SERIF ITALIC SMALL W → LATIN SMALL LETTER W# + {120377, "x"}, // MA# ( 𝘹 → x ) MATHEMATICAL SANS-SERIF ITALIC SMALL X → LATIN SMALL LETTER X# + {120378, "y"}, // MA# ( 𝘺 → y ) MATHEMATICAL SANS-SERIF ITALIC SMALL Y → LATIN SMALL LETTER Y# + {120379, "z"}, // MA# ( 𝘻 → z ) MATHEMATICAL SANS-SERIF ITALIC SMALL Z → LATIN SMALL LETTER Z# + {120380, "A"}, // MA# ( 𝘼 → A ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL A → LATIN CAPITAL LETTER A# + {120381, "B"}, // MA# ( 𝘽 → B ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL B → LATIN CAPITAL LETTER B# + {120382, "C"}, // MA# ( 𝘾 → C ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL C → LATIN CAPITAL LETTER C# + {120383, "D"}, // MA# ( 𝘿 → D ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL D → LATIN CAPITAL LETTER D# + {120384, "E"}, // MA# ( 𝙀 → E ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL E → LATIN CAPITAL LETTER E# + {120385, "F"}, // MA# ( 𝙁 → F ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL F → LATIN CAPITAL LETTER F# + {120386, "G"}, // MA# ( 𝙂 → G ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL G → LATIN CAPITAL LETTER G# + {120387, "H"}, // MA# ( 𝙃 → H ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL H → LATIN CAPITAL LETTER H# + {120388, "l"}, // MA# ( 𝙄 → l ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL I → LATIN SMALL LETTER L# →I→ + {120389, "J"}, // MA# ( 𝙅 → J ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL J → LATIN CAPITAL LETTER J# + {120390, "K"}, // MA# ( 𝙆 → K ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL K → LATIN CAPITAL LETTER K# + {120391, "L"}, // MA# ( 𝙇 → L ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL L → LATIN CAPITAL LETTER L# + {120392, "M"}, // MA# ( 𝙈 → M ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL M → LATIN CAPITAL LETTER M# + {120393, "N"}, // MA# ( 𝙉 → N ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL N → LATIN CAPITAL LETTER N# + {120394, "O"}, // MA# ( 𝙊 → O ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL O → LATIN CAPITAL LETTER O# + {120395, "P"}, // MA# ( 𝙋 → P ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL P → LATIN CAPITAL LETTER P# + {120396, "Q"}, // MA# ( 𝙌 → Q ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL Q → LATIN CAPITAL LETTER Q# + {120397, "R"}, // MA# ( 𝙍 → R ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL R → LATIN CAPITAL LETTER R# + {120398, "S"}, // MA# ( 𝙎 → S ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL S → LATIN CAPITAL LETTER S# + {120399, "T"}, // MA# ( 𝙏 → T ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL T → LATIN CAPITAL LETTER T# + {120400, "U"}, // MA# ( 𝙐 → U ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL U → LATIN CAPITAL LETTER U# + {120401, "V"}, // MA# ( 𝙑 → V ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL V → LATIN CAPITAL LETTER V# + {120402, "W"}, // MA# ( 𝙒 → W ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL W → LATIN CAPITAL LETTER W# + {120403, "X"}, // MA# ( 𝙓 → X ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL X → LATIN CAPITAL LETTER X# + {120404, "Y"}, // MA# ( 𝙔 → Y ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL Y → LATIN CAPITAL LETTER Y# + {120405, "Z"}, // MA# ( 𝙕 → Z ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL Z → LATIN CAPITAL LETTER Z# + {120406, "a"}, // MA# ( 𝙖 → a ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL A → LATIN SMALL LETTER A# + {120407, "b"}, // MA# ( 𝙗 → b ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL B → LATIN SMALL LETTER B# + {120408, "c"}, // MA# ( 𝙘 → c ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL C → LATIN SMALL LETTER C# + {120409, "d"}, // MA# ( 𝙙 → d ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL D → LATIN SMALL LETTER D# + {120410, "e"}, // MA# ( 𝙚 → e ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL E → LATIN SMALL LETTER E# + {120411, "f"}, // MA# ( 𝙛 → f ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL F → LATIN SMALL LETTER F# + {120412, "g"}, // MA# ( 𝙜 → g ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL G → LATIN SMALL LETTER G# + {120413, "h"}, // MA# ( 𝙝 → h ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL H → LATIN SMALL LETTER H# + {120414, "i"}, // MA# ( 𝙞 → i ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL I → LATIN SMALL LETTER I# + {120415, "j"}, // MA# ( 𝙟 → j ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL J → LATIN SMALL LETTER J# + {120416, "k"}, // MA# ( 𝙠 → k ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL K → LATIN SMALL LETTER K# + {120417, "l"}, // MA# ( 𝙡 → l ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL L → LATIN SMALL LETTER L# + {120418, "rn"}, // MA# ( 𝙢 → rn ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {120419, "n"}, // MA# ( 𝙣 → n ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL N → LATIN SMALL LETTER N# + {120420, "o"}, // MA# ( 𝙤 → o ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL O → LATIN SMALL LETTER O# + {120421, "p"}, // MA# ( 𝙥 → p ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL P → LATIN SMALL LETTER P# + {120422, "q"}, // MA# ( 𝙦 → q ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL Q → LATIN SMALL LETTER Q# + {120423, "r"}, // MA# ( 𝙧 → r ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL R → LATIN SMALL LETTER R# + {120424, "s"}, // MA# ( 𝙨 → s ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL S → LATIN SMALL LETTER S# + {120425, "t"}, // MA# ( 𝙩 → t ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL T → LATIN SMALL LETTER T# + {120426, "u"}, // MA# ( 𝙪 → u ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL U → LATIN SMALL LETTER U# + {120427, "v"}, // MA# ( 𝙫 → v ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL V → LATIN SMALL LETTER V# + {120428, "w"}, // MA# ( 𝙬 → w ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL W → LATIN SMALL LETTER W# + {120429, "x"}, // MA# ( 𝙭 → x ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL X → LATIN SMALL LETTER X# + {120430, "y"}, // MA# ( 𝙮 → y ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL Y → LATIN SMALL LETTER Y# + {120431, "z"}, // MA# ( 𝙯 → z ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL Z → LATIN SMALL LETTER Z# + {120432, "A"}, // MA# ( 𝙰 → A ) MATHEMATICAL MONOSPACE CAPITAL A → LATIN CAPITAL LETTER A# + {120433, "B"}, // MA# ( 𝙱 → B ) MATHEMATICAL MONOSPACE CAPITAL B → LATIN CAPITAL LETTER B# + {120434, "C"}, // MA# ( 𝙲 → C ) MATHEMATICAL MONOSPACE CAPITAL C → LATIN CAPITAL LETTER C# + {120435, "D"}, // MA# ( 𝙳 → D ) MATHEMATICAL MONOSPACE CAPITAL D → LATIN CAPITAL LETTER D# + {120436, "E"}, // MA# ( 𝙴 → E ) MATHEMATICAL MONOSPACE CAPITAL E → LATIN CAPITAL LETTER E# + {120437, "F"}, // MA# ( 𝙵 → F ) MATHEMATICAL MONOSPACE CAPITAL F → LATIN CAPITAL LETTER F# + {120438, "G"}, // MA# ( 𝙶 → G ) MATHEMATICAL MONOSPACE CAPITAL G → LATIN CAPITAL LETTER G# + {120439, "H"}, // MA# ( 𝙷 → H ) MATHEMATICAL MONOSPACE CAPITAL H → LATIN CAPITAL LETTER H# + {120440, "l"}, // MA# ( 𝙸 → l ) MATHEMATICAL MONOSPACE CAPITAL I → LATIN SMALL LETTER L# →I→ + {120441, "J"}, // MA# ( 𝙹 → J ) MATHEMATICAL MONOSPACE CAPITAL J → LATIN CAPITAL LETTER J# + {120442, "K"}, // MA# ( 𝙺 → K ) MATHEMATICAL MONOSPACE CAPITAL K → LATIN CAPITAL LETTER K# + {120443, "L"}, // MA# ( 𝙻 → L ) MATHEMATICAL MONOSPACE CAPITAL L → LATIN CAPITAL LETTER L# + {120444, "M"}, // MA# ( 𝙼 → M ) MATHEMATICAL MONOSPACE CAPITAL M → LATIN CAPITAL LETTER M# + {120445, "N"}, // MA# ( 𝙽 → N ) MATHEMATICAL MONOSPACE CAPITAL N → LATIN CAPITAL LETTER N# + {120446, "O"}, // MA# ( 𝙾 → O ) MATHEMATICAL MONOSPACE CAPITAL O → LATIN CAPITAL LETTER O# + {120447, "P"}, // MA# ( 𝙿 → P ) MATHEMATICAL MONOSPACE CAPITAL P → LATIN CAPITAL LETTER P# + {120448, "Q"}, // MA# ( 𝚀 → Q ) MATHEMATICAL MONOSPACE CAPITAL Q → LATIN CAPITAL LETTER Q# + {120449, "R"}, // MA# ( 𝚁 → R ) MATHEMATICAL MONOSPACE CAPITAL R → LATIN CAPITAL LETTER R# + {120450, "S"}, // MA# ( 𝚂 → S ) MATHEMATICAL MONOSPACE CAPITAL S → LATIN CAPITAL LETTER S# + {120451, "T"}, // MA# ( 𝚃 → T ) MATHEMATICAL MONOSPACE CAPITAL T → LATIN CAPITAL LETTER T# + {120452, "U"}, // MA# ( 𝚄 → U ) MATHEMATICAL MONOSPACE CAPITAL U → LATIN CAPITAL LETTER U# + {120453, "V"}, // MA# ( 𝚅 → V ) MATHEMATICAL MONOSPACE CAPITAL V → LATIN CAPITAL LETTER V# + {120454, "W"}, // MA# ( 𝚆 → W ) MATHEMATICAL MONOSPACE CAPITAL W → LATIN CAPITAL LETTER W# + {120455, "X"}, // MA# ( 𝚇 → X ) MATHEMATICAL MONOSPACE CAPITAL X → LATIN CAPITAL LETTER X# + {120456, "Y"}, // MA# ( 𝚈 → Y ) MATHEMATICAL MONOSPACE CAPITAL Y → LATIN CAPITAL LETTER Y# + {120457, "Z"}, // MA# ( 𝚉 → Z ) MATHEMATICAL MONOSPACE CAPITAL Z → LATIN CAPITAL LETTER Z# + {120458, "a"}, // MA# ( 𝚊 → a ) MATHEMATICAL MONOSPACE SMALL A → LATIN SMALL LETTER A# + {120459, "b"}, // MA# ( 𝚋 → b ) MATHEMATICAL MONOSPACE SMALL B → LATIN SMALL LETTER B# + {120460, "c"}, // MA# ( 𝚌 → c ) MATHEMATICAL MONOSPACE SMALL C → LATIN SMALL LETTER C# + {120461, "d"}, // MA# ( 𝚍 → d ) MATHEMATICAL MONOSPACE SMALL D → LATIN SMALL LETTER D# + {120462, "e"}, // MA# ( 𝚎 → e ) MATHEMATICAL MONOSPACE SMALL E → LATIN SMALL LETTER E# + {120463, "f"}, // MA# ( 𝚏 → f ) MATHEMATICAL MONOSPACE SMALL F → LATIN SMALL LETTER F# + {120464, "g"}, // MA# ( 𝚐 → g ) MATHEMATICAL MONOSPACE SMALL G → LATIN SMALL LETTER G# + {120465, "h"}, // MA# ( 𝚑 → h ) MATHEMATICAL MONOSPACE SMALL H → LATIN SMALL LETTER H# + {120466, "i"}, // MA# ( 𝚒 → i ) MATHEMATICAL MONOSPACE SMALL I → LATIN SMALL LETTER I# + {120467, "j"}, // MA# ( 𝚓 → j ) MATHEMATICAL MONOSPACE SMALL J → LATIN SMALL LETTER J# + {120468, "k"}, // MA# ( 𝚔 → k ) MATHEMATICAL MONOSPACE SMALL K → LATIN SMALL LETTER K# + {120469, "l"}, // MA# ( 𝚕 → l ) MATHEMATICAL MONOSPACE SMALL L → LATIN SMALL LETTER L# + {120470, "rn"}, // MA# ( 𝚖 → rn ) MATHEMATICAL MONOSPACE SMALL M → LATIN SMALL LETTER R, LATIN SMALL LETTER N# →m→ + {120471, "n"}, // MA# ( 𝚗 → n ) MATHEMATICAL MONOSPACE SMALL N → LATIN SMALL LETTER N# + {120472, "o"}, // MA# ( 𝚘 → o ) MATHEMATICAL MONOSPACE SMALL O → LATIN SMALL LETTER O# + {120473, "p"}, // MA# ( 𝚙 → p ) MATHEMATICAL MONOSPACE SMALL P → LATIN SMALL LETTER P# + {120474, "q"}, // MA# ( 𝚚 → q ) MATHEMATICAL MONOSPACE SMALL Q → LATIN SMALL LETTER Q# + {120475, "r"}, // MA# ( 𝚛 → r ) MATHEMATICAL MONOSPACE SMALL R → LATIN SMALL LETTER R# + {120476, "s"}, // MA# ( 𝚜 → s ) MATHEMATICAL MONOSPACE SMALL S → LATIN SMALL LETTER S# + {120477, "t"}, // MA# ( 𝚝 → t ) MATHEMATICAL MONOSPACE SMALL T → LATIN SMALL LETTER T# + {120478, "u"}, // MA# ( 𝚞 → u ) MATHEMATICAL MONOSPACE SMALL U → LATIN SMALL LETTER U# + {120479, "v"}, // MA# ( 𝚟 → v ) MATHEMATICAL MONOSPACE SMALL V → LATIN SMALL LETTER V# + {120480, "w"}, // MA# ( 𝚠 → w ) MATHEMATICAL MONOSPACE SMALL W → LATIN SMALL LETTER W# + {120481, "x"}, // MA# ( 𝚡 → x ) MATHEMATICAL MONOSPACE SMALL X → LATIN SMALL LETTER X# + {120482, "y"}, // MA# ( 𝚢 → y ) MATHEMATICAL MONOSPACE SMALL Y → LATIN SMALL LETTER Y# + {120483, "z"}, // MA# ( 𝚣 → z ) MATHEMATICAL MONOSPACE SMALL Z → LATIN SMALL LETTER Z# + {120484, "i"}, // MA# ( 𝚤 → i ) MATHEMATICAL ITALIC SMALL DOTLESS I → LATIN SMALL LETTER I# →ı→ + {120488, "A"}, // MA# ( 𝚨 → A ) MATHEMATICAL BOLD CAPITAL ALPHA → LATIN CAPITAL LETTER A# →𝐀→ + {120489, "B"}, // MA# ( 𝚩 → B ) MATHEMATICAL BOLD CAPITAL BETA → LATIN CAPITAL LETTER B# →Β→ + {120492, "E"}, // MA# ( 𝚬 → E ) MATHEMATICAL BOLD CAPITAL EPSILON → LATIN CAPITAL LETTER E# →𝐄→ + {120493, "Z"}, // MA# ( 𝚭 → Z ) MATHEMATICAL BOLD CAPITAL ZETA → LATIN CAPITAL LETTER Z# →Ζ→ + {120494, "H"}, // MA# ( 𝚮 → H ) MATHEMATICAL BOLD CAPITAL ETA → LATIN CAPITAL LETTER H# →Η→ + {120496, "l"}, // MA# ( 𝚰 → l ) MATHEMATICAL BOLD CAPITAL IOTA → LATIN SMALL LETTER L# →Ι→ + {120497, "K"}, // MA# ( 𝚱 → K ) MATHEMATICAL BOLD CAPITAL KAPPA → LATIN CAPITAL LETTER K# →Κ→ + {120499, "M"}, // MA# ( 𝚳 → M ) MATHEMATICAL BOLD CAPITAL MU → LATIN CAPITAL LETTER M# →𝐌→ + {120500, "N"}, // MA# ( 𝚴 → N ) MATHEMATICAL BOLD CAPITAL NU → LATIN CAPITAL LETTER N# →𝐍→ + {120502, "O"}, // MA# ( 𝚶 → O ) MATHEMATICAL BOLD CAPITAL OMICRON → LATIN CAPITAL LETTER O# →𝐎→ + {120504, "P"}, // MA# ( 𝚸 → P ) MATHEMATICAL BOLD CAPITAL RHO → LATIN CAPITAL LETTER P# →𝐏→ + {120507, "T"}, // MA# ( 𝚻 → T ) MATHEMATICAL BOLD CAPITAL TAU → LATIN CAPITAL LETTER T# →Τ→ + {120508, "Y"}, // MA# ( 𝚼 → Y ) MATHEMATICAL BOLD CAPITAL UPSILON → LATIN CAPITAL LETTER Y# →Υ→ + {120510, "X"}, // MA# ( 𝚾 → X ) MATHEMATICAL BOLD CAPITAL CHI → LATIN CAPITAL LETTER X# →Χ→ + {120514, "a"}, // MA# ( 𝛂 → a ) MATHEMATICAL BOLD SMALL ALPHA → LATIN SMALL LETTER A# →α→ + {120516, "y"}, // MA# ( 𝛄 → y ) MATHEMATICAL BOLD SMALL GAMMA → LATIN SMALL LETTER Y# →γ→ + {120522, "i"}, // MA# ( 𝛊 → i ) MATHEMATICAL BOLD SMALL IOTA → LATIN SMALL LETTER I# →ι→ + {120526, "v"}, // MA# ( 𝛎 → v ) MATHEMATICAL BOLD SMALL NU → LATIN SMALL LETTER V# →ν→ + {120528, "o"}, // MA# ( 𝛐 → o ) MATHEMATICAL BOLD SMALL OMICRON → LATIN SMALL LETTER O# →𝐨→ + {120530, "p"}, // MA# ( 𝛒 → p ) MATHEMATICAL BOLD SMALL RHO → LATIN SMALL LETTER P# →ρ→ + {120532, "o"}, // MA# ( 𝛔 → o ) MATHEMATICAL BOLD SMALL SIGMA → LATIN SMALL LETTER O# →σ→ + {120534, "u"}, // MA# ( 𝛖 → u ) MATHEMATICAL BOLD SMALL UPSILON → LATIN SMALL LETTER U# →υ→→ʋ→ + {120544, "p"}, // MA# ( 𝛠 → p ) MATHEMATICAL BOLD RHO SYMBOL → LATIN SMALL LETTER P# →ρ→ + {120546, "A"}, // MA# ( 𝛢 → A ) MATHEMATICAL ITALIC CAPITAL ALPHA → LATIN CAPITAL LETTER A# →Α→ + {120547, "B"}, // MA# ( 𝛣 → B ) MATHEMATICAL ITALIC CAPITAL BETA → LATIN CAPITAL LETTER B# →Β→ + {120550, "E"}, // MA# ( 𝛦 → E ) MATHEMATICAL ITALIC CAPITAL EPSILON → LATIN CAPITAL LETTER E# →Ε→ + {120551, "Z"}, // MA# ( 𝛧 → Z ) MATHEMATICAL ITALIC CAPITAL ZETA → LATIN CAPITAL LETTER Z# →𝑍→ + {120552, "H"}, // MA# ( 𝛨 → H ) MATHEMATICAL ITALIC CAPITAL ETA → LATIN CAPITAL LETTER H# →Η→ + {120554, "l"}, // MA# ( 𝛪 → l ) MATHEMATICAL ITALIC CAPITAL IOTA → LATIN SMALL LETTER L# →Ι→ + {120555, "K"}, // MA# ( 𝛫 → K ) MATHEMATICAL ITALIC CAPITAL KAPPA → LATIN CAPITAL LETTER K# →𝐾→ + {120557, "M"}, // MA# ( 𝛭 → M ) MATHEMATICAL ITALIC CAPITAL MU → LATIN CAPITAL LETTER M# →𝑀→ + {120558, "N"}, // MA# ( 𝛮 → N ) MATHEMATICAL ITALIC CAPITAL NU → LATIN CAPITAL LETTER N# →𝑁→ + {120560, "O"}, // MA# ( 𝛰 → O ) MATHEMATICAL ITALIC CAPITAL OMICRON → LATIN CAPITAL LETTER O# →𝑂→ + {120562, "P"}, // MA# ( 𝛲 → P ) MATHEMATICAL ITALIC CAPITAL RHO → LATIN CAPITAL LETTER P# →Ρ→ + {120565, "T"}, // MA# ( 𝛵 → T ) MATHEMATICAL ITALIC CAPITAL TAU → LATIN CAPITAL LETTER T# →Τ→ + {120566, "Y"}, // MA# ( 𝛶 → Y ) MATHEMATICAL ITALIC CAPITAL UPSILON → LATIN CAPITAL LETTER Y# →Υ→ + {120568, "X"}, // MA# ( 𝛸 → X ) MATHEMATICAL ITALIC CAPITAL CHI → LATIN CAPITAL LETTER X# →Χ→ + {120572, "a"}, // MA# ( 𝛼 → a ) MATHEMATICAL ITALIC SMALL ALPHA → LATIN SMALL LETTER A# →α→ + {120574, "y"}, // MA# ( 𝛾 → y ) MATHEMATICAL ITALIC SMALL GAMMA → LATIN SMALL LETTER Y# →γ→ + {120580, "i"}, // MA# ( 𝜄 → i ) MATHEMATICAL ITALIC SMALL IOTA → LATIN SMALL LETTER I# →ι→ + {120584, "v"}, // MA# ( 𝜈 → v ) MATHEMATICAL ITALIC SMALL NU → LATIN SMALL LETTER V# →ν→ + {120586, "o"}, // MA# ( 𝜊 → o ) MATHEMATICAL ITALIC SMALL OMICRON → LATIN SMALL LETTER O# →𝑜→ + {120588, "p"}, // MA# ( 𝜌 → p ) MATHEMATICAL ITALIC SMALL RHO → LATIN SMALL LETTER P# →ρ→ + {120590, "o"}, // MA# ( 𝜎 → o ) MATHEMATICAL ITALIC SMALL SIGMA → LATIN SMALL LETTER O# →σ→ + {120592, "u"}, // MA# ( 𝜐 → u ) MATHEMATICAL ITALIC SMALL UPSILON → LATIN SMALL LETTER U# →υ→→ʋ→ + {120602, "p"}, // MA# ( 𝜚 → p ) MATHEMATICAL ITALIC RHO SYMBOL → LATIN SMALL LETTER P# →ρ→ + {120604, "A"}, // MA# ( 𝜜 → A ) MATHEMATICAL BOLD ITALIC CAPITAL ALPHA → LATIN CAPITAL LETTER A# →Α→ + {120605, "B"}, // MA# ( 𝜝 → B ) MATHEMATICAL BOLD ITALIC CAPITAL BETA → LATIN CAPITAL LETTER B# →Β→ + {120608, "E"}, // MA# ( 𝜠 → E ) MATHEMATICAL BOLD ITALIC CAPITAL EPSILON → LATIN CAPITAL LETTER E# →Ε→ + {120609, "Z"}, // MA# ( 𝜡 → Z ) MATHEMATICAL BOLD ITALIC CAPITAL ZETA → LATIN CAPITAL LETTER Z# →Ζ→ + {120610, "H"}, // MA# ( 𝜢 → H ) MATHEMATICAL BOLD ITALIC CAPITAL ETA → LATIN CAPITAL LETTER H# →𝑯→ + {120612, "l"}, // MA# ( 𝜤 → l ) MATHEMATICAL BOLD ITALIC CAPITAL IOTA → LATIN SMALL LETTER L# →Ι→ + {120613, "K"}, // MA# ( 𝜥 → K ) MATHEMATICAL BOLD ITALIC CAPITAL KAPPA → LATIN CAPITAL LETTER K# →𝑲→ + {120615, "M"}, // MA# ( 𝜧 → M ) MATHEMATICAL BOLD ITALIC CAPITAL MU → LATIN CAPITAL LETTER M# →𝑴→ + {120616, "N"}, // MA# ( 𝜨 → N ) MATHEMATICAL BOLD ITALIC CAPITAL NU → LATIN CAPITAL LETTER N# →𝑵→ + {120618, "O"}, // MA# ( 𝜪 → O ) MATHEMATICAL BOLD ITALIC CAPITAL OMICRON → LATIN CAPITAL LETTER O# →𝑶→ + {120620, "P"}, // MA# ( 𝜬 → P ) MATHEMATICAL BOLD ITALIC CAPITAL RHO → LATIN CAPITAL LETTER P# →Ρ→ + {120623, "T"}, // MA# ( 𝜯 → T ) MATHEMATICAL BOLD ITALIC CAPITAL TAU → LATIN CAPITAL LETTER T# →Τ→ + {120624, "Y"}, // MA# ( 𝜰 → Y ) MATHEMATICAL BOLD ITALIC CAPITAL UPSILON → LATIN CAPITAL LETTER Y# →Υ→ + {120626, "X"}, // MA# ( 𝜲 → X ) MATHEMATICAL BOLD ITALIC CAPITAL CHI → LATIN CAPITAL LETTER X# →𝑿→ + {120630, "a"}, // MA# ( 𝜶 → a ) MATHEMATICAL BOLD ITALIC SMALL ALPHA → LATIN SMALL LETTER A# →α→ + {120632, "y"}, // MA# ( 𝜸 → y ) MATHEMATICAL BOLD ITALIC SMALL GAMMA → LATIN SMALL LETTER Y# →γ→ + {120638, "i"}, // MA# ( 𝜾 → i ) MATHEMATICAL BOLD ITALIC SMALL IOTA → LATIN SMALL LETTER I# →ι→ + {120642, "v"}, // MA# ( 𝝂 → v ) MATHEMATICAL BOLD ITALIC SMALL NU → LATIN SMALL LETTER V# →ν→ + {120644, "o"}, // MA# ( 𝝄 → o ) MATHEMATICAL BOLD ITALIC SMALL OMICRON → LATIN SMALL LETTER O# →𝒐→ + {120646, "p"}, // MA# ( 𝝆 → p ) MATHEMATICAL BOLD ITALIC SMALL RHO → LATIN SMALL LETTER P# →ρ→ + {120648, "o"}, // MA# ( 𝝈 → o ) MATHEMATICAL BOLD ITALIC SMALL SIGMA → LATIN SMALL LETTER O# →σ→ + {120650, "u"}, // MA# ( 𝝊 → u ) MATHEMATICAL BOLD ITALIC SMALL UPSILON → LATIN SMALL LETTER U# →υ→→ʋ→ + {120660, "p"}, // MA# ( 𝝔 → p ) MATHEMATICAL BOLD ITALIC RHO SYMBOL → LATIN SMALL LETTER P# →ρ→ + {120662, "A"}, // MA# ( 𝝖 → A ) MATHEMATICAL SANS-SERIF BOLD CAPITAL ALPHA → LATIN CAPITAL LETTER A# →Α→ + {120663, "B"}, // MA# ( 𝝗 → B ) MATHEMATICAL SANS-SERIF BOLD CAPITAL BETA → LATIN CAPITAL LETTER B# →Β→ + {120666, "E"}, // MA# ( 𝝚 → E ) MATHEMATICAL SANS-SERIF BOLD CAPITAL EPSILON → LATIN CAPITAL LETTER E# →Ε→ + {120667, "Z"}, // MA# ( 𝝛 → Z ) MATHEMATICAL SANS-SERIF BOLD CAPITAL ZETA → LATIN CAPITAL LETTER Z# →Ζ→ + {120668, "H"}, // MA# ( 𝝜 → H ) MATHEMATICAL SANS-SERIF BOLD CAPITAL ETA → LATIN CAPITAL LETTER H# →Η→ + {120670, "l"}, // MA# ( 𝝞 → l ) MATHEMATICAL SANS-SERIF BOLD CAPITAL IOTA → LATIN SMALL LETTER L# →Ι→ + {120671, "K"}, // MA# ( 𝝟 → K ) MATHEMATICAL SANS-SERIF BOLD CAPITAL KAPPA → LATIN CAPITAL LETTER K# →Κ→ + {120673, "M"}, // MA# ( 𝝡 → M ) MATHEMATICAL SANS-SERIF BOLD CAPITAL MU → LATIN CAPITAL LETTER M# →Μ→ + {120674, "N"}, // MA# ( 𝝢 → N ) MATHEMATICAL SANS-SERIF BOLD CAPITAL NU → LATIN CAPITAL LETTER N# →Ν→ + {120676, "O"}, // MA# ( 𝝤 → O ) MATHEMATICAL SANS-SERIF BOLD CAPITAL OMICRON → LATIN CAPITAL LETTER O# →Ο→ + {120678, "P"}, // MA# ( 𝝦 → P ) MATHEMATICAL SANS-SERIF BOLD CAPITAL RHO → LATIN CAPITAL LETTER P# →Ρ→ + {120681, "T"}, // MA# ( 𝝩 → T ) MATHEMATICAL SANS-SERIF BOLD CAPITAL TAU → LATIN CAPITAL LETTER T# →Τ→ + {120682, "Y"}, // MA# ( 𝝪 → Y ) MATHEMATICAL SANS-SERIF BOLD CAPITAL UPSILON → LATIN CAPITAL LETTER Y# →Υ→ + {120684, "X"}, // MA# ( 𝝬 → X ) MATHEMATICAL SANS-SERIF BOLD CAPITAL CHI → LATIN CAPITAL LETTER X# →Χ→ + {120688, "a"}, // MA# ( 𝝰 → a ) MATHEMATICAL SANS-SERIF BOLD SMALL ALPHA → LATIN SMALL LETTER A# →α→ + {120690, "y"}, // MA# ( 𝝲 → y ) MATHEMATICAL SANS-SERIF BOLD SMALL GAMMA → LATIN SMALL LETTER Y# →γ→ + {120696, "i"}, // MA# ( 𝝸 → i ) MATHEMATICAL SANS-SERIF BOLD SMALL IOTA → LATIN SMALL LETTER I# →ι→ + {120700, "v"}, // MA# ( 𝝼 → v ) MATHEMATICAL SANS-SERIF BOLD SMALL NU → LATIN SMALL LETTER V# →ν→ + {120702, "o"}, // MA# ( 𝝾 → o ) MATHEMATICAL SANS-SERIF BOLD SMALL OMICRON → LATIN SMALL LETTER O# →ο→ + {120704, "p"}, // MA# ( 𝞀 → p ) MATHEMATICAL SANS-SERIF BOLD SMALL RHO → LATIN SMALL LETTER P# →ρ→ + {120706, "o"}, // MA# ( 𝞂 → o ) MATHEMATICAL SANS-SERIF BOLD SMALL SIGMA → LATIN SMALL LETTER O# →σ→ + {120708, "u"}, // MA# ( 𝞄 → u ) MATHEMATICAL SANS-SERIF BOLD SMALL UPSILON → LATIN SMALL LETTER U# →υ→→ʋ→ + {120718, "p"}, // MA# ( 𝞎 → p ) MATHEMATICAL SANS-SERIF BOLD RHO SYMBOL → LATIN SMALL LETTER P# →ρ→ + {120720, "A"}, // MA# ( 𝞐 → A ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL ALPHA → LATIN CAPITAL LETTER A# →Α→ + {120721, "B"}, // MA# ( 𝞑 → B ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL BETA → LATIN CAPITAL LETTER B# →Β→ + {120724, "E"}, // MA# ( 𝞔 → E ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL EPSILON → LATIN CAPITAL LETTER E# →Ε→ + {120725, "Z"}, // MA# ( 𝞕 → Z ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL ZETA → LATIN CAPITAL LETTER Z# →Ζ→ + {120726, "H"}, // MA# ( 𝞖 → H ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL ETA → LATIN CAPITAL LETTER H# →Η→ + {120728, "l"}, // MA# ( 𝞘 → l ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL IOTA → LATIN SMALL LETTER L# →Ι→ + {120729, "K"}, // MA# ( 𝞙 → K ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL KAPPA → LATIN CAPITAL LETTER K# →Κ→ + {120731, "M"}, // MA# ( 𝞛 → M ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL MU → LATIN CAPITAL LETTER M# →Μ→ + {120732, "N"}, // MA# ( 𝞜 → N ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL NU → LATIN CAPITAL LETTER N# →Ν→ + {120734, "O"}, // MA# ( 𝞞 → O ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL OMICRON → LATIN CAPITAL LETTER O# →Ο→ + {120736, "P"}, // MA# ( 𝞠 → P ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL RHO → LATIN CAPITAL LETTER P# →Ρ→ + {120739, "T"}, // MA# ( 𝞣 → T ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL TAU → LATIN CAPITAL LETTER T# →Τ→ + {120740, "Y"}, // MA# ( 𝞤 → Y ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL UPSILON → LATIN CAPITAL LETTER Y# →Υ→ + {120742, "X"}, // MA# ( 𝞦 → X ) MATHEMATICAL SANS-SERIF BOLD ITALIC CAPITAL CHI → LATIN CAPITAL LETTER X# →Χ→ + {120746, "a"}, // MA# ( 𝞪 → a ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL ALPHA → LATIN SMALL LETTER A# →α→ + {120748, "y"}, // MA# ( 𝞬 → y ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL GAMMA → LATIN SMALL LETTER Y# →γ→ + {120754, "i"}, // MA# ( 𝞲 → i ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL IOTA → LATIN SMALL LETTER I# →ι→ + {120758, "v"}, // MA# ( 𝞶 → v ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL NU → LATIN SMALL LETTER V# →ν→ + {120760, "o"}, // MA# ( 𝞸 → o ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL OMICRON → LATIN SMALL LETTER O# →ο→ + {120762, "p"}, // MA# ( 𝞺 → p ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL RHO → LATIN SMALL LETTER P# →ρ→ + {120764, "o"}, // MA# ( 𝞼 → o ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL SIGMA → LATIN SMALL LETTER O# →σ→ + {120766, "u"}, // MA# ( 𝞾 → u ) MATHEMATICAL SANS-SERIF BOLD ITALIC SMALL UPSILON → LATIN SMALL LETTER U# →υ→→ʋ→ + {120776, "p"}, // MA# ( 𝟈 → p ) MATHEMATICAL SANS-SERIF BOLD ITALIC RHO SYMBOL → LATIN SMALL LETTER P# →ρ→ + {120778, "F"}, // MA# ( 𝟊 → F ) MATHEMATICAL BOLD CAPITAL DIGAMMA → LATIN CAPITAL LETTER F# →Ϝ→ + {120782, "O"}, // MA# ( 𝟎 → O ) MATHEMATICAL BOLD DIGIT ZERO → LATIN CAPITAL LETTER O# →0→ + {120783, "l"}, // MA# ( 𝟏 → l ) MATHEMATICAL BOLD DIGIT ONE → LATIN SMALL LETTER L# →1→ + {120784, "2"}, // MA# ( 𝟐 → 2 ) MATHEMATICAL BOLD DIGIT TWO → DIGIT TWO# + {120785, "3"}, // MA# ( 𝟑 → 3 ) MATHEMATICAL BOLD DIGIT THREE → DIGIT THREE# + {120786, "4"}, // MA# ( 𝟒 → 4 ) MATHEMATICAL BOLD DIGIT FOUR → DIGIT FOUR# + {120787, "5"}, // MA# ( 𝟓 → 5 ) MATHEMATICAL BOLD DIGIT FIVE → DIGIT FIVE# + {120788, "6"}, // MA# ( 𝟔 → 6 ) MATHEMATICAL BOLD DIGIT SIX → DIGIT SIX# + {120789, "7"}, // MA# ( 𝟕 → 7 ) MATHEMATICAL BOLD DIGIT SEVEN → DIGIT SEVEN# + {120790, "8"}, // MA# ( 𝟖 → 8 ) MATHEMATICAL BOLD DIGIT EIGHT → DIGIT EIGHT# + {120791, "9"}, // MA# ( 𝟗 → 9 ) MATHEMATICAL BOLD DIGIT NINE → DIGIT NINE# + {120792, "O"}, // MA# ( 𝟘 → O ) MATHEMATICAL DOUBLE-STRUCK DIGIT ZERO → LATIN CAPITAL LETTER O# →0→ + {120793, "l"}, // MA# ( 𝟙 → l ) MATHEMATICAL DOUBLE-STRUCK DIGIT ONE → LATIN SMALL LETTER L# →1→ + {120794, "2"}, // MA# ( 𝟚 → 2 ) MATHEMATICAL DOUBLE-STRUCK DIGIT TWO → DIGIT TWO# + {120795, "3"}, // MA# ( 𝟛 → 3 ) MATHEMATICAL DOUBLE-STRUCK DIGIT THREE → DIGIT THREE# + {120796, "4"}, // MA# ( 𝟜 → 4 ) MATHEMATICAL DOUBLE-STRUCK DIGIT FOUR → DIGIT FOUR# + {120797, "5"}, // MA# ( 𝟝 → 5 ) MATHEMATICAL DOUBLE-STRUCK DIGIT FIVE → DIGIT FIVE# + {120798, "6"}, // MA# ( 𝟞 → 6 ) MATHEMATICAL DOUBLE-STRUCK DIGIT SIX → DIGIT SIX# + {120799, "7"}, // MA# ( 𝟟 → 7 ) MATHEMATICAL DOUBLE-STRUCK DIGIT SEVEN → DIGIT SEVEN# + {120800, "8"}, // MA# ( 𝟠 → 8 ) MATHEMATICAL DOUBLE-STRUCK DIGIT EIGHT → DIGIT EIGHT# + {120801, "9"}, // MA# ( 𝟡 → 9 ) MATHEMATICAL DOUBLE-STRUCK DIGIT NINE → DIGIT NINE# + {120802, "O"}, // MA# ( 𝟢 → O ) MATHEMATICAL SANS-SERIF DIGIT ZERO → LATIN CAPITAL LETTER O# →0→ + {120803, "l"}, // MA# ( 𝟣 → l ) MATHEMATICAL SANS-SERIF DIGIT ONE → LATIN SMALL LETTER L# →1→ + {120804, "2"}, // MA# ( 𝟤 → 2 ) MATHEMATICAL SANS-SERIF DIGIT TWO → DIGIT TWO# + {120805, "3"}, // MA# ( 𝟥 → 3 ) MATHEMATICAL SANS-SERIF DIGIT THREE → DIGIT THREE# + {120806, "4"}, // MA# ( 𝟦 → 4 ) MATHEMATICAL SANS-SERIF DIGIT FOUR → DIGIT FOUR# + {120807, "5"}, // MA# ( 𝟧 → 5 ) MATHEMATICAL SANS-SERIF DIGIT FIVE → DIGIT FIVE# + {120808, "6"}, // MA# ( 𝟨 → 6 ) MATHEMATICAL SANS-SERIF DIGIT SIX → DIGIT SIX# + {120809, "7"}, // MA# ( 𝟩 → 7 ) MATHEMATICAL SANS-SERIF DIGIT SEVEN → DIGIT SEVEN# + {120810, "8"}, // MA# ( 𝟪 → 8 ) MATHEMATICAL SANS-SERIF DIGIT EIGHT → DIGIT EIGHT# + {120811, "9"}, // MA# ( 𝟫 → 9 ) MATHEMATICAL SANS-SERIF DIGIT NINE → DIGIT NINE# + {120812, "O"}, // MA# ( 𝟬 → O ) MATHEMATICAL SANS-SERIF BOLD DIGIT ZERO → LATIN CAPITAL LETTER O# →0→ + {120813, "l"}, // MA# ( 𝟭 → l ) MATHEMATICAL SANS-SERIF BOLD DIGIT ONE → LATIN SMALL LETTER L# →1→ + {120814, "2"}, // MA# ( 𝟮 → 2 ) MATHEMATICAL SANS-SERIF BOLD DIGIT TWO → DIGIT TWO# + {120815, "3"}, // MA# ( 𝟯 → 3 ) MATHEMATICAL SANS-SERIF BOLD DIGIT THREE → DIGIT THREE# + {120816, "4"}, // MA# ( 𝟰 → 4 ) MATHEMATICAL SANS-SERIF BOLD DIGIT FOUR → DIGIT FOUR# + {120817, "5"}, // MA# ( 𝟱 → 5 ) MATHEMATICAL SANS-SERIF BOLD DIGIT FIVE → DIGIT FIVE# + {120818, "6"}, // MA# ( 𝟲 → 6 ) MATHEMATICAL SANS-SERIF BOLD DIGIT SIX → DIGIT SIX# + {120819, "7"}, // MA# ( 𝟳 → 7 ) MATHEMATICAL SANS-SERIF BOLD DIGIT SEVEN → DIGIT SEVEN# + {120820, "8"}, // MA# ( 𝟴 → 8 ) MATHEMATICAL SANS-SERIF BOLD DIGIT EIGHT → DIGIT EIGHT# + {120821, "9"}, // MA# ( 𝟵 → 9 ) MATHEMATICAL SANS-SERIF BOLD DIGIT NINE → DIGIT NINE# + {120822, "O"}, // MA# ( 𝟶 → O ) MATHEMATICAL MONOSPACE DIGIT ZERO → LATIN CAPITAL LETTER O# →0→ + {120823, "l"}, // MA# ( 𝟷 → l ) MATHEMATICAL MONOSPACE DIGIT ONE → LATIN SMALL LETTER L# →1→ + {120824, "2"}, // MA# ( 𝟸 → 2 ) MATHEMATICAL MONOSPACE DIGIT TWO → DIGIT TWO# + {120825, "3"}, // MA# ( 𝟹 → 3 ) MATHEMATICAL MONOSPACE DIGIT THREE → DIGIT THREE# + {120826, "4"}, // MA# ( 𝟺 → 4 ) MATHEMATICAL MONOSPACE DIGIT FOUR → DIGIT FOUR# + {120827, "5"}, // MA# ( 𝟻 → 5 ) MATHEMATICAL MONOSPACE DIGIT FIVE → DIGIT FIVE# + {120828, "6"}, // MA# ( 𝟼 → 6 ) MATHEMATICAL MONOSPACE DIGIT SIX → DIGIT SIX# + {120829, "7"}, // MA# ( 𝟽 → 7 ) MATHEMATICAL MONOSPACE DIGIT SEVEN → DIGIT SEVEN# + {120830, "8"}, // MA# ( 𝟾 → 8 ) MATHEMATICAL MONOSPACE DIGIT EIGHT → DIGIT EIGHT# + {120831, "9"}, // MA# ( 𝟿 → 9 ) MATHEMATICAL MONOSPACE DIGIT NINE → DIGIT NINE# + {125127, "l"}, // MA#* ( ‎𞣇‎ → l ) MENDE KIKAKUI DIGIT ONE → LATIN SMALL LETTER L# + {125131, "8"}, // MA#* ( ‎𞣋‎ → 8 ) MENDE KIKAKUI DIGIT FIVE → DIGIT EIGHT# + {126464, "l"}, // MA# ( ‎𞸀‎ → l ) ARABIC MATHEMATICAL ALEF → LATIN SMALL LETTER L# →‎ا‎→→1→ + {126500, "o"}, // MA# ( ‎𞸤‎ → o ) ARABIC MATHEMATICAL INITIAL HEH → LATIN SMALL LETTER O# →‎ه‎→ + {126564, "o"}, // MA# ( ‎𞹤‎ → o ) ARABIC MATHEMATICAL STRETCHED HEH → LATIN SMALL LETTER O# →‎ه‎→ + {126592, "l"}, // MA# ( ‎𞺀‎ → l ) ARABIC MATHEMATICAL LOOPED ALEF → LATIN SMALL LETTER L# →‎ا‎→→1→ + {126596, "o"}, // MA# ( ‎𞺄‎ → o ) ARABIC MATHEMATICAL LOOPED HEH → LATIN SMALL LETTER O# →‎ه‎→ + {127232, "O."}, // MA#* ( 🄀 → O. ) DIGIT ZERO FULL STOP → LATIN CAPITAL LETTER O, FULL STOP# →0.→ + {127233, "O,"}, // MA#* ( 🄁 → O, ) DIGIT ZERO COMMA → LATIN CAPITAL LETTER O, COMMA# →0,→ + {127234, "l,"}, // MA#* ( 🄂 → l, ) DIGIT ONE COMMA → LATIN SMALL LETTER L, COMMA# →1,→ + {127235, "2,"}, // MA#* ( 🄃 → 2, ) DIGIT TWO COMMA → DIGIT TWO, COMMA# + {127236, "3,"}, // MA#* ( 🄄 → 3, ) DIGIT THREE COMMA → DIGIT THREE, COMMA# + {127237, "4,"}, // MA#* ( 🄅 → 4, ) DIGIT FOUR COMMA → DIGIT FOUR, COMMA# + {127238, "5,"}, // MA#* ( 🄆 → 5, ) DIGIT FIVE COMMA → DIGIT FIVE, COMMA# + {127239, "6,"}, // MA#* ( 🄇 → 6, ) DIGIT SIX COMMA → DIGIT SIX, COMMA# + {127240, "7,"}, // MA#* ( 🄈 → 7, ) DIGIT SEVEN COMMA → DIGIT SEVEN, COMMA# + {127241, "8,"}, // MA#* ( 🄉 → 8, ) DIGIT EIGHT COMMA → DIGIT EIGHT, COMMA# + {127242, "9,"}, // MA#* ( 🄊 → 9, ) DIGIT NINE COMMA → DIGIT NINE, COMMA# + {127248, "(A)"}, // MA#* ( 🄐 → (A) ) PARENTHESIZED LATIN CAPITAL LETTER A → LEFT PARENTHESIS, LATIN CAPITAL LETTER A, RIGHT PARENTHESIS# + {127249, "(B)"}, // MA#* ( 🄑 → (B) ) PARENTHESIZED LATIN CAPITAL LETTER B → LEFT PARENTHESIS, LATIN CAPITAL LETTER B, RIGHT PARENTHESIS# + {127250, "(C)"}, // MA#* ( 🄒 → (C) ) PARENTHESIZED LATIN CAPITAL LETTER C → LEFT PARENTHESIS, LATIN CAPITAL LETTER C, RIGHT PARENTHESIS# + {127251, "(D)"}, // MA#* ( 🄓 → (D) ) PARENTHESIZED LATIN CAPITAL LETTER D → LEFT PARENTHESIS, LATIN CAPITAL LETTER D, RIGHT PARENTHESIS# + {127252, "(E)"}, // MA#* ( 🄔 → (E) ) PARENTHESIZED LATIN CAPITAL LETTER E → LEFT PARENTHESIS, LATIN CAPITAL LETTER E, RIGHT PARENTHESIS# + {127253, "(F)"}, // MA#* ( 🄕 → (F) ) PARENTHESIZED LATIN CAPITAL LETTER F → LEFT PARENTHESIS, LATIN CAPITAL LETTER F, RIGHT PARENTHESIS# + {127254, "(G)"}, // MA#* ( 🄖 → (G) ) PARENTHESIZED LATIN CAPITAL LETTER G → LEFT PARENTHESIS, LATIN CAPITAL LETTER G, RIGHT PARENTHESIS# + {127255, "(H)"}, // MA#* ( 🄗 → (H) ) PARENTHESIZED LATIN CAPITAL LETTER H → LEFT PARENTHESIS, LATIN CAPITAL LETTER H, RIGHT PARENTHESIS# + {127256, "(l)"}, // MA#* ( 🄘 → (l) ) PARENTHESIZED LATIN CAPITAL LETTER I → LEFT PARENTHESIS, LATIN SMALL LETTER L, RIGHT PARENTHESIS# →(I)→ + {127257, "(J)"}, // MA#* ( 🄙 → (J) ) PARENTHESIZED LATIN CAPITAL LETTER J → LEFT PARENTHESIS, LATIN CAPITAL LETTER J, RIGHT PARENTHESIS# + {127258, "(K)"}, // MA#* ( 🄚 → (K) ) PARENTHESIZED LATIN CAPITAL LETTER K → LEFT PARENTHESIS, LATIN CAPITAL LETTER K, RIGHT PARENTHESIS# + {127259, "(L)"}, // MA#* ( 🄛 → (L) ) PARENTHESIZED LATIN CAPITAL LETTER L → LEFT PARENTHESIS, LATIN CAPITAL LETTER L, RIGHT PARENTHESIS# + {127260, "(M)"}, // MA#* ( 🄜 → (M) ) PARENTHESIZED LATIN CAPITAL LETTER M → LEFT PARENTHESIS, LATIN CAPITAL LETTER M, RIGHT PARENTHESIS# + {127261, "(N)"}, // MA#* ( 🄝 → (N) ) PARENTHESIZED LATIN CAPITAL LETTER N → LEFT PARENTHESIS, LATIN CAPITAL LETTER N, RIGHT PARENTHESIS# + {127262, "(O)"}, // MA#* ( 🄞 → (O) ) PARENTHESIZED LATIN CAPITAL LETTER O → LEFT PARENTHESIS, LATIN CAPITAL LETTER O, RIGHT PARENTHESIS# + {127263, "(P)"}, // MA#* ( 🄟 → (P) ) PARENTHESIZED LATIN CAPITAL LETTER P → LEFT PARENTHESIS, LATIN CAPITAL LETTER P, RIGHT PARENTHESIS# + {127264, "(Q)"}, // MA#* ( 🄠 → (Q) ) PARENTHESIZED LATIN CAPITAL LETTER Q → LEFT PARENTHESIS, LATIN CAPITAL LETTER Q, RIGHT PARENTHESIS# + {127265, "(R)"}, // MA#* ( 🄡 → (R) ) PARENTHESIZED LATIN CAPITAL LETTER R → LEFT PARENTHESIS, LATIN CAPITAL LETTER R, RIGHT PARENTHESIS# + {127266, "(S)"}, // MA#* ( 🄢 → (S) ) PARENTHESIZED LATIN CAPITAL LETTER S → LEFT PARENTHESIS, LATIN CAPITAL LETTER S, RIGHT PARENTHESIS# + {127267, "(T)"}, // MA#* ( 🄣 → (T) ) PARENTHESIZED LATIN CAPITAL LETTER T → LEFT PARENTHESIS, LATIN CAPITAL LETTER T, RIGHT PARENTHESIS# + {127268, "(U)"}, // MA#* ( 🄤 → (U) ) PARENTHESIZED LATIN CAPITAL LETTER U → LEFT PARENTHESIS, LATIN CAPITAL LETTER U, RIGHT PARENTHESIS# + {127269, "(V)"}, // MA#* ( 🄥 → (V) ) PARENTHESIZED LATIN CAPITAL LETTER V → LEFT PARENTHESIS, LATIN CAPITAL LETTER V, RIGHT PARENTHESIS# + {127270, "(W)"}, // MA#* ( 🄦 → (W) ) PARENTHESIZED LATIN CAPITAL LETTER W → LEFT PARENTHESIS, LATIN CAPITAL LETTER W, RIGHT PARENTHESIS# + {127271, "(X)"}, // MA#* ( 🄧 → (X) ) PARENTHESIZED LATIN CAPITAL LETTER X → LEFT PARENTHESIS, LATIN CAPITAL LETTER X, RIGHT PARENTHESIS# + {127272, "(Y)"}, // MA#* ( 🄨 → (Y) ) PARENTHESIZED LATIN CAPITAL LETTER Y → LEFT PARENTHESIS, LATIN CAPITAL LETTER Y, RIGHT PARENTHESIS# + {127273, "(Z)"}, // MA#* ( 🄩 → (Z) ) PARENTHESIZED LATIN CAPITAL LETTER Z → LEFT PARENTHESIS, LATIN CAPITAL LETTER Z, RIGHT PARENTHESIS# + {127274, "(S)"}, // MA#* ( 🄪 → (S) ) TORTOISE SHELL BRACKETED LATIN CAPITAL LETTER S → LEFT PARENTHESIS, LATIN CAPITAL LETTER S, RIGHT PARENTHESIS# →〔S〕→ + {128768, "QE"}, // MA#* ( 🜀 → QE ) ALCHEMICAL SYMBOL FOR QUINTESSENCE → LATIN CAPITAL LETTER Q, LATIN CAPITAL LETTER E# + {128775, "AR"}, // MA#* ( 🜇 → AR ) ALCHEMICAL SYMBOL FOR AQUA REGIA-2 → LATIN CAPITAL LETTER A, LATIN CAPITAL LETTER R# + {128844, "C"}, // MA#* ( 🝌 → C ) ALCHEMICAL SYMBOL FOR CALX → LATIN CAPITAL LETTER C# + {128860, "sss"}, // MA#* ( 🝜 → sss ) ALCHEMICAL SYMBOL FOR STRATUM SUPER STRATUM → LATIN SMALL LETTER S, LATIN SMALL LETTER S, LATIN SMALL LETTER S# + {128872, "T"}, // MA#* ( 🝨 → T ) ALCHEMICAL SYMBOL FOR CRUCIBLE-4 → LATIN CAPITAL LETTER T# + {128875, "MB"}, // MA#* ( 🝫 → MB ) ALCHEMICAL SYMBOL FOR BATH OF MARY → LATIN CAPITAL LETTER M, LATIN CAPITAL LETTER B# + {128876, "VB"}, // MA#* ( 🝬 → VB ) ALCHEMICAL SYMBOL FOR BATH OF VAPOURS → LATIN CAPITAL LETTER V, LATIN CAPITAL LETTER B# +}; +// clang-format on + +const char* findConfusable(uint32_t codepoint) +{ + auto it = std::lower_bound(std::begin(kConfusables), std::end(kConfusables), codepoint, [](const Confusable& lhs, uint32_t rhs) { + return lhs.codepoint < rhs; + }); + + return (it != std::end(kConfusables) && it->codepoint == codepoint) ? it->text : nullptr; +} + +} // namespace Luau diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp new file mode 100644 index 0000000..a7aa24c --- /dev/null +++ b/Ast/src/Lexer.cpp @@ -0,0 +1,1149 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Lexer.h" + +#include "Luau/Confusables.h" +#include "Luau/StringUtils.h" + +#include + +namespace Luau +{ + +Allocator::Allocator() + : root(static_cast(operator new(sizeof(Page)))) + , offset(0) +{ + root->next = nullptr; +} + +Allocator::Allocator(Allocator&& rhs) + : root(rhs.root) + , offset(rhs.offset) +{ + rhs.root = nullptr; + rhs.offset = 0; +} + +Allocator::~Allocator() +{ + Page* page = root; + + while (page) + { + Page* next = page->next; + + operator delete(page); + + page = next; + } +} + +void* Allocator::allocate(size_t size) +{ + constexpr size_t align = alignof(void*) > alignof(double) ? alignof(void*) : alignof(double); + + if (root) + { + uintptr_t data = reinterpret_cast(root->data); + uintptr_t result = (data + offset + align - 1) & ~(align - 1); + if (result + size <= data + sizeof(root->data)) + { + offset = result - data + size; + return reinterpret_cast(result); + } + } + + // allocate new page + size_t pageSize = size > sizeof(root->data) ? size : sizeof(root->data); + void* pageData = operator new(offsetof(Page, data) + pageSize); + + Page* page = static_cast(pageData); + + page->next = root; + + root = page; + offset = size; + + return page->data; +} + +Lexeme::Lexeme(const Location& location, Type type) + : type(type) + , location(location) + , length(0) + , data(nullptr) +{ +} + +Lexeme::Lexeme(const Location& location, char character) + : type(static_cast(static_cast(character))) + , location(location) + , length(0) + , data(nullptr) +{ +} + +Lexeme::Lexeme(const Location& location, Type type, const char* data, size_t size) + : type(type) + , location(location) + , length(unsigned(size)) + , data(data) +{ + LUAU_ASSERT(type == RawString || type == QuotedString || type == Number || type == Comment || type == BlockComment); +} + +Lexeme::Lexeme(const Location& location, Type type, const char* name) + : type(type) + , location(location) + , length(0) + , name(name) +{ + LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); +} + +static bool isComment(const Lexeme& lexeme) +{ + return lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment; +} + +static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", + "repeat", "return", "then", "true", "until", "while"}; + +std::string Lexeme::toString() const +{ + switch (type) + { + case Eof: + return ""; + + case Equal: + return "'=='"; + + case LessEqual: + return "'<='"; + + case GreaterEqual: + return "'>='"; + + case NotEqual: + return "'~='"; + + case Dot2: + return "'..'"; + + case Dot3: + return "'...'"; + + case SkinnyArrow: + return "'->'"; + + case DoubleColon: + return "'::'"; + + case AddAssign: + return "'+='"; + + case SubAssign: + return "'-='"; + + case MulAssign: + return "'*='"; + + case DivAssign: + return "'/='"; + + case ModAssign: + return "'%='"; + + case PowAssign: + return "'^='"; + + case ConcatAssign: + return "'..='"; + + case RawString: + case QuotedString: + return data ? format("\"%.*s\"", length, data) : "string"; + + case Number: + return data ? format("'%.*s'", length, data) : "number"; + + case Name: + return name ? format("'%s'", name) : "identifier"; + + case Comment: + return "comment"; + + case BrokenString: + return "malformed string"; + + case BrokenComment: + return "unfinished comment"; + + case BrokenUnicode: + if (codepoint) + { + if (const char* confusable = findConfusable(codepoint)) + return format("Unicode character U+%x (did you mean '%s'?)", codepoint, confusable); + + return format("Unicode character U+%x", codepoint); + } + else + { + return "invalid UTF-8 sequence"; + } + + default: + if (type < Char_END) + return format("'%c'", type); + else if (type >= Reserved_BEGIN && type < Reserved_END) + return format("'%s'", kReserved[type - Reserved_BEGIN]); + else + return ""; + } +} + +bool AstNameTable::Entry::operator==(const Entry& other) const +{ + return length == other.length && memcmp(value.value, other.value.value, length) == 0; +} + +size_t AstNameTable::EntryHash::operator()(const Entry& e) const +{ + // FNV1a + uint32_t hash = 2166136261; + + for (size_t i = 0; i < e.length; ++i) + { + hash ^= uint8_t(e.value.value[i]); + hash *= 16777619; + } + + return hash; +} + +AstNameTable::AstNameTable(Allocator& allocator) + : data({AstName(""), 0, Lexeme::Eof}, 128) + , allocator(allocator) +{ + static_assert(sizeof(kReserved) / sizeof(kReserved[0]) == Lexeme::Reserved_END - Lexeme::Reserved_BEGIN); + + for (int i = Lexeme::Reserved_BEGIN; i < Lexeme::Reserved_END; ++i) + addStatic(kReserved[i - Lexeme::Reserved_BEGIN], static_cast(i)); +} + +AstName AstNameTable::addStatic(const char* name, Lexeme::Type type) +{ + AstNameTable::Entry entry = {AstName(name), uint32_t(strlen(name)), type}; + + LUAU_ASSERT(!data.contains(entry)); + data.insert(entry); + + return entry.value; +} + +std::pair AstNameTable::getOrAddWithType(const char* name, size_t length) +{ + AstNameTable::Entry key = {AstName(name), uint32_t(length), Lexeme::Eof}; + const Entry& entry = data.insert(key); + + // entry already was inserted + if (entry.type != Lexeme::Eof) + return std::make_pair(entry.value, entry.type); + + // we just inserted an entry with a non-owned pointer into the map + // we need to correct it, *but* we need to be careful about not disturbing the hash value + char* nameData = static_cast(allocator.allocate(length + 1)); + memcpy(nameData, name, length); + nameData[length] = 0; + + const_cast(entry).value = AstName(nameData); + const_cast(entry).type = Lexeme::Name; + + return std::make_pair(entry.value, entry.type); +} + +std::pair AstNameTable::getWithType(const char* name, size_t length) const +{ + if (const Entry* entry = data.find({AstName(name), uint32_t(length), Lexeme::Eof})) + { + return std::make_pair(entry->value, entry->type); + } + return std::make_pair(AstName(), Lexeme::Name); +} + +AstName AstNameTable::getOrAdd(const char* name) +{ + return getOrAddWithType(name, strlen(name)).first; +} + +AstName AstNameTable::get(const char* name) const +{ + return getWithType(name, strlen(name)).first; +} + +inline bool isSpace(char ch) +{ + return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; +} + +inline bool isAlpha(char ch) +{ + // use or trick to convert to lower case and unsigned comparison to do range check + return unsigned((ch | ' ') - 'a') < 26; +} + +inline bool isDigit(char ch) +{ + return unsigned(ch - '0') < 10; +} + +inline bool isHexDigit(char ch) +{ + // use or trick to convert to lower case and unsigned comparison to do range check + return unsigned(ch - '0') < 10 || unsigned((ch | ' ') - 'a') < 6; +} + +inline bool isNewline(char ch) +{ + return ch == '\n'; +} + +static char unescape(char ch) +{ + switch (ch) + { + case 'a': + return '\a'; + case 'b': + return '\b'; + case 'f': + return '\f'; + case 'n': + return '\n'; + case 'r': + return '\r'; + case 't': + return '\t'; + case 'v': + return '\v'; + default: + return ch; + } +} + +Lexer::Lexer(const char* buffer, size_t bufferSize, AstNameTable& names) + : buffer(buffer) + , bufferSize(bufferSize) + , offset(0) + , line(0) + , lineOffset(0) + , lexeme(Location(Position(0, 0), 0), Lexeme::Eof) + , names(names) + , skipComments(false) + , readNames(true) +{ +} + +void Lexer::setSkipComments(bool skip) +{ + skipComments = skip; +} + +void Lexer::setReadNames(bool read) +{ + readNames = read; +} + +const Lexeme& Lexer::next() +{ + return next(this->skipComments); +} + +const Lexeme& Lexer::next(bool skipComments) +{ + // in skipComments mode we reject valid comments + do + { + // consume whitespace before the token + while (isSpace(peekch())) + consume(); + + prevLocation = lexeme.location; + + lexeme = readNext(); + } while (skipComments && isComment(lexeme)); + + return lexeme; +} + +void Lexer::nextline() +{ + while (peekch() != 0 && peekch() != '\r' && !isNewline(peekch())) + consume(); + + next(); +} + +Lexeme Lexer::lookahead() +{ + unsigned int currentOffset = offset; + unsigned int currentLine = line; + unsigned int currentLineOffset = lineOffset; + Lexeme currentLexeme = lexeme; + Location currentPrevLocation = prevLocation; + + Lexeme result = next(); + + offset = currentOffset; + line = currentLine; + lineOffset = currentLineOffset; + lexeme = currentLexeme; + prevLocation = currentPrevLocation; + + return result; +} + +bool Lexer::isReserved(const std::string& word) +{ + for (int i = Lexeme::Reserved_BEGIN; i < Lexeme::Reserved_END; ++i) + if (word == kReserved[i - Lexeme::Reserved_BEGIN]) + return true; + + return false; +} + +LUAU_FORCEINLINE +char Lexer::peekch() const +{ + return (offset < bufferSize) ? buffer[offset] : 0; +} + +LUAU_FORCEINLINE +char Lexer::peekch(unsigned int lookahead) const +{ + return (offset + lookahead < bufferSize) ? buffer[offset + lookahead] : 0; +} + +Position Lexer::position() const +{ + return Position(line, offset - lineOffset); +} + +void Lexer::consume() +{ + if (isNewline(buffer[offset])) + { + line++; + lineOffset = offset + 1; + } + + offset++; +} + +Lexeme Lexer::readCommentBody() +{ + Position start = position(); + + LUAU_ASSERT(peekch(0) == '-' && peekch(1) == '-'); + consume(); + consume(); + + size_t startOffset = offset; + + if (peekch() == '[') + { + int sep = skipLongSeparator(); + + if (sep >= 0) + { + return readLongString(start, sep, Lexeme::BlockComment, Lexeme::BrokenComment); + } + } + + // fall back to single-line comment + while (peekch() != 0 && peekch() != '\r' && !isNewline(peekch())) + consume(); + + return Lexeme(Location(start, position()), Lexeme::Comment, &buffer[startOffset], offset - startOffset); +} + +// Given a sequence [===[ or ]===], returns: +// 1. number of equal signs (or 0 if none present) between the brackets +// 2. -1 if this is not a long comment/string separator +// 3. -N if this is a malformed separator +// Does *not* consume the closing brace. +int Lexer::skipLongSeparator() +{ + char start = peekch(); + + LUAU_ASSERT(start == '[' || start == ']'); + consume(); + + int count = 0; + + while (peekch() == '=') + { + consume(); + count++; + } + + return (start == peekch()) ? count : (-count) - 1; +} + +Lexeme Lexer::readLongString(const Position& start, int sep, Lexeme::Type ok, Lexeme::Type broken) +{ + // skip (second) [ + LUAU_ASSERT(peekch() == '['); + consume(); + + unsigned int startOffset = offset; + + while (peekch()) + { + if (peekch() == ']') + { + if (skipLongSeparator() == sep) + { + LUAU_ASSERT(peekch() == ']'); + consume(); // skip (second) ] + + unsigned int endOffset = offset - sep - 2; + LUAU_ASSERT(endOffset >= startOffset); + + return Lexeme(Location(start, position()), ok, &buffer[startOffset], endOffset - startOffset); + } + } + else + { + consume(); + } + } + + return Lexeme(Location(start, position()), broken); +} + +Lexeme Lexer::readQuotedString() +{ + Position start = position(); + + char delimiter = peekch(); + LUAU_ASSERT(delimiter == '\'' || delimiter == '"'); + consume(); + + unsigned int startOffset = offset; + + while (peekch() != delimiter) + { + switch (peekch()) + { + case 0: + case '\r': + case '\n': + return Lexeme(Location(start, position()), Lexeme::BrokenString); + + case '\\': + consume(); + switch (peekch()) + { + case '\r': + consume(); + if (peekch() == '\n') + consume(); + break; + + case 0: + break; + + case 'z': + consume(); + while (isSpace(peekch())) + consume(); + break; + + default: + consume(); + } + break; + + default: + consume(); + } + } + + consume(); + + return Lexeme(Location(start, position()), Lexeme::QuotedString, &buffer[startOffset], offset - startOffset - 1); +} + +Lexeme Lexer::readNumber(const Position& start, unsigned int startOffset) +{ + LUAU_ASSERT(isDigit(peekch())); + + // This function does not do the number parsing - it only skips a number-like pattern. + // It uses the same logic as Lua stock lexer; the resulting string is later converted + // to a number with proper verification. + do + { + consume(); + } while (isDigit(peekch()) || peekch() == '.' || peekch() == '_'); + + if (peekch() == 'e' || peekch() == 'E') + { + consume(); + + if (peekch() == '+' || peekch() == '-') + consume(); + } + + while (isAlpha(peekch()) || isDigit(peekch()) || peekch() == '_') + consume(); + + return Lexeme(Location(start, position()), Lexeme::Number, &buffer[startOffset], offset - startOffset); +} + +std::pair Lexer::readName() +{ + LUAU_ASSERT(isAlpha(peekch()) || peekch() == '_'); + + unsigned int startOffset = offset; + + do + consume(); + while (isAlpha(peekch()) || isDigit(peekch()) || peekch() == '_'); + + return readNames ? names.getOrAddWithType(&buffer[startOffset], offset - startOffset) + : names.getWithType(&buffer[startOffset], offset - startOffset); +} + +Lexeme Lexer::readNext() +{ + Position start = position(); + + switch (peekch()) + { + case 0: + return Lexeme(Location(start, 0), Lexeme::Eof); + + case '-': + { + if (peekch(1) == '>') + { + consume(); + consume(); + return Lexeme(Location(start, 2), Lexeme::SkinnyArrow); + } + else if (peekch(1) == '=') + { + consume(); + consume(); + return Lexeme(Location(start, 2), Lexeme::SubAssign); + } + else if (peekch(1) == '-') + { + return readCommentBody(); + } + else + { + consume(); + return Lexeme(Location(start, 1), '-'); + } + } + + case '[': + { + int sep = skipLongSeparator(); + + if (sep >= 0) + { + return readLongString(start, sep, Lexeme::RawString, Lexeme::BrokenString); + } + else if (sep == -1) + { + return Lexeme(Location(start, 1), '['); + } + else + { + return Lexeme(Location(start, position()), Lexeme::BrokenString); + } + } + + case '=': + { + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::Equal); + } + else + return Lexeme(Location(start, 1), '='); + } + + case '<': + { + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::LessEqual); + } + else + return Lexeme(Location(start, 1), '<'); + } + + case '>': + { + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::GreaterEqual); + } + else + return Lexeme(Location(start, 1), '>'); + } + + case '~': + { + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::NotEqual); + } + else + return Lexeme(Location(start, 1), '~'); + } + + case '"': + case '\'': + return readQuotedString(); + + case '.': + consume(); + + if (peekch() == '.') + { + consume(); + + if (peekch() == '.') + { + consume(); + + return Lexeme(Location(start, 3), Lexeme::Dot3); + } + else if (peekch() == '=') + { + consume(); + + return Lexeme(Location(start, 3), Lexeme::ConcatAssign); + } + else + return Lexeme(Location(start, 2), Lexeme::Dot2); + } + else + { + if (isDigit(peekch())) + { + return readNumber(start, offset - 1); + } + else + return Lexeme(Location(start, 1), '.'); + } + + case '+': + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::AddAssign); + } + else + return Lexeme(Location(start, 1), '+'); + + case '/': + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::DivAssign); + } + else + return Lexeme(Location(start, 1), '/'); + + case '*': + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::MulAssign); + } + else + return Lexeme(Location(start, 1), '*'); + + case '%': + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::ModAssign); + } + else + return Lexeme(Location(start, 1), '%'); + + case '^': + consume(); + + if (peekch() == '=') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::PowAssign); + } + else + return Lexeme(Location(start, 1), '^'); + + case ':': + { + consume(); + if (peekch() == ':') + { + consume(); + return Lexeme(Location(start, 2), Lexeme::DoubleColon); + } + else + return Lexeme(Location(start, 1), ':'); + } + + case '(': + case ')': + case '{': + case '}': + case ']': + case ';': + case ',': + case '#': + { + char ch = peekch(); + consume(); + + return Lexeme(Location(start, 1), ch); + } + + default: + if (isDigit(peekch())) + { + return readNumber(start, offset); + } + else if (isAlpha(peekch()) || peekch() == '_') + { + std::pair name = readName(); + + return Lexeme(Location(start, position()), name.second, name.first.value); + } + else if (peekch() & 0x80) + { + return readUtf8Error(); + } + else + { + char ch = peekch(); + consume(); + + return Lexeme(Location(start, 1), ch); + } + } +} + +LUAU_NOINLINE Lexeme Lexer::readUtf8Error() +{ + Position start = position(); + uint32_t codepoint = 0; + int size = 0; + + if ((peekch() & 0b10000000) == 0b00000000) + { + size = 1; + codepoint = peekch() & 0x7F; + } + else if ((peekch() & 0b11100000) == 0b11000000) + { + size = 2; + codepoint = peekch() & 0b11111; + } + else if ((peekch() & 0b11110000) == 0b11100000) + { + size = 3; + codepoint = peekch() & 0b1111; + } + else if ((peekch() & 0b11111000) == 0b11110000) + { + size = 4; + codepoint = peekch() & 0b111; + } + else + { + consume(); + return Lexeme(Location(start, position()), Lexeme::BrokenUnicode); + } + + consume(); + + for (int i = 1; i < size; ++i) + { + if ((peekch() & 0b11000000) != 0b10000000) + return Lexeme(Location(start, position()), Lexeme::BrokenUnicode); + + codepoint = codepoint << 6; + codepoint |= (peekch() & 0b00111111); + consume(); + } + + Lexeme result(Location(start, position()), Lexeme::BrokenUnicode); + result.codepoint = codepoint; + return result; +} + +static size_t toUtf8(char* data, unsigned int code) +{ + // U+0000..U+007F + if (code < 0x80) + { + data[0] = char(code); + return 1; + } + // U+0080..U+07FF + else if (code < 0x800) + { + data[0] = char(0xC0 | (code >> 6)); + data[1] = char(0x80 | (code & 0x3F)); + return 2; + } + // U+0800..U+FFFF + else if (code < 0x10000) + { + data[0] = char(0xE0 | (code >> 12)); + data[1] = char(0x80 | ((code >> 6) & 0x3F)); + data[2] = char(0x80 | (code & 0x3F)); + return 3; + } + // U+10000..U+10FFFF + else if (code < 0x110000) + { + data[0] = char(0xF0 | (code >> 18)); + data[1] = char(0x80 | ((code >> 12) & 0x3F)); + data[2] = char(0x80 | ((code >> 6) & 0x3F)); + data[3] = char(0x80 | (code & 0x3F)); + return 4; + } + else + { + return 0; + } +} + +bool Lexer::fixupQuotedString(std::string& data) +{ + if (data.empty() || data.find('\\') == std::string::npos) + return true; + + size_t size = data.size(); + size_t write = 0; + + for (size_t i = 0; i < size;) + { + if (data[i] != '\\') + { + data[write++] = data[i]; + i++; + continue; + } + + if (i + 1 == size) + return false; + + char escape = data[i + 1]; + i += 2; // skip \e + + switch (escape) + { + case '\n': + data[write++] = '\n'; + break; + + case '\r': + data[write++] = '\n'; + if (i < size && data[i] == '\n') + i++; + break; + + case 0: + return false; + + case 'x': + { + // hex escape codes are exactly 2 hex digits long + if (i + 2 > size) + return false; + + unsigned int code = 0; + + for (int j = 0; j < 2; ++j) + { + char ch = data[i + j]; + if (!isHexDigit(ch)) + return false; + + // use or trick to convert to lower case + code = 16 * code + (isDigit(ch) ? ch - '0' : (ch | ' ') - 'a' + 10); + } + + data[write++] = char(code); + i += 2; + break; + } + + case 'z': + { + while (i < size && isSpace(data[i])) + i++; + break; + } + + case 'u': + { + // unicode escape codes are at least 3 characters including braces + if (i + 3 > size) + return false; + + if (data[i] != '{') + return false; + i++; + + if (data[i] == '}') + return false; + + unsigned int code = 0; + + for (int j = 0; j < 16; ++j) + { + if (i == size) + return false; + + char ch = data[i]; + + if (ch == '}') + break; + + if (!isHexDigit(ch)) + return false; + + // use or trick to convert to lower case + code = 16 * code + (isDigit(ch) ? ch - '0' : (ch | ' ') - 'a' + 10); + i++; + } + + if (i == size || data[i] != '}') + return false; + i++; + + size_t utf8 = toUtf8(&data[write], code); + if (utf8 == 0) + return false; + + write += utf8; + break; + } + + default: + { + if (isDigit(escape)) + { + unsigned int code = escape - '0'; + + for (int j = 0; j < 2; ++j) + { + if (i == size || !isDigit(data[i])) + break; + + code = 10 * code + (data[i] - '0'); + i++; + } + + if (code > UCHAR_MAX) + return false; + + data[write++] = char(code); + } + else + { + data[write++] = unescape(escape); + } + } + } + } + + LUAU_ASSERT(write <= size); + data.resize(write); + + return true; +} + +void Lexer::fixupMultilineString(std::string& data) +{ + if (data.empty()) + return; + + // Lua rules for multiline strings are as follows: + // - standalone \r, \r\n, \n\r and \n are all considered newlines + // - first newline in the multiline string is skipped + // - all other newlines are normalized to \n + + // Since our lexer just treats \n as newlines, we apply a simplified set of rules that is sufficient to get normalized newlines for Windows/Unix: + // - \r\n and \n are considered newlines + // - first newline is skipped + // - newlines are normalized to \n + + // This makes the string parsing behavior consistent with general lexing behavior - a standalone \r isn't considered a new line from the line + // tracking perspective + + const char* src = data.c_str(); + char* dst = &data[0]; + + // skip leading newline + if (src[0] == '\r' && src[1] == '\n') + { + src += 2; + } + else if (src[0] == '\n') + { + src += 1; + } + + // parse the rest of the string, converting newlines as we go + while (*src) + { + if (src[0] == '\r' && src[1] == '\n') + { + *dst++ = '\n'; + src += 2; + } + else // note, this handles \n by just writing it without changes + { + *dst++ = *src; + src += 1; + } + } + + data.resize(dst - &data[0]); +} + +} // namespace Luau diff --git a/Ast/src/Location.cpp b/Ast/src/Location.cpp new file mode 100644 index 0000000..d7a899e --- /dev/null +++ b/Ast/src/Location.cpp @@ -0,0 +1,17 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Location.h" + +namespace Luau +{ + +std::string toString(const Position& position) +{ + return "{ line = " + std::to_string(position.line) + ", col = " + std::to_string(position.column) + " }"; +} + +std::string toString(const Location& location) +{ + return "Location { " + toString(location.begin) + ", " + toString(location.end) + " }"; +} + +} // namespace Luau diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp new file mode 100644 index 0000000..6672efe --- /dev/null +++ b/Ast/src/Parser.cpp @@ -0,0 +1,2693 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" + +#include + +// Warning: If you are introducing new syntax, ensure that it is behind a separate +// flag so that we don't break production games by reverting syntax changes. +// See docs/SyntaxChanges.md for an explanation. +LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) +LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) +LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsParserFix, false) +LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false) +LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) +LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) +LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) + +namespace Luau +{ + +inline bool isSpace(char ch) +{ + return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; +} + +static bool isComment(const Lexeme& lexeme) +{ + return lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment; +} + +ParseError::ParseError(const Location& location, const std::string& message) + : location(location) + , message(message) +{ +} + +const char* ParseError::what() const throw() +{ + return message.c_str(); +} + +const Location& ParseError::getLocation() const +{ + return location; +} + +const std::string& ParseError::getMessage() const +{ + return message; +} + +// LUAU_NOINLINE is used to limit the stack cost of this function due to std::string object / exception plumbing +LUAU_NOINLINE void ParseError::raise(const Location& location, const char* format, ...) +{ + va_list args; + va_start(args, format); + std::string message = vformat(format, args); + va_end(args); + + throw ParseError(location, message); +} + +ParseErrors::ParseErrors(std::vector errors) + : errors(std::move(errors)) +{ + LUAU_ASSERT(!this->errors.empty()); + + if (this->errors.size() == 1) + message = this->errors.front().what(); + else + message = format("%d parse errors", int(this->errors.size())); +} + +const char* ParseErrors::what() const throw() +{ + return message.c_str(); +} + +const std::vector& ParseErrors::getErrors() const +{ + return errors; +} + +template +TempVector::TempVector(std::vector& storage) + : storage(storage) + , offset(storage.size()) + , size_(0) +{ +} + +template +TempVector::~TempVector() +{ + LUAU_ASSERT(storage.size() == offset + size_); + storage.erase(storage.begin() + offset, storage.end()); +} + +template +const T& TempVector::operator[](size_t index) const +{ + LUAU_ASSERT(index < size_); + return storage[offset + index]; +} + +template +const T& TempVector::front() const +{ + LUAU_ASSERT(size_ > 0); + return storage[offset]; +} + +template +const T& TempVector::back() const +{ + LUAU_ASSERT(size_ > 0); + return storage.back(); +} + +template +bool TempVector::empty() const +{ + return size_ == 0; +} + +template +size_t TempVector::size() const +{ + return size_; +} + +template +void TempVector::push_back(const T& item) +{ + LUAU_ASSERT(storage.size() == offset + size_); + storage.push_back(item); + size_++; +} + +static bool shouldParseTypePackAnnotation(Lexer& lexer) +{ + if (lexer.current().type == Lexeme::Dot3) + return true; + else if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == Lexeme::Dot3) + return true; + + return false; +} + +ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options) +{ + Parser p(buffer, bufferSize, names, allocator); + + try + { + std::vector hotcomments; + + while (isComment(p.lexer.current()) || (FFlag::LuauCaptureBrokenCommentSpans && p.lexer.current().type == Lexeme::BrokenComment)) + { + const char* text = p.lexer.current().data; + unsigned int length = p.lexer.current().length; + + if (length && text[0] == '!') + { + unsigned int end = length; + while (end > 0 && isSpace(text[end - 1])) + --end; + + hotcomments.push_back(std::string(text + 1, text + end)); + } + + const Lexeme::Type type = p.lexer.current().type; + const Location loc = p.lexer.current().location; + + p.lexer.next(); + + if (options.captureComments) + p.commentLocations.push_back(Comment{type, loc}); + } + + p.lexer.setSkipComments(true); + + p.options = options; + + AstStatBlock* root = p.parseChunk(); + + return ParseResult{root, hotcomments, p.parseErrors, std::move(p.commentLocations)}; + } + catch (ParseError& err) + { + // when catching a fatal error, append it to the list of non-fatal errors and return + p.parseErrors.push_back(err); + + return ParseResult{nullptr, {}, p.parseErrors}; + } +} + +Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator) + : lexer(buffer, bufferSize, names) + , allocator(allocator) + , recursionCounter(0) + , endMismatchSuspect(Location(), Lexeme::Eof) + , localMap(AstName()) +{ + Function top; + top.vararg = true; + + functionStack.push_back(top); + + nameSelf = names.addStatic("self"); + nameNumber = names.addStatic("number"); + nameError = names.addStatic(errorName); + nameNil = names.getOrAdd("nil"); // nil is a reserved keyword + + matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0); + matchRecoveryStopOnToken[Lexeme::Type::Eof] = 1; + + // read first lexeme + nextLexeme(); +} + +bool Parser::blockFollow(const Lexeme& l) +{ + return l.type == Lexeme::Eof || l.type == Lexeme::ReservedElse || l.type == Lexeme::ReservedElseif || l.type == Lexeme::ReservedEnd || + l.type == Lexeme::ReservedUntil; +} + +AstStatBlock* Parser::parseChunk() +{ + AstStatBlock* result = parseBlock(); + + if (lexer.current().type != Lexeme::Eof) + expectAndConsumeFail(Lexeme::Eof, nullptr); + + return result; +} + +// chunk ::= {stat [`;']} [laststat [`;']] +// block ::= chunk +AstStatBlock* Parser::parseBlock() +{ + unsigned int localsBegin = saveLocals(); + + AstStatBlock* result = parseBlockNoScope(); + + restoreLocals(localsBegin); + + return result; +} + +static bool isStatLast(AstStat* stat) +{ + return stat->is() || stat->is() || stat->is(); +} + +AstStatBlock* Parser::parseBlockNoScope() +{ + TempVector body(scratchStat); + + const Position prevPosition = lexer.previousLocation().end; + + while (!blockFollow(lexer.current())) + { + unsigned int recursionCounterOld = recursionCounter; + + incrementRecursionCounter("block"); + + AstStat* stat = parseStat(); + + recursionCounter = recursionCounterOld; + + if (lexer.current().type == ';') + { + nextLexeme(); + stat->hasSemicolon = true; + } + + body.push_back(stat); + + if (isStatLast(stat)) + break; + } + + const Location location = Location(prevPosition, lexer.current().location.begin); + + return allocator.alloc(location, copy(body)); +} + +// stat ::= +// varlist `=' explist | +// functioncall | +// do block end | +// while exp do block end | +// repeat block until exp | +// if exp then block {elseif exp then block} [else block] end | +// for binding `=' exp `,' exp [`,' exp] do block end | +// for namelist in explist do block end | +// function funcname funcbody | +// local function Name funcbody | +// local namelist [`=' explist] +// laststat ::= return [explist] | break +AstStat* Parser::parseStat() +{ + // guess the type from the token type + switch (lexer.current().type) + { + case Lexeme::ReservedIf: + return parseIf(); + case Lexeme::ReservedWhile: + return parseWhile(); + case Lexeme::ReservedDo: + return parseDo(); + case Lexeme::ReservedFor: + return parseFor(); + case Lexeme::ReservedRepeat: + return parseRepeat(); + case Lexeme::ReservedFunction: + return parseFunctionStat(); + case Lexeme::ReservedLocal: + return parseLocal(); + case Lexeme::ReservedReturn: + return parseReturn(); + case Lexeme::ReservedBreak: + return parseBreak(); + default:; + } + + Location start = lexer.current().location; + + // we need to disambiguate a few cases, primarily assignment (lvalue = ...) vs statements-that-are calls + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); + + if (expr->is()) + return allocator.alloc(expr->location, expr); + + // if the next token is , or =, it's an assignment (, means it's an assignment with multiple variables) + if (lexer.current().type == ',' || lexer.current().type == '=') + return parseAssignment(expr); + + // if the next token is a compound assignment operator, it's a compound assignment (these don't support multiple variables) + if (std::optional op = parseCompoundOp(lexer.current())) + return parseCompoundAssignment(expr, *op); + + // we know this isn't a call or an assignment; therefore it must be a context-sensitive keyword such as `type` or `continue` + AstName ident = getIdentifier(expr); + + if (options.allowTypeAnnotations) + { + if (ident == "type") + return parseTypeAlias(expr->location, /* exported =*/false); + if (ident == "export" && lexer.current().type == Lexeme::Name && AstName(lexer.current().name) == "type") + { + nextLexeme(); + return parseTypeAlias(expr->location, /* exported =*/true); + } + } + + if (options.supportContinueStatement && ident == "continue") + return parseContinue(expr->location); + + if (options.allowTypeAnnotations && options.allowDeclarationSyntax) + { + if (ident == "declare") + return parseDeclaration(expr->location); + } + + // skip unexpected symbol if lexer couldn't advance at all (statements are parsed in a loop) + if (start == lexer.current().location) + nextLexeme(); + + return reportStatError(expr->location, copy({expr}), {}, "Incomplete statement: expected assignment or a function call"); +} + +// if exp then block {elseif exp then block} [else block] end +AstStat* Parser::parseIf() +{ + Location start = lexer.current().location; + + nextLexeme(); // if / elseif + + AstExpr* cond = parseExpr(); + + Lexeme matchThen = lexer.current(); + bool hasThen = expectAndConsume(Lexeme::ReservedThen, "if statement"); + + AstStatBlock* thenbody = parseBlock(); + + AstStat* elsebody = nullptr; + Location end = start; + std::optional elseLocation; + bool hasEnd = false; + + if (lexer.current().type == Lexeme::ReservedElseif) + { + if (FFlag::LuauIfStatementRecursionGuard) + { + unsigned int recursionCounterOld = recursionCounter; + incrementRecursionCounter("elseif"); + elseLocation = lexer.current().location; + elsebody = parseIf(); + end = elsebody->location; + hasEnd = elsebody->as()->hasEnd; + recursionCounter = recursionCounterOld; + } + else + { + elseLocation = lexer.current().location; + elsebody = parseIf(); + end = elsebody->location; + hasEnd = elsebody->as()->hasEnd; + } + } + else + { + Lexeme matchThenElse = matchThen; + + if (lexer.current().type == Lexeme::ReservedElse) + { + elseLocation = lexer.current().location; + matchThenElse = lexer.current(); + nextLexeme(); + + elsebody = parseBlock(); + elsebody->location.begin = matchThenElse.location.end; + } + + end = lexer.current().location; + + hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchThenElse); + } + + return allocator.alloc(Location(start, end), cond, thenbody, elsebody, hasThen, matchThen.location, elseLocation, hasEnd); +} + +// while exp do block end +AstStat* Parser::parseWhile() +{ + Location start = lexer.current().location; + + nextLexeme(); // while + + AstExpr* cond = parseExpr(); + + Lexeme matchDo = lexer.current(); + bool hasDo = expectAndConsume(Lexeme::ReservedDo, "while loop"); + + functionStack.back().loopDepth++; + + AstStatBlock* body = parseBlock(); + + functionStack.back().loopDepth--; + + Location end = lexer.current().location; + + bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + + return allocator.alloc(Location(start, end), cond, body, hasDo, matchDo.location, hasEnd); +} + +// repeat block until exp +AstStat* Parser::parseRepeat() +{ + Location start = lexer.current().location; + + Lexeme matchRepeat = lexer.current(); + nextLexeme(); // repeat + + unsigned int localsBegin = saveLocals(); + + functionStack.back().loopDepth++; + + AstStatBlock* body = parseBlockNoScope(); + + functionStack.back().loopDepth--; + + bool hasUntil = expectMatchEndAndConsume(Lexeme::ReservedUntil, matchRepeat); + + AstExpr* cond = parseExpr(); + + restoreLocals(localsBegin); + + return allocator.alloc(Location(start, cond->location), cond, body, hasUntil); +} + +// do block end +AstStat* Parser::parseDo() +{ + Location start = lexer.current().location; + + Lexeme matchDo = lexer.current(); + nextLexeme(); // do + + AstStat* body = parseBlock(); + + body->location.begin = start.begin; + + expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + + return body; +} + +// break +AstStat* Parser::parseBreak() +{ + Location start = lexer.current().location; + + nextLexeme(); // break + + if (functionStack.back().loopDepth == 0) + return reportStatError(start, {}, copy({allocator.alloc(start)}), "break statement must be inside a loop"); + + return allocator.alloc(start); +} + +// continue +AstStat* Parser::parseContinue(const Location& start) +{ + if (functionStack.back().loopDepth == 0) + return reportStatError(start, {}, copy({allocator.alloc(start)}), "continue statement must be inside a loop"); + + // note: the token is already parsed for us! + + return allocator.alloc(start); +} + +// for binding `=' exp `,' exp [`,' exp] do block end | +// for bindinglist in explist do block end | +AstStat* Parser::parseFor() +{ + Location start = lexer.current().location; + + nextLexeme(); // for + + Binding varname = parseBinding(); + + if (lexer.current().type == '=') + { + nextLexeme(); + + AstExpr* from = parseExpr(); + + expectAndConsume(',', "index range"); + + AstExpr* to = parseExpr(); + + AstExpr* step = nullptr; + + if (lexer.current().type == ',') + { + nextLexeme(); + + step = parseExpr(); + } + + Lexeme matchDo = lexer.current(); + bool hasDo = expectAndConsume(Lexeme::ReservedDo, "for loop"); + + unsigned int localsBegin = saveLocals(); + + functionStack.back().loopDepth++; + + AstLocal* var = pushLocal(varname); + + AstStatBlock* body = parseBlock(); + + functionStack.back().loopDepth--; + + restoreLocals(localsBegin); + + Location end = lexer.current().location; + + bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + + return allocator.alloc(Location(start, end), var, from, to, step, body, hasDo, matchDo.location, hasEnd); + } + else + { + TempVector names(scratchBinding); + names.push_back(varname); + + if (lexer.current().type == ',') + { + nextLexeme(); + + parseBindingList(names); + } + + Location inLocation = lexer.current().location; + bool hasIn = expectAndConsume(Lexeme::ReservedIn, "for loop"); + + TempVector values(scratchExpr); + parseExprList(values); + + Lexeme matchDo = lexer.current(); + bool hasDo = expectAndConsume(Lexeme::ReservedDo, "for loop"); + + unsigned int localsBegin = saveLocals(); + + functionStack.back().loopDepth++; + + TempVector vars(scratchLocal); + + for (size_t i = 0; i < names.size(); ++i) + vars.push_back(pushLocal(names[i])); + + AstStatBlock* body = parseBlock(); + + functionStack.back().loopDepth--; + + restoreLocals(localsBegin); + + Location end = lexer.current().location; + + bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchDo); + + return allocator.alloc( + Location(start, end), copy(vars), copy(values), body, hasIn, inLocation, hasDo, matchDo.location, hasEnd); + } +} + +// function funcname funcbody | +// funcname ::= Name {`.' Name} [`:' Name] +AstStat* Parser::parseFunctionStat() +{ + Location start = lexer.current().location; + + Lexeme matchFunction = lexer.current(); + nextLexeme(); + + AstName debugname = (lexer.current().type == Lexeme::Name) ? AstName(lexer.current().name) : AstName(); + + // parse funcname into a chain of indexing operators + AstExpr* expr = parseNameExpr("function name"); + + unsigned int recursionCounterOld = recursionCounter; + + while (lexer.current().type == '.') + { + Position opPosition = lexer.current().location.begin; + nextLexeme(); + + Name name = parseName("field name"); + + // while we could concatenate the name chain, for now let's just write the short name + debugname = name.name; + + expr = allocator.alloc(Location(start, name.location), expr, name.name, name.location, opPosition, '.'); + + // note: while the parser isn't recursive here, we're generating recursive structures of unbounded depth + incrementRecursionCounter("function name"); + } + + recursionCounter = recursionCounterOld; + + // finish with : + bool hasself = false; + + if (lexer.current().type == ':') + { + Position opPosition = lexer.current().location.begin; + nextLexeme(); + + Name name = parseName("method name"); + + // while we could concatenate the name chain, for now let's just write the short name + debugname = name.name; + + expr = allocator.alloc(Location(start, name.location), expr, name.name, name.location, opPosition, ':'); + + hasself = true; + } + + matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; + + AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, {}).first; + + matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; + + return allocator.alloc(Location(start, body->location), expr, body); +} + +// local function Name funcbody | +// local bindinglist [`=' explist] +AstStat* Parser::parseLocal() +{ + Location start = lexer.current().location; + + nextLexeme(); // local + + if (lexer.current().type == Lexeme::ReservedFunction) + { + Lexeme matchFunction = lexer.current(); + nextLexeme(); + + // matchFunction is only used for diagnostics; to make it suitable for detecting missed indentation between + // `local function` and `end`, we patch the token to begin at the column where `local` starts + if (matchFunction.location.begin.line == start.begin.line) + matchFunction.location.begin.column = start.begin.column; + + Name name = parseName("variable name"); + + matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; + + auto [body, var] = parseFunctionBody(false, matchFunction, name.name, name); + + matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; + + Location location{start.begin, body->location.end}; + + return allocator.alloc(location, var, body); + } + else + { + matchRecoveryStopOnToken['=']++; + + TempVector names(scratchBinding); + parseBindingList(names); + + matchRecoveryStopOnToken['=']--; + + TempVector vars(scratchLocal); + + TempVector values(scratchExpr); + + std::optional equalsSignLocation; + + if (lexer.current().type == '=') + { + equalsSignLocation = lexer.current().location; + + nextLexeme(); + + parseExprList(values); + } + + for (size_t i = 0; i < names.size(); ++i) + vars.push_back(pushLocal(names[i])); + + Location end = values.empty() ? lexer.previousLocation() : values.back()->location; + + return allocator.alloc(Location(start, end), copy(vars), copy(values), equalsSignLocation); + } +} + +// return [explist] +AstStat* Parser::parseReturn() +{ + Location start = lexer.current().location; + + nextLexeme(); + + TempVector list(scratchExpr); + + if (!blockFollow(lexer.current()) && lexer.current().type != ';') + parseExprList(list); + + Location end = list.empty() ? start : list.back()->location; + + return allocator.alloc(Location(start, end), copy(list)); +} + +// type Name [`<' varlist `>'] `=' typeannotation +AstStat* Parser::parseTypeAlias(const Location& start, bool exported) +{ + // note: `type` token is already parsed for us, so we just need to parse the rest + + auto name = parseNameOpt("type name"); + + // Use error name if the name is missing + if (!name) + name = Name(nameError, lexer.current().location); + + // TODO: support generic type pack parameters in type aliases CLI-39907 + auto [generics, genericPacks] = parseGenericTypeList(); + + expectAndConsume('=', "type alias"); + + AstType* type = parseTypeAnnotation(); + + return allocator.alloc(Location(start, type->location), name->name, generics, type, exported); +} + +AstDeclaredClassProp Parser::parseDeclaredClassMethod() +{ + nextLexeme(); + Location start = lexer.current().location; + Name fnName = parseName("function name"); + + // TODO: generic method declarations CLI-39909 + AstArray generics; + AstArray genericPacks; + generics.size = 0; + generics.data = nullptr; + genericPacks.size = 0; + genericPacks.data = nullptr; + + Lexeme matchParen = lexer.current(); + expectAndConsume('(', "function parameter list start"); + + TempVector args(scratchBinding); + + std::optional vararg = std::nullopt; + AstTypePack* varargAnnotation = nullptr; + if (lexer.current().type != ')') + std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3 */ true); + + expectMatchAndConsume(')', matchParen); + + AstTypeList retTypes = parseOptionalReturnTypeAnnotation().value_or(AstTypeList{copy(nullptr, 0), nullptr}); + Location end = lexer.current().location; + + TempVector vars(scratchAnnotation); + TempVector> varNames(scratchOptArgName); + + if (args.size() == 0 || args[0].name.name != "self" || args[0].annotation != nullptr) + { + return AstDeclaredClassProp{fnName.name, + reportTypeAnnotationError(Location(start, end), {}, /*isMissing*/ false, "'self' must be present as the unannotated first parameter"), + true}; + } + + // Skip the first index. + for (size_t i = 1; i < args.size(); ++i) + { + varNames.push_back(AstArgumentName{args[i].name.name, args[i].name.location}); + + if (args[i].annotation) + vars.push_back(args[i].annotation); + else + vars.push_back(reportTypeAnnotationError( + Location(start, end), {}, /*isMissing*/ false, "All declaration parameters aside from 'self' must be annotated")); + } + + if (vararg && !varargAnnotation) + report(start, "All declaration parameters aside from 'self' must be annotated"); + + AstType* fnType = allocator.alloc( + Location(start, end), generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); + + return AstDeclaredClassProp{fnName.name, fnType, true}; +} + +AstStat* Parser::parseDeclaration(const Location& start) +{ + // `declare` token is already parsed at this point + if (lexer.current().type == Lexeme::ReservedFunction) + { + nextLexeme(); + Name globalName = parseName("global function name"); + + auto [generics, genericPacks] = parseGenericTypeList(); + + Lexeme matchParen = lexer.current(); + + expectAndConsume('(', "global function declaration"); + + TempVector args(scratchBinding); + + std::optional vararg; + AstTypePack* varargAnnotation = nullptr; + + if (lexer.current().type != ')') + std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); + + expectMatchAndConsume(')', matchParen); + + AstTypeList retTypes = parseOptionalReturnTypeAnnotation().value_or(AstTypeList{copy(nullptr, 0)}); + Location end = lexer.current().location; + + TempVector vars(scratchAnnotation); + TempVector varNames(scratchArgName); + + for (size_t i = 0; i < args.size(); ++i) + { + if (!args[i].annotation) + return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); + + vars.push_back(args[i].annotation); + varNames.push_back({args[i].name.name, args[i].name.location}); + } + + if (vararg && !varargAnnotation) + return reportStatError(Location(start, end), {}, {}, "All declaration parameters must be annotated"); + + return allocator.alloc( + Location(start, end), globalName.name, generics, genericPacks, AstTypeList{copy(vars), varargAnnotation}, copy(varNames), retTypes); + } + else if (AstName(lexer.current().name) == "class") + { + nextLexeme(); + Location classStart = lexer.current().location; + Name className = parseName("class name"); + std::optional superName = std::nullopt; + + if (AstName(lexer.current().name) == "extends") + { + nextLexeme(); + superName = parseName("superclass name").name; + } + + TempVector props(scratchDeclaredClassProps); + + while (lexer.current().type != Lexeme::ReservedEnd) + { + // There are two possibilities: Either it's a property or a function. + if (lexer.current().type == Lexeme::ReservedFunction) + { + props.push_back(parseDeclaredClassMethod()); + } + else + { + Name propName = parseName("property name"); + expectAndConsume(':', "property type annotation"); + AstType* propType = parseTypeAnnotation(); + props.push_back(AstDeclaredClassProp{propName.name, propType, false}); + } + } + + Location classEnd = lexer.current().location; + nextLexeme(); // skip past `end` + + return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props)); + } + else if (auto globalName = parseNameOpt("global variable name")) + { + expectAndConsume(':', "global variable declaration"); + + AstType* type = parseTypeAnnotation(); + return allocator.alloc(Location(start, type->location), globalName->name, type); + } + else + { + return reportStatError(start, {}, {}, "declare must be followed by an identifier, 'function', or 'class'"); + } +} + +static bool isExprLValue(AstExpr* expr) +{ + return expr->is() || expr->is() || expr->is() || expr->is(); +} + +// varlist `=' explist +AstStat* Parser::parseAssignment(AstExpr* initial) +{ + if (!isExprLValue(initial)) + initial = reportExprError(initial->location, copy({initial}), "Assigned expression must be a variable or a field"); + + TempVector vars(scratchExpr); + vars.push_back(initial); + + while (lexer.current().type == ',') + { + nextLexeme(); + + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ false); + + if (!isExprLValue(expr)) + expr = reportExprError(expr->location, copy({expr}), "Assigned expression must be a variable or a field"); + + vars.push_back(expr); + } + + expectAndConsume('=', "assignment"); + + TempVector values(scratchExprAux); + parseExprList(values); + + return allocator.alloc(Location(initial->location, values.back()->location), copy(vars), copy(values)); +} + +// var [`+=' | `-=' | `*=' | `/=' | `%=' | `^=' | `..='] exp +AstStat* Parser::parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op) +{ + if (!isExprLValue(initial)) + { + initial = reportExprError(initial->location, copy({initial}), "Assigned expression must be a variable or a field"); + } + + nextLexeme(); + + AstExpr* value = parseExpr(); + + return allocator.alloc(Location(initial->location, value->location), op, initial, value); +} + +// funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end +// parlist ::= bindinglist [`,' `...'] | `...' +std::pair Parser::parseFunctionBody( + bool hasself, const Lexeme& matchFunction, const AstName& debugname, std::optional localName) +{ + Location start = matchFunction.location; + + auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions(); + + Lexeme matchParen = lexer.current(); + expectAndConsume('(', "function"); + + TempVector args(scratchBinding); + + std::optional vararg; + AstTypePack* varargAnnotation = nullptr; + + if (lexer.current().type != ')') + std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); + + std::optional argLocation = matchParen.type == Lexeme::Type('(') && lexer.current().type == Lexeme::Type(')') + ? std::make_optional(Location(matchParen.location.begin, lexer.current().location.end)) + : std::nullopt; + expectMatchAndConsume(')', matchParen, true); + + std::optional typelist = parseOptionalReturnTypeAnnotation(); + + AstLocal* funLocal = nullptr; + + if (localName) + funLocal = pushLocal(Binding(*localName, nullptr)); + + unsigned int localsBegin = saveLocals(); + + Function fun; + fun.vararg = vararg.has_value(); + + functionStack.push_back(fun); + + AstLocal* self = nullptr; + + if (hasself) + self = pushLocal(Binding(Name(nameSelf, start), nullptr)); + + TempVector vars(scratchLocal); + + for (size_t i = 0; i < args.size(); ++i) + vars.push_back(pushLocal(args[i])); + + AstStatBlock* body = parseBlock(); + + functionStack.pop_back(); + + restoreLocals(localsBegin); + + Location end = lexer.current().location; + + bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); + + return {allocator.alloc(Location(start, end), generics, genericPacks, self, copy(vars), vararg, body, functionStack.size(), + debugname, typelist, varargAnnotation, hasEnd, argLocation), + funLocal}; +} + +// explist ::= {exp `,'} exp +void Parser::parseExprList(TempVector& result) +{ + result.push_back(parseExpr()); + + while (lexer.current().type == ',') + { + nextLexeme(); + + result.push_back(parseExpr()); + } +} + +Parser::Binding Parser::parseBinding() +{ + auto name = parseNameOpt("variable name"); + + // Use placeholder if the name is missing + if (!name) + name = Name(nameError, lexer.current().location); + + AstType* annotation = parseOptionalTypeAnnotation(); + + return Binding(*name, annotation); +} + +// bindinglist ::= (binding | `...') [`,' bindinglist] +std::pair, AstTypePack*> Parser::parseBindingList(TempVector& result, bool allowDot3) +{ + while (true) + { + if (lexer.current().type == Lexeme::Dot3 && allowDot3) + { + Location varargLocation = lexer.current().location; + nextLexeme(); + + AstTypePack* tailAnnotation = nullptr; + if (lexer.current().type == ':') + { + nextLexeme(); + tailAnnotation = parseVariadicArgumentAnnotation(); + } + + return {varargLocation, tailAnnotation}; + } + + result.push_back(parseBinding()); + + if (lexer.current().type != ',') + break; + nextLexeme(); + } + + return {std::nullopt, nullptr}; +} + +AstType* Parser::parseOptionalTypeAnnotation() +{ + if (options.allowTypeAnnotations && lexer.current().type == ':') + { + nextLexeme(); + return parseTypeAnnotation(); + } + else + return nullptr; +} + +// TypeList ::= TypeAnnotation [`,' TypeList] | ...TypeAnnotation +AstTypePack* Parser::parseTypeList(TempVector& result, TempVector>& resultNames) +{ + while (true) + { + if (shouldParseTypePackAnnotation(lexer)) + return parseTypePackAnnotation(); + + if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':') + { + // Fill in previous argument names with empty slots + while (resultNames.size() < result.size()) + resultNames.push_back({}); + + resultNames.push_back(AstArgumentName{AstName(lexer.current().name), lexer.current().location}); + lexer.next(); + + expectAndConsume(':'); + } + else if (!resultNames.empty()) + { + // If we have a type with named arguments, provide elements for all types + resultNames.push_back({}); + } + + result.push_back(parseTypeAnnotation()); + if (lexer.current().type != ',') + break; + nextLexeme(); + } + + return nullptr; +} + +std::optional Parser::parseOptionalReturnTypeAnnotation() +{ + if (options.allowTypeAnnotations && lexer.current().type == ':') + { + nextLexeme(); + + unsigned int oldRecursionCount = recursionCounter; + + auto [_location, result] = parseReturnTypeAnnotation(); + + // At this point, if we find a , character, it indicates that there are multiple return types + // in this type annotation, but the list wasn't wrapped in parentheses. + if (lexer.current().type == ',') + { + report(lexer.current().location, "Expected a statement, got ','; did you forget to wrap the list of return types in parentheses?"); + + nextLexeme(); + } + + recursionCounter = oldRecursionCount; + + return result; + } + + return std::nullopt; +} + +// ReturnType ::= TypeAnnotation | `(' TypeList `)' +std::pair Parser::parseReturnTypeAnnotation() +{ + incrementRecursionCounter("type annotation"); + + TempVector result(scratchAnnotation); + TempVector> resultNames(scratchOptArgName); + AstTypePack* varargAnnotation = nullptr; + + Lexeme begin = lexer.current(); + + if (lexer.current().type != '(') + { + if (shouldParseTypePackAnnotation(lexer)) + varargAnnotation = parseTypePackAnnotation(); + else + result.push_back(parseTypeAnnotation()); + + Location resultLocation = result.size() == 0 ? varargAnnotation->location : result[0]->location; + + return {resultLocation, AstTypeList{copy(result), varargAnnotation}}; + } + + nextLexeme(); + + Location innerBegin = lexer.current().location; + + matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; + + // possibly () -> ReturnType + if (lexer.current().type != ')') + varargAnnotation = parseTypeList(result, resultNames); + + const Location location{begin.location, lexer.current().location}; + + expectMatchAndConsume(')', begin, true); + + matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; + + if (lexer.current().type != Lexeme::SkinnyArrow && resultNames.empty()) + { + // If it turns out that it's just '(A)', it's possible that there are unions/intersections to follow, so fold over it. + if (result.size() == 1) + { + AstType* returnType = parseTypeAnnotation(result, innerBegin); + + return {Location{location, returnType->location}, AstTypeList{copy(&returnType, 1), varargAnnotation}}; + } + + return {location, AstTypeList{copy(result), varargAnnotation}}; + } + + AstArray generics{nullptr, 0}; + AstArray genericPacks{nullptr, 0}; + AstArray types = copy(result); + AstArray> names = copy(resultNames); + + TempVector fallbackReturnTypes(scratchAnnotation); + fallbackReturnTypes.push_back(parseFunctionTypeAnnotationTail(begin, generics, genericPacks, types, names, varargAnnotation)); + + return {Location{location, fallbackReturnTypes[0]->location}, AstTypeList{copy(fallbackReturnTypes), varargAnnotation}}; +} + +// TableIndexer ::= `[' TypeAnnotation `]' `:' TypeAnnotation +AstTableIndexer* Parser::parseTableIndexerAnnotation() +{ + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + + AstType* index = parseTypeAnnotation(); + + expectMatchAndConsume(']', begin); + + expectAndConsume(':', "table field"); + + AstType* result = parseTypeAnnotation(); + + return allocator.alloc(AstTableIndexer{index, result, Location(begin.location, result->location)}); +} + +// TableProp ::= Name `:' TypeAnnotation +// TablePropOrIndexer ::= TableProp | TableIndexer +// PropList ::= TablePropOrIndexer {fieldsep TablePropOrIndexer} [fieldsep] +// TableTypeAnnotation ::= `{' PropList `}' +AstType* Parser::parseTableTypeAnnotation() +{ + incrementRecursionCounter("type annotation"); + + TempVector props(scratchTableTypeProps); + AstTableIndexer* indexer = nullptr; + + Location start = lexer.current().location; + + Lexeme matchBrace = lexer.current(); + expectAndConsume('{', "table type"); + + while (lexer.current().type != '}') + { + if (lexer.current().type == '[') + { + if (indexer) + { + // maybe we don't need to parse the entire badIndexer... + // however, we either have { or [ to lint, not the entire table type or the bad indexer. + AstTableIndexer* badIndexer = parseTableIndexerAnnotation(); + + // we lose all additional indexer expressions from the AST after error recovery here + report(badIndexer->location, "Cannot have more than one table indexer"); + } + else + { + indexer = parseTableIndexerAnnotation(); + } + } + else if (props.empty() && !indexer && !(lexer.current().type == Lexeme::Name && lexer.lookahead().type == ':')) + { + AstType* type = parseTypeAnnotation(); + + // array-like table type: {T} desugars into {[number]: T} + AstType* index = allocator.alloc(type->location, std::nullopt, nameNumber); + indexer = allocator.alloc(AstTableIndexer{index, type, type->location}); + + break; + } + else + { + auto name = parseNameOpt("table field"); + + if (!name) + break; + + expectAndConsume(':', "table field"); + + AstType* type = parseTypeAnnotation(); + + props.push_back({name->name, name->location, type}); + } + + if (lexer.current().type == ',' || lexer.current().type == ';') + { + nextLexeme(); + } + else + { + if (lexer.current().type != '}') + break; + } + } + + Location end = lexer.current().location; + + if (!expectMatchAndConsume('}', matchBrace)) + end = lexer.previousLocation(); + + return allocator.alloc(Location(start, end), copy(props), indexer); +} + +// ReturnType ::= TypeAnnotation | `(' TypeList `)' +// FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType +AstType* Parser::parseFunctionTypeAnnotation() +{ + incrementRecursionCounter("type annotation"); + + bool monomorphic = !(FFlag::LuauParseGenericFunctions && lexer.current().type == '<'); + + auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions(); + + Lexeme begin = lexer.current(); + + if (FFlag::LuauGenericFunctionsParserFix) + expectAndConsume('(', "function parameters"); + else + { + LUAU_ASSERT(begin.type == '('); + nextLexeme(); // ( + } + + matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; + + TempVector params(scratchAnnotation); + TempVector> names(scratchOptArgName); + AstTypePack* varargAnnotation = nullptr; + + if (lexer.current().type != ')') + varargAnnotation = parseTypeList(params, names); + + expectMatchAndConsume(')', begin, true); + + matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; + + // Not a function at all. Just a parenthesized type. + if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) + return params[0]; + + AstArray paramTypes = copy(params); + AstArray> paramNames = copy(names); + + return parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation); +} + +AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation) + +{ + incrementRecursionCounter("type annotation"); + + // Users occasionally write '()' as the 'unit' type when they actually want to use 'nil', here we'll try to give a more specific error + if (lexer.current().type != Lexeme::SkinnyArrow && generics.size == 0 && genericPacks.size == 0 && params.size == 0) + { + report(Location(begin.location, lexer.previousLocation()), "Expected '->' after '()' when parsing function type; did you mean 'nil'?"); + + return allocator.alloc(begin.location, std::nullopt, nameNil); + } + else + { + expectAndConsume(Lexeme::SkinnyArrow, "function type"); + } + + const auto [endLocation, returnTypeList] = parseReturnTypeAnnotation(); + + AstTypeList paramTypes = AstTypeList{params, varargAnnotation}; + return allocator.alloc(Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList); +} + +// typeannotation ::= +// nil | +// Name[`.' Name] [`<' namelist `>'] | +// `{' [PropList] `}' | +// `(' [TypeList] `)' `->` ReturnType +// `typeof` typeannotation +AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location& begin) +{ + LUAU_ASSERT(!parts.empty()); + + incrementRecursionCounter("type annotation"); + + bool isUnion = false; + bool isIntersection = false; + + Location location = begin; + + while (true) + { + Lexeme::Type c = lexer.current().type; + if (c == '|') + { + nextLexeme(); + parts.push_back(parseSimpleTypeAnnotation()); + isUnion = true; + } + else if (c == '?') + { + Location loc = lexer.current().location; + nextLexeme(); + parts.push_back(allocator.alloc(loc, std::nullopt, nameNil)); + isUnion = true; + } + else if (c == '&') + { + nextLexeme(); + parts.push_back(parseSimpleTypeAnnotation()); + isIntersection = true; + } + else + break; + } + + if (parts.size() == 1) + return parts[0]; + + if (isUnion && isIntersection) + { + return reportTypeAnnotationError(Location(begin, parts.back()->location), copy(parts), /*isMissing*/ false, + "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); + } + + location.end = parts.back()->location.end; + + if (isUnion) + return allocator.alloc(location, copy(parts)); + + if (isIntersection) + return allocator.alloc(location, copy(parts)); + + LUAU_ASSERT(false); + ParseError::raise(begin, "Composite type was not an intersection or union."); +} + +AstType* Parser::parseTypeAnnotation() +{ + unsigned int oldRecursionCount = recursionCounter; + incrementRecursionCounter("type annotation"); + + Location begin = lexer.current().location; + + TempVector parts(scratchAnnotation); + parts.push_back(parseSimpleTypeAnnotation()); + + recursionCounter = oldRecursionCount; + + return parseTypeAnnotation(parts, begin); +} + +// typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' +// | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType +AstType* Parser::parseSimpleTypeAnnotation() +{ + incrementRecursionCounter("type annotation"); + + Location begin = lexer.current().location; + + if (lexer.current().type == Lexeme::ReservedNil) + { + nextLexeme(); + return allocator.alloc(begin, std::nullopt, nameNil); + } + else if (lexer.current().type == Lexeme::Name) + { + std::optional prefix; + Name name = parseName("type name"); + + if (lexer.current().type == '.') + { + Position pointPosition = lexer.current().location.begin; + nextLexeme(); + + prefix = name.name; + name = parseIndexName("field name", pointPosition); + } + else if (name.name == "typeof") + { + Lexeme typeofBegin = lexer.current(); + expectAndConsume('(', "typeof type"); + + AstExpr* expr = parseExpr(); + + Location end = lexer.current().location; + + expectMatchAndConsume(')', typeofBegin); + + return allocator.alloc(Location(begin, end), expr); + } + + AstArray generics = parseTypeParams(); + + Location end = lexer.previousLocation(); + + return allocator.alloc(Location(begin, end), prefix, name.name, generics); + } + else if (lexer.current().type == '{') + { + return parseTableTypeAnnotation(); + } + else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<')) + { + return parseFunctionTypeAnnotation(); + } + else + { + Location location = lexer.current().location; + + // For a missing type annoation, capture 'space' between last token and the next one + location = Location(lexer.previousLocation().end, lexer.current().location.begin); + + return reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()); + } +} + +AstTypePack* Parser::parseVariadicArgumentAnnotation() +{ + // Generic: a... + if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == Lexeme::Dot3) + { + Name name = parseName("generic name"); + Location end = lexer.current().location; + + // This will not fail because of the lookahead guard. + expectAndConsume(Lexeme::Dot3, "generic type pack annotation"); + return allocator.alloc(Location(name.location, end), name.name); + } + // Variadic: T + else + { + AstType* variadicAnnotation = parseTypeAnnotation(); + return allocator.alloc(variadicAnnotation->location, variadicAnnotation); + } +} + +AstTypePack* Parser::parseTypePackAnnotation() +{ + // Variadic: ...T + if (lexer.current().type == Lexeme::Dot3) + { + Location start = lexer.current().location; + nextLexeme(); + AstType* varargTy = parseTypeAnnotation(); + return allocator.alloc(Location(start, varargTy->location), varargTy); + } + // Generic: a... + else if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == Lexeme::Dot3) + { + Name name = parseName("generic name"); + Location end = lexer.current().location; + + // This will not fail because of the lookahead guard. + expectAndConsume(Lexeme::Dot3, "generic type pack annotation"); + return allocator.alloc(Location(name.location, end), name.name); + } + + // No type pack annotation exists here. + return nullptr; +} + +std::optional Parser::parseUnaryOp(const Lexeme& l) +{ + if (l.type == Lexeme::ReservedNot) + return AstExprUnary::Not; + else if (l.type == '-') + return AstExprUnary::Minus; + else if (l.type == '#') + return AstExprUnary::Len; + else + return std::nullopt; +} + +std::optional Parser::parseBinaryOp(const Lexeme& l) +{ + if (l.type == '+') + return AstExprBinary::Add; + else if (l.type == '-') + return AstExprBinary::Sub; + else if (l.type == '*') + return AstExprBinary::Mul; + else if (l.type == '/') + return AstExprBinary::Div; + else if (l.type == '%') + return AstExprBinary::Mod; + else if (l.type == '^') + return AstExprBinary::Pow; + else if (l.type == Lexeme::Dot2) + return AstExprBinary::Concat; + else if (l.type == Lexeme::NotEqual) + return AstExprBinary::CompareNe; + else if (l.type == Lexeme::Equal) + return AstExprBinary::CompareEq; + else if (l.type == '<') + return AstExprBinary::CompareLt; + else if (l.type == Lexeme::LessEqual) + return AstExprBinary::CompareLe; + else if (l.type == '>') + return AstExprBinary::CompareGt; + else if (l.type == Lexeme::GreaterEqual) + return AstExprBinary::CompareGe; + else if (l.type == Lexeme::ReservedAnd) + return AstExprBinary::And; + else if (l.type == Lexeme::ReservedOr) + return AstExprBinary::Or; + else + return std::nullopt; +} + +std::optional Parser::parseCompoundOp(const Lexeme& l) +{ + if (l.type == Lexeme::AddAssign) + return AstExprBinary::Add; + else if (l.type == Lexeme::SubAssign) + return AstExprBinary::Sub; + else if (l.type == Lexeme::MulAssign) + return AstExprBinary::Mul; + else if (l.type == Lexeme::DivAssign) + return AstExprBinary::Div; + else if (l.type == Lexeme::ModAssign) + return AstExprBinary::Mod; + else if (l.type == Lexeme::PowAssign) + return AstExprBinary::Pow; + else if (l.type == Lexeme::ConcatAssign) + return AstExprBinary::Concat; + else + return std::nullopt; +} + +std::optional Parser::checkUnaryConfusables() +{ + const Lexeme& curr = lexer.current(); + + // early-out: need to check if this is a possible confusable quickly + if (curr.type != '!') + return {}; + + // slow path: possible confusable + Location start = curr.location; + + if (curr.type == '!') + { + report(start, "Unexpected '!', did you mean 'not'?"); + return AstExprUnary::Not; + } + + return {}; +} + +std::optional Parser::checkBinaryConfusables(const BinaryOpPriority binaryPriority[], unsigned int limit) +{ + const Lexeme& curr = lexer.current(); + + // early-out: need to check if this is a possible confusable quickly + if (curr.type != '&' && curr.type != '|' && curr.type != '!') + return {}; + + // slow path: possible confusable + Location start = curr.location; + Lexeme next = lexer.lookahead(); + + if (curr.type == '&' && next.type == '&' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::And].left > limit) + { + nextLexeme(); + report(Location(start, next.location), "Unexpected '&&', did you mean 'and'?"); + return AstExprBinary::And; + } + else if (curr.type == '|' && next.type == '|' && curr.location.end == next.location.begin && binaryPriority[AstExprBinary::Or].left > limit) + { + nextLexeme(); + report(Location(start, next.location), "Unexpected '||', did you mean 'or'?"); + return AstExprBinary::Or; + } + else if (curr.type == '!' && next.type == '=' && curr.location.end == next.location.begin && + binaryPriority[AstExprBinary::CompareNe].left > limit) + { + nextLexeme(); + report(Location(start, next.location), "Unexpected '!=', did you mean '~='?"); + return AstExprBinary::CompareNe; + } + + return std::nullopt; +} + +// subexpr -> (asexp | unop subexpr) { binop subexpr } +// where `binop' is any binary operator with a priority higher than `limit' +AstExpr* Parser::parseExpr(unsigned int limit) +{ + static const BinaryOpPriority binaryPriority[] = { + {6, 6}, {6, 6}, {7, 7}, {7, 7}, {7, 7}, // `+' `-' `*' `/' `%' + {10, 9}, {5, 4}, // power and concat (right associative) + {3, 3}, {3, 3}, // equality and inequality + {3, 3}, {3, 3}, {3, 3}, {3, 3}, // order + {2, 2}, {1, 1} // logical (and/or) + }; + + unsigned int recursionCounterOld = recursionCounter; + + // this handles recursive calls to parseSubExpr/parseExpr + incrementRecursionCounter("expression"); + + const unsigned int unaryPriority = 8; + + Location start = lexer.current().location; + + AstExpr* expr; + + std::optional uop = parseUnaryOp(lexer.current()); + + if (!uop) + uop = checkUnaryConfusables(); + + if (uop) + { + nextLexeme(); + + AstExpr* subexpr = parseExpr(unaryPriority); + + expr = allocator.alloc(Location(start, subexpr->location), *uop, subexpr); + } + else + { + expr = parseAssertionExpr(); + } + + // expand while operators have priorities higher than `limit' + std::optional op = parseBinaryOp(lexer.current()); + + if (!op) + op = checkBinaryConfusables(binaryPriority, limit); + + while (op && binaryPriority[*op].left > limit) + { + nextLexeme(); + + // read sub-expression with higher priority + AstExpr* next = parseExpr(binaryPriority[*op].right); + + expr = allocator.alloc(Location(start, next->location), *op, expr, next); + op = parseBinaryOp(lexer.current()); + + if (!op) + op = checkBinaryConfusables(binaryPriority, limit); + + // note: while the parser isn't recursive here, we're generating recursive structures of unbounded depth + incrementRecursionCounter("expression"); + } + + recursionCounter = recursionCounterOld; + + return expr; +} + +// NAME +AstExpr* Parser::parseNameExpr(const char* context) +{ + auto name = parseNameOpt(context); + + if (!name) + return allocator.alloc(lexer.current().location, copy({}), unsigned(parseErrors.size() - 1)); + + AstLocal* const* value = localMap.find(name->name); + + if (value && *value) + { + AstLocal* local = *value; + + return allocator.alloc(name->location, local, local->functionDepth != functionStack.size() - 1); + } + + return allocator.alloc(name->location, name->name); +} + +// prefixexp -> NAME | '(' expr ')' +AstExpr* Parser::parsePrefixExpr() +{ + if (lexer.current().type == '(') + { + Location start = lexer.current().location; + + Lexeme matchParen = lexer.current(); + nextLexeme(); + + AstExpr* expr = parseExpr(); + + Location end = lexer.current().location; + + if (lexer.current().type != ')') + { + const char* suggestion = (lexer.current().type == '=') ? "; did you mean to use '{' when defining a table?" : nullptr; + + expectMatchAndConsumeFail(static_cast(')'), matchParen, suggestion); + + end = lexer.previousLocation(); + } + else + { + nextLexeme(); + } + + return allocator.alloc(Location(start, end), expr); + } + else + { + return parseNameExpr("expression"); + } +} + +// primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } +AstExpr* Parser::parsePrimaryExpr(bool asStatement) +{ + Location start = lexer.current().location; + + AstExpr* expr = parsePrefixExpr(); + + unsigned int recursionCounterOld = recursionCounter; + + while (true) + { + if (lexer.current().type == '.') + { + Position opPosition = lexer.current().location.begin; + nextLexeme(); + + Name index = parseIndexName(nullptr, opPosition); + + expr = allocator.alloc(Location(start, index.location), expr, index.name, index.location, opPosition, '.'); + } + else if (lexer.current().type == '[') + { + Lexeme matchBracket = lexer.current(); + nextLexeme(); + + AstExpr* index = parseExpr(); + + Location end = lexer.current().location; + + expectMatchAndConsume(']', matchBracket); + + expr = allocator.alloc(Location(start, end), expr, index); + } + else if (lexer.current().type == ':') + { + Position opPosition = lexer.current().location.begin; + nextLexeme(); + + Name index = parseIndexName("method name", opPosition); + AstExpr* func = allocator.alloc(Location(start, index.location), expr, index.name, index.location, opPosition, ':'); + + expr = parseFunctionArgs(func, true, index.location); + } + else if (lexer.current().type == '(') + { + // This error is handled inside 'parseFunctionArgs' as well, but for better error recovery we need to break out the current loop here + if (!asStatement && expr->location.end.line != lexer.current().location.begin.line) + { + report(lexer.current().location, + "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " + "new statement; use ';' to separate statements"); + + break; + } + + expr = parseFunctionArgs(expr, false, Location()); + } + else if (lexer.current().type == '{' || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) + { + expr = parseFunctionArgs(expr, false, Location()); + } + else + { + break; + } + + // note: while the parser isn't recursive here, we're generating recursive structures of unbounded depth + incrementRecursionCounter("expression"); + } + + recursionCounter = recursionCounterOld; + + return expr; +} + +// asexp -> simpleexp [`::' typeannotation] +AstExpr* Parser::parseAssertionExpr() +{ + Location start = lexer.current().location; + AstExpr* expr = parseSimpleExpr(); + + if (options.allowTypeAnnotations && lexer.current().type == Lexeme::DoubleColon) + { + nextLexeme(); + AstType* annotation = parseTypeAnnotation(); + return allocator.alloc(Location(start, annotation->location), expr, annotation); + } + else + return expr; +} + +static bool parseNumber(double& result, const char* data) +{ + // binary literal + if (data[0] == '0' && (data[1] == 'b' || data[1] == 'B') && data[2]) + { + char* end = nullptr; + unsigned long long value = strtoull(data + 2, &end, 2); + + result = double(value); + return *end == 0; + } + // hexadecimal literal + else if (data[0] == '0' && (data[1] == 'x' || data[1] == 'X') && data[2]) + { + char* end = nullptr; + unsigned long long value = strtoull(data + 2, &end, 16); + + result = double(value); + return *end == 0; + } + else + { + char* end = nullptr; + double value = strtod(data, &end); + + result = value; + return *end == 0; + } +} + +// simpleexp -> NUMBER | STRING | NIL | true | false | ... | constructor | FUNCTION body | primaryexp +AstExpr* Parser::parseSimpleExpr() +{ + Location start = lexer.current().location; + + if (lexer.current().type == Lexeme::ReservedNil) + { + nextLexeme(); + + return allocator.alloc(start); + } + else if (lexer.current().type == Lexeme::ReservedTrue) + { + nextLexeme(); + + return allocator.alloc(start, true); + } + else if (lexer.current().type == Lexeme::ReservedFalse) + { + nextLexeme(); + + return allocator.alloc(start, false); + } + else if (lexer.current().type == Lexeme::ReservedFunction) + { + Lexeme matchFunction = lexer.current(); + nextLexeme(); + + return parseFunctionBody(false, matchFunction, AstName(), {}).first; + } + else if (lexer.current().type == Lexeme::Number) + { + scratchData.assign(lexer.current().data, lexer.current().length); + + // Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al + if (scratchData.find('_') != std::string::npos) + { + scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end()); + } + + double value = 0; + if (parseNumber(value, scratchData.c_str())) + { + nextLexeme(); + + return allocator.alloc(start, value); + } + else + { + nextLexeme(); + + return reportExprError(start, {}, "Malformed number"); + } + } + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) + { + return parseString(); + } + else if (lexer.current().type == Lexeme::BrokenString) + { + nextLexeme(); + return reportExprError(start, {}, "Malformed string"); + } + else if (lexer.current().type == Lexeme::Dot3) + { + if (functionStack.back().vararg) + { + nextLexeme(); + + return allocator.alloc(start); + } + else + { + nextLexeme(); + + return reportExprError(start, {}, "Cannot use '...' outside of a vararg function"); + } + } + else if (lexer.current().type == '{') + { + return parseTableConstructor(); + } + else if (FFlag::LuauIfElseExpressionBaseSupport && lexer.current().type == Lexeme::ReservedIf) + { + return parseIfElseExpr(); + } + else + { + return parsePrimaryExpr(/* asStatement= */ false); + } +} + +// args ::= `(' [explist] `)' | tableconstructor | String +AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self, const Location& selfLocation) +{ + if (lexer.current().type == '(') + { + Position argStart = lexer.current().location.end; + if (func->location.end.line != lexer.current().location.begin.line) + { + report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of " + "new statement; use ';' to separate statements"); + } + + Lexeme matchParen = lexer.current(); + nextLexeme(); + + TempVector args(scratchExpr); + + if (lexer.current().type != ')') + parseExprList(args); + + Location end = lexer.current().location; + Position argEnd = end.end; + + expectMatchAndConsume(')', matchParen); + + return allocator.alloc(Location(func->location, end), func, copy(args), self, Location(argStart, argEnd)); + } + else if (lexer.current().type == '{') + { + Position argStart = lexer.current().location.end; + AstExpr* expr = parseTableConstructor(); + Position argEnd = lexer.previousLocation().end; + + return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, Location(argStart, argEnd)); + } + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) + { + Location argLocation = lexer.current().location; + AstExpr* expr = parseString(); + + return allocator.alloc(Location(func->location, expr->location), func, copy(&expr, 1), self, argLocation); + } + else + { + if (self && lexer.current().location.begin.line != func->location.end.line) + { + return reportExprError(func->location, copy({func}), "Expected function call arguments after '('"); + } + else + { + return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}), + "Expected '(', '{' or when parsing function call, got %s", lexer.current().toString().c_str()); + } + } +} + +// tableconstructor ::= `{' [fieldlist] `}' +// fieldlist ::= field {fieldsep field} [fieldsep] +// field ::= `[' exp `]' `=' exp | Name `=' exp | exp +// fieldsep ::= `,' | `;' +AstExpr* Parser::parseTableConstructor() +{ + TempVector items(scratchItem); + + Location start = lexer.current().location; + + Lexeme matchBrace = lexer.current(); + expectAndConsume('{', "table literal"); + + while (lexer.current().type != '}') + { + if (lexer.current().type == '[') + { + Lexeme matchLocationBracket = lexer.current(); + nextLexeme(); + + AstExpr* key = parseExpr(); + + expectMatchAndConsume(']', matchLocationBracket); + + expectAndConsume('=', "table field"); + + AstExpr* value = parseExpr(); + + items.push_back({AstExprTable::Item::General, key, value}); + } + else if (lexer.current().type == Lexeme::Name && lexer.lookahead().type == '=') + { + Name name = parseName("table field"); + + expectAndConsume('=', "table field"); + + AstArray nameString; + nameString.data = const_cast(name.name.value); + nameString.size = strlen(name.name.value); + + AstExpr* key = allocator.alloc(name.location, nameString); + AstExpr* value = parseExpr(); + + items.push_back({AstExprTable::Item::Record, key, value}); + } + else + { + AstExpr* expr = parseExpr(); + + items.push_back({AstExprTable::Item::List, nullptr, expr}); + } + + if (lexer.current().type == ',' || lexer.current().type == ';') + { + nextLexeme(); + } + else + { + if (lexer.current().type != '}') + break; + } + } + + Location end = lexer.current().location; + + if (!expectMatchAndConsume('}', matchBrace)) + end = lexer.previousLocation(); + + return allocator.alloc(Location(start, end), copy(items)); +} + +AstExpr* Parser::parseIfElseExpr() +{ + bool hasElse = false; + Location start = lexer.current().location; + + nextLexeme(); // skip if / elseif + + AstExpr* condition = parseExpr(); + + bool hasThen = expectAndConsume(Lexeme::ReservedThen, "if then else expression"); + + AstExpr* trueExpr = parseExpr(); + AstExpr* falseExpr = nullptr; + + if (lexer.current().type == Lexeme::ReservedElseif) + { + unsigned int oldRecursionCount = recursionCounter; + incrementRecursionCounter("expression"); + hasElse = true; + falseExpr = parseIfElseExpr(); + recursionCounter = oldRecursionCount; + } + else + { + hasElse = expectAndConsume(Lexeme::ReservedElse, "if then else expression"); + falseExpr = parseExpr(); + } + + Location end = falseExpr->location; + + return allocator.alloc(Location(start, end), condition, hasThen, trueExpr, hasElse, falseExpr); +} + +// Name +std::optional Parser::parseNameOpt(const char* context) +{ + if (lexer.current().type != Lexeme::Name) + { + reportNameError(context); + + return {}; + } + + Name result(AstName(lexer.current().name), lexer.current().location); + + nextLexeme(); + + return result; +} + +Parser::Name Parser::parseName(const char* context) +{ + if (auto name = parseNameOpt(context)) + return *name; + + Location location = lexer.current().location; + location.end = location.begin; + + return Name(nameError, location); +} + +Parser::Name Parser::parseIndexName(const char* context, const Position& previous) +{ + if (auto name = parseNameOpt(context)) + return *name; + + // If we have a reserved keyword next at the same line, assume it's an incomplete name + if (lexer.current().type >= Lexeme::Reserved_BEGIN && lexer.current().type < Lexeme::Reserved_END && + lexer.current().location.begin.line == previous.line) + { + Name result(AstName(lexer.current().name), lexer.current().location); + + nextLexeme(); + + return result; + } + + Location location = lexer.current().location; + location.end = location.begin; + + return Name(nameError, location); +} + +std::pair, AstArray> Parser::parseGenericTypeListIfFFlagParseGenericFunctions() +{ + if (FFlag::LuauParseGenericFunctions) + return Parser::parseGenericTypeList(); + AstArray generics; + AstArray genericPacks; + generics.size = 0; + generics.data = nullptr; + genericPacks.size = 0; + genericPacks.data = nullptr; + return std::pair(generics, genericPacks); +} + +std::pair, AstArray> Parser::parseGenericTypeList() +{ + TempVector names{scratchName}; + TempVector namePacks{scratchPackName}; + + if (lexer.current().type == '<') + { + Lexeme begin = lexer.current(); + nextLexeme(); + + bool seenPack = false; + while (true) + { + AstName name = parseName().name; + if (FFlag::LuauParseGenericFunctions && lexer.current().type == Lexeme::Dot3) + { + seenPack = true; + nextLexeme(); + namePacks.push_back(name); + } + else + { + if (seenPack) + report(lexer.current().location, "Generic types come before generic type packs"); + + names.push_back(name); + } + + if (lexer.current().type == ',') + nextLexeme(); + else + break; + } + + expectMatchAndConsume('>', begin); + } + + AstArray generics = copy(names); + AstArray genericPacks = copy(namePacks); + return {generics, genericPacks}; +} + +AstArray Parser::parseTypeParams() +{ + TempVector result{scratchAnnotation}; + + if (lexer.current().type == '<') + { + Lexeme begin = lexer.current(); + nextLexeme(); + + while (true) + { + result.push_back(parseTypeAnnotation()); + if (lexer.current().type == ',') + nextLexeme(); + else + break; + } + + expectMatchAndConsume('>', begin); + } + + return copy(result); +} + +AstExpr* Parser::parseString() +{ + LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString); + + scratchData.assign(lexer.current().data, lexer.current().length); + + if (lexer.current().type == Lexeme::QuotedString) + { + if (!Lexer::fixupQuotedString(scratchData)) + { + Location location = lexer.current().location; + + nextLexeme(); + + return reportExprError(location, {}, "String literal contains malformed escape sequence"); + } + } + else + { + Lexer::fixupMultilineString(scratchData); + } + + Location start = lexer.current().location; + AstArray value = copy(scratchData); + + nextLexeme(); + + return allocator.alloc(start, value); +} + +AstLocal* Parser::pushLocal(const Binding& binding) +{ + const Name& name = binding.name; + AstLocal*& local = localMap[name.name]; + + local = allocator.alloc( + name.name, name.location, /* shadow= */ local, functionStack.size() - 1, functionStack.back().loopDepth, binding.annotation); + + localStack.push_back(local); + + return local; +} + +unsigned int Parser::saveLocals() +{ + return unsigned(localStack.size()); +} + +void Parser::restoreLocals(unsigned int offset) +{ + for (size_t i = localStack.size(); i > offset; --i) + { + AstLocal* l = localStack[i - 1]; + + localMap[l->name] = l->shadow; + } + + localStack.resize(offset); +} + +bool Parser::expectAndConsume(char value, const char* context) +{ + return expectAndConsume(static_cast(static_cast(value)), context); +} + +bool Parser::expectAndConsume(Lexeme::Type type, const char* context) +{ + if (lexer.current().type != type) + { + expectAndConsumeFail(type, context); + + // check if this is an extra token and the expected token is next + if (lexer.lookahead().type == type) + { + // skip invalid and consume expected + nextLexeme(); + nextLexeme(); + } + + return false; + } + else + { + nextLexeme(); + return true; + } +} + +// LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is +// cold +LUAU_NOINLINE void Parser::expectAndConsumeFail(Lexeme::Type type, const char* context) +{ + std::string typeString = Lexeme(Location(Position(0, 0), 0), type).toString(); + std::string currLexemeString = lexer.current().toString(); + + if (context) + report(lexer.current().location, "Expected %s when parsing %s, got %s", typeString.c_str(), context, currLexemeString.c_str()); + else + report(lexer.current().location, "Expected %s, got %s", typeString.c_str(), currLexemeString.c_str()); +} + +bool Parser::expectMatchAndConsume(char value, const Lexeme& begin, bool searchForMissing) +{ + Lexeme::Type type = static_cast(static_cast(value)); + + if (lexer.current().type != type) + { + expectMatchAndConsumeFail(type, begin); + + if (searchForMissing) + { + // previous location is taken because 'current' lexeme is already the next token + unsigned currentLine = lexer.previousLocation().end.line; + + // search to the end of the line for expected token + // we will also stop if we hit a token that can be handled by parsing function above the current one + Lexeme::Type lexemeType = lexer.current().type; + + while (currentLine == lexer.current().location.begin.line && lexemeType != type && matchRecoveryStopOnToken[lexemeType] == 0) + { + nextLexeme(); + lexemeType = lexer.current().type; + } + + if (lexemeType == type) + { + nextLexeme(); + + return true; + } + } + else + { + // check if this is an extra token and the expected token is next + if (lexer.lookahead().type == type) + { + // skip invalid and consume expected + nextLexeme(); + nextLexeme(); + + return true; + } + } + + return false; + } + else + { + nextLexeme(); + + return true; + } +} + +// LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is +// cold +LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const Lexeme& begin, const char* extra) +{ + std::string typeString = Lexeme(Location(Position(0, 0), 0), type).toString(); + + if (lexer.current().location.begin.line == begin.location.begin.line) + report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), begin.toString().c_str(), + begin.location.begin.column + 1, lexer.current().toString().c_str(), extra ? extra : ""); + else + report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), begin.toString().c_str(), + begin.location.begin.line + 1, lexer.current().toString().c_str(), extra ? extra : ""); +} + +bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin) +{ + if (lexer.current().type != type) + { + expectMatchEndAndConsumeFail(type, begin); + + // check if this is an extra token and the expected token is next + if (lexer.lookahead().type == type) + { + // skip invalid and consume expected + nextLexeme(); + nextLexeme(); + + return true; + } + + return false; + } + else + { + // If the token matches on a different line and a different column, it suggests misleading indentation + // This can be used to pinpoint the problem location for a possible future *actual* mismatch + if (lexer.current().location.begin.line != begin.location.begin.line && + lexer.current().location.begin.column != begin.location.begin.column && + endMismatchSuspect.location.begin.line < begin.location.begin.line) // Only replace the previous suspect with more recent suspects + { + endMismatchSuspect = begin; + } + + nextLexeme(); + + return true; + } +} + +// LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is +// cold +LUAU_NOINLINE void Parser::expectMatchEndAndConsumeFail(Lexeme::Type type, const Lexeme& begin) +{ + if (endMismatchSuspect.type != Lexeme::Eof && endMismatchSuspect.location.begin.line > begin.location.begin.line) + { + std::string suggestion = + format("; did you forget to close %s at line %d?", endMismatchSuspect.toString().c_str(), endMismatchSuspect.location.begin.line + 1); + + expectMatchAndConsumeFail(type, begin, suggestion.c_str()); + } + else + { + expectMatchAndConsumeFail(type, begin); + } +} + +template +AstArray Parser::copy(const T* data, size_t size) +{ + AstArray result; + + result.data = size ? static_cast(allocator.allocate(sizeof(T) * size)) : nullptr; + result.size = size; + + // This is equivalent to std::uninitialized_copy, but without the exception guarantee + // since our types don't have destructors + for (size_t i = 0; i < size; ++i) + new (result.data + i) T(data[i]); + + return result; +} + +template +AstArray Parser::copy(const TempVector& data) +{ + return copy(data.empty() ? nullptr : &data[0], data.size()); +} + +template +AstArray Parser::copy(std::initializer_list data) +{ + return copy(data.size() == 0 ? nullptr : data.begin(), data.size()); +} + +AstArray Parser::copy(const std::string& data) +{ + AstArray result = copy(data.c_str(), data.size() + 1); + + result.size = data.size(); + + return result; +} + +void Parser::incrementRecursionCounter(const char* context) +{ + recursionCounter++; + + if (recursionCounter > unsigned(FInt::LuauRecursionLimit)) + { + ParseError::raise(lexer.current().location, "Exceeded allowed recursion depth; simplify your %s to make the code compile", context); + } +} + +void Parser::report(const Location& location, const char* format, va_list args) +{ + // To reduce number of errors reported to user for incomplete statements, we skip multiple errors at the same location + // For example, consider 'local a = (((b + ' where multiple tokens haven't been written yet + if (!parseErrors.empty() && location == parseErrors.back().getLocation()) + return; + + std::string message = vformat(format, args); + + // when limited to a single error, behave as if the error recovery is disabled + if (FInt::LuauParseErrorLimit == 1) + throw ParseError(location, message); + + parseErrors.emplace_back(location, message); + + if (parseErrors.size() >= unsigned(FInt::LuauParseErrorLimit)) + ParseError::raise(location, "Reached error limit (%d)", int(FInt::LuauParseErrorLimit)); +} + +void Parser::report(const Location& location, const char* format, ...) +{ + va_list args; + va_start(args, format); + report(location, format, args); + va_end(args); +} + +LUAU_NOINLINE void Parser::reportNameError(const char* context) +{ + if (context) + report(lexer.current().location, "Expected identifier when parsing %s, got %s", context, lexer.current().toString().c_str()); + else + report(lexer.current().location, "Expected identifier, got %s", lexer.current().toString().c_str()); +} + +AstStatError* Parser::reportStatError( + const Location& location, const AstArray& expressions, const AstArray& statements, const char* format, ...) +{ + va_list args; + va_start(args, format); + report(location, format, args); + va_end(args); + + return allocator.alloc(location, expressions, statements, unsigned(parseErrors.size() - 1)); +} + +AstExprError* Parser::reportExprError(const Location& location, const AstArray& expressions, const char* format, ...) +{ + va_list args; + va_start(args, format); + report(location, format, args); + va_end(args); + + return allocator.alloc(location, expressions, unsigned(parseErrors.size() - 1)); +} + +AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) +{ + va_list args; + va_start(args, format); + report(location, format, args); + va_end(args); + + return allocator.alloc(location, types, isMissing, unsigned(parseErrors.size() - 1)); +} + +const Lexeme& Parser::nextLexeme() +{ + if (options.captureComments) + { + while (true) + { + const Lexeme& lexeme = lexer.next(/*skipComments*/ false); + // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. + // The parser will turn this into a proper syntax error. + if (FFlag::LuauCaptureBrokenCommentSpans && lexeme.type == Lexeme::BrokenComment) + commentLocations.push_back(Comment{lexeme.type, lexeme.location}); + if (isComment(lexeme)) + commentLocations.push_back(Comment{lexeme.type, lexeme.location}); + else + return lexeme; + } + } + else + return lexer.next(); +} + +} // namespace Luau diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp new file mode 100644 index 0000000..24b2283 --- /dev/null +++ b/Ast/src/StringUtils.cpp @@ -0,0 +1,228 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/StringUtils.h" + +#include "Luau/Common.h" + +#include +#include +#include +#include + +namespace Luau +{ + +static void vformatAppend(std::string& ret, const char* fmt, va_list args) +{ + va_list argscopy; + va_copy(argscopy, args); +#ifdef _MSC_VER + int actualSize = _vscprintf(fmt, argscopy); +#else + int actualSize = vsnprintf(NULL, 0, fmt, argscopy); +#endif + va_end(argscopy); + + if (actualSize <= 0) + return; + + size_t sz = ret.size(); + ret.resize(sz + actualSize); + vsnprintf(&ret[0] + sz, actualSize + 1, fmt, args); +} + +std::string format(const char* fmt, ...) +{ + std::string result; + va_list args; + va_start(args, fmt); + vformatAppend(result, fmt, args); + va_end(args); + return result; +} + +void formatAppend(std::string& str, const char* fmt, ...) +{ + va_list args; + va_start(args, fmt); + vformatAppend(str, fmt, args); + va_end(args); +} + +std::string vformat(const char* fmt, va_list args) +{ + std::string ret; + vformatAppend(ret, fmt, args); + return ret; +} + +template +static std::string joinImpl(const std::vector& segments, std::string_view delimiter) +{ + if (segments.empty()) + return ""; + + size_t len = (segments.size() - 1) * delimiter.size(); + for (const auto& sv : segments) + len += sv.size(); + + std::string result; + result.resize(len); + char* dest = const_cast(result.data()); // This const_cast is only necessary until C++17 + + auto it = segments.begin(); + memcpy(dest, it->data(), it->size()); + dest += it->size(); + ++it; + for (; it != segments.end(); ++it) + { + memcpy(dest, delimiter.data(), delimiter.size()); + dest += delimiter.size(); + memcpy(dest, it->data(), it->size()); + dest += it->size(); + } + + LUAU_ASSERT(dest == result.data() + len); + + return result; +} + +std::string join(const std::vector& segments, std::string_view delimiter) +{ + return joinImpl(segments, delimiter); +} + +std::string join(const std::vector& segments, std::string_view delimiter) +{ + return joinImpl(segments, delimiter); +} + +std::vector split(std::string_view s, char delimiter) +{ + std::vector result; + + while (!s.empty()) + { + auto index = s.find(delimiter); + if (index == std::string::npos) + { + result.push_back(s); + break; + } + result.push_back(s.substr(0, index)); + s.remove_prefix(index + 1); + } + + return result; +} + +size_t editDistance(std::string_view a, std::string_view b) +{ + // When there are matching prefix and suffix, they end up computing as zero cost, effectively making it no-op. We drop these characters. + while (!a.empty() && !b.empty() && a.front() == b.front()) + { + a.remove_prefix(1); + b.remove_prefix(1); + } + + while (!a.empty() && !b.empty() && a.back() == b.back()) + { + a.remove_suffix(1); + b.remove_suffix(1); + } + + // Since we know the edit distance is the difference of the length of A and B discounting the matching prefixes and suffixes, + // it is therefore pointless to run the rest of this function to find that out. We immediately infer this size and return it. + if (a.empty()) + return b.size(); + if (b.empty()) + return a.size(); + + size_t maxDistance = a.size() + b.size(); + + std::vector distances((a.size() + 2) * (b.size() + 2), 0); + auto getPos = [b](size_t x, size_t y) -> size_t { + return (x * (b.size() + 2)) + y; + }; + + distances[0] = maxDistance; + + for (size_t x = 0; x <= a.size(); ++x) + { + distances[getPos(x + 1, 0)] = maxDistance; + distances[getPos(x + 1, 1)] = x; + } + + for (size_t y = 0; y <= b.size(); ++y) + { + distances[getPos(0, y + 1)] = maxDistance; + distances[getPos(1, y + 1)] = y; + } + + std::array seenCharToRow; + seenCharToRow.fill(0); + + for (size_t x = 1; x <= a.size(); ++x) + { + size_t lastMatchedY = 0; + + for (size_t y = 1; y <= b.size(); ++y) + { + size_t x1 = seenCharToRow[b[y - 1]]; + size_t y1 = lastMatchedY; + + size_t cost = 1; + if (a[x - 1] == b[y - 1]) + { + cost = 0; + lastMatchedY = y; + } + + size_t transposition = distances[getPos(x1, y1)] + (x - x1 - 1) + 1 + (y - y1 - 1); + size_t substitution = distances[getPos(x, y)] + cost; + size_t insertion = distances[getPos(x, y + 1)] + 1; + size_t deletion = distances[getPos(x + 1, y)] + 1; + + // It's more performant to use std::min(size_t, size_t) rather than the initializer_list overload. + // Until proven otherwise, please do not change this. + distances[getPos(x + 1, y + 1)] = std::min(std::min(insertion, deletion), std::min(substitution, transposition)); + } + + seenCharToRow[a[x - 1]] = x; + } + + return distances[getPos(a.size() + 1, b.size() + 1)]; +} + +bool startsWith(std::string_view haystack, std::string_view needle) +{ + // ::starts_with is C++20 + return haystack.size() >= needle.size() && haystack.substr(0, needle.size()) == needle; +} + +bool equalsLower(std::string_view lhs, std::string_view rhs) +{ + if (lhs.size() != rhs.size()) + return false; + + for (size_t i = 0; i < lhs.size(); ++i) + if (tolower(uint8_t(lhs[i])) != tolower(uint8_t(rhs[i]))) + return false; + + return true; +} + +size_t hashRange(const char* data, size_t size) +{ + // FNV-1a + uint32_t hash = 2166136261; + + for (size_t i = 0; i < size; ++i) + { + hash ^= uint8_t(data[i]); + hash *= 16777619; + } + + return hash; +} + +} // namespace Luau diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp new file mode 100644 index 0000000..920502b --- /dev/null +++ b/CLI/Analyze.cpp @@ -0,0 +1,252 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ModuleResolver.h" +#include "Luau/TypeInfer.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Frontend.h" +#include "Luau/TypeAttach.h" +#include "Luau/Transpiler.h" + +#include "FileUtils.h" + +enum class ReportFormat +{ + Default, + Luacheck +}; + +static void report(ReportFormat format, const char* name, const Luau::Location& location, const char* type, const char* message) +{ + switch (format) + { + case ReportFormat::Default: + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); + break; + + case ReportFormat::Luacheck: + { + // Note: luacheck's end column is inclusive but our end column is exclusive + // In addition, luacheck doesn't support multi-line messages, so if the error is multiline we'll fake end column as 100 and hope for the best + int columnEnd = (location.begin.line == location.end.line) ? location.end.column : 100; + + fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, columnEnd, type, message); + break; + } + } +} + +static void reportError(ReportFormat format, const char* name, const Luau::TypeError& error) +{ + if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) + report(format, name, error.location, "SyntaxError", syntaxError->message.c_str()); + else + report(format, name, error.location, "TypeError", Luau::toString(error).c_str()); +} + +static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) +{ + report(format, name, warning.location, Luau::LintWarning::getName(warning.code), warning.text.c_str()); +} + +static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate) +{ + Luau::CheckResult cr = frontend.check(name); + + if (!frontend.getSourceModule(name)) + { + fprintf(stderr, "Error opening %s\n", name); + return false; + } + + for (auto& error : cr.errors) + reportError(format, name, error); + + Luau::LintResult lr = frontend.lint(name); + + for (auto& error : lr.errors) + reportWarning(format, name, error); + for (auto& warning : lr.warnings) + reportWarning(format, name, warning); + + if (annotate) + { + Luau::SourceModule* sm = frontend.getSourceModule(name); + Luau::ModulePtr m = frontend.moduleResolver.getModule(name); + + Luau::attachTypeData(*sm, *m); + + std::string annotated = Luau::transpileWithTypes(*sm->root); + + printf("%s", annotated.c_str()); + } + + return cr.errors.empty() && lr.errors.empty(); +} + +static void displayHelp(const char* argv0) +{ + printf("Usage: %s [--mode] [options] [file list]\n", argv0); + printf("\n"); + printf("Available modes:\n"); + printf(" omitted: typecheck and lint input files\n"); + printf(" --annotate: typecheck input files and output source with type annotations\n"); + printf("\n"); + printf("Available options:\n"); + printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); +} + +static int assertionHandler(const char* expr, const char* file, int line) +{ + printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + return 1; +} + +struct CliFileResolver : Luau::FileResolver +{ + std::optional readSource(const Luau::ModuleName& name) override + { + std::optional source = readFile(name); + if (!source) + return std::nullopt; + + return Luau::SourceCode{*source, Luau::SourceCode::Module}; + } + + bool moduleExists(const Luau::ModuleName& name) const override + { + return !!readFile(name); + } + + std::optional fromAstFragment(Luau::AstExpr* expr) const override + { + return std::nullopt; + } + + Luau::ModuleName concat(const Luau::ModuleName& lhs, std::string_view rhs) const override + { + return lhs + "/" + std::string(rhs); + } + + std::optional getParentModuleName(const Luau::ModuleName& name) const override + { + return std::nullopt; + } + + std::optional getEnvironmentForModule(const Luau::ModuleName& name) const override + { + return std::nullopt; + } +}; + +struct CliConfigResolver : Luau::ConfigResolver +{ + Luau::Config defaultConfig; + + mutable std::unordered_map configCache; + mutable std::vector> configErrors; + + CliConfigResolver() + { + defaultConfig.mode = Luau::Mode::Nonstrict; + } + + const Luau::Config& getConfig(const Luau::ModuleName& name) const override + { + std::optional path = getParentPath(name); + if (!path) + return defaultConfig; + + return readConfigRec(*path); + } + + const Luau::Config& readConfigRec(const std::string& path) const + { + auto it = configCache.find(path); + if (it != configCache.end()) + return it->second; + + std::optional parent = getParentPath(path); + Luau::Config result = parent ? readConfigRec(*parent) : defaultConfig; + + std::string configPath = joinPaths(path, Luau::kConfigName); + + if (std::optional contents = readFile(configPath)) + { + std::optional error = Luau::parseConfig(*contents, result); + if (error) + configErrors.push_back({configPath, *error}); + } + + return configCache[path] = result; + } +}; + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + if (argc >= 2 && strcmp(argv[1], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + + ReportFormat format = ReportFormat::Default; + bool annotate = false; + + for (int i = 1; i < argc; ++i) + { + if (argv[i][0] != '-') + continue; + + if (strcmp(argv[i], "--formatter=plain") == 0) + format = ReportFormat::Luacheck; + else if (strcmp(argv[i], "--annotate") == 0) + annotate = true; + } + + Luau::FrontendOptions frontendOptions; + frontendOptions.retainFullTypeGraphs = annotate; + + CliFileResolver fileResolver; + CliConfigResolver configResolver; + Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); + + Luau::registerBuiltinTypes(frontend.typeChecker); + Luau::freeze(frontend.typeChecker.globalTypes); + + int failed = 0; + + for (int i = 1; i < argc; ++i) + { + if (argv[i][0] == '-') + continue; + + if (isDirectory(argv[i])) + { + traverseDirectory(argv[i], [&](const std::string& name) { + if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + failed += !analyzeFile(frontend, name.c_str(), format, annotate); + }); + } + else + { + failed += !analyzeFile(frontend, argv[i], format, annotate); + } + } + + if (!configResolver.configErrors.empty()) + { + failed += int(configResolver.configErrors.size()); + + for (const auto& pair : configResolver.configErrors) + fprintf(stderr, "%s: %s\n", pair.first.c_str(), pair.second.c_str()); + } + + return (format == ReportFormat::Luacheck) ? 0 : failed; +} + + diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp new file mode 100644 index 0000000..0702b74 --- /dev/null +++ b/CLI/FileUtils.cpp @@ -0,0 +1,224 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "FileUtils.h" + +#include "Luau/Common.h" + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#else +#include +#include +#include +#include +#endif + +#include + +#ifdef _WIN32 +static std::wstring fromUtf8(const std::string& path) +{ + size_t result = MultiByteToWideChar(CP_UTF8, 0, path.data(), int(path.size()), nullptr, 0); + LUAU_ASSERT(result); + + std::wstring buf(result, L'\0'); + MultiByteToWideChar(CP_UTF8, 0, path.data(), int(path.size()), &buf[0], int(buf.size())); + + return buf; +} + +static std::string toUtf8(const std::wstring& path) +{ + size_t result = WideCharToMultiByte(CP_UTF8, 0, path.data(), int(path.size()), nullptr, 0, nullptr, nullptr); + LUAU_ASSERT(result); + + std::string buf(result, '\0'); + WideCharToMultiByte(CP_UTF8, 0, path.data(), int(path.size()), &buf[0], int(buf.size()), nullptr, nullptr); + + return buf; +} +#endif + +std::optional readFile(const std::string& name) +{ +#ifdef _WIN32 + FILE* file = _wfopen(fromUtf8(name).c_str(), L"rb"); +#else + FILE* file = fopen(name.c_str(), "rb"); +#endif + + if (!file) + return std::nullopt; + + fseek(file, 0, SEEK_END); + long length = ftell(file); + if (length < 0) + { + fclose(file); + return std::nullopt; + } + fseek(file, 0, SEEK_SET); + + std::string result(length, 0); + + size_t read = fread(result.data(), 1, length, file); + fclose(file); + + if (read != size_t(length)) + return std::nullopt; + + return result; +} + +template +static void joinPaths(std::basic_string& str, const Ch* lhs, const Ch* rhs) +{ + str = lhs; + if (!str.empty() && str.back() != '/' && str.back() != '\\' && *rhs != '/' && *rhs != '\\') + str += '/'; + str += rhs; +} + +#ifdef _WIN32 +static bool traverseDirectoryRec(const std::wstring& path, const std::function& callback) +{ + std::wstring query = path + std::wstring(L"/*"); + + WIN32_FIND_DATAW data; + HANDLE h = FindFirstFileW(query.c_str(), &data); + + if (h == INVALID_HANDLE_VALUE) + return false; + + std::wstring buf; + + do + { + if (wcscmp(data.cFileName, L".") != 0 && wcscmp(data.cFileName, L"..") != 0) + { + joinPaths(buf, path.c_str(), data.cFileName); + + if (data.dwFileAttributes & FILE_ATTRIBUTE_REPARSE_POINT) + { + // Skip reparse points to avoid handling cycles + } + else if (data.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) + { + traverseDirectoryRec(buf, callback); + } + else + { + callback(toUtf8(buf)); + } + } + } while (FindNextFileW(h, &data)); + + FindClose(h); + + return true; +} + +bool traverseDirectory(const std::string& path, const std::function& callback) +{ + return traverseDirectoryRec(fromUtf8(path), callback); +} +#else +static bool traverseDirectoryRec(const std::string& path, const std::function& callback) +{ + int fd = open(path.c_str(), O_DIRECTORY); + DIR* dir = fdopendir(fd); + + if (!dir) + return false; + + std::string buf; + + while (dirent* entry = readdir(dir)) + { + const dirent& data = *entry; + + if (strcmp(data.d_name, ".") != 0 && strcmp(data.d_name, "..") != 0) + { + joinPaths(buf, path.c_str(), data.d_name); + + int type = data.d_type; + + // we need to stat DT_UNKNOWN to be able to tell the type + if (type == DT_UNKNOWN) + { + struct stat st = {}; +#ifdef _ATFILE_SOURCE + fstatat(fd, data.d_name, &st, 0); +#else + lstat(buf.c_str(), &st); +#endif + + type = IFTODT(st.st_mode); + } + + if (type == DT_DIR) + { + traverseDirectoryRec(buf, callback); + } + else if (type == DT_REG) + { + callback(buf); + } + else if (type == DT_LNK) + { + // Skip symbolic links to avoid handling cycles + } + } + } + + closedir(dir); + + return true; +} + +bool traverseDirectory(const std::string& path, const std::function& callback) +{ + return traverseDirectoryRec(path, callback); +} +#endif + +bool isDirectory(const std::string& path) +{ +#ifdef _WIN32 + return (GetFileAttributesW(fromUtf8(path).c_str()) & FILE_ATTRIBUTE_DIRECTORY) != 0; +#else + struct stat st = {}; + lstat(path.c_str(), &st); + return (st.st_mode & S_IFMT) == S_IFDIR; +#endif +} + +std::string joinPaths(const std::string& lhs, const std::string& rhs) +{ + std::string result = lhs; + if (!result.empty() && result.back() != '/' && result.back() != '\\') + result += '/'; + result += rhs; + return result; +} + +std::optional getParentPath(const std::string& path) +{ + if (path == "" || path == "." || path == "/") + return std::nullopt; + +#ifdef _WIN32 + if (path.size() == 2 && path.back() == ':') + return std::nullopt; +#endif + + std::string::size_type slash = path.find_last_of("\\/", path.size() - 1); + + if (slash == 0) + return "/"; + + if (slash != std::string::npos) + return path.substr(0, slash); + + return ""; +} diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h new file mode 100644 index 0000000..f7fbe8a --- /dev/null +++ b/CLI/FileUtils.h @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include +#include +#include + +std::optional readFile(const std::string& name); + +bool isDirectory(const std::string& path); +bool traverseDirectory(const std::string& path, const std::function& callback); + +std::string joinPaths(const std::string& lhs, const std::string& rhs); +std::optional getParentPath(const std::string& path); diff --git a/CLI/Profiler.cpp b/CLI/Profiler.cpp new file mode 100644 index 0000000..c6d15a7 --- /dev/null +++ b/CLI/Profiler.cpp @@ -0,0 +1,155 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" + +#include "Luau/DenseHash.h" + +#include +#include +#include + +struct Profiler +{ + // static state + lua_Callbacks* callbacks = nullptr; + int frequency = 1000; + std::thread thread; + + // variables for communication between loop and trigger + std::atomic exit = false; + std::atomic ticks = 0; + std::atomic samples = 0; + + // private state for trigger + uint64_t currentTicks = 0; + std::string stackScratch; + + // statistics, updated by trigger + Luau::DenseHashMap data{""}; + uint64_t gc[16] = {}; +} gProfiler; + +static void profilerTrigger(lua_State* L, int gc) +{ + uint64_t currentTicks = gProfiler.ticks.load(); + uint64_t elapsedTicks = currentTicks - gProfiler.currentTicks; + + if (elapsedTicks) + { + std::string& stack = gProfiler.stackScratch; + + stack.clear(); + + if (gc > 0) + stack += "GC,GC,"; + + lua_Debug ar; + for (int level = 0; lua_getinfo(L, level, "sn", &ar); ++level) + { + if (!stack.empty()) + stack += ';'; + + stack += ar.short_src; + stack += ','; + if (ar.name) + stack += ar.name; + stack += ','; + if (ar.linedefined > 0) + stack += std::to_string(ar.linedefined); + } + + if (!stack.empty()) + { + gProfiler.data[stack] += elapsedTicks; + } + + if (gc > 0) + { + gProfiler.gc[gc] += elapsedTicks; + } + } + + gProfiler.currentTicks = currentTicks; + gProfiler.callbacks->interrupt = nullptr; +} + +static void profilerLoop() +{ + double last = lua_clock(); + + while (!gProfiler.exit) + { + double now = lua_clock(); + + if (now - last >= 1.0 / double(gProfiler.frequency)) + { + gProfiler.ticks += uint64_t((now - last) * 1e6); + gProfiler.samples++; + gProfiler.callbacks->interrupt = profilerTrigger; + + last = now; + } + else + { + std::this_thread::yield(); + } + } +} + +void profilerStart(lua_State* L, int frequency) +{ + gProfiler.frequency = frequency; + gProfiler.callbacks = lua_callbacks(L); + + gProfiler.exit = false; + gProfiler.thread = std::thread(profilerLoop); +} + +void profilerStop() +{ + gProfiler.exit = true; + gProfiler.thread.join(); +} + +void profilerDump(const char* name) +{ + FILE* f = fopen(name, "wb"); + if (!f) + { + fprintf(stderr, "Error opening profile %s\n", name); + return; + } + + uint64_t total = 0; + + for (auto& p : gProfiler.data) + { + fprintf(f, "%lld %s\n", static_cast(p.second), p.first.c_str()); + total += p.second; + } + + fclose(f); + + printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", name, double(total) / 1e6, + static_cast(gProfiler.samples.load()), static_cast(gProfiler.data.size())); + + uint64_t totalgc = 0; + for (uint64_t p : gProfiler.gc) + totalgc += p; + + if (totalgc) + { + printf("GC: %.3f seconds (%.2f%%)", double(totalgc) / 1e6, double(totalgc) / double(total) * 100); + + for (size_t i = 0; i < std::size(gProfiler.gc); ++i) + { + extern const char* luaC_statename(int state); + + uint64_t p = gProfiler.gc[i]; + + if (p) + printf(", %s %.2f%%", luaC_statename(int(i)), double(p) / double(totalgc) * 100); + } + + printf("\n"); + } +} diff --git a/CLI/Profiler.h b/CLI/Profiler.h new file mode 100644 index 0000000..0a407e4 --- /dev/null +++ b/CLI/Profiler.h @@ -0,0 +1,8 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +struct lua_State; + +void profilerStart(lua_State* L, int frequency); +void profilerStop(); +void profilerDump(const char* name); \ No newline at end of file diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp new file mode 100644 index 0000000..6baa21e --- /dev/null +++ b/CLI/Repl.cpp @@ -0,0 +1,495 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Luau/Compiler.h" +#include "Luau/BytecodeBuilder.h" +#include "Luau/Parser.h" + +#include "FileUtils.h" +#include "Profiler.h" + +#include "linenoise.hpp" + +#include + +static int lua_loadstring(lua_State* L) +{ + size_t l = 0; + const char* s = luaL_checklstring(L, 1, &l); + const char* chunkname = luaL_optstring(L, 2, s); + + lua_setsafeenv(L, LUA_ENVIRONINDEX, false); + + std::string bytecode = Luau::compile(std::string(s, l)); + if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + return 1; + + lua_pushnil(L); + lua_insert(L, -2); /* put before error message */ + return 2; /* return nil plus error message */ +} + +static int finishrequire(lua_State* L) +{ + if (lua_isstring(L, -1)) + lua_error(L); + + return 1; +} + +static int lua_require(lua_State* L) +{ + std::string name = luaL_checkstring(L, 1); + std::string chunkname = "=" + name; + + luaL_findtable(L, LUA_REGISTRYINDEX, "_MODULES", 1); + + // return the module from the cache + lua_getfield(L, -1, name.c_str()); + if (!lua_isnil(L, -1)) + return finishrequire(L); + lua_pop(L, 1); + + std::optional source = readFile(name + ".lua"); + if (!source) + luaL_argerrorL(L, 1, ("error loading " + name).c_str()); + + // module needs to run in a new thread, isolated from the rest + lua_State* GL = lua_mainthread(L); + lua_State* ML = lua_newthread(GL); + lua_xmove(GL, L, 1); + + // new thread needs to have the globals sandboxed + luaL_sandboxthread(ML); + + // now we can compile & run module on the new thread + std::string bytecode = Luau::compile(*source); + if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + { + int status = lua_resume(ML, L, 0); + + if (status == 0) + { + if (lua_gettop(ML) == 0) + lua_pushstring(ML, "module must return a value"); + else if (!lua_istable(ML, -1) && !lua_isfunction(ML, -1)) + lua_pushstring(ML, "module must return a table or function"); + } + else if (status == LUA_YIELD) + { + lua_pushstring(ML, "module can not yield"); + } + else if (!lua_isstring(ML, -1)) + { + lua_pushstring(ML, "unknown error while running module"); + } + } + + // there's now a return value on top of ML; stack of L is MODULES thread + lua_xmove(ML, L, 1); + lua_pushvalue(L, -1); + lua_setfield(L, -4, name.c_str()); + + return finishrequire(L); +} + +static int lua_collectgarbage(lua_State* L) +{ + const char* option = luaL_optstring(L, 1, "collect"); + + if (strcmp(option, "collect") == 0) + { + lua_gc(L, LUA_GCCOLLECT, 0); + return 0; + } + + if (strcmp(option, "count") == 0) + { + int c = lua_gc(L, LUA_GCCOUNT, 0); + lua_pushnumber(L, c); + return 1; + } + + luaL_error(L, "collectgarbage must be called with 'count' or 'collect'"); +} + +static void setupState(lua_State* L) +{ + luaL_openlibs(L); + + static const luaL_Reg funcs[] = { + {"loadstring", lua_loadstring}, + {"require", lua_require}, + {"collectgarbage", lua_collectgarbage}, + {NULL, NULL}, + }; + + lua_pushvalue(L, LUA_GLOBALSINDEX); + luaL_register(L, NULL, funcs); + lua_pop(L, 1); + + luaL_sandbox(L); +} + +static std::string runCode(lua_State* L, const std::string& source) +{ + std::string bytecode = Luau::compile(source); + + if (luau_load(L, "=stdin", bytecode.data(), bytecode.size()) != 0) + { + size_t len; + const char* msg = lua_tolstring(L, -1, &len); + + std::string error(msg, len); + lua_pop(L, 1); + + return error; + } + + lua_State* T = lua_newthread(L); + + lua_pushvalue(L, -2); + lua_remove(L, -3); + lua_xmove(L, T, 1); + + int status = lua_resume(T, NULL, 0); + + if (status == 0) + { + int n = lua_gettop(T); + + if (n) + { + luaL_checkstack(T, LUA_MINSTACK, "too many results to print"); + lua_getglobal(T, "print"); + lua_insert(T, 1); + lua_pcall(T, n, 0, 0); + } + } + else + { + std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(T, -1); + error += "\nstack backtrace:\n"; + error += lua_debugtrace(T); + + fprintf(stdout, "%s", error.c_str()); + } + + lua_pop(L, 1); + return std::string(); +} + +static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) +{ + std::string_view lookup = editBuffer + start; + + for (;;) + { + size_t dot = lookup.find('.'); + std::string_view prefix = lookup.substr(0, dot); + + if (dot == std::string_view::npos) + { + // table, key + lua_pushnil(L); + + while (lua_next(L, -2) != 0) + { + // table, key, value + std::string_view key = lua_tostring(L, -2); + + if (!key.empty() && Luau::startsWith(key, prefix)) + completions.push_back(editBuffer + std::string(key.substr(prefix.size()))); + + lua_pop(L, 1); + } + + break; + } + else + { + // find the key in the table + lua_pushlstring(L, prefix.data(), prefix.size()); + lua_rawget(L, -2); + lua_remove(L, -2); + + if (lua_isnil(L, -1)) + break; + + lookup.remove_prefix(dot + 1); + } + } + + lua_pop(L, 1); +} + +static void completeRepl(lua_State* L, const char* editBuffer, std::vector& completions) +{ + size_t start = strlen(editBuffer); + while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.')) + start--; + + // look the value up in current global table first + lua_pushvalue(L, LUA_GLOBALSINDEX); + completeIndexer(L, editBuffer, start, completions); + + // and in actual global table after that + lua_getglobal(L, "_G"); + completeIndexer(L, editBuffer, start, completions); +} + +static void runRepl() +{ + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + setupState(L); + + luaL_sandboxthread(L); + + linenoise::SetCompletionCallback([L](const char* editBuffer, std::vector& completions) { + completeRepl(L, editBuffer, completions); + }); + + std::string buffer; + + for (;;) + { + bool quit = false; + std::string line = linenoise::Readline(buffer.empty() ? "> " : ">> ", quit); + if (quit) + break; + + if (buffer.empty() && runCode(L, std::string("return ") + line) == std::string()) + { + linenoise::AddHistory(line.c_str()); + continue; + } + + buffer += line; + buffer += " "; // linenoise doesn't work very well with multiline history entries + + std::string error = runCode(L, buffer); + + if (error.length() >= 5 && error.compare(error.length() - 5, 5, "") == 0) + { + continue; + } + + if (error.length()) + { + fprintf(stdout, "%s\n", error.c_str()); + } + + linenoise::AddHistory(buffer.c_str()); + buffer.clear(); + } +} + +static bool runFile(const char* name, lua_State* GL) +{ + std::optional source = readFile(name); + if (!source) + { + fprintf(stderr, "Error opening %s\n", name); + return false; + } + + // module needs to run in a new thread, isolated from the rest + lua_State* L = lua_newthread(GL); + + // new thread needs to have the globals sandboxed + luaL_sandboxthread(L); + + std::string chunkname = "=" + std::string(name); + + std::string bytecode = Luau::compile(*source); + int status = 0; + + if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + { + status = lua_resume(L, NULL, 0); + } + else + { + status = LUA_ERRSYNTAX; + } + + if (status == 0) + { + return true; + } + else + { + std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(L, -1); + error += "\nstacktrace:\n"; + error += lua_debugtrace(L); + + fprintf(stderr, "%s", error.c_str()); + return false; + } +} + +static void report(const char* name, const Luau::Location& location, const char* type, const char* message) +{ + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); +} + +static void reportError(const char* name, const Luau::ParseError& error) +{ + report(name, error.getLocation(), "SyntaxError", error.what()); +} + +static void reportError(const char* name, const Luau::CompileError& error) +{ + report(name, error.getLocation(), "CompileError", error.what()); +} + +static bool compileFile(const char* name) +{ + std::optional source = readFile(name); + if (!source) + { + fprintf(stderr, "Error opening %s\n", name); + return false; + } + + try + { + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); + bcb.setDumpSource(*source); + + Luau::compileOrThrow(bcb, *source); + + printf("%s", bcb.dumpEverything().c_str()); + + return true; + } + catch (Luau::ParseErrors& e) + { + for (auto& error : e.getErrors()) + reportError(name, error); + return false; + } + catch (Luau::CompileError& e) + { + reportError(name, e); + return false; + } +} + +static void displayHelp(const char* argv0) +{ + printf("Usage: %s [--mode] [options] [file list]\n", argv0); + printf("\n"); + printf("When mode and file list are omitted, an interactive REPL is started instead.\n"); + printf("\n"); + printf("Available modes:\n"); + printf(" omitted: compile and run input files one by one\n"); + printf(" --compile: compile input files and output resulting bytecode\n"); + printf("\n"); + printf("Available options:\n"); + printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); +} + +static int assertionHandler(const char* expr, const char* file, int line) +{ + printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + return 1; +} + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + if (argc == 1) + { + runRepl(); + return 0; + } + + if (argc >= 2 && strcmp(argv[1], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + + if (argc >= 2 && strcmp(argv[1], "--compile") == 0) + { + int failed = 0; + + for (int i = 2; i < argc; ++i) + { + if (argv[i][0] == '-') + continue; + + if (isDirectory(argv[i])) + { + traverseDirectory(argv[i], [&](const std::string& name) { + if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + failed += !compileFile(name.c_str()); + }); + } + else + { + failed += !compileFile(argv[i]); + } + } + + return failed; + } + + { + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + setupState(L); + + int profile = 0; + + for (int i = 1; i < argc; ++i) + if (strcmp(argv[i], "--profile") == 0) + profile = 10000; // default to 10 KHz + else if (strncmp(argv[i], "--profile=", 10) == 0) + profile = atoi(argv[i] + 10); + + if (profile) + profilerStart(L, profile); + + int failed = 0; + + for (int i = 1; i < argc; ++i) + { + if (argv[i][0] == '-') + continue; + + if (isDirectory(argv[i])) + { + traverseDirectory(argv[i], [&](const std::string& name) { + if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + failed += !runFile(name.c_str(), L); + }); + } + else + { + failed += !runFile(argv[i], L); + } + } + + if (profile) + { + profilerStop(); + profilerDump("profile.out"); + } + + return failed; + } +} + + diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..d6598f2 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,88 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +if(EXT_PLATFORM_STRING) + include(EXTLuau.cmake) + return() +endif() + +cmake_minimum_required(VERSION 3.0) +project(Luau LANGUAGES CXX) + +option(LUAU_BUILD_CLI "Build CLI" ON) +option(LUAU_BUILD_TESTS "Build tests" ON) + +add_library(Luau.Ast STATIC) +add_library(Luau.Compiler STATIC) +add_library(Luau.Analysis STATIC) +add_library(Luau.VM STATIC) + +if(LUAU_BUILD_CLI) + add_executable(Luau.Repl.CLI) + add_executable(Luau.Analyze.CLI) + + # This also adds target `name` on Linux/macOS and `name.exe` on Windows + set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) + set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) +endif() + +if(LUAU_BUILD_TESTS) + add_executable(Luau.UnitTest) + add_executable(Luau.Conformance) +endif() +include(Sources.cmake) + +target_compile_features(Luau.Ast PUBLIC cxx_std_17) +target_include_directories(Luau.Ast PUBLIC Ast/include) + +target_compile_features(Luau.Compiler PUBLIC cxx_std_17) +target_include_directories(Luau.Compiler PUBLIC Compiler/include) +target_link_libraries(Luau.Compiler PUBLIC Luau.Ast) + +target_compile_features(Luau.Analysis PUBLIC cxx_std_17) +target_include_directories(Luau.Analysis PUBLIC Analysis/include) +target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) + +target_compile_features(Luau.VM PRIVATE cxx_std_11) +target_include_directories(Luau.VM PUBLIC VM/include) + +set(LUAU_OPTIONS) + +if(MSVC) + list(APPEND LUAU_OPTIONS /D_CRT_SECURE_NO_WARNINGS) # We need to use the portable CRT functions. + list(APPEND LUAU_OPTIONS /WX) # Warnings are errors + list(APPEND LUAU_OPTIONS /MP) # Distribute single project compilation across multiple cores +else() + list(APPEND LUAU_OPTIONS -Wall) # All warnings + list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors + + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + list(APPEND LUAU_OPTIONS -Wno-unused) # GCC considers variables declared/checked in if() as unused + endif() +endif() + +target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) +target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) +target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) + +if(LUAU_BUILD_CLI) + target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) + target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) + + target_include_directories(Luau.Repl.CLI PRIVATE extern) + target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) + + if(UNIX) + target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + endif() + + target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) +endif() + +if(LUAU_BUILD_TESTS) + target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) + target_include_directories(Luau.UnitTest PRIVATE extern) + target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) + + target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) + target_include_directories(Luau.Conformance PRIVATE extern) + target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.VM) +endif() diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h new file mode 100644 index 0000000..07be2e7 --- /dev/null +++ b/Compiler/include/Luau/Bytecode.h @@ -0,0 +1,478 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +// clang-format off + +// This header contains the bytecode definition for Luau interpreter +// Creating the bytecode is outside the scope of this file and is handled by bytecode builder (BytecodeBuilder.h) and bytecode compiler (Compiler.h) +// Note that ALL enums declared in this file are order-sensitive since the values are baked into bytecode that needs to be processed by legacy clients. + +// Bytecode definitions +// Bytecode instructions are using "word code" - each instruction is one or many 32-bit words. +// The first word in the instruction is always the instruction header, and *must* contain the opcode (enum below) in the least significant byte. +// +// Instruction word can be encoded using one of the following encodings: +// ABC - least-significant byte for the opcode, followed by three bytes, A, B and C; each byte declares a register index, small index into some other table or an unsigned integral value +// AD - least-significant byte for the opcode, followed by A byte, followed by D half-word (16-bit integer). D is a signed integer that commonly specifies constant table index or jump offset +// E - least-significant byte for the opcode, followed by E (24-bit integer). E is a signed integer that commonly specifies a jump offset +// +// Instruction word is sometimes followed by one extra word, indicated as AUX - this is just a 32-bit word and is decoded according to the specification for each opcode. +// For each opcode the encoding is *static* - that is, based on the opcode you know a-priory how large the instruction is, with the exception of NEWCLOSURE + +// Bytecode indices +// Bytecode instructions commonly refer to integer values that define offsets or indices for various entities. For each type, there's a maximum encodable value. +// Note that in some cases, the compiler will set a lower limit than the maximum encodable value is to prevent fragile code into bumping against the limits whenever we change the compilation details. +// Additionally, in some specific instructions such as ANDK, the limit on the encoded value is smaller; this means that if a value is larger, a different instruction must be selected. +// +// Registers: 0-254. Registers refer to the values on the function's stack frame, including arguments. +// Upvalues: 0-254. Upvalues refer to the values stored in the closure object. +// Constants: 0-2^23-1. Constants are stored in a table allocated with each proto; to allow for future bytecode tweaks the encodable value is limited to 23 bits. +// Closures: 0-2^15-1. Closures are created from child protos via a child index; the limit is for the number of closures immediately referenced in each function. +// Jumps: -2^23..2^23. Jump offsets are specified in word increments, so jumping over an instruction may sometimes require an offset of 2 or more. +enum LuauOpcode +{ + // NOP: noop + LOP_NOP, + + // BREAK: debugger break + LOP_BREAK, + + // LOADNIL: sets register to nil + // A: target register + LOP_LOADNIL, + + // LOADB: sets register to boolean and jumps to a given short offset (used to compile comparison results into a boolean) + // A: target register + // B: value (0/1) + // C: jump offset + LOP_LOADB, + + // LOADN: sets register to a number literal + // A: target register + // D: value (-32768..32767) + LOP_LOADN, + + // LOADK: sets register to an entry from the constant table from the proto (number/string) + // A: target register + // D: constant table index (0..32767) + LOP_LOADK, + + // MOVE: move (copy) value from one register to another + // A: target register + // B: source register + LOP_MOVE, + + // GETGLOBAL: load value from global table using constant string as a key + // A: target register + // C: predicted slot index (based on hash) + // AUX: constant table index + LOP_GETGLOBAL, + + // SETGLOBAL: set value in global table using constant string as a key + // A: source register + // C: predicted slot index (based on hash) + // AUX: constant table index + LOP_SETGLOBAL, + + // GETUPVAL: load upvalue from the upvalue table for the current function + // A: target register + // B: upvalue index (0..255) + LOP_GETUPVAL, + + // SETUPVAL: store value into the upvalue table for the current function + // A: target register + // B: upvalue index (0..255) + LOP_SETUPVAL, + + // CLOSEUPVALS: close (migrate to heap) all upvalues that were captured for registers >= target + // A: target register + LOP_CLOSEUPVALS, + + // GETIMPORT: load imported global table global from the constant table + // A: target register + // D: constant table index (0..32767); we assume that imports are loaded into the constant table + // AUX: 3 10-bit indices of constant strings that, combined, constitute an import path; length of the path is set by the top 2 bits (1,2,3) + LOP_GETIMPORT, + + // GETTABLE: load value from table into target register using key from register + // A: target register + // B: table register + // C: index register + LOP_GETTABLE, + + // SETTABLE: store source register into table using key from register + // A: source register + // B: table register + // C: index register + LOP_SETTABLE, + + // GETTABLEKS: load value from table into target register using constant string as a key + // A: target register + // B: table register + // C: predicted slot index (based on hash) + // AUX: constant table index + LOP_GETTABLEKS, + + // SETTABLEKS: store source register into table using constant string as a key + // A: source register + // B: table register + // C: predicted slot index (based on hash) + // AUX: constant table index + LOP_SETTABLEKS, + + // GETTABLEN: load value from table into target register using small integer index as a key + // A: target register + // B: table register + // C: index-1 (index is 1..256) + LOP_GETTABLEN, + + // SETTABLEN: store source register into table using small integer index as a key + // A: source register + // B: table register + // C: index-1 (index is 1..256) + LOP_SETTABLEN, + + // NEWCLOSURE: create closure from a child proto; followed by a CAPTURE instruction for each upvalue + // A: target register + // D: child proto index (0..32767) + LOP_NEWCLOSURE, + + // NAMECALL: prepare to call specified method by name by loading function from source register using constant index into target register and copying source register into target register + 1 + // A: target register + // B: source register + // C: predicted slot index (based on hash) + // AUX: constant table index + // Note that this instruction must be followed directly by CALL; it prepares the arguments + // This instruction is roughly equivalent to GETTABLEKS + MOVE pair, but we need a special instruction to support custom __namecall metamethod + LOP_NAMECALL, + + // CALL: call specified function + // A: register where the function object lives, followed by arguments; results are placed starting from the same register + // B: argument count + 1, or 0 to preserve all arguments up to top (MULTRET) + // C: result count + 1, or 0 to preserve all values and adjust top (MULTRET) + LOP_CALL, + + // RETURN: returns specified values from the function + // A: register where the returned values start + // B: number of returned values + 1, or 0 to return all values up to top (MULTRET) + LOP_RETURN, + + // JUMP: jumps to target offset + // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + LOP_JUMP, + + // JUMPBACK: jumps to target offset; this is equivalent to JUMP but is used as a safepoint to be able to interrupt while/repeat loops + // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + LOP_JUMPBACK, + + // JUMPIF: jumps to target offset if register is not nil/false + // A: source register + // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + LOP_JUMPIF, + + // JUMPIFNOT: jumps to target offset if register is nil/false + // A: source register + // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + LOP_JUMPIFNOT, + + // JUMPIFEQ, JUMPIFLE, JUMPIFLT, JUMPIFNOTEQ, JUMPIFNOTLE, JUMPIFNOTLT: jumps to target offset if the comparison is true (or false, for NOT variants) + // A: source register 1 + // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + // AUX: source register 2 + LOP_JUMPIFEQ, + LOP_JUMPIFLE, + LOP_JUMPIFLT, + LOP_JUMPIFNOTEQ, + LOP_JUMPIFNOTLE, + LOP_JUMPIFNOTLT, + + // ADD, SUB, MUL, DIV, MOD, POW: compute arithmetic operation between two source registers and put the result into target register + // A: target register + // B: source register 1 + // C: source register 2 + LOP_ADD, + LOP_SUB, + LOP_MUL, + LOP_DIV, + LOP_MOD, + LOP_POW, + + // ADDK, SUBK, MULK, DIVK, MODK, POWK: compute arithmetic operation between the source register and a constant and put the result into target register + // A: target register + // B: source register + // C: constant table index (0..255) + LOP_ADDK, + LOP_SUBK, + LOP_MULK, + LOP_DIVK, + LOP_MODK, + LOP_POWK, + + // AND, OR: perform `and` or `or` operation (selecting first or second register based on whether the first one is truthful) and put the result into target register + // A: target register + // B: source register 1 + // C: source register 2 + LOP_AND, + LOP_OR, + + // ANDK, ORK: perform `and` or `or` operation (selecting source register or constant based on whether the source register is truthful) and put the result into target register + // A: target register + // B: source register + // C: constant table index (0..255) + LOP_ANDK, + LOP_ORK, + + // CONCAT: concatenate all strings between B and C (inclusive) and put the result into A + // A: target register + // B: source register start + // C: source register end + LOP_CONCAT, + + // NOT, MINUS, LENGTH: compute unary operation for source register and put the result into target register + // A: target register + // B: source register + LOP_NOT, + LOP_MINUS, + LOP_LENGTH, + + // NEWTABLE: create table in target register + // A: target register + // B: table size, stored as 0 for v=0 and ceil(log2(v))+1 for v!=0 + // AUX: array size + LOP_NEWTABLE, + + // DUPTABLE: duplicate table using the constant table template to target register + // A: target register + // D: constant table index (0..32767) + LOP_DUPTABLE, + + // SETLIST: set a list of values to table in target register + // A: target register + // B: source register start + // C: value count + 1, or 0 to use all values up to top (MULTRET) + // AUX: table index to start from + LOP_SETLIST, + + // FORNPREP: prepare a numeric for loop, jump over the loop if first iteration doesn't need to run + // A: target register; numeric for loops assume a register layout [limit, step, index, variable] + // D: jump offset (-32768..32767) + // limit/step are immutable, index isn't visible to user code since it's copied into variable + LOP_FORNPREP, + + // FORNLOOP: adjust loop variables for one iteration, jump back to the loop header if loop needs to continue + // A: target register; see FORNPREP for register layout + // D: jump offset (-32768..32767) + LOP_FORNLOOP, + + // FORGLOOP: adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue + // A: target register; generic for loops assume a register layout [generator, state, index, variables...] + // D: jump offset (-32768..32767) + // AUX: variable count (1..255) + // loop variables are adjusted by calling generator(state, index) and expecting it to return a tuple that's copied to the user variables + // the first variable is then copied into index; generator/state are immutable, index isn't visible to user code + LOP_FORGLOOP, + + // FORGPREP_INEXT/FORGLOOP_INEXT: FORGLOOP with 2 output variables (no AUX encoding), assuming generator is luaB_inext + // FORGPREP_INEXT prepares the index variable and jumps to FORGLOOP_INEXT + // FORGLOOP_INEXT has identical encoding and semantics to FORGLOOP (except for AUX encoding) + LOP_FORGPREP_INEXT, + LOP_FORGLOOP_INEXT, + + // FORGPREP_NEXT/FORGLOOP_NEXT: FORGLOOP with 2 output variables (no AUX encoding), assuming generator is luaB_next + // FORGPREP_NEXT prepares the index variable and jumps to FORGLOOP_NEXT + // FORGLOOP_NEXT has identical encoding and semantics to FORGLOOP (except for AUX encoding) + LOP_FORGPREP_NEXT, + LOP_FORGLOOP_NEXT, + + // GETVARARGS: copy variables into the target register from vararg storage for current function + // A: target register + // B: variable count + 1, or 0 to copy all variables and adjust top (MULTRET) + LOP_GETVARARGS, + + // DUPCLOSURE: create closure from a pre-created function object (reusing it unless environments diverge) + // A: target register + // D: constant table index (0..32767) + LOP_DUPCLOSURE, + + // PREPVARARGS: prepare stack for variadic functions so that GETVARARGS works correctly + // A: number of fixed arguments + LOP_PREPVARARGS, + + // LOADKX: sets register to an entry from the constant table from the proto (number/string) + // A: target register + // AUX: constant table index + LOP_LOADKX, + + // JUMPX: jumps to the target offset; like JUMPBACK, supports interruption + // E: jump offset (-2^23..2^23; 0 means "next instruction" aka "don't jump") + LOP_JUMPX, + + // FASTCALL: perform a fast call of a built-in function + // A: builtin function id (see LuauBuiltinFunction) + // C: jump offset to get to following CALL + // FASTCALL is followed by one of (GETIMPORT, MOVE, GETUPVAL) instructions and by CALL instruction + // This is necessary so that if FASTCALL can't perform the call inline, it can continue normal execution + // If FASTCALL *can* perform the call, it jumps over the instructions *and* over the next CALL + // Note that FASTCALL will read the actual call arguments, such as argument/result registers and counts, from the CALL instruction + LOP_FASTCALL, + + // COVERAGE: update coverage information stored in the instruction + // E: hit count for the instruction (0..2^23-1) + // The hit count is incremented by VM every time the instruction is executed, and saturates at 2^23-1 + LOP_COVERAGE, + + // CAPTURE: capture a local or an upvalue as an upvalue into a newly created closure; only valid after NEWCLOSURE + // A: capture type, see LuauCaptureType + // B: source register (for VAL/REF) or upvalue index (for UPVAL/UPREF) + LOP_CAPTURE, + + // JUMPIFEQK, JUMPIFNOTEQK: jumps to target offset if the comparison with constant is true (or false, for NOT variants) + // A: source register 1 + // D: jump offset (-32768..32767; 0 means "next instruction" aka "don't jump") + // AUX: constant table index + LOP_JUMPIFEQK, + LOP_JUMPIFNOTEQK, + + // FASTCALL1: perform a fast call of a built-in function using 1 register argument + // A: builtin function id (see LuauBuiltinFunction) + // B: source argument register + // C: jump offset to get to following CALL + LOP_FASTCALL1, + + // FASTCALL2: perform a fast call of a built-in function using 2 register arguments + // A: builtin function id (see LuauBuiltinFunction) + // B: source argument register + // C: jump offset to get to following CALL + // AUX: source register 2 in least-significant byte + LOP_FASTCALL2, + + // FASTCALL2K: perform a fast call of a built-in function using 1 register argument and 1 constant argument + // A: builtin function id (see LuauBuiltinFunction) + // B: source argument register + // C: jump offset to get to following CALL + // AUX: constant index + LOP_FASTCALL2K, + + // Enum entry for number of opcodes, not a valid opcode by itself! + LOP__COUNT +}; + +// Bytecode instruction header: it's always a 32-bit integer, with low byte (first byte in little endian) containing the opcode +// Some instruction types require more data and have more 32-bit integers following the header +#define LUAU_INSN_OP(insn) ((insn) & 0xff) + +// ABC encoding: three 8-bit values, containing registers or small numbers +#define LUAU_INSN_A(insn) (((insn) >> 8) & 0xff) +#define LUAU_INSN_B(insn) (((insn) >> 16) & 0xff) +#define LUAU_INSN_C(insn) (((insn) >> 24) & 0xff) + +// AD encoding: one 8-bit value, one signed 16-bit value +#define LUAU_INSN_D(insn) (int32_t(insn) >> 16) + +// E encoding: one signed 24-bit value +#define LUAU_INSN_E(insn) (int32_t(insn) >> 8) + +// Bytecode tags, used internally for bytecode encoded as a string +enum LuauBytecodeTag +{ + // Bytecode version + LBC_VERSION = 1, + // Types of constant table entries + LBC_CONSTANT_NIL = 0, + LBC_CONSTANT_BOOLEAN, + LBC_CONSTANT_NUMBER, + LBC_CONSTANT_STRING, + LBC_CONSTANT_IMPORT, + LBC_CONSTANT_TABLE, + LBC_CONSTANT_CLOSURE, +}; + +// Builtin function ids, used in LOP_FASTCALL +enum LuauBuiltinFunction +{ + LBF_NONE = 0, + + // assert() + LBF_ASSERT, + + // math. + LBF_MATH_ABS, + LBF_MATH_ACOS, + LBF_MATH_ASIN, + LBF_MATH_ATAN2, + LBF_MATH_ATAN, + LBF_MATH_CEIL, + LBF_MATH_COSH, + LBF_MATH_COS, + LBF_MATH_DEG, + LBF_MATH_EXP, + LBF_MATH_FLOOR, + LBF_MATH_FMOD, + LBF_MATH_FREXP, + LBF_MATH_LDEXP, + LBF_MATH_LOG10, + LBF_MATH_LOG, + LBF_MATH_MAX, + LBF_MATH_MIN, + LBF_MATH_MODF, + LBF_MATH_POW, + LBF_MATH_RAD, + LBF_MATH_SINH, + LBF_MATH_SIN, + LBF_MATH_SQRT, + LBF_MATH_TANH, + LBF_MATH_TAN, + + // bit32. + LBF_BIT32_ARSHIFT, + LBF_BIT32_BAND, + LBF_BIT32_BNOT, + LBF_BIT32_BOR, + LBF_BIT32_BXOR, + LBF_BIT32_BTEST, + LBF_BIT32_EXTRACT, + LBF_BIT32_LROTATE, + LBF_BIT32_LSHIFT, + LBF_BIT32_REPLACE, + LBF_BIT32_RROTATE, + LBF_BIT32_RSHIFT, + + // type() + LBF_TYPE, + + // string. + LBF_STRING_BYTE, + LBF_STRING_CHAR, + LBF_STRING_LEN, + + // typeof() + LBF_TYPEOF, + + // string. + LBF_STRING_SUB, + + // math. + LBF_MATH_CLAMP, + LBF_MATH_SIGN, + LBF_MATH_ROUND, + + // raw* + LBF_RAWSET, + LBF_RAWGET, + LBF_RAWEQUAL, + + // table. + LBF_TABLE_INSERT, + LBF_TABLE_UNPACK, + + // vector ctor + LBF_VECTOR, +}; + +// Capture type, used in LOP_CAPTURE +enum LuauCaptureType +{ + LCT_VAL = 0, + LCT_REF, + LCT_UPVAL, +}; diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h new file mode 100644 index 0000000..d4ebad6 --- /dev/null +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -0,0 +1,250 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Bytecode.h" +#include "Luau/DenseHash.h" + +#include + +namespace Luau +{ + +class BytecodeEncoder +{ +public: + virtual ~BytecodeEncoder() {} + + virtual uint8_t encodeOp(uint8_t op) = 0; +}; + +class BytecodeBuilder +{ +public: + // BytecodeBuilder does *not* copy the data passed via StringRef; instead, it keeps the ref around until finalize() + // Please be careful with the lifetime of the data that's being passed because of this. + // The safe and correct pattern is to only build StringRefs out of pieces of AST (AstName or AstArray<>) that are backed by AstAllocator. + // Note that you must finalize() the builder before the Allocator backing the Ast is destroyed. + struct StringRef + { + // To construct a StringRef, use sref() from Compiler.cpp. + const char* data = nullptr; + size_t length = 0; + + bool operator==(const StringRef& other) const; + }; + + struct TableShape + { + static const unsigned int kMaxLength = 32; + + int32_t keys[kMaxLength]; + unsigned int length = 0; + + bool operator==(const TableShape& other) const; + }; + + BytecodeBuilder(BytecodeEncoder* encoder = 0); + + uint32_t beginFunction(uint8_t numparams, bool isvararg = false); + void endFunction(uint8_t maxstacksize, uint8_t numupvalues); + + void setMainFunction(uint32_t fid); + + int32_t addConstantNil(); + int32_t addConstantBoolean(bool value); + int32_t addConstantNumber(double value); + int32_t addConstantString(StringRef value); + int32_t addImport(uint32_t iid); + int32_t addConstantTable(const TableShape& shape); + int32_t addConstantClosure(uint32_t fid); + + int16_t addChildFunction(uint32_t fid); + + void emitABC(LuauOpcode op, uint8_t a, uint8_t b, uint8_t c); + void emitAD(LuauOpcode op, uint8_t a, int16_t d); + void emitE(LuauOpcode op, int32_t e); + void emitAux(uint32_t aux); + + size_t emitLabel(); + + [[nodiscard]] bool patchJumpD(size_t jumpLabel, size_t targetLabel); + [[nodiscard]] bool patchSkipC(size_t jumpLabel, size_t targetLabel); + + void foldJumps(); + void expandJumps(); + + void setDebugFunctionName(StringRef name); + void setDebugLine(int line); + void pushDebugLocal(StringRef name, uint8_t reg, uint32_t startpc, uint32_t endpc); + void pushDebugUpval(StringRef name); + uint32_t getDebugPC() const; + + void finalize(); + + enum DumpFlags + { + Dump_Code = 1 << 0, + Dump_Lines = 1 << 1, + Dump_Source = 1 << 2, + Dump_Locals = 1 << 3, + }; + + void setDumpFlags(uint32_t flags) + { + dumpFlags = flags; + dumpFunctionPtr = &BytecodeBuilder::dumpCurrentFunction; + } + + void setDumpSource(const std::string& source); + + const std::string& getBytecode() const + { + LUAU_ASSERT(!bytecode.empty()); // did you forget to call finalize? + return bytecode; + } + + std::string dumpFunction(uint32_t id) const; + std::string dumpEverything() const; + + static uint32_t getImportId(int32_t id0); + static uint32_t getImportId(int32_t id0, int32_t id1); + static uint32_t getImportId(int32_t id0, int32_t id1, int32_t id2); + + static uint32_t getStringHash(StringRef key); + + static std::string getError(const std::string& message); + +private: + struct Constant + { + enum Type + { + Type_Nil, + Type_Boolean, + Type_Number, + Type_String, + Type_Import, + Type_Table, + Type_Closure, + }; + + Type type; + union + { + bool valueBoolean; + double valueNumber; + unsigned int valueString; // index into string table + uint32_t valueImport; // 10-10-10-2 encoded import id + uint32_t valueTable; // index into tableShapes[] + uint32_t valueClosure; // index of function in global list + }; + }; + + struct ConstantKey + { + Constant::Type type; + // Note: this stores value* from Constant; when type is Number_Double, this stores the same bits as double does but in uint64_t. + uint64_t value; + + bool operator==(const ConstantKey& key) const + { + return type == key.type && value == key.value; + } + }; + + struct Function + { + std::string data; + + uint8_t maxstacksize = 0; + uint8_t numparams = 0; + uint8_t numupvalues = 0; + bool isvararg = false; + + unsigned int debugname = 0; + + std::string dump; + std::string dumpname; + }; + + struct DebugLocal + { + unsigned int name; + + uint8_t reg; + uint32_t startpc; + uint32_t endpc; + }; + + struct DebugUpval + { + unsigned int name; + }; + + struct Jump + { + uint32_t source; + uint32_t target; + }; + + struct StringRefHash + { + size_t operator()(const StringRef& v) const; + }; + + struct ConstantKeyHash + { + size_t operator()(const ConstantKey& key) const; + }; + + struct TableShapeHash + { + size_t operator()(const TableShape& v) const; + }; + + std::vector functions; + uint32_t currentFunction = ~0u; + uint32_t mainFunction = ~0u; + + std::vector insns; + std::vector lines; + std::vector constants; + std::vector protos; + std::vector jumps; + + std::vector tableShapes; + + bool hasLongJumps = false; + + DenseHashMap constantMap; + DenseHashMap tableShapeMap; + + int debugLine = 0; + + std::vector debugLocals; + std::vector debugUpvals; + + DenseHashMap stringTable; + + BytecodeEncoder* encoder = nullptr; + std::string bytecode; + + uint32_t dumpFlags = 0; + std::vector dumpSource; + + std::string (BytecodeBuilder::*dumpFunctionPtr)() const = nullptr; + + void validate() const; + + std::string dumpCurrentFunction() const; + const uint32_t* dumpInstruction(const uint32_t* opcode, std::string& output) const; + + void writeFunction(std::string& ss, uint32_t id) const; + void writeLineInfo(std::string& ss) const; + void writeStringTable(std::string& ss) const; + + int32_t addConstant(const ConstantKey& key, const Constant& value); + unsigned int addStringTableEntry(StringRef value); +}; + +} // namespace Luau diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h new file mode 100644 index 0000000..f8d6715 --- /dev/null +++ b/Compiler/include/Luau/Compiler.h @@ -0,0 +1,67 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/ParseOptions.h" +#include "Luau/Location.h" +#include "Luau/StringUtils.h" +#include "Luau/Common.h" + +namespace Luau +{ +class AstStatBlock; +class AstNameTable; +class BytecodeBuilder; +class BytecodeEncoder; + +struct CompileOptions +{ + // default bytecode version target; can be used to compile code for older clients + int bytecodeVersion = 1; + + // 0 - no optimization + // 1 - baseline optimization level that doesn't prevent debuggability + // 2 - includes optimizations that harm debuggability such as inlining + int optimizationLevel = 1; + + // 0 - no debugging support + // 1 - line info & function names only; sufficient for backtraces + // 2 - full debug info with local & upvalue names; necessary for debugger + int debugLevel = 1; + + // 0 - no code coverage support + // 1 - statement coverage + // 2 - statement and expression coverage (verbose) + int coverageLevel = 0; + + // global builtin to construct vectors; disabled by default + const char* vectorLib = nullptr; + const char* vectorCtor = nullptr; +}; + +class CompileError : public std::exception +{ +public: + CompileError(const Location& location, const std::string& message); + + virtual ~CompileError() throw(); + + virtual const char* what() const throw(); + + const Location& getLocation() const; + + static LUAU_NORETURN void raise(const Location& location, const char* format, ...) LUAU_PRINTF_ATTR(2, 3); + +private: + Location location; + std::string message; +}; + +// compiles bytecode into bytecode builder using either a pre-parsed AST or parsing it from source; throws on errors +void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options = {}); +void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const CompileOptions& options = {}, const ParseOptions& parseOptions = {}); + +// compiles bytecode into a bytecode blob, that either contains the valid bytecode or an encoded error that luau_load can decode +std::string compile( + const std::string& source, const CompileOptions& options = {}, const ParseOptions& parseOptions = {}, BytecodeEncoder* encoder = nullptr); + +} // namespace Luau diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp new file mode 100644 index 0000000..3280c8a --- /dev/null +++ b/Compiler/src/BytecodeBuilder.cpp @@ -0,0 +1,1726 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/BytecodeBuilder.h" + +#include "Luau/StringUtils.h" + +#include +#include + +namespace Luau +{ + +static const uint32_t kMaxConstantCount = 1 << 23; +static const uint32_t kMaxClosureCount = 1 << 15; + +static const int kMaxJumpDistance = 1 << 23; + +static int log2(int v) +{ + LUAU_ASSERT(v); + + int r = 0; + + while (v >= (2 << r)) + r++; + + return r; +} + +static void writeByte(std::string& ss, unsigned char value) +{ + ss.append(reinterpret_cast(&value), sizeof(value)); +} + +static void writeInt(std::string& ss, int value) +{ + ss.append(reinterpret_cast(&value), sizeof(value)); +} + +static void writeDouble(std::string& ss, double value) +{ + ss.append(reinterpret_cast(&value), sizeof(value)); +} + +static void writeVarInt(std::string& ss, unsigned int value) +{ + do + { + writeByte(ss, (value & 127) | ((value > 127) << 7)); + value >>= 7; + } while (value); +} + +static int getOpLength(LuauOpcode op) +{ + switch (op) + { + case LOP_GETGLOBAL: + case LOP_SETGLOBAL: + case LOP_GETIMPORT: + case LOP_GETTABLEKS: + case LOP_SETTABLEKS: + case LOP_NAMECALL: + case LOP_JUMPIFEQ: + case LOP_JUMPIFLE: + case LOP_JUMPIFLT: + case LOP_JUMPIFNOTEQ: + case LOP_JUMPIFNOTLE: + case LOP_JUMPIFNOTLT: + case LOP_NEWTABLE: + case LOP_SETLIST: + case LOP_FORGLOOP: + case LOP_LOADKX: + case LOP_JUMPIFEQK: + case LOP_JUMPIFNOTEQK: + case LOP_FASTCALL2: + case LOP_FASTCALL2K: + return 2; + + default: + return 1; + } +} + +bool BytecodeBuilder::StringRef::operator==(const StringRef& other) const +{ + return (data && other.data) ? (length == other.length && memcmp(data, other.data, length) == 0) : (data == other.data); +} + +bool BytecodeBuilder::TableShape::operator==(const TableShape& other) const +{ + return length == other.length && memcmp(keys, other.keys, length * sizeof(keys[0])) == 0; +} + +size_t BytecodeBuilder::StringRefHash::operator()(const StringRef& v) const +{ + return hashRange(v.data, v.length); +} + +size_t BytecodeBuilder::ConstantKeyHash::operator()(const ConstantKey& key) const +{ + // finalizer from MurmurHash64B + const uint32_t m = 0x5bd1e995; + + uint32_t h1 = uint32_t(key.value); + uint32_t h2 = uint32_t(key.value >> 32) ^ (key.type * m); + + h1 ^= h2 >> 18; + h1 *= m; + h2 ^= h1 >> 22; + h2 *= m; + h1 ^= h2 >> 17; + h1 *= m; + h2 ^= h1 >> 19; + h2 *= m; + + // ... truncated to 32-bit output (normally hash is equal to (uint64_t(h1) << 32) | h2, but we only really need the lower 32-bit half) + return size_t(h2); +} + +size_t BytecodeBuilder::TableShapeHash::operator()(const TableShape& v) const +{ + // FNV-1a inspired hash (note that we feed integers instead of bytes) + uint32_t hash = 2166136261; + + for (size_t i = 0; i < v.length; ++i) + { + hash ^= v.keys[i]; + hash *= 16777619; + } + + return hash; +} + +BytecodeBuilder::BytecodeBuilder(BytecodeEncoder* encoder) + : constantMap({Constant::Type_Nil, ~0ull}) + , tableShapeMap(TableShape()) + , stringTable({nullptr, 0}) + , encoder(encoder) +{ + LUAU_ASSERT(stringTable.find(StringRef{"", 0}) == nullptr); +} + +uint32_t BytecodeBuilder::beginFunction(uint8_t numparams, bool isvararg) +{ + LUAU_ASSERT(currentFunction == ~0u); + + uint32_t id = uint32_t(functions.size()); + + Function func; + func.numparams = numparams; + func.isvararg = isvararg; + + functions.push_back(func); + + currentFunction = id; + + hasLongJumps = false; + debugLine = 0; + + return id; +} + +void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) +{ + LUAU_ASSERT(currentFunction != ~0u); + + Function& func = functions[currentFunction]; + + func.maxstacksize = maxstacksize; + func.numupvalues = numupvalues; + +#ifdef LUAU_ASSERTENABLED + validate(); +#endif + + // very approximate: 4 bytes per instruction for code, 1 byte for debug line, and 1-2 bytes for aux data like constants + func.data.reserve(insns.size() * 7); + + writeFunction(func.data, currentFunction); + + currentFunction = ~0u; + + // this call is indirect to make sure we only gain link time dependency on dumpCurrentFunction when needed + if (dumpFunctionPtr) + func.dump = (this->*dumpFunctionPtr)(); + + insns.clear(); + lines.clear(); + constants.clear(); + protos.clear(); + jumps.clear(); + tableShapes.clear(); + + debugLocals.clear(); + debugUpvals.clear(); + + constantMap.clear(); + tableShapeMap.clear(); +} + +void BytecodeBuilder::setMainFunction(uint32_t fid) +{ + mainFunction = fid; +} + +int32_t BytecodeBuilder::addConstant(const ConstantKey& key, const Constant& value) +{ + if (int32_t* cache = constantMap.find(key)) + return *cache; + + uint32_t id = uint32_t(constants.size()); + + if (id >= kMaxConstantCount) + return -1; + + constantMap[key] = int32_t(id); + constants.push_back(value); + + return int32_t(id); +} + +unsigned int BytecodeBuilder::addStringTableEntry(StringRef value) +{ + unsigned int& index = stringTable[value]; + + // note: bytecode serialization format uses 1-based table indices, 0 is reserved to mean nil + if (index == 0) + index = uint32_t(stringTable.size()); + + return index; +} + +int32_t BytecodeBuilder::addConstantNil() +{ + Constant c = {Constant::Type_Nil}; + + ConstantKey k = {Constant::Type_Nil}; + return addConstant(k, c); +} + +int32_t BytecodeBuilder::addConstantBoolean(bool value) +{ + Constant c = {Constant::Type_Boolean}; + c.valueBoolean = value; + + ConstantKey k = {Constant::Type_Boolean, value}; + return addConstant(k, c); +} + +int32_t BytecodeBuilder::addConstantNumber(double value) +{ + Constant c = {Constant::Type_Number}; + c.valueNumber = value; + + ConstantKey k = {Constant::Type_Number}; + static_assert(sizeof(k.value) == sizeof(value), "Expecting double to be 64-bit"); + memcpy(&k.value, &value, sizeof(value)); + + return addConstant(k, c); +} + +int32_t BytecodeBuilder::addConstantString(StringRef value) +{ + unsigned int index = addStringTableEntry(value); + + Constant c = {Constant::Type_String}; + c.valueString = index; + + ConstantKey k = {Constant::Type_String, index}; + + return addConstant(k, c); +} + +int32_t BytecodeBuilder::addImport(uint32_t iid) +{ + Constant c = {Constant::Type_Import}; + c.valueImport = iid; + + ConstantKey k = {Constant::Type_Import, iid}; + + return addConstant(k, c); +} + +int32_t BytecodeBuilder::addConstantTable(const TableShape& shape) +{ + if (int32_t* cache = tableShapeMap.find(shape)) + return *cache; + + uint32_t id = uint32_t(constants.size()); + + if (id >= kMaxConstantCount) + return -1; + + Constant value = {Constant::Type_Table}; + value.valueTable = uint32_t(tableShapes.size()); + + tableShapeMap[shape] = int32_t(id); + tableShapes.push_back(shape); + constants.push_back(value); + + return int32_t(id); +} + +int32_t BytecodeBuilder::addConstantClosure(uint32_t fid) +{ + Constant c = {Constant::Type_Closure}; + c.valueClosure = fid; + + ConstantKey k = {Constant::Type_Closure, fid}; + + return addConstant(k, c); +} + +int16_t BytecodeBuilder::addChildFunction(uint32_t fid) +{ + uint32_t id = uint32_t(protos.size()); + + if (id >= kMaxClosureCount) + return -1; + + protos.push_back(fid); + + return int16_t(id); +} + +void BytecodeBuilder::emitABC(LuauOpcode op, uint8_t a, uint8_t b, uint8_t c) +{ + uint32_t insn = uint32_t(op) | (a << 8) | (b << 16) | (c << 24); + + insns.push_back(insn); + lines.push_back(debugLine); +} + +void BytecodeBuilder::emitAD(LuauOpcode op, uint8_t a, int16_t d) +{ + uint32_t insn = uint32_t(op) | (a << 8) | (uint16_t(d) << 16); + + insns.push_back(insn); + lines.push_back(debugLine); +} + +void BytecodeBuilder::emitE(LuauOpcode op, int32_t e) +{ + uint32_t insn = uint32_t(op) | (uint32_t(e) << 8); + + insns.push_back(insn); + lines.push_back(debugLine); +} + +void BytecodeBuilder::emitAux(uint32_t aux) +{ + insns.push_back(aux); + lines.push_back(debugLine); +} + +size_t BytecodeBuilder::emitLabel() +{ + return insns.size(); +} + +bool BytecodeBuilder::patchJumpD(size_t jumpLabel, size_t targetLabel) +{ + LUAU_ASSERT(jumpLabel < insns.size()); + + unsigned int jumpInsn = insns[jumpLabel]; + (void)jumpInsn; + + LUAU_ASSERT(LUAU_INSN_OP(jumpInsn) == LOP_JUMP || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIF || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOT || + LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFEQ || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFLE || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFLT || + LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTEQ || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTLE || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTLT || + LUAU_INSN_OP(jumpInsn) == LOP_FORNPREP || LUAU_INSN_OP(jumpInsn) == LOP_FORNLOOP || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP || + LUAU_INSN_OP(jumpInsn) == LOP_FORGPREP_INEXT || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP_INEXT || + LUAU_INSN_OP(jumpInsn) == LOP_FORGPREP_NEXT || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP_NEXT || + LUAU_INSN_OP(jumpInsn) == LOP_JUMPBACK || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFEQK || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTEQK); + LUAU_ASSERT(LUAU_INSN_D(jumpInsn) == 0); + + LUAU_ASSERT(targetLabel <= insns.size()); + + int offset = int(targetLabel) - int(jumpLabel) - 1; + + if (int16_t(offset) == offset) + { + insns[jumpLabel] |= uint16_t(offset) << 16; + } + else if (abs(offset) < kMaxJumpDistance) + { + // our jump doesn't fit into 16 bits; we will need to repatch the bytecode sequence with jump trampolines, see expandJumps + hasLongJumps = true; + } + else + { + return false; + } + + jumps.push_back({uint32_t(jumpLabel), uint32_t(targetLabel)}); + return true; +} + +bool BytecodeBuilder::patchSkipC(size_t jumpLabel, size_t targetLabel) +{ + LUAU_ASSERT(jumpLabel < insns.size()); + + unsigned int jumpInsn = insns[jumpLabel]; + (void)jumpInsn; + + LUAU_ASSERT(LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL || LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL1 || LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL2 || + LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL2K); + LUAU_ASSERT(LUAU_INSN_C(jumpInsn) == 0); + + int offset = int(targetLabel) - int(jumpLabel) - 1; + + if (uint8_t(offset) != offset) + { + return false; + } + + insns[jumpLabel] |= offset << 24; + return true; +} + +void BytecodeBuilder::setDebugFunctionName(StringRef name) +{ + unsigned int index = addStringTableEntry(name); + + functions[currentFunction].debugname = index; + + if (dumpFunctionPtr) + functions[currentFunction].dumpname = std::string(name.data, name.length); +} + +void BytecodeBuilder::setDebugLine(int line) +{ + debugLine = line; +} + +void BytecodeBuilder::pushDebugLocal(StringRef name, uint8_t reg, uint32_t startpc, uint32_t endpc) +{ + unsigned int index = addStringTableEntry(name); + + DebugLocal local; + local.name = index; + local.reg = reg; + local.startpc = startpc; + local.endpc = endpc; + + debugLocals.push_back(local); +} + +void BytecodeBuilder::pushDebugUpval(StringRef name) +{ + unsigned int index = addStringTableEntry(name); + + DebugUpval upval; + upval.name = index; + + debugUpvals.push_back(upval); +} + +uint32_t BytecodeBuilder::getDebugPC() const +{ + return uint32_t(insns.size()); +} + +void BytecodeBuilder::finalize() +{ + LUAU_ASSERT(bytecode.empty()); + bytecode = char(LBC_VERSION); + + writeStringTable(bytecode); + + writeVarInt(bytecode, uint32_t(functions.size())); + + for (const Function& func : functions) + bytecode += func.data; + + LUAU_ASSERT(mainFunction < functions.size()); + writeVarInt(bytecode, mainFunction); +} + +void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const +{ + LUAU_ASSERT(id < functions.size()); + const Function& func = functions[id]; + + // header + writeByte(ss, func.maxstacksize); + writeByte(ss, func.numparams); + writeByte(ss, func.numupvalues); + writeByte(ss, func.isvararg); + + // instructions + writeVarInt(ss, uint32_t(insns.size())); + + for (size_t i = 0; i < insns.size();) + { + uint8_t op = LUAU_INSN_OP(insns[i]); + LUAU_ASSERT(op < LOP__COUNT); + + int oplen = getOpLength(LuauOpcode(op)); + uint8_t openc = encoder ? encoder->encodeOp(op) : op; + + writeInt(ss, openc | (insns[i] & ~0xff)); + + for (int j = 1; j < oplen; ++j) + writeInt(ss, insns[i + j]); + + i += oplen; + } + + // constants + writeVarInt(ss, uint32_t(constants.size())); + + for (const Constant& c : constants) + { + switch (c.type) + { + case Constant::Type_Nil: + writeByte(ss, LBC_CONSTANT_NIL); + break; + + case Constant::Type_Boolean: + writeByte(ss, LBC_CONSTANT_BOOLEAN); + writeByte(ss, c.valueBoolean); + break; + + case Constant::Type_Number: + writeByte(ss, LBC_CONSTANT_NUMBER); + writeDouble(ss, c.valueNumber); + break; + + case Constant::Type_String: + writeByte(ss, LBC_CONSTANT_STRING); + writeVarInt(ss, c.valueString); + break; + + case Constant::Type_Import: + writeByte(ss, LBC_CONSTANT_IMPORT); + writeInt(ss, c.valueImport); + break; + + case Constant::Type_Table: + { + const TableShape& shape = tableShapes[c.valueTable]; + writeByte(ss, LBC_CONSTANT_TABLE); + writeVarInt(ss, uint32_t(shape.length)); + for (unsigned int i = 0; i < shape.length; ++i) + writeVarInt(ss, shape.keys[i]); + break; + } + + case Constant::Type_Closure: + writeByte(ss, LBC_CONSTANT_CLOSURE); + writeVarInt(ss, c.valueClosure); + break; + + default: + LUAU_ASSERT(!"Unsupported constant type"); + } + } + + // child protos + writeVarInt(ss, uint32_t(protos.size())); + + for (uint32_t child : protos) + writeVarInt(ss, child); + + // debug info + writeVarInt(ss, func.debugname); + + bool hasLines = true; + + for (int line : lines) + if (line == 0) + { + hasLines = false; + break; + } + + if (hasLines) + { + writeByte(ss, 1); + + writeLineInfo(ss); + } + else + { + writeByte(ss, 0); + } + + bool hasDebug = !debugLocals.empty() || !debugUpvals.empty(); + + if (hasDebug) + { + writeByte(ss, 1); + + writeVarInt(ss, uint32_t(debugLocals.size())); + + for (const DebugLocal& l : debugLocals) + { + writeVarInt(ss, l.name); + writeVarInt(ss, l.startpc); + writeVarInt(ss, l.endpc); + writeByte(ss, l.reg); + } + + writeVarInt(ss, uint32_t(debugUpvals.size())); + + for (const DebugUpval& l : debugUpvals) + { + writeVarInt(ss, l.name); + } + } + else + { + writeByte(ss, 0); + } +} + +void BytecodeBuilder::writeLineInfo(std::string& ss) const +{ + // this function encodes lines inside each span as a 8-bit delta to span baseline + // span is always a power of two; depending on the line info input, it may need to be as low as 1 + int span = 1 << 24; + + // first pass: determine span length + for (size_t offset = 0; offset < lines.size(); offset += span) + { + size_t next = offset; + + int min = lines[offset]; + int max = lines[offset]; + + for (; next < lines.size() && next < offset + span; ++next) + { + min = std::min(min, lines[next]); + max = std::max(max, lines[next]); + + if (max - min > 255) + break; + } + + if (next < lines.size() && next - offset < size_t(span)) + { + // since not all lines in the range fit in 8b delta, we need to shrink the span + // next iteration will need to reprocess some lines again since span changed + span = 1 << log2(int(next - offset)); + } + } + + // second pass: compute span base + std::vector baseline((lines.size() - 1) / span + 1); + + for (size_t offset = 0; offset < lines.size(); offset += span) + { + size_t next = offset; + + int min = lines[offset]; + + for (; next < lines.size() && next < offset + span; ++next) + min = std::min(min, lines[next]); + + baseline[offset / span] = min; + } + + // third pass: write resulting data + int logspan = log2(span); + + writeByte(ss, logspan); + + uint8_t lastOffset = 0; + + for (size_t i = 0; i < lines.size(); ++i) + { + int delta = lines[i] - baseline[i >> logspan]; + LUAU_ASSERT(delta >= 0 && delta <= 255); + + writeByte(ss, delta - lastOffset); + lastOffset = delta; + } + + int lastLine = 0; + + for (size_t i = 0; i < baseline.size(); ++i) + { + writeInt(ss, baseline[i] - lastLine); + lastLine = baseline[i]; + } +} + +void BytecodeBuilder::writeStringTable(std::string& ss) const +{ + std::vector strings(stringTable.size()); + + for (auto& p : stringTable) + { + LUAU_ASSERT(p.second > 0 && p.second <= strings.size()); + strings[p.second - 1] = p.first; + } + + writeVarInt(ss, uint32_t(strings.size())); + + for (auto& s : strings) + { + writeVarInt(ss, uint32_t(s.length)); + ss.append(s.data, s.length); + } +} + +uint32_t BytecodeBuilder::getImportId(int32_t id0) +{ + LUAU_ASSERT(unsigned(id0) < 1024); + + return (1u << 30) | (id0 << 20); +} + +uint32_t BytecodeBuilder::getImportId(int32_t id0, int32_t id1) +{ + LUAU_ASSERT(unsigned(id0 | id1) < 1024); + + return (2u << 30) | (id0 << 20) | (id1 << 10); +} + +uint32_t BytecodeBuilder::getImportId(int32_t id0, int32_t id1, int32_t id2) +{ + LUAU_ASSERT(unsigned(id0 | id1 | id2) < 1024); + + return (3u << 30) | (id0 << 20) | (id1 << 10) | id2; +} + +uint32_t BytecodeBuilder::getStringHash(StringRef key) +{ + // This hashing algorithm should match luaS_hash defined in VM/lstring.cpp for short inputs; we can't use that code directly to keep compiler and + // VM independent in terms of compilation/linking. The resulting string hashes are embedded into bytecode binary and result in a better initial + // guess for the field hashes which improves performance during initial code execution. We omit the long string processing here for simplicity, as + // it doesn't really matter on long identifiers. + const char* str = key.data; + size_t len = key.length; + + unsigned int h = unsigned(len); + + // original Lua 5.1 hash for compatibility (exact match when len<32) + for (size_t i = len; i > 0; --i) + h ^= (h << 5) + (h >> 2) + (uint8_t)str[i - 1]; + + return h; +} + +void BytecodeBuilder::foldJumps() +{ + // if our function has long jumps, some processing below can make jump instructions not-jumps (e.g. JUMP->RETURN) + // it's safer to skip this processing + if (hasLongJumps) + return; + + for (Jump& jump : jumps) + { + uint32_t jumpLabel = jump.source; + + uint32_t jumpInsn = insns[jumpLabel]; + + // follow jump target through forward unconditional jumps + // we only follow forward jumps to make sure the process terminates + uint32_t targetLabel = jumpLabel + 1 + LUAU_INSN_D(jumpInsn); + LUAU_ASSERT(targetLabel < insns.size()); + uint32_t targetInsn = insns[targetLabel]; + + while (LUAU_INSN_OP(targetInsn) == LOP_JUMP && LUAU_INSN_D(targetInsn) >= 0) + { + targetLabel = targetLabel + 1 + LUAU_INSN_D(targetInsn); + LUAU_ASSERT(targetLabel < insns.size()); + targetInsn = insns[targetLabel]; + } + + int offset = int(targetLabel) - int(jumpLabel) - 1; + + // for unconditional jumps to RETURN, we can replace JUMP with RETURN + if (LUAU_INSN_OP(jumpInsn) == LOP_JUMP && LUAU_INSN_OP(targetInsn) == LOP_RETURN) + { + insns[jumpLabel] = targetInsn; + lines[jumpLabel] = lines[targetLabel]; + } + else if (int16_t(offset) == offset) + { + insns[jumpLabel] &= 0xffff; + insns[jumpLabel] |= uint16_t(offset) << 16; + } + + jump.target = targetLabel; + } +} + +void BytecodeBuilder::expandJumps() +{ + if (!hasLongJumps) + return; + + // we have some jump instructions that couldn't be patched which means their offset didn't fit into 16 bits + // our strategy for replacing instructions is as follows: instead of + // OP jumpoffset + // we will synthesize a jump trampoline before our instruction (note that jump offsets are relative to next instruction): + // JUMP +1 + // JUMPX jumpoffset + // OP -2 + // the idea is that during forward execution, we will jump over JUMPX into OP; if OP decides to jump, it will jump to JUMPX + // JUMPX can carry a 24-bit jump offset + + // jump trampolines expand the code size, which can increase existing jump distances. + // because of this, we may need to expand jumps that previously fit into 16-bit just fine. + // the worst-case expansion is 3x, so to be conservative we will repatch all jumps that have an offset >= 32767/3 + const int kMaxJumpDistanceConservative = 32767 / 3; + + // we will need to process jumps in order + std::sort(jumps.begin(), jumps.end(), [](const Jump& lhs, const Jump& rhs) { + return lhs.source < rhs.source; + }); + + // first, let's add jump thunks for every jump with a distance that's too big + // we will create new instruction buffers, with remap table keeping track of the moves: remap[oldpc] = newpc + std::vector remap(insns.size()); + + std::vector newinsns; + std::vector newlines; + + LUAU_ASSERT(insns.size() == lines.size()); + newinsns.reserve(insns.size()); + newlines.reserve(insns.size()); + + size_t currentJump = 0; + size_t pendingTrampolines = 0; + + for (size_t i = 0; i < insns.size();) + { + uint8_t op = LUAU_INSN_OP(insns[i]); + LUAU_ASSERT(op < LOP__COUNT); + + if (currentJump < jumps.size() && jumps[currentJump].source == i) + { + int offset = int(jumps[currentJump].target) - int(jumps[currentJump].source) - 1; + + if (abs(offset) > kMaxJumpDistanceConservative) + { + // insert jump trampoline as described above; we keep JUMPX offset uninitialized in this pass + newinsns.push_back(LOP_JUMP | (1 << 16)); + newinsns.push_back(LOP_JUMPX); + + newlines.push_back(lines[i]); + newlines.push_back(lines[i]); + + pendingTrampolines++; + } + + currentJump++; + } + + int oplen = getOpLength(LuauOpcode(op)); + + // copy instruction and line info to the new stream + for (int j = 0; j < oplen; ++j) + { + remap[i] = uint32_t(newinsns.size()); + + newinsns.push_back(insns[i]); + newlines.push_back(lines[i]); + + i++; + } + } + + LUAU_ASSERT(currentJump == jumps.size()); + LUAU_ASSERT(pendingTrampolines > 0); + + // now we need to recompute offsets for jump instructions - we could not do this in the first pass because the offsets are between *target* + // instructions + for (Jump& jump : jumps) + { + int offset = int(jump.target) - int(jump.source) - 1; + int newoffset = int(remap[jump.target]) - int(remap[jump.source]) - 1; + + if (abs(offset) > kMaxJumpDistanceConservative) + { + // fix up jump trampoline + uint32_t& insnt = newinsns[remap[jump.source] - 1]; + uint32_t& insnj = newinsns[remap[jump.source]]; + + LUAU_ASSERT(LUAU_INSN_OP(insnt) == LOP_JUMPX); + + // patch JUMPX to JUMPX to target location; note that newoffset is the offset of the jump *relative to OP*, so we need to add 1 to make it + // relative to JUMPX + insnt &= 0xff; + insnt |= uint32_t(newoffset + 1) << 8; + + // patch OP to OP -2 + insnj &= 0xffff; + insnj |= uint16_t(-2) << 16; + + pendingTrampolines--; + } + else + { + uint32_t& insn = newinsns[remap[jump.source]]; + + // make sure jump instruction had the correct offset before we started + LUAU_ASSERT(LUAU_INSN_D(insn) == offset); + + // patch instruction with the new offset + LUAU_ASSERT(int16_t(newoffset) == newoffset); + + insn &= 0xffff; + insn |= uint16_t(newoffset) << 16; + } + } + + LUAU_ASSERT(pendingTrampolines == 0); + + // this was hard, but we're done. + insns.swap(newinsns); + lines.swap(newlines); +} + +std::string BytecodeBuilder::getError(const std::string& message) +{ + // 0 acts as a special marker for error bytecode (it's equal to LBC_VERSION for valid bytecode blobs) + std::string result; + result += char(0); + result += message; + + return result; +} + +#ifdef LUAU_ASSERTENABLED +void BytecodeBuilder::validate() const +{ +#define VREG(v) LUAU_ASSERT(unsigned(v) < func.maxstacksize) +#define VREGRANGE(v, count) LUAU_ASSERT(unsigned(v + (count < 0 ? 0 : count)) <= func.maxstacksize) +#define VUPVAL(v) LUAU_ASSERT(unsigned(v) < func.numupvalues) +#define VCONST(v, kind) LUAU_ASSERT(unsigned(v) < constants.size() && constants[v].type == Constant::Type_##kind) +#define VCONSTANY(v) LUAU_ASSERT(unsigned(v) < constants.size()) +#define VJUMP(v) LUAU_ASSERT(size_t(i + 1 + v) < insns.size() && insnvalid[i + 1 + v]) + + LUAU_ASSERT(currentFunction != ~0u); + + const Function& func = functions[currentFunction]; + + // first pass: tag instruction offsets so that we can validate jumps + std::vector insnvalid(insns.size(), false); + + for (size_t i = 0; i < insns.size();) + { + uint8_t op = LUAU_INSN_OP(insns[i]); + + insnvalid[i] = true; + + i += getOpLength(LuauOpcode(op)); + LUAU_ASSERT(i <= insns.size()); + } + + // second pass: validate the rest of the bytecode + for (size_t i = 0; i < insns.size();) + { + uint32_t insn = insns[i]; + uint8_t op = LUAU_INSN_OP(insn); + + switch (op) + { + case LOP_LOADNIL: + VREG(LUAU_INSN_A(insn)); + break; + + case LOP_LOADB: + VREG(LUAU_INSN_A(insn)); + LUAU_ASSERT(LUAU_INSN_B(insn) == 0 || LUAU_INSN_B(insn) == 1); + VJUMP(LUAU_INSN_C(insn)); + break; + + case LOP_LOADN: + VREG(LUAU_INSN_A(insn)); + break; + + case LOP_LOADK: + VREG(LUAU_INSN_A(insn)); + VCONSTANY(LUAU_INSN_D(insn)); + break; + + case LOP_MOVE: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + break; + + case LOP_GETGLOBAL: + case LOP_SETGLOBAL: + VREG(LUAU_INSN_A(insn)); + VCONST(insns[i + 1], String); + break; + + case LOP_GETUPVAL: + case LOP_SETUPVAL: + VREG(LUAU_INSN_A(insn)); + VUPVAL(LUAU_INSN_B(insn)); + break; + + case LOP_CLOSEUPVALS: + VREG(LUAU_INSN_A(insn)); + break; + + case LOP_GETIMPORT: + VREG(LUAU_INSN_A(insn)); + VCONST(LUAU_INSN_D(insn), Import); + // TODO: check insn[i + 1] for conformance with 10-bit import encoding + break; + + case LOP_GETTABLE: + case LOP_SETTABLE: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + VREG(LUAU_INSN_C(insn)); + break; + + case LOP_GETTABLEKS: + case LOP_SETTABLEKS: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + VCONST(insns[i + 1], String); + break; + + case LOP_GETTABLEN: + case LOP_SETTABLEN: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + break; + + case LOP_NEWCLOSURE: + { + VREG(LUAU_INSN_A(insn)); + LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < protos.size()); + LUAU_ASSERT(protos[LUAU_INSN_D(insn)] < functions.size()); + unsigned int numupvalues = functions[protos[LUAU_INSN_D(insn)]].numupvalues; + + for (unsigned int j = 0; j < numupvalues; ++j) + { + LUAU_ASSERT(i + 1 + j < insns.size()); + uint32_t cinsn = insns[i + 1 + j]; + LUAU_ASSERT(LUAU_INSN_OP(cinsn) == LOP_CAPTURE); + } + } + break; + + case LOP_NAMECALL: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + VCONST(insns[i + 1], String); + LUAU_ASSERT(LUAU_INSN_OP(insns[i + 2]) == LOP_CALL); + break; + + case LOP_CALL: + { + int nparams = LUAU_INSN_B(insn) - 1; + int nresults = LUAU_INSN_C(insn) - 1; + VREG(LUAU_INSN_A(insn)); + VREGRANGE(LUAU_INSN_A(insn) + 1, nparams); // 1..nparams + VREGRANGE(LUAU_INSN_A(insn), nresults); // 1..nresults + } + break; + + case LOP_RETURN: + { + int nresults = LUAU_INSN_B(insn) - 1; + VREGRANGE(LUAU_INSN_A(insn), nresults); // 0..nresults-1 + } + break; + + case LOP_JUMP: + VJUMP(LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIF: + case LOP_JUMPIFNOT: + VREG(LUAU_INSN_A(insn)); + VJUMP(LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIFEQ: + case LOP_JUMPIFLE: + case LOP_JUMPIFLT: + case LOP_JUMPIFNOTEQ: + case LOP_JUMPIFNOTLE: + case LOP_JUMPIFNOTLT: + VREG(LUAU_INSN_A(insn)); + VREG(insns[i + 1]); + VJUMP(LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIFEQK: + case LOP_JUMPIFNOTEQK: + VREG(LUAU_INSN_A(insn)); + VCONSTANY(insns[i + 1]); + VJUMP(LUAU_INSN_D(insn)); + break; + + case LOP_ADD: + case LOP_SUB: + case LOP_MUL: + case LOP_DIV: + case LOP_MOD: + case LOP_POW: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + VREG(LUAU_INSN_C(insn)); + break; + + case LOP_ADDK: + case LOP_SUBK: + case LOP_MULK: + case LOP_DIVK: + case LOP_MODK: + case LOP_POWK: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + VCONST(LUAU_INSN_C(insn), Number); + break; + + case LOP_AND: + case LOP_OR: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + VREG(LUAU_INSN_C(insn)); + break; + + case LOP_ANDK: + case LOP_ORK: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + VCONSTANY(LUAU_INSN_C(insn)); + break; + + case LOP_CONCAT: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + VREG(LUAU_INSN_C(insn)); + LUAU_ASSERT(LUAU_INSN_B(insn) <= LUAU_INSN_C(insn)); + break; + + case LOP_NOT: + case LOP_MINUS: + case LOP_LENGTH: + VREG(LUAU_INSN_A(insn)); + VREG(LUAU_INSN_B(insn)); + break; + + case LOP_NEWTABLE: + VREG(LUAU_INSN_A(insn)); + break; + + case LOP_DUPTABLE: + VREG(LUAU_INSN_A(insn)); + VCONST(LUAU_INSN_D(insn), Table); + break; + + case LOP_SETLIST: + { + int count = LUAU_INSN_C(insn) - 1; + VREG(LUAU_INSN_A(insn)); + VREGRANGE(LUAU_INSN_B(insn), count); + } + break; + + case LOP_FORNPREP: + case LOP_FORNLOOP: + VREG(LUAU_INSN_A(insn) + 2); // for loop protocol: A, A+1, A+2 are used for iteration + VJUMP(LUAU_INSN_D(insn)); + break; + + case LOP_FORGLOOP: + VREG( + LUAU_INSN_A(insn) + 2 + insns[i + 1]); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, ... are loop variables + VJUMP(LUAU_INSN_D(insn)); + LUAU_ASSERT(insns[i + 1] >= 1); + break; + + case LOP_FORGPREP_INEXT: + case LOP_FORGLOOP_INEXT: + case LOP_FORGPREP_NEXT: + case LOP_FORGLOOP_NEXT: + VREG(LUAU_INSN_A(insn) + 4); // forg loop protocol: A, A+1, A+2 are used for iteration protocol; A+3, A+4 are loop variables + VJUMP(LUAU_INSN_D(insn)); + break; + + case LOP_GETVARARGS: + { + int nresults = LUAU_INSN_B(insn) - 1; + VREGRANGE(LUAU_INSN_A(insn), nresults); // 0..nresults-1 + } + break; + + case LOP_DUPCLOSURE: + { + VREG(LUAU_INSN_A(insn)); + VCONST(LUAU_INSN_D(insn), Closure); + unsigned int proto = constants[LUAU_INSN_D(insn)].valueClosure; + LUAU_ASSERT(proto < functions.size()); + unsigned int numupvalues = functions[proto].numupvalues; + + for (unsigned int j = 0; j < numupvalues; ++j) + { + LUAU_ASSERT(i + 1 + j < insns.size()); + uint32_t cinsn = insns[i + 1 + j]; + LUAU_ASSERT(LUAU_INSN_OP(cinsn) == LOP_CAPTURE); + LUAU_ASSERT(LUAU_INSN_A(cinsn) == LCT_VAL || LUAU_INSN_A(cinsn) == LCT_UPVAL); + } + } + break; + + case LOP_PREPVARARGS: + LUAU_ASSERT(LUAU_INSN_A(insn) == func.numparams); + LUAU_ASSERT(func.isvararg); + break; + + case LOP_BREAK: + break; + + case LOP_JUMPBACK: + VJUMP(LUAU_INSN_D(insn)); + break; + + case LOP_LOADKX: + VREG(LUAU_INSN_A(insn)); + VCONSTANY(insns[i + 1]); + break; + + case LOP_JUMPX: + VJUMP(LUAU_INSN_E(insn)); + break; + + case LOP_FASTCALL: + VJUMP(LUAU_INSN_C(insn)); + LUAU_ASSERT(LUAU_INSN_OP(insns[i + 1 + LUAU_INSN_C(insn)]) == LOP_CALL); + break; + + case LOP_FASTCALL1: + VREG(LUAU_INSN_B(insn)); + VJUMP(LUAU_INSN_C(insn)); + LUAU_ASSERT(LUAU_INSN_OP(insns[i + 1 + LUAU_INSN_C(insn)]) == LOP_CALL); + break; + + case LOP_FASTCALL2: + VREG(LUAU_INSN_B(insn)); + VJUMP(LUAU_INSN_C(insn)); + LUAU_ASSERT(LUAU_INSN_OP(insns[i + 1 + LUAU_INSN_C(insn)]) == LOP_CALL); + VREG(insns[i + 1]); + break; + + case LOP_FASTCALL2K: + VREG(LUAU_INSN_B(insn)); + VJUMP(LUAU_INSN_C(insn)); + LUAU_ASSERT(LUAU_INSN_OP(insns[i + 1 + LUAU_INSN_C(insn)]) == LOP_CALL); + VCONSTANY(insns[i + 1]); + break; + + case LOP_COVERAGE: + break; + + case LOP_CAPTURE: + switch (LUAU_INSN_A(insn)) + { + case LCT_VAL: + case LCT_REF: + VREG(LUAU_INSN_B(insn)); + break; + + case LCT_UPVAL: + VUPVAL(LUAU_INSN_B(insn)); + break; + + default: + LUAU_ASSERT(!"Unsupported capture type"); + } + break; + + default: + LUAU_ASSERT(!"Unsupported opcode"); + } + + i += getOpLength(LuauOpcode(op)); + LUAU_ASSERT(i <= insns.size()); + } + +#undef VREG +#undef VREGEND +#undef VUPVAL +#undef VCONST +#undef VCONSTANY +#undef VJUMP +} +#endif + +const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result) const +{ + uint32_t insn = *code++; + + switch (LUAU_INSN_OP(insn)) + { + case LOP_LOADNIL: + formatAppend(result, "LOADNIL R%d\n", LUAU_INSN_A(insn)); + break; + + case LOP_LOADB: + if (LUAU_INSN_C(insn)) + formatAppend(result, "LOADB R%d %d +%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + else + formatAppend(result, "LOADB R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn)); + break; + + case LOP_LOADN: + formatAppend(result, "LOADN R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_LOADK: + formatAppend(result, "LOADK R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_MOVE: + formatAppend(result, "MOVE R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn)); + break; + + case LOP_GETGLOBAL: + formatAppend(result, "GETGLOBAL R%d K%d\n", LUAU_INSN_A(insn), *code++); + break; + + case LOP_SETGLOBAL: + formatAppend(result, "SETGLOBAL R%d K%d\n", LUAU_INSN_A(insn), *code++); + break; + + case LOP_GETUPVAL: + formatAppend(result, "GETUPVAL R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn)); + break; + + case LOP_SETUPVAL: + formatAppend(result, "SETUPVAL R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn)); + break; + + case LOP_CLOSEUPVALS: + formatAppend(result, "CLOSEUPVALS R%d\n", LUAU_INSN_A(insn)); + break; + + case LOP_GETIMPORT: + formatAppend(result, "GETIMPORT R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + code++; // AUX + break; + + case LOP_GETTABLE: + formatAppend(result, "GETTABLE R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_SETTABLE: + formatAppend(result, "SETTABLE R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_GETTABLEKS: + formatAppend(result, "GETTABLEKS R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code++); + break; + + case LOP_SETTABLEKS: + formatAppend(result, "SETTABLEKS R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code++); + break; + + case LOP_GETTABLEN: + formatAppend(result, "GETTABLEN R%d R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn) + 1); + break; + + case LOP_SETTABLEN: + formatAppend(result, "SETTABLEN R%d R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn) + 1); + break; + + case LOP_NEWCLOSURE: + formatAppend(result, "NEWCLOSURE R%d P%d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_NAMECALL: + formatAppend(result, "NAMECALL R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), *code++); + break; + + case LOP_CALL: + formatAppend(result, "CALL R%d %d %d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn) - 1, LUAU_INSN_C(insn) - 1); + break; + + case LOP_RETURN: + formatAppend(result, "RETURN R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn) - 1); + break; + + case LOP_JUMP: + formatAppend(result, "JUMP %+d\n", LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIF: + formatAppend(result, "JUMPIF R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIFNOT: + formatAppend(result, "JUMPIFNOT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIFEQ: + formatAppend(result, "JUMPIFEQ R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIFLE: + formatAppend(result, "JUMPIFLE R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIFLT: + formatAppend(result, "JUMPIFLT R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIFNOTEQ: + formatAppend(result, "JUMPIFNOTEQ R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIFNOTLE: + formatAppend(result, "JUMPIFNOTLE R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIFNOTLT: + formatAppend(result, "JUMPIFNOTLT R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + break; + + case LOP_ADD: + formatAppend(result, "ADD R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_SUB: + formatAppend(result, "SUB R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_MUL: + formatAppend(result, "MUL R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_DIV: + formatAppend(result, "DIV R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_MOD: + formatAppend(result, "MOD R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_POW: + formatAppend(result, "POW R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_ADDK: + formatAppend(result, "ADDK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_SUBK: + formatAppend(result, "SUBK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_MULK: + formatAppend(result, "MULK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_DIVK: + formatAppend(result, "DIVK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_MODK: + formatAppend(result, "MODK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_POWK: + formatAppend(result, "POWK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_AND: + formatAppend(result, "AND R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_OR: + formatAppend(result, "OR R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_ANDK: + formatAppend(result, "ANDK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_ORK: + formatAppend(result, "ORK R%d R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_CONCAT: + formatAppend(result, "CONCAT R%d R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + + case LOP_NOT: + formatAppend(result, "NOT R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn)); + break; + + case LOP_MINUS: + formatAppend(result, "MINUS R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn)); + break; + + case LOP_LENGTH: + formatAppend(result, "LENGTH R%d R%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn)); + break; + + case LOP_NEWTABLE: + formatAppend(result, "NEWTABLE R%d %d %d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn) == 0 ? 0 : 1 << (LUAU_INSN_B(insn) - 1), *code++); + break; + + case LOP_DUPTABLE: + formatAppend(result, "DUPTABLE R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_SETLIST: + formatAppend(result, "SETLIST R%d R%d %d [%d]\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn) - 1, *code++); + break; + + case LOP_FORNPREP: + formatAppend(result, "FORNPREP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_FORNLOOP: + formatAppend(result, "FORNLOOP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_FORGLOOP: + formatAppend(result, "FORGLOOP R%d %+d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn), *code++); + break; + + case LOP_FORGPREP_INEXT: + formatAppend(result, "FORGPREP_INEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_FORGLOOP_INEXT: + formatAppend(result, "FORGLOOP_INEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_FORGPREP_NEXT: + formatAppend(result, "FORGPREP_NEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_FORGLOOP_NEXT: + formatAppend(result, "FORGLOOP_NEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_GETVARARGS: + formatAppend(result, "GETVARARGS R%d %d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn) - 1); + break; + + case LOP_DUPCLOSURE: + formatAppend(result, "DUPCLOSURE R%d K%d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + break; + + case LOP_BREAK: + formatAppend(result, "BREAK\n"); + break; + + case LOP_JUMPBACK: + formatAppend(result, "JUMPBACK %+d\n", LUAU_INSN_D(insn)); + break; + + case LOP_LOADKX: + formatAppend(result, "LOADKX R%d K%d\n", LUAU_INSN_A(insn), *code++); + break; + + case LOP_JUMPX: + formatAppend(result, "JUMPX %+d\n", LUAU_INSN_E(insn)); + break; + + case LOP_FASTCALL: + formatAppend(result, "FASTCALL %d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_C(insn)); + break; + + case LOP_FASTCALL1: + formatAppend(result, "FASTCALL1 %d R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + break; + case LOP_FASTCALL2: + { + uint32_t aux = *code++; + formatAppend(result, "FASTCALL2 %d R%d R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, LUAU_INSN_C(insn)); + break; + } + case LOP_FASTCALL2K: + { + uint32_t aux = *code++; + formatAppend(result, "FASTCALL2K %d R%d K%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, LUAU_INSN_C(insn)); + break; + } + + case LOP_COVERAGE: + formatAppend(result, "COVERAGE\n"); + break; + + case LOP_CAPTURE: + formatAppend(result, "CAPTURE %s %c%d\n", + LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL" : LUAU_INSN_A(insn) == LCT_REF ? "REF" : LUAU_INSN_A(insn) == LCT_VAL ? "VAL" : "", + LUAU_INSN_A(insn) == LCT_UPVAL ? 'U' : 'R', LUAU_INSN_B(insn)); + break; + + case LOP_JUMPIFEQK: + formatAppend(result, "JUMPIFEQK R%d K%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + break; + + case LOP_JUMPIFNOTEQK: + formatAppend(result, "JUMPIFNOTEQK R%d K%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + break; + + default: + LUAU_ASSERT(!"Unsupported opcode"); + } + + return code; +} + +std::string BytecodeBuilder::dumpCurrentFunction() const +{ + if ((dumpFlags & Dump_Code) == 0) + return std::string(); + + const uint32_t* code = insns.data(); + const uint32_t* codeEnd = insns.data() + insns.size(); + + int lastLine = -1; + + std::string result; + + if (dumpFlags & Dump_Locals) + { + for (size_t i = 0; i < debugLocals.size(); ++i) + { + const DebugLocal& l = debugLocals[i]; + + LUAU_ASSERT(l.startpc < l.endpc); + LUAU_ASSERT(l.startpc < lines.size()); + LUAU_ASSERT(l.endpc <= lines.size()); // endpc is exclusive in the debug info, but it's more intuitive to print inclusive data + + // it would be nice to emit name as well but it requires reverse lookup through stringtable + formatAppend(result, "local %d: reg %d, start pc %d line %d, end pc %d line %d\n", int(i), l.reg, l.startpc, lines[l.startpc], + l.endpc - 1, lines[l.endpc - 1]); + } + } + + while (code != codeEnd) + { + uint8_t op = LUAU_INSN_OP(*code); + + if (op == LOP_PREPVARARGS) + { + // Don't emit function header in bytecode - it's used for call dispatching and doesn't contain "interesting" information + code++; + continue; + } + + if (dumpFlags & Dump_Source) + { + int line = lines[code - insns.data()]; + + if (line > 0 && line != lastLine) + { + LUAU_ASSERT(size_t(line - 1) < dumpSource.size()); + formatAppend(result, "%5d: %s\n", line, dumpSource[line - 1].c_str()); + lastLine = line; + } + } + + if (dumpFlags & Dump_Lines) + { + formatAppend(result, "%d: ", lines[code - insns.data()]); + } + + code = dumpInstruction(code, result); + } + + return result; +} + +void BytecodeBuilder::setDumpSource(const std::string& source) +{ + dumpSource.clear(); + + std::string::size_type pos = 0; + + while (pos != std::string::npos) + { + std::string::size_type next = source.find('\n', pos); + + if (next == std::string::npos) + { + dumpSource.push_back(source.substr(pos)); + pos = next; + } + else + { + dumpSource.push_back(source.substr(pos, next - pos)); + pos = next + 1; + } + + if (!dumpSource.back().empty() && dumpSource.back().back() == '\r') + dumpSource.back().pop_back(); + } +} + +std::string BytecodeBuilder::dumpFunction(uint32_t id) const +{ + LUAU_ASSERT(id < functions.size()); + + return functions[id].dump; +} + +std::string BytecodeBuilder::dumpEverything() const +{ + std::string result; + + for (size_t i = 0; i < functions.size(); ++i) + { + std::string debugname = functions[i].dumpname.empty() ? "??" : functions[i].dumpname; + + formatAppend(result, "Function %d (%s):\n", int(i), debugname.c_str()); + + result += functions[i].dump; + result += "\n"; + } + + return result; +} + +} // namespace Luau diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp new file mode 100644 index 0000000..022eccb --- /dev/null +++ b/Compiler/src/Compiler.cpp @@ -0,0 +1,3778 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Compiler.h" + +#include "Luau/Parser.h" +#include "Luau/BytecodeBuilder.h" +#include "Luau/Common.h" + +#include +#include +#include + +LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) +LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false) +LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false) +LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) + +namespace Luau +{ + +static const uint32_t kMaxRegisterCount = 255; +static const uint32_t kMaxUpvalueCount = 200; +static const uint32_t kMaxLocalCount = 200; + +static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"}; + +CompileError::CompileError(const Location& location, const std::string& message) + : location(location) + , message(message) +{ +} + +CompileError::~CompileError() throw() {} + +const char* CompileError::what() const throw() +{ + return message.c_str(); +} + +const Location& CompileError::getLocation() const +{ + return location; +} + +// NOINLINE is used to limit the stack cost of this function due to std::string object / exception plumbing +LUAU_NOINLINE void CompileError::raise(const Location& location, const char* format, ...) +{ + va_list args; + va_start(args, format); + std::string message = vformat(format, args); + va_end(args); + + throw CompileError(location, message); +} + +static BytecodeBuilder::StringRef sref(AstName name) +{ + LUAU_ASSERT(name.value); + return {name.value, strlen(name.value)}; +} + +static BytecodeBuilder::StringRef sref(AstArray data) +{ + LUAU_ASSERT(data.data); + return {data.data, data.size}; +} + +struct Compiler +{ + struct Constant; + struct RegScope; + + Compiler(BytecodeBuilder& bytecode, const CompileOptions& options) + : bytecode(bytecode) + , options(options) + , functions(nullptr) + , locals(nullptr) + , globals(AstName()) + , constants(nullptr) + , predictedTableSize(nullptr) + { + } + + uint8_t getLocal(AstLocal* local) + { + Local* l = locals.find(local); + LUAU_ASSERT(l); + LUAU_ASSERT(l->allocated); + + return l->reg; + } + + uint8_t getUpval(AstLocal* local) + { + for (size_t uid = 0; uid < upvals.size(); ++uid) + if (upvals[uid] == local) + return uint8_t(uid); + + if (upvals.size() >= kMaxUpvalueCount) + CompileError::raise( + local->location, "Out of upvalue registers when trying to allocate %s: exceeded limit %d", local->name.value, kMaxUpvalueCount); + + // mark local as captured so that closeLocals emits LOP_CLOSEUPVALS accordingly + Local& l = locals[local]; + l.captured = true; + + upvals.push_back(local); + + return uint8_t(upvals.size() - 1); + } + + bool allPathsEndWithReturn(AstStat* node) + { + if (AstStatBlock* stat = node->as()) + return stat->body.size > 0 && allPathsEndWithReturn(stat->body.data[stat->body.size - 1]); + else if (node->is()) + return true; + else if (AstStatIf* stat = node->as()) + return stat->elsebody && allPathsEndWithReturn(stat->thenbody) && allPathsEndWithReturn(stat->elsebody); + else + return false; + } + + void emitLoadK(uint8_t target, int32_t cid) + { + LUAU_ASSERT(cid >= 0); + + if (cid < 32768) + { + bytecode.emitAD(LOP_LOADK, target, int16_t(cid)); + } + else + { + bytecode.emitAD(LOP_LOADKX, target, 0); + bytecode.emitAux(cid); + } + } + + uint32_t compileFunction(AstExprFunction* func) + { + LUAU_ASSERT(!functions.contains(func)); + LUAU_ASSERT(regTop == 0 && stackSize == 0 && localStack.empty() && upvals.empty()); + + RegScope rs(this); + + bool self = func->self != 0; + uint32_t fid = bytecode.beginFunction(uint8_t(self + func->args.size), func->vararg); + + setDebugLine(func); + + if (func->vararg) + bytecode.emitABC(LOP_PREPVARARGS, uint8_t(self + func->args.size), 0, 0); + + uint8_t args = allocReg(func, self + unsigned(func->args.size)); + + if (func->self) + pushLocal(func->self, args); + + for (size_t i = 0; i < func->args.size; ++i) + pushLocal(func->args.data[i], uint8_t(args + self + i)); + + AstStatBlock* stat = func->body; + + for (size_t i = 0; i < stat->body.size; ++i) + compileStat(stat->body.data[i]); + + // valid function bytecode must always end with RETURN + // we elide this if we're guaranteed to hit a RETURN statement regardless of the control flow + if (!allPathsEndWithReturn(stat)) + { + setDebugLineEnd(stat); + closeLocals(0); + + bytecode.emitABC(LOP_RETURN, 0, 1, 0); + } + + // constant folding may remove some upvalue refs from bytecode, so this puts them back + if (options.optimizationLevel >= 1 && options.debugLevel >= 2) + gatherConstUpvals(func); + + if (options.debugLevel >= 1 && func->debugname.value) + bytecode.setDebugFunctionName(sref(func->debugname)); + + if (options.debugLevel >= 2 && !upvals.empty()) + { + for (AstLocal* l : upvals) + bytecode.pushDebugUpval(sref(l->name)); + } + + if (options.optimizationLevel >= 1) + bytecode.foldJumps(); + + bytecode.expandJumps(); + + popLocals(0); + + bytecode.endFunction(uint8_t(stackSize), uint8_t(upvals.size())); + + stackSize = 0; + + Function& f = functions[func]; + f.id = fid; + f.upvals = std::move(upvals); + + return fid; + } + + // note: this doesn't just clobber target (assuming it's temp), but also clobbers *all* allocated registers >= target! + // this is important to be able to support "multret" semantics due to Lua call frame structure + bool compileExprTempMultRet(AstExpr* node, uint8_t target) + { + if (AstExprCall* expr = node->as()) + { + // We temporarily swap out regTop to have targetTop work correctly... + // This is a crude hack but it's necessary for correctness :( + RegScope rs(this, target); + compileExprCall(expr, target, /* targetCount= */ 0, /* targetTop= */ true, /* multRet= */ true); + return true; + } + else if (AstExprVarargs* expr = node->as()) + { + // We temporarily swap out regTop to have targetTop work correctly... + // This is a crude hack but it's necessary for correctness :( + RegScope rs(this, target); + compileExprVarargs(expr, target, /* targetCount= */ 0, /* multRet= */ true); + return true; + } + else + { + compileExprTemp(node, target); + return false; + } + } + + // note: this doesn't just clobber target (assuming it's temp), but also clobbers *all* allocated registers >= target! + // this is important to be able to emit code that takes fewer registers and runs faster + void compileExprTempTop(AstExpr* node, uint8_t target) + { + // We temporarily swap out regTop to have targetTop work correctly... + // This is a crude hack but it's necessary for performance :( + // It makes sure that nested call expressions can use targetTop optimization and don't need to have too many registers + RegScope rs(this, target + 1); + compileExprTemp(node, target); + } + + void compileExprVarargs(AstExprVarargs* expr, uint8_t target, uint8_t targetCount, bool multRet = false) + { + LUAU_ASSERT(!multRet || unsigned(target + targetCount) == regTop); + + setDebugLine(expr); // normally compileExpr sets up line info, but compileExprCall can be called directly + + bytecode.emitABC(LOP_GETVARARGS, target, multRet ? 0 : uint8_t(targetCount + 1), 0); + } + + void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) + { + LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop); + + setDebugLine(expr); // normally compileExpr sets up line info, but compileExprCall can be called directly + + RegScope rs(this); + + unsigned int regCount = std::max(unsigned(1 + expr->self + expr->args.size), unsigned(targetCount)); + + // Optimization: if target points to the top of the stack, we can start the call at oldTop - 1 and won't need MOVE at the end + uint8_t regs = targetTop ? allocReg(expr, regCount - targetCount) - targetCount : allocReg(expr, regCount); + + uint8_t selfreg = 0; + + int bfid = -1; + + if (options.optimizationLevel >= 1) + { + Builtin builtin = getBuiltin(expr->func); + bfid = getBuiltinFunctionId(builtin); + } + + if (expr->self) + { + AstExprIndexName* fi = expr->func->as(); + LUAU_ASSERT(fi); + + // Optimization: use local register directly in NAMECALL if possible + if (isExprLocalReg(fi->expr)) + { + selfreg = getLocal(fi->expr->as()->local); + } + else + { + // Note: to be able to compile very deeply nested self call chains (obj:method1():method2():...), we need to be able to do this in + // finite stack space NAMECALL will happily move object from regs to regs+1 but we need to compute it into regs so that + // compileExprTempTop doesn't increase stack usage for every recursive call + selfreg = regs; + + compileExprTempTop(fi->expr, selfreg); + } + } + else if (bfid < 0) + { + compileExprTempTop(expr->func, regs); + } + + // Note: if the last argument is ExprVararg or ExprCall, we need to route that directly to the called function preserving the # of args + bool multCall = false; + bool skipArgs = false; + + if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) + { + AstExpr* last = expr->args.data[expr->args.size - 1]; + skipArgs = !(last->is() || last->is()); + } + + if (!skipArgs) + { + for (size_t i = 0; i < expr->args.size; ++i) + if (i + 1 == expr->args.size) + multCall = compileExprTempMultRet(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); + else + compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); + } + + setDebugLine(expr->func); + + if (expr->self) + { + AstExprIndexName* fi = expr->func->as(); + LUAU_ASSERT(fi); + + BytecodeBuilder::StringRef iname = sref(fi->index); + int32_t cid = bytecode.addConstantString(iname); + if (cid < 0) + CompileError::raise(fi->location, "Exceeded constant limit; simplify the code to compile"); + + bytecode.emitABC(LOP_NAMECALL, regs, selfreg, uint8_t(BytecodeBuilder::getStringHash(iname))); + bytecode.emitAux(cid); + } + else if (bfid >= 0) + { + size_t fastcallLabel; + + if (skipArgs) + { + LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2; + + uint32_t args[2] = {}; + for (size_t i = 0; i < expr->args.size; ++i) + { + if (i > 0) + { + if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0) + { + opc = LOP_FASTCALL2K; + args[i] = cid; + break; + } + } + + if (isExprLocalReg(expr->args.data[i])) + args[i] = getLocal(expr->args.data[i]->as()->local); + else + { + args[i] = uint8_t(regs + 1 + i); + compileExprTempTop(expr->args.data[i], args[i]); + } + } + + fastcallLabel = bytecode.emitLabel(); + bytecode.emitABC(opc, uint8_t(bfid), args[0], 0); + if (opc != LOP_FASTCALL1) + bytecode.emitAux(args[1]); + + // Set up a traditional Lua stack for the subsequent LOP_CALL. + // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for + // these FASTCALL variants. + for (size_t i = 0; i < expr->args.size; ++i) + { + if (i > 0 && opc == LOP_FASTCALL2K) + { + emitLoadK(uint8_t(regs + 1 + i), args[i]); + break; + } + + if (args[i] != regs + 1 + i) + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), args[i], 0); + } + } + else + { + fastcallLabel = bytecode.emitLabel(); + bytecode.emitABC(LOP_FASTCALL, uint8_t(bfid), 0, 0); + } + + // note, these instructions are normally not executed and are used as a fallback for FASTCALL + // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten + compileExprTemp(expr->func, regs); + + size_t callLabel = bytecode.emitLabel(); + + // FASTCALL will skip over the instructions needed to compute function and jump over CALL which must immediately follow the instruction + // sequence after FASTCALL + if (!bytecode.patchSkipC(fastcallLabel, callLabel)) + CompileError::raise(expr->func->location, "Exceeded jump distance limit; simplify the code to compile"); + } + + bytecode.emitABC(LOP_CALL, regs, multCall ? 0 : uint8_t(expr->self + expr->args.size + 1), multRet ? 0 : uint8_t(targetCount + 1)); + + // if we didn't output results directly to target, we need to move them + if (!targetTop) + { + for (size_t i = 0; i < targetCount; ++i) + bytecode.emitABC(LOP_MOVE, uint8_t(target + i), uint8_t(regs + i), 0); + } + } + + bool shouldShareClosure(AstExprFunction* func) + { + const Function* f = functions.find(func); + if (!f) + return false; + + for (AstLocal* uv : f->upvals) + { + Local* ul = locals.find(uv); + LUAU_ASSERT(ul); + + if (ul->written) + return false; + + // it's technically safe to share closures whenever all upvalues are immutable + // this is because of a runtime equality check in DUPCLOSURE. + // however, this results in frequent deoptimization and increases the set of reachable objects, making some temporary objects permanent + // instead we apply a heuristic: we share closures if they refer to top-level upvalues, or closures that refer to top-level upvalues + // this will only deoptimize (outside of fenv changes) if top level code is executed twice with different results. + if (uv->functionDepth != 0 || uv->loopDepth != 0) + { + if (!ul->func) + return false; + + if (ul->func != func && !shouldShareClosure(ul->func)) + return false; + } + } + + return true; + } + + void compileExprFunction(AstExprFunction* expr, uint8_t target) + { + const Function* f = functions.find(expr); + LUAU_ASSERT(f); + + // when the closure has upvalues we'll use this to create the closure at runtime + // when the closure has no upvalues, we use constant closures that technically don't rely on the child function list + // however, it's still important to add the child function because debugger relies on the function hierarchy when setting breakpoints + int16_t pid = bytecode.addChildFunction(f->id); + if (pid < 0) + CompileError::raise(expr->location, "Exceeded closure limit; simplify the code to compile"); + + bool shared = false; + + if (FFlag::LuauPreloadClosuresUpval) + { + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) + { + int32_t cid = bytecode.addConstantClosure(f->id); + + if (cid >= 0 && cid < 32768) + { + bytecode.emitAD(LOP_DUPCLOSURE, target, cid); + shared = true; + } + } + } + // Optimization: when closure has no upvalues, instead of allocating it every time we can share closure objects + // (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it is used) + else if (FFlag::LuauPreloadClosures && options.optimizationLevel >= 1 && f->upvals.empty() && !setfenvUsed) + { + int32_t cid = bytecode.addConstantClosure(f->id); + + if (cid >= 0 && cid < 32768) + { + bytecode.emitAD(LOP_DUPCLOSURE, target, cid); + return; + } + } + + if (!shared) + bytecode.emitAD(LOP_NEWCLOSURE, target, pid); + + for (AstLocal* uv : f->upvals) + { + LUAU_ASSERT(uv->functionDepth < expr->functionDepth); + + Local* ul = locals.find(uv); + LUAU_ASSERT(ul); + + bool immutable = !ul->written; + + if (uv->functionDepth == expr->functionDepth - 1) + { + // get local variable + uint8_t reg = getLocal(uv); + + bytecode.emitABC(LOP_CAPTURE, immutable ? LCT_VAL : LCT_REF, reg, 0); + } + else + { + // get upvalue from parent frame + // note: this will add uv to the current upvalue list if necessary + uint8_t uid = getUpval(uv); + + bytecode.emitABC(LOP_CAPTURE, LCT_UPVAL, uid, 0); + } + } + } + + LuauOpcode getUnaryOp(AstExprUnary::Op op) + { + switch (op) + { + case AstExprUnary::Not: + return LOP_NOT; + + case AstExprUnary::Minus: + return LOP_MINUS; + + case AstExprUnary::Len: + return LOP_LENGTH; + + default: + LUAU_ASSERT(!"Unexpected unary operation"); + return LOP_NOP; + } + } + + LuauOpcode getBinaryOpArith(AstExprBinary::Op op, bool k = false) + { + switch (op) + { + case AstExprBinary::Add: + return k ? LOP_ADDK : LOP_ADD; + + case AstExprBinary::Sub: + return k ? LOP_SUBK : LOP_SUB; + + case AstExprBinary::Mul: + return k ? LOP_MULK : LOP_MUL; + + case AstExprBinary::Div: + return k ? LOP_DIVK : LOP_DIV; + + case AstExprBinary::Mod: + return k ? LOP_MODK : LOP_MOD; + + case AstExprBinary::Pow: + return k ? LOP_POWK : LOP_POW; + + default: + LUAU_ASSERT(!"Unexpected binary operation"); + return LOP_NOP; + } + } + + LuauOpcode getJumpOpCompare(AstExprBinary::Op op, bool not_ = false) + { + switch (op) + { + case AstExprBinary::CompareNe: + return not_ ? LOP_JUMPIFEQ : LOP_JUMPIFNOTEQ; + + case AstExprBinary::CompareEq: + return not_ ? LOP_JUMPIFNOTEQ : LOP_JUMPIFEQ; + + case AstExprBinary::CompareLt: + case AstExprBinary::CompareGt: + return not_ ? LOP_JUMPIFNOTLT : LOP_JUMPIFLT; + + case AstExprBinary::CompareLe: + case AstExprBinary::CompareGe: + return not_ ? LOP_JUMPIFNOTLE : LOP_JUMPIFLE; + + default: + LUAU_ASSERT(!"Unexpected binary operation"); + return LOP_NOP; + } + } + + bool isConstant(AstExpr* node) + { + const Constant* cv = constants.find(node); + + return cv && cv->type != Constant::Type_Unknown; + } + + bool isConstantTrue(AstExpr* node) + { + const Constant* cv = constants.find(node); + + return cv && cv->type != Constant::Type_Unknown && cv->isTruthful(); + } + + bool isConstantFalse(AstExpr* node) + { + const Constant* cv = constants.find(node); + + return cv && cv->type != Constant::Type_Unknown && !cv->isTruthful(); + } + + size_t compileCompareJump(AstExprBinary* expr, bool not_ = false) + { + RegScope rs(this); + LuauOpcode opc = getJumpOpCompare(expr->op, not_); + + bool isEq = (opc == LOP_JUMPIFEQ || opc == LOP_JUMPIFNOTEQ); + AstExpr* left = expr->left; + AstExpr* right = expr->right; + + bool operandIsConstant = isConstant(right); + if (isEq && !operandIsConstant) + { + operandIsConstant = isConstant(left); + if (operandIsConstant) + std::swap(left, right); + } + + uint8_t rl = compileExprAuto(left, rs); + int32_t rr = -1; + + if (isEq && operandIsConstant) + { + if (opc == LOP_JUMPIFEQ) + opc = LOP_JUMPIFEQK; + else if (opc == LOP_JUMPIFNOTEQ) + opc = LOP_JUMPIFNOTEQK; + + rr = getConstantIndex(right); + LUAU_ASSERT(rr >= 0); + } + else + rr = compileExprAuto(right, rs); + + size_t jumpLabel = bytecode.emitLabel(); + + if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe) + { + bytecode.emitAD(opc, rr, 0); + bytecode.emitAux(rl); + } + else + { + bytecode.emitAD(opc, rl, 0); + bytecode.emitAux(rr); + } + + return jumpLabel; + } + + int32_t getConstantNumber(AstExpr* node) + { + const Constant* c = constants.find(node); + + if (c && c->type == Constant::Type_Number) + { + int cid = bytecode.addConstantNumber(c->valueNumber); + if (cid < 0) + CompileError::raise(node->location, "Exceeded constant limit; simplify the code to compile"); + + return cid; + } + + return -1; + } + + int32_t getConstantIndex(AstExpr* node) + { + const Constant* c = constants.find(node); + + if (!c) + return -1; + + int cid = -1; + + switch (c->type) + { + case Constant::Type_Nil: + cid = bytecode.addConstantNil(); + break; + + case Constant::Type_Boolean: + cid = bytecode.addConstantBoolean(c->valueBoolean); + break; + + case Constant::Type_Number: + cid = bytecode.addConstantNumber(c->valueNumber); + break; + + case Constant::Type_String: + cid = bytecode.addConstantString(sref(c->valueString)); + break; + + default: + LUAU_ASSERT(!"Unexpected constant type"); + return -1; + } + + if (cid < 0) + CompileError::raise(node->location, "Exceeded constant limit; simplify the code to compile"); + + return cid; + } + + // compile expr to target temp register + // if the expr (or not expr if onlyTruth is false) is truthful, jump via skipJump + // if the expr (or not expr if onlyTruth is false) is falseful, fall through (target isn't guaranteed to be updated in this case) + // if target is omitted, then the jump behavior is the same - skipJump or fallthrough depending on the truthfulness of the expression + void compileConditionValue(AstExpr* node, const uint8_t* target, std::vector& skipJump, bool onlyTruth) + { + // Optimization: we don't need to compute constant values + const Constant* cv = constants.find(node); + + if (cv && cv->type != Constant::Type_Unknown) + { + // note that we only need to compute the value if it's truthful; otherwise we cal fall through + if (cv->isTruthful() == onlyTruth) + { + if (target) + compileExprTemp(node, *target); + + skipJump.push_back(bytecode.emitLabel()); + bytecode.emitAD(LOP_JUMP, 0, 0); + } + return; + } + + if (AstExprBinary* expr = node->as()) + { + switch (expr->op) + { + case AstExprBinary::And: + case AstExprBinary::Or: + { + // disambiguation: there's 4 cases (we only need truthful or falseful results based on onlyTruth) + // onlyTruth = 1: a and b transforms to a ? b : dontcare + // onlyTruth = 1: a or b transforms to a ? a : a + // onlyTruth = 0: a and b transforms to !a ? a : b + // onlyTruth = 0: a or b transforms to !a ? b : dontcare + if (onlyTruth == (expr->op == AstExprBinary::And)) + { + // we need to compile the left hand side, and skip to "dontcare" (aka fallthrough of the entire statement) if it's not the same as + // onlyTruth if it's the same then the result of the expression is the right hand side because of this, we *never* care about the + // result of the left hand side + std::vector elseJump; + compileConditionValue(expr->left, nullptr, elseJump, !onlyTruth); + + // fallthrough indicates that we need to compute & return the right hand side + // we use compileConditionValue again to process any extra and/or statements directly + compileConditionValue(expr->right, target, skipJump, onlyTruth); + + size_t elseLabel = bytecode.emitLabel(); + + patchJumps(expr, elseJump, elseLabel); + } + else + { + // we need to compute the left hand side first; note that we will jump to skipJump if we know the answer + compileConditionValue(expr->left, target, skipJump, onlyTruth); + + // we will fall through if computing the left hand didn't give us an "interesting" result + // we still use compileConditionValue to recursively optimize any and/or/compare statements + compileConditionValue(expr->right, target, skipJump, onlyTruth); + } + return; + } + break; + + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + case AstExprBinary::CompareLt: + case AstExprBinary::CompareLe: + case AstExprBinary::CompareGt: + case AstExprBinary::CompareGe: + { + if (target) + { + // since target is a temp register, we'll initialize it to 1, and then jump if the comparison is true + // if the comparison is false, we'll fallthrough and target will still be 1 but target has unspecified value for falseful results + // when we only care about falseful values instead of truthful values, the process is the same but with flipped conditionals + bytecode.emitABC(LOP_LOADB, *target, onlyTruth ? 1 : 0, 0); + } + + size_t jumpLabel = compileCompareJump(expr, /* not= */ !onlyTruth); + + skipJump.push_back(jumpLabel); + return; + } + break; + + // fall-through to default path below + default:; + } + } + + if (AstExprUnary* expr = node->as()) + { + // if we *do* need to compute the target, we'd have to inject "not" ops on every return path + // this is possible but cumbersome; so for now we only optimize not expression when we *don't* need the value + if (!target && expr->op == AstExprUnary::Not) + { + compileConditionValue(expr->expr, target, skipJump, !onlyTruth); + return; + } + } + + if (AstExprGroup* expr = node->as()) + { + compileConditionValue(expr->expr, target, skipJump, onlyTruth); + return; + } + + RegScope rs(this); + uint8_t reg; + + if (target) + { + reg = *target; + compileExprTemp(node, reg); + } + else + { + reg = compileExprAuto(node, rs); + } + + skipJump.push_back(bytecode.emitLabel()); + bytecode.emitAD(onlyTruth ? LOP_JUMPIF : LOP_JUMPIFNOT, reg, 0); + } + + // checks if compiling the expression as a condition value generates code that's faster than using compileExpr + bool isConditionFast(AstExpr* node) + { + const Constant* cv = constants.find(node); + + if (cv && cv->type != Constant::Type_Unknown) + return true; + + if (AstExprBinary* expr = node->as()) + { + switch (expr->op) + { + case AstExprBinary::And: + case AstExprBinary::Or: + return true; + + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + case AstExprBinary::CompareLt: + case AstExprBinary::CompareLe: + case AstExprBinary::CompareGt: + case AstExprBinary::CompareGe: + return true; + + default: + return false; + } + } + + if (AstExprGroup* expr = node->as()) + return isConditionFast(expr->expr); + + return false; + } + + void compileExprAndOr(AstExprBinary* expr, uint8_t target, bool targetTemp) + { + bool and_ = (expr->op == AstExprBinary::And); + + RegScope rs(this); + + // Optimization: when left hand side is a constant, we can emit left hand side or right hand side + const Constant* cl = constants.find(expr->left); + + if (cl && cl->type != Constant::Type_Unknown) + { + compileExpr(and_ == cl->isTruthful() ? expr->right : expr->left, target, targetTemp); + return; + } + + // Note: two optimizations below can lead to inefficient codegen when the left hand side is a condition + if (!isConditionFast(expr->left)) + { + // Optimization: when right hand side is a local variable, we can use AND/OR + if (isExprLocalReg(expr->right)) + { + uint8_t lr = compileExprAuto(expr->left, rs); + uint8_t rr = getLocal(expr->right->as()->local); + + bytecode.emitABC(and_ ? LOP_AND : LOP_OR, target, lr, rr); + return; + } + + // Optimization: when right hand side is a constant, we can use ANDK/ORK + int32_t cid = getConstantIndex(expr->right); + + if (cid >= 0 && cid <= 255) + { + uint8_t lr = compileExprAuto(expr->left, rs); + + bytecode.emitABC(and_ ? LOP_ANDK : LOP_ORK, target, lr, uint8_t(cid)); + return; + } + } + + // Optimization: if target is a temp register, we can clobber it which allows us to compute the result directly into it + // If it's not a temp register, then something like `a = a > 1 or a + 2` may clobber `a` while evaluating left hand side, and `a+2` will break + uint8_t reg = targetTemp ? target : allocReg(expr, 1); + + std::vector skipJump; + compileConditionValue(expr->left, ®, skipJump, /* onlyTruth= */ !and_); + + compileExprTemp(expr->right, reg); + + size_t moveLabel = bytecode.emitLabel(); + + patchJumps(expr, skipJump, moveLabel); + + if (target != reg) + bytecode.emitABC(LOP_MOVE, target, reg, 0); + } + + void compileExprUnary(AstExprUnary* expr, uint8_t target) + { + RegScope rs(this); + + uint8_t re = compileExprAuto(expr->expr, rs); + + bytecode.emitABC(getUnaryOp(expr->op), target, re, 0); + } + + static void unrollConcats(std::vector& args) + { + for (;;) + { + AstExprBinary* be = args.back()->as(); + + if (!be || be->op != AstExprBinary::Concat) + break; + + args.back() = be->left; + args.push_back(be->right); + } + } + + void compileExprBinary(AstExprBinary* expr, uint8_t target, bool targetTemp) + { + RegScope rs(this); + + switch (expr->op) + { + case AstExprBinary::Add: + case AstExprBinary::Sub: + case AstExprBinary::Mul: + case AstExprBinary::Div: + case AstExprBinary::Mod: + case AstExprBinary::Pow: + { + int32_t rc = getConstantNumber(expr->right); + + if (rc >= 0 && rc <= 255) + { + uint8_t rl = compileExprAuto(expr->left, rs); + + bytecode.emitABC(getBinaryOpArith(expr->op, /* k= */ true), target, rl, uint8_t(rc)); + } + else + { + uint8_t rl = compileExprAuto(expr->left, rs); + uint8_t rr = compileExprAuto(expr->right, rs); + + bytecode.emitABC(getBinaryOpArith(expr->op), target, rl, rr); + } + } + break; + + case AstExprBinary::Concat: + { + std::vector args = {expr->left, expr->right}; + + // unroll the tree of concats down the right hand side to be able to do multiple ops + unrollConcats(args); + + uint8_t regs = allocReg(expr, unsigned(args.size())); + + for (size_t i = 0; i < args.size(); ++i) + compileExprTemp(args[i], uint8_t(regs + i)); + + bytecode.emitABC(LOP_CONCAT, target, regs, uint8_t(regs + args.size() - 1)); + } + break; + + case AstExprBinary::CompareNe: + case AstExprBinary::CompareEq: + case AstExprBinary::CompareLt: + case AstExprBinary::CompareLe: + case AstExprBinary::CompareGt: + case AstExprBinary::CompareGe: + { + size_t jumpLabel = compileCompareJump(expr); + + // note: this skips over the next LOADB instruction because of "1" in the C slot + bytecode.emitABC(LOP_LOADB, target, 0, 1); + + size_t thenLabel = bytecode.emitLabel(); + + bytecode.emitABC(LOP_LOADB, target, 1, 0); + + patchJump(expr, jumpLabel, thenLabel); + } + break; + + case AstExprBinary::And: + case AstExprBinary::Or: + { + compileExprAndOr(expr, target, targetTemp); + } + break; + + default: + LUAU_ASSERT(!"Unexpected binary operation"); + } + } + + void compileExprIfElse(AstExprIfElse* expr, uint8_t target, bool targetTemp) + { + if (isConstant(expr->condition)) + { + if (isConstantTrue(expr->condition)) + { + compileExpr(expr->trueExpr, target, targetTemp); + } + else + { + compileExpr(expr->falseExpr, target, targetTemp); + } + } + else + { + std::vector elseJump; + compileConditionValue(expr->condition, nullptr, elseJump, false); + compileExpr(expr->trueExpr, target, targetTemp); + + // Jump over else expression evaluation + size_t thenLabel = bytecode.emitLabel(); + bytecode.emitAD(LOP_JUMP, 0, 0); + + size_t elseLabel = bytecode.emitLabel(); + compileExpr(expr->falseExpr, target, targetTemp); + size_t endLabel = bytecode.emitLabel(); + + patchJumps(expr, elseJump, elseLabel); + patchJump(expr, thenLabel, endLabel); + } + } + + static uint8_t encodeHashSize(unsigned int hashSize) + { + size_t hashSizeLog2 = 0; + while ((1u << hashSizeLog2) < hashSize) + hashSizeLog2++; + + return hashSize == 0 ? 0 : uint8_t(hashSizeLog2 + 1); + } + + void compileExprTable(AstExprTable* expr, uint8_t target, bool targetTemp) + { + // Optimization: if the table is empty, we can compute it directly into the target + if (expr->items.size == 0) + { + auto [hashSize, arraySize] = predictedTableSize[expr]; + + bytecode.emitABC(LOP_NEWTABLE, target, encodeHashSize(hashSize), 0); + bytecode.emitAux(arraySize); + return; + } + + unsigned int arraySize = 0; + unsigned int hashSize = 0; + unsigned int recordSize = 0; + unsigned int indexSize = 0; + + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + + arraySize += (item.kind == AstExprTable::Item::List); + hashSize += (item.kind != AstExprTable::Item::List); + recordSize += (item.kind == AstExprTable::Item::Record); + } + + // Optimization: allocate sequential explicitly specified numeric indices ([1]) as arrays + if (arraySize == 0 && hashSize > 0) + { + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + AstExprConstantNumber* ckey = item.key->as(); + + indexSize += (ckey && ckey->value == double(indexSize + 1)); + } + + // we only perform the optimization if we don't have any other []-keys + // technically it's "safe" to do this even if we have other keys, but doing so changes iteration order and may break existing code + if (hashSize == recordSize + indexSize) + hashSize = recordSize; + else + indexSize = 0; + } + + int encodedHashSize = encodeHashSize(hashSize); + + RegScope rs(this); + + // Optimization: if target is a temp register, we can clobber it which allows us to compute the result directly into it + uint8_t reg = targetTemp ? target : allocReg(expr, 1); + + // Optimization: when all items are record fields, use template tables to compile expression + if (arraySize == 0 && indexSize == 0 && hashSize == recordSize && recordSize >= 1 && recordSize <= BytecodeBuilder::TableShape::kMaxLength) + { + BytecodeBuilder::TableShape shape; + + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + LUAU_ASSERT(item.kind == AstExprTable::Item::Record); + + AstExprConstantString* ckey = item.key->as(); + LUAU_ASSERT(ckey); + + int cid = bytecode.addConstantString(sref(ckey->value)); + if (cid < 0) + CompileError::raise(ckey->location, "Exceeded constant limit; simplify the code to compile"); + + LUAU_ASSERT(shape.length < BytecodeBuilder::TableShape::kMaxLength); + shape.keys[shape.length++] = int16_t(cid); + } + + int32_t tid = bytecode.addConstantTable(shape); + if (tid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + if (tid < 32768) + { + bytecode.emitAD(LOP_DUPTABLE, reg, int16_t(tid)); + } + else + { + bytecode.emitABC(LOP_NEWTABLE, reg, encodedHashSize, 0); + bytecode.emitAux(0); + } + } + else + { + // Optimization: instead of allocating one extra element when the last element of the table literal is ..., let SETLIST allocate the + // correct amount of storage + const AstExprTable::Item* last = expr->items.size > 0 ? &expr->items.data[expr->items.size - 1] : nullptr; + + bool trailingVarargs = last && last->kind == AstExprTable::Item::List && last->value->is(); + LUAU_ASSERT(!trailingVarargs || arraySize > 0); + + bytecode.emitABC(LOP_NEWTABLE, reg, encodedHashSize, 0); + bytecode.emitAux(arraySize - trailingVarargs + indexSize); + } + + unsigned int arrayChunkSize = std::min(16u, arraySize); + uint8_t arrayChunkReg = allocReg(expr, arrayChunkSize); + unsigned int arrayChunkCurrent = 0; + + unsigned int arrayIndex = 1; + bool multRet = false; + + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + + AstExpr* key = item.key; + AstExpr* value = item.value; + + // some key/value pairs don't require us to compile the expressions, so we need to setup the line info here + setDebugLine(value); + + if (options.coverageLevel >= 2) + { + bytecode.emitABC(LOP_COVERAGE, 0, 0, 0); + } + + // flush array chunk on overflow or before hash keys to maintain insertion order + if (arrayChunkCurrent > 0 && (key || arrayChunkCurrent == arrayChunkSize)) + { + bytecode.emitABC(LOP_SETLIST, reg, arrayChunkReg, uint8_t(arrayChunkCurrent + 1)); + bytecode.emitAux(arrayIndex); + arrayIndex += arrayChunkCurrent; + arrayChunkCurrent = 0; + } + + // items with a key are set one by one via SETTABLE/SETTABLEKS + if (key) + { + RegScope rsi(this); + + // Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax + if (AstExprConstantString* ckey = key->as()) + { + BytecodeBuilder::StringRef cname = sref(ckey->value); + int32_t cid = bytecode.addConstantString(cname); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname))); + bytecode.emitAux(cid); + } + else if (AstExprConstantNumber* ckey = key->as(); + ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value) + { + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1)); + } + else + { + uint8_t rk = compileExprAuto(key, rsi); + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLE, rv, reg, rk); + } + } + // items without a key are set using SETLIST so that we can initialize large arrays quickly + else + { + uint8_t temp = uint8_t(arrayChunkReg + arrayChunkCurrent); + + if (i + 1 == expr->items.size) + multRet = compileExprTempMultRet(value, temp); + else + compileExprTempTop(value, temp); + + arrayChunkCurrent++; + } + } + + // flush last array chunk; note that this needs multret handling if the last expression was multret + if (arrayChunkCurrent) + { + bytecode.emitABC(LOP_SETLIST, reg, arrayChunkReg, multRet ? 0 : uint8_t(arrayChunkCurrent + 1)); + bytecode.emitAux(arrayIndex); + } + + if (target != reg) + bytecode.emitABC(LOP_MOVE, target, reg, 0); + } + + bool canImport(AstExprGlobal* expr) + { + const Global* global = globals.find(expr->name); + + return options.optimizationLevel >= 1 && (!global || !global->written); + } + + bool canImportChain(AstExprGlobal* expr) + { + const Global* global = globals.find(expr->name); + + return options.optimizationLevel >= 1 && (!global || (!global->written && !global->special)); + } + + void compileExprIndexName(AstExprIndexName* expr, uint8_t target) + { + setDebugLine(expr); // normally compileExpr sets up line info, but compileExprIndexName can be called directly + + // Optimization: index chains that start from global variables can be compiled into GETIMPORT statement + AstExprGlobal* importRoot = 0; + AstExprIndexName* import1 = 0; + AstExprIndexName* import2 = 0; + + if (AstExprIndexName* index = expr->expr->as()) + { + importRoot = index->expr->as(); + import1 = index; + import2 = expr; + } + else + { + importRoot = expr->expr->as(); + import1 = expr; + } + + if (importRoot && canImportChain(importRoot)) + { + int32_t id0 = bytecode.addConstantString(sref(importRoot->name)); + int32_t id1 = bytecode.addConstantString(sref(import1->index)); + int32_t id2 = import2 ? bytecode.addConstantString(sref(import2->index)) : -1; + + if (id0 < 0 || id1 < 0 || (import2 && id2 < 0)) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + // Note: GETIMPORT encoding is limited to 10 bits per object id component + if (id0 < 1024 && id1 < 1024 && id2 < 1024) + { + uint32_t iid = import2 ? BytecodeBuilder::getImportId(id0, id1, id2) : BytecodeBuilder::getImportId(id0, id1); + int32_t cid = bytecode.addImport(iid); + + if (cid >= 0 && cid < 32768) + { + bytecode.emitAD(LOP_GETIMPORT, target, int16_t(cid)); + bytecode.emitAux(iid); + return; + } + } + } + + RegScope rs(this); + uint8_t reg = compileExprAuto(expr->expr, rs); + + BytecodeBuilder::StringRef iname = sref(expr->index); + int32_t cid = bytecode.addConstantString(iname); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + bytecode.emitABC(LOP_GETTABLEKS, target, reg, uint8_t(BytecodeBuilder::getStringHash(iname))); + bytecode.emitAux(cid); + } + + void compileExprIndexExpr(AstExprIndexExpr* expr, uint8_t target) + { + RegScope rs(this); + + const Constant* cv = constants.find(expr->index); + + if (cv && cv->type == Constant::Type_Number && double(int(cv->valueNumber)) == cv->valueNumber && cv->valueNumber >= 1 && + cv->valueNumber <= 256) + { + uint8_t rt = compileExprAuto(expr->expr, rs); + uint8_t i = uint8_t(int(cv->valueNumber) - 1); + + bytecode.emitABC(LOP_GETTABLEN, target, rt, i); + } + else if (cv && cv->type == Constant::Type_String) + { + uint8_t rt = compileExprAuto(expr->expr, rs); + + BytecodeBuilder::StringRef iname = sref(cv->valueString); + int32_t cid = bytecode.addConstantString(iname); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname))); + bytecode.emitAux(cid); + } + else + { + uint8_t rt = compileExprAuto(expr->expr, rs); + uint8_t ri = compileExprAuto(expr->index, rs); + + bytecode.emitABC(LOP_GETTABLE, target, rt, ri); + } + } + + void compileExprGlobal(AstExprGlobal* expr, uint8_t target) + { + // Optimization: builtin globals can be retrieved using GETIMPORT + if (canImport(expr)) + { + int32_t id0 = bytecode.addConstantString(sref(expr->name)); + if (id0 < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + // Note: GETIMPORT encoding is limited to 10 bits per object id component + if (id0 < 1024) + { + uint32_t iid = BytecodeBuilder::getImportId(id0); + int32_t cid = bytecode.addImport(iid); + + if (cid >= 0 && cid < 32768) + { + bytecode.emitAD(LOP_GETIMPORT, target, int16_t(cid)); + bytecode.emitAux(iid); + return; + } + } + } + + BytecodeBuilder::StringRef gname = sref(expr->name); + int32_t cid = bytecode.addConstantString(gname); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + bytecode.emitABC(LOP_GETGLOBAL, target, 0, uint8_t(BytecodeBuilder::getStringHash(gname))); + bytecode.emitAux(cid); + } + + void compileExprConstant(AstExpr* node, const Constant* cv, uint8_t target) + { + switch (cv->type) + { + case Constant::Type_Nil: + bytecode.emitABC(LOP_LOADNIL, target, 0, 0); + break; + + case Constant::Type_Boolean: + bytecode.emitABC(LOP_LOADB, target, cv->valueBoolean, 0); + break; + + case Constant::Type_Number: + { + double d = cv->valueNumber; + + if (d >= std::numeric_limits::min() && d <= std::numeric_limits::max() && double(int16_t(d)) == d && + !(d == 0.0 && signbit(d))) + { + // short number encoding: doesn't require a table entry lookup + bytecode.emitAD(LOP_LOADN, target, int16_t(d)); + } + else + { + // long number encoding: use generic constant path + int32_t cid = bytecode.addConstantNumber(d); + if (cid < 0) + CompileError::raise(node->location, "Exceeded constant limit; simplify the code to compile"); + + emitLoadK(target, cid); + } + } + break; + + case Constant::Type_String: + { + int32_t cid = bytecode.addConstantString(sref(cv->valueString)); + if (cid < 0) + CompileError::raise(node->location, "Exceeded constant limit; simplify the code to compile"); + + emitLoadK(target, cid); + } + break; + + default: + LUAU_ASSERT(!"Unexpected constant type"); + } + } + + void compileExpr(AstExpr* node, uint8_t target, bool targetTemp = false) + { + setDebugLine(node); + + if (options.coverageLevel >= 2 && needsCoverage(node)) + { + bytecode.emitABC(LOP_COVERAGE, 0, 0, 0); + } + + // Optimization: if expression has a constant value, we can emit it directly + if (const Constant* cv = constants.find(node)) + { + if (cv->type != Constant::Type_Unknown) + { + compileExprConstant(node, cv, target); + return; + } + } + + if (AstExprGroup* expr = node->as()) + { + compileExpr(expr->expr, target, targetTemp); + } + else if (node->is()) + { + bytecode.emitABC(LOP_LOADNIL, target, 0, 0); + } + else if (AstExprConstantBool* expr = node->as()) + { + bytecode.emitABC(LOP_LOADB, target, expr->value, 0); + } + else if (AstExprConstantNumber* expr = node->as()) + { + int32_t cid = bytecode.addConstantNumber(expr->value); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + emitLoadK(target, cid); + } + else if (AstExprConstantString* expr = node->as()) + { + int32_t cid = bytecode.addConstantString(sref(expr->value)); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + emitLoadK(target, cid); + } + else if (AstExprLocal* expr = node->as()) + { + if (expr->upvalue) + { + uint8_t uid = getUpval(expr->local); + + bytecode.emitABC(LOP_GETUPVAL, target, uid, 0); + } + else + { + uint8_t reg = getLocal(expr->local); + + bytecode.emitABC(LOP_MOVE, target, reg, 0); + } + } + else if (AstExprGlobal* expr = node->as()) + { + compileExprGlobal(expr, target); + } + else if (AstExprVarargs* expr = node->as()) + { + compileExprVarargs(expr, target, /* targetCount= */ 1); + } + else if (AstExprCall* expr = node->as()) + { + // Optimization: when targeting temporary registers, we can compile call in a special mode that doesn't require extra register moves + if (targetTemp && target == regTop - 1) + compileExprCall(expr, target, 1, /* targetTop= */ true); + else + compileExprCall(expr, target, /* targetCount= */ 1); + } + else if (AstExprIndexName* expr = node->as()) + { + compileExprIndexName(expr, target); + } + else if (AstExprIndexExpr* expr = node->as()) + { + compileExprIndexExpr(expr, target); + } + else if (AstExprFunction* expr = node->as()) + { + compileExprFunction(expr, target); + } + else if (AstExprTable* expr = node->as()) + { + compileExprTable(expr, target, targetTemp); + } + else if (AstExprUnary* expr = node->as()) + { + compileExprUnary(expr, target); + } + else if (AstExprBinary* expr = node->as()) + { + compileExprBinary(expr, target, targetTemp); + } + else if (AstExprTypeAssertion* expr = node->as()) + { + compileExpr(expr->expr, target, targetTemp); + } + else if (AstExprIfElse* expr = node->as(); FFlag::LuauIfElseExpressionBaseSupport && expr) + { + compileExprIfElse(expr, target, targetTemp); + } + else + { + LUAU_ASSERT(!"Unknown expression type"); + } + } + + void compileExprTemp(AstExpr* node, uint8_t target) + { + return compileExpr(node, target, /* targetTemp= */ true); + } + + uint8_t compileExprAuto(AstExpr* node, RegScope&) + { + // Optimization: directly return locals instead of copying them to a temporary + if (isExprLocalReg(node)) + return getLocal(node->as()->local); + + // note: the register is owned by the parent scope + uint8_t reg = allocReg(node, 1); + + compileExprTemp(node, reg); + + return reg; + } + + // initializes target..target+targetCount-1 range using expressions from the list + // if list has fewer expressions, and last expression is a call, we assume the call returns the rest of the values + // if list has fewer expressions, and last expression isn't a call, we fill the rest with nil + // assumes target register range can be clobbered and is at the top of the register space + void compileExprListTop(const AstArray& list, uint8_t target, uint8_t targetCount) + { + // we assume that target range is at the top of the register space and can be clobbered + // this is what allows us to compile the last call expression - if it's a call - using targetTop=true + LUAU_ASSERT(unsigned(target + targetCount) == regTop); + + if (list.size == targetCount) + { + for (size_t i = 0; i < list.size; ++i) + compileExprTemp(list.data[i], uint8_t(target + i)); + } + else if (list.size > targetCount) + { + for (size_t i = 0; i < targetCount; ++i) + compileExprTemp(list.data[i], uint8_t(target + i)); + + // compute expressions with values that go nowhere; this is required to run side-effecting code if any + for (size_t i = targetCount; i < list.size; ++i) + { + RegScope rsi(this); + compileExprAuto(list.data[i], rsi); + } + } + else if (list.size > 0) + { + for (size_t i = 0; i < list.size - 1; ++i) + compileExprTemp(list.data[i], uint8_t(target + i)); + + AstExpr* last = list.data[list.size - 1]; + + if (AstExprCall* expr = last->as()) + { + compileExprCall(expr, uint8_t(target + list.size - 1), uint8_t(targetCount - (list.size - 1)), /* targetTop= */ true); + } + else if (AstExprVarargs* expr = last->as()) + { + compileExprVarargs(expr, uint8_t(target + list.size - 1), uint8_t(targetCount - (list.size - 1))); + } + else + { + compileExprTemp(last, uint8_t(target + list.size - 1)); + + for (size_t i = list.size; i < targetCount; ++i) + bytecode.emitABC(LOP_LOADNIL, uint8_t(target + i), 0, 0); + } + } + else + { + for (size_t i = 0; i < targetCount; ++i) + bytecode.emitABC(LOP_LOADNIL, uint8_t(target + i), 0, 0); + } + } + + struct LValue + { + enum Kind + { + Kind_Local, + Kind_Upvalue, + Kind_Global, + Kind_IndexName, + Kind_IndexNumber, + Kind_IndexExpr, + }; + + Kind kind; + uint8_t reg; // register for local (Local) or table (Index*) + uint8_t upval; + uint8_t index; // register for index in IndexExpr + uint8_t number; // index-1 (0-255) in IndexNumber + BytecodeBuilder::StringRef name; + Location location; + }; + + LValue compileLValue(AstExpr* node, RegScope& rs) + { + setDebugLine(node); + + if (AstExprLocal* expr = node->as()) + { + if (expr->upvalue) + { + LValue result = {LValue::Kind_Upvalue}; + result.upval = getUpval(expr->local); + result.location = node->location; + + return result; + } + else + { + LValue result = {LValue::Kind_Local}; + result.reg = getLocal(expr->local); + result.location = node->location; + + return result; + } + } + else if (AstExprGlobal* expr = node->as()) + { + LValue result = {LValue::Kind_Global}; + result.name = sref(expr->name); + result.location = node->location; + + return result; + } + else if (AstExprIndexName* expr = node->as()) + { + LValue result = {LValue::Kind_IndexName}; + result.reg = compileExprAuto(expr->expr, rs); + result.name = sref(expr->index); + result.location = node->location; + + return result; + } + else if (AstExprIndexExpr* expr = node->as()) + { + const Constant* cv = constants.find(expr->index); + + if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && + double(int(cv->valueNumber)) == cv->valueNumber) + { + LValue result = {LValue::Kind_IndexNumber}; + result.reg = compileExprAuto(expr->expr, rs); + result.number = uint8_t(int(cv->valueNumber) - 1); + result.location = node->location; + + return result; + } + else if (cv && cv->type == Constant::Type_String) + { + LValue result = {LValue::Kind_IndexName}; + result.reg = compileExprAuto(expr->expr, rs); + result.name = sref(cv->valueString); + result.location = node->location; + + return result; + } + else + { + LValue result = {LValue::Kind_IndexExpr}; + result.reg = compileExprAuto(expr->expr, rs); + result.index = compileExprAuto(expr->index, rs); + result.location = node->location; + + return result; + } + } + else + { + LUAU_ASSERT(!"Unknown assignment expression"); + + return LValue(); + } + } + + void compileLValueUse(const LValue& lv, uint8_t reg, bool set) + { + switch (lv.kind) + { + case LValue::Kind_Local: + if (set) + bytecode.emitABC(LOP_MOVE, lv.reg, reg, 0); + else + bytecode.emitABC(LOP_MOVE, reg, lv.reg, 0); + break; + + case LValue::Kind_Upvalue: + bytecode.emitABC(set ? LOP_SETUPVAL : LOP_GETUPVAL, reg, lv.upval, 0); + break; + + case LValue::Kind_Global: + { + int32_t cid = bytecode.addConstantString(lv.name); + if (cid < 0) + CompileError::raise(lv.location, "Exceeded constant limit; simplify the code to compile"); + + bytecode.emitABC(set ? LOP_SETGLOBAL : LOP_GETGLOBAL, reg, 0, uint8_t(BytecodeBuilder::getStringHash(lv.name))); + bytecode.emitAux(cid); + } + break; + + case LValue::Kind_IndexName: + { + int32_t cid = bytecode.addConstantString(lv.name); + if (cid < 0) + CompileError::raise(lv.location, "Exceeded constant limit; simplify the code to compile"); + + bytecode.emitABC(set ? LOP_SETTABLEKS : LOP_GETTABLEKS, reg, lv.reg, uint8_t(BytecodeBuilder::getStringHash(lv.name))); + bytecode.emitAux(cid); + } + break; + + case LValue::Kind_IndexNumber: + bytecode.emitABC(set ? LOP_SETTABLEN : LOP_GETTABLEN, reg, lv.reg, lv.number); + break; + + case LValue::Kind_IndexExpr: + bytecode.emitABC(set ? LOP_SETTABLE : LOP_GETTABLE, reg, lv.reg, lv.index); + break; + + default: + LUAU_ASSERT(!"Unknown lvalue kind"); + } + } + + void compileAssign(const LValue& lv, uint8_t source) + { + compileLValueUse(lv, source, /* set= */ true); + } + + bool isExprLocalReg(AstExpr* expr) + { + AstExprLocal* le = expr->as(); + if (!le || le->upvalue) + return false; + + Local* l = locals.find(le->local); + LUAU_ASSERT(l); + + return l->allocated; + } + + bool isStatBreak(AstStat* node) + { + if (AstStatBlock* stat = node->as()) + return stat->body.size == 1 && stat->body.data[0]->is(); + + return node->is(); + } + + AstStatContinue* extractStatContinue(AstStatBlock* block) + { + if (block->body.size == 1) + return block->body.data[0]->as(); + else + return nullptr; + } + + void compileStatIf(AstStatIf* stat) + { + // Optimization: condition is always false => we only need the else body + if (isConstantFalse(stat->condition)) + { + if (stat->elsebody) + compileStat(stat->elsebody); + return; + } + + // Optimization: body is a "break" statement with no "else" => we can directly break out of the loop in "then" case + if (!stat->elsebody && isStatBreak(stat->thenbody) && !areLocalsCaptured(loops.back().localOffset)) + { + // fallthrough = continue with the loop as usual + std::vector elseJump; + compileConditionValue(stat->condition, nullptr, elseJump, true); + + for (size_t jump : elseJump) + loopJumps.push_back({LoopJump::Break, jump}); + return; + } + + AstStat* continueStatement = extractStatContinue(stat->thenbody); + + // Optimization: body is a "continue" statement with no "else" => we can directly continue in "then" case + if (!stat->elsebody && continueStatement != nullptr && !areLocalsCaptured(loops.back().localOffset)) + { + if (loops.back().untilCondition) + validateContinueUntil(continueStatement, loops.back().untilCondition); + + // fallthrough = proceed with the loop body as usual + std::vector elseJump; + compileConditionValue(stat->condition, nullptr, elseJump, true); + + for (size_t jump : elseJump) + loopJumps.push_back({LoopJump::Continue, jump}); + return; + } + + std::vector elseJump; + compileConditionValue(stat->condition, nullptr, elseJump, false); + + compileStat(stat->thenbody); + + if (stat->elsebody && elseJump.size() > 0) + { + // we don't need to skip past "else" body if "then" ends with return + // this is important because, if "else" also ends with return, we may *not* have any statement to skip to! + if (allPathsEndWithReturn(stat->thenbody)) + { + size_t elseLabel = bytecode.emitLabel(); + + compileStat(stat->elsebody); + + patchJumps(stat, elseJump, elseLabel); + } + else + { + size_t thenLabel = bytecode.emitLabel(); + + bytecode.emitAD(LOP_JUMP, 0, 0); + + size_t elseLabel = bytecode.emitLabel(); + + compileStat(stat->elsebody); + + size_t endLabel = bytecode.emitLabel(); + + patchJumps(stat, elseJump, elseLabel); + patchJump(stat, thenLabel, endLabel); + } + } + else + { + size_t endLabel = bytecode.emitLabel(); + + patchJumps(stat, elseJump, endLabel); + } + } + + void compileStatWhile(AstStatWhile* stat) + { + // Optimization: condition is always false => there's no loop! + if (isConstantFalse(stat->condition)) + return; + + size_t oldJumps = loopJumps.size(); + size_t oldLocals = localStack.size(); + + loops.push_back({oldLocals, nullptr}); + + size_t loopLabel = bytecode.emitLabel(); + + std::vector elseJump; + compileConditionValue(stat->condition, nullptr, elseJump, false); + + compileStat(stat->body); + + size_t contLabel = bytecode.emitLabel(); + + size_t backLabel = bytecode.emitLabel(); + + setDebugLine(stat->condition); + + // Note: this is using JUMPBACK, not JUMP, since JUMPBACK is interruptible and we want all loops to have at least one interruptible + // instruction + bytecode.emitAD(LOP_JUMPBACK, 0, 0); + + size_t endLabel = bytecode.emitLabel(); + + patchJump(stat, backLabel, loopLabel); + patchJumps(stat, elseJump, endLabel); + + patchLoopJumps(stat, oldJumps, endLabel, contLabel); + loopJumps.resize(oldJumps); + + loops.pop_back(); + } + + void compileStatRepeat(AstStatRepeat* stat) + { + size_t oldJumps = loopJumps.size(); + size_t oldLocals = localStack.size(); + + loops.push_back({oldLocals, stat->condition}); + + size_t loopLabel = bytecode.emitLabel(); + + // note: we "inline" compileStatBlock here so that we can close/pop locals after evaluating condition + // this is necessary because condition can access locals declared inside the repeat..until body + AstStatBlock* body = stat->body; + + RegScope rs(this); + + for (size_t i = 0; i < body->body.size; ++i) + compileStat(body->body.data[i]); + + size_t contLabel = bytecode.emitLabel(); + + size_t endLabel; + + setDebugLine(stat->condition); + + if (isConstantTrue(stat->condition)) + { + closeLocals(oldLocals); + + endLabel = bytecode.emitLabel(); + } + else + { + std::vector skipJump; + compileConditionValue(stat->condition, nullptr, skipJump, true); + + // we close locals *after* we compute loop conditionals because during computation of condition it's (in theory) possible that user code + // mutates them + closeLocals(oldLocals); + + size_t backLabel = bytecode.emitLabel(); + + // Note: this is using JUMPBACK, not JUMP, since JUMPBACK is interruptible and we want all loops to have at least one interruptible + // instruction + bytecode.emitAD(LOP_JUMPBACK, 0, 0); + + size_t skipLabel = bytecode.emitLabel(); + + // we need to close locals *again* after the loop ends because the first closeLocals would be jumped over on the last iteration + closeLocals(oldLocals); + + endLabel = bytecode.emitLabel(); + + patchJump(stat, backLabel, loopLabel); + patchJumps(stat, skipJump, skipLabel); + } + + popLocals(oldLocals); + + patchLoopJumps(stat, oldJumps, endLabel, contLabel); + loopJumps.resize(oldJumps); + + loops.pop_back(); + } + + void compileStatReturn(AstStatReturn* stat) + { + RegScope rs(this); + + uint8_t temp = 0; + bool multRet = false; + + // Optimization: return local value directly instead of copying it into a temporary + if (stat->list.size == 1 && isExprLocalReg(stat->list.data[0])) + { + AstExprLocal* le = stat->list.data[0]->as(); + LUAU_ASSERT(le); + + temp = getLocal(le->local); + } + else if (stat->list.size > 0) + { + temp = allocReg(stat, unsigned(stat->list.size)); + + // Note: if the last element is a function call or a vararg specifier, then we need to somehow return all values that that call returned + for (size_t i = 0; i < stat->list.size; ++i) + if (i + 1 == stat->list.size) + multRet = compileExprTempMultRet(stat->list.data[i], uint8_t(temp + i)); + else + compileExprTempTop(stat->list.data[i], uint8_t(temp + i)); + } + + closeLocals(0); + + bytecode.emitABC(LOP_RETURN, uint8_t(temp), multRet ? 0 : uint8_t(stat->list.size + 1), 0); + } + + bool areLocalsRedundant(AstStatLocal* stat) + { + // Extra expressions may have side effects + if (stat->values.size > stat->vars.size) + return false; + + for (AstLocal* local : stat->vars) + { + Local* l = locals.find(local); + + if (!l || l->constant.type == Constant::Type_Unknown) + return false; + } + + return true; + } + + void compileStatLocal(AstStatLocal* stat) + { + // Optimization: we don't need to allocate and assign const locals, since their uses will be constant-folded + if (options.optimizationLevel >= 1 && options.debugLevel <= 1 && areLocalsRedundant(stat)) + return; + + // note: allocReg in this case allocates into parent block register - note that we don't have RegScope here + uint8_t vars = allocReg(stat, unsigned(stat->vars.size)); + + compileExprListTop(stat->values, vars, uint8_t(stat->vars.size)); + + for (size_t i = 0; i < stat->vars.size; ++i) + pushLocal(stat->vars.data[i], uint8_t(vars + i)); + } + + void compileStatFor(AstStatFor* stat) + { + RegScope rs(this); + + size_t oldLocals = localStack.size(); + size_t oldJumps = loopJumps.size(); + + loops.push_back({oldLocals, nullptr}); + + // register layout: limit, step, index + uint8_t regs = allocReg(stat, 3); + + // if the iteration index is assigned from within the loop, we need to protect the internal index from the assignment + // to do that, we will copy the index into an actual local variable on each iteration + // this makes sure the code inside the loop can't interfere with the iteration process (other than modifying the table we're iterating + // through) + uint8_t varreg = regs + 2; + + Local* il = locals.find(stat->var); + + if (il && il->written) + varreg = allocReg(stat, 1); + + compileExprTemp(stat->from, uint8_t(regs + 2)); + compileExprTemp(stat->to, uint8_t(regs + 0)); + + if (stat->step) + compileExprTemp(stat->step, uint8_t(regs + 1)); + else + bytecode.emitABC(LOP_LOADN, uint8_t(regs + 1), 1, 0); + + size_t forLabel = bytecode.emitLabel(); + + bytecode.emitAD(LOP_FORNPREP, regs, 0); + + size_t loopLabel = bytecode.emitLabel(); + + if (varreg != regs + 2) + bytecode.emitABC(LOP_MOVE, varreg, regs + 2, 0); + + pushLocal(stat->var, varreg); + + compileStat(stat->body); + + closeLocals(oldLocals); + popLocals(oldLocals); + + setDebugLine(stat); + + size_t contLabel = bytecode.emitLabel(); + + size_t backLabel = bytecode.emitLabel(); + + bytecode.emitAD(LOP_FORNLOOP, regs, 0); + + size_t endLabel = bytecode.emitLabel(); + + patchJump(stat, forLabel, endLabel); + patchJump(stat, backLabel, loopLabel); + + patchLoopJumps(stat, oldJumps, endLabel, contLabel); + loopJumps.resize(oldJumps); + + loops.pop_back(); + } + + void compileStatForIn(AstStatForIn* stat) + { + RegScope rs(this); + + size_t oldLocals = localStack.size(); + size_t oldJumps = loopJumps.size(); + + loops.push_back({oldLocals, nullptr}); + + // register layout: generator, state, index, variables... + uint8_t regs = allocReg(stat, 3); + + // this puts initial values of (generator, state, index) into the loop registers + compileExprListTop(stat->values, regs, 3); + + // for the general case, we will execute a CALL for every iteration that needs to evaluate "variables... = generator(state, index)" + // this requires at least extra 3 stack slots after index + // note that these stack slots overlap with the variables so we only need to reserve them to make sure stack frame is large enough + reserveReg(stat, 3); + + // note that we reserve at least 2 variables; this allows our fast path to assume that we need 2 variables instead of 1 or 2 + uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u)); + LUAU_ASSERT(vars == regs + 3); + + // Optimization: when we iterate through pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration + // index These instructions dynamically check if generator is equal to next/inext and bail out They assume that the generator produces 2 + // variables, which is why we allocate at least 2 above (see vars assignment) + LuauOpcode skipOp = LOP_JUMP; + LuauOpcode loopOp = LOP_FORGLOOP; + + if (options.optimizationLevel >= 1 && stat->vars.size <= 2) + { + if (stat->values.size == 1 && stat->values.data[0]->is()) + { + Builtin builtin = getBuiltin(stat->values.data[0]->as()->func); + + if (builtin.isGlobal("ipairs")) // for .. in ipairs(t) + { + skipOp = LOP_FORGPREP_INEXT; + loopOp = LOP_FORGLOOP_INEXT; + } + else if (builtin.isGlobal("pairs")) // for .. in pairs(t) + { + skipOp = LOP_FORGPREP_NEXT; + loopOp = LOP_FORGLOOP_NEXT; + } + } + else if (stat->values.size == 2) + { + Builtin builtin = getBuiltin(stat->values.data[0]); + + if (builtin.isGlobal("next")) // for .. in next,t + { + skipOp = LOP_FORGPREP_NEXT; + loopOp = LOP_FORGLOOP_NEXT; + } + } + } + + // first iteration jumps into FORGLOOP instruction, but for ipairs/pairs it does extra preparation that makes the cost of an extra instruction + // worthwhile + size_t skipLabel = bytecode.emitLabel(); + + bytecode.emitAD(skipOp, regs, 0); + + size_t loopLabel = bytecode.emitLabel(); + + for (size_t i = 0; i < stat->vars.size; ++i) + pushLocal(stat->vars.data[i], uint8_t(vars + i)); + + compileStat(stat->body); + + closeLocals(oldLocals); + popLocals(oldLocals); + + setDebugLine(stat); + + size_t contLabel = bytecode.emitLabel(); + + size_t backLabel = bytecode.emitLabel(); + + bytecode.emitAD(loopOp, regs, 0); + + // note: FORGLOOP needs variable count encoded in AUX field, other loop instructions assume a fixed variable count + if (loopOp == LOP_FORGLOOP) + bytecode.emitAux(uint32_t(stat->vars.size)); + + size_t endLabel = bytecode.emitLabel(); + + patchJump(stat, skipLabel, backLabel); + patchJump(stat, backLabel, loopLabel); + + patchLoopJumps(stat, oldJumps, endLabel, contLabel); + loopJumps.resize(oldJumps); + + loops.pop_back(); + } + + void resolveAssignConflicts(AstStat* stat, std::vector& vars) + { + // regsUsed[i] is true if we have assigned the register during earlier assignments + // regsRemap[i] is set to the register where the original (pre-assignment) copy was made + // note: regsRemap is uninitialized intentionally to speed small assignments up; regsRemap[i] is valid iff regsUsed[i] + std::bitset<256> regsUsed; + uint8_t regsRemap[256]; + + for (size_t i = 0; i < vars.size(); ++i) + { + LValue& li = vars[i]; + + if (li.kind == LValue::Kind_Local) + { + if (!regsUsed[li.reg]) + { + regsUsed[li.reg] = true; + regsRemap[li.reg] = li.reg; + } + } + else if (li.kind == LValue::Kind_IndexName || li.kind == LValue::Kind_IndexNumber || li.kind == LValue::Kind_IndexExpr) + { + // we're looking for assignments before this one that invalidate any of the registers involved + if (regsUsed[li.reg]) + { + // the register may have been evacuated previously, but if it wasn't - move it now + if (regsRemap[li.reg] == li.reg) + { + uint8_t reg = allocReg(stat, 1); + bytecode.emitABC(LOP_MOVE, reg, li.reg, 0); + + regsRemap[li.reg] = reg; + } + + li.reg = regsRemap[li.reg]; + } + + if (li.kind == LValue::Kind_IndexExpr && regsUsed[li.index]) + { + // the register may have been evacuated previously, but if it wasn't - move it now + if (regsRemap[li.index] == li.index) + { + uint8_t reg = allocReg(stat, 1); + bytecode.emitABC(LOP_MOVE, reg, li.index, 0); + + regsRemap[li.index] = reg; + } + + li.index = regsRemap[li.index]; + } + } + } + } + + void compileStatAssign(AstStatAssign* stat) + { + RegScope rs(this); + + // Optimization: one to one assignments don't require complex conflict resolution machinery and allow us to skip temporary registers for + // locals + if (stat->vars.size == 1 && stat->values.size == 1) + { + LValue var = compileLValue(stat->vars.data[0], rs); + + // Optimization: assign to locals directly + if (var.kind == LValue::Kind_Local) + { + compileExpr(stat->values.data[0], var.reg); + } + else + { + uint8_t reg = compileExprAuto(stat->values.data[0], rs); + + setDebugLine(stat->vars.data[0]); + compileAssign(var, reg); + } + return; + } + + // compute all l-values: note that this doesn't assign anything yet but it allocates registers and computes complex expressions on the left + // hand side for example, in "a[expr] = foo" expr will get evaluated here + std::vector vars(stat->vars.size); + + for (size_t i = 0; i < stat->vars.size; ++i) + vars[i] = compileLValue(stat->vars.data[i], rs); + + // perform conflict resolution: if any lvalue refers to a local reg that will be reassigned before that, we save the local variable in a + // temporary reg + resolveAssignConflicts(stat, vars); + + // compute values into temporaries + uint8_t regs = allocReg(stat, unsigned(stat->vars.size)); + + compileExprListTop(stat->values, regs, uint8_t(stat->vars.size)); + + // assign variables that have associated values; note that if we have fewer values than variables, we'll assign nil because compileExprListTop + // will generate nils + for (size_t i = 0; i < stat->vars.size; ++i) + { + setDebugLine(stat->vars.data[i]); + compileAssign(vars[i], uint8_t(regs + i)); + } + } + + void compileStatCompoundAssign(AstStatCompoundAssign* stat) + { + RegScope rs(this); + + LValue var = compileLValue(stat->var, rs); + + // Optimization: assign to locals directly + uint8_t target = (var.kind == LValue::Kind_Local) ? var.reg : allocReg(stat, 1); + + switch (stat->op) + { + case AstExprBinary::Add: + case AstExprBinary::Sub: + case AstExprBinary::Mul: + case AstExprBinary::Div: + case AstExprBinary::Mod: + case AstExprBinary::Pow: + { + if (var.kind != LValue::Kind_Local) + compileLValueUse(var, target, /* set= */ false); + + int32_t rc = getConstantNumber(stat->value); + + if (rc >= 0 && rc <= 255) + { + bytecode.emitABC(getBinaryOpArith(stat->op, /* k= */ true), target, target, uint8_t(rc)); + } + else + { + uint8_t rr = compileExprAuto(stat->value, rs); + + bytecode.emitABC(getBinaryOpArith(stat->op), target, target, rr); + } + } + break; + + case AstExprBinary::Concat: + { + std::vector args = {stat->value}; + + // unroll the tree of concats down the right hand side to be able to do multiple ops + unrollConcats(args); + + uint8_t regs = allocReg(stat, unsigned(1 + args.size())); + + compileLValueUse(var, regs, /* set= */ false); + + for (size_t i = 0; i < args.size(); ++i) + compileExprTemp(args[i], uint8_t(regs + 1 + i)); + + bytecode.emitABC(LOP_CONCAT, target, regs, uint8_t(regs + args.size())); + } + break; + + default: + LUAU_ASSERT(!"Unexpected compound assignment operation"); + } + + if (var.kind != LValue::Kind_Local) + compileAssign(var, target); + } + + void compileStatFunction(AstStatFunction* stat) + { + // Optimization: compile value expresion directly into target local register + if (isExprLocalReg(stat->name)) + { + AstExprLocal* le = stat->name->as(); + LUAU_ASSERT(le); + + compileExpr(stat->func, getLocal(le->local)); + return; + } + + RegScope rs(this); + uint8_t reg = allocReg(stat, 1); + + compileExprTemp(stat->func, reg); + + LValue var = compileLValue(stat->name, rs); + compileAssign(var, reg); + } + + void compileStat(AstStat* node) + { + setDebugLine(node); + + if (options.coverageLevel >= 1 && needsCoverage(node)) + { + bytecode.emitABC(LOP_COVERAGE, 0, 0, 0); + } + + if (AstStatBlock* stat = node->as()) + { + RegScope rs(this); + + size_t oldLocals = localStack.size(); + + for (size_t i = 0; i < stat->body.size; ++i) + compileStat(stat->body.data[i]); + + closeLocals(oldLocals); + + popLocals(oldLocals); + } + else if (AstStatIf* stat = node->as()) + { + compileStatIf(stat); + } + else if (AstStatWhile* stat = node->as()) + { + compileStatWhile(stat); + } + else if (AstStatRepeat* stat = node->as()) + { + compileStatRepeat(stat); + } + else if (node->is()) + { + // before exiting out of the loop, we need to close all local variables that were captured in closures since loop start + // normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here + LUAU_ASSERT(!loops.empty()); + closeLocals(loops.back().localOffset); + + size_t label = bytecode.emitLabel(); + + bytecode.emitAD(LOP_JUMP, 0, 0); + + loopJumps.push_back({LoopJump::Break, label}); + } + else if (AstStatContinue* stat = node->as()) + { + if (loops.back().untilCondition) + validateContinueUntil(stat, loops.back().untilCondition); + + // before continuing, we need to close all local variables that were captured in closures since loop start + // normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here + LUAU_ASSERT(!loops.empty()); + closeLocals(loops.back().localOffset); + + size_t label = bytecode.emitLabel(); + + bytecode.emitAD(LOP_JUMP, 0, 0); + + loopJumps.push_back({LoopJump::Continue, label}); + } + else if (AstStatReturn* stat = node->as()) + { + compileStatReturn(stat); + } + else if (AstStatExpr* stat = node->as()) + { + // Optimization: since we don't need to read anything from the stack, we can compile the call to not return anything which saves register + // moves + if (AstExprCall* expr = stat->expr->as()) + { + uint8_t target = uint8_t(regTop); + + compileExprCall(expr, target, /* targetCount= */ 0); + } + else + { + RegScope rs(this); + compileExprAuto(stat->expr, rs); + } + } + else if (AstStatLocal* stat = node->as()) + { + compileStatLocal(stat); + } + else if (AstStatFor* stat = node->as()) + { + compileStatFor(stat); + } + else if (AstStatForIn* stat = node->as()) + { + compileStatForIn(stat); + } + else if (AstStatAssign* stat = node->as()) + { + compileStatAssign(stat); + } + else if (AstStatCompoundAssign* stat = node->as()) + { + compileStatCompoundAssign(stat); + } + else if (AstStatFunction* stat = node->as()) + { + compileStatFunction(stat); + } + else if (AstStatLocalFunction* stat = node->as()) + { + uint8_t var = allocReg(stat, 1); + + pushLocal(stat->name, var); + compileExprFunction(stat->func, var); + + Local& l = locals[stat->name]; + + // we *have* to pushLocal before we compile the function, since the function may refer to the local as an upvalue + // however, this means the debugpc for the local is at an instruction where the local value hasn't been computed yet + // to fix this we just move the debugpc after the local value is established + l.debugpc = bytecode.getDebugPC(); + } + else if (node->is()) + { + // do nothing + } + else + { + LUAU_ASSERT(!"Unknown statement type"); + } + } + + void validateContinueUntil(AstStat* cont, AstExpr* condition) + { + UndefinedLocalVisitor visitor(this); + condition->visit(&visitor); + + if (visitor.undef) + CompileError::raise(condition->location, + "Local %s used in the repeat..until condition is undefined because continue statement on line %d jumps over it", + visitor.undef->name.value, cont->location.begin.line + 1); + } + + void gatherConstUpvals(AstExprFunction* func) + { + ConstUpvalueVisitor visitor(this); + func->body->visit(&visitor); + + for (AstLocal* local : visitor.upvals) + getUpval(local); + } + + void pushLocal(AstLocal* local, uint8_t reg) + { + if (localStack.size() >= kMaxLocalCount) + CompileError::raise( + local->location, "Out of local registers when trying to allocate %s: exceeded limit %d", local->name.value, kMaxLocalCount); + + localStack.push_back(local); + + Local& l = locals[local]; + + LUAU_ASSERT(!l.allocated); + + l.reg = reg; + l.allocated = true; + l.debugpc = bytecode.getDebugPC(); + } + + bool areLocalsCaptured(size_t start) + { + LUAU_ASSERT(start <= localStack.size()); + + for (size_t i = start; i < localStack.size(); ++i) + { + Local* l = locals.find(localStack[i]); + LUAU_ASSERT(l); + + if (l->captured && l->written) + return true; + } + + return false; + } + + void closeLocals(size_t start) + { + LUAU_ASSERT(start <= localStack.size()); + + bool captured = false; + uint8_t captureReg = 255; + + for (size_t i = start; i < localStack.size(); ++i) + { + Local* l = locals.find(localStack[i]); + LUAU_ASSERT(l); + + if (l->captured && l->written) + { + captured = true; + captureReg = std::min(captureReg, l->reg); + } + } + + if (captured) + { + bytecode.emitABC(LOP_CLOSEUPVALS, captureReg, 0, 0); + } + } + + void popLocals(size_t start) + { + LUAU_ASSERT(start <= localStack.size()); + + for (size_t i = start; i < localStack.size(); ++i) + { + Local* l = locals.find(localStack[i]); + LUAU_ASSERT(l); + LUAU_ASSERT(l->allocated); + + l->allocated = false; + + if (options.debugLevel >= 2) + { + uint32_t debugpc = bytecode.getDebugPC(); + + bytecode.pushDebugLocal(sref(localStack[i]->name), l->reg, l->debugpc, debugpc); + } + } + + localStack.resize(start); + } + + void patchJump(AstNode* node, size_t label, size_t target) + { + if (!bytecode.patchJumpD(label, target)) + CompileError::raise(node->location, "Exceeded jump distance limit; simplify the code to compile"); + } + + void patchJumps(AstNode* node, std::vector& labels, size_t target) + { + for (size_t l : labels) + patchJump(node, l, target); + } + + void patchLoopJumps(AstNode* node, size_t oldJumps, size_t endLabel, size_t contLabel) + { + LUAU_ASSERT(oldJumps <= loopJumps.size()); + + for (size_t i = oldJumps; i < loopJumps.size(); ++i) + { + const LoopJump& lj = loopJumps[i]; + + switch (lj.type) + { + case LoopJump::Break: + patchJump(node, lj.label, endLabel); + break; + + case LoopJump::Continue: + patchJump(node, lj.label, contLabel); + break; + + default: + LUAU_ASSERT(!"Unknown loop jump type"); + } + } + } + + uint8_t allocReg(AstNode* node, unsigned int count) + { + unsigned int top = regTop; + if (top + count > kMaxRegisterCount) + CompileError::raise(node->location, "Out of registers when trying to allocate %d registers: exceeded limit %d", count, kMaxRegisterCount); + + regTop += count; + stackSize = std::max(stackSize, regTop); + + return uint8_t(top); + } + + void reserveReg(AstNode* node, unsigned int count) + { + if (regTop + count > kMaxRegisterCount) + CompileError::raise(node->location, "Out of registers when trying to allocate %d registers: exceeded limit %d", count, kMaxRegisterCount); + + stackSize = std::max(stackSize, regTop + count); + } + + void setDebugLine(AstNode* node) + { + if (options.debugLevel >= 1) + bytecode.setDebugLine(node->location.begin.line + 1); + } + + void setDebugLineEnd(AstNode* node) + { + if (options.debugLevel >= 1) + bytecode.setDebugLine(node->location.end.line + 1); + } + + bool needsCoverage(AstNode* node) + { + return !node->is() && !node->is(); + } + + struct AssignmentVisitor : AstVisitor + { + struct Hasher + { + size_t operator()(const std::pair& p) const + { + return std::hash()(p.first) ^ std::hash()(p.second); + } + }; + + DenseHashMap localToTable; + DenseHashSet, Hasher> fields; + + AssignmentVisitor(Compiler* self) + : localToTable(nullptr) + , fields(std::pair()) + , self(self) + { + } + + void assignField(AstExpr* expr, AstName index) + { + if (AstExprLocal* lv = expr->as()) + { + if (AstExprTable** table = localToTable.find(lv->local)) + { + std::pair field = {*table, index}; + + if (!fields.contains(field)) + { + fields.insert(field); + self->predictedTableSize[*table].first += 1; + } + } + } + } + + void assignField(AstExpr* expr, AstExpr* index) + { + AstExprLocal* lv = expr->as(); + AstExprConstantNumber* number = index->as(); + + if (lv && number) + { + if (AstExprTable** table = localToTable.find(lv->local)) + { + unsigned int& arraySize = self->predictedTableSize[*table].second; + + if (number->value == double(arraySize + 1)) + arraySize += 1; + } + } + } + + void assign(AstExpr* var) + { + if (AstExprLocal* lv = var->as()) + { + self->locals[lv->local].written = true; + } + else if (AstExprGlobal* gv = var->as()) + { + self->globals[gv->name].written = true; + } + else if (AstExprIndexName* index = var->as()) + { + assignField(index->expr, index->index); + + var->visit(this); + } + else if (AstExprIndexExpr* index = var->as()) + { + assignField(index->expr, index->index); + + var->visit(this); + } + else + { + // we need to be able to track assignments in all expressions, including crazy ones like t[function() t = nil end] = 5 + var->visit(this); + } + } + + AstExprTable* getTableHint(AstExpr* expr) + { + // unadorned table literal + if (AstExprTable* table = expr->as()) + return table; + + // setmetatable(table literal, ...) + if (AstExprCall* call = expr->as(); call && !call->self && call->args.size == 2) + if (AstExprGlobal* func = call->func->as(); func && func->name == "setmetatable") + if (AstExprTable* table = call->args.data[0]->as()) + return table; + + return nullptr; + } + + bool visit(AstStatLocal* node) override + { + // track local -> table association so that we can update table size prediction in assignField + if (node->vars.size == 1 && node->values.size == 1) + if (AstExprTable* table = getTableHint(node->values.data[0]); table && table->items.size == 0) + localToTable[node->vars.data[0]] = table; + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatCompoundAssign* node) override + { + assign(node->var); + node->value->visit(this); + + return false; + } + + bool visit(AstStatFunction* node) override + { + assign(node->name); + node->func->visit(this); + + return false; + } + + Compiler* self; + }; + + struct ConstantVisitor : AstVisitor + { + ConstantVisitor(Compiler* self) + : self(self) + { + } + + void analyzeUnary(Constant& result, AstExprUnary::Op op, const Constant& arg) + { + switch (op) + { + case AstExprUnary::Not: + if (arg.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = !arg.isTruthful(); + } + break; + + case AstExprUnary::Minus: + if (arg.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = -arg.valueNumber; + } + break; + + case AstExprUnary::Len: + break; + + default: + LUAU_ASSERT(!"Unexpected unary operation"); + } + } + + bool constantsEqual(const Constant& la, const Constant& ra) + { + LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown); + + switch (la.type) + { + case Constant::Type_Nil: + return ra.type == Constant::Type_Nil; + + case Constant::Type_Boolean: + return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean; + + case Constant::Type_Number: + return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber; + + case Constant::Type_String: + return ra.type == Constant::Type_String && la.valueString.size == ra.valueString.size && + memcmp(la.valueString.data, ra.valueString.data, la.valueString.size) == 0; + + default: + LUAU_ASSERT(!"Unexpected constant type in comparison"); + return false; + } + } + + void analyzeBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra) + { + switch (op) + { + case AstExprBinary::Add: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber + ra.valueNumber; + } + break; + + case AstExprBinary::Sub: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber - ra.valueNumber; + } + break; + + case AstExprBinary::Mul: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber * ra.valueNumber; + } + break; + + case AstExprBinary::Div: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber / ra.valueNumber; + } + break; + + case AstExprBinary::Mod: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber; + } + break; + + case AstExprBinary::Pow: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = pow(la.valueNumber, ra.valueNumber); + } + break; + + case AstExprBinary::Concat: + break; + + case AstExprBinary::CompareNe: + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = !constantsEqual(la, ra); + } + break; + + case AstExprBinary::CompareEq: + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = constantsEqual(la, ra); + } + break; + + case AstExprBinary::CompareLt: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber < ra.valueNumber; + } + break; + + case AstExprBinary::CompareLe: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber <= ra.valueNumber; + } + break; + + case AstExprBinary::CompareGt: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber > ra.valueNumber; + } + break; + + case AstExprBinary::CompareGe: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber >= ra.valueNumber; + } + break; + + case AstExprBinary::And: + if (la.type != Constant::Type_Unknown) + { + result = la.isTruthful() ? ra : la; + } + break; + + case AstExprBinary::Or: + if (la.type != Constant::Type_Unknown) + { + result = la.isTruthful() ? la : ra; + } + break; + + default: + LUAU_ASSERT(!"Unexpected binary operation"); + } + } + + Constant analyze(AstExpr* node) + { + Constant result; + result.type = Constant::Type_Unknown; + + if (AstExprGroup* expr = node->as()) + { + result = analyze(expr->expr); + } + else if (node->is()) + { + result.type = Constant::Type_Nil; + } + else if (AstExprConstantBool* expr = node->as()) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = expr->value; + } + else if (AstExprConstantNumber* expr = node->as()) + { + result.type = Constant::Type_Number; + result.valueNumber = expr->value; + } + else if (AstExprConstantString* expr = node->as()) + { + result.type = Constant::Type_String; + result.valueString = expr->value; + } + else if (AstExprLocal* expr = node->as()) + { + const Local* l = self->locals.find(expr->local); + + if (l && l->constant.type != Constant::Type_Unknown) + { + LUAU_ASSERT(!l->written); + result = l->constant; + } + } + else if (node->is()) + { + // nope + } + else if (node->is()) + { + // nope + } + else if (AstExprCall* expr = node->as()) + { + analyze(expr->func); + + for (size_t i = 0; i < expr->args.size; ++i) + analyze(expr->args.data[i]); + } + else if (AstExprIndexName* expr = node->as()) + { + analyze(expr->expr); + } + else if (AstExprIndexExpr* expr = node->as()) + { + analyze(expr->expr); + analyze(expr->index); + } + else if (AstExprFunction* expr = node->as()) + { + // this is necessary to propagate constant information in all child functions + expr->body->visit(this); + } + else if (AstExprTable* expr = node->as()) + { + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + + if (item.key) + analyze(item.key); + + analyze(item.value); + } + } + else if (AstExprUnary* expr = node->as()) + { + Constant arg = analyze(expr->expr); + + analyzeUnary(result, expr->op, arg); + } + else if (AstExprBinary* expr = node->as()) + { + Constant la = analyze(expr->left); + Constant ra = analyze(expr->right); + + analyzeBinary(result, expr->op, la, ra); + } + else if (AstExprTypeAssertion* expr = node->as()) + { + Constant arg = analyze(expr->expr); + + result = arg; + } + else if (AstExprIfElse* expr = node->as(); FFlag::LuauIfElseExpressionBaseSupport && expr) + { + Constant cond = analyze(expr->condition); + Constant trueExpr = analyze(expr->trueExpr); + Constant falseExpr = analyze(expr->falseExpr); + if (cond.type != Constant::Type_Unknown) + { + result = cond.isTruthful() ? trueExpr : falseExpr; + } + } + else + { + LUAU_ASSERT(!"Unknown expression type"); + } + + if (result.type != Constant::Type_Unknown) + self->constants[node] = result; + + return result; + } + + bool visit(AstExpr* node) override + { + // note: we short-circuit the visitor traversal through any expression trees by returning false + // recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression + analyze(node); + + return false; + } + + bool visit(AstStatLocal* node) override + { + // for values that match 1-1 we record the initializing expression for future analysis + for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) + { + Local& l = self->locals[node->vars.data[i]]; + + l.init = node->values.data[i]; + } + + // all values that align wrt indexing are simple - we just match them 1-1 + for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) + { + Constant arg = analyze(node->values.data[i]); + + if (arg.type != Constant::Type_Unknown) + { + Local& l = self->locals[node->vars.data[i]]; + + // note: we rely on AssignmentVisitor to have been run before us + if (!l.written) + l.constant = arg; + } + } + + if (node->vars.size > node->values.size) + { + // if we have trailing variables, then depending on whether the last value is capable of returning multiple values + // (aka call or varargs), we either don't know anything about these vars, or we know they're nil + AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr; + bool multRet = last && (last->is() || last->is()); + + for (size_t i = node->values.size; i < node->vars.size; ++i) + { + if (!multRet) + { + Local& l = self->locals[node->vars.data[i]]; + + // note: we rely on AssignmentVisitor to have been run before us + if (!l.written) + { + l.constant.type = Constant::Type_Nil; + } + } + } + } + else + { + // we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside + // them + for (size_t i = node->vars.size; i < node->values.size; ++i) + analyze(node->values.data[i]); + } + + return false; + } + + Compiler* self; + }; + + struct FenvVisitor : AstVisitor + { + bool& getfenvUsed; + bool& setfenvUsed; + + FenvVisitor(bool& getfenvUsed, bool& setfenvUsed) + : getfenvUsed(getfenvUsed) + , setfenvUsed(setfenvUsed) + { + } + + bool visit(AstExprGlobal* node) override + { + if (node->name == "getfenv") + getfenvUsed = true; + if (node->name == "setfenv") + setfenvUsed = true; + + return false; + } + }; + + struct FunctionVisitor : AstVisitor + { + Compiler* self; + std::vector& functions; + + FunctionVisitor(Compiler* self, std::vector& functions) + : self(self) + , functions(functions) + { + } + + bool visit(AstExprFunction* node) override + { + node->body->visit(this); + + // this makes sure all functions that are used when compiling this one have been already added to the vector + functions.push_back(node); + + return false; + } + + bool visit(AstStatLocalFunction* node) override + { + // record local->function association for some optimizations + if (FFlag::LuauPreloadClosuresUpval) + self->locals[node->name].func = node->func; + + return true; + } + }; + + struct UndefinedLocalVisitor : AstVisitor + { + UndefinedLocalVisitor(Compiler* self) + : self(self) + , undef(nullptr) + { + } + + void check(AstLocal* local) + { + Local& l = self->locals[local]; + + if (!l.allocated && !undef) + undef = local; + } + + bool visit(AstExprLocal* node) override + { + if (!node->upvalue) + check(node->local); + + return false; + } + + bool visit(AstExprFunction* node) override + { + const Function* f = self->functions.find(node); + LUAU_ASSERT(f); + + for (AstLocal* uv : f->upvals) + { + LUAU_ASSERT(uv->functionDepth < node->functionDepth); + + if (uv->functionDepth == node->functionDepth - 1) + check(uv); + } + + return false; + } + + Compiler* self; + AstLocal* undef; + }; + + struct ConstUpvalueVisitor : AstVisitor + { + ConstUpvalueVisitor(Compiler* self) + : self(self) + { + } + + bool visit(AstExprLocal* node) override + { + if (node->upvalue && self->isConstant(node)) + { + upvals.push_back(node->local); + } + + return false; + } + + bool visit(AstExprFunction* node) override + { + // short-circuits the traversal to make it faster + return false; + } + + Compiler* self; + std::vector upvals; + }; + + struct RegScope + { + RegScope(Compiler* self) + : self(self) + , oldTop(self->regTop) + { + } + + // This ctor is useful to forcefully adjust the stack frame in case we know that registers after a certain point are scratch and can be + // discarded + RegScope(Compiler* self, unsigned int top) + : self(self) + , oldTop(self->regTop) + { + LUAU_ASSERT(top <= self->regTop); + self->regTop = top; + } + + ~RegScope() + { + self->regTop = oldTop; + } + + Compiler* self; + unsigned int oldTop; + }; + + struct Function + { + uint32_t id; + std::vector upvals; + }; + + struct Constant + { + enum Type + { + Type_Unknown, + Type_Nil, + Type_Boolean, + Type_Number, + Type_String, + }; + + Type type = Type_Unknown; + + union + { + bool valueBoolean; + double valueNumber; + AstArray valueString = {}; + }; + + bool isTruthful() const + { + LUAU_ASSERT(type != Type_Unknown); + return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); + } + }; + + struct Local + { + uint8_t reg = 0; + bool allocated = false; + bool captured = false; + bool written = false; + AstExpr* init = nullptr; + uint32_t debugpc = 0; + Constant constant; + AstExprFunction* func = nullptr; + }; + + struct Global + { + bool special = false; + bool written = false; + }; + + struct Builtin + { + AstName object; + AstName method; + + bool empty() const + { + return object == AstName() && method == AstName(); + } + + bool isGlobal(const char* name) const + { + return object == AstName() && method == name; + } + + bool isMethod(const char* table, const char* name) const + { + return object == table && method == name; + } + }; + + struct LoopJump + { + enum Type + { + Break, + Continue + }; + + Type type; + size_t label; + }; + + struct Loop + { + size_t localOffset; + + AstExpr* untilCondition; + }; + + Builtin getBuiltin(AstExpr* node) + { + if (AstExprLocal* expr = node->as()) + { + Local* l = locals.find(expr->local); + + return l && !l->written && l->init ? getBuiltin(l->init) : Builtin(); + } + else if (AstExprIndexName* expr = node->as()) + { + if (AstExprGlobal* object = expr->expr->as()) + { + Global* g = globals.find(object->name); + + return !g || (!g->special && !g->written) ? Builtin{object->name, expr->index} : Builtin(); + } + else + { + return Builtin(); + } + } + else if (AstExprGlobal* expr = node->as()) + { + Global* g = globals.find(expr->name); + + return !g || !g->written ? Builtin{AstName(), expr->name} : Builtin(); + } + else + { + return Builtin(); + } + } + + int getBuiltinFunctionId(const Builtin& builtin) + { + if (builtin.empty()) + return -1; + + if (builtin.isGlobal("assert")) + return LBF_ASSERT; + + if (builtin.isGlobal("type")) + return LBF_TYPE; + + if (builtin.isGlobal("typeof")) + return LBF_TYPEOF; + + if (builtin.isGlobal("rawset")) + return LBF_RAWSET; + if (builtin.isGlobal("rawget")) + return LBF_RAWGET; + if (builtin.isGlobal("rawequal")) + return LBF_RAWEQUAL; + + if (builtin.isGlobal("unpack")) + return LBF_TABLE_UNPACK; + + if (builtin.object == "math") + { + if (builtin.method == "abs") + return LBF_MATH_ABS; + if (builtin.method == "acos") + return LBF_MATH_ACOS; + if (builtin.method == "asin") + return LBF_MATH_ASIN; + if (builtin.method == "atan2") + return LBF_MATH_ATAN2; + if (builtin.method == "atan") + return LBF_MATH_ATAN; + if (builtin.method == "ceil") + return LBF_MATH_CEIL; + if (builtin.method == "cosh") + return LBF_MATH_COSH; + if (builtin.method == "cos") + return LBF_MATH_COS; + if (builtin.method == "deg") + return LBF_MATH_DEG; + if (builtin.method == "exp") + return LBF_MATH_EXP; + if (builtin.method == "floor") + return LBF_MATH_FLOOR; + if (builtin.method == "fmod") + return LBF_MATH_FMOD; + if (builtin.method == "frexp") + return LBF_MATH_FREXP; + if (builtin.method == "ldexp") + return LBF_MATH_LDEXP; + if (builtin.method == "log10") + return LBF_MATH_LOG10; + if (builtin.method == "log") + return LBF_MATH_LOG; + if (builtin.method == "max") + return LBF_MATH_MAX; + if (builtin.method == "min") + return LBF_MATH_MIN; + if (builtin.method == "modf") + return LBF_MATH_MODF; + if (builtin.method == "pow") + return LBF_MATH_POW; + if (builtin.method == "rad") + return LBF_MATH_RAD; + if (builtin.method == "sinh") + return LBF_MATH_SINH; + if (builtin.method == "sin") + return LBF_MATH_SIN; + if (builtin.method == "sqrt") + return LBF_MATH_SQRT; + if (builtin.method == "tanh") + return LBF_MATH_TANH; + if (builtin.method == "tan") + return LBF_MATH_TAN; + if (builtin.method == "clamp") + return LBF_MATH_CLAMP; + if (builtin.method == "sign") + return LBF_MATH_SIGN; + if (builtin.method == "round") + return LBF_MATH_ROUND; + } + + if (builtin.object == "bit32") + { + if (builtin.method == "arshift") + return LBF_BIT32_ARSHIFT; + if (builtin.method == "band") + return LBF_BIT32_BAND; + if (builtin.method == "bnot") + return LBF_BIT32_BNOT; + if (builtin.method == "bor") + return LBF_BIT32_BOR; + if (builtin.method == "bxor") + return LBF_BIT32_BXOR; + if (builtin.method == "btest") + return LBF_BIT32_BTEST; + if (builtin.method == "extract") + return LBF_BIT32_EXTRACT; + if (builtin.method == "lrotate") + return LBF_BIT32_LROTATE; + if (builtin.method == "lshift") + return LBF_BIT32_LSHIFT; + if (builtin.method == "replace") + return LBF_BIT32_REPLACE; + if (builtin.method == "rrotate") + return LBF_BIT32_RROTATE; + if (builtin.method == "rshift") + return LBF_BIT32_RSHIFT; + } + + if (builtin.object == "string") + { + if (builtin.method == "byte") + return LBF_STRING_BYTE; + if (builtin.method == "char") + return LBF_STRING_CHAR; + if (builtin.method == "len") + return LBF_STRING_LEN; + if (builtin.method == "sub") + return LBF_STRING_SUB; + } + + if (builtin.object == "table") + { + if (builtin.method == "insert") + return LBF_TABLE_INSERT; + if (builtin.method == "unpack") + return LBF_TABLE_UNPACK; + } + + if (options.vectorCtor) + { + if (options.vectorLib) + { + if (builtin.object == options.vectorLib && builtin.method == options.vectorCtor) + return LBF_VECTOR; + } + else + { + if (builtin.isGlobal(options.vectorCtor)) + return LBF_VECTOR; + } + } + + return -1; + } + + BytecodeBuilder& bytecode; + + CompileOptions options; + + DenseHashMap functions; + DenseHashMap locals; + DenseHashMap globals; + DenseHashMap constants; + DenseHashMap> predictedTableSize; + + unsigned int regTop = 0; + unsigned int stackSize = 0; + + bool getfenvUsed = false; + bool setfenvUsed = false; + + std::vector localStack; + std::vector upvals; + std::vector loopJumps; + std::vector loops; +}; + +void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) +{ + Compiler compiler(bytecode, options); + + // since access to some global objects may result in values that change over time, we block table imports + for (const char* global : kSpecialGlobals) + { + AstName name = names.get(global); + + if (name.value) + compiler.globals[name].special = true; + } + + // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written + Compiler::AssignmentVisitor assignmentVisitor(&compiler); + root->visit(&assignmentVisitor); + + // this visitor traverses the AST to analyze constantness of expressions, filling constants[] and Local::constant/Local::init + if (options.optimizationLevel >= 1) + { + Compiler::ConstantVisitor constantVisitor(&compiler); + root->visit(&constantVisitor); + } + + // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found + if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1) + { + Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed); + root->visit(&fenvVisitor); + } + + // gathers all functions with the invariant that all function references are to functions earlier in the list + // for example, function foo() return function() end end will result in two vector entries, [0] = anonymous and [1] = foo + std::vector functions; + Compiler::FunctionVisitor functionVisitor(&compiler, functions); + root->visit(&functionVisitor); + + for (AstExprFunction* expr : functions) + compiler.compileFunction(expr); + + AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), /* self= */ nullptr, + AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); + uint32_t mainid = compiler.compileFunction(&main); + + bytecode.setMainFunction(mainid); + bytecode.finalize(); +} + +void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions) +{ + Allocator allocator; + AstNameTable names(allocator); + ParseResult result = Parser::parse(source.c_str(), source.size(), names, allocator, parseOptions); + + if (!result.errors.empty()) + throw ParseErrors(result.errors); + + AstStatBlock* root = result.root; + + compileOrThrow(bytecode, root, names, options); +} + +std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder) +{ + Allocator allocator; + AstNameTable names(allocator); + ParseResult result = Parser::parse(source.c_str(), source.size(), names, allocator, parseOptions); + + if (!result.errors.empty()) + { + // Users of this function expect only a single error message + const Luau::ParseError& parseError = result.errors.front(); + std::string error = format(":%d: %s", parseError.getLocation().begin.line + 1, parseError.what()); + + return BytecodeBuilder::getError(error); + } + + try + { + BytecodeBuilder bcb(encoder); + compileOrThrow(bcb, result.root, names, options); + + return bcb.getBytecode(); + } + catch (CompileError& e) + { + std::string error = format(":%d: %s", e.getLocation().begin.line + 1, e.what()); + return BytecodeBuilder::getError(error); + } +} + +} // namespace Luau diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..d63e729 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019-2021 Roblox Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0056870 --- /dev/null +++ b/Makefile @@ -0,0 +1,169 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +MAKEFLAGS+=-r -j8 +COMMA=, + +config=debug + +BUILD=build/$(config) + +AST_SOURCES=$(wildcard Ast/src/*.cpp) +AST_OBJECTS=$(AST_SOURCES:%=$(BUILD)/%.o) +AST_TARGET=$(BUILD)/libluauast.a + +COMPILER_SOURCES=$(wildcard Compiler/src/*.cpp) +COMPILER_OBJECTS=$(COMPILER_SOURCES:%=$(BUILD)/%.o) +COMPILER_TARGET=$(BUILD)/libluaucompiler.a + +ANALYSIS_SOURCES=$(wildcard Analysis/src/*.cpp) +ANALYSIS_OBJECTS=$(ANALYSIS_SOURCES:%=$(BUILD)/%.o) +ANALYSIS_TARGET=$(BUILD)/libluauanalysis.a + +VM_SOURCES=$(wildcard VM/src/*.cpp) +VM_OBJECTS=$(VM_SOURCES:%=$(BUILD)/%.o) +VM_TARGET=$(BUILD)/libluauvm.a + +TESTS_SOURCES=$(wildcard tests/*.cpp) +TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) +TESTS_TARGET=$(BUILD)/luau-tests + +REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Repl.cpp +REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) +REPL_CLI_TARGET=$(BUILD)/luau + +ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Analyze.cpp +ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o) +ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze + +FUZZ_SOURCES=$(wildcard fuzz/*.cpp) +FUZZ_OBJECTS=$(FUZZ_SOURCES:%=$(BUILD)/%.o) + +TESTS_ARGS= +ifneq ($(flags),) + TESTS_ARGS+=--fflags=$(flags) +endif + +OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) + +# common flags +CXXFLAGS=-g -Wall -Werror +LDFLAGS= + +CXXFLAGS+=-Wno-unused # temporary, for older gcc versions + +# configuration-specific flags +ifeq ($(config),release) + CXXFLAGS+=-O2 -DNDEBUG +endif + +ifeq ($(config),coverage) + CXXFLAGS+=-fprofile-instr-generate -fcoverage-mapping + LDFLAGS+=-fprofile-instr-generate +endif + +ifeq ($(config),sanitize) + CXXFLAGS+=-fsanitize=address -O1 + LDFLAGS+=-fsanitize=address +endif + +ifeq ($(config),analyze) + CXXFLAGS+=--analyze +endif + +ifeq ($(config),fuzz) + CXX=clang++ # our fuzzing infra relies on llvm fuzzer + CXXFLAGS+=-fsanitize=address,fuzzer -Ibuild/libprotobuf-mutator -Ibuild/libprotobuf-mutator/external.protobuf/include -O2 + LDFLAGS+=-fsanitize=address,fuzzer +endif + +# target-specific flags +$(AST_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include +$(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -IAst/include +$(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include +$(VM_OBJECTS): CXXFLAGS+=-std=c++11 -IVM/include +$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -Iextern +$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IVM/include -Iextern +$(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include -Iextern +$(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include + +$(REPL_CLI_TARGET): LDFLAGS+=-lpthread +fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a build/libprotobuf-mutator/external.protobuf/lib/libprotobuf.a + +# pseudo targets +.PHONY: all test clean coverage format luau-size + +all: $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET) $(TESTS_TARGET) + +test: $(TESTS_TARGET) + $(TESTS_TARGET) $(TESTS_ARGS) + +clean: + rm -rf $(BUILD) + +coverage: $(TESTS_TARGET) + $(TESTS_TARGET) --fflags=true + mv default.profraw default-flags.profraw + $(TESTS_TARGET) + llvm-profdata merge default.profraw default-flags.profraw -o default.profdata + rm default.profraw default-flags.profraw + llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests + llvm-cov report -ignore-filename-regex=\(tests\|extern\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests + +format: + find . -name '*.h' -or -name '*.cpp' | xargs clang-format -i + +luau-size: luau + nm --print-size --demangle luau | grep ' t void luau_execute' | awk -F ' ' '{sum += strtonum("0x" $$2)} END {print sum " interpreter" }' + nm --print-size --demangle luau | grep ' t luauF_' | awk -F ' ' '{sum += strtonum("0x" $$2)} END {print sum " builtins" }' + +# executable target aliases +luau: $(REPL_CLI_TARGET) + cp $^ $@ + +luau-analyze: $(ANALYZE_CLI_TARGET) + cp $^ $@ + +# executable targets +$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) +$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) +$(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(AST_TARGET) + +$(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET): + $(CXX) $^ $(LDFLAGS) -o $@ + +# executable targets for fuzzing +fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) +fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator +fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator + +fuzz-%: + $(CXX) $^ $(LDFLAGS) -o $@ + +# static library targets +$(AST_TARGET): $(AST_OBJECTS) +$(COMPILER_TARGET): $(COMPILER_OBJECTS) +$(ANALYSIS_TARGET): $(ANALYSIS_OBJECTS) +$(VM_TARGET): $(VM_OBJECTS) + +$(AST_TARGET) $(COMPILER_TARGET) $(ANALYSIS_TARGET) $(VM_TARGET): + ar rcs $@ $^ + +# object file targets +$(BUILD)/%.cpp.o: %.cpp + @mkdir -p $(dir $@) + $(CXX) $< $(CXXFLAGS) -c -MMD -MP -o $@ + +# protobuf fuzzer setup +fuzz/luau.pb.cpp: fuzz/luau.proto build/libprotobuf-mutator + cd fuzz && ../build/libprotobuf-mutator/external.protobuf/bin/protoc luau.proto --cpp_out=. + mv fuzz/luau.pb.cc fuzz/luau.pb.cpp + +$(BUILD)/fuzz/proto.cpp.o: build/libprotobuf-mutator +$(BUILD)/fuzz/protoprint.cpp.o: build/libprotobuf-mutator + +build/libprotobuf-mutator: + git clone https://github.com/google/libprotobuf-mutator build/libprotobuf-mutator + CXX= cmake -S build/libprotobuf-mutator -B build/libprotobuf-mutator -D CMAKE_BUILD_TYPE=Release -D LIB_PROTO_MUTATOR_DOWNLOAD_PROTOBUF=ON -D LIB_PROTO_MUTATOR_TESTING=OFF + make -C build/libprotobuf-mutator -j8 + +# picks up include dependencies for all object files +-include $(OBJECTS:.o=.d) diff --git a/Sources.cmake b/Sources.cmake new file mode 100644 index 0000000..6f96f6a --- /dev/null +++ b/Sources.cmake @@ -0,0 +1,215 @@ +# Luau.Ast Sources +target_sources(Luau.Ast PRIVATE + Ast/include/Luau/Ast.h + Ast/include/Luau/Common.h + Ast/include/Luau/Confusables.h + Ast/include/Luau/DenseHash.h + Ast/include/Luau/Lexer.h + Ast/include/Luau/Location.h + Ast/include/Luau/ParseOptions.h + Ast/include/Luau/Parser.h + Ast/include/Luau/StringUtils.h + + Ast/src/Ast.cpp + Ast/src/Confusables.cpp + Ast/src/Lexer.cpp + Ast/src/Location.cpp + Ast/src/Parser.cpp + Ast/src/StringUtils.cpp +) + +# Luau.Compiler Sources +target_sources(Luau.Compiler PRIVATE + Compiler/include/Luau/Bytecode.h + Compiler/include/Luau/BytecodeBuilder.h + Compiler/include/Luau/Compiler.h + + Compiler/src/BytecodeBuilder.cpp + Compiler/src/Compiler.cpp +) + +# Luau.Analysis Sources +target_sources(Luau.Analysis PRIVATE + Analysis/include/Luau/AstQuery.h + Analysis/include/Luau/Autocomplete.h + Analysis/include/Luau/BuiltinDefinitions.h + Analysis/include/Luau/Config.h + Analysis/include/Luau/Documentation.h + Analysis/include/Luau/Error.h + Analysis/include/Luau/FileResolver.h + Analysis/include/Luau/Frontend.h + Analysis/include/Luau/IostreamHelpers.h + Analysis/include/Luau/JsonEncoder.h + Analysis/include/Luau/Linter.h + Analysis/include/Luau/Module.h + Analysis/include/Luau/ModuleResolver.h + Analysis/include/Luau/Predicate.h + Analysis/include/Luau/RecursionCounter.h + Analysis/include/Luau/RequireTracer.h + Analysis/include/Luau/Substitution.h + Analysis/include/Luau/Symbol.h + Analysis/include/Luau/TopoSortStatements.h + Analysis/include/Luau/ToString.h + Analysis/include/Luau/Transpiler.h + Analysis/include/Luau/TxnLog.h + Analysis/include/Luau/TypeAttach.h + Analysis/include/Luau/TypedAllocator.h + Analysis/include/Luau/TypeInfer.h + Analysis/include/Luau/TypePack.h + Analysis/include/Luau/TypeUtils.h + Analysis/include/Luau/TypeVar.h + Analysis/include/Luau/Unifiable.h + Analysis/include/Luau/Unifier.h + Analysis/include/Luau/Variant.h + Analysis/include/Luau/VisitTypeVar.h + + Analysis/src/AstQuery.cpp + Analysis/src/Autocomplete.cpp + Analysis/src/BuiltinDefinitions.cpp + Analysis/src/Config.cpp + Analysis/src/Error.cpp + Analysis/src/Frontend.cpp + Analysis/src/IostreamHelpers.cpp + Analysis/src/JsonEncoder.cpp + Analysis/src/Linter.cpp + Analysis/src/Module.cpp + Analysis/src/Predicate.cpp + Analysis/src/RequireTracer.cpp + Analysis/src/Substitution.cpp + Analysis/src/Symbol.cpp + Analysis/src/TopoSortStatements.cpp + Analysis/src/ToString.cpp + Analysis/src/Transpiler.cpp + Analysis/src/TxnLog.cpp + Analysis/src/TypeAttach.cpp + Analysis/src/TypedAllocator.cpp + Analysis/src/TypeInfer.cpp + Analysis/src/TypePack.cpp + Analysis/src/TypeUtils.cpp + Analysis/src/TypeVar.cpp + Analysis/src/Unifiable.cpp + Analysis/src/Unifier.cpp + Analysis/src/EmbeddedBuiltinDefinitions.cpp +) + +# Luau.VM Sources +target_sources(Luau.VM PRIVATE + VM/include/lua.h + VM/include/luaconf.h + VM/include/lualib.h + + VM/src/lapi.cpp + VM/src/laux.cpp + VM/src/lbaselib.cpp + VM/src/lbitlib.cpp + VM/src/lbuiltins.cpp + VM/src/lcorolib.cpp + VM/src/ldblib.cpp + VM/src/ldebug.cpp + VM/src/ldo.cpp + VM/src/lfunc.cpp + VM/src/lgc.cpp + VM/src/linit.cpp + VM/src/lmathlib.cpp + VM/src/lmem.cpp + VM/src/lobject.cpp + VM/src/loslib.cpp + VM/src/lperf.cpp + VM/src/lstate.cpp + VM/src/lstring.cpp + VM/src/lstrlib.cpp + VM/src/ltable.cpp + VM/src/ltablib.cpp + VM/src/ltm.cpp + VM/src/lutf8lib.cpp + VM/src/lvmexecute.cpp + VM/src/lvmload.cpp + VM/src/lvmutils.cpp + VM/src/lapi.h + VM/src/lbuiltins.h + VM/src/lbytecode.h + VM/src/lcommon.h + VM/src/ldebug.h + VM/src/ldo.h + VM/src/lfunc.h + VM/src/lgc.h + VM/src/lmem.h + VM/src/lnumutils.h + VM/src/lobject.h + VM/src/lstate.h + VM/src/lstring.h + VM/src/ltable.h + VM/src/ltm.h + VM/src/lvm.h +) + +if(TARGET Luau.Repl.CLI) + # Luau.Repl.CLI Sources + target_sources(Luau.Repl.CLI PRIVATE + CLI/FileUtils.h + CLI/FileUtils.cpp + CLI/Profiler.h + CLI/Profiler.cpp + CLI/Repl.cpp) +endif() + +if(TARGET Luau.Analyze.CLI) + # Luau.Analyze.CLI Sources + target_sources(Luau.Analyze.CLI PRIVATE + CLI/FileUtils.h + CLI/FileUtils.cpp + CLI/Analyze.cpp) +endif() + +if(TARGET Luau.UnitTest) + # Luau.UnitTest Sources + target_sources(Luau.UnitTest PRIVATE + tests/Fixture.h + tests/IostreamOptional.h + tests/ScopedFlags.h + tests/Fixture.cpp + tests/AstQuery.test.cpp + tests/AstVisitor.test.cpp + tests/Autocomplete.test.cpp + tests/BuiltinDefinitions.test.cpp + tests/Compiler.test.cpp + tests/Config.test.cpp + tests/Error.test.cpp + tests/Frontend.test.cpp + tests/JsonEncoder.test.cpp + tests/Linter.test.cpp + tests/Module.test.cpp + tests/NonstrictMode.test.cpp + tests/Parser.test.cpp + tests/Predicate.test.cpp + tests/RequireTracer.test.cpp + tests/StringUtils.test.cpp + tests/Symbol.test.cpp + tests/TopoSort.test.cpp + tests/ToString.test.cpp + tests/Transpiler.test.cpp + tests/TypeInfer.annotations.test.cpp + tests/TypeInfer.builtins.test.cpp + tests/TypeInfer.classes.test.cpp + tests/TypeInfer.definitions.test.cpp + tests/TypeInfer.generics.test.cpp + tests/TypeInfer.intersectionTypes.test.cpp + tests/TypeInfer.provisional.test.cpp + tests/TypeInfer.refinements.test.cpp + tests/TypeInfer.tables.test.cpp + tests/TypeInfer.test.cpp + tests/TypeInfer.tryUnify.test.cpp + tests/TypeInfer.typePacks.cpp + tests/TypeInfer.unionTypes.test.cpp + tests/TypePack.test.cpp + tests/TypeVar.test.cpp + tests/Variant.test.cpp + tests/main.cpp) +endif() + +if(TARGET Luau.Conformance) + # Luau.Conformance Sources + target_sources(Luau.Conformance PRIVATE + tests/Conformance.test.cpp + tests/main.cpp) +endif() diff --git a/VM/include/lua.h b/VM/include/lua.h new file mode 100644 index 0000000..2f93ad9 --- /dev/null +++ b/VM/include/lua.h @@ -0,0 +1,385 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include +#include +#include + +#include "luaconf.h" + + + + +/* option for multiple returns in `lua_pcall' and `lua_call' */ +#define LUA_MULTRET (-1) + +/* +** pseudo-indices +*/ +#define LUA_REGISTRYINDEX (-10000) +#define LUA_ENVIRONINDEX (-10001) +#define LUA_GLOBALSINDEX (-10002) +#define lua_upvalueindex(i) (LUA_GLOBALSINDEX - (i)) + +/* thread status; 0 is OK */ +enum lua_Status +{ + LUA_OK = 0, + LUA_YIELD, + LUA_ERRRUN, + LUA_ERRSYNTAX, + LUA_ERRMEM, + LUA_ERRERR, + LUA_BREAK, /* yielded for a debug breakpoint */ +}; + +typedef struct lua_State lua_State; + +typedef int (*lua_CFunction)(lua_State* L); +typedef int (*lua_Continuation)(lua_State* L, int status); + +/* +** prototype for memory-allocation functions +*/ + +typedef void* (*lua_Alloc)(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsize); + +/* non-return type */ +#define l_noret void LUA_NORETURN + +/* +** basic types +*/ +#define LUA_TNONE (-1) + +/* + * WARNING: if you change the order of this enumeration, + * grep "ORDER TYPE" + */ +// clang-format off +enum lua_Type +{ + LUA_TNIL = 0, /* must be 0 due to lua_isnoneornil */ + LUA_TBOOLEAN = 1, /* must be 1 due to l_isfalse */ + + + LUA_TLIGHTUSERDATA, + LUA_TNUMBER, + LUA_TVECTOR, + + LUA_TSTRING, /* all types above this must be value types, all types below this must be GC types - see iscollectable */ + + + LUA_TTABLE, + LUA_TFUNCTION, + LUA_TUSERDATA, + LUA_TTHREAD, + + /* values below this line are used in GCObject tags but may never show up in TValue type tags */ + LUA_TPROTO, + LUA_TUPVAL, + LUA_TDEADKEY, + + /* the count of TValue type tags */ + LUA_T_COUNT = LUA_TPROTO +}; +// clang-format on + +/* type of numbers in Luau */ +typedef double lua_Number; + +/* type for integer functions */ +typedef int lua_Integer; + +/* unsigned integer type */ +typedef unsigned lua_Unsigned; + +/* +** state manipulation +*/ +LUA_API lua_State* lua_newstate(lua_Alloc f, void* ud); +LUA_API void lua_close(lua_State* L); +LUA_API lua_State* lua_newthread(lua_State* L); +LUA_API lua_State* lua_mainthread(lua_State* L); + +/* +** basic stack manipulation +*/ +LUA_API int lua_gettop(lua_State* L); +LUA_API void lua_settop(lua_State* L, int idx); +LUA_API void lua_pushvalue(lua_State* L, int idx); +LUA_API void lua_remove(lua_State* L, int idx); +LUA_API void lua_insert(lua_State* L, int idx); +LUA_API void lua_replace(lua_State* L, int idx); +LUA_API int lua_checkstack(lua_State* L, int sz); +LUA_API void lua_rawcheckstack(lua_State* L, int sz); /* allows for unlimited stack frames */ + +LUA_API void lua_xmove(lua_State* from, lua_State* to, int n); +LUA_API void lua_xpush(lua_State* from, lua_State* to, int idx); + +/* +** access functions (stack -> C) +*/ + +LUA_API int lua_isnumber(lua_State* L, int idx); +LUA_API int lua_isstring(lua_State* L, int idx); +LUA_API int lua_iscfunction(lua_State* L, int idx); +LUA_API int lua_isLfunction(lua_State* L, int idx); +LUA_API int lua_isuserdata(lua_State* L, int idx); +LUA_API int lua_type(lua_State* L, int idx); +LUA_API const char* lua_typename(lua_State* L, int tp); + +LUA_API int lua_equal(lua_State* L, int idx1, int idx2); +LUA_API int lua_rawequal(lua_State* L, int idx1, int idx2); +LUA_API int lua_lessthan(lua_State* L, int idx1, int idx2); + +LUA_API double lua_tonumberx(lua_State* L, int idx, int* isnum); +LUA_API int lua_tointegerx(lua_State* L, int idx, int* isnum); +LUA_API unsigned lua_tounsignedx(lua_State* L, int idx, int* isnum); +LUA_API const float* lua_tovector(lua_State* L, int idx); +LUA_API int lua_toboolean(lua_State* L, int idx); +LUA_API const char* lua_tolstring(lua_State* L, int idx, size_t* len); +LUA_API const char* lua_tostringatom(lua_State* L, int idx, int* atom); +LUA_API const char* lua_namecallatom(lua_State* L, int* atom); +LUA_API int lua_objlen(lua_State* L, int idx); +LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx); +LUA_API void* lua_touserdata(lua_State* L, int idx); +LUA_API void* lua_touserdatatagged(lua_State* L, int idx, int tag); +LUA_API int lua_userdatatag(lua_State* L, int idx); +LUA_API lua_State* lua_tothread(lua_State* L, int idx); +LUA_API const void* lua_topointer(lua_State* L, int idx); + +/* +** push functions (C -> stack) +*/ +LUA_API void lua_pushnil(lua_State* L); +LUA_API void lua_pushnumber(lua_State* L, double n); +LUA_API void lua_pushinteger(lua_State* L, int n); +LUA_API void lua_pushunsigned(lua_State* L, unsigned n); +LUA_API void lua_pushvector(lua_State* L, float x, float y, float z); +LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l); +LUA_API void lua_pushstring(lua_State* L, const char* s); +LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp); +LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...); +LUA_API void lua_pushcfunction( + lua_State* L, lua_CFunction fn, const char* debugname = NULL, int nup = 0, lua_Continuation cont = NULL); +LUA_API void lua_pushboolean(lua_State* L, int b); +LUA_API void lua_pushlightuserdata(lua_State* L, void* p); +LUA_API int lua_pushthread(lua_State* L); + +/* +** get functions (Lua -> stack) +*/ +LUA_API void lua_gettable(lua_State* L, int idx); +LUA_API void lua_getfield(lua_State* L, int idx, const char* k); +LUA_API void lua_rawgetfield(lua_State* L, int idx, const char* k); +LUA_API void lua_rawget(lua_State* L, int idx); +LUA_API void lua_rawgeti(lua_State* L, int idx, int n); +LUA_API void lua_createtable(lua_State* L, int narr, int nrec); + +LUA_API void lua_setreadonly(lua_State* L, int idx, bool value); +LUA_API int lua_getreadonly(lua_State* L, int idx); +LUA_API void lua_setsafeenv(lua_State* L, int idx, bool value); + +LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag); +LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); +LUA_API int lua_getmetatable(lua_State* L, int objindex); +LUA_API void lua_getfenv(lua_State* L, int idx); + +/* +** set functions (stack -> Lua) +*/ +LUA_API void lua_settable(lua_State* L, int idx); +LUA_API void lua_setfield(lua_State* L, int idx, const char* k); +LUA_API void lua_rawset(lua_State* L, int idx); +LUA_API void lua_rawseti(lua_State* L, int idx, int n); +LUA_API int lua_setmetatable(lua_State* L, int objindex); +LUA_API int lua_setfenv(lua_State* L, int idx); + +/* +** `load' and `call' functions (load and run Luau bytecode) +*/ +LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env = 0); +LUA_API void lua_call(lua_State* L, int nargs, int nresults); +LUA_API int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc); + +/* +** coroutine functions +*/ +LUA_API int lua_yield(lua_State* L, int nresults); +LUA_API int lua_break(lua_State* L); +LUA_API int lua_resume(lua_State* L, lua_State* from, int narg); +LUA_API int lua_resumeerror(lua_State* L, lua_State* from); +LUA_API int lua_status(lua_State* L); +LUA_API int lua_isyieldable(lua_State* L); + +/* +** garbage-collection function and options +*/ + +enum lua_GCOp +{ + LUA_GCSTOP, + LUA_GCRESTART, + LUA_GCCOLLECT, + LUA_GCCOUNT, + LUA_GCISRUNNING, + + // garbage collection is handled by 'assists' that perform some amount of GC work matching pace of allocation + // explicit GC steps allow to perform some amount of work at custom points to offset the need for GC assists + // note that GC might also be paused for some duration (until bytes allocated meet the threshold) + // if an explicit step is performed during this pause, it will trigger the start of the next collection cycle + LUA_GCSTEP, + + LUA_GCSETGOAL, + LUA_GCSETSTEPMUL, + LUA_GCSETSTEPSIZE, +}; + +LUA_API int lua_gc(lua_State* L, int what, int data); + +/* +** miscellaneous functions +*/ + +LUA_API l_noret lua_error(lua_State* L); + +LUA_API int lua_next(lua_State* L, int idx); + +LUA_API void lua_concat(lua_State* L, int n); + +LUA_API uintptr_t lua_encodepointer(lua_State* L, uintptr_t p); + +LUA_API double lua_clock(); + +LUA_API void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)); + +/* +** reference system, can be used to pin objects +*/ +#define LUA_NOREF -1 +#define LUA_REFNIL 0 + +LUA_API int lua_ref(lua_State* L, int idx); +LUA_API void lua_unref(lua_State* L, int ref); + +#define lua_getref(L, ref) lua_rawgeti(L, LUA_REGISTRYINDEX, (ref)) + +/* +** =============================================================== +** some useful macros +** =============================================================== +*/ +#define lua_tonumber(L, i) lua_tonumberx(L, i, NULL) +#define lua_tointeger(L, i) lua_tointegerx(L, i, NULL) +#define lua_tounsigned(L, i) lua_tounsignedx(L, i, NULL) + +#define lua_pop(L, n) lua_settop(L, -(n)-1) + +#define lua_newtable(L) lua_createtable(L, 0, 0) + +#define lua_strlen(L, i) lua_objlen(L, (i)) + +#define lua_isfunction(L, n) (lua_type(L, (n)) == LUA_TFUNCTION) +#define lua_istable(L, n) (lua_type(L, (n)) == LUA_TTABLE) +#define lua_islightuserdata(L, n) (lua_type(L, (n)) == LUA_TLIGHTUSERDATA) +#define lua_isnil(L, n) (lua_type(L, (n)) == LUA_TNIL) +#define lua_isboolean(L, n) (lua_type(L, (n)) == LUA_TBOOLEAN) +#define lua_isthread(L, n) (lua_type(L, (n)) == LUA_TTHREAD) +#define lua_isnone(L, n) (lua_type(L, (n)) == LUA_TNONE) +#define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) + +#define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1) + +#define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s)) +#define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s)) + +#define lua_tostring(L, i) lua_tolstring(L, (i), NULL) + +#define lua_pushfstring(L, fmt, ...) lua_pushfstringL(L, fmt, ##__VA_ARGS__) + +/* +** {====================================================================== +** Debug API +** ======================================================================= +*/ + +typedef struct lua_Debug lua_Debug; /* activation record */ + +/* Functions to be called by the debugger in specific events */ +typedef void (*lua_Hook)(lua_State* L, lua_Debug* ar); + +LUA_API int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar); +LUA_API int lua_getargument(lua_State* L, int level, int n); +LUA_API const char* lua_getlocal(lua_State* L, int level, int n); +LUA_API const char* lua_setlocal(lua_State* L, int level, int n); +LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n); +LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); + +LUA_API void lua_singlestep(lua_State* L, bool singlestep); +LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable); + +/* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ +LUA_API const char* lua_debugtrace(lua_State* L); + +struct lua_Debug +{ + const char* name; /* (n) */ + const char* what; /* (s) `Lua', `C', `main', `tail' */ + const char* source; /* (s) */ + int linedefined; /* (s) */ + int currentline; /* (l) */ + unsigned char nupvals; /* (u) number of upvalues */ + unsigned char nparams; /* (a) number of parameters */ + char isvararg; /* (a) */ + char short_src[LUA_IDSIZE]; /* (s) */ + void* userdata; /* only valid in luau_callhook */ +}; + +/* }====================================================================== */ + +/* Callbacks that can be used to reconfigure behavior of the VM dynamically. + * These are shared between all coroutines. + * + * Note: interrupt is safe to set from an arbitrary thread but all other callbacks + * can only be changed when the VM is not running any code */ +struct lua_Callbacks +{ + void (*interrupt)(lua_State* L, int gc); /* gets called at safepoints (loop back edges, call/ret, gc) if set */ + void (*panic)(lua_State* L, int errcode); /* gets called when an unprotected error is raised (if longjmp is used) */ + + void (*userthread)(lua_State* LP, lua_State* L); /* gets called when L is created (LP == parent) or destroyed (LP == NULL) */ + int16_t (*useratom)(const char* s, size_t l); /* gets called when a string is created; returned atom can be retrieved via tostringatom */ + + void (*debugbreak)(lua_State* L, lua_Debug* ar); /* gets called when BREAK instruction is encountered */ + void (*debugstep)(lua_State* L, lua_Debug* ar); /* gets called after each instruction in single step mode */ + void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */ + void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */ +}; + +LUA_API lua_Callbacks* lua_callbacks(lua_State* L); + +/****************************************************************************** + * Copyright (c) 2019-2021 Roblox Corporation + * Copyright (C) 1994-2008 Lua.org, PUC-Rio. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ******************************************************************************/ diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h new file mode 100644 index 0000000..aa008a2 --- /dev/null +++ b/VM/include/luaconf.h @@ -0,0 +1,124 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +// When debugging complex issues, consider enabling one of these: +// This will reallocate the stack very aggressively at every opportunity; use this with asan to catch stale stack pointers +// #define HARDSTACKTESTS 1 +// This will call GC validation very aggressively at every incremental GC step; use this with caution as it's SLOW +// #define HARDMEMTESTS 1 +// This will call GC validation very aggressively at every GC opportunity; use this with caution as it's VERY SLOW +// #define HARDMEMTESTS 2 + +// To force MSVC2017+ to generate SSE2 code for some stdlib functions we need to locally enable /fp:fast +// Note that /fp:fast changes the semantics of floating point comparisons so this is only safe to do for functions without ones +#if defined(_MSC_VER) && !defined(__clang__) +#define LUAU_FASTMATH_BEGIN __pragma(float_control(precise, off, push)) +#define LUAU_FASTMATH_END __pragma(float_control(pop)) +#else +#define LUAU_FASTMATH_BEGIN +#define LUAU_FASTMATH_END +#endif + +// Used on functions that have a printf-like interface to validate them statically +#if defined(__GNUC__) +#define LUA_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg))) +#else +#define LUA_PRINTF_ATTR(fmt, arg) +#endif + +#ifdef _MSC_VER +#define LUA_NORETURN __declspec(noreturn) +#else +#define LUA_NORETURN __attribute__((__noreturn__)) +#endif + +/* Can be used to reconfigure visibility/exports for public APIs */ +#define LUA_API extern +#define LUALIB_API LUA_API + +/* Can be used to reconfigure visibility for internal APIs */ +#if defined(__GNUC__) +#define LUAI_FUNC __attribute__((visibility("hidden"))) extern +#define LUAI_DATA LUAI_FUNC +#else +#define LUAI_FUNC extern +#define LUAI_DATA extern +#endif + +/* Can be used to reconfigure internal error handling to use longjmp instead of C++ EH */ +#define LUA_USE_LONGJMP 0 + +/* LUA_IDSIZE gives the maximum size for the description of the source */ +#define LUA_IDSIZE 256 + +/* +@@ LUAI_GCGOAL defines the desired top heap size in relation to the live heap +@* size at the end of the GC cycle +** CHANGE it if you want the GC to run faster or slower (higher values +** mean larger GC pauses which mean slower collection.) You can also change +** this value dynamically. +*/ +#define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ + +/* +@@ LUAI_GCSTEPMUL / LUAI_GCSTEPSIZE define the default speed of garbage collection +@* relative to memory allocation. +** Every LUAI_GCSTEPSIZE KB allocated, incremental collector collects LUAI_GCSTEPSIZE +** times LUAI_GCSTEPMUL% bytes. +** CHANGE it if you want to change the granularity of the garbage +** collection. +*/ +#define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ +#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ + +/* LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function */ +#define LUA_MINSTACK 20 + +/* LUAI_MAXCSTACK limits the number of Lua stack slots that a C function can use */ +#define LUAI_MAXCSTACK 8000 + +/* LUAI_MAXCALLS limits the number of nested calls */ +#define LUAI_MAXCALLS 20000 + +/* LUAI_MAXCCALLS is the maximum depth for nested C calls; this limit depends on native stack size */ +#define LUAI_MAXCCALLS 200 + +/* buffer size used for on-stack string operations; this limit depends on native stack size */ +#define LUA_BUFFERSIZE 512 + +/* number of valid Lua userdata tags */ +#define LUA_UTAG_LIMIT 128 + +/* upper bound for number of size classes used by page allocator */ +#define LUA_SIZECLASSES 32 + +/* available number of separate memory categories */ +#define LUA_MEMORY_CATEGORIES 256 + +/* minimum size for the string table (must be power of 2) */ +#define LUA_MINSTRTABSIZE 32 + +/* maximum number of captures supported by pattern matching */ +#define LUA_MAXCAPTURES 32 + +/* }================================================================== */ + +/* Default number printing format and the string length limit */ +#define LUA_NUMBER_FMT "%.14g" +#define LUAI_MAXNUMBER2STR 32 /* 16 digits, sign, point, and \0 */ + +/* +@@ LUAI_USER_ALIGNMENT_T is a type that requires maximum alignment. +** CHANGE it if your system requires alignments larger than double. (For +** instance, if your system supports long doubles and they must be +** aligned in 16-byte boundaries, then you should add long double in the +** union.) Probably you do not need to change this. +*/ +#define LUAI_USER_ALIGNMENT_T \ + union \ + { \ + double u; \ + void* s; \ + long l; \ + } diff --git a/VM/include/lualib.h b/VM/include/lualib.h new file mode 100644 index 0000000..7a09ae9 --- /dev/null +++ b/VM/include/lualib.h @@ -0,0 +1,129 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lua.h" + +#define luaL_error(L, fmt, ...) luaL_errorL(L, fmt, ##__VA_ARGS__) +#define luaL_typeerror(L, narg, tname) luaL_typeerrorL(L, narg, tname) +#define luaL_argerror(L, narg, extramsg) luaL_argerrorL(L, narg, extramsg) + +typedef struct luaL_Reg +{ + const char* name; + lua_CFunction func; +} luaL_Reg; + +LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l); +LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* e); +LUALIB_API int luaL_callmeta(lua_State* L, int obj, const char* e); +LUALIB_API l_noret luaL_typeerrorL(lua_State* L, int narg, const char* tname); +LUALIB_API l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg); +LUALIB_API const char* luaL_checklstring(lua_State* L, int numArg, size_t* l); +LUALIB_API const char* luaL_optlstring(lua_State* L, int numArg, const char* def, size_t* l); +LUALIB_API double luaL_checknumber(lua_State* L, int numArg); +LUALIB_API double luaL_optnumber(lua_State* L, int nArg, double def); + +LUALIB_API int luaL_checkinteger(lua_State* L, int numArg); +LUALIB_API int luaL_optinteger(lua_State* L, int nArg, int def); +LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int numArg); +LUALIB_API unsigned luaL_optunsigned(lua_State* L, int numArg, unsigned def); + +LUALIB_API void luaL_checkstack(lua_State* L, int sz, const char* msg); +LUALIB_API void luaL_checktype(lua_State* L, int narg, int t); +LUALIB_API void luaL_checkany(lua_State* L, int narg); + +LUALIB_API int luaL_newmetatable(lua_State* L, const char* tname); +LUALIB_API void* luaL_checkudata(lua_State* L, int ud, const char* tname); + +LUALIB_API void luaL_where(lua_State* L, int lvl); +LUALIB_API LUA_PRINTF_ATTR(2, 3) l_noret luaL_errorL(lua_State* L, const char* fmt, ...); + +LUALIB_API int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const lst[]); + +LUALIB_API const char* luaL_tolstring(lua_State* L, int idx, size_t* len); + +LUALIB_API lua_State* luaL_newstate(void); + +LUALIB_API const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint); + +/* +** =============================================================== +** some useful macros +** =============================================================== +*/ + +#define luaL_argcheck(L, cond, arg, extramsg) ((void)((cond) ? (void)0 : luaL_argerror(L, arg, extramsg))) +#define luaL_argexpected(L, cond, arg, tname) ((void)((cond) ? (void)0 : luaL_typeerror(L, arg, tname))) + +#define luaL_checkstring(L, n) (luaL_checklstring(L, (n), NULL)) +#define luaL_optstring(L, n, d) (luaL_optlstring(L, (n), (d), NULL)) + +#define luaL_typename(L, i) lua_typename(L, lua_type(L, (i))) + +#define luaL_getmetatable(L, n) (lua_getfield(L, LUA_REGISTRYINDEX, (n))) + +#define luaL_opt(L, f, n, d) (lua_isnoneornil(L, (n)) ? (d) : f(L, (n))) + +/* generic buffer manipulation */ + +struct luaL_Buffer +{ + char* p; // current position in buffer + char* end; // end of the current buffer + lua_State* L; + struct TString* storage; + char buffer[LUA_BUFFERSIZE]; +}; + +// when internal buffer storage is exhaused, a mutable string value 'storage' will be placed on the stack +// in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) +// with the exception of luaL_addvalue that expects the value at the top and string buffer further away (top-2) +// functions that accept a 'boxloc' support string buffer placement at any location in the stack +// all the buffer users we have in Luau match this pattern, but it's something to keep in mind for new uses of buffers + +#define luaL_addchar(B, c) ((void)((B)->p < (B)->end || luaL_extendbuffer(B, 1, -1)), (*(B)->p++ = (char)(c))) +#define luaL_addstring(B, s) luaL_addlstring(B, s, strlen(s)) + +LUALIB_API void luaL_buffinit(lua_State* L, luaL_Buffer* B); +LUALIB_API char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size); +LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc); +LUALIB_API void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc); +LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t l); +LUALIB_API void luaL_addvalue(luaL_Buffer* B); +LUALIB_API void luaL_pushresult(luaL_Buffer* B); +LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size); + +/* builtin libraries */ +LUALIB_API int luaopen_base(lua_State* L); + +#define LUA_COLIBNAME "coroutine" +LUALIB_API int luaopen_coroutine(lua_State* L); + +#define LUA_TABLIBNAME "table" +LUALIB_API int luaopen_table(lua_State* L); + +#define LUA_OSLIBNAME "os" +LUALIB_API int luaopen_os(lua_State* L); + +#define LUA_STRLIBNAME "string" +LUALIB_API int luaopen_string(lua_State* L); + +#define LUA_BITLIBNAME "bit32" +LUALIB_API int luaopen_bit32(lua_State* L); + +#define LUA_UTF8LIBNAME "utf8" +LUALIB_API int luaopen_utf8(lua_State* L); + +#define LUA_MATHLIBNAME "math" +LUALIB_API int luaopen_math(lua_State* L); + +#define LUA_DBLIBNAME "debug" +LUALIB_API int luaopen_debug(lua_State* L); + +/* open all builtin libraries */ +LUALIB_API void luaL_openlibs(lua_State* L); + +/* sandbox libraries and globals */ +LUALIB_API void luaL_sandbox(lua_State* L); +LUALIB_API void luaL_sandboxthread(lua_State* L); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp new file mode 100644 index 0000000..0131536 --- /dev/null +++ b/VM/src/lapi.cpp @@ -0,0 +1,1273 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lapi.h" + +#include "lstate.h" +#include "lstring.h" +#include "ltable.h" +#include "lfunc.h" +#include "lgc.h" +#include "ldo.h" +#include "lvm.h" +#include "lnumutils.h" + +#include + +LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) + +const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" + "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" + "$URL: www.lua.org $\n"; + +const char* luau_ident = "$Luau: Copyright (C) 2019-2021 Roblox Corporation $\n" + "$URL: luau-lang.org $\n"; + +#define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base)) + +#define api_checkvalidindex(L, i) api_check(L, (i) != luaO_nilobject) + +#define api_incr_top(L) \ + { \ + api_check(L, L->top < L->ci->top); \ + L->top++; \ + } + +static Table* getcurrenv(lua_State* L) +{ + if (L->ci == L->base_ci) /* no enclosing function? */ + return hvalue(gt(L)); /* use global table as environment */ + else + { + Closure* func = curr_func(L); + return func->env; + } +} + +static LUAU_NOINLINE TValue* index2adrslow(lua_State* L, int idx) +{ + api_check(L, idx <= 0); + if (idx > LUA_REGISTRYINDEX) + { + api_check(L, idx != 0 && -idx <= L->top - L->base); + return L->top + idx; + } + else + switch (idx) + { /* pseudo-indices */ + case LUA_REGISTRYINDEX: + return registry(L); + case LUA_ENVIRONINDEX: + { + sethvalue(L, &L->env, getcurrenv(L)); + return &L->env; + } + case LUA_GLOBALSINDEX: + return gt(L); + default: + { + Closure* func = curr_func(L); + idx = LUA_GLOBALSINDEX - idx; + return (idx <= func->nupvalues) ? &func->c.upvals[idx - 1] : cast_to(TValue*, luaO_nilobject); + } + } +} + +static LUAU_FORCEINLINE TValue* index2adr(lua_State* L, int idx) +{ + if (idx > 0) + { + TValue* o = L->base + (idx - 1); + api_check(L, idx <= L->ci->top - L->base); + if (o >= L->top) + return cast_to(TValue*, luaO_nilobject); + else + return o; + } + else + { + return index2adrslow(L, idx); + } +} + +const TValue* luaA_toobject(lua_State* L, int idx) +{ + StkId p = index2adr(L, idx); + return (p == luaO_nilobject) ? NULL : p; +} + +void luaA_pushobject(lua_State* L, const TValue* o) +{ + setobj2s(L, L->top, o); + api_incr_top(L); +} + +int lua_checkstack(lua_State* L, int size) +{ + int res = 1; + if (size > LUAI_MAXCSTACK || (L->top - L->base + size) > LUAI_MAXCSTACK) + res = 0; /* stack overflow */ + else if (size > 0) + { + luaD_checkstack(L, size); + expandstacklimit(L, L->top + size); + } + return res; +} + +void lua_rawcheckstack(lua_State* L, int size) +{ + luaD_checkstack(L, size); + expandstacklimit(L, L->top + size); + return; +} + +void lua_xmove(lua_State* from, lua_State* to, int n) +{ + if (from == to) + return; + api_checknelems(from, n); + api_check(from, from->global == to->global); + api_check(from, to->ci->top - to->top >= n); + luaC_checkthreadsleep(to); + + StkId ttop = to->top; + StkId ftop = from->top - n; + for (int i = 0; i < n; i++) + setobj2s(to, ttop + i, ftop + i); + + from->top = ftop; + to->top = ttop + n; + + return; +} + +void lua_xpush(lua_State* from, lua_State* to, int idx) +{ + api_check(from, from->global == to->global); + luaC_checkthreadsleep(to); + setobj2s(to, to->top, index2adr(from, idx)); + api_incr_top(to); + return; +} + +lua_State* lua_newthread(lua_State* L) +{ + luaC_checkGC(L); + luaC_checkthreadsleep(L); + lua_State* L1 = luaE_newthread(L); + setthvalue(L, L->top, L1); + api_incr_top(L); + global_State* g = L->global; + if (g->cb.userthread) + g->cb.userthread(L, L1); + return L1; +} + +lua_State* lua_mainthread(lua_State* L) +{ + return L->global->mainthread; +} + +/* +** basic stack manipulation +*/ + +int lua_gettop(lua_State* L) +{ + return cast_int(L->top - L->base); +} + +void lua_settop(lua_State* L, int idx) +{ + if (idx >= 0) + { + api_check(L, idx <= L->stack_last - L->base); + while (L->top < L->base + idx) + setnilvalue(L->top++); + L->top = L->base + idx; + } + else + { + api_check(L, -(idx + 1) <= (L->top - L->base)); + L->top += idx + 1; /* `subtract' index (index is negative) */ + } + return; +} + +void lua_remove(lua_State* L, int idx) +{ + StkId p = index2adr(L, idx); + api_checkvalidindex(L, p); + while (++p < L->top) + setobjs2s(L, p - 1, p); + L->top--; + return; +} + +void lua_insert(lua_State* L, int idx) +{ + luaC_checkthreadsleep(L); + StkId p = index2adr(L, idx); + api_checkvalidindex(L, p); + for (StkId q = L->top; q > p; q--) + setobjs2s(L, q, q - 1); + setobjs2s(L, p, L->top); + return; +} + +void lua_replace(lua_State* L, int idx) +{ + /* explicit test for incompatible code */ + if (idx == LUA_ENVIRONINDEX && L->ci == L->base_ci) + luaG_runerror(L, "no calling environment"); + api_checknelems(L, 1); + luaC_checkthreadsleep(L); + StkId o = index2adr(L, idx); + api_checkvalidindex(L, o); + if (idx == LUA_ENVIRONINDEX) + { + Closure* func = curr_func(L); + api_check(L, ttistable(L->top - 1)); + func->env = hvalue(L->top - 1); + luaC_barrier(L, func, L->top - 1); + } + else + { + setobj(L, o, L->top - 1); + if (idx < LUA_GLOBALSINDEX) /* function upvalue? */ + luaC_barrier(L, curr_func(L), L->top - 1); + } + L->top--; + return; +} + +void lua_pushvalue(lua_State* L, int idx) +{ + luaC_checkthreadsleep(L); + StkId o = index2adr(L, idx); + setobj2s(L, L->top, o); + api_incr_top(L); + return; +} + +/* +** access functions (stack -> C) +*/ + +int lua_type(lua_State* L, int idx) +{ + StkId o = index2adr(L, idx); + return (o == luaO_nilobject) ? LUA_TNONE : ttype(o); +} + +const char* lua_typename(lua_State* L, int t) +{ + return (t == LUA_TNONE) ? "no value" : luaT_typenames[t]; +} + +int lua_iscfunction(lua_State* L, int idx) +{ + StkId o = index2adr(L, idx); + return iscfunction(o); +} + +int lua_isLfunction(lua_State* L, int idx) +{ + StkId o = index2adr(L, idx); + return isLfunction(o); +} + +int lua_isnumber(lua_State* L, int idx) +{ + TValue n; + const TValue* o = index2adr(L, idx); + return tonumber(o, &n); +} + +int lua_isstring(lua_State* L, int idx) +{ + int t = lua_type(L, idx); + return (t == LUA_TSTRING || t == LUA_TNUMBER); +} + +int lua_isuserdata(lua_State* L, int idx) +{ + const TValue* o = index2adr(L, idx); + return (ttisuserdata(o) || ttislightuserdata(o)); +} + +int lua_rawequal(lua_State* L, int index1, int index2) +{ + StkId o1 = index2adr(L, index1); + StkId o2 = index2adr(L, index2); + return (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : luaO_rawequalObj(o1, o2); +} + +int lua_equal(lua_State* L, int index1, int index2) +{ + StkId o1, o2; + int i; + o1 = index2adr(L, index1); + o2 = index2adr(L, index2); + i = (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : equalobj(L, o1, o2); + return i; +} + +int lua_lessthan(lua_State* L, int index1, int index2) +{ + StkId o1, o2; + int i; + o1 = index2adr(L, index1); + o2 = index2adr(L, index2); + i = (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : luaV_lessthan(L, o1, o2); + return i; +} + +double lua_tonumberx(lua_State* L, int idx, int* isnum) +{ + TValue n; + const TValue* o = index2adr(L, idx); + if (tonumber(o, &n)) + { + if (isnum) + *isnum = 1; + return nvalue(o); + } + else + { + if (isnum) + *isnum = 0; + return 0; + } +} + +int lua_tointegerx(lua_State* L, int idx, int* isnum) +{ + TValue n; + const TValue* o = index2adr(L, idx); + if (tonumber(o, &n)) + { + int res; + double num = nvalue(o); + luai_num2int(res, num); + if (isnum) + *isnum = 1; + return res; + } + else + { + if (isnum) + *isnum = 0; + return 0; + } +} + +unsigned lua_tounsignedx(lua_State* L, int idx, int* isnum) +{ + TValue n; + const TValue* o = index2adr(L, idx); + if (tonumber(o, &n)) + { + unsigned res; + double num = nvalue(o); + luai_num2unsigned(res, num); + if (isnum) + *isnum = 1; + return res; + } + else + { + if (isnum) + *isnum = 0; + return 0; + } +} + +int lua_toboolean(lua_State* L, int idx) +{ + const TValue* o = index2adr(L, idx); + return !l_isfalse(o); +} + +const char* lua_tolstring(lua_State* L, int idx, size_t* len) +{ + StkId o = index2adr(L, idx); + if (!ttisstring(o)) + { + luaC_checkthreadsleep(L); + if (!luaV_tostring(L, o)) + { /* conversion failed? */ + if (len != NULL) + *len = 0; + return NULL; + } + luaC_checkGC(L); + o = index2adr(L, idx); /* previous call may reallocate the stack */ + } + if (len != NULL) + *len = tsvalue(o)->len; + return svalue(o); +} + +const char* lua_tostringatom(lua_State* L, int idx, int* atom) +{ + StkId o = index2adr(L, idx); + if (!ttisstring(o)) + return NULL; + const TString* s = tsvalue(o); + if (atom) + *atom = s->atom; + return getstr(s); +} + +const char* lua_namecallatom(lua_State* L, int* atom) +{ + const TString* s = L->namecall; + if (!s) + return NULL; + if (atom) + *atom = s->atom; + return getstr(s); +} + +const float* lua_tovector(lua_State* L, int idx) +{ + StkId o = index2adr(L, idx); + if (!ttisvector(o)) + { + return NULL; + } + return vvalue(o); +} + +int lua_objlen(lua_State* L, int idx) +{ + StkId o = index2adr(L, idx); + switch (ttype(o)) + { + case LUA_TSTRING: + return tsvalue(o)->len; + case LUA_TUSERDATA: + return uvalue(o)->len; + case LUA_TTABLE: + return luaH_getn(hvalue(o)); + case LUA_TNUMBER: + { + int l = (luaV_tostring(L, o) ? tsvalue(o)->len : 0); + return l; + } + default: + return 0; + } +} + +lua_CFunction lua_tocfunction(lua_State* L, int idx) +{ + StkId o = index2adr(L, idx); + return (!iscfunction(o)) ? NULL : cast_to(lua_CFunction, clvalue(o)->c.f); +} + +void* lua_touserdata(lua_State* L, int idx) +{ + StkId o = index2adr(L, idx); + switch (ttype(o)) + { + case LUA_TUSERDATA: + return uvalue(o)->data; + case LUA_TLIGHTUSERDATA: + return pvalue(o); + default: + return NULL; + } +} + +void* lua_touserdatatagged(lua_State* L, int idx, int tag) +{ + StkId o = index2adr(L, idx); + return (ttisuserdata(o) && uvalue(o)->tag == tag) ? uvalue(o)->data : NULL; +} + +int lua_userdatatag(lua_State* L, int idx) +{ + StkId o = index2adr(L, idx); + if (ttisuserdata(o)) + return uvalue(o)->tag; + return -1; +} + +lua_State* lua_tothread(lua_State* L, int idx) +{ + StkId o = index2adr(L, idx); + return (!ttisthread(o)) ? NULL : thvalue(o); +} + +const void* lua_topointer(lua_State* L, int idx) +{ + StkId o = index2adr(L, idx); + switch (ttype(o)) + { + case LUA_TTABLE: + return hvalue(o); + case LUA_TFUNCTION: + return clvalue(o); + case LUA_TTHREAD: + return thvalue(o); + case LUA_TUSERDATA: + case LUA_TLIGHTUSERDATA: + return lua_touserdata(L, idx); + default: + return NULL; + } +} + +/* +** push functions (C -> stack) +*/ + +void lua_pushnil(lua_State* L) +{ + setnilvalue(L->top); + api_incr_top(L); + return; +} + +void lua_pushnumber(lua_State* L, double n) +{ + setnvalue(L->top, n); + api_incr_top(L); + return; +} + +void lua_pushinteger(lua_State* L, int n) +{ + setnvalue(L->top, cast_num(n)); + api_incr_top(L); + return; +} + +void lua_pushunsigned(lua_State* L, unsigned u) +{ + setnvalue(L->top, cast_num(u)); + api_incr_top(L); + return; +} + +void lua_pushvector(lua_State* L, float x, float y, float z) +{ + setvvalue(L->top, x, y, z); + api_incr_top(L); + return; +} + +void lua_pushlstring(lua_State* L, const char* s, size_t len) +{ + luaC_checkGC(L); + luaC_checkthreadsleep(L); + setsvalue2s(L, L->top, luaS_newlstr(L, s, len)); + api_incr_top(L); + return; +} + +void lua_pushstring(lua_State* L, const char* s) +{ + if (s == NULL) + lua_pushnil(L); + else + lua_pushlstring(L, s, strlen(s)); +} + +const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp) +{ + luaC_checkGC(L); + luaC_checkthreadsleep(L); + const char* ret = luaO_pushvfstring(L, fmt, argp); + return ret; +} + +const char* lua_pushfstringL(lua_State* L, const char* fmt, ...) +{ + luaC_checkGC(L); + luaC_checkthreadsleep(L); + va_list argp; + va_start(argp, fmt); + const char* ret = luaO_pushvfstring(L, fmt, argp); + va_end(argp); + return ret; +} + +void lua_pushcfunction(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) +{ + luaC_checkGC(L); + luaC_checkthreadsleep(L); + api_checknelems(L, nup); + Closure* cl = luaF_newCclosure(L, nup, getcurrenv(L)); + cl->c.f = fn; + cl->c.cont = cont; + cl->c.debugname = debugname; + L->top -= nup; + while (nup--) + setobj2n(L, &cl->c.upvals[nup], L->top + nup); + setclvalue(L, L->top, cl); + LUAU_ASSERT(iswhite(obj2gco(cl))); + api_incr_top(L); + return; +} + +void lua_pushboolean(lua_State* L, int b) +{ + setbvalue(L->top, (b != 0)); /* ensure that true is 1 */ + api_incr_top(L); + return; +} + +void lua_pushlightuserdata(lua_State* L, void* p) +{ + setpvalue(L->top, p); + api_incr_top(L); + return; +} + +int lua_pushthread(lua_State* L) +{ + luaC_checkthreadsleep(L); + setthvalue(L, L->top, L); + api_incr_top(L); + return L->global->mainthread == L; +} + +/* +** get functions (Lua -> stack) +*/ + +void lua_gettable(lua_State* L, int idx) +{ + luaC_checkthreadsleep(L); + StkId t = index2adr(L, idx); + api_checkvalidindex(L, t); + luaV_gettable(L, t, L->top - 1, L->top - 1); + return; +} + +void lua_getfield(lua_State* L, int idx, const char* k) +{ + luaC_checkthreadsleep(L); + StkId t = index2adr(L, idx); + api_checkvalidindex(L, t); + TValue key; + setsvalue(L, &key, luaS_new(L, k)); + luaV_gettable(L, t, &key, L->top); + api_incr_top(L); + return; +} + +void lua_rawgetfield(lua_State* L, int idx, const char* k) +{ + luaC_checkthreadsleep(L); + StkId t = index2adr(L, idx); + api_check(L, ttistable(t)); + TValue key; + setsvalue(L, &key, luaS_new(L, k)); + setobj2s(L, L->top, luaH_getstr(hvalue(t), tsvalue(&key))); + api_incr_top(L); + return; +} + +void lua_rawget(lua_State* L, int idx) +{ + luaC_checkthreadsleep(L); + StkId t = index2adr(L, idx); + api_check(L, ttistable(t)); + setobj2s(L, L->top - 1, luaH_get(hvalue(t), L->top - 1)); + return; +} + +void lua_rawgeti(lua_State* L, int idx, int n) +{ + luaC_checkthreadsleep(L); + StkId t = index2adr(L, idx); + api_check(L, ttistable(t)); + setobj2s(L, L->top, luaH_getnum(hvalue(t), n)); + api_incr_top(L); + return; +} + +void lua_createtable(lua_State* L, int narray, int nrec) +{ + luaC_checkGC(L); + luaC_checkthreadsleep(L); + sethvalue(L, L->top, luaH_new(L, narray, nrec)); + api_incr_top(L); + return; +} + +void lua_setreadonly(lua_State* L, int objindex, bool value) +{ + const TValue* o = index2adr(L, objindex); + api_check(L, ttistable(o)); + Table* t = hvalue(o); + t->readonly = value; + return; +} + +int lua_getreadonly(lua_State* L, int objindex) +{ + const TValue* o = index2adr(L, objindex); + api_check(L, ttistable(o)); + Table* t = hvalue(o); + int res = t->readonly; + return res; +} + +void lua_setsafeenv(lua_State* L, int objindex, bool value) +{ + const TValue* o = index2adr(L, objindex); + api_check(L, ttistable(o)); + Table* t = hvalue(o); + t->safeenv = value; + return; +} + +int lua_getmetatable(lua_State* L, int objindex) +{ + const TValue* obj; + Table* mt = NULL; + int res; + obj = index2adr(L, objindex); + switch (ttype(obj)) + { + case LUA_TTABLE: + mt = hvalue(obj)->metatable; + break; + case LUA_TUSERDATA: + mt = uvalue(obj)->metatable; + break; + default: + mt = L->global->mt[ttype(obj)]; + break; + } + if (mt == NULL) + res = 0; + else + { + sethvalue(L, L->top, mt); + api_incr_top(L); + res = 1; + } + return res; +} + +void lua_getfenv(lua_State* L, int idx) +{ + StkId o; + o = index2adr(L, idx); + api_checkvalidindex(L, o); + switch (ttype(o)) + { + case LUA_TFUNCTION: + sethvalue(L, L->top, clvalue(o)->env); + break; + case LUA_TTHREAD: + setobj2s(L, L->top, gt(thvalue(o))); + break; + default: + setnilvalue(L->top); + break; + } + api_incr_top(L); + return; +} + +/* +** set functions (stack -> Lua) +*/ + +void lua_settable(lua_State* L, int idx) +{ + StkId t; + api_checknelems(L, 2); + t = index2adr(L, idx); + api_checkvalidindex(L, t); + luaV_settable(L, t, L->top - 2, L->top - 1); + L->top -= 2; /* pop index and value */ + return; +} + +void lua_setfield(lua_State* L, int idx, const char* k) +{ + StkId t; + TValue key; + api_checknelems(L, 1); + t = index2adr(L, idx); + api_checkvalidindex(L, t); + setsvalue(L, &key, luaS_new(L, k)); + luaV_settable(L, t, &key, L->top - 1); + L->top--; /* pop value */ + return; +} + +void lua_rawset(lua_State* L, int idx) +{ + StkId t; + api_checknelems(L, 2); + t = index2adr(L, idx); + api_check(L, ttistable(t)); + if (hvalue(t)->readonly) + luaG_runerror(L, "Attempt to modify a readonly table"); + setobj2t(L, luaH_set(L, hvalue(t), L->top - 2), L->top - 1); + luaC_barriert(L, hvalue(t), L->top - 1); + L->top -= 2; + return; +} + +void lua_rawseti(lua_State* L, int idx, int n) +{ + StkId o; + api_checknelems(L, 1); + o = index2adr(L, idx); + api_check(L, ttistable(o)); + if (hvalue(o)->readonly) + luaG_runerror(L, "Attempt to modify a readonly table"); + setobj2t(L, luaH_setnum(L, hvalue(o), n), L->top - 1); + luaC_barriert(L, hvalue(o), L->top - 1); + L->top--; + return; +} + +int lua_setmetatable(lua_State* L, int objindex) +{ + TValue* obj; + Table* mt; + api_checknelems(L, 1); + obj = index2adr(L, objindex); + api_checkvalidindex(L, obj); + if (ttisnil(L->top - 1)) + mt = NULL; + else + { + api_check(L, ttistable(L->top - 1)); + mt = hvalue(L->top - 1); + } + switch (ttype(obj)) + { + case LUA_TTABLE: + { + if (hvalue(obj)->readonly) + luaG_runerror(L, "Attempt to modify a readonly table"); + hvalue(obj)->metatable = mt; + if (mt) + luaC_objbarriert(L, hvalue(obj), mt); + break; + } + case LUA_TUSERDATA: + { + uvalue(obj)->metatable = mt; + if (mt) + luaC_objbarrier(L, uvalue(obj), mt); + break; + } + default: + { + L->global->mt[ttype(obj)] = mt; + break; + } + } + L->top--; + return 1; +} + +int lua_setfenv(lua_State* L, int idx) +{ + StkId o; + int res = 1; + api_checknelems(L, 1); + o = index2adr(L, idx); + api_checkvalidindex(L, o); + api_check(L, ttistable(L->top - 1)); + switch (ttype(o)) + { + case LUA_TFUNCTION: + clvalue(o)->env = hvalue(L->top - 1); + break; + case LUA_TTHREAD: + sethvalue(L, gt(thvalue(o)), hvalue(L->top - 1)); + break; + default: + res = 0; + break; + } + if (res) + { + luaC_objbarrier(L, &gcvalue(o)->gch, hvalue(L->top - 1)); + } + L->top--; + return res; +} + +/* +** `load' and `call' functions (run Lua code) +*/ + +#define adjustresults(L, nres) \ + { \ + if (nres == LUA_MULTRET && L->top >= L->ci->top) \ + L->ci->top = L->top; \ + } + +#define checkresults(L, na, nr) api_check(L, (nr) == LUA_MULTRET || (L->ci->top - L->top >= (nr) - (na))) + +void lua_call(lua_State* L, int nargs, int nresults) +{ + StkId func; + api_checknelems(L, nargs + 1); + api_check(L, L->status == 0); + checkresults(L, nargs, nresults); + func = L->top - (nargs + 1); + + int wasActive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); + + luaD_call(L, func, nresults); + + if (!wasActive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + + adjustresults(L, nresults); + return; +} + +/* +** Execute a protected call. +*/ +struct CallS +{ /* data to `f_call' */ + StkId func; + int nresults; +}; + +static void f_call(lua_State* L, void* ud) +{ + struct CallS* c = cast_to(struct CallS*, ud); + luaD_call(L, c->func, c->nresults); + return; +} + +int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) +{ + struct CallS c; + int status; + ptrdiff_t func; + api_checknelems(L, nargs + 1); + api_check(L, L->status == 0); + checkresults(L, nargs, nresults); + if (errfunc == 0) + func = 0; + else + { + StkId o = index2adr(L, errfunc); + api_checkvalidindex(L, o); + func = savestack(L, o); + } + c.func = L->top - (nargs + 1); /* function to be called */ + c.nresults = nresults; + + int wasActive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); + + status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); + + if (!wasActive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + + adjustresults(L, nresults); + return status; +} + +int lua_status(lua_State* L) +{ + return L->status; +} + +/* +** Garbage-collection function +*/ + +int lua_gc(lua_State* L, int what, int data) +{ + int res = 0; + condhardmemtests(luaC_validate(L), 1); + global_State* g = L->global; + switch (what) + { + case LUA_GCSTOP: + { + g->GCthreshold = SIZE_MAX; + break; + } + case LUA_GCRESTART: + { + g->GCthreshold = g->totalbytes; + break; + } + case LUA_GCCOLLECT: + { + luaC_fullgc(L); + break; + } + case LUA_GCCOUNT: + { + /* GC values are expressed in Kbytes: #bytes/2^10 */ + res = cast_int(g->totalbytes >> 10); + break; + } + case LUA_GCISRUNNING: + { + res = (g->GCthreshold != SIZE_MAX); + break; + } + case LUA_GCSTEP: + { + size_t prevthreshold = g->GCthreshold; + size_t amount = (cast_to(size_t, data) << 10); + + // temporarily adjust the threshold so that we can perform GC work + if (amount <= g->totalbytes) + g->GCthreshold = g->totalbytes - amount; + else + g->GCthreshold = 0; + + bool waspaused = g->gcstate == GCSpause; + + // track how much work the loop will actually perform + size_t actualwork = 0; + + while (g->GCthreshold <= g->totalbytes) + { + luaC_step(L, false); + + actualwork += g->gcstepsize; + + if (g->gcstate == GCSpause) + { /* end of cycle? */ + res = 1; /* signal it */ + break; + } + } + + // if cycle hasn't finished, advance threshold forward for the amount of extra work performed + if (g->gcstate != GCSpause) + { + // if a new cycle was triggered by explicit step, we ignore old threshold as that shows an incorrect 'credit' of GC work + if (waspaused) + g->GCthreshold = g->totalbytes + actualwork; + else + g->GCthreshold = prevthreshold + actualwork; + } + break; + } + case LUA_GCSETGOAL: + { + res = g->gcgoal; + g->gcgoal = data; + break; + } + case LUA_GCSETSTEPMUL: + { + res = g->gcstepmul; + g->gcstepmul = data; + break; + } + case LUA_GCSETSTEPSIZE: + { + /* GC values are expressed in Kbytes: #bytes/2^10 */ + res = g->gcstepsize >> 10; + g->gcstepsize = data << 10; + break; + } + default: + res = -1; /* invalid option */ + } + return res; +} + +/* +** miscellaneous functions +*/ + +l_noret lua_error(lua_State* L) +{ + api_checknelems(L, 1); + + luaD_throw(L, LUA_ERRRUN); +} + +int lua_next(lua_State* L, int idx) +{ + luaC_checkthreadsleep(L); + StkId t = index2adr(L, idx); + api_check(L, ttistable(t)); + int more = luaH_next(L, hvalue(t), L->top - 1); + if (more) + { + api_incr_top(L); + } + else /* no more elements */ + L->top -= 1; /* remove key */ + return more; +} + +void lua_concat(lua_State* L, int n) +{ + api_checknelems(L, n); + if (n >= 2) + { + luaC_checkGC(L); + luaC_checkthreadsleep(L); + luaV_concat(L, n, cast_int(L->top - L->base) - 1); + L->top -= (n - 1); + } + else if (n == 0) + { /* push empty string */ + luaC_checkthreadsleep(L); + setsvalue2s(L, L->top, luaS_newlstr(L, "", 0)); + api_incr_top(L); + } + /* else n == 1; nothing to do */ + return; +} + +void* lua_newuserdata(lua_State* L, size_t sz, int tag) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + luaC_checkGC(L); + luaC_checkthreadsleep(L); + Udata* u = luaS_newudata(L, sz, tag); + setuvalue(L, L->top, u); + api_incr_top(L); + return u->data; +} + +void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) +{ + luaC_checkGC(L); + luaC_checkthreadsleep(L); + Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); + memcpy(u->data + sz, &dtor, sizeof(dtor)); + setuvalue(L, L->top, u); + api_incr_top(L); + return u->data; +} + +static const char* aux_upvalue(StkId fi, int n, TValue** val) +{ + Closure* f; + if (!ttisfunction(fi)) + return NULL; + f = clvalue(fi); + if (f->isC) + { + if (!(1 <= n && n <= f->nupvalues)) + return NULL; + *val = &f->c.upvals[n - 1]; + return ""; + } + else + { + Proto* p = f->l.p; + if (!(1 <= n && n <= p->sizeupvalues)) + return NULL; + TValue* r = &f->l.uprefs[n - 1]; + *val = ttisupval(r) ? upvalue(r)->v : r; + return getstr(p->upvalues[n - 1]); + } +} + +const char* lua_getupvalue(lua_State* L, int funcindex, int n) +{ + luaC_checkthreadsleep(L); + TValue* val; + const char* name = aux_upvalue(index2adr(L, funcindex), n, &val); + if (name) + { + setobj2s(L, L->top, val); + api_incr_top(L); + } + return name; +} + +const char* lua_setupvalue(lua_State* L, int funcindex, int n) +{ + const char* name; + TValue* val; + StkId fi; + fi = index2adr(L, funcindex); + api_checknelems(L, 1); + name = aux_upvalue(fi, n, &val); + if (name) + { + L->top--; + setobj(L, val, L->top); + luaC_barrier(L, clvalue(fi), L->top); + luaC_upvalbarrier(L, NULL, val); + } + return name; +} + +uintptr_t lua_encodepointer(lua_State* L, uintptr_t p) +{ + global_State* g = L->global; + return uintptr_t((g->ptrenckey[0] * p + g->ptrenckey[2]) ^ (g->ptrenckey[1] * p + g->ptrenckey[3])); +} + +int lua_ref(lua_State* L, int idx) +{ + int ref = LUA_REFNIL; + global_State* g = L->global; + StkId p = index2adr(L, idx); + if (!ttisnil(p)) + { + Table* reg = hvalue(registry(L)); + + if (g->registryfree != 0) + { /* reuse existing slot */ + ref = g->registryfree; + } + else + { /* no free elements */ + ref = luaH_getn(reg); + ref++; /* create new reference */ + } + + TValue* slot = luaH_setnum(L, reg, ref); + if (g->registryfree != 0) + g->registryfree = int(nvalue(slot)); + setobj2t(L, slot, p); + luaC_barriert(L, reg, p); + } + return ref; +} + +void lua_unref(lua_State* L, int ref) +{ + if (ref <= LUA_REFNIL) + return; + + global_State* g = L->global; + Table* reg = hvalue(registry(L)); + TValue* slot = luaH_setnum(L, reg, ref); + setnvalue(slot, g->registryfree); /* NB: no barrier needed because value isn't collectable */ + g->registryfree = ref; + return; +} + +void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)) +{ + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + L->global->udatagc[tag] = dtor; +} + +lua_Callbacks* lua_callbacks(lua_State* L) +{ + return &L->global->cb; +} diff --git a/VM/src/lapi.h b/VM/src/lapi.h new file mode 100644 index 0000000..b727218 --- /dev/null +++ b/VM/src/lapi.h @@ -0,0 +1,8 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" + +LUAI_FUNC const TValue* luaA_toobject(lua_State* L, int idx); +LUAI_FUNC void luaA_pushobject(lua_State* L, const TValue* o); diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp new file mode 100644 index 0000000..e37618f --- /dev/null +++ b/VM/src/laux.cpp @@ -0,0 +1,477 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include "lobject.h" +#include "lstate.h" +#include "lstring.h" +#include "lapi.h" +#include "lgc.h" + +#include + +/* convert a stack index to positive */ +#define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) + +/* +** {====================================================== +** Error-report functions +** ======================================================= +*/ + +static const char* currfuncname(lua_State* L) +{ + Closure* cl = L->ci > L->base_ci ? curr_func(L) : NULL; + const char* debugname = cl && cl->isC ? cl->c.debugname + 0 : NULL; + + if (debugname && strcmp(debugname, "__namecall") == 0) + return L->namecall ? getstr(L->namecall) : NULL; + else + return debugname; +} + +LUALIB_API l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg) +{ + const char* fname = currfuncname(L); + + if (fname) + luaL_error(L, "invalid argument #%d to '%s' (%s)", narg, fname, extramsg); + else + luaL_error(L, "invalid argument #%d (%s)", narg, extramsg); +} + +LUALIB_API l_noret luaL_typeerrorL(lua_State* L, int narg, const char* tname) +{ + const char* fname = currfuncname(L); + const TValue* obj = luaA_toobject(L, narg); + + if (obj) + { + if (fname) + luaL_error(L, "invalid argument #%d to '%s' (%s expected, got %s)", narg, fname, tname, luaT_objtypename(L, obj)); + else + luaL_error(L, "invalid argument #%d (%s expected, got %s)", narg, tname, luaT_objtypename(L, obj)); + } + else + { + if (fname) + luaL_error(L, "missing argument #%d to '%s' (%s expected)", narg, fname, tname); + else + luaL_error(L, "missing argument #%d (%s expected)", narg, tname); + } +} + +static l_noret tag_error(lua_State* L, int narg, int tag) +{ + luaL_typeerrorL(L, narg, lua_typename(L, tag)); +} + +LUALIB_API void luaL_where(lua_State* L, int level) +{ + lua_Debug ar; + if (lua_getinfo(L, level, "sl", &ar) && ar.currentline > 0) + { + lua_pushfstring(L, "%s:%d: ", ar.short_src, ar.currentline); + return; + } + lua_pushliteral(L, ""); /* else, no information available... */ +} + +LUALIB_API l_noret luaL_errorL(lua_State* L, const char* fmt, ...) +{ + va_list argp; + va_start(argp, fmt); + luaL_where(L, 1); + lua_pushvfstring(L, fmt, argp); + va_end(argp); + lua_concat(L, 2); + lua_error(L); +} + +/* }====================================================== */ + +LUALIB_API int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const lst[]) +{ + const char* name = (def) ? luaL_optstring(L, narg, def) : luaL_checkstring(L, narg); + int i; + for (i = 0; lst[i]; i++) + if (strcmp(lst[i], name) == 0) + return i; + const char* msg = lua_pushfstring(L, "invalid option '%s'", name); + luaL_argerrorL(L, narg, msg); +} + +LUALIB_API int luaL_newmetatable(lua_State* L, const char* tname) +{ + lua_getfield(L, LUA_REGISTRYINDEX, tname); /* get registry.name */ + if (!lua_isnil(L, -1)) /* name already in use? */ + return 0; /* leave previous value on top, but return 0 */ + lua_pop(L, 1); + lua_newtable(L); /* create metatable */ + lua_pushvalue(L, -1); + lua_setfield(L, LUA_REGISTRYINDEX, tname); /* registry.name = metatable */ + return 1; +} + +LUALIB_API void* luaL_checkudata(lua_State* L, int ud, const char* tname) +{ + void* p = lua_touserdata(L, ud); + if (p != NULL) + { /* value is a userdata? */ + if (lua_getmetatable(L, ud)) + { /* does it have a metatable? */ + lua_getfield(L, LUA_REGISTRYINDEX, tname); /* get correct metatable */ + if (lua_rawequal(L, -1, -2)) + { /* does it have the correct mt? */ + lua_pop(L, 2); /* remove both metatables */ + return p; + } + } + } + luaL_typeerrorL(L, ud, tname); /* else error */ +} + +LUALIB_API void luaL_checkstack(lua_State* L, int space, const char* mes) +{ + if (!lua_checkstack(L, space)) + luaL_error(L, "stack overflow (%s)", mes); +} + +LUALIB_API void luaL_checktype(lua_State* L, int narg, int t) +{ + if (lua_type(L, narg) != t) + tag_error(L, narg, t); +} + +LUALIB_API void luaL_checkany(lua_State* L, int narg) +{ + if (lua_type(L, narg) == LUA_TNONE) + luaL_error(L, "missing argument #%d", narg); +} + +LUALIB_API const char* luaL_checklstring(lua_State* L, int narg, size_t* len) +{ + const char* s = lua_tolstring(L, narg, len); + if (!s) + tag_error(L, narg, LUA_TSTRING); + return s; +} + +LUALIB_API const char* luaL_optlstring(lua_State* L, int narg, const char* def, size_t* len) +{ + if (lua_isnoneornil(L, narg)) + { + if (len) + *len = (def ? strlen(def) : 0); + return def; + } + else + return luaL_checklstring(L, narg, len); +} + +LUALIB_API double luaL_checknumber(lua_State* L, int narg) +{ + int isnum; + double d = lua_tonumberx(L, narg, &isnum); + if (!isnum) + tag_error(L, narg, LUA_TNUMBER); + return d; +} + +LUALIB_API double luaL_optnumber(lua_State* L, int narg, double def) +{ + return luaL_opt(L, luaL_checknumber, narg, def); +} + +LUALIB_API int luaL_checkinteger(lua_State* L, int narg) +{ + int isnum; + int d = lua_tointegerx(L, narg, &isnum); + if (!isnum) + tag_error(L, narg, LUA_TNUMBER); + return d; +} + +LUALIB_API int luaL_optinteger(lua_State* L, int narg, int def) +{ + return luaL_opt(L, luaL_checkinteger, narg, def); +} + +LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int narg) +{ + int isnum; + unsigned d = lua_tounsignedx(L, narg, &isnum); + if (!isnum) + tag_error(L, narg, LUA_TNUMBER); + return d; +} + +LUALIB_API unsigned luaL_optunsigned(lua_State* L, int narg, unsigned def) +{ + return luaL_opt(L, luaL_checkunsigned, narg, def); +} + +LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* event) +{ + if (!lua_getmetatable(L, obj)) /* no metatable? */ + return 0; + lua_pushstring(L, event); + lua_rawget(L, -2); + if (lua_isnil(L, -1)) + { + lua_pop(L, 2); /* remove metatable and metafield */ + return 0; + } + else + { + lua_remove(L, -2); /* remove only metatable */ + return 1; + } +} + +LUALIB_API int luaL_callmeta(lua_State* L, int obj, const char* event) +{ + obj = abs_index(L, obj); + if (!luaL_getmetafield(L, obj, event)) /* no metafield? */ + return 0; + lua_pushvalue(L, obj); + lua_call(L, 1, 1); + return 1; +} + +static int libsize(const luaL_Reg* l) +{ + int size = 0; + for (; l->name; l++) + size++; + return size; +} + +LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l) +{ + if (libname) + { + int size = libsize(l); + /* check whether lib already exists */ + luaL_findtable(L, LUA_REGISTRYINDEX, "_LOADED", 1); + lua_getfield(L, -1, libname); /* get _LOADED[libname] */ + if (!lua_istable(L, -1)) + { /* not found? */ + lua_pop(L, 1); /* remove previous result */ + /* try global variable (and create one if it does not exist) */ + if (luaL_findtable(L, LUA_GLOBALSINDEX, libname, size) != NULL) + luaL_error(L, "name conflict for module '%s'", libname); + lua_pushvalue(L, -1); + lua_setfield(L, -3, libname); /* _LOADED[libname] = new table */ + } + lua_remove(L, -2); /* remove _LOADED table */ + } + for (; l->name; l++) + { + lua_pushcfunction(L, l->func, l->name); + lua_setfield(L, -2, l->name); + } +} + +LUALIB_API const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) +{ + const char* e; + lua_pushvalue(L, idx); + do + { + e = strchr(fname, '.'); + if (e == NULL) + e = fname + strlen(fname); + lua_pushlstring(L, fname, e - fname); + lua_rawget(L, -2); + if (lua_isnil(L, -1)) + { /* no such field? */ + lua_pop(L, 1); /* remove this nil */ + lua_createtable(L, 0, (*e == '.' ? 1 : szhint)); /* new table for field */ + lua_pushlstring(L, fname, e - fname); + lua_pushvalue(L, -2); + lua_settable(L, -4); /* set new table into field */ + } + else if (!lua_istable(L, -1)) + { /* field has a non-table value? */ + lua_pop(L, 2); /* remove table and value */ + return fname; /* return problematic part of the name */ + } + lua_remove(L, -2); /* remove previous table */ + fname = e + 1; + } while (*e == '.'); + return NULL; +} + +/* +** {====================================================== +** Generic Buffer manipulation +** ======================================================= +*/ + +static size_t getnextbuffersize(lua_State* L, size_t currentsize, size_t desiredsize) +{ + size_t newsize = currentsize + currentsize / 2; + + // check for size oveflow + if (SIZE_MAX - desiredsize < currentsize) + luaL_error(L, "buffer too large"); + + // growth factor might not be enough to satisfy the desired size + if (newsize < desiredsize) + newsize = desiredsize; + + return newsize; +} + +LUALIB_API void luaL_buffinit(lua_State* L, luaL_Buffer* B) +{ + // start with an internal buffer + B->p = B->buffer; + B->end = B->p + LUA_BUFFERSIZE; + + B->L = L; + B->storage = nullptr; +} + +LUALIB_API char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size) +{ + luaL_buffinit(L, B); + luaL_reservebuffer(B, size, -1); + return B->p; +} + +LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc) +{ + lua_State* L = B->L; + + if (B->storage) + LUAU_ASSERT(B->storage == tsvalue(L->top + boxloc)); + + char* base = B->storage ? B->storage->data : B->buffer; + + size_t capacity = B->end - base; + size_t nextsize = getnextbuffersize(B->L, capacity, capacity + additionalsize); + + TString* newStorage = luaS_bufstart(L, nextsize); + + memcpy(newStorage->data, base, B->p - base); + + // place the string storage at the expected position in the stack + if (base == B->buffer) + { + lua_pushnil(L); + lua_insert(L, boxloc); + } + + setsvalue2s(L, L->top + boxloc, newStorage); + B->p = newStorage->data + (B->p - base); + B->end = newStorage->data + nextsize; + B->storage = newStorage; + + return B->p; +} + +LUALIB_API void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc) +{ + if (size_t(B->end - B->p) < size) + luaL_extendbuffer(B, size - (B->end - B->p), boxloc); +} + +LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) +{ + if (size_t(B->end - B->p) < len) + luaL_extendbuffer(B, len - (B->end - B->p), -1); + + memcpy(B->p, s, len); + B->p += len; +} + +LUALIB_API void luaL_addvalue(luaL_Buffer* B) +{ + lua_State* L = B->L; + + size_t vl; + if (const char* s = lua_tolstring(L, -1, &vl)) + { + if (size_t(B->end - B->p) < vl) + luaL_extendbuffer(B, vl - (B->end - B->p), -2); + + memcpy(B->p, s, vl); + B->p += vl; + + lua_pop(L, 1); + } +} + +LUALIB_API void luaL_pushresult(luaL_Buffer* B) +{ + lua_State* L = B->L; + + if (TString* storage = B->storage) + { + luaC_checkGC(L); + + // if we finished just at the end of the string buffer, we can convert it to a mutable stirng without a copy + if (B->p == B->end) + { + setsvalue2s(L, L->top - 1, luaS_buffinish(L, storage)); + } + else + { + setsvalue2s(L, L->top - 1, luaS_newlstr(L, storage->data, B->p - storage->data)); + } + } + else + { + lua_pushlstring(L, B->buffer, B->p - B->buffer); + } +} + +LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size) +{ + B->p += size; + luaL_pushresult(B); +} + +/* }====================================================== */ + +LUALIB_API const char* luaL_tolstring(lua_State* L, int idx, size_t* len) +{ + if (luaL_callmeta(L, idx, "__tostring")) /* is there a metafield? */ + { + if (!lua_isstring(L, -1)) + luaL_error(L, "'__tostring' must return a string"); + return lua_tolstring(L, -1, len); + } + + switch (lua_type(L, idx)) + { + case LUA_TNUMBER: + lua_pushstring(L, lua_tostring(L, idx)); + break; + case LUA_TSTRING: + lua_pushvalue(L, idx); + break; + case LUA_TBOOLEAN: + lua_pushstring(L, (lua_toboolean(L, idx) ? "true" : "false")); + break; + case LUA_TNIL: + lua_pushliteral(L, "nil"); + break; + case LUA_TVECTOR: + { + const float* v = lua_tovector(L, idx); + lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); + break; + } + default: + { + const void* ptr = lua_topointer(L, idx); + unsigned long long enc = lua_encodepointer(L, uintptr_t(ptr)); + lua_pushfstring(L, "%s: 0x%016llx", luaL_typename(L, idx), enc); + break; + } + } + return lua_tolstring(L, -1, len); +} diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp new file mode 100644 index 0000000..87fc163 --- /dev/null +++ b/VM/src/lbaselib.cpp @@ -0,0 +1,466 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include "lstate.h" +#include "lapi.h" +#include "ldo.h" + +#include +#include +#include + +static void writestring(const char* s, size_t l) +{ + fwrite(s, 1, l, stdout); +} + +static int luaB_print(lua_State* L) +{ + int n = lua_gettop(L); /* number of arguments */ + for (int i = 1; i <= n; i++) + { + size_t l; + const char* s = luaL_tolstring(L, i, &l); /* convert to string using __tostring et al */ + if (i > 1) + writestring("\t", 1); + writestring(s, l); + lua_pop(L, 1); /* pop result */ + } + writestring("\n", 1); + return 0; +} + +static int luaB_tonumber(lua_State* L) +{ + int base = luaL_optinteger(L, 2, 10); + if (base == 10) + { /* standard conversion */ + luaL_checkany(L, 1); + if (lua_isnumber(L, 1)) + { + lua_pushnumber(L, lua_tonumber(L, 1)); + return 1; + } + } + else + { + const char* s1 = luaL_checkstring(L, 1); + luaL_argcheck(L, 2 <= base && base <= 36, 2, "base out of range"); + char* s2; + unsigned long long n; + n = strtoull(s1, &s2, base); + if (s1 != s2) + { /* at least one valid digit? */ + while (isspace((unsigned char)(*s2))) + s2++; /* skip trailing spaces */ + if (*s2 == '\0') + { /* no invalid trailing characters? */ + lua_pushnumber(L, (double)n); + return 1; + } + } + } + lua_pushnil(L); /* else not a number */ + return 1; +} + +static int luaB_error(lua_State* L) +{ + int level = luaL_optinteger(L, 2, 1); + lua_settop(L, 1); + if (lua_isstring(L, 1) && level > 0) + { /* add extra information? */ + luaL_where(L, level); + lua_pushvalue(L, 1); + lua_concat(L, 2); + } + lua_error(L); +} + +static int luaB_getmetatable(lua_State* L) +{ + luaL_checkany(L, 1); + if (!lua_getmetatable(L, 1)) + { + lua_pushnil(L); + return 1; /* no metatable */ + } + luaL_getmetafield(L, 1, "__metatable"); + return 1; /* returns either __metatable field (if present) or metatable */ +} + +static int luaB_setmetatable(lua_State* L) +{ + int t = lua_type(L, 2); + luaL_checktype(L, 1, LUA_TTABLE); + luaL_argexpected(L, t == LUA_TNIL || t == LUA_TTABLE, 2, "nil or table"); + if (luaL_getmetafield(L, 1, "__metatable")) + luaL_error(L, "cannot change a protected metatable"); + lua_settop(L, 2); + lua_setmetatable(L, 1); + return 1; +} + +static void getfunc(lua_State* L, int opt) +{ + if (lua_isfunction(L, 1)) + lua_pushvalue(L, 1); + else + { + lua_Debug ar; + int level = opt ? luaL_optinteger(L, 1, 1) : luaL_checkinteger(L, 1); + luaL_argcheck(L, level >= 0, 1, "level must be non-negative"); + if (lua_getinfo(L, level, "f", &ar) == 0) + luaL_argerror(L, 1, "invalid level"); + if (lua_isnil(L, -1)) + luaL_error(L, "no function environment for tail call at level %d", level); + } +} + +static int luaB_getfenv(lua_State* L) +{ + getfunc(L, 1); + if (lua_iscfunction(L, -1)) /* is a C function? */ + lua_pushvalue(L, LUA_GLOBALSINDEX); /* return the thread's global env. */ + else + lua_getfenv(L, -1); + lua_setsafeenv(L, -1, false); + return 1; +} + +static int luaB_setfenv(lua_State* L) +{ + luaL_checktype(L, 2, LUA_TTABLE); + getfunc(L, 0); + lua_pushvalue(L, 2); + lua_setsafeenv(L, -1, false); + if (lua_isnumber(L, 1) && lua_tonumber(L, 1) == 0) + { + /* change environment of current thread */ + lua_pushthread(L); + lua_insert(L, -2); + lua_setfenv(L, -2); + return 0; + } + else if (lua_iscfunction(L, -2) || lua_setfenv(L, -2) == 0) + luaL_error(L, "'setfenv' cannot change environment of given object"); + return 1; +} + +static int luaB_rawequal(lua_State* L) +{ + luaL_checkany(L, 1); + luaL_checkany(L, 2); + lua_pushboolean(L, lua_rawequal(L, 1, 2)); + return 1; +} + +static int luaB_rawget(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + luaL_checkany(L, 2); + lua_settop(L, 2); + lua_rawget(L, 1); + return 1; +} + +static int luaB_rawset(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + luaL_checkany(L, 2); + luaL_checkany(L, 3); + lua_settop(L, 3); + lua_rawset(L, 1); + return 1; +} + +static int luaB_gcinfo(lua_State* L) +{ + lua_pushinteger(L, lua_gc(L, LUA_GCCOUNT, 0)); + return 1; +} + +static int luaB_type(lua_State* L) +{ + luaL_checkany(L, 1); + lua_pushstring(L, luaL_typename(L, 1)); + return 1; +} + +static int luaB_typeof(lua_State* L) +{ + luaL_checkany(L, 1); + const TValue* obj = luaA_toobject(L, 1); + lua_pushstring(L, luaT_objtypename(L, obj)); + return 1; +} + +int luaB_next(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + lua_settop(L, 2); /* create a 2nd argument if there isn't one */ + if (lua_next(L, 1)) + return 2; + else + { + lua_pushnil(L); + return 1; + } +} + +static int luaB_pairs(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + lua_pushvalue(L, lua_upvalueindex(1)); /* return generator, */ + lua_pushvalue(L, 1); /* state, */ + lua_pushnil(L); /* and initial value */ + return 3; +} + +int luaB_inext(lua_State* L) +{ + int i = luaL_checkinteger(L, 2); + luaL_checktype(L, 1, LUA_TTABLE); + i++; /* next value */ + lua_pushinteger(L, i); + lua_rawgeti(L, 1, i); + return (lua_isnil(L, -1)) ? 0 : 2; +} + +static int luaB_ipairs(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + lua_pushvalue(L, lua_upvalueindex(1)); /* return generator, */ + lua_pushvalue(L, 1); /* state, */ + lua_pushinteger(L, 0); /* and initial value */ + return 3; +} + +static int luaB_assert(lua_State* L) +{ + luaL_checkany(L, 1); + if (!lua_toboolean(L, 1)) + luaL_error(L, "%s", luaL_optstring(L, 2, "assertion failed!")); + return lua_gettop(L); +} + +static int luaB_select(lua_State* L) +{ + int n = lua_gettop(L); + if (lua_type(L, 1) == LUA_TSTRING && *lua_tostring(L, 1) == '#') + { + lua_pushinteger(L, n - 1); + return 1; + } + else + { + int i = luaL_checkinteger(L, 1); + if (i < 0) + i = n + i; + else if (i > n) + i = n; + luaL_argcheck(L, 1 <= i, 1, "index out of range"); + return n - i; + } +} + +static void luaB_pcallrun(lua_State* L, void* ud) +{ + StkId func = (StkId)ud; + + luaD_call(L, func, LUA_MULTRET); +} + +static int luaB_pcally(lua_State* L) +{ + luaL_checkany(L, 1); + + StkId func = L->base; + + // any errors from this point on are handled by continuation + L->ci->flags |= LUA_CALLINFO_HANDLE; + + // maintain yieldable invariant (baseCcalls <= nCcalls) + L->baseCcalls++; + int status = luaD_pcall(L, luaB_pcallrun, func, savestack(L, func), 0); + L->baseCcalls--; + + // necessary to accomodate functions that return lots of values + expandstacklimit(L, L->top); + + // yielding means we need to propagate yield; resume will call continuation function later + if (status == 0 && (L->status == LUA_YIELD || L->status == LUA_BREAK)) + return -1; // -1 is a marker for yielding from C + + // immediate return (error or success) + lua_rawcheckstack(L, 1); + lua_pushboolean(L, status == 0); + lua_insert(L, 1); + return lua_gettop(L); // return status + all results +} + +static int luaB_pcallcont(lua_State* L, int status) +{ + if (status == 0) + { + lua_rawcheckstack(L, 1); + lua_pushboolean(L, true); + lua_insert(L, 1); // insert status before all results + return lua_gettop(L); + } + else + { + lua_rawcheckstack(L, 1); + lua_pushboolean(L, false); + lua_insert(L, -2); // insert status before error object + return 2; + } +} + +static int luaB_xpcally(lua_State* L) +{ + luaL_checktype(L, 2, LUA_TFUNCTION); + + /* swap function & error function */ + lua_pushvalue(L, 1); + lua_pushvalue(L, 2); + lua_replace(L, 1); + lua_replace(L, 2); + /* at this point the stack looks like err, f, args */ + + // any errors from this point on are handled by continuation + L->ci->flags |= LUA_CALLINFO_HANDLE; + + StkId errf = L->base; + StkId func = L->base + 1; + + // maintain yieldable invariant (baseCcalls <= nCcalls) + L->baseCcalls++; + int status = luaD_pcall(L, luaB_pcallrun, func, savestack(L, func), savestack(L, errf)); + L->baseCcalls--; + + // necessary to accomodate functions that return lots of values + expandstacklimit(L, L->top); + + // yielding means we need to propagate yield; resume will call continuation function later + if (status == 0 && (L->status == LUA_YIELD || L->status == LUA_BREAK)) + return -1; // -1 is a marker for yielding from C + + // immediate return (error or success) + lua_rawcheckstack(L, 1); + lua_pushboolean(L, status == 0); + lua_replace(L, 1); // replace error function with status + return lua_gettop(L); // return status + all results +} + +static void luaB_xpcallerr(lua_State* L, void* ud) +{ + StkId func = (StkId)ud; + + luaD_call(L, func, 1); +} + +static int luaB_xpcallcont(lua_State* L, int status) +{ + if (status == 0) + { + lua_rawcheckstack(L, 1); + lua_pushboolean(L, true); + lua_replace(L, 1); // replace error function with status + return lua_gettop(L); /* return status + all results */ + } + else + { + lua_rawcheckstack(L, 3); + lua_pushboolean(L, false); + lua_pushvalue(L, 1); // push error function on top of the stack + lua_pushvalue(L, -3); // push error object (that was on top of the stack before) + + StkId res = L->top - 3; + StkId errf = L->top - 2; + + // note: we pass res as errfunc as a short cut; if errf generates an error, we'll try to execute res (boolean) and fail + luaD_pcall(L, luaB_xpcallerr, errf, savestack(L, errf), savestack(L, res)); + + return 2; + } +} + +static int luaB_tostring(lua_State* L) +{ + luaL_checkany(L, 1); + luaL_tolstring(L, 1, NULL); + return 1; +} + +static int luaB_newproxy(lua_State* L) +{ + int t = lua_type(L, 1); + luaL_argexpected(L, t == LUA_TNONE || t == LUA_TNIL || t == LUA_TBOOLEAN, 1, "nil or boolean"); + + bool needsmt = lua_toboolean(L, 1); + + lua_newuserdata(L, 0, 0); + + if (needsmt) + { + lua_newtable(L); + lua_setmetatable(L, -2); + } + + return 1; +} + +static const luaL_Reg base_funcs[] = { + {"assert", luaB_assert}, + {"error", luaB_error}, + {"gcinfo", luaB_gcinfo}, + {"getfenv", luaB_getfenv}, + {"getmetatable", luaB_getmetatable}, + {"next", luaB_next}, + {"newproxy", luaB_newproxy}, + {"print", luaB_print}, + {"rawequal", luaB_rawequal}, + {"rawget", luaB_rawget}, + {"rawset", luaB_rawset}, + {"select", luaB_select}, + {"setfenv", luaB_setfenv}, + {"setmetatable", luaB_setmetatable}, + {"tonumber", luaB_tonumber}, + {"tostring", luaB_tostring}, + {"type", luaB_type}, + {"typeof", luaB_typeof}, + {NULL, NULL}, +}; + +static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFunction u) +{ + lua_pushcfunction(L, u); + lua_pushcfunction(L, f, name, 1); + lua_setfield(L, -2, name); +} + +LUALIB_API int luaopen_base(lua_State* L) +{ + /* set global _G */ + lua_pushvalue(L, LUA_GLOBALSINDEX); + lua_setglobal(L, "_G"); + + /* open lib into global table */ + luaL_register(L, "_G", base_funcs); + lua_pushliteral(L, "Luau"); + lua_setglobal(L, "_VERSION"); /* set global _VERSION */ + + /* `ipairs' and `pairs' need auxiliary functions as upvalues */ + auxopen(L, "ipairs", luaB_ipairs, luaB_inext); + auxopen(L, "pairs", luaB_pairs, luaB_next); + + lua_pushcfunction(L, luaB_pcally, "pcall", 0, luaB_pcallcont); + lua_setfield(L, -2, "pcall"); + + lua_pushcfunction(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); + lua_setfield(L, -2, "xpcall"); + + return 1; +} diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp new file mode 100644 index 0000000..0754a35 --- /dev/null +++ b/VM/src/lbitlib.cpp @@ -0,0 +1,201 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include "lnumutils.h" + +#define ALLONES ~0u +#define NBITS int(8 * sizeof(unsigned)) + +/* macro to trim extra bits */ +#define trim(x) ((x)&ALLONES) + +/* builds a number with 'n' ones (1 <= n <= NBITS) */ +#define mask(n) (~((ALLONES << 1) << ((n)-1))) + +typedef unsigned b_uint; + +static b_uint andaux(lua_State* L) +{ + int i, n = lua_gettop(L); + b_uint r = ~(b_uint)0; + for (i = 1; i <= n; i++) + r &= luaL_checkunsigned(L, i); + return trim(r); +} + +static int b_and(lua_State* L) +{ + b_uint r = andaux(L); + lua_pushunsigned(L, r); + return 1; +} + +static int b_test(lua_State* L) +{ + b_uint r = andaux(L); + lua_pushboolean(L, r != 0); + return 1; +} + +static int b_or(lua_State* L) +{ + int i, n = lua_gettop(L); + b_uint r = 0; + for (i = 1; i <= n; i++) + r |= luaL_checkunsigned(L, i); + lua_pushunsigned(L, trim(r)); + return 1; +} + +static int b_xor(lua_State* L) +{ + int i, n = lua_gettop(L); + b_uint r = 0; + for (i = 1; i <= n; i++) + r ^= luaL_checkunsigned(L, i); + lua_pushunsigned(L, trim(r)); + return 1; +} + +static int b_not(lua_State* L) +{ + b_uint r = ~luaL_checkunsigned(L, 1); + lua_pushunsigned(L, trim(r)); + return 1; +} + +static int b_shift(lua_State* L, b_uint r, int i) +{ + if (i < 0) + { /* shift right? */ + i = -i; + r = trim(r); + if (i >= NBITS) + r = 0; + else + r >>= i; + } + else + { /* shift left */ + if (i >= NBITS) + r = 0; + else + r <<= i; + r = trim(r); + } + lua_pushunsigned(L, r); + return 1; +} + +static int b_lshift(lua_State* L) +{ + return b_shift(L, luaL_checkunsigned(L, 1), luaL_checkinteger(L, 2)); +} + +static int b_rshift(lua_State* L) +{ + return b_shift(L, luaL_checkunsigned(L, 1), -luaL_checkinteger(L, 2)); +} + +static int b_arshift(lua_State* L) +{ + b_uint r = luaL_checkunsigned(L, 1); + int i = luaL_checkinteger(L, 2); + if (i < 0 || !(r & ((b_uint)1 << (NBITS - 1)))) + return b_shift(L, r, -i); + else + { /* arithmetic shift for 'negative' number */ + if (i >= NBITS) + r = ALLONES; + else + r = trim((r >> i) | ~(~(b_uint)0 >> i)); /* add signal bit */ + lua_pushunsigned(L, r); + return 1; + } +} + +static int b_rot(lua_State* L, int i) +{ + b_uint r = luaL_checkunsigned(L, 1); + i &= (NBITS - 1); /* i = i % NBITS */ + r = trim(r); + if (i != 0) /* avoid undefined shift of NBITS when i == 0 */ + r = (r << i) | (r >> (NBITS - i)); + lua_pushunsigned(L, trim(r)); + return 1; +} + +static int b_lrot(lua_State* L) +{ + return b_rot(L, luaL_checkinteger(L, 2)); +} + +static int b_rrot(lua_State* L) +{ + return b_rot(L, -luaL_checkinteger(L, 2)); +} + +/* +** get field and width arguments for field-manipulation functions, +** checking whether they are valid. +** ('luaL_error' called without 'return' to avoid later warnings about +** 'width' being used uninitialized.) +*/ +static int fieldargs(lua_State* L, int farg, int* width) +{ + int f = luaL_checkinteger(L, farg); + int w = luaL_optinteger(L, farg + 1, 1); + luaL_argcheck(L, 0 <= f, farg, "field cannot be negative"); + luaL_argcheck(L, 0 < w, farg + 1, "width must be positive"); + if (f + w > NBITS) + luaL_error(L, "trying to access non-existent bits"); + *width = w; + return f; +} + +static int b_extract(lua_State* L) +{ + int w; + b_uint r = luaL_checkunsigned(L, 1); + int f = fieldargs(L, 2, &w); + r = (r >> f) & mask(w); + lua_pushunsigned(L, r); + return 1; +} + +static int b_replace(lua_State* L) +{ + int w; + b_uint r = luaL_checkunsigned(L, 1); + b_uint v = luaL_checkunsigned(L, 2); + int f = fieldargs(L, 3, &w); + int m = mask(w); + v &= m; /* erase bits outside given width */ + r = (r & ~(m << f)) | (v << f); + lua_pushunsigned(L, r); + return 1; +} + +static const luaL_Reg bitlib[] = { + {"arshift", b_arshift}, + {"band", b_and}, + {"bnot", b_not}, + {"bor", b_or}, + {"bxor", b_xor}, + {"btest", b_test}, + {"extract", b_extract}, + {"lrotate", b_lrot}, + {"lshift", b_lshift}, + {"replace", b_replace}, + {"rrotate", b_rrot}, + {"rshift", b_rshift}, + {NULL, NULL}, +}; + +LUALIB_API int luaopen_bit32(lua_State* L) +{ + luaL_register(L, LUA_BITLIBNAME, bitlib); + + return 1; +} diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp new file mode 100644 index 0000000..e1c99b2 --- /dev/null +++ b/VM/src/lbuiltins.cpp @@ -0,0 +1,1099 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lbuiltins.h" + +#include "lstate.h" +#include "lstring.h" +#include "ltable.h" +#include "lgc.h" +#include "lnumutils.h" +#include "ldo.h" + +#include + +#ifdef _MSC_VER +#include +#endif + +// luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM +// The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. +// If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path +// If luauF_* succeeds, it needs to return *all* requested arguments, filling results with nil as appropriate. +// On input, nparams refers to the actual number of arguments (0+), whereas nresults contains LUA_MULTRET for arbitrary returns or 0+ for a +// fixed-length return Because of this, and the fact that "extra" returned values will be ignored, implementations below typically check that nresults +// is <= expected number, which covers the LUA_MULTRET case. + +static int luauF_assert(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults == 0 && !l_isfalse(arg0)) + { + return 0; + } + + return -1; +} + +static int luauF_abs(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, fabs(a1)); + return 1; + } + + return -1; +} + +static int luauF_acos(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, acos(a1)); + return 1; + } + + return -1; +} + +static int luauF_asin(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, asin(a1)); + return 1; + } + + return -1; +} + +static int luauF_atan2(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + setnvalue(res, atan2(a1, a2)); + return 1; + } + + return -1; +} + +static int luauF_atan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, atan(a1)); + return 1; + } + + return -1; +} + +LUAU_FASTMATH_BEGIN +static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, ceil(a1)); + return 1; + } + + return -1; +} +LUAU_FASTMATH_END + +static int luauF_cosh(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, cosh(a1)); + return 1; + } + + return -1; +} + +static int luauF_cos(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, cos(a1)); + return 1; + } + + return -1; +} + +static int luauF_deg(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + const double rpd = (3.14159265358979323846 / 180.0); + setnvalue(res, a1 / rpd); + return 1; + } + + return -1; +} + +static int luauF_exp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, exp(a1)); + return 1; + } + + return -1; +} + +LUAU_FASTMATH_BEGIN +static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, floor(a1)); + return 1; + } + + return -1; +} +LUAU_FASTMATH_END + +static int luauF_fmod(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + setnvalue(res, fmod(a1, a2)); + return 1; + } + + return -1; +} + +static int luauF_frexp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 2 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + int e; + double f = frexp(a1, &e); + setnvalue(res, f); + setnvalue(res + 1, double(e)); + return 2; + } + + return -1; +} + +static int luauF_ldexp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + setnvalue(res, ldexp(a1, int(a2))); + return 1; + } + + return -1; +} + +static int luauF_log10(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, log10(a1)); + return 1; + } + + return -1; +} + +static int luauF_log(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + + if (nparams == 1) + { + setnvalue(res, log(a1)); + return 1; + } + else if (ttisnumber(args)) + { + double a2 = nvalue(args); + + if (a2 == 2.0) + { + setnvalue(res, log2(a1)); + return 1; + } + else if (a2 == 10.0) + { + setnvalue(res, log10(a1)); + return 1; + } + else + { + setnvalue(res, log(a1) / log(a2)); + return 1; + } + } + } + + return -1; +} + +static int luauF_max(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double r = nvalue(arg0); + + for (int i = 2; i <= nparams; ++i) + { + if (!ttisnumber(args + (i - 2))) + return -1; + + double a = nvalue(args + (i - 2)); + + r = (a > r) ? a : r; + } + + setnvalue(res, r); + return 1; + } + + return -1; +} + +static int luauF_min(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double r = nvalue(arg0); + + for (int i = 2; i <= nparams; ++i) + { + if (!ttisnumber(args + (i - 2))) + return -1; + + double a = nvalue(args + (i - 2)); + + r = (a < r) ? a : r; + } + + setnvalue(res, r); + return 1; + } + + return -1; +} + +static int luauF_modf(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 2 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + double ip; + double fp = modf(a1, &ip); + setnvalue(res, ip); + setnvalue(res + 1, fp); + return 2; + } + + return -1; +} + +static int luauF_pow(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + setnvalue(res, pow(a1, a2)); + return 1; + } + + return -1; +} + +static int luauF_rad(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + const double rpd = (3.14159265358979323846 / 180.0); + setnvalue(res, a1 * rpd); + return 1; + } + + return -1; +} + +static int luauF_sinh(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, sinh(a1)); + return 1; + } + + return -1; +} + +static int luauF_sin(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, sin(a1)); + return 1; + } + + return -1; +} + +LUAU_FASTMATH_BEGIN +static int luauF_sqrt(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, sqrt(a1)); + return 1; + } + + return -1; +} +LUAU_FASTMATH_END + +static int luauF_tanh(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, tanh(a1)); + return 1; + } + + return -1; +} + +static int luauF_tan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, tan(a1)); + return 1; + } + + return -1; +} + +static int luauF_arshift(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + + unsigned u; + luai_num2unsigned(u, a1); + int s = int(a2); + + // note: we only specialize fast-path that doesn't require further conditionals (negative shifts and shifts greater or equal to bit width can + // be handled generically) + if (unsigned(s) < 32) + { + // note: technically right shift of negative values is UB, but this behavior is getting defined in C++20 and all compilers do the right + // (shift) thing. + uint32_t r = int32_t(u) >> s; + + setnvalue(res, double(r)); + return 1; + } + } + + return -1; +} + +static int luauF_band(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + uint32_t r = ~0u; + + if (!ttisnumber(arg0)) + return -1; + + { + double a1 = nvalue(arg0); + unsigned u; + luai_num2unsigned(u, a1); + + r &= u; + } + + for (int i = 2; i <= nparams; ++i) + { + if (!ttisnumber(args + (i - 2))) + return -1; + + double a = nvalue(args + (i - 2)); + unsigned u; + luai_num2unsigned(u, a); + + r &= u; + } + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + +static int luauF_bnot(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + unsigned u; + luai_num2unsigned(u, a1); + + uint32_t r = ~u; + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + +static int luauF_bor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + uint32_t r = 0; + + if (!ttisnumber(arg0)) + return -1; + + { + double a1 = nvalue(arg0); + unsigned u; + luai_num2unsigned(u, a1); + + r |= u; + } + + for (int i = 2; i <= nparams; ++i) + { + if (!ttisnumber(args + (i - 2))) + return -1; + + double a = nvalue(args + (i - 2)); + unsigned u; + luai_num2unsigned(u, a); + + r |= u; + } + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + +static int luauF_bxor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + uint32_t r = 0; + + if (!ttisnumber(arg0)) + return -1; + + { + double a1 = nvalue(arg0); + unsigned u; + luai_num2unsigned(u, a1); + + r ^= u; + } + + for (int i = 2; i <= nparams; ++i) + { + if (!ttisnumber(args + (i - 2))) + return -1; + + double a = nvalue(args + (i - 2)); + unsigned u; + luai_num2unsigned(u, a); + + r ^= u; + } + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + +static int luauF_btest(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + uint32_t r = ~0u; + + if (!ttisnumber(arg0)) + return -1; + + { + double a1 = nvalue(arg0); + unsigned u; + luai_num2unsigned(u, a1); + + r &= u; + } + + for (int i = 2; i <= nparams; ++i) + { + if (!ttisnumber(args + (i - 2))) + return -1; + + double a = nvalue(args + (i - 2)); + unsigned u; + luai_num2unsigned(u, a); + + r &= u; + } + + setbvalue(res, r != 0); + return 1; + } + + return -1; +} + +static int luauF_extract(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + double a3 = nvalue(args + 1); + + unsigned n; + luai_num2unsigned(n, a1); + int f = int(a2); + int w = int(a3); + + if (f >= 0 && w > 0 && f + w <= 32) + { + uint32_t m = ~(0xfffffffeu << (w - 1)); + uint32_t r = (n >> f) & m; + + setnvalue(res, double(r)); + return 1; + } + } + + return -1; +} + +static int luauF_lrotate(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + + unsigned u; + luai_num2unsigned(u, a1); + int s = int(a2); + + // MSVC doesn't recognize the rotate form that is UB-safe +#ifdef _MSC_VER + uint32_t r = _rotl(u, s); +#else + uint32_t r = (u << (s & 31)) | (u >> ((32 - s) & 31)); +#endif + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + +static int luauF_lshift(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + + unsigned u; + luai_num2unsigned(u, a1); + int s = int(a2); + + // note: we only specialize fast-path that doesn't require further conditionals (negative shifts and shifts greater or equal to bit width can + // be handled generically) + if (unsigned(s) < 32) + { + uint32_t r = u << s; + + setnvalue(res, double(r)); + return 1; + } + } + + return -1; +} + +static int luauF_replace(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 4 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1) && ttisnumber(args + 2)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + double a3 = nvalue(args + 1); + double a4 = nvalue(args + 2); + + unsigned n, v; + luai_num2unsigned(n, a1); + luai_num2unsigned(v, a2); + int f = int(a3); + int w = int(a4); + + if (f >= 0 && w > 0 && f + w <= 32) + { + uint32_t m = ~(0xfffffffeu << (w - 1)); + uint32_t r = (n & ~(m << f)) | ((v & m) << f); + + setnvalue(res, double(r)); + return 1; + } + } + + return -1; +} + +static int luauF_rrotate(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + + unsigned u; + luai_num2unsigned(u, a1); + int s = int(a2); + + // MSVC doesn't recognize the rotate form that is UB-safe +#ifdef _MSC_VER + uint32_t r = _rotr(u, s); +#else + uint32_t r = (u >> (s & 31)) | (u << ((32 - s) & 31)); +#endif + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + +static int luauF_rshift(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args)) + { + double a1 = nvalue(arg0); + double a2 = nvalue(args); + + unsigned u; + luai_num2unsigned(u, a1); + int s = int(a2); + + // note: we only specialize fast-path that doesn't require further conditionals (negative shifts and shifts greater or equal to bit width can + // be handled generically) + if (unsigned(s) < 32) + { + uint32_t r = u >> s; + + setnvalue(res, double(r)); + return 1; + } + } + + return -1; +} + +static int luauF_type(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + int tt = ttype(arg0); + TString* ttname = L->global->ttname[tt]; + + setsvalue2s(L, res, ttname); + return 1; + } + + return -1; +} + +static int luauF_byte(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && ttisstring(arg0) && ttisnumber(args)) + { + TString* ts = tsvalue(arg0); + int i = int(nvalue(args)); + int j = (nparams >= 3) ? (ttisnumber(args + 1) ? int(nvalue(args + 1)) : 0) : i; + + if (i >= 1 && j >= i && j <= int(ts->len)) + { + int c = j - i + 1; + const char* s = getstr(ts); + + // for vararg returns, we only support a single result + // this is because this frees us from concerns about stack space + if (c == (nresults < 0 ? 1 : nresults)) + { + for (int k = 0; k < c; ++k) + { + setnvalue(res + k, uint8_t(s[i + k - 1])); + } + + return c; + } + } + } + + return -1; +} + +static int luauF_char(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + char buffer[8]; + + if (nparams < int(sizeof(buffer)) && nresults <= 1) + { + + if (nparams >= 1) + { + if (!ttisnumber(arg0)) + return -1; + + int ch = int(nvalue(arg0)); + + if ((unsigned char)(ch) != ch) + return -1; + + buffer[0] = ch; + } + + for (int i = 2; i <= nparams; ++i) + { + if (!ttisnumber(args + (i - 2))) + return -1; + + int ch = int(nvalue(args + (i - 2))); + + if ((unsigned char)(ch) != ch) + return -1; + + buffer[i - 1] = ch; + } + + buffer[nparams] = 0; + + setsvalue2s(L, res, luaS_newlstr(L, buffer, nparams)); + return 1; + } + + return -1; +} + +static int luauF_len(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisstring(arg0)) + { + TString* ts = tsvalue(arg0); + + setnvalue(res, int(ts->len)); + return 1; + } + + return -1; +} + +static int luauF_typeof(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1) + { + const TString* ttname = luaT_objtypenamestr(L, arg0); + + setsvalue2s(L, res, ttname); + return 1; + } + + return -1; +} + +static int luauF_sub(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 3 && nresults <= 1 && ttisstring(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + TString* ts = tsvalue(arg0); + int i = int(nvalue(args)); + int j = int(nvalue(args + 1)); + + if (i >= 1 && j >= i && unsigned(j - 1) < unsigned(ts->len)) + { + setsvalue2s(L, res, luaS_newlstr(L, getstr(ts) + (i - 1), j - i + 1)); + return 1; + } + } + + return -1; +} + +static int luauF_clamp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + double v = nvalue(arg0); + double min = nvalue(args); + double max = nvalue(args + 1); + + if (min <= max) + { + double r = v < min ? min : v; + r = r > max ? max : r; + + setnvalue(res, r); + return 1; + } + } + + return -1; +} + +static int luauF_sign(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double v = nvalue(arg0); + setnvalue(res, v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0); + return 1; + } + + return -1; +} + +LUAU_FASTMATH_BEGIN +static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double v = nvalue(arg0); + setnvalue(res, round(v)); + return 1; + } + + return -1; +} +LUAU_FASTMATH_END + +static int luauF_rawequal(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1) + { + setbvalue(res, luaO_rawequalObj(arg0, args)); + return 1; + } + + return -1; +} + +static int luauF_rawget(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 2 && nresults <= 1 && ttistable(arg0)) + { + setobj2s(L, res, luaH_get(hvalue(arg0), args)); + return 1; + } + + return -1; +} + +static int luauF_rawset(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 3 && nresults <= 1 && ttistable(arg0)) + { + const TValue* key = args; + if (ttisnil(key)) + return -1; + else if (ttisnumber(key) && luai_numisnan(nvalue(key))) + return -1; + else if (ttisvector(key) && luai_vecisnan(vvalue(key))) + return -1; + + if (hvalue(arg0)->readonly) + return -1; + + setobj2s(L, res, arg0); + setobj2t(L, luaH_set(L, hvalue(arg0), args), args + 1); + luaC_barriert(L, hvalue(arg0), args + 1); + return 1; + } + + return -1; +} + +static int luauF_tinsert(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams == 2 && nresults <= 0 && ttistable(arg0)) + { + if (hvalue(arg0)->readonly) + return -1; + + int pos = luaH_getn(hvalue(arg0)) + 1; + setobj2t(L, luaH_setnum(L, hvalue(arg0), pos), args); + luaC_barriert(L, hvalue(arg0), args); + return 0; + } + + return -1; +} + +static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults < 0 && ttistable(arg0)) + { + Table* t = hvalue(arg0); + int n = -1; + + if (nparams == 1) + n = luaH_getn(t); + else if (nparams == 3 && ttisnumber(args) && ttisnumber(args + 1) && nvalue(args) == 1.0) + n = int(nvalue(args + 1)); + + if (n >= 0 && n <= t->sizearray && cast_int(L->stack_last - res) >= n) + { + TValue* array = t->array; + for (int i = 0; i < n; ++i) + setobj2s(L, res + i, array + i); + expandstacklimit(L, res + n); + return n; + } + } + + return -1; +} + +static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) + { + double x = nvalue(arg0); + double y = nvalue(args); + double z = nvalue(args + 1); + + setvvalue(res, float(x), float(y), float(z)); + return 1; + } + + return -1; +} + +luau_FastFunction luauF_table[256] = { + NULL, + luauF_assert, + + luauF_abs, + luauF_acos, + luauF_asin, + luauF_atan2, + luauF_atan, + luauF_ceil, + luauF_cosh, + luauF_cos, + luauF_deg, + luauF_exp, + luauF_floor, + luauF_fmod, + luauF_frexp, + luauF_ldexp, + luauF_log10, + luauF_log, + luauF_max, + luauF_min, + luauF_modf, + luauF_pow, + luauF_rad, + luauF_sinh, + luauF_sin, + luauF_sqrt, + luauF_tanh, + luauF_tan, + + luauF_arshift, + luauF_band, + luauF_bnot, + luauF_bor, + luauF_bxor, + luauF_btest, + luauF_extract, + luauF_lrotate, + luauF_lshift, + luauF_replace, + luauF_rrotate, + luauF_rshift, + + luauF_type, + + luauF_byte, + luauF_char, + luauF_len, + + luauF_typeof, + + luauF_sub, + + luauF_clamp, + luauF_sign, + luauF_round, + + luauF_rawset, + luauF_rawget, + luauF_rawequal, + + luauF_tinsert, + luauF_tunpack, + + luauF_vector, +}; diff --git a/VM/src/lbuiltins.h b/VM/src/lbuiltins.h new file mode 100644 index 0000000..a642c93 --- /dev/null +++ b/VM/src/lbuiltins.h @@ -0,0 +1,9 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" + +typedef int (*luau_FastFunction)(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams); + +extern luau_FastFunction luauF_table[256]; diff --git a/VM/src/lbytecode.h b/VM/src/lbytecode.h new file mode 100644 index 0000000..c4d250d --- /dev/null +++ b/VM/src/lbytecode.h @@ -0,0 +1,9 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +// This is a forwarding header for Luau bytecode definition +// Luau consists of several components, including compiler (Ast, Compiler) and VM (virtual machine) +// These components are fully independent, but they both need the bytecode format defined in this header +// so it needs to be shared. +#include "../../Compiler/include/Luau/Bytecode.h" diff --git a/VM/src/lcommon.h b/VM/src/lcommon.h new file mode 100644 index 0000000..adbd81f --- /dev/null +++ b/VM/src/lcommon.h @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include +#include + +#include "luaconf.h" + +// This is a forwarding header for Luau common definition (assertions, flags) +// Luau consists of several components, including compiler (Ast, Compiler) and VM (virtual machine) +// These components are fully independent, but they need a common set of utilities defined in this header +// so it needs to be shared. +#include "../../Ast/include/Luau/Common.h" + +typedef LUAI_USER_ALIGNMENT_T L_Umaxalign; + +/* internal assertions for in-house debugging */ +#define check_exp(c, e) (LUAU_ASSERT(c), (e)) +#define api_check(l, e) LUAU_ASSERT(e) + +#ifndef cast_to +#define cast_to(t, exp) ((t)(exp)) +#endif + +#define cast_byte(i) cast_to(uint8_t, (i)) +#define cast_num(i) cast_to(double, (i)) +#define cast_int(i) cast_to(int, (i)) + +/* +** type for virtual-machine instructions +** must be an unsigned with (at least) 4 bytes (see details in lopcodes.h) +*/ +typedef uint32_t Instruction; + +/* +** macro to control inclusion of some hard tests on stack reallocation +*/ +#if defined(HARDSTACKTESTS) && HARDSTACKTESTS +#define condhardstacktests(x) (x) +#else +#define condhardstacktests(x) ((void)0) +#endif + +/* +** macro to control inclusion of some hard tests on garbage collection +*/ +#if defined(HARDMEMTESTS) && HARDMEMTESTS +#define condhardmemtests(x, l) (HARDMEMTESTS >= l ? (x) : (void)0) +#else +#define condhardmemtests(x, l) ((void)0) +#endif diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp new file mode 100644 index 0000000..c0b50b9 --- /dev/null +++ b/VM/src/lcorolib.cpp @@ -0,0 +1,265 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include "lstate.h" +#include "lvm.h" + +LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false) + +#define CO_RUN 0 /* running */ +#define CO_SUS 1 /* suspended */ +#define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ +#define CO_DEAD 3 + +#define CO_STATUS_ERROR -1 +#define CO_STATUS_BREAK -2 + +static const char* const statnames[] = {"running", "suspended", "normal", "dead"}; + +static int costatus(lua_State* L, lua_State* co) +{ + if (co == L) + return CO_RUN; + if (co->status == LUA_YIELD) + return CO_SUS; + if (co->status == LUA_BREAK) + return CO_NOR; + if (co->status != 0) /* some error occured */ + return CO_DEAD; + if (co->ci != co->base_ci) /* does it have frames? */ + return CO_NOR; + if (co->top == co->base) + return CO_DEAD; + return CO_SUS; /* initial state */ +} + +static int luaB_costatus(lua_State* L) +{ + lua_State* co = lua_tothread(L, 1); + luaL_argexpected(L, co, 1, "thread"); + lua_pushstring(L, statnames[costatus(L, co)]); + return 1; +} + +static int auxresume(lua_State* L, lua_State* co, int narg) +{ + // error handling for edge cases + if (co->status != LUA_YIELD) + { + int status = costatus(L, co); + if (status != CO_SUS) + { + lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]); + return CO_STATUS_ERROR; + } + } + + if (narg) + { + if (!lua_checkstack(co, narg)) + luaL_error(L, "too many arguments to resume"); + lua_xmove(L, co, narg); + } + + co->singlestep = L->singlestep; + + int status = lua_resume(co, L, narg); + if (status == 0 || status == LUA_YIELD) + { + int nres = cast_int(co->top - co->base); + if (nres) + { + /* +1 accounts for true/false status in resumefinish */ + if (nres + 1 > LUA_MINSTACK && !lua_checkstack(L, nres + 1)) + luaL_error(L, "too many results to resume"); + lua_xmove(co, L, nres); /* move yielded values */ + } + return nres; + } + else if (status == LUA_BREAK) + { + return CO_STATUS_BREAK; + } + else + { + lua_xmove(co, L, 1); /* move error message */ + return CO_STATUS_ERROR; + } +} + +static int interruptThread(lua_State* L, lua_State* co) +{ + // notify the debugger that the thread was suspended + if (L->global->cb.debuginterrupt) + luau_callhook(L, L->global->cb.debuginterrupt, co); + + return lua_break(L); +} + +static int auxresumecont(lua_State* L, lua_State* co) +{ + if (co->status == 0 || co->status == LUA_YIELD) + { + int nres = cast_int(co->top - co->base); + if (!lua_checkstack(L, nres + 1)) + luaL_error(L, "too many results to resume"); + lua_xmove(co, L, nres); /* move yielded values */ + return nres; + } + else + { + lua_rawcheckstack(L, 2); + lua_xmove(co, L, 1); /* move error message */ + return CO_STATUS_ERROR; + } +} + +static int luaB_coresumefinish(lua_State* L, int r) +{ + if (r < 0) + { + lua_pushboolean(L, 0); + lua_insert(L, -2); + return 2; /* return false + error message */ + } + else + { + lua_pushboolean(L, 1); + lua_insert(L, -(r + 1)); + return r + 1; /* return true + `resume' returns */ + } +} + +static int luaB_coresumey(lua_State* L) +{ + lua_State* co = lua_tothread(L, 1); + luaL_argexpected(L, co, 1, "thread"); + int narg = cast_int(L->top - L->base) - 1; + int r = auxresume(L, co, narg); + + if (r == CO_STATUS_BREAK) + return interruptThread(L, co); + + return luaB_coresumefinish(L, r); +} + +static int luaB_coresumecont(lua_State* L, int status) +{ + lua_State* co = lua_tothread(L, 1); + luaL_argexpected(L, co, 1, "thread"); + + // if coroutine still hasn't yielded after the break, break current thread again + if (co->status == LUA_BREAK) + return interruptThread(L, co); + + int r = auxresumecont(L, co); + + return luaB_coresumefinish(L, r); +} + +static int luaB_auxwrapfinish(lua_State* L, int r) +{ + if (r < 0) + { + if (lua_isstring(L, -1)) + { /* error object is a string? */ + luaL_where(L, 1); /* add extra info */ + lua_insert(L, -2); + lua_concat(L, 2); + } + lua_error(L); /* propagate error */ + } + return r; +} + +static int luaB_auxwrapy(lua_State* L) +{ + lua_State* co = lua_tothread(L, lua_upvalueindex(1)); + int narg = cast_int(L->top - L->base); + int r = auxresume(L, co, narg); + + if (r == CO_STATUS_BREAK) + return interruptThread(L, co); + + return luaB_auxwrapfinish(L, r); +} + +static int luaB_auxwrapcont(lua_State* L, int status) +{ + lua_State* co = lua_tothread(L, lua_upvalueindex(1)); + + // if coroutine still hasn't yielded after the break, break current thread again + if (co->status == LUA_BREAK) + return interruptThread(L, co); + + int r = auxresumecont(L, co); + + return luaB_auxwrapfinish(L, r); +} + +static int luaB_cocreate(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TFUNCTION); + lua_State* NL = lua_newthread(L); + + if (FFlag::LuauPreferXpush) + { + lua_xpush(L, NL, 1); // push function on top of NL + } + else + { + lua_pushvalue(L, 1); /* move function to top */ + lua_xmove(L, NL, 1); /* move function from L to NL */ + } + + return 1; +} + +static int luaB_cowrap(lua_State* L) +{ + luaB_cocreate(L); + + lua_pushcfunction(L, luaB_auxwrapy, NULL, 1, luaB_auxwrapcont); + + return 1; +} + +static int luaB_yield(lua_State* L) +{ + int nres = cast_int(L->top - L->base); + return lua_yield(L, nres); +} + +static int luaB_corunning(lua_State* L) +{ + if (lua_pushthread(L)) + lua_pushnil(L); /* main thread is not a coroutine */ + return 1; +} + +static int luaB_yieldable(lua_State* L) +{ + lua_pushboolean(L, lua_isyieldable(L)); + return 1; +} + +static const luaL_Reg co_funcs[] = { + {"create", luaB_cocreate}, + {"running", luaB_corunning}, + {"status", luaB_costatus}, + {"wrap", luaB_cowrap}, + {"yield", luaB_yield}, + {"isyieldable", luaB_yieldable}, + {NULL, NULL}, +}; + +LUALIB_API int luaopen_coroutine(lua_State* L) +{ + luaL_register(L, LUA_COLIBNAME, co_funcs); + + lua_pushcfunction(L, luaB_coresumey, "resume", 0, luaB_coresumecont); + lua_setfield(L, -2, "resume"); + + return 1; +} diff --git a/VM/src/ldblib.cpp b/VM/src/ldblib.cpp new file mode 100644 index 0000000..965d2b3 --- /dev/null +++ b/VM/src/ldblib.cpp @@ -0,0 +1,167 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include "lvm.h" + +#include +#include +#include + +static lua_State* getthread(lua_State* L, int* arg) +{ + if (lua_isthread(L, 1)) + { + *arg = 1; + return lua_tothread(L, 1); + } + else + { + *arg = 0; + return L; + } +} + +static int db_info(lua_State* L) +{ + int arg; + lua_State* L1 = getthread(L, &arg); + + // If L1 != L, L1 can be in any state, and therefore there are no guarantees about its stack space + if (L != L1) + lua_rawcheckstack(L1, 1); // for 'f' option + + int level; + if (lua_isnumber(L, arg + 1)) + { + level = (int)lua_tointeger(L, arg + 1); + luaL_argcheck(L, level >= 0, arg + 1, "level can't be negative"); + } + else if (arg == 0 && lua_isfunction(L, 1)) + { + // convert absolute index to relative index + level = -lua_gettop(L); + } + else + luaL_argerror(L, arg + 1, "function or level expected"); + + const char* options = luaL_checkstring(L, arg + 2); + + lua_Debug ar; + if (!lua_getinfo(L1, level, options, &ar)) + return 0; + + int results = 0; + bool occurs[26] = {}; + + for (const char* it = options; *it; ++it) + { + if (unsigned(*it - 'a') < 26) + { + if (occurs[*it - 'a']) + luaL_argerror(L, arg + 2, "duplicate option"); + occurs[*it - 'a'] = true; + } + + switch (*it) + { + case 's': + lua_pushstring(L, ar.short_src); + results++; + break; + + case 'l': + lua_pushinteger(L, ar.currentline); + results++; + break; + + case 'n': + lua_pushstring(L, ar.name ? ar.name : ""); + results++; + break; + + case 'f': + if (L1 == L) + lua_pushvalue(L, -1 - results); /* function is right before results */ + else + lua_xmove(L1, L, 1); /* function is at top of L1 */ + results++; + break; + + case 'a': + lua_pushinteger(L, ar.nparams); + lua_pushboolean(L, ar.isvararg); + results += 2; + break; + + default: + luaL_argerror(L, arg + 2, "invalid option"); + } + } + + return results; +} + +static int db_traceback(lua_State* L) +{ + int arg; + lua_State* L1 = getthread(L, &arg); + const char* msg = luaL_optstring(L, arg + 1, NULL); + int level = luaL_optinteger(L, arg + 2, (L == L1) ? 1 : 0); + luaL_argcheck(L, level >= 0, arg + 2, "level can't be negative"); + + luaL_Buffer buf; + luaL_buffinit(L, &buf); + + if (msg) + { + luaL_addstring(&buf, msg); + luaL_addstring(&buf, "\n"); + } + + lua_Debug ar; + for (int i = level; lua_getinfo(L1, i, "sln", &ar); ++i) + { + if (strcmp(ar.what, "C") == 0) + continue; + + if (ar.source) + luaL_addstring(&buf, ar.short_src); + + if (ar.currentline > 0) + { + char line[32]; +#ifdef _MSC_VER + _itoa(ar.currentline, line, 10); // 5x faster than sprintf +#else + sprintf(line, "%d", ar.currentline); +#endif + + luaL_addchar(&buf, ':'); + luaL_addstring(&buf, line); + } + + if (ar.name) + { + luaL_addstring(&buf, " function "); + luaL_addstring(&buf, ar.name); + } + + luaL_addchar(&buf, '\n'); + } + + luaL_pushresult(&buf); + return 1; +} + +static const luaL_Reg dblib[] = { + {"info", db_info}, + {"traceback", db_traceback}, + {NULL, NULL}, +}; + +LUALIB_API int luaopen_debug(lua_State* L) +{ + luaL_register(L, LUA_DBLIBNAME, dblib); + return 1; +} diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp new file mode 100644 index 0000000..1890e68 --- /dev/null +++ b/VM/src/ldebug.cpp @@ -0,0 +1,428 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "ldebug.h" + +#include "lapi.h" +#include "lfunc.h" +#include "lmem.h" +#include "lgc.h" +#include "ldo.h" +#include "lbytecode.h" + +#include +#include + +static const char* getfuncname(Closure* f); + +static int currentpc(lua_State* L, CallInfo* ci) +{ + return pcRel(ci->savedpc, ci_func(ci)->l.p); +} + +static int currentline(lua_State* L, CallInfo* ci) +{ + return luaG_getline(ci_func(ci)->l.p, currentpc(L, ci)); +} + +static Proto* getluaproto(CallInfo* ci) +{ + return (isLua(ci) ? cast_to(Proto*, ci_func(ci)->l.p) : NULL); +} + +int lua_getargument(lua_State* L, int level, int n) +{ + if (unsigned(level) >= unsigned(L->ci - L->base_ci)) + return 0; + + CallInfo* ci = L->ci - level; + Proto* fp = getluaproto(ci); + int res = 0; + + if (fp && n > 0) + { + if (n <= fp->numparams) + { + luaC_checkthreadsleep(L); + luaA_pushobject(L, ci->base + (n - 1)); + res = 1; + } + else if (fp->is_vararg && n < ci->base - ci->func) + { + luaC_checkthreadsleep(L); + luaA_pushobject(L, ci->func + n); + res = 1; + } + } + + return res; +} + +const char* lua_getlocal(lua_State* L, int level, int n) +{ + if (unsigned(level) >= unsigned(L->ci - L->base_ci)) + return 0; + + CallInfo* ci = L->ci - level; + Proto* fp = getluaproto(ci); + const LocVar* var = fp ? luaF_getlocal(fp, n, currentpc(L, ci)) : NULL; + if (var) + { + luaC_checkthreadsleep(L); + luaA_pushobject(L, ci->base + var->reg); + } + const char* name = var ? getstr(var->varname) : NULL; + return name; +} + +const char* lua_setlocal(lua_State* L, int level, int n) +{ + if (unsigned(level) >= unsigned(L->ci - L->base_ci)) + return 0; + + CallInfo* ci = L->ci - level; + Proto* fp = getluaproto(ci); + const LocVar* var = fp ? luaF_getlocal(fp, n, currentpc(L, ci)) : NULL; + if (var) + setobjs2s(L, ci->base + var->reg, L->top - 1); + L->top--; /* pop value */ + const char* name = var ? getstr(var->varname) : NULL; + return name; +} + +static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, CallInfo* ci) +{ + int status = 1; + for (; *what; what++) + { + switch (*what) + { + case 's': + { + if (f->isC) + { + ar->source = "=[C]"; + ar->what = "C"; + ar->linedefined = -1; + } + else + { + ar->source = getstr(f->l.p->source); + ar->what = "Lua"; + ar->linedefined = luaG_getline(f->l.p, 0); + } + luaO_chunkid(ar->short_src, ar->source, LUA_IDSIZE); + break; + } + case 'l': + { + if (ci) + { + ar->currentline = isLua(ci) ? currentline(L, ci) : -1; + } + else + { + ar->currentline = f->isC ? -1 : luaG_getline(f->l.p, 0); + } + + break; + } + case 'u': + { + ar->nupvals = f->nupvalues; + break; + } + case 'a': + { + if (f->isC) + { + ar->isvararg = 1; + ar->nparams = 0; + } + else + { + ar->isvararg = f->l.p->is_vararg; + ar->nparams = f->l.p->numparams; + } + break; + } + case 'n': + { + ar->name = ci ? getfuncname(ci_func(ci)) : getfuncname(f); + break; + } + default:; + } + } + return status; +} + +int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar) +{ + int status = 0; + Closure* f = NULL; + CallInfo* ci = NULL; + if (level < 0) + { + StkId func = L->top + level; + api_check(L, ttisfunction(func)); + f = clvalue(func); + } + else if (unsigned(level) < unsigned(L->ci - L->base_ci)) + { + ci = L->ci - level; + LUAU_ASSERT(ttisfunction(ci->func)); + f = clvalue(ci->func); + } + if (f) + { + status = auxgetinfo(L, what, ar, f, ci); + if (strchr(what, 'f')) + { + luaC_checkthreadsleep(L); + setclvalue(L, L->top, f); + incr_top(L); + } + } + return status; +} + +static const char* getfuncname(Closure* cl) +{ + if (cl->isC) + { + if (cl->c.debugname) + { + return cl->c.debugname; + } + } + else + { + Proto* p = cl->l.p; + + if (p->debugname) + { + return getstr(p->debugname); + } + } + return nullptr; +} + +l_noret luaG_typeerrorL(lua_State* L, const TValue* o, const char* op) +{ + const char* t = luaT_objtypename(L, o); + + luaG_runerror(L, "attempt to %s a %s value", op, t); +} + +l_noret luaG_forerrorL(lua_State* L, const TValue* o, const char* what) +{ + const char* t = luaT_objtypename(L, o); + + luaG_runerror(L, "invalid 'for' %s (number expected, got %s)", what, t); +} + +l_noret luaG_concaterror(lua_State* L, StkId p1, StkId p2) +{ + const char* t1 = luaT_objtypename(L, p1); + const char* t2 = luaT_objtypename(L, p2); + + luaG_runerror(L, "attempt to concatenate %s with %s", t1, t2); +} + +l_noret luaG_aritherror(lua_State* L, const TValue* p1, const TValue* p2, TMS op) +{ + const char* t1 = luaT_objtypename(L, p1); + const char* t2 = luaT_objtypename(L, p2); + const char* opname = luaT_eventname[op] + 2; // skip __ from metamethod name + + if (t1 == t2) + luaG_runerror(L, "attempt to perform arithmetic (%s) on %s", opname, t1); + else + luaG_runerror(L, "attempt to perform arithmetic (%s) on %s and %s", opname, t1, t2); +} + +l_noret luaG_ordererror(lua_State* L, const TValue* p1, const TValue* p2, TMS op) +{ + const char* t1 = luaT_objtypename(L, p1); + const char* t2 = luaT_objtypename(L, p2); + const char* opname = (op == TM_LT) ? "<" : (op == TM_LE) ? "<=" : "=="; + + luaG_runerror(L, "attempt to compare %s %s %s", t1, opname, t2); +} + +l_noret luaG_indexerror(lua_State* L, const TValue* p1, const TValue* p2) +{ + const char* t1 = luaT_objtypename(L, p1); + const char* t2 = luaT_objtypename(L, p2); + const TString* key = ttisstring(p2) ? tsvalue(p2) : 0; + + if (key && key->len <= 64) // limit length to make sure we don't generate very long error messages for very long keys + luaG_runerror(L, "attempt to index %s with '%s'", t1, getstr(key)); + else + luaG_runerror(L, "attempt to index %s with %s", t1, t2); +} + +static void pusherror(lua_State* L, const char* msg) +{ + CallInfo* ci = L->ci; + if (isLua(ci)) + { + char buff[LUA_IDSIZE]; /* add file:line information */ + luaO_chunkid(buff, getstr(getluaproto(ci)->source), LUA_IDSIZE); + int line = currentline(L, ci); + luaO_pushfstring(L, "%s:%d: %s", buff, line, msg); + } + else + { + lua_pushstring(L, msg); + } +} + +l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...) +{ + va_list argp; + va_start(argp, fmt); + char result[LUA_BUFFERSIZE]; + vsnprintf(result, sizeof(result), fmt, argp); + va_end(argp); + + pusherror(L, result); + luaD_throw(L, LUA_ERRRUN); +} + +void luaG_pusherror(lua_State* L, const char* error) +{ + pusherror(L, error); +} + +void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable) +{ + if (p->lineinfo) + { + for (int i = 0; i < p->sizecode; ++i) + { + // note: we keep prologue as is, instead opting to break at the first meaningful instruction + if (LUAU_INSN_OP(p->code[i]) == LOP_PREPVARARGS) + continue; + + if (luaG_getline(p, i) != line) + continue; + + // lazy copy of the original opcode array; done when the first breakpoint is set + if (!p->debuginsn) + { + p->debuginsn = luaM_newarray(L, p->sizecode, uint8_t, p->memcat); + for (int j = 0; j < p->sizecode; ++j) + p->debuginsn[j] = LUAU_INSN_OP(p->code[j]); + } + + uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->code[i]); + + // patch just the opcode byte, leave arguments alone + p->code[i] &= ~0xff; + p->code[i] |= op; + LUAU_ASSERT(LUAU_INSN_OP(p->code[i]) == op); + + // note: this is important! + // we only patch the *first* instruction in each proto that's attributed to a given line + // this can be changed, but if requires making patching a bit more nuanced so that we don't patch AUX words + break; + } + } + + for (int i = 0; i < p->sizep; ++i) + { + luaG_breakpoint(L, p->p[i], line, enable); + } +} + +bool luaG_onbreak(lua_State* L) +{ + if (L->ci == L->base_ci) + return false; + + if (!isLua(L->ci)) + return false; + + return LUAU_INSN_OP(*L->ci->savedpc) == LOP_BREAK; +} + +int luaG_getline(Proto* p, int pc) +{ + LUAU_ASSERT(pc >= 0 && pc < p->sizecode); + + if (!p->lineinfo) + return 0; + + return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc]; +} + +void lua_singlestep(lua_State* L, bool singlestep) +{ + L->singlestep = singlestep; +} + +void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable) +{ + const TValue* func = luaA_toobject(L, funcindex); + api_check(L, ttisfunction(func) && !clvalue(func)->isC); + + luaG_breakpoint(L, clvalue(func)->l.p, line, enable); +} + +static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) +{ + size_t size = strlen(data); + size_t copy = offset + size >= bufsize ? bufsize - offset - 1 : size; + memcpy(buf + offset, data, copy); + return offset + copy; +} + +const char* lua_debugtrace(lua_State* L) +{ + static char buf[4096]; + + const int limit1 = 10; + const int limit2 = 10; + + int depth = int(L->ci - L->base_ci); + size_t offset = 0; + + lua_Debug ar; + for (int level = 0; lua_getinfo(L, level, "sln", &ar); ++level) + { + if (ar.source) + offset = append(buf, sizeof(buf), offset, ar.short_src); + + if (ar.currentline > 0) + { + char line[32]; + sprintf(line, ":%d", ar.currentline); + + offset = append(buf, sizeof(buf), offset, line); + } + + if (ar.name) + { + offset = append(buf, sizeof(buf), offset, " function "); + offset = append(buf, sizeof(buf), offset, ar.name); + } + + offset = append(buf, sizeof(buf), offset, "\n"); + + if (depth > limit1 + limit2 && level == limit1 - 1) + { + char skip[32]; + sprintf(skip, "... (+%d frames)\n", int(depth - limit1 - limit2)); + + offset = append(buf, sizeof(buf), offset, skip); + + level = depth - limit2 - 1; + } + } + + LUAU_ASSERT(offset < sizeof(buf)); + buf[offset] = '\0'; + + return buf; +} diff --git a/VM/src/ldebug.h b/VM/src/ldebug.h new file mode 100644 index 0000000..cf905e9 --- /dev/null +++ b/VM/src/ldebug.h @@ -0,0 +1,28 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lstate.h" + +#define pcRel(pc, p) ((pc) ? cast_to(int, (pc) - (p)->code) - 1 : 0) + +#define luaG_typeerror(L, o, opname) luaG_typeerrorL(L, o, opname) +#define luaG_forerror(L, o, what) luaG_forerrorL(L, o, what) +#define luaG_runerror(L, fmt, ...) luaG_runerrorL(L, fmt, ##__VA_ARGS__) + +#define LUA_MEMERRMSG "not enough memory" +#define LUA_ERRERRMSG "error in error handling" + +LUAI_FUNC l_noret luaG_typeerrorL(lua_State* L, const TValue* o, const char* opname); +LUAI_FUNC l_noret luaG_forerrorL(lua_State* L, const TValue* o, const char* what); +LUAI_FUNC l_noret luaG_concaterror(lua_State* L, StkId p1, StkId p2); +LUAI_FUNC l_noret luaG_aritherror(lua_State* L, const TValue* p1, const TValue* p2, TMS op); +LUAI_FUNC l_noret luaG_ordererror(lua_State* L, const TValue* p1, const TValue* p2, TMS op); +LUAI_FUNC l_noret luaG_indexerror(lua_State* L, const TValue* p1, const TValue* p2); +LUAI_FUNC LUA_PRINTF_ATTR(2, 3) l_noret luaG_runerrorL(lua_State* L, const char* fmt, ...); +LUAI_FUNC void luaG_pusherror(lua_State* L, const char* error); + +LUAI_FUNC void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable); +LUAI_FUNC bool luaG_onbreak(lua_State* L); + +LUAI_FUNC int luaG_getline(Proto* p, int pc); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp new file mode 100644 index 0000000..ee4962a --- /dev/null +++ b/VM/src/ldo.cpp @@ -0,0 +1,554 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "ldo.h" + +#include "lstring.h" +#include "lfunc.h" +#include "lgc.h" +#include "lmem.h" +#include "lvm.h" + +#include + +#include +#include + +LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) + +/* +** {====================================================== +** Error-recovery functions +** ======================================================= +*/ + +#if LUA_USE_LONGJMP +struct lua_jmpbuf +{ + lua_jmpbuf* volatile prev; + volatile int status; + jmp_buf buf; +}; + +int luaD_rawrunprotected(lua_State* L, Pfunc f, void* ud) +{ + lua_jmpbuf jb; + jb.prev = L->global->errorjmp; + jb.status = 0; + L->global->errorjmp = &jb; + + if (setjmp(jb.buf) == 0) + f(L, ud); + + L->global->errorjmp = jb.prev; + return jb.status; +} + +l_noret luaD_throw(lua_State* L, int errcode) +{ + if (lua_jmpbuf* jb = L->global->errorjmp) + { + jb->status = errcode; + longjmp(jb->buf, 1); + } + + if (L->global->panic) + L->global->panic(L, errcode); + + abort(); +} +#else +class lua_exception : public std::exception +{ +public: + lua_exception(lua_State* L, int status) + : L(L) + , status(status) + { + } + + const char* what() const throw() override + { + if (FFlag::LuauExceptionMessageFix) + { + // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. + if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) + { + // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. + if (const char* str = lua_tostring(L, -1)) + { + return str; + } + } + + switch (status) + { + case LUA_ERRRUN: + return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; + case LUA_ERRSYNTAX: + return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; + case LUA_ERRMEM: + return "lua_exception: " LUA_MEMERRMSG; + case LUA_ERRERR: + return "lua_exception: " LUA_ERRERRMSG; + default: + return "lua_exception: unexpected exception status"; + } + } + else + { + return lua_tostring(L, -1); + } + } + + int getStatus() const + { + return status; + } + +private: + lua_State* L; + int status; +}; + +int luaD_rawrunprotected(lua_State* L, Pfunc f, void* ud) +{ + int status = 0; + + try + { + f(L, ud); + return 0; + } + catch (lua_exception& e) + { + // lua_exception means that luaD_throw was called and an exception object is on stack if status is ERRRUN + status = e.getStatus(); + } + catch (std::exception& e) + { + // Luau will never throw this, but this can catch exceptions that escape from C++ implementations of external functions + try + { + // there's no exception object on stack; let's push the error on stack so that error handling below can proceed + luaG_pusherror(L, e.what()); + status = LUA_ERRRUN; + } + catch (std::exception&) + { + // out of memory while allocating error string + status = LUA_ERRMEM; + } + } + + return status; +} + +l_noret luaD_throw(lua_State* L, int errcode) +{ + throw lua_exception(L, errcode); +} +#endif + +/* }====================================================== */ + +static void correctstack(lua_State* L, TValue* oldstack) +{ + CallInfo* ci; + GCObject* up; + L->top = (L->top - oldstack) + L->stack; + for (up = L->openupval; up != NULL; up = up->gch.next) + gco2uv(up)->v = (gco2uv(up)->v - oldstack) + L->stack; + for (ci = L->base_ci; ci <= L->ci; ci++) + { + ci->top = (ci->top - oldstack) + L->stack; + ci->base = (ci->base - oldstack) + L->stack; + ci->func = (ci->func - oldstack) + L->stack; + } + L->base = (L->base - oldstack) + L->stack; +} + +void luaD_reallocstack(lua_State* L, int newsize) +{ + TValue* oldstack = L->stack; + int realsize = newsize + 1 + EXTRA_STACK; + LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK - 1); + luaM_reallocarray(L, L->stack, L->stacksize, realsize, TValue, L->memcat); + for (int i = L->stacksize; i < realsize; i++) + setnilvalue(L->stack + i); /* erase new segment */ + L->stacksize = realsize; + L->stack_last = L->stack + newsize; + correctstack(L, oldstack); +} + +void luaD_reallocCI(lua_State* L, int newsize) +{ + CallInfo* oldci = L->base_ci; + luaM_reallocarray(L, L->base_ci, L->size_ci, newsize, CallInfo, L->memcat); + L->size_ci = newsize; + L->ci = (L->ci - oldci) + L->base_ci; + L->end_ci = L->base_ci + L->size_ci - 1; +} + +void luaD_growstack(lua_State* L, int n) +{ + if (n <= L->stacksize) /* double size is enough? */ + luaD_reallocstack(L, 2 * L->stacksize); + else + luaD_reallocstack(L, L->stacksize + n); +} + +CallInfo* luaD_growCI(lua_State* L) +{ + if (L->size_ci > LUAI_MAXCALLS) /* overflow while handling overflow? */ + luaD_throw(L, LUA_ERRERR); + else + { + luaD_reallocCI(L, 2 * L->size_ci); + if (L->size_ci > LUAI_MAXCALLS) + luaG_runerror(L, "stack overflow"); + } + return ++L->ci; +} + +/* +** Call a function (C or Lua). The function to be called is at *func. +** The arguments are on the stack, right after the function. +** When returns, all the results are on the stack, starting at the original +** function position. +*/ +void luaD_call(lua_State* L, StkId func, int nResults) +{ + if (++L->nCcalls >= LUAI_MAXCCALLS) + { + if (L->nCcalls == LUAI_MAXCCALLS) + luaG_runerror(L, "C stack overflow"); + else if (L->nCcalls >= (LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3))) + luaD_throw(L, LUA_ERRERR); /* error while handing stack error */ + } + if (luau_precall(L, func, nResults) == PCRLUA) + { /* is a Lua function? */ + L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */ + luau_execute(L); /* call it */ + } + L->nCcalls--; + luaC_checkGC(L); +} + +static void seterrorobj(lua_State* L, int errcode, StkId oldtop) +{ + switch (errcode) + { + case LUA_ERRMEM: + { + setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_MEMERRMSG)); /* can not fail because string is pinned in luaopen */ + break; + } + case LUA_ERRERR: + { + setsvalue2s(L, oldtop, luaS_newliteral(L, LUA_ERRERRMSG)); /* can not fail because string is pinned in luaopen */ + break; + } + case LUA_ERRSYNTAX: + case LUA_ERRRUN: + { + setobjs2s(L, oldtop, L->top - 1); /* error message on current top */ + break; + } + } + L->top = oldtop + 1; +} + +static void resume_continue(lua_State* L) +{ + // unroll Lua/C combined stack, processing continuations + while (L->status == 0 && L->ci > L->base_ci) + { + LUAU_ASSERT(L->baseCcalls == L->nCcalls); + + Closure* cl = curr_func(L); + + if (cl->isC) + { + LUAU_ASSERT(cl->c.cont); + + // C continuation; we expect this to be followed by Lua continuations + int n = cl->c.cont(L, 0); + + // Continuation can break again + if (L->status == LUA_BREAK) + break; + + luau_poscall(L, L->top - n); + } + else + { + // Lua continuation; it terminates at the end of the stack or at another C continuation + luau_execute(L); + } + } +} + +static void resume(lua_State* L, void* ud) +{ + StkId firstArg = cast_to(StkId, ud); + + if (L->status == 0) + { + // start coroutine + LUAU_ASSERT(L->ci == L->base_ci && firstArg > L->base); + if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA) + return; + + L->ci->flags |= LUA_CALLINFO_RETURN; + } + else + { + // resume from previous yield or break + LUAU_ASSERT(L->status == LUA_YIELD || L->status == LUA_BREAK); + L->status = 0; + + Closure* cl = curr_func(L); + + if (cl->isC) + { + // if the top stack frame is a C call continuation, resume_continue will handle that case + if (!cl->c.cont) + { + // finish interrupted execution of `OP_CALL' + luau_poscall(L, firstArg); + } + } + else + { + // yielded inside a hook: just continue its execution + L->base = L->ci->base; + } + } + + // run continuations from the stack; typically resumes Lua code and pcalls + resume_continue(L); +} + +static CallInfo* resume_findhandler(lua_State* L) +{ + CallInfo* ci = L->ci; + + while (ci > L->base_ci) + { + if (ci->flags & LUA_CALLINFO_HANDLE) + return ci; + + ci--; + } + + return NULL; +} + +static void resume_handle(lua_State* L, void* ud) +{ + CallInfo* ci = (CallInfo*)ud; + Closure* cl = ci_func(ci); + + LUAU_ASSERT(ci->flags & LUA_CALLINFO_HANDLE); + LUAU_ASSERT(cl->isC && cl->c.cont); + LUAU_ASSERT(L->status != 0); + + // restore nCcalls back to base since this might not have happened during error handling + L->nCcalls = L->baseCcalls; + + // make sure we don't run the handler the second time + ci->flags &= ~LUA_CALLINFO_HANDLE; + + // restore thread status to 0 since we're handling the error + int status = L->status; + L->status = 0; + + // push error object to stack top if it's not already there + if (status != LUA_ERRRUN) + seterrorobj(L, status, L->top); + + // adjust the stack frame for ci to prepare for cont call + L->base = ci->base; + ci->top = L->top; + + // save ci pointer - it will be invalidated by cont call! + ptrdiff_t old_ci = saveci(L, ci); + + // handle the error in continuation; note that this executes on top of original stack! + int n = cl->c.cont(L, status); + + // restore the stack frame to the frame with continuation + L->ci = restoreci(L, old_ci); + + // close eventual pending closures; this means it's now safe to restore stack + luaF_close(L, L->base); + + // finish cont call and restore stack to previous ci top + luau_poscall(L, L->top - n); + + // run remaining continuations from the stack; typically resumes pcalls + resume_continue(L); +} + +static int resume_error(lua_State* L, const char* msg) +{ + L->top = L->ci->base; + setsvalue2s(L, L->top, luaS_new(L, msg)); + incr_top(L); + return LUA_ERRRUN; +} + +static void resume_finish(lua_State* L, int status) +{ + L->nCcalls = L->baseCcalls; + resetbit(L->stackstate, THREAD_ACTIVEBIT); + + if (status != 0) + { /* error? */ + L->status = cast_byte(status); /* mark thread as `dead' */ + seterrorobj(L, status, L->top); + L->ci->top = L->top; + } + else if (L->status == 0) + { + expandstacklimit(L, L->top); + } +} + +int lua_resume(lua_State* L, lua_State* from, int nargs) +{ + int status; + if (L->status != LUA_YIELD && L->status != LUA_BREAK && (L->status != 0 || L->ci != L->base_ci)) + return resume_error(L, "cannot resume non-suspended coroutine"); + + L->nCcalls = from ? from->nCcalls : 0; + if (L->nCcalls >= LUAI_MAXCCALLS) + return resume_error(L, "C stack overflow"); + + L->baseCcalls = ++L->nCcalls; + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + + luaC_checkthreadsleep(L); + + status = luaD_rawrunprotected(L, resume, L->top - nargs); + + CallInfo* ch = NULL; + while (status != 0 && (ch = resume_findhandler(L)) != NULL) + { + L->status = cast_byte(status); + status = luaD_rawrunprotected(L, resume_handle, ch); + } + + resume_finish(L, status); + --L->nCcalls; + return L->status; +} + +int lua_resumeerror(lua_State* L, lua_State* from) +{ + int status; + if (L->status != LUA_YIELD && L->status != LUA_BREAK && (L->status != 0 || L->ci != L->base_ci)) + return resume_error(L, "cannot resume non-suspended coroutine"); + + L->nCcalls = from ? from->nCcalls : 0; + if (L->nCcalls >= LUAI_MAXCCALLS) + return resume_error(L, "C stack overflow"); + + L->baseCcalls = ++L->nCcalls; + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + + luaC_checkthreadsleep(L); + + status = LUA_ERRRUN; + + CallInfo* ch = NULL; + while (status != 0 && (ch = resume_findhandler(L)) != NULL) + { + L->status = cast_byte(status); + status = luaD_rawrunprotected(L, resume_handle, ch); + } + + resume_finish(L, status); + --L->nCcalls; + return L->status; +} + +int lua_yield(lua_State* L, int nresults) +{ + if (L->nCcalls > L->baseCcalls) + luaG_runerror(L, "attempt to yield across metamethod/C-call boundary"); + L->base = L->top - nresults; /* protect stack slots below */ + L->status = LUA_YIELD; + return -1; +} + +int lua_break(lua_State* L) +{ + if (L->nCcalls > L->baseCcalls) + luaG_runerror(L, "attempt to break across metamethod/C-call boundary"); + L->status = LUA_BREAK; + return -1; +} + +int lua_isyieldable(lua_State* L) +{ + return (L->nCcalls <= L->baseCcalls); +} + +static void callerrfunc(lua_State* L, void* ud) +{ + StkId errfunc = cast_to(StkId, ud); + + setobjs2s(L, L->top, L->top - 1); + setobjs2s(L, L->top - 1, errfunc); + incr_top(L); + luaD_call(L, L->top - 2, 1); +} + +static void restore_stack_limit(lua_State* L) +{ + LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK - 1); + if (L->size_ci > LUAI_MAXCALLS) + { /* there was an overflow? */ + int inuse = cast_int(L->ci - L->base_ci); + if (inuse + 1 < LUAI_MAXCALLS) /* can `undo' overflow? */ + luaD_reallocCI(L, LUAI_MAXCALLS); + } +} + +int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t ef) +{ + int status; + unsigned short oldnCcalls = L->nCcalls; + ptrdiff_t old_ci = saveci(L, L->ci); + status = luaD_rawrunprotected(L, func, u); + if (status != 0) + { + // call user-defined error function (used in xpcall) + if (ef) + { + // if errfunc fails, we fail with "error in error handling" + if (luaD_rawrunprotected(L, callerrfunc, restorestack(L, ef)) != 0) + status = LUA_ERRERR; + } + + // an error occured, check if we have a protected error callback + if (L->global->cb.debugprotectederror) + { + L->global->cb.debugprotectederror(L); + + // debug hook is only allowed to break + if (L->status == LUA_BREAK) + return 0; + } + + StkId oldtop = restorestack(L, old_top); + luaF_close(L, oldtop); /* close eventual pending closures */ + seterrorobj(L, status, oldtop); + L->nCcalls = oldnCcalls; + L->ci = restoreci(L, old_ci); + L->base = L->ci->base; + restore_stack_limit(L); + } + return status; +} diff --git a/VM/src/ldo.h b/VM/src/ldo.h new file mode 100644 index 0000000..4fe1c34 --- /dev/null +++ b/VM/src/ldo.h @@ -0,0 +1,54 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" +#include "lstate.h" +#include "luaconf.h" +#include "ldebug.h" + +#define luaD_checkstack(L, n) \ + if ((char*)L->stack_last - (char*)L->top <= (n) * (int)sizeof(TValue)) \ + luaD_growstack(L, n); \ + else \ + condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK - 1)); + +#define incr_top(L) \ + { \ + luaD_checkstack(L, 1); \ + L->top++; \ + } + +#define savestack(L, p) ((char*)(p) - (char*)L->stack) +#define restorestack(L, n) ((TValue*)((char*)L->stack + (n))) + +#define expandstacklimit(L, p) \ + { \ + LUAU_ASSERT((p) <= (L)->stack_last); \ + if ((L)->ci->top < (p)) \ + (L)->ci->top = (p); \ + } + +#define incr_ci(L) ((L->ci == L->end_ci) ? luaD_growCI(L) : (condhardstacktests(luaD_reallocCI(L, L->size_ci)), ++L->ci)) + +#define saveci(L, p) ((char*)(p) - (char*)L->base_ci) +#define restoreci(L, n) ((CallInfo*)((char*)L->base_ci + (n))) + +/* results from luaD_precall */ +#define PCRLUA 0 /* initiated a call to a Lua function */ +#define PCRC 1 /* did a call to a C function */ +#define PCRYIELD 2 /* C funtion yielded */ + +/* type of protected functions, to be ran by `runprotected' */ +typedef void (*Pfunc)(lua_State* L, void* ud); + +LUAI_FUNC CallInfo* luaD_growCI(lua_State* L); + +LUAI_FUNC void luaD_call(lua_State* L, StkId func, int nResults); +LUAI_FUNC int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t oldtop, ptrdiff_t ef); +LUAI_FUNC void luaD_reallocCI(lua_State* L, int newsize); +LUAI_FUNC void luaD_reallocstack(lua_State* L, int newsize); +LUAI_FUNC void luaD_growstack(lua_State* L, int n); + +LUAI_FUNC l_noret luaD_throw(lua_State* L, int errcode); +LUAI_FUNC int luaD_rawrunprotected(lua_State* L, Pfunc f, void* ud); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp new file mode 100644 index 0000000..0b54302 --- /dev/null +++ b/VM/src/lfunc.cpp @@ -0,0 +1,167 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lfunc.h" + +#include "lstate.h" +#include "lmem.h" +#include "lgc.h" + +Proto* luaF_newproto(lua_State* L) +{ + Proto* f = luaM_new(L, Proto, sizeof(Proto), L->activememcat); + luaC_link(L, f, LUA_TPROTO); + f->k = NULL; + f->sizek = 0; + f->p = NULL; + f->sizep = 0; + f->code = NULL; + f->sizecode = 0; + f->sizeupvalues = 0; + f->nups = 0; + f->upvalues = NULL; + f->numparams = 0; + f->is_vararg = 0; + f->maxstacksize = 0; + f->sizelineinfo = 0; + f->linegaplog2 = 0; + f->lineinfo = NULL; + f->abslineinfo = NULL; + f->sizelocvars = 0; + f->locvars = NULL; + f->source = NULL; + f->debugname = NULL; + f->debuginsn = NULL; + return f; +} + +Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) +{ + Closure* c = luaM_new(L, Closure, sizeLclosure(nelems), L->activememcat); + luaC_link(L, c, LUA_TFUNCTION); + c->isC = 0; + c->env = e; + c->nupvalues = cast_byte(nelems); + c->stacksize = p->maxstacksize; + c->preload = 0; + c->l.p = p; + for (int i = 0; i < nelems; ++i) + setnilvalue(&c->l.uprefs[i]); + return c; +} + +Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e) +{ + Closure* c = luaM_new(L, Closure, sizeCclosure(nelems), L->activememcat); + luaC_link(L, c, LUA_TFUNCTION); + c->isC = 1; + c->env = e; + c->nupvalues = cast_byte(nelems); + c->stacksize = LUA_MINSTACK; + c->preload = 0; + c->c.f = NULL; + c->c.cont = NULL; + c->c.debugname = NULL; + return c; +} + +UpVal* luaF_findupval(lua_State* L, StkId level) +{ + global_State* g = L->global; + GCObject** pp = &L->openupval; + UpVal* p; + UpVal* uv; + while (*pp != NULL && (p = gco2uv(*pp))->v >= level) + { + LUAU_ASSERT(p->v != &p->u.value); + if (p->v == level) + { /* found a corresponding upvalue? */ + if (isdead(g, obj2gco(p))) /* is it dead? */ + changewhite(obj2gco(p)); /* ressurect it */ + return p; + } + pp = &p->next; + } + uv = luaM_new(L, UpVal, sizeof(UpVal), L->activememcat); /* not found: create a new one */ + uv->tt = LUA_TUPVAL; + uv->marked = luaC_white(g); + uv->memcat = L->activememcat; + uv->v = level; /* current value lives in the stack */ + uv->next = *pp; /* chain it in the proper position */ + *pp = obj2gco(uv); + uv->u.l.prev = &g->uvhead; /* double link it in `uvhead' list */ + uv->u.l.next = g->uvhead.u.l.next; + uv->u.l.next->u.l.prev = uv; + g->uvhead.u.l.next = uv; + LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); + return uv; +} + +static void unlinkupval(UpVal* uv) +{ + LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); + uv->u.l.next->u.l.prev = uv->u.l.prev; /* remove from `uvhead' list */ + uv->u.l.prev->u.l.next = uv->u.l.next; +} + +void luaF_freeupval(lua_State* L, UpVal* uv) +{ + if (uv->v != &uv->u.value) /* is it open? */ + unlinkupval(uv); /* remove from open list */ + luaM_free(L, uv, sizeof(UpVal), uv->memcat); /* free upvalue */ +} + +void luaF_close(lua_State* L, StkId level) +{ + UpVal* uv; + global_State* g = L->global; + while (L->openupval != NULL && (uv = gco2uv(L->openupval))->v >= level) + { + GCObject* o = obj2gco(uv); + LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); + L->openupval = uv->next; /* remove from `open' list */ + if (isdead(g, o)) + luaF_freeupval(L, uv); /* free upvalue */ + else + { + unlinkupval(uv); + setobj(L, &uv->u.value, uv->v); + uv->v = &uv->u.value; /* now current value lives here */ + luaC_linkupval(L, uv); /* link upvalue into `gcroot' list */ + } + } +} + +void luaF_freeproto(lua_State* L, Proto* f) +{ + luaM_freearray(L, f->code, f->sizecode, Instruction, f->memcat); + luaM_freearray(L, f->p, f->sizep, Proto*, f->memcat); + luaM_freearray(L, f->k, f->sizek, TValue, f->memcat); + if (f->lineinfo) + luaM_freearray(L, f->lineinfo, f->sizelineinfo, uint8_t, f->memcat); + luaM_freearray(L, f->locvars, f->sizelocvars, struct LocVar, f->memcat); + luaM_freearray(L, f->upvalues, f->sizeupvalues, TString*, f->memcat); + if (f->debuginsn) + luaM_freearray(L, f->debuginsn, f->sizecode, uint8_t, f->memcat); + luaM_free(L, f, sizeof(Proto), f->memcat); +} + +void luaF_freeclosure(lua_State* L, Closure* c) +{ + int size = c->isC ? sizeCclosure(c->nupvalues) : sizeLclosure(c->nupvalues); + luaM_free(L, c, size, c->memcat); +} + +const LocVar* luaF_getlocal(const Proto* f, int local_number, int pc) +{ + int i; + for (i = 0; i < f->sizelocvars; i++) + { + if (pc >= f->locvars[i].startpc && pc < f->locvars[i].endpc) + { /* is variable active? */ + local_number--; + if (local_number == 0) + return &f->locvars[i]; + } + } + return NULL; /* not found */ +} diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h new file mode 100644 index 0000000..4be2366 --- /dev/null +++ b/VM/src/lfunc.h @@ -0,0 +1,18 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" + +#define sizeCclosure(n) (offsetof(Closure, c.upvals) + sizeof(TValue) * (n)) +#define sizeLclosure(n) (offsetof(Closure, l.uprefs) + sizeof(TValue) * (n)) + +LUAI_FUNC Proto* luaF_newproto(lua_State* L); +LUAI_FUNC Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p); +LUAI_FUNC Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e); +LUAI_FUNC UpVal* luaF_findupval(lua_State* L, StkId level); +LUAI_FUNC void luaF_close(lua_State* L, StkId level); +LUAI_FUNC void luaF_freeproto(lua_State* L, Proto* f); +LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c); +LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv); +LUAI_FUNC const LocVar* luaF_getlocal(const Proto* func, int local_number, int pc); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp new file mode 100644 index 0000000..9b040fb --- /dev/null +++ b/VM/src/lgc.cpp @@ -0,0 +1,1696 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lgc.h" + +#include "lobject.h" +#include "lstate.h" +#include "ltable.h" +#include "lfunc.h" +#include "lstring.h" +#include "ldo.h" + +#include +#include + +LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false) +LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) +LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false) +LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false) +LUAU_FASTFLAG(LuauArrayBoundary) + +#define GC_SWEEPMAX 40 +#define GC_SWEEPCOST 10 + +#define GC_INTERRUPT(state) \ + { \ + void (*interrupt)(lua_State*, int) = g->cb.interrupt; \ + if (LUAU_UNLIKELY(!!interrupt)) \ + interrupt(L, state); \ + } + +#define maskmarks cast_byte(~(bitmask(BLACKBIT) | WHITEBITS)) + +#define makewhite(g, x) ((x)->gch.marked = cast_byte(((x)->gch.marked & maskmarks) | luaC_white(g))) + +#define white2gray(x) reset2bits((x)->gch.marked, WHITE0BIT, WHITE1BIT) +#define black2gray(x) resetbit((x)->gch.marked, BLACKBIT) + +#define stringmark(s) reset2bits((s)->marked, WHITE0BIT, WHITE1BIT) + +#define markvalue(g, o) \ + { \ + checkconsistency(o); \ + if (iscollectable(o) && iswhite(gcvalue(o))) \ + reallymarkobject(g, gcvalue(o)); \ + } + +#define markobject(g, t) \ + { \ + if (iswhite(obj2gco(t))) \ + reallymarkobject(g, obj2gco(t)); \ + } + +static void recordGcStateTime(global_State* g, int startgcstate, double seconds, bool assist) +{ + switch (startgcstate) + { + case GCSpause: + // record root mark time if we have switched to next state + if (g->gcstate == GCSpropagate) + g->gcstats.currcycle.marktime += seconds; + break; + case GCSpropagate: + case GCSpropagateagain: + g->gcstats.currcycle.marktime += seconds; + + // atomic step had to be performed during the switch and it's tracked separately + if (g->gcstate == GCSsweepstring) + g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime; + break; + case GCSsweepstring: + case GCSsweep: + g->gcstats.currcycle.sweeptime += seconds; + break; + } + + if (assist) + g->gcstats.stepassisttimeacc += seconds; + else + g->gcstats.stepexplicittimeacc += seconds; +} + +static void startGcCycleStats(global_State* g) +{ + g->gcstats.currcycle.starttimestamp = lua_clock(); + g->gcstats.currcycle.waittime = g->gcstats.currcycle.starttimestamp - g->gcstats.lastcycle.endtimestamp; +} + +static void finishGcCycleStats(global_State* g) +{ + g->gcstats.currcycle.endtimestamp = lua_clock(); + g->gcstats.currcycle.endtotalsizebytes = g->totalbytes; + + g->gcstats.completedcycles++; + g->gcstats.lastcycle = g->gcstats.currcycle; + g->gcstats.currcycle = GCCycleStats(); + + g->gcstats.cyclestatsacc.markitems += g->gcstats.lastcycle.markitems; + g->gcstats.cyclestatsacc.marktime += g->gcstats.lastcycle.marktime; + g->gcstats.cyclestatsacc.atomictime += g->gcstats.lastcycle.atomictime; + g->gcstats.cyclestatsacc.sweepitems += g->gcstats.lastcycle.sweepitems; + g->gcstats.cyclestatsacc.sweeptime += g->gcstats.lastcycle.sweeptime; +} + +static void removeentry(LuaNode* n) +{ + LUAU_ASSERT(ttisnil(gval(n))); + if (iscollectable(gkey(n))) + setttype(gkey(n), LUA_TDEADKEY); /* dead key; remove it */ +} + +static void reallymarkobject(global_State* g, GCObject* o) +{ + LUAU_ASSERT(iswhite(o) && !isdead(g, o)); + white2gray(o); + switch (o->gch.tt) + { + case LUA_TSTRING: + { + return; + } + case LUA_TUSERDATA: + { + Table* mt = gco2u(o)->metatable; + gray2black(o); /* udata are never gray */ + if (mt) + markobject(g, mt); + return; + } + case LUA_TUPVAL: + { + UpVal* uv = gco2uv(o); + markvalue(g, uv->v); + if (uv->v == &uv->u.value) /* closed? */ + gray2black(o); /* open upvalues are never black */ + return; + } + case LUA_TFUNCTION: + { + gco2cl(o)->gclist = g->gray; + g->gray = o; + break; + } + case LUA_TTABLE: + { + gco2h(o)->gclist = g->gray; + g->gray = o; + break; + } + case LUA_TTHREAD: + { + gco2th(o)->gclist = g->gray; + g->gray = o; + break; + } + case LUA_TPROTO: + { + gco2p(o)->gclist = g->gray; + g->gray = o; + break; + } + default: + LUAU_ASSERT(0); + } +} + +static const char* gettablemode(global_State* g, Table* h) +{ + const TValue* mode = gfasttm(g, h->metatable, TM_MODE); + + if (mode && ttisstring(mode)) + return svalue(mode); + + return NULL; +} + +static int traversetable(global_State* g, Table* h) +{ + int i; + int weakkey = 0; + int weakvalue = 0; + if (h->metatable) + markobject(g, cast_to(Table*, h->metatable)); + + if (FFlag::LuauShrinkWeakTables) + { + /* is there a weak mode? */ + if (const char* modev = gettablemode(g, h)) + { + weakkey = (strchr(modev, 'k') != NULL); + weakvalue = (strchr(modev, 'v') != NULL); + if (weakkey || weakvalue) + { /* is really weak? */ + h->gclist = g->weak; /* must be cleared after GC, ... */ + g->weak = obj2gco(h); /* ... so put in the appropriate list */ + } + } + } + else + { + const TValue* mode = gfasttm(g, h->metatable, TM_MODE); + if (mode && ttisstring(mode)) + { /* is there a weak mode? */ + const char* modev = svalue(mode); + weakkey = (strchr(modev, 'k') != NULL); + weakvalue = (strchr(modev, 'v') != NULL); + if (weakkey || weakvalue) + { /* is really weak? */ + h->gclist = g->weak; /* must be cleared after GC, ... */ + g->weak = obj2gco(h); /* ... so put in the appropriate list */ + } + } + } + + if (weakkey && weakvalue) + return 1; + if (!weakvalue) + { + i = h->sizearray; + while (i--) + markvalue(g, &h->array[i]); + } + i = sizenode(h); + while (i--) + { + LuaNode* n = gnode(h, i); + LUAU_ASSERT(ttype(gkey(n)) != LUA_TDEADKEY || ttisnil(gval(n))); + if (ttisnil(gval(n))) + removeentry(n); /* remove empty entries */ + else + { + LUAU_ASSERT(!ttisnil(gkey(n))); + if (!weakkey) + markvalue(g, gkey(n)); + if (!weakvalue) + markvalue(g, gval(n)); + } + } + return weakkey || weakvalue; +} + +/* +** All marks are conditional because a GC may happen while the +** prototype is still being created +*/ +static void traverseproto(global_State* g, Proto* f) +{ + int i; + if (f->source) + stringmark(f->source); + if (f->debugname) + stringmark(f->debugname); + for (i = 0; i < f->sizek; i++) /* mark literals */ + markvalue(g, &f->k[i]); + for (i = 0; i < f->sizeupvalues; i++) + { /* mark upvalue names */ + if (f->upvalues[i]) + stringmark(f->upvalues[i]); + } + for (i = 0; i < f->sizep; i++) + { /* mark nested protos */ + if (f->p[i]) + markobject(g, f->p[i]); + } + for (i = 0; i < f->sizelocvars; i++) + { /* mark local-variable names */ + if (f->locvars[i].varname) + stringmark(f->locvars[i].varname); + } +} + +static void traverseclosure(global_State* g, Closure* cl) +{ + markobject(g, cl->env); + if (cl->isC) + { + int i; + for (i = 0; i < cl->nupvalues; i++) /* mark its upvalues */ + markvalue(g, &cl->c.upvals[i]); + } + else + { + int i; + LUAU_ASSERT(cl->nupvalues == cl->l.p->nups); + markobject(g, cast_to(Proto*, cl->l.p)); + for (i = 0; i < cl->nupvalues; i++) /* mark its upvalues */ + markvalue(g, &cl->l.uprefs[i]); + } +} + +static void traversestack(global_State* g, lua_State* l, bool clearstack) +{ + markvalue(g, gt(l)); + if (l->namecall) + stringmark(l->namecall); + for (StkId o = l->stack; o < l->top; o++) + markvalue(g, o); + /* final traversal? */ + if (g->gcstate == GCSatomic || (FFlag::LuauGcFullSkipInactiveThreads && clearstack)) + { + StkId stack_end = l->stack + l->stacksize; + for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */ + setnilvalue(o); + } +} + +/* +** traverse one gray object, turning it to black. +** Returns `quantity' traversed. +*/ +static size_t propagatemark(global_State* g) +{ + GCObject* o = g->gray; + LUAU_ASSERT(isgray(o)); + gray2black(o); + switch (o->gch.tt) + { + case LUA_TTABLE: + { + Table* h = gco2h(o); + g->gray = h->gclist; + if (traversetable(g, h)) /* table is weak? */ + black2gray(o); /* keep it gray */ + return sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); + } + case LUA_TFUNCTION: + { + Closure* cl = gco2cl(o); + g->gray = cl->gclist; + traverseclosure(g, cl); + return cl->isC ? sizeCclosure(cl->nupvalues) : sizeLclosure(cl->nupvalues); + } + case LUA_TTHREAD: + { + lua_State* th = gco2th(o); + g->gray = th->gclist; + + if (FFlag::LuauGcFullSkipInactiveThreads) + { + LUAU_ASSERT(!luaC_threadsleeping(th)); + + // threads that are executing and the main thread are not deactivated + bool active = luaC_threadactive(th) || th == th->global->mainthread; + + if (!active && g->gcstate == GCSpropagate) + { + traversestack(g, th, /* clearstack= */ true); + + l_setbit(th->stackstate, THREAD_SLEEPINGBIT); + } + else + { + th->gclist = g->grayagain; + g->grayagain = o; + + black2gray(o); + + traversestack(g, th, /* clearstack= */ false); + } + } + else + { + th->gclist = g->grayagain; + g->grayagain = o; + + black2gray(o); + + traversestack(g, th, /* clearstack= */ false); + } + + return sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; + } + case LUA_TPROTO: + { + Proto* p = gco2p(o); + g->gray = p->gclist; + traverseproto(g, p); + return sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + + sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; + } + default: + LUAU_ASSERT(0); + return 0; + } +} + +static void propagateall(global_State* g) +{ + while (g->gray) + { + propagatemark(g); + } +} + +/* +** The next function tells whether a key or value can be cleared from +** a weak table. Non-collectable objects are never removed from weak +** tables. Strings behave as `values', so are never removed too. for +** other objects: if really collected, cannot keep them. +*/ +static int isobjcleared(GCObject* o) +{ + if (o->gch.tt == LUA_TSTRING) + { + stringmark(&o->ts); /* strings are `values', so are never weak */ + return 0; + } + + return iswhite(o); +} + +#define iscleared(o) (iscollectable(o) && isobjcleared(gcvalue(o))) + +/* +** clear collected entries from weaktables +*/ +static void cleartable(lua_State* L, GCObject* l) +{ + while (l) + { + Table* h = gco2h(l); + int i = h->sizearray; + while (i--) + { + TValue* o = &h->array[i]; + if (iscleared(o)) /* value was collected? */ + setnilvalue(o); /* remove value */ + } + i = sizenode(h); + int activevalues = 0; + while (i--) + { + LuaNode* n = gnode(h, i); + + if (FFlag::LuauShrinkWeakTables) + { + // non-empty entry? + if (!ttisnil(gval(n))) + { + // can we clear key or value? + if (iscleared(gkey(n)) || iscleared(gval(n))) + { + setnilvalue(gval(n)); /* remove value ... */ + removeentry(n); /* remove entry from table */ + } + else + { + activevalues++; + } + } + } + else + { + if (!ttisnil(gval(n)) && /* non-empty entry? */ + (iscleared(gkey(n)) || iscleared(gval(n)))) + { + setnilvalue(gval(n)); /* remove value ... */ + removeentry(n); /* remove entry from table */ + } + } + } + + if (FFlag::LuauShrinkWeakTables) + { + if (const char* modev = gettablemode(L->global, h)) + { + // are we allowed to shrink this weak table? + if (strchr(modev, 's')) + { + // shrink at 37.5% occupancy + if (activevalues < sizenode(h) * 3 / 8) + luaH_resizehash(L, h, activevalues); + } + } + } + + l = h->gclist; + } +} + +static void shrinkstack(lua_State* L) +{ + /* compute used stack - note that we can't use th->top if we're in the middle of vararg call */ + StkId lim = L->top; + for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) + { + LUAU_ASSERT(ci->top <= L->stack_last); + if (lim < ci->top) + lim = ci->top; + } + + /* shrink stack and callinfo arrays if we aren't using most of the space */ + int ci_used = cast_int(L->ci - L->base_ci); /* number of `ci' in use */ + int s_used = cast_int(lim - L->stack); /* part of stack in use */ + if (L->size_ci > LUAI_MAXCALLS) /* handling overflow? */ + return; /* do not touch the stacks */ + if (3 * ci_used < L->size_ci && 2 * BASIC_CI_SIZE < L->size_ci) + luaD_reallocCI(L, L->size_ci / 2); /* still big enough... */ + condhardstacktests(luaD_reallocCI(L, ci_used + 1)); + if (3 * s_used < L->stacksize && 2 * (BASIC_STACK_SIZE + EXTRA_STACK) < L->stacksize) + luaD_reallocstack(L, L->stacksize / 2); /* still big enough... */ + condhardstacktests(luaD_reallocstack(L, s_used)); +} + +static void freeobj(lua_State* L, GCObject* o) +{ + switch (o->gch.tt) + { + case LUA_TPROTO: + luaF_freeproto(L, gco2p(o)); + break; + case LUA_TFUNCTION: + luaF_freeclosure(L, gco2cl(o)); + break; + case LUA_TUPVAL: + luaF_freeupval(L, gco2uv(o)); + break; + case LUA_TTABLE: + luaH_free(L, gco2h(o)); + break; + case LUA_TTHREAD: + LUAU_ASSERT(gco2th(o) != L && gco2th(o) != L->global->mainthread); + luaE_freethread(L, gco2th(o)); + break; + case LUA_TSTRING: + luaS_free(L, gco2ts(o)); + break; + case LUA_TUSERDATA: + luaS_freeudata(L, gco2u(o)); + break; + default: + LUAU_ASSERT(0); + } +} + +#define sweepwholelist(L, p, tc) sweeplist(L, p, SIZE_MAX, tc) + +static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* traversedcount) +{ + GCObject* curr; + global_State* g = L->global; + int deadmask = otherwhite(g); + size_t startcount = count; + LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); /* make sure we never sweep fixed objects */ + while ((curr = *p) != NULL && count-- > 0) + { + int alive = (curr->gch.marked ^ WHITEBITS) & deadmask; + if (curr->gch.tt == LUA_TTHREAD) + { + sweepwholelist(L, &gco2th(curr)->openupval, traversedcount); /* sweep open upvalues */ + + lua_State* th = gco2th(curr); + + if (alive) + { + resetbit(th->stackstate, THREAD_SLEEPINGBIT); + shrinkstack(th); + } + } + if (alive) + { /* not dead? */ + LUAU_ASSERT(!isdead(g, curr)); + makewhite(g, curr); /* make it white (for next cycle) */ + p = &curr->gch.next; + } + else + { /* must erase `curr' */ + LUAU_ASSERT(isdead(g, curr)); + *p = curr->gch.next; + if (curr == g->rootgc) /* is the first element of the list? */ + g->rootgc = curr->gch.next; /* adjust first */ + freeobj(L, curr); + } + } + + // if we didn't reach the end of the list it means that we've stopped because the count dropped below zero + if (traversedcount) + *traversedcount += startcount - (curr ? count + 1 : count); + + return p; +} + +static void deletelist(lua_State* L, GCObject** p, GCObject* limit) +{ + GCObject* curr; + while ((curr = *p) != limit) + { + if (curr->gch.tt == LUA_TTHREAD) /* delete open upvalues of each thread */ + deletelist(L, &gco2th(curr)->openupval, NULL); + + *p = curr->gch.next; + freeobj(L, curr); + } +} + +static void shrinkbuffers(lua_State* L) +{ + global_State* g = L->global; + /* check size of string hash */ + if (g->strt.nuse < cast_to(uint32_t, g->strt.size / 4) && g->strt.size > LUA_MINSTRTABSIZE * 2) + luaS_resize(L, g->strt.size / 2); /* table is too big */ +} + +static void shrinkbuffersfull(lua_State* L) +{ + global_State* g = L->global; + /* check size of string hash */ + int hashsize = g->strt.size; + while (g->strt.nuse < cast_to(uint32_t, hashsize / 4) && hashsize > LUA_MINSTRTABSIZE * 2) + hashsize /= 2; + if (hashsize != g->strt.size) + luaS_resize(L, hashsize); /* table is too big */ +} + +void luaC_freeall(lua_State* L) +{ + global_State* g = L->global; + + LUAU_ASSERT(L == g->mainthread); + LUAU_ASSERT(L->next == NULL); /* mainthread is at the end of rootgc list */ + + deletelist(L, &g->rootgc, obj2gco(L)); + + for (int i = 0; i < g->strt.size; i++) /* free all string lists */ + deletelist(L, &g->strt.hash[i], NULL); + + LUAU_ASSERT(L->global->strt.nuse == 0); + deletelist(L, &g->strbufgc, NULL); + // unfortunately, when string objects are freed, the string table use count is decremented + // even when the string is a buffer that wasn't placed into the table + L->global->strt.nuse = 0; +} + +static void markmt(global_State* g) +{ + int i; + for (i = 0; i < LUA_T_COUNT; i++) + if (g->mt[i]) + markobject(g, g->mt[i]); +} + +/* mark root set */ +static void markroot(lua_State* L) +{ + global_State* g = L->global; + g->gray = NULL; + g->grayagain = NULL; + g->weak = NULL; + markobject(g, g->mainthread); + /* make global table be traversed before main stack */ + markvalue(g, gt(g->mainthread)); + markvalue(g, registry(L)); + markmt(g); + g->gcstate = GCSpropagate; +} + +static void remarkupvals(global_State* g) +{ + UpVal* uv; + for (uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + { + LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); + if (isgray(obj2gco(uv))) + markvalue(g, uv->v); + } +} + +static void atomic(lua_State* L) +{ + global_State* g = L->global; + g->gcstate = GCSatomic; + /* remark occasional upvalues of (maybe) dead threads */ + remarkupvals(g); + /* traverse objects caught by write barrier and by 'remarkupvals' */ + propagateall(g); + /* remark weak tables */ + g->gray = g->weak; + g->weak = NULL; + LUAU_ASSERT(!iswhite(obj2gco(g->mainthread))); + markobject(g, L); /* mark running thread */ + markmt(g); /* mark basic metatables (again) */ + propagateall(g); + /* remark gray again */ + g->gray = g->grayagain; + g->grayagain = NULL; + propagateall(g); + cleartable(L, g->weak); /* remove collected objects from weak tables */ + g->weak = NULL; + /* flip current white */ + g->currentwhite = cast_byte(otherwhite(g)); + g->sweepstrgc = 0; + g->sweepgc = &g->rootgc; + g->gcstate = GCSsweepstring; + + GC_INTERRUPT(GCSatomic); +} + +static size_t singlestep(lua_State* L) +{ + size_t cost = 0; + global_State* g = L->global; + switch (g->gcstate) + { + case GCSpause: + { + markroot(L); /* start a new collection */ + break; + } + case GCSpropagate: + { + if (FFlag::LuauRescanGrayAgain) + { + if (g->gray) + { + g->gcstats.currcycle.markitems++; + + cost = propagatemark(g); + } + else + { + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; + + g->gcstate = GCSpropagateagain; + } + } + else + { + if (g->gray) + { + g->gcstats.currcycle.markitems++; + + cost = propagatemark(g); + } + else /* no more `gray' objects */ + { + double starttimestamp = lua_clock(); + + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + atomic(L); /* finish mark phase */ + + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } + } + break; + } + case GCSpropagateagain: + { + if (g->gray) + { + g->gcstats.currcycle.markitems++; + + cost = propagatemark(g); + } + else /* no more `gray' objects */ + { + double starttimestamp = lua_clock(); + + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + atomic(L); /* finish mark phase */ + + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } + break; + } + case GCSsweepstring: + { + size_t traversedcount = 0; + sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount); + + // nothing more to sweep? + if (g->sweepstrgc >= g->strt.size) + { + // sweep string buffer list and preserve used string count + uint32_t nuse = L->global->strt.nuse; + sweepwholelist(L, &g->strbufgc, &traversedcount); + L->global->strt.nuse = nuse; + + g->gcstate = GCSsweep; // end sweep-string phase + } + + g->gcstats.currcycle.sweepitems += traversedcount; + + cost = GC_SWEEPCOST; + break; + } + case GCSsweep: + { + size_t traversedcount = 0; + g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); + + g->gcstats.currcycle.sweepitems += traversedcount; + + if (*g->sweepgc == NULL) + { /* nothing more to sweep? */ + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ + } + cost = GC_SWEEPMAX * GC_SWEEPCOST; + break; + } + default: + LUAU_ASSERT(0); + } + + return cost; +} + +static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCycleStats* cyclestats) +{ + // adjust for error using Proportional-Integral controller + // https://en.wikipedia.org/wiki/PID_controller + int32_t errorKb = int32_t((cyclestats->atomicstarttotalsizebytes - cyclestats->heapgoalsizebytes) / 1024); + + // we use sliding window for the error integral to avoid error sum 'windup' when the desired target cannot be reached + int32_t* slot = &triggerstats->terms[triggerstats->termpos % triggerstats->termcount]; + int32_t prev = *slot; + *slot = errorKb; + triggerstats->integral += errorKb - prev; + triggerstats->termpos++; + + // controller tuning + // https://en.wikipedia.org/wiki/Ziegler%E2%80%93Nichols_method + const double Ku = 0.9; // ultimate gain (measured) + const double Tu = 2.5; // oscillation period (measured) + + const double Kp = 0.45 * Ku; // proportional gain + const double Ti = 0.8 * Tu; + const double Ki = 0.54 * Ku / Ti; // integral gain + + double proportionalTerm = Kp * errorKb; + double integralTerm = Ki * triggerstats->integral; + + double totalTerm = proportionalTerm + integralTerm; + + return int64_t(totalTerm * 1024); +} + +static size_t getheaptrigger(global_State* g, size_t heapgoal) +{ + GCCycleStats* lastcycle = &g->gcstats.lastcycle; + GCCycleStats* currcycle = &g->gcstats.currcycle; + + // adjust threshold based on a guess of how many bytes will be allocated between the cycle start and sweep phase + // our goal is to begin the sweep when used memory has reached the heap goal + const double durationthreshold = 1e-3; + double allocationduration = currcycle->atomicstarttimestamp - lastcycle->endtimestamp; + + // avoid measuring intervals smaller than 1ms + if (allocationduration < durationthreshold) + return heapgoal; + + double allocationrate = (currcycle->atomicstarttotalsizebytes - lastcycle->endtotalsizebytes) / allocationduration; + double markduration = currcycle->atomicstarttimestamp - currcycle->starttimestamp; + + int64_t expectedgrowth = int64_t(markduration * allocationrate); + int64_t offset = getheaptriggererroroffset(&g->gcstats.triggerstats, currcycle); + int64_t heaptrigger = heapgoal - (expectedgrowth + offset); + + // clamp the trigger between memory use at the end of the cycle and the heap goal + return heaptrigger < int64_t(g->totalbytes) ? g->totalbytes : (heaptrigger > int64_t(heapgoal) ? heapgoal : size_t(heaptrigger)); +} + +void luaC_step(lua_State* L, bool assist) +{ + global_State* g = L->global; + ptrdiff_t lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ + LUAU_ASSERT(g->totalbytes >= g->GCthreshold); + size_t debt = g->totalbytes - g->GCthreshold; + + GC_INTERRUPT(0); + + // at the start of the new cycle + if (g->gcstate == GCSpause) + startGcCycleStats(g); + + if (assist) + g->gcstats.currcycle.assistwork += lim; + else + g->gcstats.currcycle.explicitwork += lim; + + int lastgcstate = g->gcstate; + double lastttimestamp = lua_clock(); + + // always perform at least one single step + do + { + lim -= singlestep(L); + + // if we have switched to a different state, capture the duration of last stage + // this way we reduce the number of timer calls we make + if (lastgcstate != g->gcstate) + { + GC_INTERRUPT(lastgcstate); + + double now = lua_clock(); + + recordGcStateTime(g, lastgcstate, now - lastttimestamp, assist); + + lastttimestamp = now; + lastgcstate = g->gcstate; + } + } while (lim > 0 && g->gcstate != GCSpause); + + recordGcStateTime(g, lastgcstate, lua_clock() - lastttimestamp, assist); + + // at the end of the last cycle + if (g->gcstate == GCSpause) + { + // at the end of a collection cycle, set goal based on gcgoal setting + size_t heapgoal = (g->totalbytes / 100) * g->gcgoal; + size_t heaptrigger = getheaptrigger(g, heapgoal); + + g->GCthreshold = heaptrigger; + + finishGcCycleStats(g); + + g->gcstats.currcycle.heapgoalsizebytes = heapgoal; + g->gcstats.currcycle.heaptriggersizebytes = heaptrigger; + } + else + { + g->GCthreshold = g->totalbytes + g->gcstepsize; + + // compensate if GC is "behind schedule" (has some debt to pay) + if (g->GCthreshold > debt) + g->GCthreshold -= debt; + } + + GC_INTERRUPT(g->gcstate); +} + +void luaC_fullgc(lua_State* L) +{ + global_State* g = L->global; + + if (g->gcstate == GCSpause) + startGcCycleStats(g); + + if (g->gcstate <= GCSpropagateagain) + { + /* reset sweep marks to sweep all elements (returning them to white) */ + g->sweepstrgc = 0; + g->sweepgc = &g->rootgc; + /* reset other collector lists */ + g->gray = NULL; + g->grayagain = NULL; + g->weak = NULL; + g->gcstate = GCSsweepstring; + } + LUAU_ASSERT(g->gcstate != GCSpause && g->gcstate != GCSpropagate && g->gcstate != GCSpropagateagain); + /* finish any pending sweep phase */ + while (g->gcstate != GCSpause) + { + LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + singlestep(L); + } + + finishGcCycleStats(g); + + /* run a full collection cycle */ + startGcCycleStats(g); + + markroot(L); + while (g->gcstate != GCSpause) + { + singlestep(L); + } + /* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */ + shrinkbuffersfull(L); + + size_t heapgoalsizebytes = (g->totalbytes / 100) * g->gcgoal; + + // trigger cannot be correctly adjusted after a forced full GC. + // we will try to place it so that we can reach the goal based on + // the rate at which we run the GC relative to allocation rate + // and on amount of bytes we need to traverse in propagation stage. + // goal and stepmul are defined in percents + g->GCthreshold = g->totalbytes * (g->gcgoal * g->gcstepmul / 100 - 100) / g->gcstepmul; + + // but it might be impossible to satisfy that directly + if (g->GCthreshold < g->totalbytes) + g->GCthreshold = g->totalbytes; + + finishGcCycleStats(g); + + g->gcstats.currcycle.heapgoalsizebytes = heapgoalsizebytes; + g->gcstats.currcycle.heaptriggersizebytes = g->GCthreshold; +} + +void luaC_barrierupval(lua_State* L, GCObject* v) +{ + if (FFlag::LuauGcFullSkipInactiveThreads) + { + global_State* g = L->global; + LUAU_ASSERT(iswhite(v) && !isdead(g, v)); + + if (keepinvariant(g)) + reallymarkobject(g, v); + } +} + +void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v) +{ + global_State* g = L->global; + LUAU_ASSERT(isblack(o) && iswhite(v) && !isdead(g, v) && !isdead(g, o)); + LUAU_ASSERT(g->gcstate != GCSpause); + /* must keep invariant? */ + if (keepinvariant(g)) + reallymarkobject(g, v); /* restore invariant */ + else /* don't mind */ + makewhite(g, o); /* mark as white just to avoid other barriers */ +} + +void luaC_barriertable(lua_State* L, Table* t, GCObject* v) +{ + global_State* g = L->global; + GCObject* o = obj2gco(t); + + // in the second propagation stage, table assignment barrier works as a forward barrier + if (FFlag::LuauRescanGrayAgainForwardBarrier && g->gcstate == GCSpropagateagain) + { + LUAU_ASSERT(isblack(o) && iswhite(v) && !isdead(g, v) && !isdead(g, o)); + reallymarkobject(g, v); + return; + } + + LUAU_ASSERT(isblack(o) && !isdead(g, o)); + LUAU_ASSERT(g->gcstate != GCSpause); + black2gray(o); /* make table gray (again) */ + t->gclist = g->grayagain; + g->grayagain = o; +} + +void luaC_barrierback(lua_State* L, Table* t) +{ + global_State* g = L->global; + GCObject* o = obj2gco(t); + LUAU_ASSERT(isblack(o) && !isdead(g, o)); + LUAU_ASSERT(g->gcstate != GCSpause); + black2gray(o); /* make table gray (again) */ + t->gclist = g->grayagain; + g->grayagain = o; +} + +void luaC_linkobj(lua_State* L, GCObject* o, uint8_t tt) +{ + global_State* g = L->global; + o->gch.next = g->rootgc; + g->rootgc = o; + o->gch.marked = luaC_white(g); + o->gch.tt = tt; + o->gch.memcat = L->activememcat; +} + +void luaC_linkupval(lua_State* L, UpVal* uv) +{ + global_State* g = L->global; + GCObject* o = obj2gco(uv); + o->gch.next = g->rootgc; /* link upvalue into `rootgc' list */ + g->rootgc = o; + if (isgray(o)) + { + if (keepinvariant(g)) + { + gray2black(o); /* closed upvalues need barrier */ + luaC_barrier(L, uv, uv->v); + } + else + { /* sweep phase: sweep it (turning it into white) */ + makewhite(g, o); + LUAU_ASSERT(g->gcstate != GCSpause); + } + } +} + +static void validateobjref(global_State* g, GCObject* f, GCObject* t) +{ + LUAU_ASSERT(!isdead(g, t)); + + if (keepinvariant(g)) + { + /* basic incremental invariant: black can't point to white */ + LUAU_ASSERT(!(isblack(f) && iswhite(t))); + } +} + +static void validateref(global_State* g, GCObject* f, TValue* v) +{ + if (iscollectable(v)) + { + LUAU_ASSERT(ttype(v) == gcvalue(v)->gch.tt); + validateobjref(g, f, gcvalue(v)); + } +} + +static void validatetable(global_State* g, Table* h) +{ + int sizenode = 1 << h->lsizenode; + + if (FFlag::LuauArrayBoundary) + LUAU_ASSERT(h->lastfree <= sizenode); + else + LUAU_ASSERT(h->lastfree >= 0 && h->lastfree <= sizenode); + + if (h->metatable) + validateobjref(g, obj2gco(h), obj2gco(h->metatable)); + + for (int i = 0; i < h->sizearray; ++i) + validateref(g, obj2gco(h), &h->array[i]); + + for (int i = 0; i < sizenode; ++i) + { + LuaNode* n = &h->node[i]; + + LUAU_ASSERT(ttype(gkey(n)) != LUA_TDEADKEY || ttisnil(gval(n))); + LUAU_ASSERT(i + gnext(n) >= 0 && i + gnext(n) < sizenode); + + if (!ttisnil(gval(n))) + { + TValue k = {}; + k.tt = gkey(n)->tt; + k.value = gkey(n)->value; + + validateref(g, obj2gco(h), &k); + validateref(g, obj2gco(h), gval(n)); + } + } +} + +static void validateclosure(global_State* g, Closure* cl) +{ + validateobjref(g, obj2gco(cl), obj2gco(cl->env)); + + if (cl->isC) + { + for (int i = 0; i < cl->nupvalues; ++i) + validateref(g, obj2gco(cl), &cl->c.upvals[i]); + } + else + { + LUAU_ASSERT(cl->nupvalues == cl->l.p->nups); + + validateobjref(g, obj2gco(cl), obj2gco(cl->l.p)); + + for (int i = 0; i < cl->nupvalues; ++i) + validateref(g, obj2gco(cl), &cl->l.uprefs[i]); + } +} + +static void validatestack(global_State* g, lua_State* l) +{ + validateref(g, obj2gco(l), gt(l)); + + for (CallInfo* ci = l->base_ci; ci <= l->ci; ++ci) + { + LUAU_ASSERT(l->stack <= ci->base); + LUAU_ASSERT(ci->func <= ci->base && ci->base <= ci->top); + LUAU_ASSERT(ci->top <= l->stack_last); + } + + // note: stack refs can violate gc invariant so we only check for liveness + for (StkId o = l->stack; o < l->top; ++o) + checkliveness(g, o); + + if (l->namecall) + validateobjref(g, obj2gco(l), obj2gco(l->namecall)); + + for (GCObject* uv = l->openupval; uv; uv = uv->gch.next) + { + LUAU_ASSERT(uv->gch.tt == LUA_TUPVAL); + LUAU_ASSERT(gco2uv(uv)->v != &gco2uv(uv)->u.value); + } +} + +static void validateproto(global_State* g, Proto* f) +{ + if (f->source) + validateobjref(g, obj2gco(f), obj2gco(f->source)); + + if (f->debugname) + validateobjref(g, obj2gco(f), obj2gco(f->debugname)); + + for (int i = 0; i < f->sizek; ++i) + validateref(g, obj2gco(f), &f->k[i]); + + for (int i = 0; i < f->sizeupvalues; ++i) + if (f->upvalues[i]) + validateobjref(g, obj2gco(f), obj2gco(f->upvalues[i])); + + for (int i = 0; i < f->sizep; ++i) + if (f->p[i]) + validateobjref(g, obj2gco(f), obj2gco(f->p[i])); + + for (int i = 0; i < f->sizelocvars; i++) + if (f->locvars[i].varname) + validateobjref(g, obj2gco(f), obj2gco(f->locvars[i].varname)); +} + +static void validateobj(global_State* g, GCObject* o) +{ + /* dead objects can only occur during sweep */ + if (isdead(g, o)) + { + LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + return; + } + + switch (o->gch.tt) + { + case LUA_TSTRING: + break; + + case LUA_TTABLE: + validatetable(g, gco2h(o)); + break; + + case LUA_TFUNCTION: + validateclosure(g, gco2cl(o)); + break; + + case LUA_TUSERDATA: + if (gco2u(o)->metatable) + validateobjref(g, o, obj2gco(gco2u(o)->metatable)); + break; + + case LUA_TTHREAD: + validatestack(g, gco2th(o)); + break; + + case LUA_TPROTO: + validateproto(g, gco2p(o)); + break; + + case LUA_TUPVAL: + validateref(g, o, gco2uv(o)->v); + break; + + default: + LUAU_ASSERT(!"unexpected object type"); + } +} + +static void validatelist(global_State* g, GCObject* o) +{ + while (o) + { + validateobj(g, o); + + o = o->gch.next; + } +} + +static void validategraylist(global_State* g, GCObject* o) +{ + if (!keepinvariant(g)) + return; + + while (o) + { + LUAU_ASSERT(isgray(o)); + + switch (o->gch.tt) + { + case LUA_TTABLE: + o = gco2h(o)->gclist; + break; + case LUA_TFUNCTION: + o = gco2cl(o)->gclist; + break; + case LUA_TTHREAD: + o = gco2th(o)->gclist; + break; + case LUA_TPROTO: + o = gco2p(o)->gclist; + break; + default: + LUAU_ASSERT(!"unknown object in gray list"); + return; + } + } +} + +void luaC_validate(lua_State* L) +{ + global_State* g = L->global; + + LUAU_ASSERT(!isdead(g, obj2gco(g->mainthread))); + checkliveness(g, &g->registry); + + for (int i = 0; i < LUA_T_COUNT; ++i) + if (g->mt[i]) + LUAU_ASSERT(!isdead(g, obj2gco(g->mt[i]))); + + validategraylist(g, g->weak); + validategraylist(g, g->gray); + validategraylist(g, g->grayagain); + + for (int i = 0; i < g->strt.size; ++i) + validatelist(g, g->strt.hash[i]); + + validatelist(g, g->rootgc); + validatelist(g, g->strbufgc); + + for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + { + LUAU_ASSERT(uv->tt == LUA_TUPVAL); + LUAU_ASSERT(uv->v != &uv->u.value); + LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); + } +} + +inline bool safejson(char ch) +{ + return unsigned(ch) < 128 && ch >= 32 && ch != '\\' && ch != '\"'; +} + +static void dumpref(FILE* f, GCObject* o) +{ + fprintf(f, "\"%p\"", o); +} + +static void dumprefs(FILE* f, TValue* data, size_t size) +{ + bool first = true; + + for (size_t i = 0; i < size; ++i) + { + if (iscollectable(&data[i])) + { + if (!first) + fputc(',', f); + first = false; + + dumpref(f, gcvalue(&data[i])); + } + } +} + +static void dumpstringdata(FILE* f, const char* data, size_t len) +{ + for (size_t i = 0; i < len; ++i) + fputc(safejson(data[i]) ? data[i] : '?', f); +} + +static void dumpstring(FILE* f, TString* ts) +{ + fprintf(f, "{\"type\":\"string\",\"cat\":%d,\"size\":%d,\"data\":\"", ts->memcat, int(sizestring(ts->len))); + dumpstringdata(f, ts->data, ts->len); + fprintf(f, "\"}"); +} + +static void dumptable(FILE* f, Table* h) +{ + size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); + + fprintf(f, "{\"type\":\"table\",\"cat\":%d,\"size\":%d", h->memcat, int(size)); + + if (h->node != &luaH_dummynode) + { + fprintf(f, ",\"pairs\":["); + + bool first = true; + + for (int i = 0; i < sizenode(h); ++i) + { + const LuaNode& n = h->node[i]; + + if (!ttisnil(&n.val) && (iscollectable(&n.key) || iscollectable(&n.val))) + { + if (!first) + fputc(',', f); + first = false; + + if (iscollectable(&n.key)) + dumpref(f, gcvalue(&n.key)); + else + fprintf(f, "null"); + + fputc(',', f); + + if (iscollectable(&n.val)) + dumpref(f, gcvalue(&n.val)); + else + fprintf(f, "null"); + } + } + + fprintf(f, "]"); + } + if (h->sizearray) + { + fprintf(f, ",\"array\":["); + dumprefs(f, h->array, h->sizearray); + fprintf(f, "]"); + } + if (h->metatable) + { + fprintf(f, ",\"metatable\":"); + dumpref(f, obj2gco(h->metatable)); + } + fprintf(f, "}"); +} + +static void dumpclosure(FILE* f, Closure* cl) +{ + fprintf(f, "{\"type\":\"function\",\"cat\":%d,\"size\":%d", cl->memcat, + cl->isC ? int(sizeCclosure(cl->nupvalues)) : int(sizeLclosure(cl->nupvalues))); + + fprintf(f, ",\"env\":"); + dumpref(f, obj2gco(cl->env)); + if (cl->isC) + { + if (cl->nupvalues) + { + fprintf(f, ",\"upvalues\":["); + dumprefs(f, cl->c.upvals, cl->nupvalues); + fprintf(f, "]"); + } + } + else + { + fprintf(f, ",\"proto\":"); + dumpref(f, obj2gco(cl->l.p)); + if (cl->nupvalues) + { + fprintf(f, ",\"upvalues\":["); + dumprefs(f, cl->l.uprefs, cl->nupvalues); + fprintf(f, "]"); + } + } + fprintf(f, "}"); +} + +static void dumpudata(FILE* f, Udata* u) +{ + fprintf(f, "{\"type\":\"userdata\",\"cat\":%d,\"size\":%d,\"tag\":%d", u->memcat, int(sizeudata(u->len)), u->tag); + + if (u->metatable) + { + fprintf(f, ",\"metatable\":"); + dumpref(f, obj2gco(u->metatable)); + } + fprintf(f, "}"); +} + +static void dumpthread(FILE* f, lua_State* th) +{ + size_t size = sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; + + fprintf(f, "{\"type\":\"thread\",\"cat\":%d,\"size\":%d", th->memcat, int(size)); + + if (iscollectable(&th->l_gt)) + { + fprintf(f, ",\"env\":"); + dumpref(f, gcvalue(&th->l_gt)); + } + + Closure* tcl = 0; + for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) + { + if (ttisfunction(ci->func)) + { + tcl = clvalue(ci->func); + break; + } + } + + if (tcl && !tcl->isC && tcl->l.p->source) + { + Proto* p = tcl->l.p; + + fprintf(f, ",\"source\":\""); + dumpstringdata(f, p->source->data, p->source->len); + fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); + } + + if (th->top > th->stack) + { + fprintf(f, ",\"stack\":["); + dumprefs(f, th->stack, th->top - th->stack); + fprintf(f, "]"); + } + fprintf(f, "}"); +} + +static void dumpproto(FILE* f, Proto* p) +{ + size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + + sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; + + fprintf(f, "{\"type\":\"proto\",\"cat\":%d,\"size\":%d", p->memcat, int(size)); + + if (p->source) + { + fprintf(f, ",\"source\":\""); + dumpstringdata(f, p->source->data, p->source->len); + fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); + } + + if (p->sizek) + { + fprintf(f, ",\"constants\":["); + dumprefs(f, p->k, p->sizek); + fprintf(f, "]"); + } + + if (p->sizep) + { + fprintf(f, ",\"protos\":["); + for (int i = 0; i < p->sizep; ++i) + { + if (i != 0) + fputc(',', f); + dumpref(f, obj2gco(p->p[i])); + } + fprintf(f, "]"); + } + + fprintf(f, "}"); +} + +static void dumpupval(FILE* f, UpVal* uv) +{ + fprintf(f, "{\"type\":\"upvalue\",\"cat\":%d,\"size\":%d", uv->memcat, int(sizeof(UpVal))); + + if (iscollectable(uv->v)) + { + fprintf(f, ",\"object\":"); + dumpref(f, gcvalue(uv->v)); + } + fprintf(f, "}"); +} + +static void dumpobj(FILE* f, GCObject* o) +{ + switch (o->gch.tt) + { + case LUA_TSTRING: + return dumpstring(f, gco2ts(o)); + + case LUA_TTABLE: + return dumptable(f, gco2h(o)); + + case LUA_TFUNCTION: + return dumpclosure(f, gco2cl(o)); + + case LUA_TUSERDATA: + return dumpudata(f, gco2u(o)); + + case LUA_TTHREAD: + return dumpthread(f, gco2th(o)); + + case LUA_TPROTO: + return dumpproto(f, gco2p(o)); + + case LUA_TUPVAL: + return dumpupval(f, gco2uv(o)); + + default: + LUAU_ASSERT(0); + } +} + +static void dumplist(FILE* f, GCObject* o) +{ + while (o) + { + dumpref(f, o); + fputc(':', f); + dumpobj(f, o); + fputc(',', f); + fputc('\n', f); + + // thread has additional list containing collectable objects that are not present in rootgc + if (o->gch.tt == LUA_TTHREAD) + dumplist(f, gco2th(o)->openupval); + + o = o->gch.next; + } +} + +void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)) +{ + global_State* g = L->global; + FILE* f = static_cast(file); + + fprintf(f, "{\"objects\":{\n"); + dumplist(f, g->rootgc); + dumplist(f, g->strbufgc); + for (int i = 0; i < g->strt.size; ++i) + dumplist(f, g->strt.hash[i]); + + fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , + fprintf(f, "},\"roots\":{\n"); + fprintf(f, "\"mainthread\":"); + dumpref(f, obj2gco(g->mainthread)); + fprintf(f, ",\"registry\":"); + dumpref(f, gcvalue(&g->registry)); + + fprintf(f, "},\"stats\":{\n"); + + fprintf(f, "\"size\":%d,\n", int(g->totalbytes)); + + fprintf(f, "\"categories\":{\n"); + for (int i = 0; i < LUA_MEMORY_CATEGORIES; i++) + { + if (size_t bytes = g->memcatbytes[i]) + { + if (categoryName) + fprintf(f, "\"%d\":{\"name\":\"%s\", \"size\":%d},\n", i, categoryName(L, i), int(bytes)); + else + fprintf(f, "\"%d\":{\"size\":%d},\n", i, int(bytes)); + } + } + fprintf(f, "\"none\":{}\n"); // to avoid issues with trailing , + fprintf(f, "}\n"); + fprintf(f, "}}\n"); +} + +// measure the allocation rate in bytes/sec +// returns -1 if allocation rate cannot be measured +int64_t luaC_allocationrate(lua_State* L) +{ + global_State* g = L->global; + const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms + + if (g->gcstate <= GCSpropagateagain) + { + double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; + + if (duration < durationthreshold) + return -1; + + return int64_t((g->totalbytes - g->gcstats.lastcycle.endtotalsizebytes) / duration); + } + + // totalbytes is unstable during the sweep, use the rate measured at the end of mark phase + double duration = g->gcstats.currcycle.atomicstarttimestamp - g->gcstats.lastcycle.endtimestamp; + + if (duration < durationthreshold) + return -1; + + return int64_t((g->gcstats.currcycle.atomicstarttotalsizebytes - g->gcstats.lastcycle.endtotalsizebytes) / duration); +} + +void luaC_wakethread(lua_State* L) +{ + if (!luaC_threadsleeping(L)) + return; + + global_State* g = L->global; + + resetbit(L->stackstate, THREAD_SLEEPINGBIT); + + if (keepinvariant(g)) + { + GCObject* o = obj2gco(L); + + L->gclist = g->grayagain; + g->grayagain = o; + + black2gray(o); + } +} + +const char* luaC_statename(int state) +{ + switch (state) + { + case GCSpause: + return "pause"; + + case GCSpropagate: + return "mark"; + + case GCSpropagateagain: + return "remark"; + + case GCSatomic: + return "atomic"; + + case GCSsweepstring: + return "sweepstring"; + + case GCSsweep: + return "sweep"; + + default: + return NULL; + } +} diff --git a/VM/src/lgc.h b/VM/src/lgc.h new file mode 100644 index 0000000..dc780bb --- /dev/null +++ b/VM/src/lgc.h @@ -0,0 +1,150 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "ldo.h" +#include "lobject.h" +#include "lstate.h" + +LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) + +/* +** Possible states of the Garbage Collector +*/ +#define GCSpause 0 +#define GCSpropagate 1 +#define GCSpropagateagain 2 +#define GCSatomic 3 +#define GCSsweepstring 4 +#define GCSsweep 5 + +/* +** macro to tell when main invariant (white objects cannot point to black +** ones) must be kept. During a collection, the sweep +** phase may break the invariant, as objects turned white may point to +** still-black objects. The invariant is restored when sweep ends and +** all objects are white again. +*/ +#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain) + +/* +** some userful bit tricks +*/ +#define resetbits(x, m) ((x) &= cast_to(uint8_t, ~(m))) +#define setbits(x, m) ((x) |= (m)) +#define testbits(x, m) ((x) & (m)) +#define bitmask(b) (1 << (b)) +#define bit2mask(b1, b2) (bitmask(b1) | bitmask(b2)) +#define l_setbit(x, b) setbits(x, bitmask(b)) +#define resetbit(x, b) resetbits(x, bitmask(b)) +#define testbit(x, b) testbits(x, bitmask(b)) +#define set2bits(x, b1, b2) setbits(x, (bit2mask(b1, b2))) +#define reset2bits(x, b1, b2) resetbits(x, (bit2mask(b1, b2))) +#define test2bits(x, b1, b2) testbits(x, (bit2mask(b1, b2))) + +/* +** Layout for bit use in `marked' field: +** bit 0 - object is white (type 0) +** bit 1 - object is white (type 1) +** bit 2 - object is black +** bit 3 - object is fixed (should not be collected) +*/ + +#define WHITE0BIT 0 +#define WHITE1BIT 1 +#define BLACKBIT 2 +#define FIXEDBIT 3 +#define WHITEBITS bit2mask(WHITE0BIT, WHITE1BIT) + +#define iswhite(x) test2bits((x)->gch.marked, WHITE0BIT, WHITE1BIT) +#define isblack(x) testbit((x)->gch.marked, BLACKBIT) +#define isgray(x) (!testbits((x)->gch.marked, WHITEBITS | bitmask(BLACKBIT))) +#define isfixed(x) testbit((x)->gch.marked, FIXEDBIT) + +#define otherwhite(g) (g->currentwhite ^ WHITEBITS) +#define isdead(g, v) (((v)->gch.marked & (WHITEBITS | bitmask(FIXEDBIT))) == (otherwhite(g) & WHITEBITS)) + +#define changewhite(x) ((x)->gch.marked ^= WHITEBITS) +#define gray2black(x) l_setbit((x)->gch.marked, BLACKBIT) + +#define luaC_white(g) cast_to(uint8_t, ((g)->currentwhite) & WHITEBITS) + +// Thread stack states +#define THREAD_ACTIVEBIT 0 // thread is currently active +#define THREAD_SLEEPINGBIT 1 // thread is not executing and stack should not be modified + +#define luaC_threadactive(L) (testbit((L)->stackstate, THREAD_ACTIVEBIT)) +#define luaC_threadsleeping(L) (testbit((L)->stackstate, THREAD_SLEEPINGBIT)) + +#define luaC_checkGC(L) \ + { \ + condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK - 1)); \ + if (L->global->totalbytes >= L->global->GCthreshold) \ + { \ + condhardmemtests(luaC_validate(L), 1); \ + luaC_step(L, true); \ + } \ + else \ + { \ + condhardmemtests(luaC_validate(L), 2); \ + } \ + } + +#define luaC_barrier(L, p, v) \ + { \ + if (iscollectable(v) && isblack(obj2gco(p)) && iswhite(gcvalue(v))) \ + luaC_barrierf(L, obj2gco(p), gcvalue(v)); \ + } + +#define luaC_barriert(L, t, v) \ + { \ + if (iscollectable(v) && isblack(obj2gco(t)) && iswhite(gcvalue(v))) \ + luaC_barriertable(L, t, gcvalue(v)); \ + } + +#define luaC_barrierfast(L, t) \ + { \ + if (isblack(obj2gco(t))) \ + luaC_barrierback(L, t); \ + } + +#define luaC_objbarrier(L, p, o) \ + { \ + if (isblack(obj2gco(p)) && iswhite(obj2gco(o))) \ + luaC_barrierf(L, obj2gco(p), obj2gco(o)); \ + } + +#define luaC_objbarriert(L, t, o) \ + { \ + if (isblack(obj2gco(t)) && iswhite(obj2gco(o))) \ + luaC_barriertable(L, t, obj2gco(o)); \ + } + +#define luaC_upvalbarrier(L, uv, tv) \ + { \ + if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || ((UpVal*)uv)->v != &((UpVal*)uv)->u.value)) \ + luaC_barrierupval(L, gcvalue(tv)); \ + } + +#define luaC_checkthreadsleep(L) \ + { \ + if (luaC_threadsleeping(L)) \ + luaC_wakethread(L); \ + } + +#define luaC_link(L, o, tt) luaC_linkobj(L, cast_to(GCObject*, (o)), tt) + +LUAI_FUNC void luaC_freeall(lua_State* L); +LUAI_FUNC void luaC_step(lua_State* L, bool assist); +LUAI_FUNC void luaC_fullgc(lua_State* L); +LUAI_FUNC void luaC_linkobj(lua_State* L, GCObject* o, uint8_t tt); +LUAI_FUNC void luaC_linkupval(lua_State* L, UpVal* uv); +LUAI_FUNC void luaC_barrierupval(lua_State* L, GCObject* v); +LUAI_FUNC void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v); +LUAI_FUNC void luaC_barriertable(lua_State* L, Table* t, GCObject* v); +LUAI_FUNC void luaC_barrierback(lua_State* L, Table* t); +LUAI_FUNC void luaC_validate(lua_State* L); +LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); +LUAI_FUNC int64_t luaC_allocationrate(lua_State* L); +LUAI_FUNC void luaC_wakethread(lua_State* L); +LUAI_FUNC const char* luaC_statename(int state); \ No newline at end of file diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp new file mode 100644 index 0000000..bf5e738 --- /dev/null +++ b/VM/src/linit.cpp @@ -0,0 +1,87 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include + +static const luaL_Reg lualibs[] = { + {"", luaopen_base}, + {LUA_COLIBNAME, luaopen_coroutine}, + {LUA_TABLIBNAME, luaopen_table}, + {LUA_OSLIBNAME, luaopen_os}, + {LUA_STRLIBNAME, luaopen_string}, + {LUA_MATHLIBNAME, luaopen_math}, + {LUA_DBLIBNAME, luaopen_debug}, + {LUA_UTF8LIBNAME, luaopen_utf8}, + {LUA_BITLIBNAME, luaopen_bit32}, + {NULL, NULL}, +}; + +LUALIB_API void luaL_openlibs(lua_State* L) +{ + const luaL_Reg* lib = lualibs; + for (; lib->func; lib++) + { + lua_pushcfunction(L, lib->func); + lua_pushstring(L, lib->name); + lua_call(L, 1, 0); + } +} + +LUALIB_API void luaL_sandbox(lua_State* L) +{ + // set all libraries to read-only + lua_pushnil(L); + while (lua_next(L, LUA_GLOBALSINDEX) != 0) + { + if (lua_istable(L, -1)) + lua_setreadonly(L, -1, true); + + lua_pop(L, 1); + } + + // set all builtin metatables to read-only + lua_pushliteral(L, ""); + lua_getmetatable(L, -1); + lua_setreadonly(L, -1, true); + lua_pop(L, 1); + + // set globals to readonly and activate safeenv since the env is immutable + lua_setreadonly(L, LUA_GLOBALSINDEX, true); + lua_setsafeenv(L, LUA_GLOBALSINDEX, true); +} + +LUALIB_API void luaL_sandboxthread(lua_State* L) +{ + // create new global table that proxies reads to original table + lua_newtable(L); + + lua_newtable(L); + lua_pushvalue(L, LUA_GLOBALSINDEX); + lua_setfield(L, -2, "__index"); + lua_setreadonly(L, -1, true); + + lua_setmetatable(L, -2); + + // we can set safeenv now although it's important to set it to false if code is loaded twice into the thread + lua_replace(L, LUA_GLOBALSINDEX); + lua_setsafeenv(L, LUA_GLOBALSINDEX, true); +} + +static void* l_alloc(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsize) +{ + (void)ud; + (void)osize; + if (nsize == 0) + { + free(ptr); + return NULL; + } + else + return realloc(ptr, nsize); +} + +LUALIB_API lua_State* luaL_newstate(void) +{ + return lua_newstate(l_alloc, NULL); +} diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp new file mode 100644 index 0000000..8e476a5 --- /dev/null +++ b/VM/src/lmathlib.cpp @@ -0,0 +1,446 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include "lstate.h" + +#include +#include + +#undef PI +#define PI (3.14159265358979323846) +#define RADIANS_PER_DEGREE (PI / 180.0) + +#define PCG32_INC 105 + +static uint32_t pcg32_random(uint64_t* state) +{ + uint64_t oldstate = *state; + *state = oldstate * 6364136223846793005ULL + (PCG32_INC | 1); + uint32_t xorshifted = uint32_t(((oldstate >> 18u) ^ oldstate) >> 27u); + uint32_t rot = uint32_t(oldstate >> 59u); + return (xorshifted >> rot) | (xorshifted << ((-int32_t(rot)) & 31)); +} + +static void pcg32_seed(uint64_t* state, uint64_t seed) +{ + *state = 0; + pcg32_random(state); + *state += seed; + pcg32_random(state); +} + +static int math_abs(lua_State* L) +{ + lua_pushnumber(L, fabs(luaL_checknumber(L, 1))); + return 1; +} + +static int math_sin(lua_State* L) +{ + lua_pushnumber(L, sin(luaL_checknumber(L, 1))); + return 1; +} + +static int math_sinh(lua_State* L) +{ + lua_pushnumber(L, sinh(luaL_checknumber(L, 1))); + return 1; +} + +static int math_cos(lua_State* L) +{ + lua_pushnumber(L, cos(luaL_checknumber(L, 1))); + return 1; +} + +static int math_cosh(lua_State* L) +{ + lua_pushnumber(L, cosh(luaL_checknumber(L, 1))); + return 1; +} + +static int math_tan(lua_State* L) +{ + lua_pushnumber(L, tan(luaL_checknumber(L, 1))); + return 1; +} + +static int math_tanh(lua_State* L) +{ + lua_pushnumber(L, tanh(luaL_checknumber(L, 1))); + return 1; +} + +static int math_asin(lua_State* L) +{ + lua_pushnumber(L, asin(luaL_checknumber(L, 1))); + return 1; +} + +static int math_acos(lua_State* L) +{ + lua_pushnumber(L, acos(luaL_checknumber(L, 1))); + return 1; +} + +static int math_atan(lua_State* L) +{ + lua_pushnumber(L, atan(luaL_checknumber(L, 1))); + return 1; +} + +static int math_atan2(lua_State* L) +{ + lua_pushnumber(L, atan2(luaL_checknumber(L, 1), luaL_checknumber(L, 2))); + return 1; +} + +static int math_ceil(lua_State* L) +{ + lua_pushnumber(L, ceil(luaL_checknumber(L, 1))); + return 1; +} + +static int math_floor(lua_State* L) +{ + lua_pushnumber(L, floor(luaL_checknumber(L, 1))); + return 1; +} + +static int math_fmod(lua_State* L) +{ + lua_pushnumber(L, fmod(luaL_checknumber(L, 1), luaL_checknumber(L, 2))); + return 1; +} + +static int math_modf(lua_State* L) +{ + double ip; + double fp = modf(luaL_checknumber(L, 1), &ip); + lua_pushnumber(L, ip); + lua_pushnumber(L, fp); + return 2; +} + +static int math_sqrt(lua_State* L) +{ + lua_pushnumber(L, sqrt(luaL_checknumber(L, 1))); + return 1; +} + +static int math_pow(lua_State* L) +{ + lua_pushnumber(L, pow(luaL_checknumber(L, 1), luaL_checknumber(L, 2))); + return 1; +} + +static int math_log(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + double res; + if (lua_isnoneornil(L, 2)) + res = log(x); + else + { + double base = luaL_checknumber(L, 2); + if (base == 2.0) + res = log2(x); + else if (base == 10.0) + res = log10(x); + else + res = log(x) / log(base); + } + lua_pushnumber(L, res); + return 1; +} + +static int math_log10(lua_State* L) +{ + lua_pushnumber(L, log10(luaL_checknumber(L, 1))); + return 1; +} + +static int math_exp(lua_State* L) +{ + lua_pushnumber(L, exp(luaL_checknumber(L, 1))); + return 1; +} + +static int math_deg(lua_State* L) +{ + lua_pushnumber(L, luaL_checknumber(L, 1) / RADIANS_PER_DEGREE); + return 1; +} + +static int math_rad(lua_State* L) +{ + lua_pushnumber(L, luaL_checknumber(L, 1) * RADIANS_PER_DEGREE); + return 1; +} + +static int math_frexp(lua_State* L) +{ + int e; + lua_pushnumber(L, frexp(luaL_checknumber(L, 1), &e)); + lua_pushinteger(L, e); + return 2; +} + +static int math_ldexp(lua_State* L) +{ + lua_pushnumber(L, ldexp(luaL_checknumber(L, 1), luaL_checkinteger(L, 2))); + return 1; +} + +static int math_min(lua_State* L) +{ + int n = lua_gettop(L); /* number of arguments */ + double dmin = luaL_checknumber(L, 1); + int i; + for (i = 2; i <= n; i++) + { + double d = luaL_checknumber(L, i); + if (d < dmin) + dmin = d; + } + lua_pushnumber(L, dmin); + return 1; +} + +static int math_max(lua_State* L) +{ + int n = lua_gettop(L); /* number of arguments */ + double dmax = luaL_checknumber(L, 1); + int i; + for (i = 2; i <= n; i++) + { + double d = luaL_checknumber(L, i); + if (d > dmax) + dmax = d; + } + lua_pushnumber(L, dmax); + return 1; +} + +static int math_random(lua_State* L) +{ + global_State* g = L->global; + switch (lua_gettop(L)) + { /* check number of arguments */ + case 0: + { /* no arguments */ + // Using ldexp instead of division for speed & clarity. + // See http://mumble.net/~campbell/tmp/random_real.c for details on generating doubles from integer ranges. + uint32_t rl = pcg32_random(&g->rngstate); + uint32_t rh = pcg32_random(&g->rngstate); + double rd = ldexp(double(rl | (uint64_t(rh) << 32)), -64); + lua_pushnumber(L, rd); /* number between 0 and 1 */ + break; + } + case 1: + { /* only upper limit */ + int u = luaL_checkinteger(L, 1); + luaL_argcheck(L, 1 <= u, 1, "interval is empty"); + + uint64_t x = uint64_t(u) * pcg32_random(&g->rngstate); + int r = int(1 + (x >> 32)); + lua_pushinteger(L, r); /* int between 1 and `u' */ + break; + } + case 2: + { /* lower and upper limits */ + int l = luaL_checkinteger(L, 1); + int u = luaL_checkinteger(L, 2); + luaL_argcheck(L, l <= u, 2, "interval is empty"); + + uint32_t ul = uint32_t(u) - uint32_t(l); + luaL_argcheck(L, ul < UINT_MAX, 2, "interval is too large"); // -INT_MIN..INT_MAX interval can result in integer overflow + uint64_t x = uint64_t(ul + 1) * pcg32_random(&g->rngstate); + int r = int(l + (x >> 32)); + lua_pushinteger(L, r); /* int between `l' and `u' */ + break; + } + default: + luaL_error(L, "wrong number of arguments"); + } + return 1; +} + +static int math_randomseed(lua_State* L) +{ + int seed = luaL_checkinteger(L, 1); + + pcg32_seed(&L->global->rngstate, seed); + return 0; +} + +static const unsigned char kPerlin[512] = {151, 160, 137, 91, 90, 15, 131, 13, 201, 95, 96, 53, 194, 233, 7, 225, 140, 36, 103, 30, 69, 142, 8, 99, + 37, 240, 21, 10, 23, 190, 6, 148, 247, 120, 234, 75, 0, 26, 197, 62, 94, 252, 219, 203, 117, 35, 11, 32, 57, 177, 33, 88, 237, 149, 56, 87, 174, + 20, 125, 136, 171, 168, 68, 175, 74, 165, 71, 134, 139, 48, 27, 166, 77, 146, 158, 231, 83, 111, 229, 122, 60, 211, 133, 230, 220, 105, 92, 41, + 55, 46, 245, 40, 244, 102, 143, 54, 65, 25, 63, 161, 1, 216, 80, 73, 209, 76, 132, 187, 208, 89, 18, 169, 200, 196, 135, 130, 116, 188, 159, 86, + 164, 100, 109, 198, 173, 186, 3, 64, 52, 217, 226, 250, 124, 123, 5, 202, 38, 147, 118, 126, 255, 82, 85, 212, 207, 206, 59, 227, 47, 16, 58, 17, + 182, 189, 28, 42, 223, 183, 170, 213, 119, 248, 152, 2, 44, 154, 163, 70, 221, 153, 101, 155, 167, 43, 172, 9, 129, 22, 39, 253, 19, 98, 108, 110, + 79, 113, 224, 232, 178, 185, 112, 104, 218, 246, 97, 228, 251, 34, 242, 193, 238, 210, 144, 12, 191, 179, 162, 241, 81, 51, 145, 235, 249, 14, + 239, 107, 49, 192, 214, 31, 181, 199, 106, 157, 184, 84, 204, 176, 115, 121, 50, 45, 127, 4, 150, 254, 138, 236, 205, 93, 222, 114, 67, 29, 24, + 72, 243, 141, 128, 195, 78, 66, 215, 61, 156, 180, + + 151, 160, 137, 91, 90, 15, 131, 13, 201, 95, 96, 53, 194, 233, 7, 225, 140, 36, 103, 30, 69, 142, 8, 99, 37, 240, 21, 10, 23, 190, 6, 148, 247, + 120, 234, 75, 0, 26, 197, 62, 94, 252, 219, 203, 117, 35, 11, 32, 57, 177, 33, 88, 237, 149, 56, 87, 174, 20, 125, 136, 171, 168, 68, 175, 74, + 165, 71, 134, 139, 48, 27, 166, 77, 146, 158, 231, 83, 111, 229, 122, 60, 211, 133, 230, 220, 105, 92, 41, 55, 46, 245, 40, 244, 102, 143, 54, 65, + 25, 63, 161, 1, 216, 80, 73, 209, 76, 132, 187, 208, 89, 18, 169, 200, 196, 135, 130, 116, 188, 159, 86, 164, 100, 109, 198, 173, 186, 3, 64, 52, + 217, 226, 250, 124, 123, 5, 202, 38, 147, 118, 126, 255, 82, 85, 212, 207, 206, 59, 227, 47, 16, 58, 17, 182, 189, 28, 42, 223, 183, 170, 213, + 119, 248, 152, 2, 44, 154, 163, 70, 221, 153, 101, 155, 167, 43, 172, 9, 129, 22, 39, 253, 19, 98, 108, 110, 79, 113, 224, 232, 178, 185, 112, + 104, 218, 246, 97, 228, 251, 34, 242, 193, 238, 210, 144, 12, 191, 179, 162, 241, 81, 51, 145, 235, 249, 14, 239, 107, 49, 192, 214, 31, 181, 199, + 106, 157, 184, 84, 204, 176, 115, 121, 50, 45, 127, 4, 150, 254, 138, 236, 205, 93, 222, 114, 67, 29, 24, 72, 243, 141, 128, 195, 78, 66, 215, 61, + 156, 180}; + +static float fade(float t) +{ + return t * t * t * (t * (t * 6 - 15) + 10); +} + +static float lerp(float t, float a, float b) +{ + return a + t * (b - a); +} + +static float grad(unsigned char hash, float x, float y, float z) +{ + unsigned char h = hash & 15; + float u = (h < 8) ? x : y; + float v = (h < 4) ? y : (h == 12 || h == 14) ? x : z; + + return (h & 1 ? -u : u) + (h & 2 ? -v : v); +} + +static float perlin(float x, float y, float z) +{ + float xflr = floorf(x); + float yflr = floorf(y); + float zflr = floorf(z); + + int xi = int(xflr) & 255; + int yi = int(yflr) & 255; + int zi = int(zflr) & 255; + + float xf = x - xflr; + float yf = y - yflr; + float zf = z - zflr; + + float u = fade(xf); + float v = fade(yf); + float w = fade(zf); + + const unsigned char* p = kPerlin; + + int a = p[xi] + yi; + int aa = p[a] + zi; + int ab = p[a + 1] + zi; + + int b = p[xi + 1] + yi; + int ba = p[b] + zi; + int bb = p[b + 1] + zi; + + return lerp(w, + lerp(v, lerp(u, grad(p[aa], xf, yf, zf), grad(p[ba], xf - 1, yf, zf)), lerp(u, grad(p[ab], xf, yf - 1, zf), grad(p[bb], xf - 1, yf - 1, zf))), + lerp(v, lerp(u, grad(p[aa + 1], xf, yf, zf - 1), grad(p[ba + 1], xf - 1, yf, zf - 1)), + lerp(u, grad(p[ab + 1], xf, yf - 1, zf - 1), grad(p[bb + 1], xf - 1, yf - 1, zf - 1)))); +} + +static int math_noise(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + double y = luaL_optnumber(L, 2, 0.0); + double z = luaL_optnumber(L, 3, 0.0); + + double r = perlin((float)x, (float)y, (float)z); + + lua_pushnumber(L, r); + + return 1; +} + +static int math_clamp(lua_State* L) +{ + double v = luaL_checknumber(L, 1); + double min = luaL_checknumber(L, 2); + double max = luaL_checknumber(L, 3); + + luaL_argcheck(L, min <= max, 3, "max must be greater than or equal to min"); + + double r = v < min ? min : v; + r = r > max ? max : r; + + lua_pushnumber(L, r); + return 1; +} + +static int math_sign(lua_State* L) +{ + double v = luaL_checknumber(L, 1); + lua_pushnumber(L, v > 0.0 ? 1.0 : v < 0.0 ? -1.0 : 0.0); + return 1; +} + +static int math_round(lua_State* L) +{ + double v = luaL_checknumber(L, 1); + lua_pushnumber(L, round(v)); + return 1; +} + +static const luaL_Reg mathlib[] = { + {"abs", math_abs}, + {"acos", math_acos}, + {"asin", math_asin}, + {"atan2", math_atan2}, + {"atan", math_atan}, + {"ceil", math_ceil}, + {"cosh", math_cosh}, + {"cos", math_cos}, + {"deg", math_deg}, + {"exp", math_exp}, + {"floor", math_floor}, + {"fmod", math_fmod}, + {"frexp", math_frexp}, + {"ldexp", math_ldexp}, + {"log10", math_log10}, + {"log", math_log}, + {"max", math_max}, + {"min", math_min}, + {"modf", math_modf}, + {"pow", math_pow}, + {"rad", math_rad}, + {"random", math_random}, + {"randomseed", math_randomseed}, + {"sinh", math_sinh}, + {"sin", math_sin}, + {"sqrt", math_sqrt}, + {"tanh", math_tanh}, + {"tan", math_tan}, + {"noise", math_noise}, + {"clamp", math_clamp}, + {"sign", math_sign}, + {"round", math_round}, + {NULL, NULL}, +}; + +/* +** Open math library +*/ +LUALIB_API int luaopen_math(lua_State* L) +{ + uint64_t seed = uintptr_t(L); + seed ^= time(NULL); + seed ^= clock(); + + pcg32_seed(&L->global->rngstate, seed); + + luaL_register(L, LUA_MATHLIBNAME, mathlib); + lua_pushnumber(L, PI); + lua_setfield(L, -2, "pi"); + lua_pushnumber(L, HUGE_VAL); + lua_setfield(L, -2, "huge"); + return 1; +} diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp new file mode 100644 index 0000000..2759f3b --- /dev/null +++ b/VM/src/lmem.cpp @@ -0,0 +1,340 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lmem.h" + +#include "lstate.h" +#include "ldo.h" +#include "ldebug.h" + +#include + +#ifndef __has_feature +#define __has_feature(x) 0 +#endif + +#if __has_feature(address_sanitizer) || defined(LUAU_ENABLE_ASAN) +#include +#define ASAN_POISON_MEMORY_REGION(addr, size) __asan_poison_memory_region((addr), (size)) +#define ASAN_UNPOISON_MEMORY_REGION(addr, size) __asan_unpoison_memory_region((addr), (size)) +#else +#define ASAN_POISON_MEMORY_REGION(addr, size) (void)0 +#define ASAN_UNPOISON_MEMORY_REGION(addr, size) (void)0 +#endif + +/* + * The sizes of Luau objects aren't crucial for code correctness, but they are crucial for memory efficiency + * To prevent some of them accidentally growing and us losing memory without realizing it, we're going to lock + * the sizes of all critical structures down. + */ +#if defined(__APPLE__) && !defined(__MACH__) +#define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : gcc32) +#else +// Android somehow uses a similar ABI to MSVC, *not* to iOS... +#define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : ms32) +#endif + +static_assert(sizeof(TValue) == ABISWITCH(16, 16, 16), "size mismatch for value"); +static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); +static_assert(offsetof(Udata, data) == ABISWITCH(24, 16, 16), "size mismatch for userdata header"); +static_assert(sizeof(Table) == ABISWITCH(56, 36, 36), "size mismatch for table header"); +static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table entry"); + +const size_t kSizeClasses = LUA_SIZECLASSES; +const size_t kMaxSmallSize = 512; +const size_t kPageSize = 16 * 1024 - 24; // slightly under 16KB since that results in less fragmentation due to heap metadata +const size_t kBlockHeader = sizeof(double) > sizeof(void*) ? sizeof(double) : sizeof(void*); // suitable for aligning double & void* on all platforms + +struct SizeClassConfig +{ + int sizeOfClass[kSizeClasses]; + int8_t classForSize[kMaxSmallSize + 1]; + int classCount = 0; + + SizeClassConfig() + { + memset(sizeOfClass, 0, sizeof(sizeOfClass)); + memset(classForSize, -1, sizeof(classForSize)); + + // we use a progressive size class scheme: + // - all size classes are aligned by 8b to satisfy pointer alignment requirements + // - we first allocate sizes classes in multiples of 8 + // - after the first cutoff we allocate size classes in multiples of 16 + // - after the second cutoff we allocate size classes in multiples of 32 + // this balances internal fragmentation vs external fragmentation + for (int size = 8; size < 64; size += 8) + sizeOfClass[classCount++] = size; + + for (int size = 64; size < 256; size += 16) + sizeOfClass[classCount++] = size; + + for (int size = 256; size <= 512; size += 32) + sizeOfClass[classCount++] = size; + + LUAU_ASSERT(size_t(classCount) <= kSizeClasses); + + // fill the lookup table for all classes + for (int klass = 0; klass < classCount; ++klass) + classForSize[sizeOfClass[klass]] = int8_t(klass); + + // fill the gaps in lookup table + for (int size = kMaxSmallSize - 1; size >= 0; --size) + if (classForSize[size] < 0) + classForSize[size] = classForSize[size + 1]; + } +}; + +const SizeClassConfig kSizeClassConfig; + +// size class for a block of size sz +#define sizeclass(sz) (size_t((sz)-1) < kMaxSmallSize ? kSizeClassConfig.classForSize[sz] : -1) + +// metadata for a block is stored in the first pointer of the block +#define metadata(block) (*(void**)(block)) + +/* +** About the realloc function: +** void * frealloc (void *ud, void *ptr, size_t osize, size_t nsize); +** (`osize' is the old size, `nsize' is the new size) +** +** Lua ensures that (ptr == NULL) iff (osize == 0). +** +** * frealloc(ud, NULL, 0, x) creates a new block of size `x' +** +** * frealloc(ud, p, x, 0) frees the block `p' +** (in this specific case, frealloc must return NULL). +** particularly, frealloc(ud, NULL, 0, 0) does nothing +** (which is equivalent to free(NULL) in ANSI C) +** +** frealloc returns NULL if it cannot create or reallocate the area +** (any reallocation to an equal or smaller size cannot fail!) +*/ + +struct lua_Page +{ + lua_Page* prev; + lua_Page* next; + + int busyBlocks; + int blockSize; + + void* freeList; + int freeNext; + + union + { + char data[1]; + double align1; + void* align2; + }; +}; + +l_noret luaM_toobig(lua_State* L) +{ + luaG_runerror(L, "memory allocation error: block too big"); +} + +static lua_Page* luaM_newpage(lua_State* L, uint8_t sizeClass) +{ + global_State* g = L->global; + lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, kPageSize); + if (!page) + luaD_throw(L, LUA_ERRMEM); + + int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + kBlockHeader; + int blockCount = (kPageSize - offsetof(lua_Page, data)) / blockSize; + + ASAN_POISON_MEMORY_REGION(page->data, blockSize * blockCount); + + // setup page header + page->prev = NULL; + page->next = NULL; + + page->busyBlocks = 0; + page->blockSize = blockSize; + + // note: we start with the last block in the page and move downward + // either order would work, but that way we don't need to store the block count in the page + // additionally, GC stores objects in singly linked lists, and this way the GC lists end up in increasing pointer order + page->freeList = NULL; + page->freeNext = (blockCount - 1) * blockSize; + + // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) + LUAU_ASSERT(!g->freepages[sizeClass]); + g->freepages[sizeClass] = page; + + return page; +} + +static void luaM_freepage(lua_State* L, lua_Page* page, uint8_t sizeClass) +{ + global_State* g = L->global; + + // remove page from freelist + if (page->next) + page->next->prev = page->prev; + + if (page->prev) + page->prev->next = page->next; + else if (g->freepages[sizeClass] == page) + g->freepages[sizeClass] = page->next; + + // so long + (*g->frealloc)(L, g->ud, page, kPageSize, 0); +} + +static void* luaM_newblock(lua_State* L, int sizeClass) +{ + global_State* g = L->global; + lua_Page* page = g->freepages[sizeClass]; + + // slow path: no page in the freelist, allocate a new one + if (!page) + page = luaM_newpage(L, sizeClass); + + LUAU_ASSERT(!page->prev); + LUAU_ASSERT(page->freeList || page->freeNext >= 0); + LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass] + kBlockHeader); + + void* block; + + if (page->freeNext >= 0) + { + block = page->data + page->freeNext; + ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); + + page->freeNext -= page->blockSize; + page->busyBlocks++; + } + else + { + block = page->freeList; + ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); + + page->freeList = metadata(block); + page->busyBlocks++; + } + + // the first word in a block point back to the page + metadata(block) = page; + + // if we allocate the last block out of a page, we need to remove it from free list + if (!page->freeList && page->freeNext < 0) + { + g->freepages[sizeClass] = page->next; + if (page->next) + page->next->prev = NULL; + page->next = NULL; + } + + // the user data is right after the metadata + return (char*)block + kBlockHeader; +} + +static void luaM_freeblock(lua_State* L, int sizeClass, void* block) +{ + global_State* g = L->global; + + // the user data is right after the metadata + LUAU_ASSERT(block); + block = (char*)block - kBlockHeader; + + lua_Page* page = (lua_Page*)metadata(block); + LUAU_ASSERT(page && page->busyBlocks > 0); + LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass] + kBlockHeader); + + // if the page wasn't in the page free list, it should be now since it got a block! + if (!page->freeList && page->freeNext < 0) + { + LUAU_ASSERT(!page->prev); + LUAU_ASSERT(!page->next); + + page->next = g->freepages[sizeClass]; + if (page->next) + page->next->prev = page; + g->freepages[sizeClass] = page; + } + + // add the block to the free list inside the page + metadata(block) = page->freeList; + page->freeList = block; + + ASAN_POISON_MEMORY_REGION(block, page->blockSize); + + page->busyBlocks--; + + // if it's the last block in the page, we don't need the page + if (page->busyBlocks == 0) + luaM_freepage(L, page, sizeClass); +} + +/* +** generic allocation routines. +*/ +void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) +{ + global_State* g = L->global; + + int nclass = sizeclass(nsize); + + void* block = nclass >= 0 ? luaM_newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); + if (block == NULL && nsize > 0) + luaD_throw(L, LUA_ERRMEM); + + g->totalbytes += nsize; + g->memcatbytes[memcat] += nsize; + + return block; +} + +void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) +{ + global_State* g = L->global; + LUAU_ASSERT((osize == 0) == (block == NULL)); + + int oclass = sizeclass(osize); + + if (oclass >= 0) + luaM_freeblock(L, oclass, block); + else + (*g->frealloc)(L, g->ud, block, osize, 0); + + g->totalbytes -= osize; + g->memcatbytes[memcat] -= osize; +} + +void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8_t memcat) +{ + global_State* g = L->global; + LUAU_ASSERT((osize == 0) == (block == NULL)); + + int nclass = sizeclass(nsize); + int oclass = sizeclass(osize); + void* result; + + // if either block needs to be allocated using a block allocator, we can't use realloc directly + if (nclass >= 0 || oclass >= 0) + { + result = nclass >= 0 ? luaM_newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); + if (result == NULL && nsize > 0) + luaD_throw(L, LUA_ERRMEM); + + if (osize > 0 && nsize > 0) + memcpy(result, block, osize < nsize ? osize : nsize); + + if (oclass >= 0) + luaM_freeblock(L, oclass, block); + else + (*g->frealloc)(L, g->ud, block, osize, 0); + } + else + { + result = (*g->frealloc)(L, g->ud, block, osize, nsize); + if (result == NULL && nsize > 0) + luaD_throw(L, LUA_ERRMEM); + } + + LUAU_ASSERT((nsize == 0) == (result == NULL)); + g->totalbytes = (g->totalbytes - osize) + nsize; + g->memcatbytes[memcat] += nsize - osize; + return result; +} diff --git a/VM/src/lmem.h b/VM/src/lmem.h new file mode 100644 index 0000000..f526a1b --- /dev/null +++ b/VM/src/lmem.h @@ -0,0 +1,21 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lua.h" + +#define luaM_new(L, t, size, memcat) cast_to(t*, luaM_new_(L, size, memcat)) +#define luaM_free(L, p, size, memcat) luaM_free_(L, (p), size, memcat) + +#define luaM_arraysize_(n, e) ((cast_to(size_t, (n)) <= SIZE_MAX / (e)) ? (n) * (e) : (luaM_toobig(L), SIZE_MAX)) + +#define luaM_newarray(L, n, t, memcat) cast_to(t*, luaM_new_(L, luaM_arraysize_(n, sizeof(t)), memcat)) +#define luaM_freearray(L, b, n, t, memcat) luaM_free_(L, (b), (n) * sizeof(t), memcat) +#define luaM_reallocarray(L, v, oldn, n, t, memcat) \ + ((v) = cast_to(t*, luaM_realloc_(L, v, (oldn) * sizeof(t), luaM_arraysize_(n, sizeof(t)), memcat))) + +LUAI_FUNC void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat); +LUAI_FUNC void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat); +LUAI_FUNC void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8_t memcat); + +LUAI_FUNC l_noret luaM_toobig(lua_State* L); diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h new file mode 100644 index 0000000..43f8014 --- /dev/null +++ b/VM/src/lnumutils.h @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include +#include + +#define luai_numadd(a, b) ((a) + (b)) +#define luai_numsub(a, b) ((a) - (b)) +#define luai_nummul(a, b) ((a) * (b)) +#define luai_numdiv(a, b) ((a) / (b)) +#define luai_numpow(a, b) (pow(a, b)) +#define luai_numunm(a) (-(a)) +#define luai_numisnan(a) ((a) != (a)) +#define luai_numeq(a, b) ((a) == (b)) +#define luai_numlt(a, b) ((a) < (b)) +#define luai_numle(a, b) ((a) <= (b)) + +inline bool luai_veceq(const float* a, const float* b) +{ + return a[0] == b[0] && a[1] == b[1] && a[2] == b[2]; +} + +inline bool luai_vecisnan(const float* a) +{ + return a[0] != a[0] || a[1] != a[1] || a[2] != a[2]; +} + +LUAU_FASTMATH_BEGIN +inline double luai_nummod(double a, double b) +{ + return a - floor(a / b) * b; +} +LUAU_FASTMATH_END + +#define luai_num2int(i, d) ((i) = (int)(d)) + +/* On MSVC in 32-bit, double to unsigned cast compiles into a call to __dtoui3, so we invoke x87->int64 conversion path manually */ +#if defined(_MSC_VER) && defined(_M_IX86) +#define luai_num2unsigned(i, n) \ + { \ + __int64 l; \ + __asm { __asm fld n __asm fistp l} \ + ; \ + i = (unsigned int)l; \ + } +#else +#define luai_num2unsigned(i, n) ((i) = (unsigned)(long long)(n)) +#endif + +#define luai_num2str(s, n) snprintf((s), sizeof(s), LUA_NUMBER_FMT, (n)) +#define luai_str2num(s, p) strtod((s), (p)) diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp new file mode 100644 index 0000000..bf13e6e --- /dev/null +++ b/VM/src/lobject.cpp @@ -0,0 +1,160 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lobject.h" + +#include "lstate.h" +#include "lstring.h" +#include "lgc.h" +#include "ldo.h" +#include "lnumutils.h" + +#include +#include +#include +#include + + + +const TValue luaO_nilobject_ = {{NULL}, LUA_TNIL}; + +int luaO_log2(unsigned int x) +{ + static const uint8_t log_2[256] = {0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8}; + int l = -1; + while (x >= 256) + { + l += 8; + x >>= 8; + } + return l + log_2[x]; +} + +int luaO_rawequalObj(const TValue* t1, const TValue* t2) +{ + if (ttype(t1) != ttype(t2)) + return 0; + else + switch (ttype(t1)) + { + case LUA_TNIL: + return 1; + case LUA_TNUMBER: + return luai_numeq(nvalue(t1), nvalue(t2)); + case LUA_TVECTOR: + return luai_veceq(vvalue(t1), vvalue(t2)); + case LUA_TBOOLEAN: + return bvalue(t1) == bvalue(t2); /* boolean true must be 1 !! */ + case LUA_TLIGHTUSERDATA: + return pvalue(t1) == pvalue(t2); + default: + LUAU_ASSERT(iscollectable(t1)); + return gcvalue(t1) == gcvalue(t2); + } +} + +int luaO_rawequalKey(const TKey* t1, const TValue* t2) +{ + if (ttype(t1) != ttype(t2)) + return 0; + else + switch (ttype(t1)) + { + case LUA_TNIL: + return 1; + case LUA_TNUMBER: + return luai_numeq(nvalue(t1), nvalue(t2)); + case LUA_TVECTOR: + return luai_veceq(vvalue(t1), vvalue(t2)); + case LUA_TBOOLEAN: + return bvalue(t1) == bvalue(t2); /* boolean true must be 1 !! */ + case LUA_TLIGHTUSERDATA: + return pvalue(t1) == pvalue(t2); + default: + LUAU_ASSERT(iscollectable(t1)); + return gcvalue(t1) == gcvalue(t2); + } +} + +int luaO_str2d(const char* s, double* result) +{ + char* endptr; + *result = luai_str2num(s, &endptr); + if (endptr == s) + return 0; /* conversion failed */ + if (*endptr == 'x' || *endptr == 'X') /* maybe an hexadecimal constant? */ + *result = cast_num(strtoul(s, &endptr, 16)); + if (*endptr == '\0') + return 1; /* most common case */ + while (isspace(cast_to(unsigned char, *endptr))) + endptr++; + if (*endptr != '\0') + return 0; /* invalid trailing characters? */ + return 1; +} + +const char* luaO_pushvfstring(lua_State* L, const char* fmt, va_list argp) +{ + char result[LUA_BUFFERSIZE]; + vsnprintf(result, sizeof(result), fmt, argp); + + setsvalue2s(L, L->top, luaS_new(L, result)); + incr_top(L); + return svalue(L->top - 1); +} + +const char* luaO_pushfstring(lua_State* L, const char* fmt, ...) +{ + const char* msg; + va_list argp; + va_start(argp, fmt); + msg = luaO_pushvfstring(L, fmt, argp); + va_end(argp); + return msg; +} + +void luaO_chunkid(char* out, const char* source, size_t bufflen) +{ + if (*source == '=') + { + source++; /* skip the `=' */ + size_t srclen = strlen(source); + size_t dstlen = srclen < bufflen ? srclen : bufflen - 1; + memcpy(out, source, dstlen); + out[dstlen] = '\0'; + } + else if (*source == '@') + { + size_t l; + source++; /* skip the `@' */ + bufflen -= sizeof(" '...' "); + l = strlen(source); + strcpy(out, ""); + if (l > bufflen) + { + source += (l - bufflen); /* get last part of file name */ + strcat(out, "..."); + } + strcat(out, source); + } + else + { /* out = [string "string"] */ + size_t len = strcspn(source, "\n\r"); /* stop at first newline */ + bufflen -= sizeof(" [string \"...\"] "); + if (len > bufflen) + len = bufflen; + strcpy(out, "[string \""); + if (source[len] != '\0') + { /* must truncate? */ + strncat(out, source, len); + strcat(out, "..."); + } + else + strcat(out, source); + strcat(out, "\"]"); + } +} diff --git a/VM/src/lobject.h b/VM/src/lobject.h new file mode 100644 index 0000000..c5f2e2f --- /dev/null +++ b/VM/src/lobject.h @@ -0,0 +1,447 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lua.h" +#include "lcommon.h" + +/* +** Union of all collectible objects +*/ +typedef union GCObject GCObject; + +/* +** Common Header for all collectible objects (in macro form, to be +** included in other objects) +*/ +// clang-format off +#define CommonHeader \ + GCObject* next; \ + uint8_t tt; uint8_t marked; uint8_t memcat +// clang-format on + +/* +** Common header in struct form +*/ +typedef struct GCheader +{ + CommonHeader; +} GCheader; + +/* +** Union of all Lua values +*/ +typedef union +{ + GCObject* gc; + void* p; + double n; + int b; + float v[2]; // v[0], v[1] live here; v[2] lives in TValue::extra +} Value; + +/* +** Tagged Values +*/ + +typedef struct lua_TValue +{ + Value value; + int extra; + int tt; +} TValue; + +/* Macros to test type */ +#define ttisnil(o) (ttype(o) == LUA_TNIL) +#define ttisnumber(o) (ttype(o) == LUA_TNUMBER) +#define ttisstring(o) (ttype(o) == LUA_TSTRING) +#define ttistable(o) (ttype(o) == LUA_TTABLE) +#define ttisfunction(o) (ttype(o) == LUA_TFUNCTION) +#define ttisboolean(o) (ttype(o) == LUA_TBOOLEAN) +#define ttisuserdata(o) (ttype(o) == LUA_TUSERDATA) +#define ttisthread(o) (ttype(o) == LUA_TTHREAD) +#define ttislightuserdata(o) (ttype(o) == LUA_TLIGHTUSERDATA) +#define ttisvector(o) (ttype(o) == LUA_TVECTOR) +#define ttisupval(o) (ttype(o) == LUA_TUPVAL) + +/* Macros to access values */ +#define ttype(o) ((o)->tt) +#define gcvalue(o) check_exp(iscollectable(o), (o)->value.gc) +#define pvalue(o) check_exp(ttislightuserdata(o), (o)->value.p) +#define nvalue(o) check_exp(ttisnumber(o), (o)->value.n) +#define vvalue(o) check_exp(ttisvector(o), (o)->value.v) +#define tsvalue(o) check_exp(ttisstring(o), &(o)->value.gc->ts) +#define uvalue(o) check_exp(ttisuserdata(o), &(o)->value.gc->u) +#define clvalue(o) check_exp(ttisfunction(o), &(o)->value.gc->cl) +#define hvalue(o) check_exp(ttistable(o), &(o)->value.gc->h) +#define bvalue(o) check_exp(ttisboolean(o), (o)->value.b) +#define thvalue(o) check_exp(ttisthread(o), &(o)->value.gc->th) +#define upvalue(o) check_exp(ttisupval(o), &(o)->value.gc->uv) + +// beware bit magic: a value is false if it's nil or boolean false +// baseline implementation: (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) +// we'd like a branchless version of this which helps with performance, and a very fast version +// so our strategy is to always read the boolean value (not using bvalue(o) because that asserts when type isn't boolean) +// we then combine it with type to produce 0/1 as follows: +// - when type is nil (0), & makes the result 0 +// - when type is boolean (1), we effectively only look at the bottom bit, so result is 0 iff boolean value is 0 +// - when type is different, it must have some of the top bits set - we keep all top bits of boolean value so the result is non-0 +#define l_isfalse(o) (!(((o)->value.b | ~1) & ttype(o))) + +/* +** for internal debug only +*/ +#define checkconsistency(obj) LUAU_ASSERT(!iscollectable(obj) || (ttype(obj) == (obj)->value.gc->gch.tt)) + +#define checkliveness(g, obj) LUAU_ASSERT(!iscollectable(obj) || ((ttype(obj) == (obj)->value.gc->gch.tt) && !isdead(g, (obj)->value.gc))) + +/* Macros to set values */ +#define setnilvalue(obj) ((obj)->tt = LUA_TNIL) + +#define setnvalue(obj, x) \ + { \ + TValue* i_o = (obj); \ + i_o->value.n = (x); \ + i_o->tt = LUA_TNUMBER; \ + } + +#define setvvalue(obj, x, y, z) \ + { \ + TValue* i_o = (obj); \ + float* i_v = i_o->value.v; \ + i_v[0] = (x); \ + i_v[1] = (y); \ + i_v[2] = (z); \ + i_o->tt = LUA_TVECTOR; \ + } + +#define setpvalue(obj, x) \ + { \ + TValue* i_o = (obj); \ + i_o->value.p = (x); \ + i_o->tt = LUA_TLIGHTUSERDATA; \ + } + +#define setbvalue(obj, x) \ + { \ + TValue* i_o = (obj); \ + i_o->value.b = (x); \ + i_o->tt = LUA_TBOOLEAN; \ + } + +#define setsvalue(L, obj, x) \ + { \ + TValue* i_o = (obj); \ + i_o->value.gc = cast_to(GCObject*, (x)); \ + i_o->tt = LUA_TSTRING; \ + checkliveness(L->global, i_o); \ + } + +#define setuvalue(L, obj, x) \ + { \ + TValue* i_o = (obj); \ + i_o->value.gc = cast_to(GCObject*, (x)); \ + i_o->tt = LUA_TUSERDATA; \ + checkliveness(L->global, i_o); \ + } + +#define setthvalue(L, obj, x) \ + { \ + TValue* i_o = (obj); \ + i_o->value.gc = cast_to(GCObject*, (x)); \ + i_o->tt = LUA_TTHREAD; \ + checkliveness(L->global, i_o); \ + } + +#define setclvalue(L, obj, x) \ + { \ + TValue* i_o = (obj); \ + i_o->value.gc = cast_to(GCObject*, (x)); \ + i_o->tt = LUA_TFUNCTION; \ + checkliveness(L->global, i_o); \ + } + +#define sethvalue(L, obj, x) \ + { \ + TValue* i_o = (obj); \ + i_o->value.gc = cast_to(GCObject*, (x)); \ + i_o->tt = LUA_TTABLE; \ + checkliveness(L->global, i_o); \ + } + +#define setptvalue(L, obj, x) \ + { \ + TValue* i_o = (obj); \ + i_o->value.gc = cast_to(GCObject*, (x)); \ + i_o->tt = LUA_TPROTO; \ + checkliveness(L->global, i_o); \ + } + +#define setupvalue(L, obj, x) \ + { \ + TValue* i_o = (obj); \ + i_o->value.gc = cast_to(GCObject*, (x)); \ + i_o->tt = LUA_TUPVAL; \ + checkliveness(L->global, i_o); \ + } + +#define setobj(L, obj1, obj2) \ + { \ + const TValue* o2 = (obj2); \ + TValue* o1 = (obj1); \ + *o1 = *o2; \ + checkliveness(L->global, o1); \ + } + +/* +** different types of sets, according to destination +*/ + +/* from stack to (same) stack */ +#define setobjs2s setobj +/* to stack (not from same stack) */ +#define setobj2s setobj +#define setsvalue2s setsvalue +#define sethvalue2s sethvalue +#define setptvalue2s setptvalue +/* from table to same table */ +#define setobjt2t setobj +/* to table */ +#define setobj2t setobj +/* to new object */ +#define setobj2n setobj +#define setsvalue2n setsvalue + +#define setttype(obj, tt) (ttype(obj) = (tt)) + +#define iscollectable(o) (ttype(o) >= LUA_TSTRING) + +typedef TValue* StkId; /* index to stack elements */ + +/* +** String headers for string table +*/ +typedef struct TString +{ + CommonHeader; + + int16_t atom; + + unsigned int hash; + unsigned int len; + + char data[1]; // string data is allocated right after the header +} TString; + +#define getstr(ts) (ts)->data +#define svalue(o) getstr(tsvalue(o)) + +typedef struct Udata +{ + CommonHeader; + + uint8_t tag; + + int len; + + struct Table* metatable; + + union + { + char data[1]; // userdata is allocated right after the header + L_Umaxalign dummy; // ensures maximum alignment for data + }; +} Udata; + +/* +** Function Prototypes +*/ +// clang-format off +typedef struct Proto +{ + CommonHeader; + + + TValue* k; /* constants used by the function */ + Instruction* code; /* function bytecode */ + struct Proto** p; /* functions defined inside the function */ + uint8_t* lineinfo; /* for each instruction, line number as a delta from baseline */ + int* abslineinfo; /* baseline line info, one entry for each 1<isC) +#define isLfunction(o) (ttype(o) == LUA_TFUNCTION && !clvalue(o)->isC) + +/* +** Tables +*/ + +typedef struct TKey +{ + ::Value value; + int extra; + unsigned tt : 4; + int next : 28; /* for chaining */ +} TKey; + +typedef struct LuaNode +{ + TValue val; + TKey key; +} LuaNode; + +/* copy a value into a key */ +#define setnodekey(L, node, obj) \ + { \ + LuaNode* n_ = (node); \ + const TValue* i_o = (obj); \ + n_->key.value = i_o->value; \ + n_->key.extra = i_o->extra; \ + n_->key.tt = i_o->tt; \ + checkliveness(L->global, i_o); \ + } + +/* copy a value from a key */ +#define getnodekey(L, obj, node) \ + { \ + TValue* i_o = (obj); \ + const LuaNode* n_ = (node); \ + i_o->value = n_->key.value; \ + i_o->extra = n_->key.extra; \ + i_o->tt = n_->key.tt; \ + checkliveness(L->global, i_o); \ + } + +// clang-format off +typedef struct Table +{ + CommonHeader; + + + uint8_t flags; /* 1<

lsizenode)) + +#define luaO_nilobject (&luaO_nilobject_) + +LUAI_DATA const TValue luaO_nilobject_; + +#define ceillog2(x) (luaO_log2((x)-1) + 1) + +LUAI_FUNC int luaO_log2(unsigned int x); +LUAI_FUNC int luaO_rawequalObj(const TValue* t1, const TValue* t2); +LUAI_FUNC int luaO_rawequalKey(const TKey* t1, const TValue* t2); +LUAI_FUNC int luaO_str2d(const char* s, double* result); +LUAI_FUNC const char* luaO_pushvfstring(lua_State* L, const char* fmt, va_list argp); +LUAI_FUNC const char* luaO_pushfstring(lua_State* L, const char* fmt, ...); +LUAI_FUNC void luaO_chunkid(char* out, const char* source, size_t len); diff --git a/VM/src/loslib.cpp b/VM/src/loslib.cpp new file mode 100644 index 0000000..8eaef60 --- /dev/null +++ b/VM/src/loslib.cpp @@ -0,0 +1,193 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include +#include + +#define LUA_STRFTIMEOPTIONS "aAbBcdHIjmMpSUwWxXyYzZ%" + +#if defined(_WIN32) +static tm* gmtime_r(const time_t* timep, tm* result) +{ + return gmtime_s(result, timep) == 0 ? result : NULL; +} + +static tm* localtime_r(const time_t* timep, tm* result) +{ + return localtime_s(result, timep) == 0 ? result : NULL; +} + +static time_t timegm(struct tm* timep) +{ + return _mkgmtime(timep); +} +#endif + +static int os_clock(lua_State* L) +{ + lua_pushnumber(L, lua_clock()); + return 1; +} + +/* +** {====================================================== +** Time/Date operations +** { year=%Y, month=%m, day=%d, hour=%H, min=%M, sec=%S, +** wday=%w+1, yday=%j, isdst=? } +** ======================================================= +*/ + +static void setfield(lua_State* L, const char* key, int value) +{ + lua_pushinteger(L, value); + lua_setfield(L, -2, key); +} + +static void setboolfield(lua_State* L, const char* key, int value) +{ + if (value < 0) /* undefined? */ + return; /* does not set field */ + lua_pushboolean(L, value); + lua_setfield(L, -2, key); +} + +static int getboolfield(lua_State* L, const char* key) +{ + int res; + lua_rawgetfield(L, -1, key); + res = lua_isnil(L, -1) ? -1 : lua_toboolean(L, -1); + lua_pop(L, 1); + return res; +} + +static int getfield(lua_State* L, const char* key, int d) +{ + int res; + lua_rawgetfield(L, -1, key); + if (lua_isnumber(L, -1)) + res = (int)lua_tointeger(L, -1); + else + { + if (d < 0) + luaL_error(L, "field '%s' missing in date table", key); + res = d; + } + lua_pop(L, 1); + return res; +} + +static int os_date(lua_State* L) +{ + const char* s = luaL_optstring(L, 1, "%c"); + time_t t = luaL_opt(L, (time_t)luaL_checknumber, 2, time(NULL)); + + struct tm tm; + struct tm* stm; + if (*s == '!') + { /* UTC? */ + stm = gmtime_r(&t, &tm); + s++; /* skip `!' */ + } + else + { + // on Windows, localtime() fails with dates before epoch start so we disallow that + stm = t < 0 ? NULL : localtime_r(&t, &tm); + } + + if (stm == NULL) /* invalid date? */ + { + lua_pushnil(L); + } + else if (strcmp(s, "*t") == 0) + { + lua_createtable(L, 0, 9); /* 9 = number of fields */ + setfield(L, "sec", stm->tm_sec); + setfield(L, "min", stm->tm_min); + setfield(L, "hour", stm->tm_hour); + setfield(L, "day", stm->tm_mday); + setfield(L, "month", stm->tm_mon + 1); + setfield(L, "year", stm->tm_year + 1900); + setfield(L, "wday", stm->tm_wday + 1); + setfield(L, "yday", stm->tm_yday + 1); + setboolfield(L, "isdst", stm->tm_isdst); + } + else + { + char cc[3]; + cc[0] = '%'; + cc[2] = '\0'; + + luaL_Buffer b; + luaL_buffinit(L, &b); + for (; *s; s++) + { + if (*s != '%' || *(s + 1) == '\0') /* no conversion specifier? */ + { + luaL_addchar(&b, *s); + } + else if (strchr(LUA_STRFTIMEOPTIONS, *(s + 1)) == 0) + { + luaL_argerror(L, 1, "invalid conversion specifier"); + } + else + { + size_t reslen; + char buff[200]; /* should be big enough for any conversion result */ + cc[1] = *(++s); + reslen = strftime(buff, sizeof(buff), cc, stm); + luaL_addlstring(&b, buff, reslen); + } + } + luaL_pushresult(&b); + } + return 1; +} + +static int os_time(lua_State* L) +{ + time_t t; + if (lua_isnoneornil(L, 1)) /* called without args? */ + t = time(NULL); /* get current time */ + else + { + struct tm ts; + luaL_checktype(L, 1, LUA_TTABLE); + lua_settop(L, 1); /* make sure table is at the top */ + ts.tm_sec = getfield(L, "sec", 0); + ts.tm_min = getfield(L, "min", 0); + ts.tm_hour = getfield(L, "hour", 12); + ts.tm_mday = getfield(L, "day", -1); + ts.tm_mon = getfield(L, "month", -1) - 1; + ts.tm_year = getfield(L, "year", -1) - 1900; + ts.tm_isdst = getboolfield(L, "isdst"); + + // Note: upstream Lua uses mktime() here which assumes input is local time, but we prefer UTC for consistency + t = timegm(&ts); + } + if (t == (time_t)(-1)) + lua_pushnil(L); + else + lua_pushnumber(L, (double)t); + return 1; +} + +static int os_difftime(lua_State* L) +{ + lua_pushnumber(L, difftime((time_t)(luaL_checknumber(L, 1)), (time_t)(luaL_optnumber(L, 2, 0)))); + return 1; +} + +static const luaL_Reg syslib[] = { + {"clock", os_clock}, + {"date", os_date}, + {"difftime", os_difftime}, + {"time", os_time}, + {NULL, NULL}, +}; + +LUALIB_API int luaopen_os(lua_State* L) +{ + luaL_register(L, LUA_OSLIBNAME, syslib); + return 1; +} diff --git a/VM/src/lperf.cpp b/VM/src/lperf.cpp new file mode 100644 index 0000000..2f6c729 --- /dev/null +++ b/VM/src/lperf.cpp @@ -0,0 +1,55 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lua.h" + +#ifdef _WIN32 +#include +#endif + +#ifdef __APPLE__ +#include +#include +#endif + +#include + +static double clock_period() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceFrequency(&result); + return 1.0 / double(result.QuadPart); +#elif defined(__APPLE__) + mach_timebase_info_data_t result = {}; + mach_timebase_info(&result); + return double(result.numer) / double(result.denom) * 1e-9; +#elif defined(__linux__) + return 1e-9; +#else + return 1.0 / double(CLOCKS_PER_SEC); +#endif +} + +static double clock_timestamp() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceCounter(&result); + return double(result.QuadPart); +#elif defined(__APPLE__) + return double(mach_absolute_time()); +#elif defined(__linux__) + timespec now; + clock_gettime(CLOCK_MONOTONIC, &now); + return now.tv_sec * 1e9 + now.tv_nsec; +#else + return double(clock()); +#endif +} + +double lua_clock() +{ + static double period = clock_period(); + + return clock_timestamp() * period; +} diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp new file mode 100644 index 0000000..0b2dfb6 --- /dev/null +++ b/VM/src/lstate.cpp @@ -0,0 +1,199 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lstate.h" + +#include "ltable.h" +#include "lstring.h" +#include "lfunc.h" +#include "lmem.h" +#include "lgc.h" +#include "ldo.h" +#include "ldebug.h" + +/* +** Main thread combines a thread state and the global state +*/ +typedef struct LG +{ + lua_State l; + global_State g; +} LG; + +static void stack_init(lua_State* L1, lua_State* L) +{ + /* initialize CallInfo array */ + L1->base_ci = luaM_newarray(L, BASIC_CI_SIZE, CallInfo, L1->memcat); + L1->ci = L1->base_ci; + L1->size_ci = BASIC_CI_SIZE; + L1->end_ci = L1->base_ci + L1->size_ci - 1; + /* initialize stack array */ + L1->stack = luaM_newarray(L, BASIC_STACK_SIZE + EXTRA_STACK, TValue, L1->memcat); + L1->stacksize = BASIC_STACK_SIZE + EXTRA_STACK; + for (int i = 0; i < BASIC_STACK_SIZE + EXTRA_STACK; i++) + setnilvalue(L1->stack + i); /* erase new stack */ + L1->top = L1->stack; + L1->stack_last = L1->stack + (L1->stacksize - EXTRA_STACK) - 1; + /* initialize first ci */ + L1->ci->func = L1->top; + setnilvalue(L1->top++); /* `function' entry for this `ci' */ + L1->base = L1->ci->base = L1->top; + L1->ci->top = L1->top + LUA_MINSTACK; +} + +static void freestack(lua_State* L, lua_State* L1) +{ + luaM_freearray(L, L1->base_ci, L1->size_ci, CallInfo, L1->memcat); + luaM_freearray(L, L1->stack, L1->stacksize, TValue, L1->memcat); +} + +/* +** open parts that may cause memory-allocation errors +*/ +static void f_luaopen(lua_State* L, void* ud) +{ + global_State* g = L->global; + stack_init(L, L); /* init stack */ + sethvalue(L, gt(L), luaH_new(L, 0, 2)); /* table of globals */ + sethvalue(L, registry(L), luaH_new(L, 0, 2)); /* registry */ + luaS_resize(L, LUA_MINSTRTABSIZE); /* initial size of string table */ + luaT_init(L); + luaS_fix(luaS_newliteral(L, LUA_MEMERRMSG)); /* pin to make sure we can always throw this error */ + luaS_fix(luaS_newliteral(L, LUA_ERRERRMSG)); /* pin to make sure we can always throw this error */ + g->GCthreshold = 4 * g->totalbytes; +} + +static void preinit_state(lua_State* L, global_State* g) +{ + L->global = g; + L->stack = NULL; + L->stacksize = 0; + L->openupval = NULL; + L->size_ci = 0; + L->nCcalls = L->baseCcalls = 0; + L->status = 0; + L->base_ci = L->ci = NULL; + L->namecall = NULL; + L->cachedslot = 0; + L->singlestep = false; + L->stackstate = 0; + L->activememcat = 0; + L->userdata = NULL; + setnilvalue(gt(L)); +} + +static void close_state(lua_State* L) +{ + global_State* g = L->global; + luaF_close(L, L->stack); /* close all upvalues for this thread */ + luaC_freeall(L); /* collect all objects */ + LUAU_ASSERT(g->rootgc == obj2gco(L)); + LUAU_ASSERT(g->strbufgc == NULL); + LUAU_ASSERT(g->strt.nuse == 0); + luaM_freearray(L, L->global->strt.hash, L->global->strt.size, TString*, 0); + freestack(L, L); + LUAU_ASSERT(g->totalbytes == sizeof(LG)); + for (int i = 0; i < LUA_SIZECLASSES; i++) + LUAU_ASSERT(g->freepages[i] == NULL); + LUAU_ASSERT(g->memcatbytes[0] == sizeof(LG)); + for (int i = 1; i < LUA_MEMORY_CATEGORIES; i++) + LUAU_ASSERT(g->memcatbytes[i] == 0); + (*g->frealloc)(L, g->ud, L, sizeof(LG), 0); +} + +lua_State* luaE_newthread(lua_State* L) +{ + lua_State* L1 = luaM_new(L, lua_State, sizeof(lua_State), L->activememcat); + luaC_link(L, L1, LUA_TTHREAD); + preinit_state(L1, L->global); + L1->activememcat = L->activememcat; // inherit the active memory category + stack_init(L1, L); /* init stack */ + setobj2n(L, gt(L1), gt(L)); /* share table of globals */ + L1->singlestep = L->singlestep; + LUAU_ASSERT(iswhite(obj2gco(L1))); + return L1; +} + +void luaE_freethread(lua_State* L, lua_State* L1) +{ + luaF_close(L1, L1->stack); /* close all upvalues for this thread */ + LUAU_ASSERT(L1->openupval == NULL); + global_State* g = L->global; + if (g->cb.userthread) + g->cb.userthread(NULL, L1); + freestack(L, L1); + luaM_free(L, L1, sizeof(lua_State), L1->memcat); +} + +lua_State* lua_newstate(lua_Alloc f, void* ud) +{ + int i; + lua_State* L; + global_State* g; + void* l = (*f)(NULL, ud, NULL, 0, sizeof(LG)); + if (l == NULL) + return NULL; + L = (lua_State*)l; + g = &((LG*)L)->g; + L->next = NULL; + L->tt = LUA_TTHREAD; + L->marked = g->currentwhite = bit2mask(WHITE0BIT, FIXEDBIT); + L->memcat = 0; + preinit_state(L, g); + g->frealloc = f; + g->ud = ud; + g->mainthread = L; + g->uvhead.u.l.prev = &g->uvhead; + g->uvhead.u.l.next = &g->uvhead; + g->GCthreshold = 0; /* mark it as unfinished state */ + g->registryfree = 0; + g->errorjmp = NULL; + g->rngstate = 0; + g->ptrenckey[0] = 1; + g->ptrenckey[1] = 0; + g->ptrenckey[2] = 0; + g->ptrenckey[3] = 0; + g->strt.size = 0; + g->strt.nuse = 0; + g->strt.hash = NULL; + setnilvalue(registry(L)); + g->gcstate = GCSpause; + g->rootgc = obj2gco(L); + g->sweepstrgc = 0; + g->sweepgc = &g->rootgc; + g->gray = NULL; + g->grayagain = NULL; + g->weak = NULL; + g->strbufgc = NULL; + g->totalbytes = sizeof(LG); + g->gcgoal = LUAI_GCGOAL; + g->gcstepmul = LUAI_GCSTEPMUL; + g->gcstepsize = LUAI_GCSTEPSIZE << 10; + for (i = 0; i < LUA_SIZECLASSES; i++) + g->freepages[i] = NULL; + for (i = 0; i < LUA_T_COUNT; i++) + g->mt[i] = NULL; + for (i = 0; i < LUA_UTAG_LIMIT; i++) + g->udatagc[i] = NULL; + for (i = 0; i < LUA_MEMORY_CATEGORIES; i++) + g->memcatbytes[i] = 0; + + g->memcatbytes[0] = sizeof(LG); + + g->cb = lua_Callbacks(); + g->gcstats = GCStats(); + + if (luaD_rawrunprotected(L, f_luaopen, NULL) != 0) + { + /* memory allocation error: free partial state */ + close_state(L); + L = NULL; + } + return L; +} + +void lua_close(lua_State* L) +{ + L = L->global->mainthread; /* only the main thread can be closed */ + luaF_close(L, L->stack); /* close all upvalues for this thread */ + close_state(L); +} diff --git a/VM/src/lstate.h b/VM/src/lstate.h new file mode 100644 index 0000000..5637988 --- /dev/null +++ b/VM/src/lstate.h @@ -0,0 +1,271 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" +#include "ltm.h" + +/* table of globals */ +#define gt(L) (&L->l_gt) + +/* registry */ +#define registry(L) (&L->global->registry) + +/* extra stack space to handle TM calls and some other extras */ +#define EXTRA_STACK 5 + +#define BASIC_CI_SIZE 8 + +#define BASIC_STACK_SIZE (2 * LUA_MINSTACK) + +// clang-format off +typedef struct stringtable +{ + + GCObject** hash; + uint32_t nuse; /* number of elements */ + int size; +} stringtable; +// clang-format on + +/* +** informations about a call +** +** the general Lua stack frame structure is as follows: +** - each function gets a stack frame, with function "registers" being stack slots on the frame +** - function arguments are associated with registers 0+ +** - function locals and temporaries follow after; usually locals are a consecutive block per scope, and temporaries are allocated after this, but +*this is up to the compiler +** +** when function doesn't have varargs, the stack layout is as follows: +** ^ (func) ^^ [fixed args] [locals + temporaries] +** where ^ is the 'func' pointer in CallInfo struct, and ^^ is the 'base' pointer (which is what registers are relative to) +** +** when function *does* have varargs, the stack layout is more complex - the runtime has to copy the fixed arguments so that the 0+ addressing still +*works as follows: +** ^ (func) [fixed args] [varargs] ^^ [fixed args] [locals + temporaries] +** +** computing the sizes of these individual blocks works as follows: +** - the number of fixed args is always matching the `numparams` in a function's Proto object; runtime adds `nil` during the call execution as +*necessary +** - the number of variadic args can be computed by evaluating (ci->base - ci->func - 1 - numparams) +** +** the CallInfo structures are allocated as an array, with each subsequent call being *appended* to this array (so if f calls g, CallInfo for g +*immediately follows CallInfo for f) +** the `nresults` field in CallInfo is set by the caller to tell the function how many arguments the caller is expecting on the stack after the +*function returns +** the `flags` field in CallInfo contains internal execution flags that are important for pcall/etc, see LUA_CALLINFO_* +*/ +// clang-format off +typedef struct CallInfo +{ + + StkId base; /* base for this function */ + StkId func; /* function index in the stack */ + StkId top; /* top for this function */ + const Instruction* savedpc; + + int nresults; /* expected number of results from this function */ + unsigned int flags; /* call frame flags, see LUA_CALLINFO_* */ +} CallInfo; +// clang-format on + +#define LUA_CALLINFO_RETURN (1 << 0) /* should the interpreter return after returning from this callinfo? first frame must have this set */ +#define LUA_CALLINFO_HANDLE (1 << 1) /* should the error thrown during execution get handled by continuation from this callinfo? func must be C */ + +#define curr_func(L) (clvalue(L->ci->func)) +#define ci_func(ci) (clvalue((ci)->func)) +#define f_isLua(ci) (!ci_func(ci)->isC) +#define isLua(ci) (ttisfunction((ci)->func) && f_isLua(ci)) + +struct GCCycleStats +{ + size_t heapgoalsizebytes = 0; + size_t heaptriggersizebytes = 0; + + double waittime = 0.0; // time from end of the last cycle to the start of a new one + + double starttimestamp = 0.0; + double endtimestamp = 0.0; + + double marktime = 0.0; + + double atomicstarttimestamp = 0.0; + size_t atomicstarttotalsizebytes = 0; + double atomictime = 0.0; + + double sweeptime = 0.0; + + size_t markitems = 0; + size_t sweepitems = 0; + + size_t assistwork = 0; + size_t explicitwork = 0; + + size_t endtotalsizebytes = 0; +}; + +// data for proportional-integral controller of heap trigger value +struct GCHeapTriggerStats +{ + static const unsigned termcount = 32; + int32_t terms[termcount] = {0}; + uint32_t termpos = 0; + int32_t integral = 0; +}; + +struct GCStats +{ + double stepexplicittimeacc = 0.0; + double stepassisttimeacc = 0.0; + + // when cycle is completed, last cycle values are updated + uint64_t completedcycles = 0; + + GCCycleStats lastcycle; + GCCycleStats currcycle; + + // only step count and their time is accumulated + GCCycleStats cyclestatsacc; + + GCHeapTriggerStats triggerstats; +}; + +/* +** `global state', shared by all threads of this state +*/ +// clang-format off +typedef struct global_State +{ + stringtable strt; /* hash table for strings */ + + + lua_Alloc frealloc; /* function to reallocate memory */ + void* ud; /* auxiliary data to `frealloc' */ + + + uint8_t currentwhite; + uint8_t gcstate; /* state of garbage collector */ + + + int sweepstrgc; /* position of sweep in `strt' */ + GCObject* rootgc; /* list of all collectable objects */ + GCObject** sweepgc; /* position of sweep in `rootgc' */ + GCObject* gray; /* list of gray objects */ + GCObject* grayagain; /* list of objects to be traversed atomically */ + GCObject* weak; /* list of weak tables (to be cleared) */ + + GCObject* strbufgc; // list of all string buffer objects + + + size_t GCthreshold; // when totalbytes > GCthreshold; run GC step + size_t totalbytes; // number of bytes currently allocated + int gcgoal; // see LUAI_GCGOAL + int gcstepmul; // see LUAI_GCSTEPMUL + int gcstepsize; // see LUAI_GCSTEPSIZE + + struct lua_Page* freepages[LUA_SIZECLASSES]; /* free page linked list for each size class */ + + size_t memcatbytes[LUA_MEMORY_CATEGORIES]; /* total amount of memory used by each memory category */ + + + struct lua_State* mainthread; + UpVal uvhead; /* head of double-linked list of all open upvalues */ + struct Table* mt[LUA_T_COUNT]; /* metatables for basic types */ + TString* ttname[LUA_T_COUNT]; /* names for basic types */ + TString* tmname[TM_N]; /* array with tag-method names */ + + TValue registry; /* registry table, used by lua_ref and LUA_REGISTRYINDEX */ + int registryfree; /* next free slot in registry */ + + struct lua_jmpbuf* errorjmp; /* jump buffer data for longjmp-style error handling */ + + uint64_t rngstate; /* PCG random number generator state */ + uint64_t ptrenckey[4]; /* pointer encoding key for display */ + + void (*udatagc[LUA_UTAG_LIMIT])(void*); /* for each userdata tag, a gc callback to be called immediately before freeing memory */ + + lua_Callbacks cb; + + GCStats gcstats; + +} global_State; +// clang-format on + +/* +** `per thread' state +*/ +// clang-format off +struct lua_State +{ + CommonHeader; + uint8_t status; + + uint8_t activememcat; /* memory category that is used for new GC object allocations */ + uint8_t stackstate; + + bool singlestep; /* call debugstep hook after each instruction */ + + + StkId top; /* first free slot in the stack */ + StkId base; /* base of current function */ + global_State* global; + CallInfo* ci; /* call info for current function */ + StkId stack_last; /* last free slot in the stack */ + StkId stack; /* stack base */ + + + CallInfo* end_ci; /* points after end of ci array*/ + CallInfo* base_ci; /* array of CallInfo's */ + + + int stacksize; + int size_ci; /* size of array `base_ci' */ + + + unsigned short nCcalls; /* number of nested C calls */ + unsigned short baseCcalls; /* nested C calls when resuming coroutine */ + + int cachedslot; /* when table operations or INDEX/NEWINDEX is invoked from Luau, what is the expected slot for lookup? */ + + + TValue l_gt; /* table of globals */ + TValue env; /* temporary place for environments */ + GCObject* openupval; /* list of open upvalues in this stack */ + GCObject* gclist; + + TString* namecall; /* when invoked from Luau using NAMECALL, what method do we need to invoke? */ + + void* userdata; +}; +// clang-format on + +/* +** Union of all collectible objects +*/ +union GCObject +{ + GCheader gch; + struct TString ts; + struct Udata u; + struct Closure cl; + struct Table h; + struct Proto p; + struct UpVal uv; + struct lua_State th; /* thread */ +}; + +/* macros to convert a GCObject into a specific value */ +#define gco2ts(o) check_exp((o)->gch.tt == LUA_TSTRING, &((o)->ts)) +#define gco2u(o) check_exp((o)->gch.tt == LUA_TUSERDATA, &((o)->u)) +#define gco2cl(o) check_exp((o)->gch.tt == LUA_TFUNCTION, &((o)->cl)) +#define gco2h(o) check_exp((o)->gch.tt == LUA_TTABLE, &((o)->h)) +#define gco2p(o) check_exp((o)->gch.tt == LUA_TPROTO, &((o)->p)) +#define gco2uv(o) check_exp((o)->gch.tt == LUA_TUPVAL, &((o)->uv)) +#define gco2th(o) check_exp((o)->gch.tt == LUA_TTHREAD, &((o)->th)) + +/* macro to convert any Lua object into a GCObject */ +#define obj2gco(v) check_exp(iscollectable(v), cast_to(GCObject*, (v) + 0)) + +LUAI_FUNC lua_State* luaE_newthread(lua_State* L); +LUAI_FUNC void luaE_freethread(lua_State* L, lua_State* L1); diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp new file mode 100644 index 0000000..d77e17c --- /dev/null +++ b/VM/src/lstring.cpp @@ -0,0 +1,237 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lstring.h" + +#include "lgc.h" +#include "lmem.h" + +#include + +unsigned int luaS_hash(const char* str, size_t len) +{ + // Note that this hashing algorithm is replicated in BytecodeBuilder.cpp, BytecodeBuilder::getStringHash + unsigned int a = 0, b = 0; + unsigned int h = unsigned(len); + + // hash prefix in 12b chunks (using aligned reads) with ARX based hash (LuaJIT v2.1, lookup3) + // note that we stop at length<32 to maintain compatibility with Lua 5.1 + while (len >= 32) + { +#define rol(x, s) ((x >> s) | (x << (32 - s))) +#define mix(u, v, w) a ^= h, a -= rol(h, u), b ^= a, b -= rol(a, v), h ^= b, h -= rol(b, w) + + // should compile into fast unaligned reads + uint32_t block[3]; + memcpy(block, str, 12); + + a += block[0]; + b += block[1]; + h += block[2]; + mix(14, 11, 25); + str += 12; + len -= 12; + +#undef mix +#undef rol + } + + // original Lua 5.1 hash for compatibility (exact match when len<32) + for (size_t i = len; i > 0; --i) + h ^= (h << 5) + (h >> 2) + (uint8_t)str[i - 1]; + + return h; +} + +void luaS_resize(lua_State* L, int newsize) +{ + GCObject** newhash; + stringtable* tb; + int i; + if (L->global->gcstate == GCSsweepstring) + return; /* cannot resize during GC traverse */ + newhash = luaM_newarray(L, newsize, GCObject*, 0); + tb = &L->global->strt; + for (i = 0; i < newsize; i++) + newhash[i] = NULL; + /* rehash */ + for (i = 0; i < tb->size; i++) + { + GCObject* p = tb->hash[i]; + while (p) + { /* for each node in the list */ + GCObject* next = p->gch.next; /* save next */ + unsigned int h = gco2ts(p)->hash; + int h1 = lmod(h, newsize); /* new position */ + LUAU_ASSERT(cast_int(h % newsize) == lmod(h, newsize)); + p->gch.next = newhash[h1]; /* chain it */ + newhash[h1] = p; + p = next; + } + } + luaM_freearray(L, tb->hash, tb->size, TString*, 0); + tb->size = newsize; + tb->hash = newhash; +} + +static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) +{ + TString* ts; + stringtable* tb; + if (l > MAXSSIZE) + luaM_toobig(L); + ts = luaM_new(L, TString, sizestring(l), L->activememcat); + ts->len = unsigned(l); + ts->hash = h; + ts->marked = luaC_white(L->global); + ts->tt = LUA_TSTRING; + ts->memcat = L->activememcat; + memcpy(ts->data, str, l); + ts->data[l] = '\0'; /* ending 0 */ + ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, l) : -1; + tb = &L->global->strt; + h = lmod(h, tb->size); + ts->next = tb->hash[h]; /* chain new entry */ + tb->hash[h] = obj2gco(ts); + tb->nuse++; + if (tb->nuse > cast_to(uint32_t, tb->size) && tb->size <= INT_MAX / 2) + luaS_resize(L, tb->size * 2); /* too crowded */ + return ts; +} + +static void linkstrbuf(lua_State* L, TString* ts) +{ + global_State* g = L->global; + GCObject* o = obj2gco(ts); + o->gch.next = g->strbufgc; + g->strbufgc = o; + o->gch.marked = luaC_white(g); +} + +static void unlinkstrbuf(lua_State* L, TString* ts) +{ + global_State* g = L->global; + + GCObject** p = &g->strbufgc; + + while (GCObject* curr = *p) + { + if (curr == obj2gco(ts)) + { + *p = curr->gch.next; + return; + } + else + { + p = &curr->gch.next; + } + } + + LUAU_ASSERT(!"failed to find string buffer"); +} + +TString* luaS_bufstart(lua_State* L, size_t size) +{ + if (size > MAXSSIZE) + luaM_toobig(L); + + TString* ts = luaM_new(L, TString, sizestring(size), L->activememcat); + + ts->tt = LUA_TSTRING; + ts->memcat = L->activememcat; + linkstrbuf(L, ts); + + ts->len = unsigned(size); + + return ts; +} + +TString* luaS_buffinish(lua_State* L, TString* ts) +{ + unsigned int h = luaS_hash(ts->data, ts->len); + stringtable* tb = &L->global->strt; + int bucket = lmod(h, tb->size); + + // search if we already have this string in the hash table + for (GCObject* o = tb->hash[bucket]; o != NULL; o = o->gch.next) + { + TString* el = gco2ts(o); + + if (el->len == ts->len && memcmp(el->data, ts->data, ts->len) == 0) + { + // string may be dead + if (isdead(L->global, o)) + changewhite(o); + + return el; + } + } + + unlinkstrbuf(L, ts); + + ts->hash = h; + ts->data[ts->len] = '\0'; // ending 0 + + // Complete string object + ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; + ts->next = tb->hash[bucket]; // chain new entry + tb->hash[bucket] = obj2gco(ts); + + tb->nuse++; + if (tb->nuse > cast_to(uint32_t, tb->size) && tb->size <= INT_MAX / 2) + luaS_resize(L, tb->size * 2); // too crowded + + return ts; +} + +TString* luaS_newlstr(lua_State* L, const char* str, size_t l) +{ + GCObject* o; + unsigned int h = luaS_hash(str, l); + for (o = L->global->strt.hash[lmod(h, L->global->strt.size)]; o != NULL; o = o->gch.next) + { + TString* ts = gco2ts(o); + if (ts->len == l && (memcmp(str, getstr(ts), l) == 0)) + { + /* string may be dead */ + if (isdead(L->global, o)) + changewhite(o); + return ts; + } + } + return newlstr(L, str, l, h); /* not found */ +} + +void luaS_free(lua_State* L, TString* ts) +{ + L->global->strt.nuse--; + luaM_free(L, ts, sizestring(ts->len), ts->memcat); +} + +Udata* luaS_newudata(lua_State* L, size_t s, int tag) +{ + if (s > INT_MAX - sizeof(Udata)) + luaM_toobig(L); + Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); + luaC_link(L, u, LUA_TUSERDATA); + u->len = int(s); + u->metatable = NULL; + LUAU_ASSERT(tag >= 0 && tag <= 255); + u->tag = uint8_t(tag); + return u; +} + +void luaS_freeudata(lua_State* L, Udata* u) +{ + LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); + + void (*dtor)(void*) = nullptr; + if (u->tag == UTAG_IDTOR) + memcpy(&dtor, u->data + u->len - sizeof(dtor), sizeof(dtor)); + else if (u->tag) + dtor = L->global->udatagc[u->tag]; + + if (dtor) + dtor(u->data); + + luaM_free(L, u, sizeudata(u->len), u->memcat); +} diff --git a/VM/src/lstring.h b/VM/src/lstring.h new file mode 100644 index 0000000..612da28 --- /dev/null +++ b/VM/src/lstring.h @@ -0,0 +1,33 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" +#include "lstate.h" + +/* string size limit */ +#define MAXSSIZE (1 << 30) + +/* special tag value is used for user data with inline dtors */ +#define UTAG_IDTOR LUA_UTAG_LIMIT + +#define sizestring(len) (offsetof(TString, data) + len + 1) +#define sizeudata(len) (offsetof(Udata, data) + len) + +#define luaS_new(L, s) (luaS_newlstr(L, s, strlen(s))) +#define luaS_newliteral(L, s) (luaS_newlstr(L, "" s, (sizeof(s) / sizeof(char)) - 1)) + +#define luaS_fix(s) l_setbit((s)->marked, FIXEDBIT) + +LUAI_FUNC unsigned int luaS_hash(const char* str, size_t len); + +LUAI_FUNC void luaS_resize(lua_State* L, int newsize); + +LUAI_FUNC TString* luaS_newlstr(lua_State* L, const char* str, size_t l); +LUAI_FUNC void luaS_free(lua_State* L, TString* ts); + +LUAI_FUNC Udata* luaS_newudata(lua_State* L, size_t s, int tag); +LUAI_FUNC void luaS_freeudata(lua_State* L, Udata* u); + +LUAI_FUNC TString* luaS_bufstart(lua_State* L, size_t size); +LUAI_FUNC TString* luaS_buffinish(lua_State* L, TString* ts); diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp new file mode 100644 index 0000000..a9db372 --- /dev/null +++ b/VM/src/lstrlib.cpp @@ -0,0 +1,1654 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include "lstring.h" + +#include +#include +#include + +/* macro to `unsign' a character */ +#define uchar(c) ((unsigned char)(c)) + +static int str_len(lua_State* L) +{ + size_t l; + luaL_checklstring(L, 1, &l); + lua_pushinteger(L, (int)l); + return 1; +} + +static int posrelat(int pos, size_t len) +{ + /* relative string position: negative means back from end */ + if (pos < 0) + pos += (int)len + 1; + return (pos >= 0) ? pos : 0; +} + +static int str_sub(lua_State* L) +{ + size_t l; + const char* s = luaL_checklstring(L, 1, &l); + int start = posrelat(luaL_checkinteger(L, 2), l); + int end = posrelat(luaL_optinteger(L, 3, -1), l); + if (start < 1) + start = 1; + if (end > (int)l) + end = (int)l; + if (start <= end) + lua_pushlstring(L, s + start - 1, end - start + 1); + else + lua_pushliteral(L, ""); + return 1; +} + +static int str_reverse(lua_State* L) +{ + size_t l; + const char* s = luaL_checklstring(L, 1, &l); + luaL_Buffer b; + char* ptr = luaL_buffinitsize(L, &b, l); + while (l--) + *ptr++ = s[l]; + luaL_pushresultsize(&b, ptr - b.p); + return 1; +} + +static int str_lower(lua_State* L) +{ + size_t l; + const char* s = luaL_checklstring(L, 1, &l); + luaL_Buffer b; + char* ptr = luaL_buffinitsize(L, &b, l); + for (size_t i = 0; i < l; i++) + *ptr++ = tolower(uchar(s[i])); + luaL_pushresultsize(&b, l); + return 1; +} + +static int str_upper(lua_State* L) +{ + size_t l; + const char* s = luaL_checklstring(L, 1, &l); + luaL_Buffer b; + char* ptr = luaL_buffinitsize(L, &b, l); + for (size_t i = 0; i < l; i++) + *ptr++ = toupper(uchar(s[i])); + luaL_pushresultsize(&b, l); + return 1; +} + +static int str_rep(lua_State* L) +{ + size_t l; + const char* s = luaL_checklstring(L, 1, &l); + int n = luaL_checkinteger(L, 2); + + if (n <= 0) + { + lua_pushliteral(L, ""); + return 1; + } + + if (l > MAXSSIZE / (size_t)n) // may overflow? + luaL_error(L, "resulting string too large"); + + luaL_Buffer b; + char* ptr = luaL_buffinitsize(L, &b, l * n); + + const char* start = ptr; + + size_t left = l * n; + size_t step = l; + + memcpy(ptr, s, l); + ptr += l; + left -= l; + + // use the increasing 'pattern' inside our target buffer to fill the next part + while (step < left) + { + memcpy(ptr, start, step); + ptr += step; + left -= step; + step <<= 1; + } + + // fill tail + memcpy(ptr, start, left); + ptr += left; + + luaL_pushresultsize(&b, l * n); + + return 1; +} + +static int str_byte(lua_State* L) +{ + size_t l; + const char* s = luaL_checklstring(L, 1, &l); + int posi = posrelat(luaL_optinteger(L, 2, 1), l); + int pose = posrelat(luaL_optinteger(L, 3, posi), l); + int n, i; + if (posi <= 0) + posi = 1; + if ((size_t)pose > l) + pose = (int)l; + if (posi > pose) + return 0; /* empty interval; return no values */ + n = (int)(pose - posi + 1); + if (posi + n <= pose) /* overflow? */ + luaL_error(L, "string slice too long"); + luaL_checkstack(L, n, "string slice too long"); + for (i = 0; i < n; i++) + lua_pushinteger(L, uchar(s[posi + i - 1])); + return n; +} + +static int str_char(lua_State* L) +{ + int n = lua_gettop(L); /* number of arguments */ + + luaL_Buffer b; + char* ptr = luaL_buffinitsize(L, &b, n); + + for (int i = 1; i <= n; i++) + { + int c = luaL_checkinteger(L, i); + luaL_argcheck(L, uchar(c) == c, i, "invalid value"); + + *ptr++ = uchar(c); + } + luaL_pushresultsize(&b, n); + return 1; +} + +/* +** {====================================================== +** PATTERN MATCHING +** ======================================================= +*/ + +#define CAP_UNFINISHED (-1) +#define CAP_POSITION (-2) + +typedef struct MatchState +{ + int matchdepth; /* control for recursive depth (to avoid C stack overflow) */ + const char* src_init; /* init of source string */ + const char* src_end; /* end ('\0') of source string */ + const char* p_end; /* end ('\0') of pattern */ + lua_State* L; + int level; /* total number of captures (finished or unfinished) */ + struct + { + const char* init; + ptrdiff_t len; + } capture[LUA_MAXCAPTURES]; +} MatchState; + +/* recursive function */ +static const char* match(MatchState* ms, const char* s, const char* p); + +#define L_ESC '%' +#define SPECIALS "^$*+?.([%-" + +static int check_capture(MatchState* ms, int l) +{ + l -= '1'; + if (l < 0 || l >= ms->level || ms->capture[l].len == CAP_UNFINISHED) + luaL_error(ms->L, "invalid capture index %%%d", l + 1); + return l; +} + +static int capture_to_close(MatchState* ms) +{ + int level = ms->level; + for (level--; level >= 0; level--) + if (ms->capture[level].len == CAP_UNFINISHED) + return level; + luaL_error(ms->L, "invalid pattern capture"); +} + +static const char* classend(MatchState* ms, const char* p) +{ + switch (*p++) + { + case L_ESC: + { + if (p == ms->p_end) + luaL_error(ms->L, "malformed pattern (ends with '%%')"); + return p + 1; + } + case '[': + { + if (*p == '^') + p++; + do + { /* look for a `]' */ + if (p == ms->p_end) + luaL_error(ms->L, "malformed pattern (missing ']')"); + if (*(p++) == L_ESC && p < ms->p_end) + p++; /* skip escapes (e.g. `%]') */ + } while (*p != ']'); + return p + 1; + } + default: + { + return p; + } + } +} + +static int match_class(int c, int cl) +{ + int res; + switch (tolower(cl)) + { + case 'a': + res = isalpha(c); + break; + case 'c': + res = iscntrl(c); + break; + case 'd': + res = isdigit(c); + break; + case 'g': + res = isgraph(c); + break; + case 'l': + res = islower(c); + break; + case 'p': + res = ispunct(c); + break; + case 's': + res = isspace(c); + break; + case 'u': + res = isupper(c); + break; + case 'w': + res = isalnum(c); + break; + case 'x': + res = isxdigit(c); + break; + case 'z': + res = (c == 0); + break; /* deprecated option */ + default: + return (cl == c); + } + return (islower(cl) ? res : !res); +} + +static int matchbracketclass(int c, const char* p, const char* ec) +{ + int sig = 1; + if (*(p + 1) == '^') + { + sig = 0; + p++; /* skip the `^' */ + } + while (++p < ec) + { + if (*p == L_ESC) + { + p++; + if (match_class(c, uchar(*p))) + return sig; + } + else if ((*(p + 1) == '-') && (p + 2 < ec)) + { + p += 2; + if (uchar(*(p - 2)) <= c && c <= uchar(*p)) + return sig; + } + else if (uchar(*p) == c) + return sig; + } + return !sig; +} + +static int singlematch(MatchState* ms, const char* s, const char* p, const char* ep) +{ + if (s >= ms->src_end) + return 0; + else + { + int c = uchar(*s); + switch (*p) + { + case '.': + return 1; /* matches any char */ + case L_ESC: + return match_class(c, uchar(*(p + 1))); + case '[': + return matchbracketclass(c, p, ep - 1); + default: + return (uchar(*p) == c); + } + } +} + +static const char* matchbalance(MatchState* ms, const char* s, const char* p) +{ + if (p >= ms->p_end - 1) + luaL_error(ms->L, "malformed pattern (missing arguments to '%%b')"); + if (*s != *p) + return NULL; + else + { + int b = *p; + int e = *(p + 1); + int cont = 1; + while (++s < ms->src_end) + { + if (*s == e) + { + if (--cont == 0) + return s + 1; + } + else if (*s == b) + cont++; + } + } + return NULL; /* string ends out of balance */ +} + +static const char* max_expand(MatchState* ms, const char* s, const char* p, const char* ep) +{ + ptrdiff_t i = 0; /* counts maximum expand for item */ + while (singlematch(ms, s + i, p, ep)) + i++; + /* keeps trying to match with the maximum repetitions */ + while (i >= 0) + { + const char* res = match(ms, (s + i), ep + 1); + if (res) + return res; + i--; /* else didn't match; reduce 1 repetition to try again */ + } + return NULL; +} + +static const char* min_expand(MatchState* ms, const char* s, const char* p, const char* ep) +{ + for (;;) + { + const char* res = match(ms, s, ep + 1); + if (res != NULL) + return res; + else if (singlematch(ms, s, p, ep)) + s++; /* try with one more repetition */ + else + return NULL; + } +} + +static const char* start_capture(MatchState* ms, const char* s, const char* p, int what) +{ + const char* res; + int level = ms->level; + if (level >= LUA_MAXCAPTURES) + luaL_error(ms->L, "too many captures"); + ms->capture[level].init = s; + ms->capture[level].len = what; + ms->level = level + 1; + if ((res = match(ms, s, p)) == NULL) /* match failed? */ + ms->level--; /* undo capture */ + return res; +} + +static const char* end_capture(MatchState* ms, const char* s, const char* p) +{ + int l = capture_to_close(ms); + const char* res; + ms->capture[l].len = s - ms->capture[l].init; /* close capture */ + if ((res = match(ms, s, p)) == NULL) /* match failed? */ + ms->capture[l].len = CAP_UNFINISHED; /* undo capture */ + return res; +} + +static const char* match_capture(MatchState* ms, const char* s, int l) +{ + size_t len; + l = check_capture(ms, l); + len = ms->capture[l].len; + if ((size_t)(ms->src_end - s) >= len && memcmp(ms->capture[l].init, s, len) == 0) + return s + len; + else + return NULL; +} + +static const char* match(MatchState* ms, const char* s, const char* p) +{ + if (ms->matchdepth-- == 0) + luaL_error(ms->L, "pattern too complex"); +init: /* using goto's to optimize tail recursion */ + if (p != ms->p_end) + { /* end of pattern? */ + switch (*p) + { + case '(': + { /* start capture */ + if (*(p + 1) == ')') /* position capture? */ + s = start_capture(ms, s, p + 2, CAP_POSITION); + else + s = start_capture(ms, s, p + 1, CAP_UNFINISHED); + break; + } + case ')': + { /* end capture */ + s = end_capture(ms, s, p + 1); + break; + } + case '$': + { + if ((p + 1) != ms->p_end) /* is the `$' the last char in pattern? */ + goto dflt; /* no; go to default */ + s = (s == ms->src_end) ? s : NULL; /* check end of string */ + break; + } + case L_ESC: + { /* escaped sequences not in the format class[*+?-]? */ + switch (*(p + 1)) + { + case 'b': + { /* balanced string? */ + s = matchbalance(ms, s, p + 2); + if (s != NULL) + { + p += 4; + goto init; /* return match(ms, s, p + 4); */ + } /* else fail (s == NULL) */ + break; + } + case 'f': + { /* frontier? */ + const char* ep; + char previous; + p += 2; + if (*p != '[') + luaL_error(ms->L, "missing '[' after '%%f' in pattern"); + ep = classend(ms, p); /* points to what is next */ + previous = (s == ms->src_init) ? '\0' : *(s - 1); + if (!matchbracketclass(uchar(previous), p, ep - 1) && matchbracketclass(uchar(*s), p, ep - 1)) + { + p = ep; + goto init; /* return match(ms, s, ep); */ + } + s = NULL; /* match failed */ + break; + } + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + { /* capture results (%0-%9)? */ + s = match_capture(ms, s, uchar(*(p + 1))); + if (s != NULL) + { + p += 2; + goto init; /* return match(ms, s, p + 2) */ + } + break; + } + default: + goto dflt; + } + break; + } + default: + dflt: + { /* pattern class plus optional suffix */ + const char* ep = classend(ms, p); /* points to optional suffix */ + /* does not match at least once? */ + if (!singlematch(ms, s, p, ep)) + { + if (*ep == '*' || *ep == '?' || *ep == '-') + { /* accept empty? */ + p = ep + 1; + goto init; /* return match(ms, s, ep + 1); */ + } + else /* '+' or no suffix */ + s = NULL; /* fail */ + } + else + { /* matched once */ + switch (*ep) + { /* handle optional suffix */ + case '?': + { /* optional */ + const char* res; + if ((res = match(ms, s + 1, ep + 1)) != NULL) + s = res; + else + { + p = ep + 1; + goto init; /* else return match(ms, s, ep + 1); */ + } + break; + } + case '+': /* 1 or more repetitions */ + s++; /* 1 match already done */ + /* go through */ + case '*': /* 0 or more repetitions */ + s = max_expand(ms, s, p, ep); + break; + case '-': /* 0 or more repetitions (minimum) */ + s = min_expand(ms, s, p, ep); + break; + default: /* no suffix */ + s++; + p = ep; + goto init; /* return match(ms, s + 1, ep); */ + } + } + break; + } + } + } + ms->matchdepth++; + return s; +} + +static const char* lmemfind(const char* s1, size_t l1, const char* s2, size_t l2) +{ + if (l2 == 0) + return s1; /* empty strings are everywhere */ + else if (l2 > l1) + return NULL; /* avoids a negative `l1' */ + else + { + const char* init; /* to search for a `*s2' inside `s1' */ + l2--; /* 1st char will be checked by `memchr' */ + l1 = l1 - l2; /* `s2' cannot be found after that */ + while (l1 > 0 && (init = (const char*)memchr(s1, *s2, l1)) != NULL) + { + init++; /* 1st char is already checked */ + if (memcmp(init, s2 + 1, l2) == 0) + return init - 1; + else + { /* correct `l1' and `s1' to try again */ + l1 -= init - s1; + s1 = init; + } + } + return NULL; /* not found */ + } +} + +static void push_onecapture(MatchState* ms, int i, const char* s, const char* e) +{ + if (i >= ms->level) + { + if (i == 0) /* ms->level == 0, too */ + lua_pushlstring(ms->L, s, e - s); /* add whole match */ + else + luaL_error(ms->L, "invalid capture index"); + } + else + { + ptrdiff_t l = ms->capture[i].len; + if (l == CAP_UNFINISHED) + luaL_error(ms->L, "unfinished capture"); + if (l == CAP_POSITION) + lua_pushinteger(ms->L, (int)(ms->capture[i].init - ms->src_init) + 1); + else + lua_pushlstring(ms->L, ms->capture[i].init, l); + } +} + +static int push_captures(MatchState* ms, const char* s, const char* e) +{ + int i; + int nlevels = (ms->level == 0 && s) ? 1 : ms->level; + luaL_checkstack(ms->L, nlevels, "too many captures"); + for (i = 0; i < nlevels; i++) + push_onecapture(ms, i, s, e); + return nlevels; /* number of strings pushed */ +} + +/* check whether pattern has no special characters */ +static int nospecials(const char* p, size_t l) +{ + size_t upto = 0; + do + { + if (strpbrk(p + upto, SPECIALS)) + return 0; /* pattern has a special character */ + upto += strlen(p + upto) + 1; /* may have more after \0 */ + } while (upto <= l); + return 1; /* no special chars found */ +} + +static void prepstate(MatchState* ms, lua_State* L, const char* s, size_t ls, const char* p, size_t lp) +{ + ms->L = L; + ms->matchdepth = LUAI_MAXCCALLS; + ms->src_init = s; + ms->src_end = s + ls; + ms->p_end = p + lp; +} + +static void reprepstate(MatchState* ms) +{ + ms->level = 0; + LUAU_ASSERT(ms->matchdepth == LUAI_MAXCCALLS); +} + +static int str_find_aux(lua_State* L, int find) +{ + size_t ls, lp; + const char* s = luaL_checklstring(L, 1, &ls); + const char* p = luaL_checklstring(L, 2, &lp); + int init = posrelat(luaL_optinteger(L, 3, 1), ls); + if (init < 1) + init = 1; + else if (init > (int)ls + 1) + { /* start after string's end? */ + lua_pushnil(L); /* cannot find anything */ + return 1; + } + /* explicit request or no special characters? */ + if (find && (lua_toboolean(L, 4) || nospecials(p, lp))) + { + /* do a plain search */ + const char* s2 = lmemfind(s + init - 1, ls - init + 1, p, lp); + if (s2) + { + lua_pushinteger(L, (int)(s2 - s + 1)); + lua_pushinteger(L, (int)(s2 - s + lp)); + return 2; + } + } + else + { + MatchState ms; + const char* s1 = s + init - 1; + int anchor = (*p == '^'); + if (anchor) + { + p++; + lp--; /* skip anchor character */ + } + prepstate(&ms, L, s, ls, p, lp); + do + { + const char* res; + reprepstate(&ms); + if ((res = match(&ms, s1, p)) != NULL) + { + if (find) + { + lua_pushinteger(L, (int)(s1 - s + 1)); /* start */ + lua_pushinteger(L, (int)(res - s)); /* end */ + return push_captures(&ms, NULL, 0) + 2; + } + else + return push_captures(&ms, s1, res); + } + } while (s1++ < ms.src_end && !anchor); + } + lua_pushnil(L); /* not found */ + return 1; +} + +static int str_find(lua_State* L) +{ + return str_find_aux(L, 1); +} + +static int str_match(lua_State* L) +{ + return str_find_aux(L, 0); +} + +static int gmatch_aux(lua_State* L) +{ + MatchState ms; + size_t ls, lp; + const char* s = lua_tolstring(L, lua_upvalueindex(1), &ls); + const char* p = lua_tolstring(L, lua_upvalueindex(2), &lp); + const char* src; + prepstate(&ms, L, s, ls, p, lp); + for (src = s + (size_t)lua_tointeger(L, lua_upvalueindex(3)); src <= ms.src_end; src++) + { + const char* e; + reprepstate(&ms); + if ((e = match(&ms, src, p)) != NULL) + { + int newstart = (int)(e - s); + if (e == src) + newstart++; /* empty match? go at least one position */ + lua_pushinteger(L, newstart); + lua_replace(L, lua_upvalueindex(3)); + return push_captures(&ms, src, e); + } + } + return 0; /* not found */ +} + +static int gmatch(lua_State* L) +{ + luaL_checkstring(L, 1); + luaL_checkstring(L, 2); + lua_settop(L, 2); + lua_pushinteger(L, 0); + lua_pushcfunction(L, gmatch_aux, NULL, 3); + return 1; +} + +static void add_s(MatchState* ms, luaL_Buffer* b, const char* s, const char* e) +{ + size_t l, i; + const char* news = lua_tolstring(ms->L, 3, &l); + + luaL_reservebuffer(b, l, -1); + + for (i = 0; i < l; i++) + { + if (news[i] != L_ESC) + luaL_addchar(b, news[i]); + else + { + i++; /* skip ESC */ + if (!isdigit(uchar(news[i]))) + { + if (news[i] != L_ESC) + luaL_error(ms->L, "invalid use of '%c' in replacement string", L_ESC); + luaL_addchar(b, news[i]); + } + else if (news[i] == '0') + luaL_addlstring(b, s, e - s); + else + { + push_onecapture(ms, news[i] - '1', s, e); + luaL_addvalue(b); /* add capture to accumulated result */ + } + } + } +} + +static void add_value(MatchState* ms, luaL_Buffer* b, const char* s, const char* e, int tr) +{ + lua_State* L = ms->L; + switch (tr) + { + case LUA_TFUNCTION: + { + int n; + lua_pushvalue(L, 3); + n = push_captures(ms, s, e); + lua_call(L, n, 1); + break; + } + case LUA_TTABLE: + { + push_onecapture(ms, 0, s, e); + lua_gettable(L, 3); + break; + } + default: + { /* LUA_TNUMBER or LUA_TSTRING */ + add_s(ms, b, s, e); + return; + } + } + if (!lua_toboolean(L, -1)) + { /* nil or false? */ + lua_pop(L, 1); + lua_pushlstring(L, s, e - s); /* keep original text */ + } + else if (!lua_isstring(L, -1)) + luaL_error(L, "invalid replacement value (a %s)", luaL_typename(L, -1)); + luaL_addvalue(b); /* add result to accumulator */ +} + +static int str_gsub(lua_State* L) +{ + size_t srcl, lp; + const char* src = luaL_checklstring(L, 1, &srcl); + const char* p = luaL_checklstring(L, 2, &lp); + int tr = lua_type(L, 3); + int max_s = luaL_optinteger(L, 4, (int)srcl + 1); + int anchor = (*p == '^'); + int n = 0; + MatchState ms; + luaL_Buffer b; + luaL_argexpected(L, tr == LUA_TNUMBER || tr == LUA_TSTRING || tr == LUA_TFUNCTION || tr == LUA_TTABLE, 3, "string/function/table"); + luaL_buffinit(L, &b); + if (anchor) + { + p++; + lp--; /* skip anchor character */ + } + prepstate(&ms, L, src, srcl, p, lp); + while (n < max_s) + { + const char* e; + reprepstate(&ms); + e = match(&ms, src, p); + if (e) + { + n++; + add_value(&ms, &b, src, e, tr); + } + if (e && e > src) /* non empty match? */ + src = e; /* skip it */ + else if (src < ms.src_end) + luaL_addchar(&b, *src++); + else + break; + if (anchor) + break; + } + luaL_addlstring(&b, src, ms.src_end - src); + luaL_pushresult(&b); + lua_pushinteger(L, n); /* number of substitutions */ + return 2; +} + +/* }====================================================== */ + +/* valid flags in a format specification */ +#define FLAGS "-+ #0" +/* maximum size of each formatted item (> len(format('%99.99f', -1e308))) */ +#define MAX_ITEM 512 +/* maximum size of each format specification (such as '%-099.99d') */ +#define MAX_FORMAT 32 + +static void addquoted(lua_State* L, luaL_Buffer* b, int arg) +{ + size_t l; + const char* s = luaL_checklstring(L, arg, &l); + + luaL_reservebuffer(b, l + 2, -1); + + luaL_addchar(b, '"'); + while (l--) + { + switch (*s) + { + case '"': + case '\\': + case '\n': + { + luaL_addchar(b, '\\'); + luaL_addchar(b, *s); + break; + } + case '\r': + { + luaL_addlstring(b, "\\r", 2); + break; + } + case '\0': + { + luaL_addlstring(b, "\\000", 4); + break; + } + default: + { + luaL_addchar(b, *s); + break; + } + } + s++; + } + luaL_addchar(b, '"'); +} + +static const char* scanformat(lua_State* L, const char* strfrmt, char* form, size_t* size) +{ + const char* p = strfrmt; + while (*p != '\0' && strchr(FLAGS, *p) != NULL) + p++; /* skip flags */ + if ((size_t)(p - strfrmt) >= sizeof(FLAGS)) + luaL_error(L, "invalid format (repeated flags)"); + if (isdigit(uchar(*p))) + p++; /* skip width */ + if (isdigit(uchar(*p))) + p++; /* (2 digits at most) */ + if (*p == '.') + { + p++; + if (isdigit(uchar(*p))) + p++; /* skip precision */ + if (isdigit(uchar(*p))) + p++; /* (2 digits at most) */ + } + if (isdigit(uchar(*p))) + luaL_error(L, "invalid format (width or precision too long)"); + *(form++) = '%'; + *size = p - strfrmt + 1; + strncpy(form, strfrmt, *size); + form += *size; + *form = '\0'; + return p; +} + +static void addInt64Format(char form[MAX_FORMAT], char formatIndicator, size_t formatItemSize) +{ + LUAU_ASSERT((formatItemSize + 3) <= MAX_FORMAT); + LUAU_ASSERT(form[0] == '%'); + LUAU_ASSERT(form[formatItemSize] != 0); + LUAU_ASSERT(form[formatItemSize + 1] == 0); + form[formatItemSize + 0] = 'l'; + form[formatItemSize + 1] = 'l'; + form[formatItemSize + 2] = formatIndicator; + form[formatItemSize + 3] = 0; +} + +static int str_format(lua_State* L) +{ + int top = lua_gettop(L); + int arg = 1; + size_t sfl; + const char* strfrmt = luaL_checklstring(L, arg, &sfl); + const char* strfrmt_end = strfrmt + sfl; + luaL_Buffer b; + luaL_buffinit(L, &b); + while (strfrmt < strfrmt_end) + { + if (*strfrmt != L_ESC) + luaL_addchar(&b, *strfrmt++); + else if (*++strfrmt == L_ESC) + luaL_addchar(&b, *strfrmt++); /* %% */ + else + { /* format item */ + char form[MAX_FORMAT]; /* to store the format (`%...') */ + char buff[MAX_ITEM]; /* to store the formatted item */ + if (++arg > top) + luaL_error(L, "missing argument #%d", arg); + size_t formatItemSize = 0; + strfrmt = scanformat(L, strfrmt, form, &formatItemSize); + char formatIndicator = *strfrmt++; + switch (formatIndicator) + { + case 'c': + { + sprintf(buff, form, (int)luaL_checknumber(L, arg)); + break; + } + case 'd': + case 'i': + { + addInt64Format(form, formatIndicator, formatItemSize); + sprintf(buff, form, (long long)luaL_checknumber(L, arg)); + break; + } + case 'o': + case 'u': + case 'x': + case 'X': + { + double argValue = luaL_checknumber(L, arg); + addInt64Format(form, formatIndicator, formatItemSize); + unsigned long long v = (argValue < 0) ? (unsigned long long)(long long)argValue : (unsigned long long)argValue; + sprintf(buff, form, v); + break; + } + case 'e': + case 'E': + case 'f': + case 'g': + case 'G': + { + sprintf(buff, form, (double)luaL_checknumber(L, arg)); + break; + } + case 'q': + { + addquoted(L, &b, arg); + continue; /* skip the 'addsize' at the end */ + } + case 's': + { + size_t l; + const char* s = luaL_checklstring(L, arg, &l); + if (!strchr(form, '.') && l >= 100) + { + /* no precision and string is too long to be formatted; + keep original string */ + lua_pushvalue(L, arg); + luaL_addvalue(&b); + continue; /* skip the `addsize' at the end */ + } + else + { + sprintf(buff, form, s); + break; + } + } + default: + { /* also treat cases `pnLlh' */ + luaL_error(L, "invalid option '%%%c' to 'format'", *(strfrmt - 1)); + } + } + luaL_addlstring(&b, buff, strlen(buff)); + } + } + luaL_pushresult(&b); + return 1; +} + +static int str_split(lua_State* L) +{ + size_t haystackLen; + const char* haystack = luaL_checklstring(L, 1, &haystackLen); + size_t needleLen; + const char* needle = luaL_optlstring(L, 2, ",", &needleLen); + + const char* begin = haystack; + const char* end = haystack + haystackLen; + const char* spanStart = begin; + int numMatches = 0; + + lua_createtable(L, 0, 0); + + if (needleLen == 0) + begin++; + + // Don't iterate the last needleLen - 1 bytes of the string - they are + // impossible to be splits and would let us memcmp past the end of the + // buffer. + for (const char* iter = begin; iter <= end - needleLen; iter++) + { + // Use of memcmp here instead of strncmp is so that we allow embedded + // nulls to be used in either of the haystack or the needle strings. + // Most Lua string APIs allow embedded nulls, and this should be no + // exception. + if (memcmp(iter, needle, needleLen) == 0) + { + lua_pushinteger(L, ++numMatches); + lua_pushlstring(L, spanStart, iter - spanStart); + lua_settable(L, -3); + + spanStart = iter + needleLen; + if (needleLen > 0) + iter += needleLen - 1; + } + } + + if (needleLen > 0) + { + lua_pushinteger(L, ++numMatches); + lua_pushlstring(L, spanStart, end - spanStart); + lua_settable(L, -3); + } + + return 1; +} + +/* +** {====================================================== +** PACK/UNPACK +** ======================================================= +*/ + +/* value used for padding */ +#if !defined(LUAL_PACKPADBYTE) +#define LUAL_PACKPADBYTE 0x00 +#endif + +/* maximum size for the binary representation of an integer */ +#define MAXINTSIZE 16 + +/* number of bits in a character */ +#define NB CHAR_BIT + +/* mask for one character (NB 1's) */ +#define MC ((1 << NB) - 1) + +/* internal size of integers used for pack/unpack */ +#define SZINT (int)sizeof(long long) + +/* dummy union to get native endianness */ +static const union +{ + int dummy; + char little; /* true iff machine is little endian */ +} nativeendian = {1}; + +/* assume we need to align for double & pointers */ +#define MAXALIGN 8 + +/* +** Union for serializing floats +*/ +typedef union Ftypes +{ + float f; + double d; + double n; + char buff[5 * sizeof(double)]; /* enough for any float type */ +} Ftypes; + +/* +** information to pack/unpack stuff +*/ +typedef struct Header +{ + lua_State* L; + int islittle; + int maxalign; +} Header; + +/* +** options for pack/unpack +*/ +typedef enum KOption +{ + Kint, /* signed integers */ + Kuint, /* unsigned integers */ + Kfloat, /* floating-point numbers */ + Kchar, /* fixed-length strings */ + Kstring, /* strings with prefixed length */ + Kzstr, /* zero-terminated strings */ + Kpadding, /* padding */ + Kpaddalign, /* padding for alignment */ + Knop /* no-op (configuration or spaces) */ +} KOption; + +/* +** Read an integer numeral from string 'fmt' or return 'df' if +** there is no numeral +*/ +static int digit(int c) +{ + return '0' <= c && c <= '9'; +} + +static int getnum(Header* h, const char** fmt, int df) +{ + if (!digit(**fmt)) /* no number? */ + return df; /* return default value */ + else + { + int a = 0; + do + { + a = a * 10 + (*((*fmt)++) - '0'); + } while (digit(**fmt) && a <= (INT_MAX - 9) / 10); + if (a > MAXSSIZE || digit(**fmt)) + luaL_error(h->L, "size specifier is too large"); + return a; + } +} + +/* +** Read an integer numeral and raises an error if it is larger +** than the maximum size for integers. +*/ +static int getnumlimit(Header* h, const char** fmt, int df) +{ + int sz = getnum(h, fmt, df); + if (sz > MAXINTSIZE || sz <= 0) + luaL_error(h->L, "integral size (%d) out of limits [1,%d]", sz, MAXINTSIZE); + return sz; +} + +/* +** Initialize Header +*/ +static void initheader(lua_State* L, Header* h) +{ + h->L = L; + h->islittle = nativeendian.little; + h->maxalign = 1; +} + +/* +** Read and classify next option. 'size' is filled with option's size. +*/ +static KOption getoption(Header* h, const char** fmt, int* size) +{ + int opt = *((*fmt)++); + *size = 0; /* default */ + switch (opt) + { + case 'b': + *size = 1; + return Kint; + case 'B': + *size = 1; + return Kuint; + case 'h': + *size = 2; + return Kint; + case 'H': + *size = 2; + return Kuint; + case 'l': + *size = 8; + return Kint; + case 'L': + *size = 8; + return Kuint; + case 'j': + *size = 4; + return Kint; + case 'J': + *size = 4; + return Kuint; + case 'T': + *size = 4; + return Kuint; + case 'f': + *size = 4; + return Kfloat; + case 'd': + *size = 8; + return Kfloat; + case 'n': + *size = 8; + return Kfloat; + case 'i': + *size = getnumlimit(h, fmt, 4); + return Kint; + case 'I': + *size = getnumlimit(h, fmt, 4); + return Kuint; + case 's': + *size = getnumlimit(h, fmt, 4); + return Kstring; + case 'c': + *size = getnum(h, fmt, -1); + if (*size == -1) + luaL_error(h->L, "missing size for format option 'c'"); + return Kchar; + case 'z': + return Kzstr; + case 'x': + *size = 1; + return Kpadding; + case 'X': + return Kpaddalign; + case ' ': + break; + case '<': + h->islittle = 1; + break; + case '>': + h->islittle = 0; + break; + case '=': + h->islittle = nativeendian.little; + break; + case '!': + h->maxalign = getnumlimit(h, fmt, MAXALIGN); + break; + default: + luaL_error(h->L, "invalid format option '%c'", opt); + } + return Knop; +} + +/* +** Read, classify, and fill other details about the next option. +** 'psize' is filled with option's size, 'notoalign' with its +** alignment requirements. +** Local variable 'size' gets the size to be aligned. (Kpadal option +** always gets its full alignment, other options are limited by +** the maximum alignment ('maxalign'). Kchar option needs no alignment +** despite its size. +*/ +static KOption getdetails(Header* h, size_t totalsize, const char** fmt, int* psize, int* ntoalign) +{ + KOption opt = getoption(h, fmt, psize); + int align = *psize; /* usually, alignment follows size */ + if (opt == Kpaddalign) + { /* 'X' gets alignment from following option */ + if (**fmt == '\0' || getoption(h, fmt, &align) == Kchar || align == 0) + luaL_argerror(h->L, 1, "invalid next option for option 'X'"); + } + if (align <= 1 || opt == Kchar) /* need no alignment? */ + *ntoalign = 0; + else + { + if (align > h->maxalign) /* enforce maximum alignment */ + align = h->maxalign; + if ((align & (align - 1)) != 0) /* is 'align' not a power of 2? */ + luaL_argerror(h->L, 1, "format asks for alignment not power of 2"); + *ntoalign = (align - (int)(totalsize & (align - 1))) & (align - 1); + } + return opt; +} + +/* +** Pack integer 'n' with 'size' bytes and 'islittle' endianness. +** The final 'if' handles the case when 'size' is larger than +** the size of a Lua integer, correcting the extra sign-extension +** bytes if necessary (by default they would be zeros). +*/ +static void packint(luaL_Buffer* b, unsigned long long n, int islittle, int size, int neg) +{ + LUAU_ASSERT(size <= MAXINTSIZE); + char buff[MAXINTSIZE]; + int i; + buff[islittle ? 0 : size - 1] = (char)(n & MC); /* first byte */ + for (i = 1; i < size; i++) + { + n >>= NB; + buff[islittle ? i : size - 1 - i] = (char)(n & MC); + } + if (neg && size > SZINT) + { /* negative number need sign extension? */ + for (i = SZINT; i < size; i++) /* correct extra bytes */ + buff[islittle ? i : size - 1 - i] = (char)MC; + } + luaL_addlstring(b, buff, size); /* add result to buffer */ +} + +/* +** Copy 'size' bytes from 'src' to 'dest', correcting endianness if +** given 'islittle' is different from native endianness. +*/ +static void copywithendian(volatile char* dest, volatile const char* src, int size, int islittle) +{ + if (islittle == nativeendian.little) + { + while (size-- != 0) + *(dest++) = *(src++); + } + else + { + dest += size - 1; + while (size-- != 0) + *(dest--) = *(src++); + } +} + +static int str_pack(lua_State* L) +{ + luaL_Buffer b; + Header h; + const char* fmt = luaL_checkstring(L, 1); /* format string */ + int arg = 1; /* current argument to pack */ + size_t totalsize = 0; /* accumulate total size of result */ + initheader(L, &h); + lua_pushnil(L); /* mark to separate arguments from string buffer */ + luaL_buffinit(L, &b); + while (*fmt != '\0') + { + int size, ntoalign; + KOption opt = getdetails(&h, totalsize, &fmt, &size, &ntoalign); + totalsize += ntoalign + size; + while (ntoalign-- > 0) + luaL_addchar(&b, LUAL_PACKPADBYTE); /* fill alignment */ + arg++; + switch (opt) + { + case Kint: + { /* signed integers */ + long long n = (long long)luaL_checknumber(L, arg); + if (size < SZINT) + { /* need overflow check? */ + long long lim = (long long)1 << ((size * NB) - 1); + luaL_argcheck(L, -lim <= n && n < lim, arg, "integer overflow"); + } + packint(&b, n, h.islittle, size, (n < 0)); + break; + } + case Kuint: + { /* unsigned integers */ + unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, n, h.islittle, size, 0); + break; + } + case Kfloat: + { /* floating-point options */ + volatile Ftypes u; + char buff[MAXINTSIZE]; + double n = luaL_checknumber(L, arg); /* get argument */ + if (size == sizeof(u.f)) + u.f = (float)n; /* copy it into 'u' */ + else if (size == sizeof(u.d)) + u.d = (double)n; + else + u.n = n; + /* move 'u' to final result, correcting endianness if needed */ + copywithendian(buff, u.buff, size, h.islittle); + luaL_addlstring(&b, buff, size); + break; + } + case Kchar: + { /* fixed-size string */ + size_t len; + const char* s = luaL_checklstring(L, arg, &len); + luaL_argcheck(L, len <= (size_t)size, arg, "string longer than given size"); + luaL_addlstring(&b, s, len); /* add string */ + while (len++ < (size_t)size) /* pad extra space */ + luaL_addchar(&b, LUAL_PACKPADBYTE); + break; + } + case Kstring: + { /* strings with length count */ + size_t len; + const char* s = luaL_checklstring(L, arg, &len); + luaL_argcheck(L, size >= (int)sizeof(size_t) || len < ((size_t)1 << (size * NB)), arg, "string length does not fit in given size"); + packint(&b, len, h.islittle, size, 0); /* pack length */ + luaL_addlstring(&b, s, len); + totalsize += len; + break; + } + case Kzstr: + { /* zero-terminated string */ + size_t len; + const char* s = luaL_checklstring(L, arg, &len); + luaL_argcheck(L, strlen(s) == len, arg, "string contains zeros"); + luaL_addlstring(&b, s, len); + luaL_addchar(&b, '\0'); /* add zero at the end */ + totalsize += len + 1; + break; + } + case Kpadding: + luaL_addchar(&b, LUAL_PACKPADBYTE); /* FALLTHROUGH */ + case Kpaddalign: + case Knop: + arg--; /* undo increment */ + break; + } + } + luaL_pushresult(&b); + return 1; +} + +static int str_packsize(lua_State* L) +{ + Header h; + const char* fmt = luaL_checkstring(L, 1); /* format string */ + int totalsize = 0; /* accumulate total size of result */ + initheader(L, &h); + while (*fmt != '\0') + { + int size, ntoalign; + KOption opt = getdetails(&h, totalsize, &fmt, &size, &ntoalign); + luaL_argcheck(L, opt != Kstring && opt != Kzstr, 1, "variable-length format"); + size += ntoalign; /* total space used by option */ + luaL_argcheck(L, totalsize <= MAXSSIZE - size, 1, "format result too large"); + totalsize += size; + } + lua_pushinteger(L, totalsize); + return 1; +} + +/* +** Unpack an integer with 'size' bytes and 'islittle' endianness. +** If size is smaller than the size of a Lua integer and integer +** is signed, must do sign extension (propagating the sign to the +** higher bits); if size is larger than the size of a Lua integer, +** it must check the unread bytes to see whether they do not cause an +** overflow. +*/ +static long long unpackint(lua_State* L, const char* str, int islittle, int size, int issigned) +{ + unsigned long long res = 0; + int i; + int limit = (size <= SZINT) ? size : SZINT; + for (i = limit - 1; i >= 0; i--) + { + res <<= NB; + res |= (unsigned char)str[islittle ? i : size - 1 - i]; + } + if (size < SZINT) + { /* real size smaller than int? */ + if (issigned) + { /* needs sign extension? */ + unsigned long long mask = (unsigned long long)1 << (size * NB - 1); + res = ((res ^ mask) - mask); /* do sign extension */ + } + } + else if (size > SZINT) + { /* must check unread bytes */ + int mask = (!issigned || (long long)res >= 0) ? 0 : MC; + for (i = limit; i < size; i++) + { + if ((unsigned char)str[islittle ? i : size - 1 - i] != mask) + luaL_error(L, "%d-byte integer does not fit into Lua Integer", size); + } + } + return (long long)res; +} + +static int str_unpack(lua_State* L) +{ + Header h; + const char* fmt = luaL_checkstring(L, 1); + size_t ld; + const char* data = luaL_checklstring(L, 2, &ld); + int pos = posrelat(luaL_optinteger(L, 3, 1), ld) - 1; + if (pos < 0) + pos = 0; + int n = 0; /* number of results */ + luaL_argcheck(L, size_t(pos) <= ld, 3, "initial position out of string"); + initheader(L, &h); + while (*fmt != '\0') + { + int size, ntoalign; + KOption opt = getdetails(&h, pos, &fmt, &size, &ntoalign); + luaL_argcheck(L, (size_t)ntoalign + size <= ld - pos, 2, "data string too short"); + pos += ntoalign; /* skip alignment */ + /* stack space for item + next position */ + luaL_checkstack(L, 2, "too many results"); + n++; + switch (opt) + { + case Kint: + { + long long res = unpackint(L, data + pos, h.islittle, size, true); + lua_pushnumber(L, (double)res); + break; + } + case Kuint: + { + unsigned long long res = unpackint(L, data + pos, h.islittle, size, false); + lua_pushnumber(L, (double)res); + break; + } + case Kfloat: + { + volatile Ftypes u; + double num; + copywithendian(u.buff, data + pos, size, h.islittle); + if (size == sizeof(u.f)) + num = (double)u.f; + else if (size == sizeof(u.d)) + num = (double)u.d; + else + num = u.n; + lua_pushnumber(L, num); + break; + } + case Kchar: + { + lua_pushlstring(L, data + pos, size); + break; + } + case Kstring: + { + size_t len = (size_t)unpackint(L, data + pos, h.islittle, size, 0); + luaL_argcheck(L, len <= ld - pos - size, 2, "data string too short"); + lua_pushlstring(L, data + pos + size, len); + pos += (int)len; /* skip string */ + break; + } + case Kzstr: + { + size_t len = strlen(data + pos); + luaL_argcheck(L, pos + len < ld, 2, "unfinished string for format 'z'"); + lua_pushlstring(L, data + pos, len); + pos += (int)len + 1; /* skip string plus final '\0' */ + break; + } + case Kpaddalign: + case Kpadding: + case Knop: + n--; /* undo increment */ + break; + } + pos += size; + } + lua_pushinteger(L, pos + 1); /* next position */ + return n + 1; +} + +/* }====================================================== */ + +static const luaL_Reg strlib[] = { + {"byte", str_byte}, + {"char", str_char}, + {"find", str_find}, + {"format", str_format}, + {"gmatch", gmatch}, + {"gsub", str_gsub}, + {"len", str_len}, + {"lower", str_lower}, + {"match", str_match}, + {"rep", str_rep}, + {"reverse", str_reverse}, + {"sub", str_sub}, + {"upper", str_upper}, + {"split", str_split}, + {"pack", str_pack}, + {"packsize", str_packsize}, + {"unpack", str_unpack}, + {NULL, NULL}, +}; + +static void createmetatable(lua_State* L) +{ + lua_createtable(L, 0, 1); /* create metatable for strings */ + lua_pushliteral(L, ""); /* dummy string */ + lua_pushvalue(L, -2); + lua_setmetatable(L, -2); /* set string metatable */ + lua_pop(L, 1); /* pop dummy string */ + lua_pushvalue(L, -2); /* string library... */ + lua_setfield(L, -2, "__index"); /* ...is the __index metamethod */ + lua_pop(L, 1); /* pop metatable */ +} + +/* +** Open string library +*/ +LUALIB_API int luaopen_string(lua_State* L) +{ + luaL_register(L, LUA_STRLIBNAME, strlib); + createmetatable(L); + + return 1; +} diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp new file mode 100644 index 0000000..883442a --- /dev/null +++ b/VM/src/ltable.cpp @@ -0,0 +1,799 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details + +/* +** Implementation of tables (aka arrays, objects, or hash tables). +** Tables keep its elements in two parts: an array part and a hash part. +** Non-negative integer keys are all candidates to be kept in the array +** part. The actual size of the array is the largest `n' such that at +** least half the slots between 0 and n are in use. +** Hash uses a mix of chained scatter table with Brent's variation. +** A main invariant of these tables is that, if an element is not +** in its main position (i.e. the `original' position that its hash gives +** to it), then the colliding element is in its own main position. +** Hence even when the load factor reaches 100%, performance remains good. +*/ + +#include "ltable.h" + +#include "lstate.h" +#include "ldebug.h" +#include "lgc.h" +#include "lmem.h" +#include "lnumutils.h" + +#include + +LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false) + +// max size of both array and hash part is 2^MAXBITS +#define MAXBITS 26 +#define MAXSIZE (1 << MAXBITS) + +// TKey is bitpacked for memory efficiency so we need to validate bit counts for worst case +static_assert(TKey{{NULL}, 0, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt"); +static_assert(TKey{{NULL}, 0, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); +static_assert(TKey{{NULL}, 0, LUA_TNIL, -(MAXSIZE - 1)}.next == -(MAXSIZE - 1), "not enough bits for next"); + +// reset cache of absent metamethods, cache is updated in luaT_gettm +#define invalidateTMcache(t) t->flags = 0 + +// empty hash data points to dummynode so that we can always dereference it +const LuaNode luaH_dummynode = { + {{NULL}, 0, LUA_TNIL}, /* value */ + {{NULL}, 0, LUA_TNIL, 0} /* key */ +}; + +#define dummynode (&luaH_dummynode) + +// hash is always reduced mod 2^k +#define hashpow2(t, n) (gnode(t, lmod((n), sizenode(t)))) + +#define hashstr(t, str) hashpow2(t, (str)->hash) +#define hashboolean(t, p) hashpow2(t, p) + +static LuaNode* hashpointer(const Table* t, const void* p) +{ + // we discard the high 32-bit portion of the pointer on 64-bit platforms as it doesn't carry much entropy anyway + unsigned int h = unsigned(uintptr_t(p)); + + // MurmurHash3 32-bit finalizer + h ^= h >> 16; + h *= 0x85ebca6bu; + h ^= h >> 13; + h *= 0xc2b2ae35u; + h ^= h >> 16; + + return hashpow2(t, h); +} + +static LuaNode* hashnum(const Table* t, double n) +{ + static_assert(sizeof(double) == sizeof(unsigned int) * 2, "expected a 8-byte double"); + unsigned int i[2]; + memcpy(i, &n, sizeof(i)); + + // mask out sign bit to make sure -0 and 0 hash to the same value + uint32_t h1 = i[0]; + uint32_t h2 = i[1] & 0x7fffffff; + + // finalizer from MurmurHash64B + const uint32_t m = 0x5bd1e995; + + h1 ^= h2 >> 18; + h1 *= m; + h2 ^= h1 >> 22; + h2 *= m; + h1 ^= h2 >> 17; + h1 *= m; + h2 ^= h1 >> 19; + h2 *= m; + + // ... truncated to 32-bit output (normally hash is equal to (uint64_t(h1) << 32) | h2, but we only really need the lower 32-bit half) + return hashpow2(t, h2); +} + +static LuaNode* hashvec(const Table* t, const float* v) +{ + unsigned int i[3]; + memcpy(i, v, sizeof(i)); + + // convert -0 to 0 to make sure they hash to the same value + i[0] = (i[0] == 0x8000000) ? 0 : i[0]; + i[1] = (i[1] == 0x8000000) ? 0 : i[1]; + i[2] = (i[2] == 0x8000000) ? 0 : i[2]; + + // scramble bits to make sure that integer coordinates have entropy in lower bits + i[0] ^= i[0] >> 17; + i[1] ^= i[1] >> 17; + i[2] ^= i[2] >> 17; + + // Optimized Spatial Hashing for Collision Detection of Deformable Objects + unsigned int h = (i[0] * 73856093) ^ (i[1] * 19349663) ^ (i[2] * 83492791); + + return hashpow2(t, h); +} + +/* +** returns the `main' position of an element in a table (that is, the index +** of its hash value) +*/ +static LuaNode* mainposition(const Table* t, const TValue* key) +{ + switch (ttype(key)) + { + case LUA_TNUMBER: + return hashnum(t, nvalue(key)); + case LUA_TVECTOR: + return hashvec(t, vvalue(key)); + case LUA_TSTRING: + return hashstr(t, tsvalue(key)); + case LUA_TBOOLEAN: + return hashboolean(t, bvalue(key)); + case LUA_TLIGHTUSERDATA: + return hashpointer(t, pvalue(key)); + default: + return hashpointer(t, gcvalue(key)); + } +} + +/* +** returns the index for `key' if `key' is an appropriate key to live in +** the array part of the table, -1 otherwise. +*/ +static int arrayindex(double key) +{ + int i; + luai_num2int(i, key); + + return luai_numeq(cast_num(i), key) ? i : -1; +} + +/* +** returns the index of a `key' for table traversals. First goes all +** elements in the array part, then elements in the hash part. The +** beginning of a traversal is signalled by -1. +*/ +static int findindex(lua_State* L, Table* t, StkId key) +{ + int i; + if (ttisnil(key)) + return -1; /* first iteration */ + i = ttisnumber(key) ? arrayindex(nvalue(key)) : -1; + if (0 < i && i <= t->sizearray) /* is `key' inside array part? */ + return i - 1; /* yes; that's the index (corrected to C) */ + else + { + LuaNode* n = mainposition(t, key); + for (;;) + { /* check whether `key' is somewhere in the chain */ + /* key may be dead already, but it is ok to use it in `next' */ + if (luaO_rawequalKey(gkey(n), key) || (ttype(gkey(n)) == LUA_TDEADKEY && iscollectable(key) && gcvalue(gkey(n)) == gcvalue(key))) + { + i = cast_int(n - gnode(t, 0)); /* key index in hash table */ + /* hash elements are numbered after array ones */ + return i + t->sizearray; + } + if (gnext(n) == 0) + break; + n += gnext(n); + } + luaG_runerror(L, "invalid key to 'next'"); /* key not found */ + } +} + +int luaH_next(lua_State* L, Table* t, StkId key) +{ + int i = findindex(L, t, key); /* find original element */ + for (i++; i < t->sizearray; i++) + { /* try first array part */ + if (!ttisnil(&t->array[i])) + { /* a non-nil value? */ + setnvalue(key, cast_num(i + 1)); + setobj2s(L, key + 1, &t->array[i]); + return 1; + } + } + for (i -= t->sizearray; i < sizenode(t); i++) + { /* then hash part */ + if (!ttisnil(gval(gnode(t, i)))) + { /* a non-nil value? */ + getnodekey(L, key, gnode(t, i)); + setobj2s(L, key + 1, gval(gnode(t, i))); + return 1; + } + } + return 0; /* no more elements */ +} + +/* +** {============================================================= +** Rehash +** ============================================================== +*/ + +#define maybesetaboundary(t, boundary) \ + { \ + if (FFlag::LuauArrayBoundary && t->aboundary <= 0) \ + t->aboundary = -int(boundary); \ + } + +#define getaboundary(t) (t->aboundary < 0 ? -t->aboundary : t->sizearray) + +static int computesizes(int nums[], int* narray) +{ + int i; + int twotoi; /* 2^i */ + int a = 0; /* number of elements smaller than 2^i */ + int na = 0; /* number of elements to go to array part */ + int n = 0; /* optimal size for array part */ + for (i = 0, twotoi = 1; twotoi / 2 < *narray; i++, twotoi *= 2) + { + if (nums[i] > 0) + { + a += nums[i]; + if (a > twotoi / 2) + { /* more than half elements present? */ + n = twotoi; /* optimal size (till now) */ + na = a; /* all elements smaller than n will go to array part */ + } + } + if (a == *narray) + break; /* all elements already counted */ + } + *narray = n; + LUAU_ASSERT(*narray / 2 <= na && na <= *narray); + return na; +} + +static int countint(double key, int* nums) +{ + int k = arrayindex(key); + if (0 < k && k <= MAXSIZE) + { /* is `key' an appropriate array index? */ + nums[ceillog2(k)]++; /* count as such */ + return 1; + } + else + return 0; +} + +static int numusearray(const Table* t, int* nums) +{ + int lg; + int ttlg; /* 2^lg */ + int ause = 0; /* summation of `nums' */ + int i = 1; /* count to traverse all array keys */ + for (lg = 0, ttlg = 1; lg <= MAXBITS; lg++, ttlg *= 2) + { /* for each slice */ + int lc = 0; /* counter */ + int lim = ttlg; + if (lim > t->sizearray) + { + lim = t->sizearray; /* adjust upper limit */ + if (i > lim) + break; /* no more elements to count */ + } + /* count elements in range (2^(lg-1), 2^lg] */ + for (; i <= lim; i++) + { + if (!ttisnil(&t->array[i - 1])) + lc++; + } + nums[lg] += lc; + ause += lc; + } + return ause; +} + +static int numusehash(const Table* t, int* nums, int* pnasize) +{ + int totaluse = 0; /* total number of elements */ + int ause = 0; /* summation of `nums' */ + int i = sizenode(t); + while (i--) + { + LuaNode* n = &t->node[i]; + if (!ttisnil(gval(n))) + { + if (ttisnumber(gkey(n))) + ause += countint(nvalue(gkey(n)), nums); + totaluse++; + } + } + *pnasize += ause; + return totaluse; +} + +static void setarrayvector(lua_State* L, Table* t, int size) +{ + if (size > MAXSIZE) + luaG_runerror(L, "table overflow"); + luaM_reallocarray(L, t->array, t->sizearray, size, TValue, t->memcat); + TValue* array = t->array; + for (int i = t->sizearray; i < size; i++) + setnilvalue(&array[i]); + t->sizearray = size; +} + +static void setnodevector(lua_State* L, Table* t, int size) +{ + int lsize; + if (size == 0) + { /* no elements to hash part? */ + t->node = cast_to(LuaNode*, dummynode); /* use common `dummynode' */ + lsize = 0; + } + else + { + int i; + lsize = ceillog2(size); + if (lsize > MAXBITS) + luaG_runerror(L, "table overflow"); + size = twoto(lsize); + t->node = luaM_newarray(L, size, LuaNode, t->memcat); + for (i = 0; i < size; i++) + { + LuaNode* n = gnode(t, i); + gnext(n) = 0; + setnilvalue(gkey(n)); + setnilvalue(gval(n)); + } + } + t->lsizenode = cast_byte(lsize); + t->nodemask8 = cast_byte((1 << lsize) - 1); + t->lastfree = size; /* all positions are free */ +} + +static void resize(lua_State* L, Table* t, int nasize, int nhsize) +{ + if (nasize > MAXSIZE || nhsize > MAXSIZE) + luaG_runerror(L, "table overflow"); + int oldasize = t->sizearray; + int oldhsize = t->lsizenode; + LuaNode* nold = t->node; /* save old hash ... */ + if (nasize > oldasize) /* array part must grow? */ + setarrayvector(L, t, nasize); + /* create new hash part with appropriate size */ + setnodevector(L, t, nhsize); + if (nasize < oldasize) + { /* array part must shrink? */ + t->sizearray = nasize; + /* re-insert elements from vanishing slice */ + for (int i = nasize; i < oldasize; i++) + { + if (!ttisnil(&t->array[i])) + setobjt2t(L, luaH_setnum(L, t, i + 1), &t->array[i]); + } + /* shrink array */ + luaM_reallocarray(L, t->array, oldasize, nasize, TValue, t->memcat); + } + /* re-insert elements from hash part */ + for (int i = twoto(oldhsize) - 1; i >= 0; i--) + { + LuaNode* old = nold + i; + if (!ttisnil(gval(old))) + { + TValue ok; + getnodekey(L, &ok, old); + setobjt2t(L, luaH_set(L, t, &ok), gval(old)); + } + } + if (nold != dummynode) + luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); /* free old array */ +} + +void luaH_resizearray(lua_State* L, Table* t, int nasize) +{ + int nsize = (t->node == dummynode) ? 0 : sizenode(t); + resize(L, t, nasize, nsize); +} + +void luaH_resizehash(lua_State* L, Table* t, int nhsize) +{ + resize(L, t, t->sizearray, nhsize); +} + +static void rehash(lua_State* L, Table* t, const TValue* ek) +{ + int nums[MAXBITS + 1]; /* nums[i] = number of keys between 2^(i-1) and 2^i */ + for (int i = 0; i <= MAXBITS; i++) + nums[i] = 0; /* reset counts */ + int nasize = numusearray(t, nums); /* count keys in array part */ + int totaluse = nasize; /* all those keys are integer keys */ + totaluse += numusehash(t, nums, &nasize); /* count keys in hash part */ + /* count extra key */ + if (ttisnumber(ek)) + nasize += countint(nvalue(ek), nums); + totaluse++; + /* compute new size for array part */ + int na = computesizes(nums, &nasize); + /* resize the table to new computed sizes */ + resize(L, t, nasize, totaluse - na); +} + +/* +** }============================================================= +*/ + +Table* luaH_new(lua_State* L, int narray, int nhash) +{ + Table* t = luaM_new(L, Table, sizeof(Table), L->activememcat); + luaC_link(L, t, LUA_TTABLE); + t->metatable = NULL; + t->flags = cast_byte(~0); + t->array = NULL; + t->sizearray = 0; + t->lastfree = 0; + t->lsizenode = 0; + t->readonly = 0; + t->safeenv = 0; + t->nodemask8 = 0; + t->node = cast_to(LuaNode*, dummynode); + if (narray > 0) + setarrayvector(L, t, narray); + if (nhash > 0) + setnodevector(L, t, nhash); + return t; +} + +void luaH_free(lua_State* L, Table* t) +{ + if (t->node != dummynode) + luaM_freearray(L, t->node, sizenode(t), LuaNode, t->memcat); + luaM_freearray(L, t->array, t->sizearray, TValue, t->memcat); + luaM_free(L, t, sizeof(Table), t->memcat); +} + +static LuaNode* getfreepos(Table* t) +{ + while (t->lastfree > 0) + { + t->lastfree--; + + LuaNode* n = gnode(t, t->lastfree); + if (ttisnil(gkey(n))) + return n; + } + return NULL; /* could not find a free place */ +} + +/* +** inserts a new key into a hash table; first, check whether key's main +** position is free. If not, check whether colliding node is in its main +** position or not: if it is not, move colliding node to an empty place and +** put new key in its main position; otherwise (colliding node is in its main +** position), new key goes to an empty position. +*/ +static TValue* newkey(lua_State* L, Table* t, const TValue* key) +{ + LuaNode* mp = mainposition(t, key); + if (!ttisnil(gval(mp)) || mp == dummynode) + { + LuaNode* othern; + LuaNode* n = getfreepos(t); /* get a free place */ + if (n == NULL) + { /* cannot find a free place? */ + rehash(L, t, key); /* grow table */ + return luaH_set(L, t, key); /* re-insert key into grown table */ + } + LUAU_ASSERT(n != dummynode); + TValue mk; + getnodekey(L, &mk, mp); + othern = mainposition(t, &mk); + if (othern != mp) + { /* is colliding node out of its main position? */ + /* yes; move colliding node into free position */ + while (othern + gnext(othern) != mp) + othern += gnext(othern); /* find previous */ + gnext(othern) = cast_int(n - othern); /* redo the chain with `n' in place of `mp' */ + *n = *mp; /* copy colliding node into free pos. (mp->next also goes) */ + if (gnext(mp) != 0) + { + gnext(n) += cast_int(mp - n); /* correct 'next' */ + gnext(mp) = 0; /* now 'mp' is free */ + } + setnilvalue(gval(mp)); + } + else + { /* colliding node is in its own main position */ + /* new node will go into free position */ + if (gnext(mp) != 0) + gnext(n) = cast_int((mp + gnext(mp)) - n); /* chain new position */ + else + LUAU_ASSERT(gnext(n) == 0); + gnext(mp) = cast_int(n - mp); + mp = n; + } + } + setnodekey(L, mp, key); + luaC_barriert(L, t, key); + LUAU_ASSERT(ttisnil(gval(mp))); + return gval(mp); +} + +/* +** search function for integers +*/ +const TValue* luaH_getnum(Table* t, int key) +{ + /* (1 <= key && key <= t->sizearray) */ + if (cast_to(unsigned int, key - 1) < cast_to(unsigned int, t->sizearray)) + return &t->array[key - 1]; + else if (t->node != dummynode) + { + double nk = cast_num(key); + LuaNode* n = hashnum(t, nk); + for (;;) + { /* check whether `key' is somewhere in the chain */ + if (ttisnumber(gkey(n)) && luai_numeq(nvalue(gkey(n)), nk)) + return gval(n); /* that's it */ + if (gnext(n) == 0) + break; + n += gnext(n); + } + return luaO_nilobject; + } + else + return luaO_nilobject; +} + +/* +** search function for strings +*/ +const TValue* luaH_getstr(Table* t, TString* key) +{ + LuaNode* n = hashstr(t, key); + for (;;) + { /* check whether `key' is somewhere in the chain */ + if (ttisstring(gkey(n)) && tsvalue(gkey(n)) == key) + return gval(n); /* that's it */ + if (gnext(n) == 0) + break; + n += gnext(n); + } + return luaO_nilobject; +} + +/* +** main search function +*/ +const TValue* luaH_get(Table* t, const TValue* key) +{ + switch (ttype(key)) + { + case LUA_TNIL: + return luaO_nilobject; + case LUA_TSTRING: + return luaH_getstr(t, tsvalue(key)); + case LUA_TNUMBER: + { + int k; + double n = nvalue(key); + luai_num2int(k, n); + if (luai_numeq(cast_num(k), nvalue(key))) /* index is int? */ + return luaH_getnum(t, k); /* use specialized version */ + /* else go through */ + } + default: + { + LuaNode* n = mainposition(t, key); + for (;;) + { /* check whether `key' is somewhere in the chain */ + if (luaO_rawequalKey(gkey(n), key)) + return gval(n); /* that's it */ + if (gnext(n) == 0) + break; + n += gnext(n); + } + return luaO_nilobject; + } + } +} + +TValue* luaH_set(lua_State* L, Table* t, const TValue* key) +{ + const TValue* p = luaH_get(t, key); + invalidateTMcache(t); + if (p != luaO_nilobject) + return cast_to(TValue*, p); + else + { + if (ttisnil(key)) + luaG_runerror(L, "table index is nil"); + else if (ttisnumber(key) && luai_numisnan(nvalue(key))) + luaG_runerror(L, "table index is NaN"); + else if (ttisvector(key) && luai_vecisnan(vvalue(key))) + luaG_runerror(L, "table index contains NaN"); + return newkey(L, t, key); + } +} + +TValue* luaH_setnum(lua_State* L, Table* t, int key) +{ + /* (1 <= key && key <= t->sizearray) */ + if (cast_to(unsigned int, key - 1) < cast_to(unsigned int, t->sizearray)) + return &t->array[key - 1]; + /* hash fallback */ + const TValue* p = luaH_getnum(t, key); + if (p != luaO_nilobject) + return cast_to(TValue*, p); + else + { + TValue k; + setnvalue(&k, cast_num(key)); + return newkey(L, t, &k); + } +} + +TValue* luaH_setstr(lua_State* L, Table* t, TString* key) +{ + const TValue* p = luaH_getstr(t, key); + invalidateTMcache(t); + if (p != luaO_nilobject) + return cast_to(TValue*, p); + else + { + TValue k; + setsvalue(L, &k, key); + return newkey(L, t, &k); + } +} + +static LUAU_NOINLINE int unbound_search(Table* t, unsigned int j) +{ + unsigned int i = j; /* i is zero or a present index */ + j++; + /* find `i' and `j' such that i is present and j is not */ + while (!ttisnil(luaH_getnum(t, j))) + { + i = j; + j *= 2; + if (j > cast_to(unsigned int, INT_MAX)) + { /* overflow? */ + /* table was built with bad purposes: resort to linear search */ + i = 1; + while (!ttisnil(luaH_getnum(t, i))) + i++; + return i - 1; + } + } + /* now do a binary search between them */ + while (j - i > 1) + { + unsigned int m = (i + j) / 2; + if (ttisnil(luaH_getnum(t, m))) + j = m; + else + i = m; + } + return i; +} + +static int updateaboundary(Table* t, int boundary) +{ + if (boundary < t->sizearray && ttisnil(&t->array[boundary - 1])) + { + if (boundary >= 2 && !ttisnil(&t->array[boundary - 2])) + { + maybesetaboundary(t, boundary - 1); + return boundary - 1; + } + } + else if (boundary + 1 < t->sizearray && !ttisnil(&t->array[boundary]) && ttisnil(&t->array[boundary + 1])) + { + maybesetaboundary(t, boundary + 1); + return boundary + 1; + } + + return 0; +} + +/* +** Try to find a boundary in table `t'. A `boundary' is an integer index +** such that t[i] is non-nil and t[i+1] is nil (and 0 if t[1] is nil). +*/ +int luaH_getn(Table* t) +{ + int boundary = getaboundary(t); + + if (FFlag::LuauArrayBoundary && boundary > 0) + { + if (!ttisnil(&t->array[t->sizearray - 1]) && t->node == dummynode) + return t->sizearray; /* fast-path: the end of the array in `t' already refers to a boundary */ + if (boundary < t->sizearray && !ttisnil(&t->array[boundary - 1]) && ttisnil(&t->array[boundary])) + return boundary; /* fast-path: boundary already refers to a boundary in `t' */ + + int foundboundary = updateaboundary(t, boundary); + if (foundboundary > 0) + return foundboundary; + } + + int j = t->sizearray; + + if (j > 0 && ttisnil(&t->array[j - 1])) + { + // "branchless" binary search from Array Layouts for Comparison-Based Searching, Paul Khuong, Pat Morin, 2017. + // note that clang is cmov-shy on cmovs around memory operands, so it will compile this to a branchy loop. + TValue* base = t->array; + int rest = j; + while (int half = rest >> 1) + { + base = ttisnil(&base[half]) ? base : base + half; + rest -= half; + } + int boundary = !ttisnil(base) + int(base - t->array); + maybesetaboundary(t, boundary); + return boundary; + } + /* else must find a boundary in hash part */ + else if (t->node == dummynode) /* hash part is empty? */ + return j; /* that is easy... */ + else + return unbound_search(t, j); +} + +Table* luaH_clone(lua_State* L, Table* tt) +{ + Table* t = luaM_new(L, Table, sizeof(Table), L->activememcat); + luaC_link(L, t, LUA_TTABLE); + t->metatable = tt->metatable; + t->flags = tt->flags; + t->array = NULL; + t->sizearray = 0; + t->lsizenode = 0; + t->nodemask8 = 0; + t->readonly = 0; + t->safeenv = 0; + t->node = cast_to(LuaNode*, dummynode); + t->lastfree = 0; + + if (tt->sizearray) + { + t->array = luaM_newarray(L, tt->sizearray, TValue, t->memcat); + maybesetaboundary(t, getaboundary(tt)); + t->sizearray = tt->sizearray; + + memcpy(t->array, tt->array, t->sizearray * sizeof(TValue)); + } + + if (tt->node != dummynode) + { + int size = 1 << tt->lsizenode; + t->node = luaM_newarray(L, size, LuaNode, t->memcat); + t->lsizenode = tt->lsizenode; + t->nodemask8 = tt->nodemask8; + memcpy(t->node, tt->node, size * sizeof(LuaNode)); + t->lastfree = tt->lastfree; + } + + return t; +} + +void luaH_clear(Table* tt) +{ + /* clear array part */ + for (int i = 0; i < tt->sizearray; ++i) + { + setnilvalue(&tt->array[i]); + } + + maybesetaboundary(tt, 0); + + /* clear hash part */ + if (tt->node != dummynode) + { + int size = sizenode(tt); + tt->lastfree = size; + for (int i = 0; i < size; ++i) + { + LuaNode* n = gnode(tt, i); + setnilvalue(gkey(n)); + setnilvalue(gval(n)); + gnext(n) = 0; + } + } + + /* back to empty -> no tag methods present */ + tt->flags = cast_byte(~0); +} diff --git a/VM/src/ltable.h b/VM/src/ltable.h new file mode 100644 index 0000000..f98d87b --- /dev/null +++ b/VM/src/ltable.h @@ -0,0 +1,30 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" + +#define gnode(t, i) (&(t)->node[i]) +#define gkey(n) (&(n)->key) +#define gval(n) (&(n)->val) +#define gnext(n) ((n)->key.next) + +static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast below is incorrect"); +#define gval2slot(t, v) int(cast_to(LuaNode*, static_cast(v)) - t->node) + +LUAI_FUNC const TValue* luaH_getnum(Table* t, int key); +LUAI_FUNC TValue* luaH_setnum(lua_State* L, Table* t, int key); +LUAI_FUNC const TValue* luaH_getstr(Table* t, TString* key); +LUAI_FUNC TValue* luaH_setstr(lua_State* L, Table* t, TString* key); +LUAI_FUNC const TValue* luaH_get(Table* t, const TValue* key); +LUAI_FUNC TValue* luaH_set(lua_State* L, Table* t, const TValue* key); +LUAI_FUNC Table* luaH_new(lua_State* L, int narray, int lnhash); +LUAI_FUNC void luaH_resizearray(lua_State* L, Table* t, int nasize); +LUAI_FUNC void luaH_resizehash(lua_State* L, Table* t, int nhsize); +LUAI_FUNC void luaH_free(lua_State* L, Table* t); +LUAI_FUNC int luaH_next(lua_State* L, Table* t, StkId key); +LUAI_FUNC int luaH_getn(Table* t); +LUAI_FUNC Table* luaH_clone(lua_State* L, Table* tt); +LUAI_FUNC void luaH_clear(Table* tt); + +extern const LuaNode luaH_dummynode; diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp new file mode 100644 index 0000000..090e183 --- /dev/null +++ b/VM/src/ltablib.cpp @@ -0,0 +1,569 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include "lstate.h" +#include "ltable.h" +#include "lstring.h" +#include "lgc.h" +#include "ldebug.h" +#include "lvm.h" + +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry, false) + +LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false) + +bool lua_telemetry_table_move_oob_src_from = false; +bool lua_telemetry_table_move_oob_src_to = false; +bool lua_telemetry_table_move_oob_dst = false; + +static int foreachi(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + luaL_checktype(L, 2, LUA_TFUNCTION); + int i; + int n = lua_objlen(L, 1); + for (i = 1; i <= n; i++) + { + lua_pushvalue(L, 2); /* function */ + lua_pushinteger(L, i); /* 1st argument */ + lua_rawgeti(L, 1, i); /* 2nd argument */ + lua_call(L, 2, 1); + if (!lua_isnil(L, -1)) + return 1; + lua_pop(L, 1); /* remove nil result */ + } + return 0; +} + +static int foreach (lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + luaL_checktype(L, 2, LUA_TFUNCTION); + lua_pushnil(L); /* first key */ + while (lua_next(L, 1)) + { + lua_pushvalue(L, 2); /* function */ + lua_pushvalue(L, -3); /* key */ + lua_pushvalue(L, -3); /* value */ + lua_call(L, 2, 1); + if (!lua_isnil(L, -1)) + return 1; + lua_pop(L, 2); /* remove value and result */ + } + return 0; +} + +static int maxn(lua_State* L) +{ + double max = 0; + luaL_checktype(L, 1, LUA_TTABLE); + lua_pushnil(L); /* first key */ + while (lua_next(L, 1)) + { + lua_pop(L, 1); /* remove value */ + if (lua_type(L, -1) == LUA_TNUMBER) + { + double v = lua_tonumber(L, -1); + if (v > max) + max = v; + } + } + lua_pushnumber(L, max); + return 1; +} + +static int getn(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + lua_pushinteger(L, lua_objlen(L, 1)); + return 1; +} + +static void moveelements(lua_State* L, int srct, int dstt, int f, int e, int t) +{ + Table* src = hvalue(L->base + (srct - 1)); + Table* dst = hvalue(L->base + (dstt - 1)); + + if (dst->readonly) + luaG_runerror(L, "Attempt to modify a readonly table"); + + int n = e - f + 1; /* number of elements to move */ + + if (cast_to(unsigned int, f - 1) < cast_to(unsigned int, src->sizearray) && + cast_to(unsigned int, t - 1) < cast_to(unsigned int, dst->sizearray) && + cast_to(unsigned int, f - 1 + n) <= cast_to(unsigned int, src->sizearray) && + cast_to(unsigned int, t - 1 + n) <= cast_to(unsigned int, dst->sizearray)) + { + TValue* srcarray = src->array; + TValue* dstarray = dst->array; + + if (t > e || t <= f || (dstt != srct && dst != src)) + { + for (int i = 0; i < n; ++i) + { + TValue* s = &srcarray[f + i - 1]; + TValue* d = &dstarray[t + i - 1]; + setobj2t(L, d, s); + } + } + else + { + for (int i = n - 1; i >= 0; i--) + { + TValue* s = &srcarray[(f + i) - 1]; + TValue* d = &dstarray[(t + i) - 1]; + setobj2t(L, d, s); + } + } + + luaC_barrierfast(L, dst); + } + else + { + if (t > e || t <= f || dst != src) + { + for (int i = 0; i < n; ++i) + { + lua_rawgeti(L, srct, f + i); + lua_rawseti(L, dstt, t + i); + } + } + else + { + for (int i = n - 1; i >= 0; i--) + { + lua_rawgeti(L, srct, f + i); + lua_rawseti(L, dstt, t + i); + } + } + } +} + +static int tinsert(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + int n = lua_objlen(L, 1); + int pos; /* where to insert new element */ + switch (lua_gettop(L)) + { + case 2: + { /* called with only 2 arguments */ + pos = n + 1; /* insert new element at the end */ + break; + } + case 3: + { + pos = luaL_checkinteger(L, 2); /* 2nd argument is the position */ + + /* move up elements if necessary */ + if (1 <= pos && pos <= n) + moveelements(L, 1, 1, pos, n, pos + 1); + break; + } + default: + { + luaL_error(L, "wrong number of arguments to 'insert'"); + } + } + lua_rawseti(L, 1, pos); /* t[pos] = v */ + return 0; +} + +static int tremove(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + int n = lua_objlen(L, 1); + int pos = luaL_optinteger(L, 2, n); + + if (!(1 <= pos && pos <= n)) /* position is outside bounds? */ + return 0; /* nothing to remove */ + lua_rawgeti(L, 1, pos); /* result = t[pos] */ + + moveelements(L, 1, 1, pos + 1, n, pos); + + lua_pushnil(L); + lua_rawseti(L, 1, n); /* t[n] = nil */ + return 1; +} + +/* +** Copy elements (1[f], ..., 1[e]) into (tt[t], tt[t+1], ...). Whenever +** possible, copy in increasing order, which is better for rehashing. +** "possible" means destination after original range, or smaller +** than origin, or copying to another table. +*/ +static int tmove(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + int f = luaL_checkinteger(L, 2); + int e = luaL_checkinteger(L, 3); + int t = luaL_checkinteger(L, 4); + int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ + luaL_checktype(L, tt, LUA_TTABLE); + + if (DFFlag::LuauTableMoveTelemetry) + { + int nf = lua_objlen(L, 1); + int nt = lua_objlen(L, tt); + + // source index range must be in bounds in source table unless the table is empty (permits 1..#t moves) + if (!(f == 1 || (f >= 1 && f <= nf))) + lua_telemetry_table_move_oob_src_from = true; + if (!(e == nf || (e >= 1 && e <= nf))) + lua_telemetry_table_move_oob_src_to = true; + + // destination index must be in bounds in dest table or be exactly at the first empty element (permits concats) + if (!(t == nt + 1 || (t >= 1 && t <= nt + 1))) + lua_telemetry_table_move_oob_dst = true; + } + + if (e >= f) + { /* otherwise, nothing to move */ + luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); + int n = e - f + 1; /* number of elements to move */ + luaL_argcheck(L, t <= INT_MAX - n + 1, 4, "destination wrap around"); + + Table* dst = hvalue(L->base + (tt - 1)); + + if (dst->readonly) /* also checked in moveelements, but this blocks resizes of r/o tables */ + luaG_runerror(L, "Attempt to modify a readonly table"); + + if (t > 0 && (t - 1) <= dst->sizearray && (t - 1 + n) > dst->sizearray) + { /* grow the destination table array */ + luaH_resizearray(L, dst, t - 1 + n); + } + + moveelements(L, 1, tt, f, e, t); + } + lua_pushvalue(L, tt); /* return destination table */ + return 1; +} + +static void addfield(lua_State* L, luaL_Buffer* b, int i) +{ + lua_rawgeti(L, 1, i); + if (!lua_isstring(L, -1)) + luaL_error(L, "invalid value (%s) at index %d in table for 'concat'", luaL_typename(L, -1), i); + luaL_addvalue(b); +} + +static int tconcat(lua_State* L) +{ + luaL_Buffer b; + size_t lsep; + int i, last; + const char* sep = luaL_optlstring(L, 2, "", &lsep); + luaL_checktype(L, 1, LUA_TTABLE); + i = luaL_optinteger(L, 3, 1); + last = luaL_opt(L, luaL_checkinteger, 4, lua_objlen(L, 1)); + luaL_buffinit(L, &b); + for (; i < last; i++) + { + addfield(L, &b, i); + luaL_addlstring(&b, sep, lsep); + } + if (i == last) /* add last value (if interval was not empty) */ + addfield(L, &b, i); + luaL_pushresult(&b); + return 1; +} + +static int tpack(lua_State* L) +{ + int n = lua_gettop(L); /* number of elements to pack */ + lua_createtable(L, n, 1); /* create result table */ + + Table* t = hvalue(L->top - 1); + + for (int i = 0; i < n; ++i) + { + TValue* e = &t->array[i]; + setobj2t(L, e, L->base + i); + } + + /* t.n = number of elements */ + TValue* nv = luaH_setstr(L, t, luaS_newliteral(L, "n")); + setnvalue(nv, n); + + return 1; /* return table */ +} + +static int tunpack(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + Table* t = hvalue(L->base); + + int i = luaL_optinteger(L, 2, 1); + int e = luaL_opt(L, luaL_checkinteger, 3, lua_objlen(L, 1)); + if (i > e) + return 0; /* empty range */ + unsigned n = (unsigned)e - i; /* number of elements minus 1 (avoid overflows) */ + if (n >= (unsigned int)INT_MAX || !lua_checkstack(L, (int)(++n))) + luaL_error(L, "too many results to unpack"); + + // fast-path: direct array-to-stack copy + if (i == 1 && int(n) <= t->sizearray) + { + for (i = 0; i < int(n); i++) + setobj2s(L, L->top + i, &t->array[i]); + L->top += n; + } + else + { + /* push arg[i..e - 1] (to avoid overflows) */ + for (; i < e; i++) + lua_rawgeti(L, 1, i); + lua_rawgeti(L, 1, e); /* push last element */ + } + return (int)n; +} + +/* +** {====================================================== +** Quicksort +** (based on `Algorithms in MODULA-3', Robert Sedgewick; +** Addison-Wesley, 1993.) +*/ + +static void set2(lua_State* L, int i, int j) +{ + lua_rawseti(L, 1, i); + lua_rawseti(L, 1, j); +} + +static int sort_comp(lua_State* L, int a, int b) +{ + if (!lua_isnil(L, 2)) + { /* function? */ + int res; + lua_pushvalue(L, 2); + lua_pushvalue(L, a - 1); /* -1 to compensate function */ + lua_pushvalue(L, b - 2); /* -2 to compensate function and `a' */ + lua_call(L, 2, 1); + res = lua_toboolean(L, -1); + lua_pop(L, 1); + return res; + } + else /* a < b? */ + return lua_lessthan(L, a, b); +} + +static void auxsort(lua_State* L, int l, int u) +{ + while (l < u) + { /* for tail recursion */ + int i, j; + /* sort elements a[l], a[(l+u)/2] and a[u] */ + lua_rawgeti(L, 1, l); + lua_rawgeti(L, 1, u); + if (sort_comp(L, -1, -2)) /* a[u] < a[l]? */ + set2(L, l, u); /* swap a[l] - a[u] */ + else + lua_pop(L, 2); + if (u - l == 1) + break; /* only 2 elements */ + i = (l + u) / 2; + lua_rawgeti(L, 1, i); + lua_rawgeti(L, 1, l); + if (sort_comp(L, -2, -1)) /* a[i]= P */ + while (lua_rawgeti(L, 1, ++i), sort_comp(L, -1, -2)) + { + if (i >= u) + luaL_error(L, "invalid order function for sorting"); + lua_pop(L, 1); /* remove a[i] */ + } + /* repeat --j until a[j] <= P */ + while (lua_rawgeti(L, 1, --j), sort_comp(L, -3, -1)) + { + if (j <= l) + luaL_error(L, "invalid order function for sorting"); + lua_pop(L, 1); /* remove a[j] */ + } + if (j < i) + { + lua_pop(L, 3); /* pop pivot, a[i], a[j] */ + break; + } + set2(L, i, j); + } + lua_rawgeti(L, 1, u - 1); + lua_rawgeti(L, 1, i); + set2(L, u - 1, i); /* swap pivot (a[u-1]) with a[i] */ + /* a[l..i-1] <= a[i] == P <= a[i+1..u] */ + /* adjust so that smaller half is in [j..i] and larger one in [l..u] */ + if (i - l < u - i) + { + j = l; + i = i - 1; + l = i + 2; + } + else + { + j = i + 1; + i = u; + u = j - 2; + } + auxsort(L, j, i); /* call recursively the smaller one */ + } /* repeat the routine for the larger one */ +} + +static int sort(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + int n = lua_objlen(L, 1); + luaL_checkstack(L, 40, ""); /* assume array is smaller than 2^40 */ + if (!lua_isnoneornil(L, 2)) /* is there a 2nd argument? */ + luaL_checktype(L, 2, LUA_TFUNCTION); + lua_settop(L, 2); /* make sure there is two arguments */ + auxsort(L, 1, n); + return 0; +} + +/* }====================================================== */ + +static int tcreate(lua_State* L) +{ + int size = luaL_checkinteger(L, 1); + if (size < 0) + luaL_argerror(L, 1, "size out of range"); + + if (!lua_isnoneornil(L, 2)) + { + lua_createtable(L, size, 0); + Table* t = hvalue(L->top - 1); + + StkId v = L->base + 1; + + for (int i = 0; i < size; ++i) + { + TValue* e = &t->array[i]; + setobj2t(L, e, v); + } + } + else + { + lua_createtable(L, size, 0); + } + + return 1; +} + +static int tfind(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + luaL_checkany(L, 2); + int init = luaL_optinteger(L, 3, 1); + if (init < 1) + luaL_argerror(L, 3, "index out of range"); + + Table* t = hvalue(L->base); + StkId v = L->base + 1; + + for (int i = init;; ++i) + { + const TValue* e = luaH_getnum(t, i); + if (ttisnil(e)) + break; + + if (equalobj(L, v, e)) + { + lua_pushinteger(L, i); + return 1; + } + } + + lua_pushnil(L); + return 1; +} + +static int tclear(lua_State* L) +{ + luaL_checktype(L, 1, LUA_TTABLE); + + Table* tt = hvalue(L->base); + if (tt->readonly) + luaG_runerror(L, "Attempt to modify a readonly table"); + + luaH_clear(tt); + return 0; +} + +static int tfreeze(lua_State* L) +{ + if (!FFlag::LuauTableFreeze) + luaG_runerror(L, "table.freeze is disabled"); + + luaL_checktype(L, 1, LUA_TTABLE); + luaL_argcheck(L, !lua_getreadonly(L, 1), 1, "table is already frozen"); + luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); + + lua_setreadonly(L, 1, true); + + lua_pushvalue(L, 1); + return 1; +} + +static int tisfrozen(lua_State* L) +{ + if (!FFlag::LuauTableFreeze) + luaG_runerror(L, "table.isfrozen is disabled"); + + luaL_checktype(L, 1, LUA_TTABLE); + + lua_pushboolean(L, lua_getreadonly(L, 1)); + return 1; +} + +static const luaL_Reg tab_funcs[] = { + {"concat", tconcat}, + {"foreach", foreach}, + {"foreachi", foreachi}, + {"getn", getn}, + {"maxn", maxn}, + {"insert", tinsert}, + {"remove", tremove}, + {"sort", sort}, + {"pack", tpack}, + {"unpack", tunpack}, + {"move", tmove}, + {"create", tcreate}, + {"find", tfind}, + {"clear", tclear}, + {"freeze", tfreeze}, + {"isfrozen", tisfrozen}, + {NULL, NULL}, +}; + +LUALIB_API int luaopen_table(lua_State* L) +{ + luaL_register(L, LUA_TABLIBNAME, tab_funcs); + + // Lua 5.1 compat + lua_pushcfunction(L, tunpack, "unpack"); + lua_setglobal(L, "unpack"); + + return 1; +} diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp new file mode 100644 index 0000000..a77a7c7 --- /dev/null +++ b/VM/src/ltm.cpp @@ -0,0 +1,140 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "ltm.h" + +#include "lstate.h" +#include "lstring.h" +#include "ltable.h" +#include "lgc.h" + +#include + +// clang-format off +const char* const luaT_typenames[] = { + /* ORDER TYPE */ + "nil", + "boolean", + + + "userdata", + "number", + "vector", + + "string", + + + "table", + "function", + "userdata", + "thread", +}; + +const char* const luaT_eventname[] = { + /* ORDER TM */ + + "__index", + "__newindex", + "__mode", + "__namecall", + + "__eq", + + + "__add", + "__sub", + "__mul", + "__div", + "__mod", + "__pow", + "__unm", + + + "__len", + "__lt", + "__le", + "__concat", + "__call", + "__type", +}; +// clang-format on + +static_assert(sizeof(luaT_typenames) / sizeof(luaT_typenames[0]) == LUA_T_COUNT, "luaT_typenames size mismatch"); +static_assert(sizeof(luaT_eventname) / sizeof(luaT_eventname[0]) == TM_N, "luaT_eventname size mismatch"); + +void luaT_init(lua_State* L) +{ + int i; + for (i = 0; i < LUA_T_COUNT; i++) + { + L->global->ttname[i] = luaS_new(L, luaT_typenames[i]); + luaS_fix(L->global->ttname[i]); /* never collect these names */ + } + for (i = 0; i < TM_N; i++) + { + L->global->tmname[i] = luaS_new(L, luaT_eventname[i]); + luaS_fix(L->global->tmname[i]); /* never collect these names */ + } +} + +/* +** function to be used with macro "fasttm": optimized for absence of +** tag methods. +*/ +const TValue* luaT_gettm(Table* events, TMS event, TString* ename) +{ + const TValue* tm = luaH_getstr(events, ename); + LUAU_ASSERT(event <= TM_EQ); + if (ttisnil(tm)) + { /* no tag method? */ + events->flags |= cast_byte(1u << event); /* cache this fact */ + return NULL; + } + else + return tm; +} + +const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event) +{ + /* + NB: Tag-methods were replaced by meta-methods in Lua 5.0, but the + old names are still around (this function, for example). + */ + Table* mt; + switch (ttype(o)) + { + case LUA_TTABLE: + mt = hvalue(o)->metatable; + break; + case LUA_TUSERDATA: + mt = uvalue(o)->metatable; + break; + default: + mt = L->global->mt[ttype(o)]; + } + return (mt ? luaH_getstr(mt, L->global->tmname[event]) : luaO_nilobject); +} + +const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) +{ + if (ttisuserdata(o) && uvalue(o)->tag && uvalue(o)->metatable) + { + const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); + + if (ttisstring(type)) + return tsvalue(type); + } + else if (Table* mt = L->global->mt[ttype(o)]) + { + const TValue* type = luaH_getstr(mt, L->global->tmname[TM_TYPE]); + + if (ttisstring(type)) + return tsvalue(type); + } + + return L->global->ttname[ttype(o)]; +} + +const char* luaT_objtypename(lua_State* L, const TValue* o) +{ + return getstr(luaT_objtypenamestr(L, o)); +} diff --git a/VM/src/ltm.h b/VM/src/ltm.h new file mode 100644 index 0000000..0e4e915 --- /dev/null +++ b/VM/src/ltm.h @@ -0,0 +1,57 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" + +/* + * WARNING: if you change the order of this enumeration, + * grep "ORDER TM" + */ +// clang-format off +typedef enum +{ + + TM_INDEX, + TM_NEWINDEX, + TM_MODE, + TM_NAMECALL, + + TM_EQ, /* last tag method with `fast' access */ + + + TM_ADD, + TM_SUB, + TM_MUL, + TM_DIV, + TM_MOD, + TM_POW, + TM_UNM, + + + TM_LEN, + TM_LT, + TM_LE, + TM_CONCAT, + TM_CALL, + TM_TYPE, + + TM_N /* number of elements in the enum */ +} TMS; +// clang-format on + +#define gfasttm(g, et, e) ((et) == NULL ? NULL : ((et)->flags & (1u << (e))) ? NULL : luaT_gettm(et, e, (g)->tmname[e])) + +#define fasttm(l, et, e) gfasttm(l->global, et, e) +#define fastnotm(et, e) ((et) == NULL || ((et)->flags & (1u << (e)))) + +LUAI_DATA const char* const luaT_typenames[]; +LUAI_DATA const char* const luaT_eventname[]; + +LUAI_FUNC const TValue* luaT_gettm(Table* events, TMS event, TString* ename); +LUAI_FUNC const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event); + +LUAI_FUNC const TString* luaT_objtypenamestr(lua_State* L, const TValue* o); +LUAI_FUNC const char* luaT_objtypename(lua_State* L, const TValue* o); + +LUAI_FUNC void luaT_init(lua_State* L); diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp new file mode 100644 index 0000000..6a02629 --- /dev/null +++ b/VM/src/lutf8lib.cpp @@ -0,0 +1,294 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lualib.h" + +#include "lcommon.h" + +#define MAXUNICODE 0x10FFFF + +#define iscont(p) ((*(p)&0xC0) == 0x80) + +/* from strlib */ +/* translate a relative string position: negative means back from end */ +static int u_posrelat(int pos, size_t len) +{ + if (pos >= 0) + return pos; + else if (0u - (size_t)pos > len) + return 0; + else + return (int)len + pos + 1; +} + +/* +** Decode one UTF-8 sequence, returning NULL if byte sequence is invalid. +*/ +static const char* utf8_decode(const char* o, int* val) +{ + static const unsigned int limits[] = {0xFF, 0x7F, 0x7FF, 0xFFFF}; + const unsigned char* s = (const unsigned char*)o; + unsigned int c = s[0]; + unsigned int res = 0; /* final result */ + if (c < 0x80) /* ascii? */ + res = c; + else + { + int count = 0; /* to count number of continuation bytes */ + while (c & 0x40) + { /* still have continuation bytes? */ + int cc = s[++count]; /* read next byte */ + if ((cc & 0xC0) != 0x80) /* not a continuation byte? */ + return NULL; /* invalid byte sequence */ + res = (res << 6) | (cc & 0x3F); /* add lower 6 bits from cont. byte */ + c <<= 1; /* to test next bit */ + } + res |= ((c & 0x7F) << (count * 5)); /* add first byte */ + if (count > 3 || res > MAXUNICODE || res <= limits[count]) + return NULL; /* invalid byte sequence */ + s += count; /* skip continuation bytes read */ + } + if (val) + *val = res; + return (const char*)s + 1; /* +1 to include first byte */ +} + +/* +** utf8len(s [, i [, j]]) --> number of characters that start in the +** range [i,j], or nil + current position if 's' is not well formed in +** that interval +*/ +static int utflen(lua_State* L) +{ + int n = 0; + size_t len; + const char* s = luaL_checklstring(L, 1, &len); + int posi = u_posrelat(luaL_optinteger(L, 2, 1), len); + int posj = u_posrelat(luaL_optinteger(L, 3, -1), len); + luaL_argcheck(L, 1 <= posi && --posi <= (int)len, 2, "initial position out of string"); + luaL_argcheck(L, --posj < (int)len, 3, "final position out of string"); + while (posi <= posj) + { + const char* s1 = utf8_decode(s + posi, NULL); + if (s1 == NULL) + { /* conversion error? */ + lua_pushnil(L); /* return nil ... */ + lua_pushinteger(L, posi + 1); /* ... and current position */ + return 2; + } + posi = (int)(s1 - s); + n++; + } + lua_pushinteger(L, n); + return 1; +} + +/* +** codepoint(s, [i, [j]]) -> returns codepoints for all characters +** that start in the range [i,j] +*/ +static int codepoint(lua_State* L) +{ + size_t len; + const char* s = luaL_checklstring(L, 1, &len); + int posi = u_posrelat(luaL_optinteger(L, 2, 1), len); + int pose = u_posrelat(luaL_optinteger(L, 3, posi), len); + int n; + const char* se; + luaL_argcheck(L, posi >= 1, 2, "out of range"); + luaL_argcheck(L, pose <= (int)len, 3, "out of range"); + if (posi > pose) + return 0; /* empty interval; return no values */ + if (pose - posi >= INT_MAX) /* (int -> int) overflow? */ + luaL_error(L, "string slice too long"); + n = (int)(pose - posi) + 1; + luaL_checkstack(L, n, "string slice too long"); + n = 0; + se = s + pose; + for (s += posi - 1; s < se;) + { + int code; + s = utf8_decode(s, &code); + if (s == NULL) + luaL_error(L, "invalid UTF-8 code"); + lua_pushinteger(L, code); + n++; + } + return n; +} + +// from Lua 5.3 lobject.h +#define UTF8BUFFSZ 8 + +// from Lua 5.3 lobject.c, copied verbatim + static +static int luaO_utf8esc(char* buff, unsigned long x) +{ + int n = 1; /* number of bytes put in buffer (backwards) */ + LUAU_ASSERT(x <= 0x10FFFF); + if (x < 0x80) /* ascii? */ + buff[UTF8BUFFSZ - 1] = cast_to(char, x); + else + { /* need continuation bytes */ + unsigned int mfb = 0x3f; /* maximum that fits in first byte */ + do + { /* add continuation bytes */ + buff[UTF8BUFFSZ - (n++)] = cast_to(char, 0x80 | (x & 0x3f)); + x >>= 6; /* remove added bits */ + mfb >>= 1; /* now there is one less bit available in first byte */ + } while (x > mfb); /* still needs continuation byte? */ + buff[UTF8BUFFSZ - n] = cast_to(char, (~mfb << 1) | x); /* add first byte */ + } + return n; +} + +// lighter replacement for pushutfchar; doesn't push any string onto the stack +static int buffutfchar(lua_State* L, int arg, char* buff, const char** charstr) +{ + int code = luaL_checkinteger(L, arg); + luaL_argcheck(L, 0 <= code && code <= MAXUNICODE, arg, "value out of range"); + int l = luaO_utf8esc(buff, cast_to(long, code)); + *charstr = buff + UTF8BUFFSZ - l; + return l; +} + +/* +** utfchar(n1, n2, ...) -> char(n1)..char(n2)... +** +** This version avoids the need to make more invasive upgrades elsewhere (like +** implementing the %U escape in lua_pushfstring) and avoids pushing string +** objects for each codepoint in the multi-argument case. -Jovanni +*/ +static int utfchar(lua_State* L) +{ + char buff[UTF8BUFFSZ]; + const char* charstr; + + int n = lua_gettop(L); /* number of arguments */ + if (n == 1) + { /* optimize common case of single char */ + int l = buffutfchar(L, 1, buff, &charstr); + lua_pushlstring(L, charstr, l); + } + else + { + luaL_Buffer b; + luaL_buffinit(L, &b); + for (int i = 1; i <= n; i++) + { + int l = buffutfchar(L, i, buff, &charstr); + luaL_addlstring(&b, charstr, l); + } + luaL_pushresult(&b); + } + return 1; +} + +/* +** offset(s, n, [i]) -> index where n-th character counting from +** position 'i' starts; 0 means character at 'i'. +*/ +static int byteoffset(lua_State* L) +{ + size_t len; + const char* s = luaL_checklstring(L, 1, &len); + int n = luaL_checkinteger(L, 2); + int posi = (n >= 0) ? 1 : (int)len + 1; + posi = u_posrelat(luaL_optinteger(L, 3, posi), len); + luaL_argcheck(L, 1 <= posi && --posi <= (int)len, 3, "position out of range"); + if (n == 0) + { + /* find beginning of current byte sequence */ + while (posi > 0 && iscont(s + posi)) + posi--; + } + else + { + if (iscont(s + posi)) + luaL_error(L, "initial position is a continuation byte"); + if (n < 0) + { + while (n < 0 && posi > 0) + { /* move back */ + do + { /* find beginning of previous character */ + posi--; + } while (posi > 0 && iscont(s + posi)); + n++; + } + } + else + { + n--; /* do not move for 1st character */ + while (n > 0 && posi < (int)len) + { + do + { /* find beginning of next character */ + posi++; + } while (iscont(s + posi)); /* (cannot pass final '\0') */ + n--; + } + } + } + if (n == 0) /* did it find given character? */ + lua_pushinteger(L, posi + 1); + else /* no such character */ + lua_pushnil(L); + return 1; +} + +static int iter_aux(lua_State* L) +{ + size_t len; + const char* s = luaL_checklstring(L, 1, &len); + int n = lua_tointeger(L, 2) - 1; + if (n < 0) /* first iteration? */ + n = 0; /* start from here */ + else if (n < (int)len) + { + n++; /* skip current byte */ + while (iscont(s + n)) + n++; /* and its continuations */ + } + if (n >= (int)len) + return 0; /* no more codepoints */ + else + { + int code; + const char* next = utf8_decode(s + n, &code); + if (next == NULL || iscont(next)) + luaL_error(L, "invalid UTF-8 code"); + lua_pushinteger(L, n + 1); + lua_pushinteger(L, code); + return 2; + } +} + +static int iter_codes(lua_State* L) +{ + luaL_checkstring(L, 1); + lua_pushcfunction(L, iter_aux); + lua_pushvalue(L, 1); + lua_pushinteger(L, 0); + return 3; +} + +/* pattern to match a single UTF-8 character */ +#define UTF8PATT "[\0-\x7F\xC2-\xF4][\x80-\xBF]*" + +static const luaL_Reg funcs[] = { + {"offset", byteoffset}, + {"codepoint", codepoint}, + {"char", utfchar}, + {"len", utflen}, + {"codes", iter_codes}, + {NULL, NULL}, +}; + +LUALIB_API int luaopen_utf8(lua_State* L) +{ + luaL_register(L, LUA_UTF8LIBNAME, funcs); + + lua_pushlstring(L, UTF8PATT, sizeof(UTF8PATT) / sizeof(char) - 1); + lua_setfield(L, -2, "charpattern"); + + return 1; +} diff --git a/VM/src/lvm.h b/VM/src/lvm.h new file mode 100644 index 0000000..25a2716 --- /dev/null +++ b/VM/src/lvm.h @@ -0,0 +1,31 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" +#include "ltm.h" + +#define tostring(L, o) ((ttype(o) == LUA_TSTRING) || (luaV_tostring(L, o))) + +#define tonumber(o, n) (ttype(o) == LUA_TNUMBER || (((o) = luaV_tonumber(o, n)) != NULL)) + +#define equalobj(L, o1, o2) (ttype(o1) == ttype(o2) && luaV_equalval(L, o1, o2)) + +LUAI_FUNC int luaV_strcmp(const TString* ls, const TString* rs); +LUAI_FUNC int luaV_lessthan(lua_State* L, const TValue* l, const TValue* r); +LUAI_FUNC int luaV_lessequal(lua_State* L, const TValue* l, const TValue* r); +LUAI_FUNC int luaV_equalval(lua_State* L, const TValue* t1, const TValue* t2); +LUAI_FUNC void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op); +LUAI_FUNC void luaV_dolen(lua_State* L, StkId ra, const TValue* rb); +LUAI_FUNC const TValue* luaV_tonumber(const TValue* obj, TValue* n); +LUAI_FUNC const float* luaV_tovector(const TValue* obj); +LUAI_FUNC int luaV_tostring(lua_State* L, StkId obj); +LUAI_FUNC void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId val); +LUAI_FUNC void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val); +LUAI_FUNC void luaV_concat(lua_State* L, int total, int last); +LUAI_FUNC void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil); + +LUAI_FUNC void luau_execute(lua_State* L); +LUAI_FUNC int luau_precall(lua_State* L, struct lua_TValue* func, int nresults); +LUAI_FUNC void luau_poscall(lua_State* L, StkId first); +LUAI_FUNC void luau_callhook(lua_State* L, lua_Hook hook, void* userdata); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp new file mode 100644 index 0000000..5f0ee92 --- /dev/null +++ b/VM/src/lvmexecute.cpp @@ -0,0 +1,2957 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lvm.h" + +#include "lstate.h" +#include "ltable.h" +#include "lfunc.h" +#include "lstring.h" +#include "lgc.h" +#include "lmem.h" +#include "ldebug.h" +#include "ldo.h" +#include "lbuiltins.h" +#include "lnumutils.h" +#include "lbytecode.h" + +#include + +LUAU_FASTFLAGVARIABLE(LuauLoopUseSafeenv, false) + +// Disable c99-designator to avoid the warning in CGOTO dispatch table +#ifdef __clang__ +#if __has_warning("-Wc99-designator") +#pragma clang diagnostic ignored "-Wc99-designator" +#endif +#endif + +// When working with VM code, pay attention to these rules for correctness: +// 1. Many external Lua functions can fail; for them to fail and be able to generate a proper stack, we need to copy pc to L->ci->savedpc before the +// call +// 2. Many external Lua functions can reallocate the stack. This invalidates stack pointers in VM C stack frame, most importantly base, but also +// ra/rb/rc! +// 3. VM_PROTECT macro saves savedpc and restores base for you; most external calls need to be wrapped into that. However, it does NOT restore +// ra/rb/rc! +// 4. When copying an object to any existing object as a field, generally speaking you need to call luaC_barrier! Be careful with all setobj calls +// 5. To make 4 easier to follow, please use setobj2s for copies to stack and setobj for other copies. +// 6. You can define HARDSTACKTESTS in llimits.h which will aggressively realloc stack; with address sanitizer this should be effective at finding +// stack corruption bugs +// 7. Many external Lua functions can call GC! GC will *not* traverse pointers to new objects that aren't reachable from Lua root. Be careful when +// creating new Lua objects, store them to stack soon. + +// When calling luau_callTM, we usually push the arguments to the top of the stack. +// This is safe to do for complicated reasons: +// - stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) +// - stack reallocation copies values past stack_last + +// All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT +// This makes sure that we save the pc (in case the Lua call needs to generate a backtrace) before the call, +// and restores the stack pointer after in case stack gets reallocated +// Should only be used on the slow paths. +#define VM_PROTECT(x) \ + { \ + L->ci->savedpc = pc; \ + { \ + x; \ + }; \ + base = L->base; \ + } + +// Some external functions can cause an error, but never reallocate the stack; for these, VM_PROTECT_PC() is +// a cheaper version of VM_PROTECT that can be called before the external call. +#define VM_PROTECT_PC() L->ci->savedpc = pc + +#define VM_REG(i) (LUAU_ASSERT(unsigned(i) < unsigned(L->top - base)), &base[i]) +#define VM_KV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->l.p->sizek)), &k[i]) +#define VM_UV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->nupvalues)), &cl->l.uprefs[i]) + +#define VM_PATCH_C(pc, slot) ((uint8_t*)(pc))[3] = uint8_t(slot) + +// NOTE: If debugging the Luau code, disable this macro to prevent timeouts from +// occurring when tracing code in Visual Studio / XCode +#if 0 +#define VM_INTERRUPT() +#else +#define VM_INTERRUPT() \ + { \ + void (*interrupt)(lua_State*, int) = L->global->cb.interrupt; \ + if (LUAU_UNLIKELY(!!interrupt)) \ + { /* the interrupt hook is called right before we advance pc */ \ + VM_PROTECT(L->ci->savedpc++; interrupt(L, -1)); \ + } \ + } +#endif + + +#define VM_DISPATCH_OP(op) &&CASE_##op + + +#define VM_DISPATCH_TABLE() \ + VM_DISPATCH_OP(LOP_NOP), VM_DISPATCH_OP(LOP_BREAK), VM_DISPATCH_OP(LOP_LOADNIL), VM_DISPATCH_OP(LOP_LOADB), VM_DISPATCH_OP(LOP_LOADN), \ + VM_DISPATCH_OP(LOP_LOADK), VM_DISPATCH_OP(LOP_MOVE), VM_DISPATCH_OP(LOP_GETGLOBAL), VM_DISPATCH_OP(LOP_SETGLOBAL), \ + VM_DISPATCH_OP(LOP_GETUPVAL), VM_DISPATCH_OP(LOP_SETUPVAL), VM_DISPATCH_OP(LOP_CLOSEUPVALS), VM_DISPATCH_OP(LOP_GETIMPORT), \ + VM_DISPATCH_OP(LOP_GETTABLE), VM_DISPATCH_OP(LOP_SETTABLE), VM_DISPATCH_OP(LOP_GETTABLEKS), VM_DISPATCH_OP(LOP_SETTABLEKS), \ + VM_DISPATCH_OP(LOP_GETTABLEN), VM_DISPATCH_OP(LOP_SETTABLEN), VM_DISPATCH_OP(LOP_NEWCLOSURE), VM_DISPATCH_OP(LOP_NAMECALL), \ + VM_DISPATCH_OP(LOP_CALL), VM_DISPATCH_OP(LOP_RETURN), VM_DISPATCH_OP(LOP_JUMP), VM_DISPATCH_OP(LOP_JUMPBACK), VM_DISPATCH_OP(LOP_JUMPIF), \ + VM_DISPATCH_OP(LOP_JUMPIFNOT), VM_DISPATCH_OP(LOP_JUMPIFEQ), VM_DISPATCH_OP(LOP_JUMPIFLE), VM_DISPATCH_OP(LOP_JUMPIFLT), \ + VM_DISPATCH_OP(LOP_JUMPIFNOTEQ), VM_DISPATCH_OP(LOP_JUMPIFNOTLE), VM_DISPATCH_OP(LOP_JUMPIFNOTLT), VM_DISPATCH_OP(LOP_ADD), \ + VM_DISPATCH_OP(LOP_SUB), VM_DISPATCH_OP(LOP_MUL), VM_DISPATCH_OP(LOP_DIV), VM_DISPATCH_OP(LOP_MOD), VM_DISPATCH_OP(LOP_POW), \ + VM_DISPATCH_OP(LOP_ADDK), VM_DISPATCH_OP(LOP_SUBK), VM_DISPATCH_OP(LOP_MULK), VM_DISPATCH_OP(LOP_DIVK), VM_DISPATCH_OP(LOP_MODK), \ + VM_DISPATCH_OP(LOP_POWK), VM_DISPATCH_OP(LOP_AND), VM_DISPATCH_OP(LOP_OR), VM_DISPATCH_OP(LOP_ANDK), VM_DISPATCH_OP(LOP_ORK), \ + VM_DISPATCH_OP(LOP_CONCAT), VM_DISPATCH_OP(LOP_NOT), VM_DISPATCH_OP(LOP_MINUS), VM_DISPATCH_OP(LOP_LENGTH), VM_DISPATCH_OP(LOP_NEWTABLE), \ + VM_DISPATCH_OP(LOP_DUPTABLE), VM_DISPATCH_OP(LOP_SETLIST), VM_DISPATCH_OP(LOP_FORNPREP), VM_DISPATCH_OP(LOP_FORNLOOP), \ + VM_DISPATCH_OP(LOP_FORGLOOP), VM_DISPATCH_OP(LOP_FORGPREP_INEXT), VM_DISPATCH_OP(LOP_FORGLOOP_INEXT), VM_DISPATCH_OP(LOP_FORGPREP_NEXT), \ + VM_DISPATCH_OP(LOP_FORGLOOP_NEXT), VM_DISPATCH_OP(LOP_GETVARARGS), VM_DISPATCH_OP(LOP_DUPCLOSURE), VM_DISPATCH_OP(LOP_PREPVARARGS), \ + VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ + VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_JUMPIFEQK), VM_DISPATCH_OP(LOP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ + VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), + +#if defined(__GNUC__) || defined(__clang__) +#define VM_USE_CGOTO 1 +#else +#define VM_USE_CGOTO 0 +#endif + +/** + * These macros help dispatching Luau opcodes using either case + * statements or computed goto. + * VM_CASE(op) Generates either a case statement or a label + * VM_NEXT() fetch a byte and dispatch or jump to the beginning of the switch statement + * VM_CONTINUE() Use an opcode override to dispatch with computed goto or + * switch statement to skip a LOP_BREAK instruction. + */ +#if VM_USE_CGOTO +#define VM_CASE(op) CASE_##op: +#define VM_NEXT() goto*(SingleStep ? &&dispatch : kDispatchTable[*(uint8_t*)pc]) +#define VM_CONTINUE(op) goto* kDispatchTable[uint8_t(op)] +#else +#define VM_CASE(op) case op: +#define VM_NEXT() goto dispatch +#define VM_CONTINUE(op) \ + dispatchOp = uint8_t(op); \ + goto dispatchContinue +#endif + +LUAU_NOINLINE static void luau_prepareFORN(lua_State* L, StkId plimit, StkId pstep, StkId pinit) +{ + if (!ttisnumber(pinit) && !luaV_tonumber(pinit, pinit)) + luaG_forerror(L, pinit, "initial value"); + if (!ttisnumber(plimit) && !luaV_tonumber(plimit, plimit)) + luaG_forerror(L, plimit, "limit"); + if (!ttisnumber(pstep) && !luaV_tonumber(pstep, pstep)) + luaG_forerror(L, pstep, "step"); +} + +LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) +{ + StkId ra = &L->base[a]; + LUAU_ASSERT(ra + 6 <= L->top); + + setobjs2s(L, ra + 3 + 2, ra + 2); + setobjs2s(L, ra + 3 + 1, ra + 1); + setobjs2s(L, ra + 3, ra); + + L->top = ra + 3 + 3; /* func. + 2 args (state and index) */ + LUAU_ASSERT(L->top <= L->stack_last); + + luaD_call(L, ra + 3, c); + L->top = L->ci->top; + + // recompute ra since stack might have been reallocated + ra = &L->base[a]; + LUAU_ASSERT(ra < L->top); + + // copy first variable back into the iteration index + setobjs2s(L, ra + 2, ra + 3); + + return ttisnil(ra + 2); +} + +// calls a C function f with no yielding support; optionally save one resulting value to the res register +// the function and arguments have to already be pushed to L->top +LUAU_NOINLINE static void luau_callTM(lua_State* L, int nparams, int res) +{ + ++L->nCcalls; + + if (L->nCcalls >= LUAI_MAXCCALLS) + luaG_runerror(L, "C stack overflow"); + + luaD_checkstack(L, LUA_MINSTACK); + + StkId top = L->top; + StkId fun = top - nparams - 1; + + CallInfo* ci = incr_ci(L); + ci->func = fun; + ci->base = fun + 1; + ci->top = top + LUA_MINSTACK; + ci->savedpc = NULL; + ci->flags = 0; + ci->nresults = (res >= 0); + LUAU_ASSERT(ci->top <= L->stack_last); + + LUAU_ASSERT(ttisfunction(ci->func)); + LUAU_ASSERT(clvalue(ci->func)->isC); + + L->base = fun + 1; + LUAU_ASSERT(L->top == L->base + nparams); + + lua_CFunction func = clvalue(fun)->c.f; + int n = func(L); + LUAU_ASSERT(n >= 0); // yields should have been blocked by nCcalls + + // ci is our callinfo, cip is our parent + // note that we read L->ci again since it may have been reallocated by the call + CallInfo* cip = L->ci - 1; + + // copy return value into parent stack + if (res >= 0) + { + if (n > 0) + { + setobj2s(L, &cip->base[res], L->top - n); + } + else + { + setnilvalue(&cip->base[res]); + } + } + + L->ci = cip; + L->base = cip->base; + L->top = cip->top; + + --L->nCcalls; +} + +LUAU_NOINLINE static void luau_tryfuncTM(lua_State* L, StkId func) +{ + const TValue* tm = luaT_gettmbyobj(L, func, TM_CALL); + if (!ttisfunction(tm)) + luaG_typeerror(L, func, "call"); + for (StkId p = L->top; p > func; p--) /* open space for metamethod */ + setobjs2s(L, p, p - 1); + L->top++; /* stack space pre-allocated by the caller */ + setobj2s(L, func, tm); /* tag method is the new function to be called */ +} + +LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) +{ + ptrdiff_t base = savestack(L, L->base); + ptrdiff_t top = savestack(L, L->top); + ptrdiff_t ci_top = savestack(L, L->ci->top); + int status = L->status; + + // if the hook is called externally on a paused thread, we need to make sure the paused thread can emit Lua calls + if (status == LUA_YIELD || status == LUA_BREAK) + { + L->status = 0; + L->base = L->ci->base; + } + + luaD_checkstack(L, LUA_MINSTACK); /* ensure minimum stack size */ + L->ci->top = L->top + LUA_MINSTACK; + LUAU_ASSERT(L->ci->top <= L->stack_last); + + // note: the pc expectations of the hook are matching the general "pc points to next instruction" + // however, for the hook to be able to continue execution from the same point, this is called with savedpc at the *current* instruction + if (L->ci->savedpc) + L->ci->savedpc++; + + Closure* cl = clvalue(L->ci->func); + + lua_Debug ar; + ar.currentline = cl->isC ? -1 : luaG_getline(cl->l.p, pcRel(L->ci->savedpc, cl->l.p)); + ar.userdata = userdata; + + hook(L, &ar); + + if (L->ci->savedpc) + L->ci->savedpc--; + + L->ci->top = restorestack(L, ci_top); + L->top = restorestack(L, top); + + // note that we only restore the paused state if the hook hasn't yielded by itself + if (status == LUA_YIELD && L->status != LUA_YIELD) + { + L->status = LUA_YIELD; + L->base = restorestack(L, base); + } + else if (status == LUA_BREAK) + { + LUAU_ASSERT(L->status != LUA_BREAK); // hook shouldn't break again + + L->status = LUA_BREAK; + L->base = restorestack(L, base); + } +} + +inline bool luau_skipstep(uint8_t op) +{ + return op == LOP_PREPVARARGS || op == LOP_BREAK; +} + +// declared in lbaselib.cpp, needed to support cases when pairs/ipairs have been replaced via setfenv +LUAI_FUNC int luaB_inext(lua_State* L); +LUAI_FUNC int luaB_next(lua_State* L); + +template +static void luau_execute(lua_State* L) +{ +#if VM_USE_CGOTO + static const void* kDispatchTable[256] = {VM_DISPATCH_TABLE()}; +#endif + + // the critical interpreter state, stored in locals for performance + // the hope is that these map to registers without spilling (which is not true for x86 :/) + Closure* cl; + StkId base; + TValue* k; + const Instruction* pc; + + LUAU_ASSERT(isLua(L->ci)); + LUAU_ASSERT(luaC_threadactive(L)); + LUAU_ASSERT(!luaC_threadsleeping(L)); + + pc = L->ci->savedpc; + cl = clvalue(L->ci->func); + base = L->base; + k = cl->l.p->k; + + VM_NEXT(); // starts the interpreter "loop" + + { + dispatch: + // Note: this code doesn't always execute! on some platforms we use computed goto which bypasses all of this unless we run in single-step mode + // Therefore only ever put assertions here. + LUAU_ASSERT(base == L->base && L->base == L->ci->base); + LUAU_ASSERT(base <= L->top && L->top <= L->stack + L->stacksize); + + // ... and singlestep logic :) + if (SingleStep) + { + if (L->global->cb.debugstep && !luau_skipstep(*(uint8_t*)pc)) + { + VM_PROTECT(luau_callhook(L, L->global->cb.debugstep, NULL)); + + // allow debugstep hook to put thread into error/yield state + if (L->status != 0) + goto exit; + } + +#if VM_USE_CGOTO + VM_CONTINUE(*(uint8_t*)pc); +#endif + } + +#if !VM_USE_CGOTO + // Note: this assumes that LUAU_INSN_OP() decodes the first byte (aka least significant byte in the little endian encoding) + size_t dispatchOp = *(uint8_t*)pc; + + dispatchContinue: + switch (dispatchOp) +#endif + { + VM_CASE(LOP_NOP) + { + Instruction insn = *pc++; + LUAU_ASSERT(insn == 0); + VM_NEXT(); + } + + VM_CASE(LOP_LOADNIL) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + setnilvalue(ra); + VM_NEXT(); + } + + VM_CASE(LOP_LOADB) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + setbvalue(ra, LUAU_INSN_B(insn)); + + pc += LUAU_INSN_C(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_LOADN) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + setnvalue(ra, LUAU_INSN_D(insn)); + VM_NEXT(); + } + + VM_CASE(LOP_LOADK) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_D(insn)); + + setobj2s(L, ra, kv); + VM_NEXT(); + } + + VM_CASE(LOP_MOVE) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + + setobj2s(L, ra, rb); + VM_NEXT(); + } + + VM_CASE(LOP_GETGLOBAL) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: value is in expected slot + Table* h = cl->env; + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv)) && !ttisnil(gval(n))) + { + setobj2s(L, ra, gval(n)); + VM_NEXT(); + } + else + { + // slow-path, may invoke Lua calls via __index metamethod + TValue g; + sethvalue(L, &g, h); + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, &g, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + VM_NEXT(); + } + } + + VM_CASE(LOP_SETGLOBAL) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: value is in expected slot + Table* h = cl->env; + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) + { + setobj(L, gval(n), ra); + luaC_barriert(L, h, ra); + VM_NEXT(); + } + else + { + // slow-path, may invoke Lua calls via __newindex metamethod + TValue g; + sethvalue(L, &g, h); + L->cachedslot = slot; + VM_PROTECT(luaV_settable(L, &g, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + VM_NEXT(); + } + } + + VM_CASE(LOP_GETUPVAL) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* ur = VM_UV(LUAU_INSN_B(insn)); + TValue* v = ttisupval(ur) ? upvalue(ur)->v : ur; + + setobj2s(L, ra, v); + VM_NEXT(); + } + + VM_CASE(LOP_SETUPVAL) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* ur = VM_UV(LUAU_INSN_B(insn)); + UpVal* uv = upvalue(ur); + + setobj(L, uv->v, ra); + luaC_barrier(L, uv, ra); + luaC_upvalbarrier(L, uv, uv->v); + VM_NEXT(); + } + + VM_CASE(LOP_CLOSEUPVALS) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + if (L->openupval && gco2uv(L->openupval)->v >= ra) + luaF_close(L, ra); + VM_NEXT(); + } + + VM_CASE(LOP_GETIMPORT) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_D(insn)); + + // fast-path: import resolution was successful and closure environment is "safe" for import + if (!ttisnil(kv) && cl->env->safeenv) + { + setobj2s(L, ra, kv); + pc++; // skip over AUX + VM_NEXT(); + } + else + { + uint32_t aux = *pc++; + + VM_PROTECT(luaV_getimport(L, cl->env, k, aux, /* propagatenil= */ false)); + ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack + + setobj2s(L, ra, L->top - 1); + L->top--; + VM_NEXT(); + } + } + + VM_CASE(LOP_GETTABLEKS) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: built-in table + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: value is in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) + { + setobj2s(L, ra, gval(n)); + VM_NEXT(); + } + else if (!h->metatable) + { + // fast-path: value is not in expected slot, but the table lookup doesn't involve metatable + const TValue* res = luaH_getstr(h, tsvalue(kv)); + + if (res != luaO_nilobject) + { + int cachedslot = gval2slot(h, res); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, cachedslot); + } + + setobj2s(L, ra, res); + VM_NEXT(); + } + else + { + // slow-path, may invoke Lua calls via __index metamethod + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + VM_NEXT(); + } + } + else + { + // fast-path: user data with C __index TM + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_INDEX)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + VM_NEXT(); + } + else if (ttisvector(rb)) + { + // fast-path: quick case-insensitive comparison with "X"/"Y"/"Z" + const char* name = getstr(tsvalue(kv)); + int ic = (name[0] | ' ') - 'x'; + + if (unsigned(ic) < 3 && name[1] == '\0') + { + setnvalue(ra, rb->value.v[ic]); + VM_NEXT(); + } + + fn = fasttm(L, L->global->mt[LUA_TVECTOR], TM_INDEX); + + if (fn && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + VM_NEXT(); + } + else + { + // slow-path, may invoke Lua calls via __index metamethod + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + VM_NEXT(); + } + } + else + { + // slow-path, may invoke Lua calls via __index metamethod + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + VM_NEXT(); + } + } + } + + VM_CASE(LOP_SETTABLEKS) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + // fast-path: built-in table + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: value is in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)) && !h->readonly)) + { + setobj(L, gval(n), ra); + luaC_barriert(L, h, ra); + VM_NEXT(); + } + else if (fastnotm(h->metatable, TM_NEWINDEX) && !h->readonly) + { + VM_PROTECT_PC(); // set may fail + + TValue* res = luaH_setstr(L, h, tsvalue(kv)); + + if (res != luaO_nilobject) + { + int cachedslot = gval2slot(h, res); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, cachedslot); + } + + setobj(L, res, ra); + luaC_barriert(L, h, ra); + VM_NEXT(); + } + else + { + // slow-path, may invoke Lua calls via __index metamethod + L->cachedslot = slot; + VM_PROTECT(luaV_settable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + VM_NEXT(); + } + } + else + { + // fast-path: user data with C __index TM + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_NEWINDEX)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 4 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + setobj2s(L, top + 3, ra); + L->top = top + 4; + + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luau_callTM(L, 3, -1)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + VM_NEXT(); + } + else + { + // slow-path, may invoke Lua calls via __index metamethod + VM_PROTECT(luaV_settable(L, rb, kv, ra)); + VM_NEXT(); + } + } + } + + VM_CASE(LOP_GETTABLE) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path: array lookup + if (ttistable(rb) && ttisnumber(rc)) + { + Table* h = hvalue(rb); + + double indexd = nvalue(rc); + int index = int(indexd); + + // index has to be an exact integer and in-bounds for the array portion + if (LUAU_LIKELY(unsigned(index - 1) < unsigned(h->sizearray) && !h->metatable && double(index) == indexd)) + { + setobj2s(L, ra, &h->array[unsigned(index - 1)]); + VM_NEXT(); + } + else + { + // slow-path: handles out of bounds array lookups and non-integer numeric keys + VM_PROTECT(luaV_gettable(L, rb, rc, ra)); + VM_NEXT(); + } + } + else + { + // slow-path: handles non-array table lookup as well as __index MT calls + VM_PROTECT(luaV_gettable(L, rb, rc, ra)); + VM_NEXT(); + } + } + + VM_CASE(LOP_SETTABLE) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path: array assign + if (ttistable(rb) && ttisnumber(rc)) + { + Table* h = hvalue(rb); + + double indexd = nvalue(rc); + int index = int(indexd); + + // index has to be an exact integer and in-bounds for the array portion + if (LUAU_LIKELY(unsigned(index - 1) < unsigned(h->sizearray) && !h->metatable && !h->readonly && double(index) == indexd)) + { + setobj2t(L, &h->array[unsigned(index - 1)], ra); + luaC_barriert(L, h, ra); + VM_NEXT(); + } + else + { + // slow-path: handles out of bounds array assignments and non-integer numeric keys + VM_PROTECT(luaV_settable(L, rb, rc, ra)); + VM_NEXT(); + } + } + else + { + // slow-path: handles non-array table access as well as __newindex MT calls + VM_PROTECT(luaV_settable(L, rb, rc, ra)); + VM_NEXT(); + } + } + + VM_CASE(LOP_GETTABLEN) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + int c = LUAU_INSN_C(insn); + + // fast-path: array lookup + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + if (LUAU_LIKELY(unsigned(c) < unsigned(h->sizearray) && !h->metatable)) + { + setobj2s(L, ra, &h->array[c]); + VM_NEXT(); + } + } + + // slow-path: handles out of bounds array lookups + TValue n; + setnvalue(&n, c + 1); + VM_PROTECT(luaV_gettable(L, rb, &n, ra)); + VM_NEXT(); + } + + VM_CASE(LOP_SETTABLEN) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + int c = LUAU_INSN_C(insn); + + // fast-path: array assign + if (ttistable(rb)) + { + Table* h = hvalue(rb); + + if (LUAU_LIKELY(unsigned(c) < unsigned(h->sizearray) && !h->metatable && !h->readonly)) + { + setobj2t(L, &h->array[c], ra); + luaC_barriert(L, h, ra); + VM_NEXT(); + } + } + + // slow-path: handles out of bounds array lookups + TValue n; + setnvalue(&n, c + 1); + VM_PROTECT(luaV_settable(L, rb, &n, ra)); + VM_NEXT(); + } + + VM_CASE(LOP_NEWCLOSURE) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + Proto* pv = cl->l.p->p[LUAU_INSN_D(insn)]; + LUAU_ASSERT(unsigned(LUAU_INSN_D(insn)) < unsigned(cl->l.p->sizep)); + + // note: we save closure to stack early in case the code below wants to capture it by value + Closure* ncl = luaF_newLclosure(L, pv->nups, cl->env, pv); + setclvalue(L, ra, ncl); + + for (int ui = 0; ui < pv->nups; ++ui) + { + Instruction uinsn = *pc++; + LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); + + switch (LUAU_INSN_A(uinsn)) + { + case LCT_VAL: + setobj(L, &ncl->l.uprefs[ui], VM_REG(LUAU_INSN_B(uinsn))); + break; + + case LCT_REF: + setupvalue(L, &ncl->l.uprefs[ui], luaF_findupval(L, VM_REG(LUAU_INSN_B(uinsn)))); + break; + + case LCT_UPVAL: + setobj(L, &ncl->l.uprefs[ui], VM_UV(LUAU_INSN_B(uinsn))); + break; + + default: + LUAU_ASSERT(!"Unknown upvalue capture type"); + } + } + + VM_PROTECT(luaC_checkGC(L)); + VM_NEXT(); + } + + VM_CASE(LOP_NAMECALL) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + LUAU_ASSERT(ttisstring(kv)); + + if (ttistable(rb)) + { + Table* h = hvalue(rb); + // note: we can't use nodemask8 here because we need to query the main position of the table, and 8-bit nodemask8 only works + // for predictive lookups + LuaNode* n = &h->node[tsvalue(kv)->hash & (sizenode(h) - 1)]; + + const TValue* mt = 0; + const LuaNode* mtn = 0; + + // fast-path: key is in the table in expected slot + if (ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(n)); + } + // fast-path: key is absent from the base, table has an __index table, and it has the result in the expected slot + else if (gnext(n) == 0 && (mt = fasttm(L, hvalue(rb)->metatable, TM_INDEX)) && ttistable(mt) && + (mtn = &hvalue(mt)->node[LUAU_INSN_C(insn) & hvalue(mt)->nodemask8]) && ttisstring(gkey(mtn)) && + tsvalue(gkey(mtn)) == tsvalue(kv) && !ttisnil(gval(mtn))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(mtn)); + } + else + { + // slow-path: handles full table lookup + setobj2s(L, ra + 1, rb); + L->cachedslot = LUAU_INSN_C(insn); + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + } + } + else + { + Table* mt = ttisuserdata(rb) ? uvalue(rb)->metatable : L->global->mt[ttype(rb)]; + const TValue* tmi = 0; + + // fast-path: metatable with __namecall + if (const TValue* fn = fasttm(L, mt, TM_NAMECALL)) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, fn); + + L->namecall = tsvalue(kv); + } + else if ((tmi = fasttm(L, mt, TM_INDEX)) && ttistable(tmi)) + { + Table* h = hvalue(tmi); + int slot = LUAU_INSN_C(insn) & h->nodemask8; + LuaNode* n = &h->node[slot]; + + // fast-path: metatable with __index that has method in expected slot + if (LUAU_LIKELY(ttisstring(gkey(n)) && tsvalue(gkey(n)) == tsvalue(kv) && !ttisnil(gval(n)))) + { + // note: order of copies allows rb to alias ra+1 or ra + setobj2s(L, ra + 1, rb); + setobj2s(L, ra, gval(n)); + } + else + { + // slow-path: handles slot mismatch + setobj2s(L, ra + 1, rb); + L->cachedslot = slot; + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, L->cachedslot); + } + } + else + { + // slow-path: handles non-table __index + setobj2s(L, ra + 1, rb); + VM_PROTECT(luaV_gettable(L, rb, kv, ra)); + } + } + + // intentional fallthrough to CALL + LUAU_ASSERT(LUAU_INSN_OP(*pc) == LOP_CALL); + } + + VM_CASE(LOP_CALL) + { + VM_INTERRUPT(); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + int nparams = LUAU_INSN_B(insn) - 1; + int nresults = LUAU_INSN_C(insn) - 1; + + StkId argtop = L->top; + argtop = (nparams == LUA_MULTRET) ? argtop : ra + 1 + nparams; + + // slow-path: not a function call + if (LUAU_UNLIKELY(!ttisfunction(ra))) + { + VM_PROTECT(luau_tryfuncTM(L, ra)); + argtop++; // __call adds an extra self + } + + Closure* ccl = clvalue(ra); + L->ci->savedpc = pc; + + CallInfo* ci = incr_ci(L); + ci->func = ra; + ci->base = ra + 1; + ci->top = argtop + ccl->stacksize; // note: technically UB since we haven't reallocated the stack yet + ci->savedpc = NULL; + ci->flags = 0; + ci->nresults = nresults; + + L->base = ci->base; + L->top = argtop; + + // note: this reallocs stack, but we don't need to VM_PROTECT this + // this is because we're going to modify base/savedpc manually anyhow + // crucially, we can't use ra/argtop after this line + luaD_checkstack(L, ccl->stacksize); + + LUAU_ASSERT(ci->top <= L->stack_last); + + if (!ccl->isC) + { + Proto* p = ccl->l.p; + + // fill unused parameters with nil + StkId argi = L->top; + StkId argend = L->base + p->numparams; + while (argi < argend) + setnilvalue(argi++); /* complete missing arguments */ + L->top = p->is_vararg ? argi : ci->top; + + // reentry + pc = p->code; + cl = ccl; + base = L->base; + k = p->k; + VM_NEXT(); + } + else + { + lua_CFunction func = ccl->c.f; + int n = func(L); + + // yield + if (n < 0) + goto exit; + + // ci is our callinfo, cip is our parent + CallInfo* ci = L->ci; + CallInfo* cip = ci - 1; + + // copy return values into parent stack (but only up to nresults!), fill the rest with nil + // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally + StkId res = ci->func; + StkId vali = L->top - n; + StkId valend = L->top; + + int i; + for (i = nresults; i != 0 && vali < valend; i--) + setobjs2s(L, res++, vali++); + while (i-- > 0) + setnilvalue(res++); + + // pop the stack frame + L->ci = cip; + L->base = cip->base; + L->top = (nresults == LUA_MULTRET) ? res : cip->top; + + base = L->base; // stack may have been reallocated, so we need to refresh base ptr + VM_NEXT(); + } + } + + VM_CASE(LOP_RETURN) + { + VM_INTERRUPT(); + Instruction insn = *pc++; + StkId ra = &base[LUAU_INSN_A(insn)]; // note: this can point to L->top if b == LUA_MULTRET making VM_REG unsafe to use + int b = LUAU_INSN_B(insn) - 1; + + // ci is our callinfo, cip is our parent + CallInfo* ci = L->ci; + CallInfo* cip = ci - 1; + + StkId res = ci->func; // note: we assume CALL always puts func+args and expects results to start at func + + StkId vali = ra; + StkId valend = + (b == LUA_MULTRET) ? L->top : ra + b; // copy as much as possible for MULTRET calls, and only as much as needed otherwise + + int nresults = ci->nresults; + + // copy return values into parent stack (but only up to nresults!), fill the rest with nil + // note: in MULTRET context nresults starts as -1 so i != 0 condition never activates intentionally + int i; + for (i = nresults; i != 0 && vali < valend; i--) + setobjs2s(L, res++, vali++); + while (i-- > 0) + setnilvalue(res++); + + // pop the stack frame + L->ci = cip; + L->base = cip->base; + L->top = (nresults == LUA_MULTRET) ? res : cip->top; + + // we're done! + if (LUAU_UNLIKELY(ci->flags & LUA_CALLINFO_RETURN)) + { + L->top = res; + goto exit; + } + + LUAU_ASSERT(isLua(L->ci)); + + // reentry + pc = cip->savedpc; + cl = clvalue(cip->func); + base = L->base; + k = cl->l.p->k; + VM_NEXT(); + } + + VM_CASE(LOP_JUMP) + { + Instruction insn = *pc++; + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_JUMPIF) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + pc += l_isfalse(ra) ? 0 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_JUMPIFNOT) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + pc += l_isfalse(ra) ? LUAU_INSN_D(insn) : 0; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_JUMPIFEQ) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // Note that all jumps below jump by 1 in the "false" case to skip over aux + if (ttype(ra) == ttype(rb)) + { + switch (ttype(ra)) + { + case LUA_TNIL: + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TBOOLEAN: + pc += bvalue(ra) == bvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TLIGHTUSERDATA: + pc += pvalue(ra) == pvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TNUMBER: + pc += nvalue(ra) == nvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TVECTOR: + pc += luai_veceq(vvalue(ra), vvalue(rb)) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TSTRING: + case LUA_TFUNCTION: + case LUA_TTHREAD: + pc += gcvalue(ra) == gcvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TTABLE: + // fast-path: same metatable, no EQ metamethod + if (hvalue(ra)->metatable == hvalue(rb)->metatable) + { + const TValue* fn = fasttm(L, hvalue(ra)->metatable, TM_EQ); + + if (!fn) + { + pc += hvalue(ra) == hvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + // slow path after switch() + break; + + case LUA_TUSERDATA: + // fast-path: same metatable, no EQ metamethod or C metamethod + if (uvalue(ra)->metatable == uvalue(rb)->metatable) + { + const TValue* fn = fasttm(L, uvalue(ra)->metatable, TM_EQ); + + if (!fn) + { + pc += uvalue(ra) == uvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else if (ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, ra); + setobj2s(L, top + 2, rb); + int res = int(top - base); + L->top = top + 3; + + VM_PROTECT(luau_callTM(L, 2, res)); + pc += !l_isfalse(&base[res]) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + // slow path after switch() + break; + + default:; + } + + // slow-path: tables with metatables and userdata values + // note that we don't have a fast path for userdata values without metatables, since that's very rare + int res; + VM_PROTECT(res = luaV_equalval(L, ra, rb)); + + pc += (res == 1) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + pc += 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + + VM_CASE(LOP_JUMPIFNOTEQ) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // Note that all jumps below jump by 1 in the "true" case to skip over aux + if (ttype(ra) == ttype(rb)) + { + switch (ttype(ra)) + { + case LUA_TNIL: + pc += 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TBOOLEAN: + pc += bvalue(ra) != bvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TLIGHTUSERDATA: + pc += pvalue(ra) != pvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TNUMBER: + pc += nvalue(ra) != nvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TVECTOR: + pc += !luai_veceq(vvalue(ra), vvalue(rb)) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TSTRING: + case LUA_TFUNCTION: + case LUA_TTHREAD: + pc += gcvalue(ra) != gcvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TTABLE: + // fast-path: same metatable, no EQ metamethod + if (hvalue(ra)->metatable == hvalue(rb)->metatable) + { + const TValue* fn = fasttm(L, hvalue(ra)->metatable, TM_EQ); + + if (!fn) + { + pc += hvalue(ra) != hvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + // slow path after switch() + break; + + case LUA_TUSERDATA: + // fast-path: same metatable, no EQ metamethod or C metamethod + if (uvalue(ra)->metatable == uvalue(rb)->metatable) + { + const TValue* fn = fasttm(L, uvalue(ra)->metatable, TM_EQ); + + if (!fn) + { + pc += uvalue(ra) != uvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else if (ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, ra); + setobj2s(L, top + 2, rb); + int res = int(top - base); + L->top = top + 3; + + VM_PROTECT(luau_callTM(L, 2, res)); + pc += l_isfalse(&base[res]) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + // slow path after switch() + break; + + default:; + } + + // slow-path: tables with metatables and userdata values + // note that we don't have a fast path for userdata values without metatables, since that's very rare + int res; + VM_PROTECT(res = luaV_equalval(L, ra, rb)); + + pc += (res == 0) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + + VM_CASE(LOP_JUMPIFLE) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // fast-path: number + // Note that all jumps below jump by 1 in the "false" case to skip over aux + if (ttisnumber(ra) && ttisnumber(rb)) + { + pc += nvalue(ra) <= nvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + // fast-path: string + else if (ttisstring(ra) && ttisstring(rb)) + { + pc += luaV_strcmp(tsvalue(ra), tsvalue(rb)) <= 0 ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + int res; + VM_PROTECT(res = luaV_lessequal(L, ra, rb)); + + pc += (res == 1) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + + VM_CASE(LOP_JUMPIFNOTLE) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // fast-path: number + // Note that all jumps below jump by 1 in the "true" case to skip over aux + if (ttisnumber(ra) && ttisnumber(rb)) + { + pc += !(nvalue(ra) <= nvalue(rb)) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + // fast-path: string + else if (ttisstring(ra) && ttisstring(rb)) + { + pc += !(luaV_strcmp(tsvalue(ra), tsvalue(rb)) <= 0) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + int res; + VM_PROTECT(res = luaV_lessequal(L, ra, rb)); + + pc += (res == 0) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + + VM_CASE(LOP_JUMPIFLT) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // fast-path: number + // Note that all jumps below jump by 1 in the "false" case to skip over aux + if (ttisnumber(ra) && ttisnumber(rb)) + { + pc += nvalue(ra) < nvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + // fast-path: string + else if (ttisstring(ra) && ttisstring(rb)) + { + pc += luaV_strcmp(tsvalue(ra), tsvalue(rb)) < 0 ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + int res; + VM_PROTECT(res = luaV_lessthan(L, ra, rb)); + + pc += (res == 1) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + + VM_CASE(LOP_JUMPIFNOTLT) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(aux); + + // fast-path: number + // Note that all jumps below jump by 1 in the "true" case to skip over aux + if (ttisnumber(ra) && ttisnumber(rb)) + { + pc += !(nvalue(ra) < nvalue(rb)) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + // fast-path: string + else if (ttisstring(ra) && ttisstring(rb)) + { + pc += !(luaV_strcmp(tsvalue(ra), tsvalue(rb)) < 0) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + int res; + VM_PROTECT(res = luaV_lessthan(L, ra, rb)); + + pc += (res == 0) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + + VM_CASE(LOP_ADD) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + setnvalue(ra, nvalue(rb) + nvalue(rc)); + VM_NEXT(); + } + else if (ttisvector(rb) && ttisvector(rc)) + { + const float* vb = rb->value.v; + const float* vc = rc->value.v; + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2]); + VM_NEXT(); + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_ADD)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, rc); + L->top = top + 3; + + VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_ADD)); + VM_NEXT(); + } + } + } + + VM_CASE(LOP_SUB) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + setnvalue(ra, nvalue(rb) - nvalue(rc)); + VM_NEXT(); + } + else if (ttisvector(rb) && ttisvector(rc)) + { + const float* vb = rb->value.v; + const float* vc = rc->value.v; + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2]); + VM_NEXT(); + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_SUB)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, rc); + L->top = top + 3; + + VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_SUB)); + VM_NEXT(); + } + } + } + + VM_CASE(LOP_MUL) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + setnvalue(ra, nvalue(rb) * nvalue(rc)); + VM_NEXT(); + } + else if (ttisvector(rb) && ttisnumber(rc)) + { + const float* vb = rb->value.v; + float vc = cast_to(float, nvalue(rc)); + setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc); + VM_NEXT(); + } + else if (ttisvector(rb) && ttisvector(rc)) + { + const float* vb = rb->value.v; + const float* vc = rc->value.v; + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2]); + VM_NEXT(); + } + else if (ttisnumber(rb) && ttisvector(rc)) + { + float vb = cast_to(float, nvalue(rb)); + const float* vc = rc->value.v; + setvvalue(ra, vb * vc[0], vb * vc[1], vb * vc[2]); + VM_NEXT(); + } + else + { + // fast-path for userdata with C functions + StkId rbc = ttisnumber(rb) ? rc : rb; + const TValue* fn = 0; + if (ttisuserdata(rbc) && (fn = luaT_gettmbyobj(L, rbc, TM_MUL)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, rc); + L->top = top + 3; + + VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MUL)); + VM_NEXT(); + } + } + } + + VM_CASE(LOP_DIV) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + setnvalue(ra, nvalue(rb) / nvalue(rc)); + VM_NEXT(); + } + else if (ttisvector(rb) && ttisnumber(rc)) + { + const float* vb = rb->value.v; + float vc = cast_to(float, nvalue(rc)); + setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc); + VM_NEXT(); + } + else if (ttisvector(rb) && ttisvector(rc)) + { + const float* vb = rb->value.v; + const float* vc = rc->value.v; + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2]); + VM_NEXT(); + } + else if (ttisnumber(rb) && ttisvector(rc)) + { + float vb = cast_to(float, nvalue(rb)); + const float* vc = rc->value.v; + setvvalue(ra, vb / vc[0], vb / vc[1], vb / vc[2]); + VM_NEXT(); + } + else + { + // fast-path for userdata with C functions + StkId rbc = ttisnumber(rb) ? rc : rb; + const TValue* fn = 0; + if (ttisuserdata(rbc) && (fn = luaT_gettmbyobj(L, rbc, TM_DIV)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, rc); + L->top = top + 3; + + VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_DIV)); + VM_NEXT(); + } + } + } + + VM_CASE(LOP_MOD) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + double nb = nvalue(rb); + double nc = nvalue(rc); + setnvalue(ra, luai_nummod(nb, nc)); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_MOD)); + VM_NEXT(); + } + } + + VM_CASE(LOP_POW) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb) && ttisnumber(rc)) + { + setnvalue(ra, pow(nvalue(rb), nvalue(rc))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rc, TM_POW)); + VM_NEXT(); + } + } + + VM_CASE(LOP_ADDK) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + setnvalue(ra, nvalue(rb) + nvalue(kv)); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_ADD)); + VM_NEXT(); + } + } + + VM_CASE(LOP_SUBK) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + setnvalue(ra, nvalue(rb) - nvalue(kv)); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_SUB)); + VM_NEXT(); + } + } + + VM_CASE(LOP_MULK) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + setnvalue(ra, nvalue(rb) * nvalue(kv)); + VM_NEXT(); + } + else if (ttisvector(rb)) + { + const float* vb = rb->value.v; + float vc = cast_to(float, nvalue(kv)); + setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc); + VM_NEXT(); + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_MUL)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MUL)); + VM_NEXT(); + } + } + } + + VM_CASE(LOP_DIVK) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + setnvalue(ra, nvalue(rb) / nvalue(kv)); + VM_NEXT(); + } + else if (ttisvector(rb)) + { + const float* vb = rb->value.v; + float vc = cast_to(float, nvalue(kv)); + setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc); + VM_NEXT(); + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_DIV)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 3 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + setobj2s(L, top + 2, kv); + L->top = top + 3; + + VM_PROTECT(luau_callTM(L, 2, LUAU_INSN_A(insn))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_DIV)); + VM_NEXT(); + } + } + } + + VM_CASE(LOP_MODK) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + double nb = nvalue(rb); + double nk = nvalue(kv); + setnvalue(ra, luai_nummod(nb, nk)); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_MOD)); + VM_NEXT(); + } + } + + VM_CASE(LOP_POWK) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + // fast-path + if (ttisnumber(rb)) + { + double nb = nvalue(rb); + double nk = nvalue(kv); + + // pow is very slow so we specialize this for ^2, ^0.5 and ^3 + double r = (nk == 2.0) ? nb * nb : (nk == 0.5) ? sqrt(nb) : (nk == 3.0) ? nb * nb * nb : pow(nb, nk); + + setnvalue(ra, r); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, kv, TM_POW)); + VM_NEXT(); + } + } + + VM_CASE(LOP_AND) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + setobj2s(L, ra, l_isfalse(rb) ? rb : rc); + VM_NEXT(); + } + + VM_CASE(LOP_OR) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + StkId rc = VM_REG(LUAU_INSN_C(insn)); + + setobj2s(L, ra, l_isfalse(rb) ? rc : rb); + VM_NEXT(); + } + + VM_CASE(LOP_ANDK) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + setobj2s(L, ra, l_isfalse(rb) ? rb : kv); + VM_NEXT(); + } + + VM_CASE(LOP_ORK) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + TValue* kv = VM_KV(LUAU_INSN_C(insn)); + + setobj2s(L, ra, l_isfalse(rb) ? kv : rb); + VM_NEXT(); + } + + VM_CASE(LOP_CONCAT) + { + Instruction insn = *pc++; + int b = LUAU_INSN_B(insn); + int c = LUAU_INSN_C(insn); + + // This call may realloc the stack! So we need to query args further down + VM_PROTECT(luaV_concat(L, c - b + 1, c)); + + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + setobjs2s(L, ra, base + b); + VM_PROTECT(luaC_checkGC(L)); + VM_NEXT(); + } + + VM_CASE(LOP_NOT) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + + int res = l_isfalse(rb); + setbvalue(ra, res); + VM_NEXT(); + } + + VM_CASE(LOP_MINUS) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + + // fast-path + if (ttisnumber(rb)) + { + setnvalue(ra, -nvalue(rb)); + VM_NEXT(); + } + else if (ttisvector(rb)) + { + const float* vb = rb->value.v; + setvvalue(ra, -vb[0], -vb[1], -vb[2]); + VM_NEXT(); + } + else + { + // fast-path for userdata with C functions + const TValue* fn = 0; + if (ttisuserdata(rb) && (fn = luaT_gettmbyobj(L, rb, TM_UNM)) && ttisfunction(fn) && clvalue(fn)->isC) + { + // note: it's safe to push arguments past top for complicated reasons (see top of the file) + LUAU_ASSERT(L->top + 2 < L->stack + L->stacksize); + StkId top = L->top; + setobj2s(L, top + 0, fn); + setobj2s(L, top + 1, rb); + L->top = top + 2; + + VM_PROTECT(luau_callTM(L, 1, LUAU_INSN_A(insn))); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_doarith(L, ra, rb, rb, TM_UNM)); + VM_NEXT(); + } + } + } + + VM_CASE(LOP_LENGTH) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = VM_REG(LUAU_INSN_B(insn)); + + // fast-path #1: tables + if (ttistable(rb)) + { + setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); + VM_NEXT(); + } + // fast-path #2: strings (not very important but easy to do) + else if (ttisstring(rb)) + { + setnvalue(ra, cast_num(tsvalue(rb)->len)); + VM_NEXT(); + } + else + { + // slow-path, may invoke C/Lua via metamethods + VM_PROTECT(luaV_dolen(L, ra, rb)); + VM_NEXT(); + } + } + + VM_CASE(LOP_NEWTABLE) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + int b = LUAU_INSN_B(insn); + uint32_t aux = *pc++; + + sethvalue(L, ra, luaH_new(L, aux, b == 0 ? 0 : (1 << (b - 1)))); + VM_PROTECT(luaC_checkGC(L)); + VM_NEXT(); + } + + VM_CASE(LOP_DUPTABLE) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_D(insn)); + + sethvalue(L, ra, luaH_clone(L, hvalue(kv))); + VM_PROTECT(luaC_checkGC(L)); + VM_NEXT(); + } + + VM_CASE(LOP_SETLIST) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + StkId rb = &base[LUAU_INSN_B(insn)]; // note: this can point to L->top if c == LUA_MULTRET making VM_REG unsafe to use + int c = LUAU_INSN_C(insn) - 1; + uint32_t index = *pc++; + + if (c == LUA_MULTRET) + { + c = int(L->top - rb); + L->top = L->ci->top; + } + + Table* h = hvalue(ra); + + if (!ttistable(ra)) + return; // temporary workaround to weaken a rather powerful exploitation primitive in case of a MITM attack on bytecode + + int last = index + c - 1; + if (last > h->sizearray) + luaH_resizearray(L, h, last); + + TValue* array = h->array; + + for (int i = 0; i < c; ++i) + setobj2t(L, &array[index + i - 1], rb + i); + + luaC_barrierfast(L, h); + VM_NEXT(); + } + + VM_CASE(LOP_FORNPREP) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + if (!ttisnumber(ra + 0) || !ttisnumber(ra + 1) || !ttisnumber(ra + 2)) + { + // slow-path: can convert arguments to numbers and trigger Lua errors + // Note: this doesn't reallocate stack so we don't need to recompute ra + VM_PROTECT(luau_prepareFORN(L, ra + 0, ra + 1, ra + 2)); + } + + double limit = nvalue(ra + 0); + double step = nvalue(ra + 1); + double idx = nvalue(ra + 2); + + // Note: make sure the loop condition is exactly the same between this and LOP_FORNLOOP so that we handle NaN/etc. consistently + pc += (step > 0 ? idx <= limit : limit <= idx) ? 0 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_FORNLOOP) + { + VM_INTERRUPT(); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + LUAU_ASSERT(ttisnumber(ra + 0) && ttisnumber(ra + 1) && ttisnumber(ra + 2)); + + double limit = nvalue(ra + 0); + double step = nvalue(ra + 1); + double idx = nvalue(ra + 2) + step; + + setnvalue(ra + 2, idx); + + // Note: make sure the loop condition is exactly the same between this and LOP_FORNPREP so that we handle NaN/etc. consistently + if (step > 0 ? idx <= limit : limit <= idx) + { + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + // fallthrough to exit + VM_NEXT(); + } + } + + VM_CASE(LOP_FORGLOOP) + { + VM_INTERRUPT(); + Instruction insn = *pc++; + uint32_t aux = *pc; + + // note: this is a slow generic path, fast-path is FORGLOOP_INEXT/NEXT + bool stop; + VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), aux)); + + // note that we need to increment pc by 1 to exit the loop since we need to skip over aux + pc += stop ? 1 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_FORGPREP_INEXT) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + // fast-path: ipairs/inext + bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_inext; + if (safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + } + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_FORGLOOP_INEXT) + { + VM_INTERRUPT(); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + // fast-path: ipairs/inext + if (ttistable(ra + 1) && ttislightuserdata(ra + 2)) + { + Table* h = hvalue(ra + 1); + int index = int(reinterpret_cast(pvalue(ra + 2))); + + // if 1-based index of the last iteration is in bounds, this means 0-based index of the current iteration is in bounds + if (unsigned(index) < unsigned(h->sizearray)) + { + // note that nil elements inside the array terminate the traversal + if (!ttisnil(&h->array[index])) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setnvalue(ra + 3, double(index + 1)); + setobj2s(L, ra + 4, &h->array[index]); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + // fallthrough to exit + VM_NEXT(); + } + } + else if (h->lsizenode == 0 && ttisnil(gval(h->node))) + { + // hash part is empty: fallthrough to exit + VM_NEXT(); + } + else + { + // the table has a hash part; index + 1 may appear in it in which case we need to iterate through the hash portion as well + const TValue* val = luaH_getnum(h, index + 1); + + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setnvalue(ra + 3, double(index + 1)); + setobj2s(L, ra + 4, val); + + // note that nil elements inside the array terminate the traversal + pc += ttisnil(ra + 4) ? 0 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + else + { + // slow-path; can call Lua/C generators + bool stop; + VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), 2)); + + pc += stop ? 0 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + + VM_CASE(LOP_FORGPREP_NEXT) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + // fast-path: pairs/next + bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_next; + if (safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); + } + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_FORGLOOP_NEXT) + { + VM_INTERRUPT(); + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + // fast-path: pairs/next + if (ttistable(ra + 1) && ttislightuserdata(ra + 2)) + { + Table* h = hvalue(ra + 1); + int index = int(reinterpret_cast(pvalue(ra + 2))); + + int sizearray = h->sizearray; + int sizenode = 1 << h->lsizenode; + + // first we advance index through the array portion + while (unsigned(index) < unsigned(sizearray)) + { + if (!ttisnil(&h->array[index])) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + setnvalue(ra + 3, double(index + 1)); + setobj2s(L, ra + 4, &h->array[index]); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + index++; + } + + // then we advance index through the hash portion + while (unsigned(index - sizearray) < unsigned(sizenode)) + { + LuaNode* n = &h->node[index - sizearray]; + + if (!ttisnil(gval(n))) + { + setpvalue(ra + 2, reinterpret_cast(uintptr_t(index + 1))); + getnodekey(L, ra + 3, n); + setobj2s(L, ra + 4, gval(n)); + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + index++; + } + + // fallthrough to exit + VM_NEXT(); + } + else + { + // slow-path; can call Lua/C generators + bool stop; + VM_PROTECT(stop = luau_loopFORG(L, LUAU_INSN_A(insn), 2)); + + pc += stop ? 0 : LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + + VM_CASE(LOP_GETVARARGS) + { + Instruction insn = *pc++; + int b = LUAU_INSN_B(insn) - 1; + int n = cast_int(base - L->ci->func) - cl->l.p->numparams - 1; + + if (b == LUA_MULTRET) + { + VM_PROTECT(luaD_checkstack(L, n)); + StkId ra = VM_REG(LUAU_INSN_A(insn)); // previous call may change the stack + + for (int j = 0; j < n; j++) + setobjs2s(L, ra + j, base - n + j); + + L->top = ra + n; + VM_NEXT(); + } + else + { + StkId ra = VM_REG(LUAU_INSN_A(insn)); + + for (int j = 0; j < b && j < n; j++) + setobjs2s(L, ra + j, base - n + j); + for (int j = n; j < b; j++) + setnilvalue(ra + j); + VM_NEXT(); + } + } + + VM_CASE(LOP_DUPCLOSURE) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* kv = VM_KV(LUAU_INSN_D(insn)); + + Closure* kcl = clvalue(kv); + + // clone closure if the environment is not shared + // note: we save closure to stack early in case the code below wants to capture it by value + Closure* ncl = (kcl->env == cl->env) ? kcl : luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); + setclvalue(L, ra, ncl); + + // this loop does three things: + // - if the closure was created anew, it just fills it with upvalues + // - if the closure from the constant table is used, it fills it with upvalues so that it can be shared in the future + // - if the closure is reused, it checks if the reuse is safe via rawequal, and falls back to duplicating the closure + // normally this would use two separate loops, for reuse check and upvalue setup, but MSVC codegen goes crazy if you do that + for (int ui = 0; ui < kcl->nupvalues; ++ui) + { + Instruction uinsn = pc[ui]; + LUAU_ASSERT(LUAU_INSN_OP(uinsn) == LOP_CAPTURE); + LUAU_ASSERT(LUAU_INSN_A(uinsn) == LCT_VAL || LUAU_INSN_A(uinsn) == LCT_UPVAL); + + TValue* uv = (LUAU_INSN_A(uinsn) == LCT_VAL) ? VM_REG(LUAU_INSN_B(uinsn)) : VM_UV(LUAU_INSN_B(uinsn)); + + // check if the existing closure is safe to reuse + if (ncl == kcl && luaO_rawequalObj(&ncl->l.uprefs[ui], uv)) + continue; + + // lazily clone the closure and update the upvalues + if (ncl == kcl && kcl->preload == 0) + { + ncl = luaF_newLclosure(L, kcl->nupvalues, cl->env, kcl->l.p); + setclvalue(L, ra, ncl); + + ui = -1; // restart the loop to fill all upvalues + continue; + } + + // this updates a newly created closure, or an existing closure created during preload, in which case we need a barrier + setobj(L, &ncl->l.uprefs[ui], uv); + luaC_barrier(L, ncl, uv); + } + + // this is a noop if ncl is newly created or shared successfully, but it has to run after the closure is preloaded for the first time + ncl->preload = 0; + + if (kcl != ncl) + VM_PROTECT(luaC_checkGC(L)); + + pc += kcl->nupvalues; + VM_NEXT(); + } + + VM_CASE(LOP_PREPVARARGS) + { + Instruction insn = *pc++; + int numparams = LUAU_INSN_A(insn); + + // all fixed parameters are copied after the top so we need more stack space + VM_PROTECT(luaD_checkstack(L, cl->stacksize + numparams)); + + // the caller must have filled extra fixed arguments with nil + LUAU_ASSERT(cast_int(L->top - base) >= numparams); + + // move fixed parameters to final position + StkId fixed = base; /* first fixed argument */ + base = L->top; /* final position of first argument */ + + for (int i = 0; i < numparams; ++i) + { + setobjs2s(L, base + i, fixed + i); + setnilvalue(fixed + i); + } + + // rewire our stack frame to point to the new base + L->ci->base = base; + L->ci->top = base + cl->stacksize; + + L->base = base; + L->top = L->ci->top; + VM_NEXT(); + } + + VM_CASE(LOP_JUMPBACK) + { + VM_INTERRUPT(); + Instruction insn = *pc++; + + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_LOADKX) + { + Instruction insn = *pc++; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + uint32_t aux = *pc++; + TValue* kv = VM_KV(aux); + + setobj2s(L, ra, kv); + VM_NEXT(); + } + + VM_CASE(LOP_JUMPX) + { + VM_INTERRUPT(); + Instruction insn = *pc++; + + pc += LUAU_INSN_E(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + + VM_CASE(LOP_FASTCALL) + { + Instruction insn = *pc++; + int bfid = LUAU_INSN_A(insn); + int skip = LUAU_INSN_C(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code + skip) < unsigned(cl->l.p->sizecode)); + + Instruction call = pc[skip]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + StkId ra = VM_REG(LUAU_INSN_A(call)); + + int nparams = LUAU_INSN_B(call) - 1; + int nresults = LUAU_INSN_C(call) - 1; + + nparams = (nparams == LUA_MULTRET) ? int(L->top - ra - 1) : nparams; + + luau_FastFunction f = luauF_table[bfid]; + + if (cl->env->safeenv && f) + { + VM_PROTECT_PC(); + + int n = f(L, ra, ra + 1, nresults, ra + 2, nparams); + + if (n >= 0) + { + L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top; + + pc += skip + 1; // skip instructions that compute function as well as CALL + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + + VM_CASE(LOP_COVERAGE) + { + Instruction insn = *pc++; + int hits = LUAU_INSN_E(insn); + + // update hits with saturated add and patch the instruction in place + hits = (hits < (1 << 23) - 1) ? hits + 1 : hits; + ((uint32_t*)pc)[-1] = LOP_COVERAGE | (uint32_t(hits) << 8); + + VM_NEXT(); + } + + VM_CASE(LOP_CAPTURE) + { + LUAU_ASSERT(!"CAPTURE is a pseudo-opcode and must be executed as part of NEWCLOSURE"); + LUAU_UNREACHABLE(); + } + + VM_CASE(LOP_JUMPIFEQK) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* rb = VM_KV(aux); + + // Note that all jumps below jump by 1 in the "false" case to skip over aux + if (ttype(ra) == ttype(rb)) + { + switch (ttype(ra)) + { + case LUA_TNIL: + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TBOOLEAN: + pc += bvalue(ra) == bvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TNUMBER: + pc += nvalue(ra) == nvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TSTRING: + pc += gcvalue(ra) == gcvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + default:; + } + + LUAU_ASSERT(!"Constant is expected to be of primitive type"); + } + else + { + pc += 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + + VM_CASE(LOP_JUMPIFNOTEQK) + { + Instruction insn = *pc++; + uint32_t aux = *pc; + StkId ra = VM_REG(LUAU_INSN_A(insn)); + TValue* rb = VM_KV(aux); + + // Note that all jumps below jump by 1 in the "true" case to skip over aux + if (ttype(ra) == ttype(rb)) + { + switch (ttype(ra)) + { + case LUA_TNIL: + pc += 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TBOOLEAN: + pc += bvalue(ra) != bvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TNUMBER: + pc += nvalue(ra) != nvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + case LUA_TSTRING: + pc += gcvalue(ra) != gcvalue(rb) ? LUAU_INSN_D(insn) : 1; + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + + default:; + } + + LUAU_ASSERT(!"Constant is expected to be of primitive type"); + } + else + { + pc += LUAU_INSN_D(insn); + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + } + + VM_CASE(LOP_FASTCALL1) + { + Instruction insn = *pc++; + int bfid = LUAU_INSN_A(insn); + TValue* arg = VM_REG(LUAU_INSN_B(insn)); + int skip = LUAU_INSN_C(insn); + + LUAU_ASSERT(unsigned(pc - cl->l.p->code + skip) < unsigned(cl->l.p->sizecode)); + + Instruction call = pc[skip]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + StkId ra = VM_REG(LUAU_INSN_A(call)); + + int nparams = 1; + int nresults = LUAU_INSN_C(call) - 1; + + luau_FastFunction f = luauF_table[bfid]; + + if (cl->env->safeenv && f) + { + VM_PROTECT_PC(); + + int n = f(L, ra, arg, nresults, nullptr, nparams); + + if (n >= 0) + { + L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top; + + pc += skip + 1; // skip instructions that compute function as well as CALL + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + + VM_CASE(LOP_FASTCALL2) + { + Instruction insn = *pc++; + int bfid = LUAU_INSN_A(insn); + int skip = LUAU_INSN_C(insn) - 1; + uint32_t aux = *pc++; + TValue* arg1 = VM_REG(LUAU_INSN_B(insn)); + TValue* arg2 = VM_REG(aux); + + LUAU_ASSERT(unsigned(pc - cl->l.p->code + skip) < unsigned(cl->l.p->sizecode)); + + Instruction call = pc[skip]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + StkId ra = VM_REG(LUAU_INSN_A(call)); + + int nparams = 2; + int nresults = LUAU_INSN_C(call) - 1; + + luau_FastFunction f = luauF_table[bfid]; + + if (cl->env->safeenv && f) + { + VM_PROTECT_PC(); + + int n = f(L, ra, arg1, nresults, arg2, nparams); + + if (n >= 0) + { + L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top; + + pc += skip + 1; // skip instructions that compute function as well as CALL + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + + VM_CASE(LOP_FASTCALL2K) + { + Instruction insn = *pc++; + int bfid = LUAU_INSN_A(insn); + int skip = LUAU_INSN_C(insn) - 1; + uint32_t aux = *pc++; + TValue* arg1 = VM_REG(LUAU_INSN_B(insn)); + TValue* arg2 = VM_KV(aux); + + LUAU_ASSERT(unsigned(pc - cl->l.p->code + skip) < unsigned(cl->l.p->sizecode)); + + Instruction call = pc[skip]; + LUAU_ASSERT(LUAU_INSN_OP(call) == LOP_CALL); + + StkId ra = VM_REG(LUAU_INSN_A(call)); + + int nparams = 2; + int nresults = LUAU_INSN_C(call) - 1; + + luau_FastFunction f = luauF_table[bfid]; + + if (cl->env->safeenv && f) + { + VM_PROTECT_PC(); + + int n = f(L, ra, arg1, nresults, arg2, nparams); + + if (n >= 0) + { + L->top = (nresults == LUA_MULTRET) ? ra + n : L->ci->top; + + pc += skip + 1; // skip instructions that compute function as well as CALL + LUAU_ASSERT(unsigned(pc - cl->l.p->code) < unsigned(cl->l.p->sizecode)); + VM_NEXT(); + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + else + { + // continue execution through the fallback code + VM_NEXT(); + } + } + + VM_CASE(LOP_BREAK) + { + LUAU_ASSERT(cl->l.p->debuginsn); + + uint8_t op = cl->l.p->debuginsn[unsigned(pc - cl->l.p->code)]; + LUAU_ASSERT(op != LOP_BREAK); + + if (L->global->cb.debugbreak) + { + VM_PROTECT(luau_callhook(L, L->global->cb.debugbreak, NULL)); + + // allow debugbreak hook to put thread into error/yield state + if (L->status != 0) + goto exit; + } + + VM_CONTINUE(op); + } + +#if !VM_USE_CGOTO + default: + LUAU_ASSERT(!"Unknown opcode"); + LUAU_UNREACHABLE(); // improves switch() codegen by eliding opcode bounds checks +#endif + } + } + +exit:; +} + +void luau_execute(lua_State* L) +{ + if (L->singlestep) + luau_execute(L); + else + luau_execute(L); +} + +int luau_precall(lua_State* L, StkId func, int nresults) +{ + if (!ttisfunction(func)) + { + luau_tryfuncTM(L, func); + // L->top is incremented by tryfuncTM + } + + Closure* ccl = clvalue(func); + + CallInfo* ci = incr_ci(L); + ci->func = func; + ci->base = func + 1; + ci->top = L->top + ccl->stacksize; + ci->savedpc = NULL; + ci->flags = 0; + ci->nresults = nresults; + + L->base = ci->base; + // Note: L->top is assigned externally + + luaD_checkstack(L, ccl->stacksize); + LUAU_ASSERT(ci->top <= L->stack_last); + + if (!ccl->isC) + { + // fill unused parameters with nil + StkId argi = L->top; + StkId argend = L->base + ccl->l.p->numparams; + while (argi < argend) + setnilvalue(argi++); /* complete missing arguments */ + L->top = ccl->l.p->is_vararg ? argi : ci->top; + + L->ci->savedpc = ccl->l.p->code; + + return PCRLUA; + } + else + { + lua_CFunction func = ccl->c.f; + int n = func(L); + + // yield + if (n < 0) + return PCRYIELD; + + // ci is our callinfo, cip is our parent + CallInfo* ci = L->ci; + CallInfo* cip = ci - 1; + + // copy return values into parent stack (but only up to nresults!), fill the rest with nil + // TODO: it might be worthwhile to handle the case when nresults==b explicitly? + StkId res = ci->func; + StkId vali = L->top - n; + StkId valend = L->top; + + int i; + for (i = nresults; i != 0 && vali < valend; i--) + setobjs2s(L, res++, vali++); + while (i-- > 0) + setnilvalue(res++); + + // pop the stack frame + L->ci = cip; + L->base = cip->base; + L->top = res; + + return PCRC; + } +} + +void luau_poscall(lua_State* L, StkId first) +{ + // finish interrupted execution of `OP_CALL' + // ci is our callinfo, cip is our parent + CallInfo* ci = L->ci; + CallInfo* cip = ci - 1; + + // copy return values into parent stack (but only up to nresults!), fill the rest with nil + // TODO: it might be worthwhile to handle the case when nresults==b explicitly? + StkId res = ci->func; + StkId vali = first; + StkId valend = L->top; + + int i; + for (i = ci->nresults; i != 0 && vali < valend; i--) + setobjs2s(L, res++, vali++); + while (i-- > 0) + setnilvalue(res++); + + // pop the stack frame + L->ci = cip; + L->base = cip->base; + L->top = (ci->nresults == LUA_MULTRET) ? res : cip->top; +} diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp new file mode 100644 index 0000000..0a23234 --- /dev/null +++ b/VM/src/lvmload.cpp @@ -0,0 +1,330 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lvm.h" + +#include "lstate.h" +#include "ltable.h" +#include "lfunc.h" +#include "lstring.h" +#include "lgc.h" +#include "lmem.h" +#include "lbytecode.h" + +#include + +#include + +void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) +{ + int count = id >> 30; + int id0 = count > 0 ? int(id >> 20) & 1023 : -1; + int id1 = count > 1 ? int(id >> 10) & 1023 : -1; + int id2 = count > 2 ? int(id) & 1023 : -1; + + // allocate a stack slot so that we can do table lookups + luaD_checkstack(L, 1); + setnilvalue(L->top); + L->top++; + + // global lookup into L->top-1 + TValue g; + sethvalue(L, &g, env); + luaV_gettable(L, &g, &k[id0], L->top - 1); + + // table lookup for id1 + if (id1 >= 0 && (!propagatenil || !ttisnil(L->top - 1))) + luaV_gettable(L, L->top - 1, &k[id1], L->top - 1); + + // table lookup for id2 + if (id2 >= 0 && (!propagatenil || !ttisnil(L->top - 1))) + luaV_gettable(L, L->top - 1, &k[id2], L->top - 1); +} + +template +static T read(const char* data, size_t size, size_t& offset) +{ + T result; + memcpy(&result, data + offset, sizeof(T)); + offset += sizeof(T); + + return result; +} + +static unsigned int readVarInt(const char* data, size_t size, size_t& offset) +{ + unsigned int result = 0; + unsigned int shift = 0; + + uint8_t byte; + + do + { + byte = read(data, size, offset); + result |= (byte & 127) << shift; + shift += 7; + } while (byte & 128); + + return result; +} + +static TString* readString(std::vector& strings, const char* data, size_t size, size_t& offset) +{ + unsigned int id = readVarInt(data, size, offset); + + return id == 0 ? NULL : strings[id - 1]; +} + +static void resolveImportSafe(lua_State* L, Table* env, TValue* k, uint32_t id) +{ + struct ResolveImport + { + TValue* k; + uint32_t id; + + static void run(lua_State* L, void* ud) + { + ResolveImport* self = static_cast(ud); + + // note: we call getimport with nil propagation which means that accesses to table chains like A.B.C will resolve in nil + // this is technically not necessary but it reduces the number of exceptions when loading scripts that rely on getfenv/setfenv for global + // injection + luaV_getimport(L, hvalue(gt(L)), self->k, self->id, /* propagatenil= */ true); + } + }; + + ResolveImport ri = {k, id}; + if (hvalue(gt(L))->safeenv) + { + // luaD_pcall will make sure that if any C/Lua calls during import resolution fail, the thread state is restored back + int oldTop = lua_gettop(L); + int status = luaD_pcall(L, &ResolveImport::run, &ri, savestack(L, L->top), 0); + LUAU_ASSERT(oldTop + 1 == lua_gettop(L)); // if an error occurred, luaD_pcall saves it on stack + + if (status != 0) + { + // replace error object with nil + setnilvalue(L->top - 1); + } + } + else + { + setnilvalue(L->top); + L->top++; + } +} + +int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env) +{ + size_t offset = 0; + + uint8_t version = read(data, size, offset); + + // 0 means the rest of the bytecode is the error message + if (version == 0 || version != LBC_VERSION) + { + char chunkid[LUA_IDSIZE]; + luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + + if (version == 0) + lua_pushfstring(L, "%s%.*s", chunkid, int(size - offset), data + offset); + else + lua_pushfstring(L, "%s: bytecode version mismatch", chunkid); + return 1; + } + + // pause GC for the duration of deserialization - some objects we're creating aren't rooted + size_t GCthreshold = L->global->GCthreshold; + L->global->GCthreshold = SIZE_MAX; + + // env is 0 for current environment and a stack relative index otherwise + LUAU_ASSERT(env <= 0 && L->top - L->base >= -env); + Table* envt = (env == 0) ? hvalue(gt(L)) : hvalue(L->top + env); + + TString* source = luaS_new(L, chunkname); + + // string table + unsigned int stringCount = readVarInt(data, size, offset); + std::vector strings(stringCount); + + for (unsigned int i = 0; i < stringCount; ++i) + { + unsigned int length = readVarInt(data, size, offset); + + strings[i] = luaS_newlstr(L, data + offset, length); + offset += length; + } + + // proto table + unsigned int protoCount = readVarInt(data, size, offset); + std::vector protos(protoCount); + + for (unsigned int i = 0; i < protoCount; ++i) + { + Proto* p = luaF_newproto(L); + p->source = source; + + p->maxstacksize = read(data, size, offset); + p->numparams = read(data, size, offset); + p->nups = read(data, size, offset); + p->is_vararg = read(data, size, offset); + + p->sizecode = readVarInt(data, size, offset); + p->code = luaM_newarray(L, p->sizecode, Instruction, p->memcat); + for (int j = 0; j < p->sizecode; ++j) + p->code[j] = read(data, size, offset); + + p->sizek = readVarInt(data, size, offset); + p->k = luaM_newarray(L, p->sizek, TValue, p->memcat); + +#ifdef HARDMEMTESTS + // this is redundant during normal runs, but resolveImportSafe can trigger GC checks under HARDMEMTESTS + // because p->k isn't fully formed at this point, we pre-fill it with nil to make subsequent setup safe + for (int j = 0; j < p->sizek; ++j) + { + setnilvalue(&p->k[j]); + } +#endif + + for (int j = 0; j < p->sizek; ++j) + { + switch (read(data, size, offset)) + { + case LBC_CONSTANT_NIL: + setnilvalue(&p->k[j]); + break; + + case LBC_CONSTANT_BOOLEAN: + { + uint8_t v = read(data, size, offset); + setbvalue(&p->k[j], v); + break; + } + + case LBC_CONSTANT_NUMBER: + { + double v = read(data, size, offset); + setnvalue(&p->k[j], v); + break; + } + + case LBC_CONSTANT_STRING: + { + TString* v = readString(strings, data, size, offset); + setsvalue2n(L, &p->k[j], v); + break; + } + + case LBC_CONSTANT_IMPORT: + { + uint32_t iid = read(data, size, offset); + resolveImportSafe(L, envt, p->k, iid); + setobj(L, &p->k[j], L->top - 1); + L->top--; + break; + } + + case LBC_CONSTANT_TABLE: + { + int keys = readVarInt(data, size, offset); + Table* h = luaH_new(L, 0, keys); + for (int i = 0; i < keys; ++i) + { + int key = readVarInt(data, size, offset); + TValue* val = luaH_set(L, h, &p->k[key]); + setnvalue(val, 0.0); + } + sethvalue(L, &p->k[j], h); + break; + } + + case LBC_CONSTANT_CLOSURE: + { + uint32_t fid = readVarInt(data, size, offset); + Closure* cl = luaF_newLclosure(L, protos[fid]->nups, envt, protos[fid]); + cl->preload = (cl->nupvalues > 0); + setclvalue(L, &p->k[j], cl); + break; + } + + default: + LUAU_ASSERT(!"Unexpected constant kind"); + } + } + + p->sizep = readVarInt(data, size, offset); + p->p = luaM_newarray(L, p->sizep, Proto*, p->memcat); + for (int j = 0; j < p->sizep; ++j) + { + uint32_t fid = readVarInt(data, size, offset); + p->p[j] = protos[fid]; + } + + p->debugname = readString(strings, data, size, offset); + + uint8_t lineinfo = read(data, size, offset); + + if (lineinfo) + { + p->linegaplog2 = read(data, size, offset); + + int intervals = ((p->sizecode - 1) >> p->linegaplog2) + 1; + int absoffset = (p->sizecode + 3) & ~3; + + p->sizelineinfo = absoffset + intervals * sizeof(int); + p->lineinfo = luaM_newarray(L, p->sizelineinfo, uint8_t, p->memcat); + p->abslineinfo = (int*)(p->lineinfo + absoffset); + + uint8_t lastoffset = 0; + for (int j = 0; j < p->sizecode; ++j) + { + lastoffset += read(data, size, offset); + p->lineinfo[j] = lastoffset; + } + + int lastLine = 0; + for (int j = 0; j < intervals; ++j) + { + lastLine += read(data, size, offset); + p->abslineinfo[j] = lastLine; + } + } + + uint8_t debuginfo = read(data, size, offset); + + if (debuginfo) + { + p->sizelocvars = readVarInt(data, size, offset); + p->locvars = luaM_newarray(L, p->sizelocvars, LocVar, p->memcat); + + for (int j = 0; j < p->sizelocvars; ++j) + { + p->locvars[j].varname = readString(strings, data, size, offset); + p->locvars[j].startpc = readVarInt(data, size, offset); + p->locvars[j].endpc = readVarInt(data, size, offset); + p->locvars[j].reg = read(data, size, offset); + } + + p->sizeupvalues = readVarInt(data, size, offset); + p->upvalues = luaM_newarray(L, p->sizeupvalues, TString*, p->memcat); + + for (int j = 0; j < p->sizeupvalues; ++j) + { + p->upvalues[j] = readString(strings, data, size, offset); + } + } + + protos[i] = p; + } + + // "main" proto is pushed to Lua stack + uint32_t mainid = readVarInt(data, size, offset); + Proto* main = protos[mainid]; + + Closure* cl = luaF_newLclosure(L, 0, envt, main); + setclvalue(L, L->top, cl); + incr_top(L); + + L->global->GCthreshold = GCthreshold; + + return 0; +} diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp new file mode 100644 index 0000000..f52e8e7 --- /dev/null +++ b/VM/src/lvmutils.cpp @@ -0,0 +1,492 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lvm.h" + +#include "lstate.h" +#include "lstring.h" +#include "ltable.h" +#include "lgc.h" +#include "ldo.h" +#include "lnumutils.h" + +#include + +/* limit for table tag-method chains (to avoid loops) */ +#define MAXTAGLOOP 100 + +const TValue* luaV_tonumber(const TValue* obj, TValue* n) +{ + double num; + if (ttisnumber(obj)) + return obj; + if (ttisstring(obj) && luaO_str2d(svalue(obj), &num)) + { + setnvalue(n, num); + return n; + } + else + return NULL; +} + +int luaV_tostring(lua_State* L, StkId obj) +{ + if (!ttisnumber(obj)) + return 0; + else + { + char s[LUAI_MAXNUMBER2STR]; + double n = nvalue(obj); + luai_num2str(s, n); + setsvalue2s(L, obj, luaS_new(L, s)); + return 1; + } +} + +const float* luaV_tovector(const TValue* obj) +{ + if (ttisvector(obj)) + return obj->value.v; + + return nullptr; +} + +static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) +{ + ptrdiff_t result = savestack(L, res); + // RBOLOX: using stack room beyond top is technically safe here, but for very complicated reasons: + // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated + // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua + // stack and checkstack may invalidate those pointers + // * we cannot use savestack/restorestack because the arguments are sometimes on the C++ stack + // * during stack reallocation all of the allocated stack is copied (even beyond stack_last) so these + // values will be preserved even if they go past stack_last + LUAU_ASSERT((L->top + 3) < (L->stack + L->stacksize)); + setobj2s(L, L->top, f); /* push function */ + setobj2s(L, L->top + 1, p1); /* 1st argument */ + setobj2s(L, L->top + 2, p2); /* 2nd argument */ + luaD_checkstack(L, 3); + L->top += 3; + luaD_call(L, L->top - 3, 1); + res = restorestack(L, result); + L->top--; + setobjs2s(L, res, L->top); +} + +static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3) +{ + // RBOLOX: using stack room beyond top is technically safe here, but for very complicated reasons: + // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated + // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua + // stack and checkstack may invalidate those pointers + // * we cannot use savestack/restorestack because the arguments are sometimes on the C++ stack + // * during stack reallocation all of the allocated stack is copied (even beyond stack_last) so these + // values will be preserved even if they go past stack_last + LUAU_ASSERT((L->top + 4) < (L->stack + L->stacksize)); + setobj2s(L, L->top, f); /* push function */ + setobj2s(L, L->top + 1, p1); /* 1st argument */ + setobj2s(L, L->top + 2, p2); /* 2nd argument */ + setobj2s(L, L->top + 3, p3); /* 3th argument */ + luaD_checkstack(L, 4); + L->top += 4; + luaD_call(L, L->top - 4, 0); +} + +void luaV_gettable(lua_State* L, const TValue* t, TValue* key, StkId val) +{ + int loop; + for (loop = 0; loop < MAXTAGLOOP; loop++) + { + const TValue* tm; + if (ttistable(t)) + { /* `t' is a table? */ + Table* h = hvalue(t); + + const TValue* res = luaH_get(h, key); /* do a primitive get */ + + if (res != luaO_nilobject) + L->cachedslot = gval2slot(h, res); /* remember slot to accelerate future lookups */ + + if (!ttisnil(res) /* result is no nil? */ + || (tm = fasttm(L, h->metatable, TM_INDEX)) == NULL) + { /* or no TM? */ + setobj2s(L, val, res); + return; + } + /* t isn't a table, so see if it has an INDEX meta-method to look up the key with */ + } + else if (ttisnil(tm = luaT_gettmbyobj(L, t, TM_INDEX))) + luaG_indexerror(L, t, key); + if (ttisfunction(tm)) + { + callTMres(L, val, tm, t, key); + return; + } + t = tm; /* else repeat with `tm' */ + } + luaG_runerror(L, "loop in gettable"); +} + +void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val) +{ + int loop; + TValue temp; + for (loop = 0; loop < MAXTAGLOOP; loop++) + { + const TValue* tm; + if (ttistable(t)) + { /* `t' is a table? */ + Table* h = hvalue(t); + + if (h->readonly) + luaG_runerror(L, "Attempt to modify a readonly table"); + + TValue* oldval = luaH_set(L, h, key); /* do a primitive set */ + + L->cachedslot = gval2slot(h, oldval); /* remember slot to accelerate future lookups */ + + if (!ttisnil(oldval) || /* result is no nil? */ + (tm = fasttm(L, h->metatable, TM_NEWINDEX)) == NULL) + { /* or no TM? */ + setobj2t(L, oldval, val); + luaC_barriert(L, h, val); + return; + } + /* else will try the tag method */ + } + else if (ttisnil(tm = luaT_gettmbyobj(L, t, TM_NEWINDEX))) + luaG_indexerror(L, t, key); + if (ttisfunction(tm)) + { + callTM(L, tm, t, key, val); + return; + } + /* else repeat with `tm' */ + setobj(L, &temp, tm); /* avoid pointing inside table (may rehash) */ + t = &temp; + } + luaG_runerror(L, "loop in settable"); +} + +static int call_binTM(lua_State* L, const TValue* p1, const TValue* p2, StkId res, TMS event) +{ + const TValue* tm = luaT_gettmbyobj(L, p1, event); /* try first operand */ + if (ttisnil(tm)) + tm = luaT_gettmbyobj(L, p2, event); /* try second operand */ + if (ttisnil(tm)) + return 0; + callTMres(L, res, tm, p1, p2); + return 1; +} + +static const TValue* get_compTM(lua_State* L, Table* mt1, Table* mt2, TMS event) +{ + const TValue* tm1 = fasttm(L, mt1, event); + const TValue* tm2; + if (tm1 == NULL) + return NULL; /* no metamethod */ + if (mt1 == mt2) + return tm1; /* same metatables => same metamethods */ + tm2 = fasttm(L, mt2, event); + if (tm2 == NULL) + return NULL; /* no metamethod */ + if (luaO_rawequalObj(tm1, tm2)) /* same metamethods? */ + return tm1; + return NULL; +} + +static int call_orderTM(lua_State* L, const TValue* p1, const TValue* p2, TMS event) +{ + const TValue* tm1 = luaT_gettmbyobj(L, p1, event); + const TValue* tm2; + if (ttisnil(tm1)) + return -1; /* no metamethod? */ + tm2 = luaT_gettmbyobj(L, p2, event); + if (!luaO_rawequalObj(tm1, tm2)) /* different metamethods? */ + return -1; + callTMres(L, L->top, tm1, p1, p2); + return !l_isfalse(L->top); +} + +int luaV_strcmp(const TString* ls, const TString* rs) +{ + if (ls == rs) + return 0; + + const char* l = getstr(ls); + size_t ll = ls->len; + const char* r = getstr(rs); + size_t lr = rs->len; + size_t lmin = ll < lr ? ll : lr; + + int res = memcmp(l, r, lmin); + if (res != 0) + return res; + + return ll == lr ? 0 : ll < lr ? -1 : 1; +} + +int luaV_lessthan(lua_State* L, const TValue* l, const TValue* r) +{ + int res; + if (ttype(l) != ttype(r)) + luaG_ordererror(L, l, r, TM_LT); + else if (ttisnumber(l)) + return luai_numlt(nvalue(l), nvalue(r)); + else if (ttisstring(l)) + return luaV_strcmp(tsvalue(l), tsvalue(r)) < 0; + else if ((res = call_orderTM(L, l, r, TM_LT)) == -1) + luaG_ordererror(L, l, r, TM_LT); + return res; +} + +int luaV_lessequal(lua_State* L, const TValue* l, const TValue* r) +{ + int res; + if (ttype(l) != ttype(r)) + luaG_ordererror(L, l, r, TM_LE); + else if (ttisnumber(l)) + return luai_numle(nvalue(l), nvalue(r)); + else if (ttisstring(l)) + return luaV_strcmp(tsvalue(l), tsvalue(r)) <= 0; + else if ((res = call_orderTM(L, l, r, TM_LE)) != -1) /* first try `le' */ + return res; + else if ((res = call_orderTM(L, r, l, TM_LT)) == -1) /* error if not `lt' */ + luaG_ordererror(L, l, r, TM_LE); + return !res; +} + +int luaV_equalval(lua_State* L, const TValue* t1, const TValue* t2) +{ + const TValue* tm; + LUAU_ASSERT(ttype(t1) == ttype(t2)); + switch (ttype(t1)) + { + case LUA_TNIL: + return 1; + case LUA_TNUMBER: + return luai_numeq(nvalue(t1), nvalue(t2)); + case LUA_TVECTOR: + return luai_veceq(vvalue(t1), vvalue(t2)); + case LUA_TBOOLEAN: + return bvalue(t1) == bvalue(t2); /* true must be 1 !! */ + case LUA_TLIGHTUSERDATA: + return pvalue(t1) == pvalue(t2); + case LUA_TUSERDATA: + { + tm = get_compTM(L, uvalue(t1)->metatable, uvalue(t2)->metatable, TM_EQ); + if (!tm) + return uvalue(t1) == uvalue(t2); + break; /* will try TM */ + } + case LUA_TTABLE: + { + tm = get_compTM(L, hvalue(t1)->metatable, hvalue(t2)->metatable, TM_EQ); + if (!tm) + return hvalue(t1) == hvalue(t2); + break; /* will try TM */ + } + default: + return gcvalue(t1) == gcvalue(t2); + } + callTMres(L, L->top, tm, t1, t2); /* call TM */ + return !l_isfalse(L->top); +} + +void luaV_concat(lua_State* L, int total, int last) +{ + do + { + StkId top = L->base + last + 1; + int n = 2; /* number of elements handled in this pass (at least 2) */ + if (!(ttisstring(top - 2) || ttisnumber(top - 2)) || !tostring(L, top - 1)) + { + if (!call_binTM(L, top - 2, top - 1, top - 2, TM_CONCAT)) + luaG_concaterror(L, top - 2, top - 1); + } + else if (tsvalue(top - 1)->len == 0) /* second op is empty? */ + (void)tostring(L, top - 2); /* result is first op (as string) */ + else + { + /* at least two string values; get as many as possible */ + size_t tl = tsvalue(top - 1)->len; + char* buffer; + int i; + /* collect total length */ + for (n = 1; n < total && tostring(L, top - n - 1); n++) + { + size_t l = tsvalue(top - n - 1)->len; + if (l > MAXSSIZE - tl) + luaG_runerror(L, "string length overflow"); + tl += l; + } + + char buf[LUA_BUFFERSIZE]; + TString* ts = nullptr; + + if (tl < LUA_BUFFERSIZE) + { + buffer = buf; + } + else + { + ts = luaS_bufstart(L, tl); + buffer = ts->data; + } + + tl = 0; + for (i = n; i > 0; i--) + { /* concat all strings */ + size_t l = tsvalue(top - i)->len; + memcpy(buffer + tl, svalue(top - i), l); + tl += l; + } + + if (tl < LUA_BUFFERSIZE) + { + setsvalue2s(L, top - n, luaS_newlstr(L, buffer, tl)); + } + else + { + setsvalue2s(L, top - n, luaS_buffinish(L, ts)); + } + } + total -= n - 1; /* got `n' strings to create 1 new */ + last -= n - 1; + } while (total > 1); /* repeat until only 1 result left */ +} + +void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TMS op) +{ + TValue tempb, tempc; + const TValue *b, *c; + if ((b = luaV_tonumber(rb, &tempb)) != NULL && (c = luaV_tonumber(rc, &tempc)) != NULL) + { + double nb = nvalue(b), nc = nvalue(c); + switch (op) + { + case TM_ADD: + setnvalue(ra, luai_numadd(nb, nc)); + break; + case TM_SUB: + setnvalue(ra, luai_numsub(nb, nc)); + break; + case TM_MUL: + setnvalue(ra, luai_nummul(nb, nc)); + break; + case TM_DIV: + setnvalue(ra, luai_numdiv(nb, nc)); + break; + case TM_MOD: + setnvalue(ra, luai_nummod(nb, nc)); + break; + case TM_POW: + setnvalue(ra, luai_numpow(nb, nc)); + break; + case TM_UNM: + setnvalue(ra, luai_numunm(nb)); + break; + default: + LUAU_ASSERT(0); + break; + } + } + else + { + // vector operations that we support: v + v, v - v, v * v, s * v, v * s, v / v, s / v, v / s, -v + const float* vb = luaV_tovector(rb); + const float* vc = luaV_tovector(rc); + + if (vb && vc) + { + switch (op) + { + case TM_ADD: + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2]); + return; + case TM_SUB: + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2]); + return; + case TM_MUL: + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2]); + return; + case TM_DIV: + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2]); + return; + case TM_UNM: + setvvalue(ra, -vb[0], -vb[1], -vb[2]); + return; + default: + break; + } + } + else if (vb) + { + c = luaV_tonumber(rc, &tempc); + + if (c) + { + float nc = cast_to(float, nvalue(c)); + + switch (op) + { + case TM_MUL: + setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc); + return; + case TM_DIV: + setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc); + return; + default: + break; + } + } + } + else if (vc) + { + b = luaV_tonumber(rb, &tempb); + + if (b) + { + float nb = cast_to(float, nvalue(b)); + + switch (op) + { + case TM_MUL: + setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2]); + return; + case TM_DIV: + setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2]); + return; + default: + break; + } + } + } + + if (!call_binTM(L, rb, rc, ra, op)) + { + luaG_aritherror(L, rb, rc, op); + } + } +} + +void luaV_dolen(lua_State* L, StkId ra, const TValue* rb) +{ + switch (ttype(rb)) + { + case LUA_TTABLE: + { + setnvalue(ra, cast_num(luaH_getn(hvalue(rb)))); + break; + } + case LUA_TSTRING: + { + setnvalue(ra, cast_num(tsvalue(rb)->len)); + break; + } + default: + { /* try metamethod */ + if (!call_binTM(L, rb, luaO_nilobject, ra, TM_LEN)) + luaG_typeerror(L, rb, "get length of"); + } + } +} diff --git a/bench/bench.py b/bench/bench.py new file mode 100644 index 0000000..b23ca89 --- /dev/null +++ b/bench/bench.py @@ -0,0 +1,860 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +import argparse +import os +import subprocess +import math +import sys +import re +import json + +# Taken from rotest +from color import colored, Color +from tabulate import TablePrinter, Alignment + +# Based on rotest, specialized for benchmark results +import influxbench + +try: + import matplotlib + import matplotlib.pyplot as plt +except ModuleNotFoundError: + matplotlib = None + +try: + import scipy + from scipy import stats +except ModuleNotFoundError: + print("scipy package is required") + exit(1) + +scriptdir = os.path.dirname(os.path.realpath(__file__)) +defaultVm = 'luau.exe' if os.name == "nt" else './luau' + +argumentParser = argparse.ArgumentParser(description='Benchmark Lua script execution with an option to compare different VMs') + +argumentParser.add_argument('--vm', dest='vm',default=defaultVm,help='Lua executable to test (' + defaultVm + ' by default)') +argumentParser.add_argument('--folder', dest='folder',default=os.path.join(scriptdir, 'tests'),help='Folder with tests (tests by default)') +argumentParser.add_argument('--compare', dest='vmNext',type=str,nargs='*',help='List of Lua executables to compare against') +argumentParser.add_argument('--results', dest='results',type=str,nargs='*',help='List of json result files to compare and graph') +argumentParser.add_argument('--run-test', action='store', default=None, help='Regex test filter') +argumentParser.add_argument('--extra-loops', action='store',type=int,default=0, help='Amount of times to loop over one test (one test already performs multiple runs)') +argumentParser.add_argument('--filename', action='store',type=str,default='bench', help='File name for graph and results file') + +if matplotlib != None: + argumentParser.add_argument('--absolute', dest='absolute',action='store_const',const=1,default=0,help='Display absolute values instead of relative (enabled by default when benchmarking a single VM)') + argumentParser.add_argument('--speedup', dest='speedup',action='store_const',const=1,default=0,help='Draw a speedup graph') + argumentParser.add_argument('--sort', dest='sort',action='store_const',const=1,default=0,help='Sort values from worst to best improvements, ignoring conf. int. (disabled by default)') + argumentParser.add_argument('--window', dest='window',action='store_const',const=1,default=0,help='Display window with resulting plot (disabled by default)') + argumentParser.add_argument('--graph-vertical', action='store_true',dest='graph_vertical', help="Draw graph with vertical bars instead of horizontal") + +argumentParser.add_argument('--report-metrics', dest='report_metrics', help="Send metrics about this session to InfluxDB URL upon completion.") + +argumentParser.add_argument('--print-influx-debugging', action='store_true', dest='print_influx_debugging', help="Print output to aid in debugging of influx metrics reporting.") +argumentParser.add_argument('--no-print-influx-debugging', action='store_false', dest='print_influx_debugging', help="Don't print output to aid in debugging of influx metrics reporting.") + +argumentParser.add_argument('--no-print-final-summary', action='store_false', dest='print_final_summary', help="Don't print a table summarizing the results after all tests are run") + +def arrayRange(count): + result = [] + + for i in range(count): + result.append(i) + + return result + +def arrayRangeOffset(count, offset): + result = [] + + for i in range(count): + result.append(i + offset) + + return result + +def getVmOutput(cmd): + if os.name == "nt": + try: + return subprocess.check_output("start /realtime /affinity 1 /b /wait cmd /C \"" + cmd + "\"", shell=True, cwd=scriptdir).decode() + except KeyboardInterrupt: + exit(1) + except: + return "" + else: + with subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd=scriptdir) as p: + # Try to lock to a single processor + if sys.platform != "darwin": + os.sched_setaffinity(p.pid, { 0 }) + + # Try to set high priority (requires sudo) + try: + os.nice(-10) + except: + pass + + return p.communicate()[0] + +def getShortVmName(name): + # Hope that the path to executable doesn't contain spaces + argumentPos = name.find(" ") + + if argumentPos != -1: + executableName = name[0:argumentPos] + arguments = name[argumentPos+1:] + + pathPos = executableName.rfind("\\") + + if pathPos == -1: + pathPos = executableName.rfind("/") + + if pathPos != -1: + executableName = executableName[pathPos+1:] + + return executableName + " " + arguments + + pathPos = name.rfind("\\") + + if pathPos == -1: + pathPos = name.rfind("/") + + if pathPos != -1: + return name[pathPos+1:] + + return name + +class TestResult: + filename = "" + vm = "" + shortVm = "" + name = "" + + values = [] + count = 0 + min = None + avg = 0 + max = None + + sampleStdDev = 0 + unbiasedEst = 0 + sampleConfidenceInterval = 0 + +def extractResult(filename, vm, output): + elements = output.split("|><|") + + # Remove test output + elements.remove(elements[0]) + + result = TestResult() + + result.filename = filename + result.vm = vm + result.shortVm = getShortVmName(vm) + + result.name = elements[0] + elements.remove(elements[0]) + + timeTable = [] + + for el in elements: + timeTable.append(float(el)) + + result.values = timeTable + result.count = len(timeTable) + + return result + +def mergeResult(lhs, rhs): + for value in rhs.values: + lhs.values.append(value) + + lhs.count = len(lhs.values) + +def mergeResults(lhs, rhs): + for a, b in zip(lhs, rhs): + mergeResult(a, b) + +def finalizeResult(result): + total = 0.0 + + # Compute basic parameters + for v in result.values: + if result.min == None or v < result.min: + result.min = v + + if result.max == None or v > result.max: + result.max = v + + total = total + v + + if result.count > 0: + result.avg = total / result.count + else: + result.avg = 0 + + # Compute standard deviation + sumOfSquares = 0 + + for v in result.values: + sumOfSquares = sumOfSquares + (v - result.avg) ** 2 + + if result.count > 1: + result.sampleStdDev = math.sqrt(sumOfSquares / (result.count - 1)) + result.unbiasedEst = result.sampleStdDev * result.sampleStdDev + + # Two-tailed distribution with 95% conf. + tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) + + # Compute confidence interval + result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + else: + result.sampleStdDev = 0 + result.unbiasedEst = 0 + result.sampleConfidenceInterval = 0 + + return result + +# Full result set +allResults = [] + + +# Data for the graph +plotLegend = [] + +plotLabels = [] + +plotValueLists = [] +plotConfIntLists = [] + +# Totals +vmTotalMin = [] +vmTotalAverage = [] +vmTotalImprovement = [] +vmTotalResults = [] + +# Data for Telegraf report +mainTotalMin = 0 +mainTotalAverage = 0 +mainTotalMax = 0 + +def getExtraArguments(filepath): + try: + with open(filepath) as f: + for i in f.readlines(): + pos = i.find("--bench-args:") + if pos != -1: + return i[pos + 13:].strip() + except: + pass + + return "" + +def substituteArguments(cmd, extra): + if argumentSubstituionCallback != None: + cmd = argumentSubstituionCallback(cmd) + + if cmd.find("@EXTRA") != -1: + cmd = cmd.replace("@EXTRA", extra) + else: + cmd = cmd + " " + extra + + return cmd + +def extractResults(filename, vm, output, allowFailure): + results = [] + + splitOutput = output.split("||_||") + + if len(splitOutput) <= 1: + if allowFailure: + result = TestResult() + + result.filename = filename + result.vm = vm + result.shortVm = getShortVmName(vm) + + results.append(result) + + return results + + splitOutput.remove(splitOutput[len(splitOutput) - 1]) + + for el in splitOutput: + results.append(extractResult(filename, vm, el)) + + return results + +def analyzeResult(subdir, main, comparisons): + # Aggregate statistics + global mainTotalMin, mainTotalAverage, mainTotalMax + + mainTotalMin = mainTotalMin + main.min + mainTotalAverage = mainTotalAverage + main.avg + mainTotalMax = mainTotalMax + main.max + + if arguments.vmNext != None: + resultPrinter.add_row({ + 'Test': main.name, + 'Min': '{:8.3f}ms'.format(main.min), + 'Average': '{:8.3f}ms'.format(main.avg), + 'StdDev%': '{:8.3f}%'.format(main.sampleConfidenceInterval / main.avg * 100), + 'Driver': main.shortVm, + 'Speedup': "", + 'Significance': "", + 'P(T<=t)': "" + }) + else: + resultPrinter.add_row({ + 'Test': main.name, + 'Min': '{:8.3f}ms'.format(main.min), + 'Average': '{:8.3f}ms'.format(main.avg), + 'StdDev%': '{:8.3f}%'.format(main.sampleConfidenceInterval / main.avg * 100), + 'Driver': main.shortVm + }) + + if influxReporter != None: + influxReporter.report_result(subdir, main.name, main.filename, "SUCCESS", main.min, main.avg, main.max, main.sampleConfidenceInterval, main.shortVm, main.vm) + + print(colored(Color.YELLOW, 'SUCCESS') + ': {:<40}'.format(main.name) + ": " + '{:8.3f}'.format(main.avg) + "ms +/- " + + '{:6.3f}'.format(main.sampleConfidenceInterval / main.avg * 100) + "% on " + main.shortVm) + + plotLabels.append(main.name) + + index = 0 + + if len(plotValueLists) < index + 1: + plotValueLists.append([]) + plotConfIntLists.append([]) + + vmTotalMin.append(0.0) + vmTotalAverage.append(0.0) + vmTotalImprovement.append(0.0) + vmTotalResults.append(0) + + if arguments.absolute or arguments.speedup: + scale = 1 + else: + scale = 100 / main.avg + + plotValueLists[index].append(main.avg * scale) + plotConfIntLists[index].append(main.sampleConfidenceInterval * scale) + + vmTotalMin[index] += main.min + vmTotalAverage[index] += main.avg + + for compare in comparisons: + index = index + 1 + + if len(plotValueLists) < index + 1 and not arguments.speedup: + plotValueLists.append([]) + plotConfIntLists.append([]) + + vmTotalMin.append(0.0) + vmTotalAverage.append(0.0) + vmTotalImprovement.append(0.0) + vmTotalResults.append(0) + + if compare.min == None: + print(colored(Color.RED, 'FAILED') + ": '" + main.name + "' on '" + compare.vm + "'") + + resultPrinter.add_row({ 'Test': main.name, 'Min': "", 'Average': "FAILED", 'StdDev%': "", 'Driver': compare.shortVm, 'Speedup': "", 'Significance': "", 'P(T<=t)': "" }) + + if influxReporter != None: + influxReporter.report_result(subdir, main.filename, main.filename, "FAILED", 0.0, 0.0, 0.0, 0.0, compare.shortVm, compare.vm) + + if arguments.speedup: + plotValueLists[0].pop() + plotValueLists[0].append(0) + + plotConfIntLists[0].pop() + plotConfIntLists[0].append(0) + else: + plotValueLists[index].append(0) + plotConfIntLists[index].append(0) + + continue + + pooledStdDev = math.sqrt((main.unbiasedEst + compare.unbiasedEst) / 2) + + tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) + degreesOfFreedom = 2 * main.count - 2 + + # Two-tailed distribution with 95% conf. + tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) + + noSignificantDifference = tStat < tCritical + + pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) + + if noSignificantDifference: + verdict = "likely same" + elif main.avg < compare.avg: + verdict = "likely worse" + else: + verdict = "likely better" + + speedup = (plotValueLists[0][-1] / (compare.avg * scale) - 1) + speedupColor = Color.YELLOW if speedup < 0 and noSignificantDifference else Color.RED if speedup < 0 else Color.GREEN if speedup > 0 else Color.YELLOW + + resultPrinter.add_row({ + 'Test': main.name, + 'Min': '{:8.3f}ms'.format(compare.min), + 'Average': '{:8.3f}ms'.format(compare.avg), + 'StdDev%': '{:8.3f}%'.format(compare.sampleConfidenceInterval / compare.avg * 100), + 'Driver': compare.shortVm, + 'Speedup': colored(speedupColor, '{:8.3f}%'.format(speedup * 100)), + 'Significance': verdict, + 'P(T<=t)': '---' if pValue < 0 else '{:.0f}%'.format(pValue * 100) + }) + + print(colored(Color.YELLOW, 'SUCCESS') + ': {:<40}'.format(main.name) + ": " + '{:8.3f}'.format(compare.avg) + "ms +/- " + + '{:6.3f}'.format(compare.sampleConfidenceInterval / compare.avg * 100) + "% on " + compare.shortVm + + ' ({:+7.3f}%, '.format(speedup * 100) + verdict + ")") + + if influxReporter != None: + influxReporter.report_result(subdir, main.name, main.filename, "SUCCESS", compare.min, compare.avg, compare.max, compare.sampleConfidenceInterval, compare.shortVm, compare.vm) + + if arguments.speedup: + oldValue = plotValueLists[0].pop() + newValue = compare.avg + + plotValueLists[0].append((oldValue / newValue - 1) * 100) + + plotConfIntLists[0].pop() + plotConfIntLists[0].append(0) + else: + plotValueLists[index].append(compare.avg * scale) + plotConfIntLists[index].append(compare.sampleConfidenceInterval * scale) + + vmTotalMin[index] += compare.min + vmTotalAverage[index] += compare.avg + vmTotalImprovement[index] += math.log(main.avg / compare.avg) + vmTotalResults[index] += 1 + +def runTest(subdir, filename, filepath): + filepath = os.path.abspath(filepath) + + mainVm = os.path.abspath(arguments.vm) + + # Process output will contain the test name and execution times + mainOutput = getVmOutput(substituteArguments(mainVm, getExtraArguments(filepath)) + " " + filepath) + mainResultSet = extractResults(filename, mainVm, mainOutput, False) + + if len(mainResultSet) == 0: + print(colored(Color.RED, 'FAILED') + ": '" + filepath + "' on '" + mainVm + "'") + + if arguments.vmNext != None: + resultPrinter.add_row({ 'Test': filepath, 'Min': "", 'Average': "FAILED", 'StdDev%': "", 'Driver': getShortVmName(mainVm), 'Speedup': "", 'Significance': "", 'P(T<=t)': "" }) + else: + resultPrinter.add_row({ 'Test': filepath, 'Min': "", 'Average': "FAILED", 'StdDev%': "", 'Driver': getShortVmName(mainVm) }) + + if influxReporter != None: + influxReporter.report_result(subdir, filename, filename, "FAILED", 0.0, 0.0, 0.0, 0.0, getShortVmName(mainVm), mainVm) + return + + compareResultSets = [] + + if arguments.vmNext != None: + for compareVm in arguments.vmNext: + compareVm = os.path.abspath(compareVm) + + compareOutput = getVmOutput(substituteArguments(compareVm, getExtraArguments(filepath)) + " " + filepath) + compareResultSet = extractResults(filename, compareVm, compareOutput, True) + + compareResultSets.append(compareResultSet) + + if arguments.extra_loops > 0: + # get more results + for i in range(arguments.extra_loops): + extraMainOutput = getVmOutput(substituteArguments(mainVm, getExtraArguments(filepath)) + " " + filepath) + extraMainResultSet = extractResults(filename, mainVm, extraMainOutput, False) + + mergeResults(mainResultSet, extraMainResultSet) + + if arguments.vmNext != None: + i = 0 + for compareVm in arguments.vmNext: + compareVm = os.path.abspath(compareVm) + + extraCompareOutput = getVmOutput(substituteArguments(compareVm, getExtraArguments(filepath)) + " " + filepath) + extraCompareResultSet = extractResults(filename, compareVm, extraCompareOutput, True) + + mergeResults(compareResultSets[i], extraCompareResultSet) + i += 1 + + # finalize results + for result in mainResultSet: + finalizeResult(result) + + for compareResultSet in compareResultSets: + for result in compareResultSet: + finalizeResult(result) + + # analyze results + for i in range(len(mainResultSet)): + mainResult = mainResultSet[i] + compareResults = [] + + for el in compareResultSets: + if i < len(el): + compareResults.append(el[i]) + else: + noResult = TestResult() + + noResult.filename = el[0].filename + noResult.vm = el[0].vm + noResult.shortVm = el[0].shortVm + + compareResults.append(noResult) + + analyzeResult(subdir, mainResult, compareResults) + + mergedResults = [] + mergedResults.append(mainResult) + + for el in compareResults: + mergedResults.append(el) + + allResults.append(mergedResults) + +def rearrangeSortKeyForComparison(e): + if plotValueLists[1][e] == 0: + return 1 + + return plotValueLists[0][e] / plotValueLists[1][e] + +def rearrangeSortKeyForSpeedup(e): + return plotValueLists[0][e] + +def rearrangeSortKeyDescending(e): + return -plotValueLists[0][e] + +# Re-arrange results from worst to best +def rearrange(key): + global plotLabels + + index = arrayRange(len(plotLabels)) + index = sorted(index, key=key) + + # Recreate value lists in sorted order + plotLabelsPrev = plotLabels + plotLabels = [] + + for i in index: + plotLabels.append(plotLabelsPrev[i]) + + for group in range(len(plotValueLists)): + plotValueListPrev = plotValueLists[group] + plotValueLists[group] = [] + + plotConfIntListPrev = plotConfIntLists[group] + plotConfIntLists[group] = [] + + for i in index: + plotValueLists[group].append(plotValueListPrev[i]) + plotConfIntLists[group].append(plotConfIntListPrev[i]) + +# Graph +def graph(): + if len(plotValueLists) == 0: + print("No results") + return + + ind = arrayRange(len(plotLabels)) + width = 0.8 / len(plotValueLists) + + if arguments.graph_vertical: + # Extend graph width when we have a lot of tests to draw + barcount = len(plotValueLists[0]) + plt.figure(figsize=(max(8, barcount * 0.3), 8)) + else: + # Extend graph height when we have a lot of tests to draw + barcount = len(plotValueLists[0]) + plt.figure(figsize=(8, max(8, barcount * 0.3))) + + plotBars = [] + + matplotlib.rc('xtick', labelsize=10) + matplotlib.rc('ytick', labelsize=10) + + if arguments.graph_vertical: + # Draw Y grid behind the bars + plt.rc('axes', axisbelow=True) + plt.grid(True, 'major', 'y') + + for i in range(len(plotValueLists)): + bar = plt.bar(arrayRangeOffset(len(plotLabels), i * width), plotValueLists[i], width, yerr=plotConfIntLists[i]) + plotBars.append(bar[0]) + + if arguments.absolute: + plt.ylabel('Time (ms)') + elif arguments.speedup: + plt.ylabel('Speedup (%)') + else: + plt.ylabel('Relative time (%)') + + plt.title('Benchmark') + plt.xticks(ind, plotLabels, rotation='vertical') + else: + # Draw X grid behind the bars + plt.rc('axes', axisbelow=True) + plt.grid(True, 'major', 'x') + + for i in range(len(plotValueLists)): + bar = plt.barh(arrayRangeOffset(len(plotLabels), i * width), plotValueLists[i], width, xerr=plotConfIntLists[i]) + plotBars.append(bar[0]) + + if arguments.absolute: + plt.xlabel('Time (ms)') + elif arguments.speedup: + plt.xlabel('Speedup (%)') + else: + plt.xlabel('Relative time (%)') + + plt.title('Benchmark') + plt.yticks(ind, plotLabels) + + plt.gca().invert_yaxis() + + plt.legend(plotBars, plotLegend) + + plt.tight_layout() + + plt.savefig(arguments.filename + ".png", dpi=200) + + if arguments.window: + plt.show() + +def addTotalsToTable(): + if len(vmTotalMin) == 0: + return + + if arguments.vmNext != None: + index = 0 + + resultPrinter.add_row({ + 'Test': 'Total', + 'Min': '{:8.3f}ms'.format(vmTotalMin[index]), + 'Average': '{:8.3f}ms'.format(vmTotalAverage[index]), + 'StdDev%': "---", + 'Driver': getShortVmName(os.path.abspath(arguments.vm)), + 'Speedup': "", + 'Significance': "", + 'P(T<=t)': "" + }) + + for compareVm in arguments.vmNext: + index = index + 1 + + speedup = vmTotalAverage[0] / vmTotalAverage[index] * 100 - 100 + + resultPrinter.add_row({ + 'Test': 'Total', + 'Min': '{:8.3f}ms'.format(vmTotalMin[index]), + 'Average': '{:8.3f}ms'.format(vmTotalAverage[index]), + 'StdDev%': "---", + 'Driver': getShortVmName(os.path.abspath(compareVm)), + 'Speedup': colored(Color.RED if speedup < 0 else Color.GREEN if speedup > 0 else Color.YELLOW, '{:8.3f}%'.format(speedup)), + 'Significance': "", + 'P(T<=t)': "" + }) + else: + resultPrinter.add_row({ + 'Test': 'Total', + 'Min': '{:8.3f}ms'.format(vmTotalMin[0]), + 'Average': '{:8.3f}ms'.format(vmTotalAverage[0]), + 'StdDev%': "---", + 'Driver': getShortVmName(os.path.abspath(arguments.vm)) + }) + +def writeResultsToFile(): + class TestResultEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, TestResult): + return [obj.filename, obj.vm, obj.shortVm, obj.name, obj.values, obj.count] + return json.JSONEncoder.default(self, obj) + + try: + with open(arguments.filename + ".json", "w") as allResultsFile: + allResultsFile.write(json.dumps(allResults, cls=TestResultEncoder)) + except: + print("Failed to write results to a file") + +def run(args, argsubcb): + global arguments, resultPrinter, influxReporter, argumentSubstituionCallback, allResults + arguments = args + argumentSubstituionCallback = argsubcb + + if arguments.report_metrics or arguments.print_influx_debugging: + influxReporter = influxbench.InfluxReporter(arguments) + else: + influxReporter = None + + if matplotlib == None: + arguments.absolute = 0 + arguments.speedup = 0 + arguments.sort = 0 + arguments.window = 0 + + # Load results from files + if arguments.results != None: + vmList = [] + + for result in arguments.results: + with open(result) as resultsFile: + resultArray = json.load(resultsFile) + + for test in resultArray: + for i in range(len(test)): + arr = test[i] + + tr = TestResult() + + tr.filename = arr[0] + tr.vm = arr[1] + tr.shortVm = arr[2] + tr.name = arr[3] + tr.values = arr[4] + tr.count = arr[5] + + test[i] = tr + + for test in resultArray[0]: + if vmList.count(test.vm) > 0: + pointPos = result.rfind(".") + + if pointPos != -1: + vmList.append(test.vm + " [" + result[0:pointPos] + "]") + else: + vmList.append(test.vm + " [" + result + "]") + else: + vmList.append(test.vm) + + if len(allResults) == 0: + allResults = resultArray + else: + for prevEl in allResults: + found = False + + for nextEl in resultArray: + if nextEl[0].filename == prevEl[0].filename and nextEl[0].name == prevEl[0].name: + for run in nextEl: + prevEl.append(run) + found = True + + if not found: + el = resultArray[0] + + for run in el: + result = TestResult() + + result.filename = run.filename + result.vm = run.vm + result.shortVm = run.shortVm + result.name = run.name + + prevEl.append(result) + + arguments.vmNext = [] + + for i in range(len(vmList)): + if i == 0: + arguments.vm = vmList[i] + else: + arguments.vmNext.append(vmList[i]) + + plotLegend.append(getShortVmName(arguments.vm)) + + if arguments.vmNext != None: + for compareVm in arguments.vmNext: + plotLegend.append(getShortVmName(compareVm)) + else: + arguments.absolute = 1 # When looking at one VM, I feel that relative graph doesn't make a lot of sense + + # Results table formatting + if arguments.vmNext != None: + resultPrinter = TablePrinter([ + {'label': 'Test', 'align': Alignment.LEFT}, + {'label': 'Min', 'align': Alignment.RIGHT}, + {'label': 'Average', 'align': Alignment.RIGHT}, + {'label': 'StdDev%', 'align': Alignment.RIGHT}, + {'label': 'Driver', 'align': Alignment.LEFT}, + {'label': 'Speedup', 'align': Alignment.RIGHT}, + {'label': 'Significance', 'align': Alignment.LEFT}, + {'label': 'P(T<=t)', 'align': Alignment.RIGHT} + ]) + else: + resultPrinter = TablePrinter([ + {'label': 'Test', 'align': Alignment.LEFT}, + {'label': 'Min', 'align': Alignment.RIGHT}, + {'label': 'Average', 'align': Alignment.RIGHT}, + {'label': 'StdDev%', 'align': Alignment.RIGHT}, + {'label': 'Driver', 'align': Alignment.LEFT} + ]) + + if arguments.results != None: + for resultSet in allResults: + # finalize results + for result in resultSet: + finalizeResult(result) + + # analyze results + mainResult = resultSet[0] + compareResults = [] + + for i in range(len(resultSet)): + if i != 0: + compareResults.append(resultSet[i]) + + analyzeResult('', mainResult, compareResults) + else: + for subdir, dirs, files in os.walk(arguments.folder): + for filename in files: + filepath = subdir + os.sep + filename + + if filename.endswith(".lua"): + if arguments.run_test == None or re.match(arguments.run_test, filename[:-4]): + runTest(subdir, filename, filepath) + + if arguments.sort and len(plotValueLists) > 1: + rearrange(rearrangeSortKeyForComparison) + elif arguments.sort and len(plotValueLists) == 1: + rearrange(rearrangeSortKeyDescending) + elif arguments.speedup: + rearrange(rearrangeSortKeyForSpeedup) + + plotLegend[0] = arguments.vm + " vs " + arguments.vmNext[0] + + if arguments.print_final_summary: + addTotalsToTable() + + print() + print(colored(Color.YELLOW, '==================================================RESULTS==================================================')) + resultPrinter.print(summary=False) + print(colored(Color.YELLOW, '---')) + + if len(vmTotalMin) != 0 and arguments.vmNext != None: + index = 0 + + for compareVm in arguments.vmNext: + index = index + 1 + + name = getShortVmName(os.path.abspath(compareVm)) + deltaGeoMean = math.exp(vmTotalImprovement[index] / vmTotalResults[index]) * 100 - 100 + + if deltaGeoMean > 0: + print("'{}' change is {:.3f}% positive on average".format(name, deltaGeoMean)) + else: + print("'{}' change is {:.3f}% negative on average".format(name, deltaGeoMean)) + + if matplotlib != None: + graph() + + writeResultsToFile() + + if influxReporter != None: + influxReporter.report_result(arguments.folder, "Total", "all", "SUCCESS", mainTotalMin, mainTotalAverage, mainTotalMax, 0.0, getShortVmName(arguments.vm), os.path.abspath(arguments.vm)) + influxReporter.flush(0) + + +if __name__ == "__main__": + arguments = argumentParser.parse_args() + run(arguments, None) diff --git a/bench/bench_support.lua b/bench/bench_support.lua new file mode 100644 index 0000000..171b8da --- /dev/null +++ b/bench/bench_support.lua @@ -0,0 +1,50 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +local bench = {} + +bench.runs = 20 +bench.extraRuns = 4 + +function bench.runCode(f, description) + local timeTable = {} + + for i = 1,bench.runs + bench.extraRuns do + -- try to run GC if it's available + if collectgarbage then + pcall(function() + collectgarbage() + end) + end + + local ts0 = os.clock() + + local result = f() + + local ts1 = os.clock() + + -- If test case doesn't return a duration (if only a part of code is measured) we will measure full execution time here + if not result then + result = ts1 - ts0 + end + + table.insert(timeTable, result) + end + + table.sort(timeTable) + + for i = 1,bench.extraRuns do + table.remove(timeTable, #timeTable) + end + + -- Output test name followed by each result + local report = "|><|"..description + + for _,v in ipairs(timeTable) do + report = report .. "|><|" .. (v * 1000) + end + + report = report .. "||_||" + + print(report) +end + +return bench diff --git a/bench/color.py b/bench/color.py new file mode 100644 index 0000000..363c2b2 --- /dev/null +++ b/bench/color.py @@ -0,0 +1,37 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +from enum import Enum +import sys + +class Color(Enum): + DEFAULT = 0 + RED = 1 + GREEN = 2 + BLUE = 3 + YELLOW = 4 + WHITE = 5 + +def colored_on(color:Color, message:str): + from colorama import Fore, Style + color_mappings = { + Color.DEFAULT: (Fore.WHITE, Style.NORMAL), + Color.RED: (Fore.RED, Style.NORMAL), + Color.GREEN: (Fore.GREEN, Style.NORMAL), + Color.BLUE: (Fore.BLUE, Style.BRIGHT), + Color.YELLOW: (Fore.YELLOW, Style.NORMAL), + Color.WHITE: (Fore.WHITE, Style.BRIGHT) + } + fore, style = color_mappings[color] + return fore + style + message + Style.RESET_ALL + +def colored_off(color:Color, message:str): + return message + +try: + if sys.stdout.isatty(): + import colorama + colorama.init() + colored = colored_on + else: + colored = colored_off +except: + colored = colored_off diff --git a/bench/gc/test_BinaryTree.lua b/bench/gc/test_BinaryTree.lua new file mode 100644 index 0000000..2a79738 --- /dev/null +++ b/bench/gc/test_BinaryTree.lua @@ -0,0 +1,56 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + -- The Computer Language Benchmarks Game + -- http://benchmarksgame.alioth.debian.org/ + -- contributed by Mike Pall + + local function BottomUpTree(item, depth) + if depth > 0 then + local i = item + item + depth = depth - 1 + local left, right = BottomUpTree(i-1, depth), BottomUpTree(i, depth) + return { item, left, right } + else + return { item } + end + end + + local function ItemCheck(tree) + if tree[2] then + return tree[1] + ItemCheck(tree[2]) - ItemCheck(tree[3]) + else + return tree[1] + end + end + + local N = 10 + local mindepth = 4 + local maxdepth = mindepth + 2 + if maxdepth < N then maxdepth = N end + + local ts0 = os.clock() + + do + local stretchdepth = maxdepth + 1 + local stretchtree = BottomUpTree(0, stretchdepth) + end + + local longlivedtree = BottomUpTree(0, maxdepth) + + for depth=mindepth,maxdepth,2 do + local iterations = 2 ^ (maxdepth - depth + mindepth) + local check = 0 + for i=1,iterations do + check = check + ItemCheck(BottomUpTree(1, depth)) + + ItemCheck(BottomUpTree(-1, depth)) + end + end + + local ts1 = os.clock() + + return ts1 - ts0 +end + +bench.runCode(test, "BinaryTree") \ No newline at end of file diff --git a/bench/gc/test_GC_Boehm_Trees.lua b/bench/gc/test_GC_Boehm_Trees.lua new file mode 100644 index 0000000..1451769 --- /dev/null +++ b/bench/gc/test_GC_Boehm_Trees.lua @@ -0,0 +1,77 @@ +--!non-strict +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +local stretchTreeDepth = 18 -- about 16Mb +local longLivedTreeDepth = 16 -- about 4Mb +local arraySize = 500000 --about 4Mb +local minTreeDepth = 4 +local maxTreeDepth = 16 + +-- Nodes used by a tree of a given size +function treeSize(i) + return bit32.lshift(1, i + 1) - 1 +end + +function getNumIters(i) + return 2 * treeSize(stretchTreeDepth) / treeSize(i) +end + +-- Build tree top down, assigning to older objects. +function populate(depth, thisNode) + if depth <= 0 then + return + end + + depth = depth - 1 + thisNode.left = {} + thisNode.right = {} + populate(depth, thisNode.left) + populate(depth, thisNode.right) +end + +-- Build tree bottom-up +function makeTree(depth) + if depth <= 0 then + return {} + end + + return { left = makeTree(depth - 1), right = makeTree(depth - 1) } +end + +function timeConstruction(depth) + local numIters = getNumIters(depth) + local tempTree = {} + + for i = 1, numIters do + tempTree = {} + populate(depth, tempTree) + tempTree = nil + end + + for i = 1, numIters do + tempTree = makeTree(depth) + tempTree = nil + end +end + +function test() + -- Stretch the memory space quickly + local _tempTree = makeTree(stretchTreeDepth) + _tempTree = nil + + -- Create a long lived object + local longLivedTree = {} + populate(longLivedTreeDepth, longLivedTree) + + -- Create long-lived array, filling half of it + local array = {} + for i = 1, arraySize/2 do + array[i] = 1.0 / i + end + + for d = minTreeDepth,maxTreeDepth,2 do + timeConstruction(d) + end +end + +bench.runCode(test, "GC: Boehm tree") diff --git a/bench/gc/test_GC_Tree_Pruning_Eager.lua b/bench/gc/test_GC_Tree_Pruning_Eager.lua new file mode 100644 index 0000000..514766a --- /dev/null +++ b/bench/gc/test_GC_Tree_Pruning_Eager.lua @@ -0,0 +1,50 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + local count = 1 + + local function fill_tree(tree, levels) + if not tree.left then + tree.left = { id = count } + count = count + 1 + end + + if not tree.right then + tree.right = { id = count } + count = count + 1 + end + + if levels ~= 0 then + fill_tree(tree.left, levels - 1) + fill_tree(tree.right, levels - 1) + end + end + + local function prune_tree(tree, level) + if tree.left then + if math.random() > 0.9 - level * 0.05 then + tree.left = nil + else + prune_tree(tree.left, level + 1) + end + end + + if tree.right then + if math.random() > 0.9 - level * 0.05 then + tree.right = nil + else + prune_tree(tree.right, level + 1) + end + end + end + + local tree = { id = 0 } + + for i = 1,1000 do + fill_tree(tree, 10) + + prune_tree(tree, 0) + end +end + +bench.runCode(test, "GC: tree pruning (eager fill)") diff --git a/bench/gc/test_GC_Tree_Pruning_Gen.lua b/bench/gc/test_GC_Tree_Pruning_Gen.lua new file mode 100644 index 0000000..a8d0f40 --- /dev/null +++ b/bench/gc/test_GC_Tree_Pruning_Gen.lua @@ -0,0 +1,54 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + local count = 1 + + local function fill_tree(tree, levels) + if not tree.left then + tree.left = { id = count } + count = count + 1 + end + + if not tree.right then + tree.right = { id = count } + count = count + 1 + end + + if levels ~= 0 then + fill_tree(tree.left, levels - 1) + fill_tree(tree.right, levels - 1) + end + end + + local function prune_tree(tree, level) + if tree.left then + if math.random() > 0.9 - level * 0.05 then + tree.left = nil + else + prune_tree(tree.left, level + 1) + end + end + + if tree.right then + if math.random() > 0.9 - level * 0.05 then + tree.right = nil + else + prune_tree(tree.right, level + 1) + end + end + end + + -- create a static tree + local tree = { id = 0 } + fill_tree(tree, 16) + + for i = 1,1000 do + local small_tree = { id = 0 } + + fill_tree(small_tree, 8) + + prune_tree(small_tree, 0) + end +end + +bench.runCode(test, "GC: tree pruning (eager fill, gen)") diff --git a/bench/gc/test_GC_Tree_Pruning_Lazy.lua b/bench/gc/test_GC_Tree_Pruning_Lazy.lua new file mode 100644 index 0000000..8cb6919 --- /dev/null +++ b/bench/gc/test_GC_Tree_Pruning_Lazy.lua @@ -0,0 +1,56 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + local count = 1 + + local function fill_tree(tree, levels) + local left = tree.left; + local right = tree.right; + + if not left then + left = { id = count } + count = count + 1 + end + + if not right then + right = { id = count } + count = count + 1 + end + + if levels ~= 0 then + fill_tree(left, levels - 1) + fill_tree(right, levels - 1) + end + + tree.left = left; + tree.right = right; + end + + local function prune_tree(tree, level) + if tree.left then + if math.random() > 0.9 - level * 0.05 then + tree.left = nil + else + prune_tree(tree.left, level + 1) + end + end + + if tree.right then + if math.random() > 0.9 - level * 0.05 then + tree.right = nil + else + prune_tree(tree.right, level + 1) + end + end + end + + local tree = { id = 0 } + + for i = 1,1000 do + fill_tree(tree, 10) + + prune_tree(tree, 0) + end +end + +bench.runCode(test, "GC: tree pruning (lazy fill)") diff --git a/bench/gc/test_GC_hashtable_Keyval.lua b/bench/gc/test_GC_hashtable_Keyval.lua new file mode 100644 index 0000000..fcb2482 --- /dev/null +++ b/bench/gc/test_GC_hashtable_Keyval.lua @@ -0,0 +1,22 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + local t = {} + + local max = 10000 + local iters = 50000 + + for i = 1,iters do + local is = tostring(i) + local input = string.rep(is, 1000 / #is) + + t[is] = input + + -- remove old entries + if i > max then + t[tostring(i - max)] = nil + end + end +end + +bench.runCode(test, "GC: hashtable keys and values") diff --git a/bench/gc/test_LB_mandel.lua b/bench/gc/test_LB_mandel.lua new file mode 100644 index 0000000..4be7850 --- /dev/null +++ b/bench/gc/test_LB_mandel.lua @@ -0,0 +1,97 @@ +--[[ +MIT License + +Copyright (c) 2017 Gabriel de Quadros Ligneul + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +]] +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +local Complex={type="package"} + +local function complex(x,y) + return setmetatable({ re=x, im=y }, Complex.metatable) +end + +function Complex.conj(x,y) + return complex(x.re,-x.im) +end + +function Complex.norm2(x) + local n=Complex.mul(x,Complex.conj(x)) + return n.re +end + +function Complex.abs(x) + return sqrt(Complex.norm2(x)) +end + +function Complex.add(x,y) + return complex(x.re+y.re,x.im+y.im) +end + +function Complex.mul(x,y) + return complex(x.re*y.re-x.im*y.im,x.re*y.im+x.im*y.re) +end + +Complex.metatable={ + __add = Complex.add, + __mul = Complex.mul, +} + +local function abs(x) + return math.sqrt(Complex.norm2(x)) +end + +xmin=-2.0 xmax=2.0 ymin=-2.0 ymax=2.0 +N=(arg and arg[1]) or 64 + +function level(x,y) + local c=complex(x,y) + local l=0 + local z=c + repeat + z=z*z+c + l=l+1 + until abs(z)>2.0 or l>255 + return l-1 +end + +dx=(xmax-xmin)/N +dy=(ymax-ymin)/N + +print("P2") +print("# mandelbrot set",xmin,xmax,ymin,ymax,N) +print(N,N,255) +local S = 0 +for i=1,N do + local x=xmin+(i-1)*dx + for j=1,N do + local y=ymin+(j-1)*dy + S = S + level(x,y) + end + -- if i % 10 == 0 then print(collectgarbage"count") end +end +print(S) + +end + +bench.runCode(test, "mandel") diff --git a/bench/gc/test_LargeTableCtor_array.lua b/bench/gc/test_LargeTableCtor_array.lua new file mode 100644 index 0000000..535877f --- /dev/null +++ b/bench/gc/test_LargeTableCtor_array.lua @@ -0,0 +1,35 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + for i=1,4000 do + local t = + { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, + } + end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "LargeTableCtor: array") \ No newline at end of file diff --git a/bench/gc/test_LargeTableCtor_hash.lua b/bench/gc/test_LargeTableCtor_hash.lua new file mode 100644 index 0000000..6faf766 --- /dev/null +++ b/bench/gc/test_LargeTableCtor_hash.lua @@ -0,0 +1,14 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + for i=1,100000 do + local t = { a = 1, b = 2, c = 3, d = 4, e = 5, f = 6 } + end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "LargeTableCtor: hash") \ No newline at end of file diff --git a/bench/gc/test_Pcall_pcall_yield.lua b/bench/gc/test_Pcall_pcall_yield.lua new file mode 100644 index 0000000..ac46c79 --- /dev/null +++ b/bench/gc/test_Pcall_pcall_yield.lua @@ -0,0 +1,18 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + function test() coroutine.yield() return 1 end + + local ts0 = os.clock() + for i=0,100000 do + local co = coroutine.create(function() return pcall(test) end) + coroutine.resume(co) + coroutine.resume(co) + end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "Pcall: pcall yield") \ No newline at end of file diff --git a/bench/gc/test_SunSpider_3d-raytrace.lua b/bench/gc/test_SunSpider_3d-raytrace.lua new file mode 100644 index 0000000..60e4f61 --- /dev/null +++ b/bench/gc/test_SunSpider_3d-raytrace.lua @@ -0,0 +1,502 @@ +--[[ + * Copyright (C) 2007 Apple Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +]] +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +local size = 30 + +function createVector(x,y,z) + return { x,y,z }; +end + +function sqrLengthVector(self) + return self[1] * self[1] + self[2] * self[2] + self[3] * self[3]; +end + +function lengthVector(self) + return math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); +end + +function addVector(self, v) + self[1] = self[1] + v[1]; + self[2] = self[2] + v[2]; + self[3] = self[3] + v[3]; + return self; +end + +function subVector(self, v) + self[1] = self[1] - v[1]; + self[2] = self[2] - v[2]; + self[3] = self[3] - v[3]; + return self; +end + +function scaleVector(self, scale) + self[1] = self[1] * scale; + self[2] = self[2] * scale; + self[3] = self[3] * scale; + return self; +end + +function normaliseVector(self) + local len = math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); + self[1] = self[1] / len; + self[2] = self[2] / len; + self[3] = self[3] / len; + return self; +end + +function add(v1, v2) + return { v1[1] + v2[1], v1[2] + v2[2], v1[3] + v2[3] }; +end + +function sub(v1, v2) + return { v1[1] - v2[1], v1[2] - v2[2], v1[3] - v2[3] }; +end + +function scalev(v1, v2) + return { v1[1] * v2[1], v1[2] * v2[2], v1[3] * v2[3] }; +end + +function dot(v1, v2) + return v1[1] * v2[1] + v1[2] * v2[2] + v1[3] * v2[3]; +end + +function scale(v, scale) + return { v[1] * scale, v[2] * scale, v[3] * scale }; +end + +function cross(v1, v2) + return { v1[2] * v2[3] - v1[3] * v2[2], + v1[3] * v2[1] - v1[1] * v2[3], + v1[1] * v2[2] - v1[2] * v2[1] }; + +end + +function normalise(v) + local len = lengthVector(v); + return { v[1] / len, v[2] / len, v[3] / len }; +end + +function transformMatrix(self, v) + local vals = self; + local x = vals[1] * v[1] + vals[2] * v[2] + vals[3] * v[3] + vals[4]; + local y = vals[5] * v[1] + vals[6] * v[2] + vals[7] * v[3] + vals[8]; + local z = vals[9] * v[1] + vals[10] * v[2] + vals[11] * v[3] + vals[12]; + return { x, y, z }; +end + +function invertMatrix(self) + local temp = {} + local tx = -self[4]; + local ty = -self[8]; + local tz = -self[12]; + for h = 0,2 do + for v = 0,2 do + temp[h + v * 4 + 1] = self[v + h * 4 + 1]; + end + end + + for i = 0,10 do + self[i + 1] = temp[i + 1]; + end + + self[4] = tx * self[1] + ty * self[2] + tz * self[3]; + self[8] = tx * self[5] + ty * self[6] + tz * self[7]; + self[12] = tx * self[9] + ty * self[10] + tz * self[11]; + return self; +end + +-- Triangle intersection using barycentric coord method +function Triangle(p1, p2, p3) + local this = {} + + local edge1 = sub(p3, p1); + local edge2 = sub(p2, p1); + local normal = cross(edge1, edge2); + if (math.abs(normal[1]) > math.abs(normal[2])) then + if (math.abs(normal[1]) > math.abs(normal[3])) then + this.axis = 0; + else + this.axis = 2; + end + else + if (math.abs(normal[2]) > math.abs(normal[3])) then + this.axis = 1; + else + this.axis = 2; + end + end + + local u = (this.axis + 1) % 3; + local v = (this.axis + 2) % 3; + local u1 = edge1[u + 1]; + local v1 = edge1[v + 1]; + + local u2 = edge2[u + 1]; + local v2 = edge2[v + 1]; + this.normal = normalise(normal); + this.nu = normal[u + 1] / normal[this.axis + 1]; + this.nv = normal[v + 1] / normal[this.axis + 1]; + this.nd = dot(normal, p1) / normal[this.axis + 1]; + local det = u1 * v2 - v1 * u2; + this.eu = p1[u + 1]; + this.ev = p1[v + 1]; + this.nu1 = u1 / det; + this.nv1 = -v1 / det; + this.nu2 = v2 / det; + this.nv2 = -u2 / det; + this.material = { 0.7, 0.7, 0.7 }; + + + this.intersect = function(self, orig, dir, near, far) + local u = (self.axis + 1) % 3; + local v = (self.axis + 2) % 3; + local d = dir[self.axis + 1] + self.nu * dir[u + 1] + self.nv * dir[v + 1]; + local t = (self.nd - orig[self.axis + 1] - self.nu * orig[u + 1] - self.nv * orig[v + 1]) / d; + + if (t < near or t > far) then + return nil; + end + + local Pu = orig[u + 1] + t * dir[u + 1] - self.eu; + local Pv = orig[v + 1] + t * dir[v + 1] - self.ev; + local a2 = Pv * self.nu1 + Pu * self.nv1; + + if (a2 < 0) then + return nil; + end + + local a3 = Pu * self.nu2 + Pv * self.nv2; + if (a3 < 0) then + return nil; + end + + if ((a2 + a3) > 1) then + return nil; + end + + return t; + end + + return this +end + +function Scene(a_triangles) + local this = {} + this.triangles = a_triangles; + this.lights = {}; + this.ambient = {0,0,0}; + this.background = {0.8,0.8,1}; + + this.intersect = function(self, origin, dir, near, far) + local closest = nil; + for i = 0,#self.triangles-1 do + local triangle = self.triangles[i + 1]; + local d = triangle:intersect(origin, dir, near, far); + if (d == nil or d > far or d < near) then + -- continue; + else + far = d; + closest = triangle; + end + end + + if (not closest) then + return { self.background[1],self.background[2],self.background[3] }; + end + + local normal = closest.normal; + local hit = add(origin, scale(dir, far)); + if (dot(dir, normal) > 0) then + normal = { -normal[1], -normal[2], -normal[3] }; + end + + local colour = nil; + if (closest.shader) then + colour = closest.shader(closest, hit, dir); + else + colour = closest.material; + end + + -- do reflection + local reflected = nil; + if (colour.reflection or 0 > 0.001) then + local reflection = addVector(scale(normal, -2*dot(dir, normal)), dir); + reflected = self:intersect(hit, reflection, 0.0001, 1000000); + if (colour.reflection >= 0.999999) then + return reflected; + end + end + + local l = { self.ambient[1], self.ambient[2], self.ambient[3] }; + + for i = 0,#self.lights-1 do + local light = self.lights[i + 1]; + local toLight = sub(light, hit); + local distance = lengthVector(toLight); + scaleVector(toLight, 1.0/distance); + distance = distance - 0.0001; + + if (self:blocked(hit, toLight, distance)) then + -- continue; + else + local nl = dot(normal, toLight); + if (nl > 0) then + addVector(l, scale(light.colour, nl)); + end + end + end + + l = scalev(l, colour); + if (reflected) then + l = addVector(scaleVector(l, 1 - colour.reflection), scaleVector(reflected, colour.reflection)); + end + + return l; + end + + this.blocked = function(self, O, D, far) + local near = 0.0001; + local closest = nil; + for i = 0,#self.triangles-1 do + local triangle = self.triangles[i + 1]; + local d = triangle:intersect(O, D, near, far); + if (d == nil or d > far or d < near) then + --continue; + else + return true; + end + end + + return false; + end + + return this +end + +local zero = { 0,0,0 }; + +-- this camera code is from notes i made ages ago, it is from *somewhere* -- i cannot remember where +-- that somewhere is +function Camera(origin, lookat, up) + local this = {} + + local zaxis = normaliseVector(subVector(lookat, origin)); + local xaxis = normaliseVector(cross(up, zaxis)); + local yaxis = normaliseVector(cross(xaxis, subVector({ 0,0,0 }, zaxis))); + local m = {}; + m[1] = xaxis[1]; m[2] = xaxis[2]; m[3] = xaxis[3]; + m[5] = yaxis[1]; m[6] = yaxis[2]; m[7] = yaxis[3]; + m[9] = zaxis[1]; m[10] = zaxis[2]; m[11] = zaxis[3]; + m[4] = 0; m[8] = 0; m[12] = 0; + invertMatrix(m); + m[4] = 0; m[8] = 0; m[12] = 0; + this.origin = origin; + this.directions = {}; + this.directions[1] = normalise({ -0.7, 0.7, 1 }); + this.directions[2] = normalise({ 0.7, 0.7, 1 }); + this.directions[3] = normalise({ 0.7, -0.7, 1 }); + this.directions[4] = normalise({ -0.7, -0.7, 1 }); + this.directions[1] = transformMatrix(m, this.directions[1]); + this.directions[2] = transformMatrix(m, this.directions[2]); + this.directions[3] = transformMatrix(m, this.directions[3]); + this.directions[4] = transformMatrix(m, this.directions[4]); + + this.generateRayPair = function(self, y) + rays = { {}, {} } + rays[1].origin = self.origin; + rays[2].origin = self.origin; + rays[1].dir = addVector(scale(self.directions[1], y), scale(self.directions[4], 1 - y)); + rays[2].dir = addVector(scale(self.directions[2], y), scale(self.directions[3], 1 - y)); + return rays; + end + + function renderRows(camera, scene, pixels, width, height, starty, stopy) + for y = starty,stopy-1 do + local rays = camera:generateRayPair(y / height); + for x = 0,width-1 do + local xp = x / width; + local origin = addVector(scale(rays[1].origin, xp), scale(rays[2].origin, 1 - xp)); + local dir = normaliseVector(addVector(scale(rays[1].dir, xp), scale(rays[2].dir, 1 - xp))); + local l = scene:intersect(origin, dir, 0, math.huge); + pixels[y + 1][x + 1] = l; + end + end + end + + this.render = function(self, scene, pixels, width, height) + local cam = self; + local row = 0; + renderRows(cam, scene, pixels, width, height, 0, height); + end + + return this +end + +function raytraceScene() + local startDate = 13154863; + local numTriangles = 2 * 6; + local triangles = {}; -- numTriangles); + local tfl = createVector(-10, 10, -10); + local tfr = createVector( 10, 10, -10); + local tbl = createVector(-10, 10, 10); + local tbr = createVector( 10, 10, 10); + local bfl = createVector(-10, -10, -10); + local bfr = createVector( 10, -10, -10); + local bbl = createVector(-10, -10, 10); + local bbr = createVector( 10, -10, 10); + + -- cube!!! + -- front + local i = 0; + + triangles[i + 1] = Triangle(tfl, tfr, bfr); i = i + 1; + triangles[i + 1] = Triangle(tfl, bfr, bfl); i = i + 1; + -- back + triangles[i + 1] = Triangle(tbl, tbr, bbr); i = i + 1; + triangles[i + 1] = Triangle(tbl, bbr, bbl); i = i + 1; + -- triangles[i-1].material = [0.7,0.2,0.2]; + -- triangles[i-1].material.reflection = 0.8; + -- left + triangles[i + 1] = Triangle(tbl, tfl, bbl); i = i + 1; + -- triangles[i-1].reflection = 0.6; + triangles[i + 1] = Triangle(tfl, bfl, bbl); i = i + 1; + -- triangles[i-1].reflection = 0.6; + -- right + triangles[i + 1] = Triangle(tbr, tfr, bbr); i = i + 1; + triangles[i + 1] = Triangle(tfr, bfr, bbr); i = i + 1; + -- top + triangles[i + 1] = Triangle(tbl, tbr, tfr); i = i + 1; + triangles[i + 1] = Triangle(tbl, tfr, tfl); i = i + 1; + -- bottom + triangles[i + 1] = Triangle(bbl, bbr, bfr); i = i + 1; + triangles[i + 1] = Triangle(bbl, bfr, bfl); i = i + 1; + + -- Floor!!!! + local green = createVector(0.0, 0.4, 0.0); + green.reflection = 0; -- + local grey = createVector(0.4, 0.4, 0.4); + grey.reflection = 1.0; + local floorShader = function(tri, pos, view) + local x = ((pos[1]/32) % 2 + 2) % 2; + local z = ((pos[3]/32 + 0.3) % 2 + 2) % 2; + if ((x < 1) ~= (z < 1)) then + --in the real world we use the fresnel term... + -- local angle = 1-dot(view, tri.normal); + -- angle *= angle; + -- angle *= angle; + -- angle *= angle; + --grey.reflection = angle; + return grey; + else + return green; + end + end + + local ffl = createVector(-1000, -30, -1000); + local ffr = createVector( 1000, -30, -1000); + local fbl = createVector(-1000, -30, 1000); + local fbr = createVector( 1000, -30, 1000); + triangles[i + 1] = Triangle(fbl, fbr, ffr); i = i + 1; + triangles[i-1 + 1].shader = floorShader; + triangles[i + 1] = Triangle(fbl, ffr, ffl); i = i + 1; + triangles[i-1 + 1].shader = floorShader; + + local _scene = Scene(triangles); + _scene.lights[1] = createVector(20, 38, -22); + _scene.lights[1].colour = createVector(0.7, 0.3, 0.3); + _scene.lights[2] = createVector(-23, 40, 17); + _scene.lights[2].colour = createVector(0.7, 0.3, 0.3); + _scene.lights[3] = createVector(23, 20, 17); + _scene.lights[3].colour = createVector(0.7, 0.7, 0.7); + _scene.ambient = createVector(0.1, 0.1, 0.1); + -- _scene.background = createVector(0.7, 0.7, 1.0); + + local pixels = {}; + for y = 0,size-1 do + pixels[y + 1] = {}; + for x = 0,size-1 do + pixels[y + 1][x + 1] = 0; + end + end + + local _camera = Camera(createVector(-40, 40, 40), createVector(0, 0, 0), createVector(0, 1, 0)); + _camera:render(_scene, pixels, size, size); + + return pixels; +end + +function arrayToCanvasCommands(pixels) + local s = 'Test\nvar pixels = ['; + for y = 0,size-1 do + s = s .. "["; + for x = 0,size-1 do + s = s .. "[" .. math.floor(pixels[y + 1][x + 1][1] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][2] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][3] * 255) .. "],"; + end + s = s .. "],"; + end + s = s .. '];\n var canvas = document.getElementById("renderCanvas").getContext("2d");\n\ +\n\ +\n\ + var size = ' .. size .. ';\n\ + canvas.fillStyle = "red";\n\ + canvas.fillRect(0, 0, size, size);\n\ + canvas.scale(1, -1);\n\ + canvas.translate(0, -size);\n\ +\n\ + if (!canvas.setFillColor)\n\ + canvas.setFillColor = function(r, g, b, a) {\n\ + this.fillStyle = "rgb("+[Math.floor(r), Math.floor(g), Math.floor(b)]+")";\n\ + }\n\ +\n\ +for (var y = 0; y < size; y++) {\n\ + for (var x = 0; x < size; x++) {\n\ + var l = pixels[y][x];\n\ + canvas.setFillColor(l[0], l[1], l[2], 1);\n\ + canvas.fillRect(x, y, 1, 1);\n\ + }\n\ +}'; + + return s; +end + +testOutput = arrayToCanvasCommands(raytraceScene()); + +--local f = io.output("output.html") +--f:write(testOutput) +--f:close() + +local expectedLength = 11599; +local testLength = #testOutput + +if (testLength ~= expectedLength) then + assert(false, "Error: bad result: expected length " .. expectedLength .. " but got " .. testLength); +end + +end + +bench.runCode(test, "3d-raytrace") diff --git a/bench/gc/test_SunSpider_crypto-aes.lua b/bench/gc/test_SunSpider_crypto-aes.lua new file mode 100644 index 0000000..8537e3d --- /dev/null +++ b/bench/gc/test_SunSpider_crypto-aes.lua @@ -0,0 +1,436 @@ +--[[ + * AES Cipher function: encrypt 'input' with Rijndael algorithm + * + * takes byte-array 'input' (16 bytes) + * 2D byte-array key schedule 'w' (Nr+1 x Nb bytes) + * + * applies Nr rounds (10/12/14) using key schedule w for 'add round key' stage + * + * returns byte-array encrypted value (16 bytes) + */]] + + local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +-- Sbox is pre-computed multiplicative inverse in GF(2^8) used in SubBytes and KeyExpansion [§5.1.1] +local Sbox = { 0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76, + 0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0, + 0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15, + 0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75, + 0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84, + 0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf, + 0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8, + 0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2, + 0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73, + 0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb, + 0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79, + 0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08, + 0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a, + 0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e, + 0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf, + 0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16 }; + +-- Rcon is Round Constant used for the Key Expansion [1st col is 2^(r-1) in GF(2^8)] [§5.2] +local Rcon = { { 0x00, 0x00, 0x00, 0x00 }, + {0x01, 0x00, 0x00, 0x00}, + {0x02, 0x00, 0x00, 0x00}, + {0x04, 0x00, 0x00, 0x00}, + {0x08, 0x00, 0x00, 0x00}, + {0x10, 0x00, 0x00, 0x00}, + {0x20, 0x00, 0x00, 0x00}, + {0x40, 0x00, 0x00, 0x00}, + {0x80, 0x00, 0x00, 0x00}, + {0x1b, 0x00, 0x00, 0x00}, + {0x36, 0x00, 0x00, 0x00} }; + +function Cipher(input, w) -- main Cipher function [§5.1] + local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) + local Nr = #w / Nb - 1; -- no of rounds: 10/12/14 for 128/192/256-bit keys + + local state = {{},{},{},{}}; -- initialise 4xNb byte-array 'state' with input [§3.4] + for i = 0,4*Nb-1 do state[(i % 4) + 1][math.floor(i/4) + 1] = input[i + 1]; end + + state = AddRoundKey(state, w, 0, Nb); + + for round = 1,Nr-1 do + state = SubBytes(state, Nb); + state = ShiftRows(state, Nb); + state = MixColumns(state, Nb); + state = AddRoundKey(state, w, round, Nb); + end + + state = SubBytes(state, Nb); + state = ShiftRows(state, Nb); + state = AddRoundKey(state, w, Nr, Nb); + + local output = {} -- convert state to 1-d array before returning [§3.4] + for i = 0,4*Nb-1 do output[i + 1] = state[(i % 4) + 1][math.floor(i / 4) + 1]; end + + return output; +end + + +function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1] + for r = 0,3 do + for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end + end + return s; +end + + +function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2] + local t = {}; + for r = 1,3 do + for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy + for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back + end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES): + return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf +end + + +function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3] + for c = 0,3 do + local a = {}; -- 'a' is a copy of the current column from 's' + local b = {}; -- 'b' is a•{02} in GF(2^8) + for i = 0,3 do + a[i + 1] = s[i + 1][c + 1]; + + if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then + b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b); + else + b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1); + end + end + -- a[n] ^ b[n] is a•{03} in GF(2^8) + s[1][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(b[1], a[2]), bit32.bxor(b[2], a[3])), a[4]); -- 2*a0 + 3*a1 + a2 + a3 + s[2][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(a[1], b[2]), bit32.bxor(a[3], b[3])), a[4]); -- a0 * 2*a1 + 3*a2 + a3 + s[3][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(a[1], a[2]), bit32.bxor(b[3], a[4])), b[4]); -- a0 + a1 + 2*a2 + 3*a3 + s[4][c + 1] = bit32.bxor(bit32.bxor(bit32.bxor(a[1], b[1]), bit32.bxor(a[2], a[3])), b[4]); -- 3*a0 + a1 + a2 + 2*a3 +end + return s; +end + + +function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4] + for r = 0,3 do + for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end + end + return state; +end + + +function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2] + local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) + local Nk = #key / 4 -- key length (in words): 4/6/8 for 128/192/256-bit keys + local Nr = Nk + 6; -- no of rounds: 10/12/14 for 128/192/256-bit keys + + local w = {}; + local temp = {}; + + for i = 0,Nk do + local r = { key[4*i + 1], key[4*i + 2], key[4*i + 3], key[4*i + 4] }; + w[i + 1] = r; + end + + for i = Nk,(Nb*(Nr+1)) - 1 do + w[i + 1] = {}; + for t = 0,3 do temp[t + 1] = w[i-1 + 1][t + 1]; end + if (i % Nk == 0) then + temp = SubWord(RotWord(temp)); + for t = 0,3 do temp[t + 1] = bit32.bxor(temp[t + 1], Rcon[i/Nk + 1][t + 1]); end + elseif (Nk > 6 and i % Nk == 4) then + temp = SubWord(temp); + end + for t = 0,3 do w[i + 1][t + 1] = bit32.bxor(w[i - Nk + 1][t + 1], temp[t + 1]); end + end + + return w; +end + +function SubWord(w) -- apply SBox to 4-byte word w + for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end + return w; +end + +function RotWord(w) -- rotate 4-byte word w left by one byte + w[5] = w[1]; + for i = 0,3 do w[i + 1] = w[i + 2]; end + return w; +end + + +--[[ + * Use AES to encrypt 'plaintext' with 'password' using 'nBits' key, in 'Counter' mode of operation + * - see http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf + * for each block + * - outputblock = cipher(counter, key) + * - cipherblock = plaintext xor outputblock + ]] + +function AESEncryptCtr(plaintext, password, nBits) + if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys + + -- for this example script, generate the key by applying Cipher to 1st 16/24/32 chars of password; + -- for real-world applications, a higher security approach would be to hash the password e.g. with SHA-1 + local nBytes = nBits/8; -- no bytes in key + local pwBytes = {}; + for i = 0,nBytes-1 do pwBytes[i + 1] = bit32.band(string.byte(password, i + 1), 0xff); end + local key = Cipher(pwBytes, KeyExpansion(pwBytes)); + + -- key is now 16/24/32 bytes long + for i = 1,nBytes-16 do + table.insert(key, key[i]) + end + + -- initialise counter block (NIST SP800-38A §B.2): millisecond time-stamp for nonce in 1st 8 bytes, + -- block counter in 2nd 8 bytes + local blockSize = 16; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES + local counterBlock = {}; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES + local nonce = 12564231564 -- (new Date()).getTime(); -- milliseconds since 1-Jan-1970 + + -- encode nonce in two stages to cater for JavaScript 32-bit limit on bitwise ops + for i = 0,3 do counterBlock[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff); end + for i = 0,3 do counterBlock[i + 4 + 1] = bit32.band(bit32.rshift(math.floor(nonce / 0x100000000), i*8), 0xff); end + + -- generate key schedule - an expansion of the key into distinct Key Rounds for each round + local keySchedule = KeyExpansion(key); + + local blockCount = math.ceil(#plaintext / blockSize); + local ciphertext = {}; -- ciphertext as array of strings + + for b = 0,blockCount-1 do + -- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes) + -- again done in two stages for 32-bit ops + for c = 0,3 do counterBlock[15-c + 1] = bit32.band(bit32.rshift(b, c*8), 0xff); end + for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.rshift(math.floor(b/0x100000000), c*8) end + + local cipherCntr = Cipher(counterBlock, keySchedule); -- -- encrypt counter block -- + + -- calculate length of final block: + local blockLength = nil + + if b= self.count_max then + Toggle.activate(self) + self.counter = 0 + end + return self + end + + function NthToggle:new (start_state, max_counter) + local o = Toggle.new(self, start_state) + o.count_max = max_counter + o.counter = 0 + return o + end + + + ----------------------------------------------------------- + -- main + ----------------------------------------------------------- + + function main () + local start = os.clock() + local N = 30000 + + local val = 1 + local toggle = Toggle:new(val) + for i=1,N do + val = toggle:activate():value() + val = toggle:activate():value() + val = toggle:activate():value() + val = toggle:activate():value() + val = toggle:activate():value() + val = toggle:activate():value() + val = toggle:activate():value() + val = toggle:activate():value() + val = toggle:activate():value() + val = toggle:activate():value() + end + print(val and "true" or "false") + + val = 1 + local ntoggle = NthToggle:new(val, 3) + for i=1,N do + val = ntoggle:activate():value() + val = ntoggle:activate():value() + val = ntoggle:activate():value() + val = ntoggle:activate():value() + val = ntoggle:activate():value() + val = ntoggle:activate():value() + val = ntoggle:activate():value() + val = ntoggle:activate():value() + val = ntoggle:activate():value() + val = ntoggle:activate():value() + end + print(val and "true" or "false") + return os.clock() - start + end + + return main() +end + +bench.runCode(test, "MethodCalls") \ No newline at end of file diff --git a/bench/micro_tests/test_OOP_constructor.lua b/bench/micro_tests/test_OOP_constructor.lua new file mode 100644 index 0000000..6e3633c --- /dev/null +++ b/bench/micro_tests/test_OOP_constructor.lua @@ -0,0 +1,29 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local Number = {} + Number.__index = Number + + function Number.new(v) + local self = { + value = v + } + setmetatable(self, Number) + return self + end + + function Number:Get() + return self.value + end + + local ts0 = os.clock() + for i=1,100000 do + local n = Number.new(42) + end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "OOP: constructor") \ No newline at end of file diff --git a/bench/micro_tests/test_OOP_method_call.lua b/bench/micro_tests/test_OOP_method_call.lua new file mode 100644 index 0000000..c2c6c4e --- /dev/null +++ b/bench/micro_tests/test_OOP_method_call.lua @@ -0,0 +1,31 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local Number = {} + Number.__index = Number + + function Number.new(v) + local self = { + value = v + } + setmetatable(self, Number) + return self + end + + function Number:Get() + return self.value + end + + local n = Number.new(42) + + local ts0 = os.clock() + for i=1,1000000 do + local nv = n:Get() + end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "OOP: method call") \ No newline at end of file diff --git a/bench/micro_tests/test_OOP_virtual_constructor.lua b/bench/micro_tests/test_OOP_virtual_constructor.lua new file mode 100644 index 0000000..48b1e55 --- /dev/null +++ b/bench/micro_tests/test_OOP_virtual_constructor.lua @@ -0,0 +1,29 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local Number = {} + Number.__index = Number + + function Number:new(class, v) + local self = { + value = v + } + setmetatable(self, Number) + return self + end + + function Number:Get() + return self.value + end + + local ts0 = os.clock() + for i=1,100000 do + local n = Number:new(42) + end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "OOP: virtual constructor") \ No newline at end of file diff --git a/bench/micro_tests/test_Pcall_call_return.lua b/bench/micro_tests/test_Pcall_call_return.lua new file mode 100644 index 0000000..9d07708 --- /dev/null +++ b/bench/micro_tests/test_Pcall_call_return.lua @@ -0,0 +1,14 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + function test() return 1 end + + local ts0 = os.clock() + for i=0,100000 do test() end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "Pcall: call return") \ No newline at end of file diff --git a/bench/micro_tests/test_Pcall_pcall_return.lua b/bench/micro_tests/test_Pcall_pcall_return.lua new file mode 100644 index 0000000..a6ff359 --- /dev/null +++ b/bench/micro_tests/test_Pcall_pcall_return.lua @@ -0,0 +1,14 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + function test() return 1 end + + local ts0 = os.clock() + for i=0,100000 do pcall(test) end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "Pcall: pcall return") \ No newline at end of file diff --git a/bench/micro_tests/test_Pcall_pcall_yield.lua b/bench/micro_tests/test_Pcall_pcall_yield.lua new file mode 100644 index 0000000..ac46c79 --- /dev/null +++ b/bench/micro_tests/test_Pcall_pcall_yield.lua @@ -0,0 +1,18 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + function test() coroutine.yield() return 1 end + + local ts0 = os.clock() + for i=0,100000 do + local co = coroutine.create(function() return pcall(test) end) + coroutine.resume(co) + coroutine.resume(co) + end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "Pcall: pcall yield") \ No newline at end of file diff --git a/bench/micro_tests/test_Pcall_xpcall_return.lua b/bench/micro_tests/test_Pcall_xpcall_return.lua new file mode 100644 index 0000000..a64eddf --- /dev/null +++ b/bench/micro_tests/test_Pcall_xpcall_return.lua @@ -0,0 +1,14 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + function test() return 1 end + + local ts0 = os.clock() + for i=0,100000 do xpcall(test, error) end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "Pcall: xpcall return") \ No newline at end of file diff --git a/bench/micro_tests/test_SqrtSum_exponent.lua b/bench/micro_tests/test_SqrtSum_exponent.lua new file mode 100644 index 0000000..eaddbfd --- /dev/null +++ b/bench/micro_tests/test_SqrtSum_exponent.lua @@ -0,0 +1,13 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + local sum = 0 + for i=0,500000 do sum = sum + i^0.5 end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "SqrtSum: ^0.5") \ No newline at end of file diff --git a/bench/micro_tests/test_SqrtSum_math_sqrt.lua b/bench/micro_tests/test_SqrtSum_math_sqrt.lua new file mode 100644 index 0000000..44b61cc --- /dev/null +++ b/bench/micro_tests/test_SqrtSum_math_sqrt.lua @@ -0,0 +1,13 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + local sum = 0 + for i=0,500000 do sum = sum + math.sqrt(i) end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "SqrtSum: math.sqrt") \ No newline at end of file diff --git a/bench/micro_tests/test_SqrtSum_sqrt.lua b/bench/micro_tests/test_SqrtSum_sqrt.lua new file mode 100644 index 0000000..34d8b38 --- /dev/null +++ b/bench/micro_tests/test_SqrtSum_sqrt.lua @@ -0,0 +1,14 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local sqrt = math.sqrt + local ts0 = os.clock() + local sum = 0 + for i=0,500000 do sum = sum + sqrt(i) end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "SqrtSum: sqrt") \ No newline at end of file diff --git a/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua b/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua new file mode 100644 index 0000000..242edb8 --- /dev/null +++ b/bench/micro_tests/test_SqrtSum_sqrt_getfenv.lua @@ -0,0 +1,15 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local sqrt = math.sqrt + getfenv() + local ts0 = os.clock() + local sum = 0 + for i=0,500000 do sum = sum + sqrt(i) end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "SqrtSum: sqrt getfenv") \ No newline at end of file diff --git a/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua b/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua new file mode 100644 index 0000000..fa8bfd0 --- /dev/null +++ b/bench/micro_tests/test_SqrtSum_sqrt_roundabout.lua @@ -0,0 +1,14 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local sqrt = nil or math.sqrt -- breaks fastcall analysis + local ts0 = os.clock() + local sum = 0 + for i=0,500000 do sum = sum + sqrt(i) end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "SqrtSum: sqrt roundabout") \ No newline at end of file diff --git a/bench/micro_tests/test_TableCreate_nil.lua b/bench/micro_tests/test_TableCreate_nil.lua new file mode 100644 index 0000000..1eff20e --- /dev/null +++ b/bench/micro_tests/test_TableCreate_nil.lua @@ -0,0 +1,12 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + for i=1,100000 do table.create(100) end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableCreate: nil") \ No newline at end of file diff --git a/bench/micro_tests/test_TableCreate_number.lua b/bench/micro_tests/test_TableCreate_number.lua new file mode 100644 index 0000000..620b562 --- /dev/null +++ b/bench/micro_tests/test_TableCreate_number.lua @@ -0,0 +1,12 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + for i=1,100000 do table.create(100,0) end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableCreate: number") \ No newline at end of file diff --git a/bench/micro_tests/test_TableCreate_zerofill.lua b/bench/micro_tests/test_TableCreate_zerofill.lua new file mode 100644 index 0000000..08c6c91 --- /dev/null +++ b/bench/micro_tests/test_TableCreate_zerofill.lua @@ -0,0 +1,15 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + for i=1,100000 do + local t = table.create(100) + for j=1,100 do t[j] = 0 end + end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableCreate: zerofill") \ No newline at end of file diff --git a/bench/micro_tests/test_TableFind_loop_ipairs.lua b/bench/micro_tests/test_TableFind_loop_ipairs.lua new file mode 100644 index 0000000..f363013 --- /dev/null +++ b/bench/micro_tests/test_TableFind_loop_ipairs.lua @@ -0,0 +1,24 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local function find(t, v) + for i,e in ipairs(t) do + if e == v then + return i + end + end + + return nil + end + + local t = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20} + + local ts0 = os.clock() + for i=1,100000 do find(t,15) end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableFind: for+ipairs") \ No newline at end of file diff --git a/bench/micro_tests/test_TableFind_table_find.lua b/bench/micro_tests/test_TableFind_table_find.lua new file mode 100644 index 0000000..a7619fc --- /dev/null +++ b/bench/micro_tests/test_TableFind_table_find.lua @@ -0,0 +1,14 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local t = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20} + + local ts0 = os.clock() + for i=1,100000 do table.find(t,15) end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableFind: table.find") \ No newline at end of file diff --git a/bench/micro_tests/test_TableInsertion_index_cached.lua b/bench/micro_tests/test_TableInsertion_index_cached.lua new file mode 100644 index 0000000..7f75b02 --- /dev/null +++ b/bench/micro_tests/test_TableInsertion_index_cached.lua @@ -0,0 +1,19 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + + for i=1,300 do + local t = {} + for j=1,1000 do + t[j] = j + end + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableInsertion: t[i]") \ No newline at end of file diff --git a/bench/micro_tests/test_TableInsertion_index_len.lua b/bench/micro_tests/test_TableInsertion_index_len.lua new file mode 100644 index 0000000..b9f71e0 --- /dev/null +++ b/bench/micro_tests/test_TableInsertion_index_len.lua @@ -0,0 +1,19 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + + for i=1,300 do + local t = {} + for j=1,1000 do + t[#t+1] = j + end + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableInsertion: t[#t+1]") \ No newline at end of file diff --git a/bench/micro_tests/test_TableInsertion_table_insert.lua b/bench/micro_tests/test_TableInsertion_table_insert.lua new file mode 100644 index 0000000..9efccd4 --- /dev/null +++ b/bench/micro_tests/test_TableInsertion_table_insert.lua @@ -0,0 +1,19 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + + for i=1,300 do + local t = {} + for j=1,1000 do + table.insert(t, j) + end + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableInsertion: table.insert") \ No newline at end of file diff --git a/bench/micro_tests/test_TableInsertion_table_insert_index.lua b/bench/micro_tests/test_TableInsertion_table_insert_index.lua new file mode 100644 index 0000000..af2292b --- /dev/null +++ b/bench/micro_tests/test_TableInsertion_table_insert_index.lua @@ -0,0 +1,30 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + + for i=1,5 do + local t = {} + for j=1,1000 do + table.insert(t, 1, j) + end + end + + local ts1 = os.clock() + + for i=1,5 do + local t = {} + for j=1,1000 do + table.insert(t, 1, j) + end + + for j=1,1000 do + assert(t[j] == (1000- (j - 1) ) ) + end + end + + return ts1-ts0 +end + +bench.runCode(test, "TableInsertion: table.insert(pos)") \ No newline at end of file diff --git a/bench/micro_tests/test_TableIteration.lua b/bench/micro_tests/test_TableIteration.lua new file mode 100644 index 0000000..47a94a3 --- /dev/null +++ b/bench/micro_tests/test_TableIteration.lua @@ -0,0 +1,19 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local t = {} + + for i=1,100 do t[tostring(i)] = i end + + local ts0 = os.clock() + local sum = 0 + for i=1,10000 do + for k,v in pairs(t) do sum = sum + v end + end + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableIteration") \ No newline at end of file diff --git a/bench/micro_tests/test_TableMarshal_select.lua b/bench/micro_tests/test_TableMarshal_select.lua new file mode 100644 index 0000000..110d912 --- /dev/null +++ b/bench/micro_tests/test_TableMarshal_select.lua @@ -0,0 +1,20 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local function pack(...) + return {n = select('#', ...), ...} + end + + local ts0 = os.clock() + + for i=1,100000 do + local t = pack(1,2,3,4,5,6,7,8,9,10) + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableMarshal: {n=select,...}") \ No newline at end of file diff --git a/bench/micro_tests/test_TableMarshal_table_pack.lua b/bench/micro_tests/test_TableMarshal_table_pack.lua new file mode 100644 index 0000000..45810a3 --- /dev/null +++ b/bench/micro_tests/test_TableMarshal_table_pack.lua @@ -0,0 +1,20 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local function pack(...) + return table.pack(...) + end + + local ts0 = os.clock() + + for i=1,100000 do + local t = pack(1,2,3,4,5,6,7,8,9,10) + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableMarshal: table.pack") \ No newline at end of file diff --git a/bench/micro_tests/test_TableMarshal_table_unpack_array.lua b/bench/micro_tests/test_TableMarshal_table_unpack_array.lua new file mode 100644 index 0000000..67bc1ef --- /dev/null +++ b/bench/micro_tests/test_TableMarshal_table_unpack_array.lua @@ -0,0 +1,21 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local function consume(...) + end + + local t = {1,2,3,4,5,6,7,8,9,10} + + local ts0 = os.clock() + + for i=1,1000000 do + consume(table.unpack(t)) + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableMarshal: table.unpack") \ No newline at end of file diff --git a/bench/micro_tests/test_TableMarshal_table_unpack_range.lua b/bench/micro_tests/test_TableMarshal_table_unpack_range.lua new file mode 100644 index 0000000..a678d8b --- /dev/null +++ b/bench/micro_tests/test_TableMarshal_table_unpack_range.lua @@ -0,0 +1,21 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local function consume(...) + end + + local t = {n=10,1,2,3,4,5,6,7,8,9,10} + + local ts0 = os.clock() + + for i=1,1000000 do + consume(table.unpack(t, 1, t.n)) + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableMarshal: table.unpack/n") \ No newline at end of file diff --git a/bench/micro_tests/test_TableMarshal_varargs.lua b/bench/micro_tests/test_TableMarshal_varargs.lua new file mode 100644 index 0000000..19ef81f --- /dev/null +++ b/bench/micro_tests/test_TableMarshal_varargs.lua @@ -0,0 +1,20 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local function pack(...) + return {...} + end + + local ts0 = os.clock() + + for i=1,100000 do + local t = pack(1,2,3,4,5,6,7,8,9,10) + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableMarshal: {...}") \ No newline at end of file diff --git a/bench/micro_tests/test_TableMove_empty_table.lua b/bench/micro_tests/test_TableMove_empty_table.lua new file mode 100644 index 0000000..75ce272 --- /dev/null +++ b/bench/micro_tests/test_TableMove_empty_table.lua @@ -0,0 +1,23 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + local t = table.create(250001, 0) + + for i=1,250000 do + t[i] = i + end + + local t2 = {} + + local ts0 = os.clock() + table.move(t, 1, 250000, 1, t2) + local ts1 = os.clock() + + for i=1,250000-1 do + assert(t2[i] == i) + end + + return ts1-ts0 +end + +bench.runCode(test, "TableMove: {}") diff --git a/bench/micro_tests/test_TableMove_same_table.lua b/bench/micro_tests/test_TableMove_same_table.lua new file mode 100644 index 0000000..8157657 --- /dev/null +++ b/bench/micro_tests/test_TableMove_same_table.lua @@ -0,0 +1,21 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + local t = table.create(5000001, 0) + + for i=0,5000000 do + t[i] = i + end + + local ts0 = os.clock() + table.move(t, 1, 250000, 250001, t) + local ts1 = os.clock() + + for i=250001,(500000-1) do + assert(t[i] == (i - 250001) + 1) + end + + return ts1-ts0 +end + +bench.runCode(test, "TableMove: same table") diff --git a/bench/micro_tests/test_TableMove_table_create.lua b/bench/micro_tests/test_TableMove_table_create.lua new file mode 100644 index 0000000..19dfd18 --- /dev/null +++ b/bench/micro_tests/test_TableMove_table_create.lua @@ -0,0 +1,23 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + local t = table.create(250001, 0) + + for i=1,250000 do + t[i] = i + end + + local t2 = table.create(250001, 100) + + local ts0 = os.clock() + table.move(t, 1, 250000, 1, t2) + local ts1 = os.clock() + + for i=1,250000-1 do + assert(t2[i] == i) + end + + return ts1-ts0 +end + +bench.runCode(test, "TableMove: table.create") diff --git a/bench/micro_tests/test_TableRemoval_table_remove.lua b/bench/micro_tests/test_TableRemoval_table_remove.lua new file mode 100644 index 0000000..25acd54 --- /dev/null +++ b/bench/micro_tests/test_TableRemoval_table_remove.lua @@ -0,0 +1,22 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local ts0 = os.clock() + + local iterations = 25000 + + local t = table.create(iterations, 100) + + for j=1,100 do + table.remove(t, 1) + end + + assert(#t == (iterations - 100)) + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "TableRemoval: table.remove") diff --git a/bench/micro_tests/test_UpvalueCapture.lua b/bench/micro_tests/test_UpvalueCapture.lua new file mode 100644 index 0000000..96e8f57 --- /dev/null +++ b/bench/micro_tests/test_UpvalueCapture.lua @@ -0,0 +1,19 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local tab = {} + + local ts0 = os.clock() + + for i=1, 1_000_000 do + local j = i + 1 + tab[i] = function() return i,j end + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "UpvalueCapture") \ No newline at end of file diff --git a/bench/micro_tests/test_VariadicSelect.lua b/bench/micro_tests/test_VariadicSelect.lua new file mode 100644 index 0000000..668956f --- /dev/null +++ b/bench/micro_tests/test_VariadicSelect.lua @@ -0,0 +1,26 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + function sum(...) + local res = 0 + local length = select("#", ...) + for i = 1, length do + local item = select(i, ...) + res += item + end + return res + end + + local ts0 = os.clock() + + for i=1, 100_000 do + sum(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "VariadicSelect") diff --git a/bench/micro_tests/test_string_lib.lua b/bench/micro_tests/test_string_lib.lua new file mode 100644 index 0000000..3994c77 --- /dev/null +++ b/bench/micro_tests/test_string_lib.lua @@ -0,0 +1,39 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +bench.runCode(function() + local src = string.rep("abcdefghijklmnopqrstuvwxyz", 100) + local str = "" + for i=1,1000 do + str = string.upper(src) + str = string.reverse(str) + str = string.lower(str) + end + assert(#str) +end, "string: reverse/upper/lower (large)") + +bench.runCode(function() + local str = "" + for i=1,100000 do + src = "abcdefghijklmnopqrstuvwxyz" .. i + str = string.upper(src) + str = string.reverse(str) + str = string.lower(str) + end + assert(#str) +end, "string: reverse/upper/lower (unique)") + +bench.runCode(function() + local str = "" + for i=1,1000000 do + str = string.rep("_", 19) + end + assert(#str) +end, "string: rep (small)") + +bench.runCode(function() + local str = "" + for i=1,100 do + str = string.rep("abcd", 100000) + end + assert(#str) +end, "string: rep (large)") diff --git a/bench/micro_tests/test_table_concat.lua b/bench/micro_tests/test_table_concat.lua new file mode 100644 index 0000000..430ad0a --- /dev/null +++ b/bench/micro_tests/test_table_concat.lua @@ -0,0 +1,27 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +bench.runCode(function() + for outer=1,28,3 do + for inner=1,28,3 do + local t2 = table.create(20, string.rep("n", outer)) + local str = "" + for i=1,500 do + str = table.concat(t2, string.rep("!", inner)) + end + assert(#str) + end + end +end, "table: concat (small)") + +bench.runCode(function() + for outer=1,21,3 do + for inner=1,21,3 do + local t2 = table.create(200, string.rep("n", outer)) + local str = "" + for i=1,100 do + str = table.concat(t2, string.rep("!", inner)) + end + assert(#str) + end + end +end, "table: concat (big)") diff --git a/bench/tabulate.py b/bench/tabulate.py new file mode 100644 index 0000000..fc03417 --- /dev/null +++ b/bench/tabulate.py @@ -0,0 +1,91 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +from typing import Dict +from enum import Enum +import re + +class Alignment(Enum): + LEFT = 0 + RIGHT = 1 + CENTER = 2 + +class TablePrinter(object): + def __init__(self, columns): + assert(len(columns) > 0) + self._columns = columns + self._widths = [len(col['label']) for col in self._columns] + self._rows = [] + pass + + def _convert_field_dict_to_ordered_list(self, fields:Dict[str, object]): + assert(len(fields) == len(self._columns)) + + ordered_list = [None] * len(self._columns) + column_names = [column['label'] for column in self._columns] + + for column, value in fields.items(): + index = column_names.index(column) + ordered_list[index] = value + return ordered_list + + def _print_row(self, row, align_style=None): + for i, (value, column, align_width) in enumerate(zip(row, self._columns, self._widths)): + if i > 0: + print(' | ', end='') + + actual_align_style = align_style if align_style != None else column['align'] + align_char = { + Alignment.LEFT: '<', + Alignment.CENTER: '^', + Alignment.RIGHT: '>' + }[actual_align_style] + print('{0:{align_style}{align_width}}'.format(value, align_style=align_char, align_width=align_width), end=' ') + print() + pass + + def _print_horizontal_separator(self): + for i, align_width in enumerate(self._widths): + if i > 0: + print('-+-', end='') + print('-' * (align_width+1), end='') + print() + pass + + def clean_colorama(self, str): + return re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]').sub('', str) + + def add_row(self, fields:Dict[str, object]): + fields = self._convert_field_dict_to_ordered_list(fields) + + for i, value in enumerate(fields): + + self._widths[i] = max(self._widths[i], len(self.clean_colorama(str(value)))) + + self._rows.append(fields) + + def _compute_summary_row(self): + sums = [0] * len(self._widths) + for row in self._rows: + for i, value in enumerate(row): + if not isinstance(value, int): + continue + sums[i] = sums[i] + value + sums[0] = "Total" + return sums + + def print(self, summary=False): + self._print_row([column['label'] for column in self._columns], align_style=Alignment.LEFT) + self._print_horizontal_separator() + + if summary: + summary_row = self._compute_summary_row() + for i, value in enumerate(summary_row): + self._widths[i] = max(self._widths[i], len(str(value))) + + for row in self._rows: + self._print_row(row) + + if summary: + self._print_horizontal_separator() + self._print_row(summary_row) + + pass diff --git a/bench/tests/base64.lua b/bench/tests/base64.lua new file mode 100644 index 0000000..3755c54 --- /dev/null +++ b/bench/tests/base64.lua @@ -0,0 +1,81 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + local base64 = {} + + local extract = bit32.extract + + function base64.makeencoder( s62, s63, spad ) + local encoder = {} + for b64code, char in pairs{[0]='A','B','C','D','E','F','G','H','I','J', + 'K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y', + 'Z','a','b','c','d','e','f','g','h','i','j','k','l','m','n', + 'o','p','q','r','s','t','u','v','w','x','y','z','0','1','2', + '3','4','5','6','7','8','9',s62 or '+',s63 or'/',spad or'='} do + encoder[b64code] = char:byte() + end + return encoder + end + + function base64.makedecoder( s62, s63, spad ) + local decoder = {} + for b64code, charcode in pairs( base64.makeencoder( s62, s63, spad )) do + decoder[charcode] = b64code + end + return decoder + end + + local DEFAULT_ENCODER = base64.makeencoder() + local DEFAULT_DECODER = base64.makedecoder() + + local char, concat, byte = string.char, table.concat, string.byte + + function base64.decode( b64, decoder, usecaching ) + decoder = decoder or DEFAULT_DECODER + local cache = usecaching and {} + local t, k = {}, 1 + local n = #b64 + local padding = b64:sub(-2) == '==' and 2 or b64:sub(-1) == '=' and 1 or 0 + for i = 1, padding > 0 and n-4 or n, 4 do + local a, b, c, d = byte( b64, i, i+3 ) + local s + if usecaching then + local v0 = a*0x1000000 + b*0x10000 + c*0x100 + d + s = cache[v0] + if not s then + local v = decoder[a]*0x40000 + decoder[b]*0x1000 + decoder[c]*0x40 + decoder[d] + s = char( extract(v,16,8), extract(v,8,8), extract(v,0,8)) + cache[v0] = s + end + else + local v = decoder[a]*0x40000 + decoder[b]*0x1000 + decoder[c]*0x40 + decoder[d] + s = char( extract(v,16,8), extract(v,8,8), extract(v,0,8)) + end + t[k] = s + k = k + 1 + end + if padding == 1 then + local a, b, c = byte( b64, n-3, n-1 ) + local v = decoder[a]*0x40000 + decoder[b]*0x1000 + decoder[c]*0x40 + t[k] = char( extract(v,16,8), extract(v,8,8)) + elseif padding == 2 then + local a, b = byte( b64, n-3, n-2 ) + local v = decoder[a]*0x40000 + decoder[b]*0x1000 + t[k] = char( extract(v,16,8)) + end + return concat( t ) + end + + local ts0 = os.clock() + + for i = 1, 2000 do + base64.decode("TWFuIGlzIGRpc3Rpbmd1aXNoZWQsIG5vdCBvbmx5IGJ5IGhpcyByZWFzb24sIGJ1dCBieSB0aGlzIHNpbmd1bGFyIHBhc3Npb24gZnJvbSBvdGhlciBhbmltYWxzLCB3aGljaCBpcyBhIGx1c3Qgb2YgdGhlIG1pbmQsIHRoYXQgYnkgYSBwZXJzZXZlcmFuY2Ugb2YgZGVsaWdodCBpbiB0aGUgY29udGludWVkIGFuZCBpbmRlZmF0aWdhYmxlIGdlbmVyYXRpb24gb2Yga25vd2xlZGdlLCBleGNlZWRzIHRoZSBzaG9ydCB2ZWhlbWVuY2Ugb2YgYW55IGNhcm5hbCBwbGVhc3VyZS4=") + end + + local ts1 = os.clock() + + return ts1 - ts0 +end + +bench.runCode(test, "base64") \ No newline at end of file diff --git a/bench/tests/deltablue.lua b/bench/tests/deltablue.lua new file mode 100644 index 0000000..ecf246d --- /dev/null +++ b/bench/tests/deltablue.lua @@ -0,0 +1,934 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +-- Copyright 2008 the V8 project authors. All rights reserved. +-- Copyright 1996 John Maloney and Mario Wolczko. + +-- This program is free software; you can redistribute it and/or modify +-- it under the terms of the GNU General Public License as published by +-- the Free Software Foundation; either version 2 of the License, or +-- (at your option) any later version. +-- +-- This program is distributed in the hope that it will be useful, +-- but WITHOUT ANY WARRANTY; without even the implied warranty of +-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +-- GNU General Public License for more details. +-- +-- You should have received a copy of the GNU General Public License +-- along with this program; if not, write to the Free Software +-- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + + +-- This implementation of the DeltaBlue benchmark is derived +-- from the Smalltalk implementation by John Maloney and Mario +-- Wolczko. Some parts have been translated directly, whereas +-- others have been modified more aggresively to make it feel +-- more like a JavaScript program. + + +-- +-- A JavaScript implementation of the DeltaBlue constraint-solving +-- algorithm, as described in: +-- +-- "The DeltaBlue Algorithm: An Incremental Constraint Hierarchy Solver" +-- Bjorn N. Freeman-Benson and John Maloney +-- January 1990 Communications of the ACM, +-- also available as University of Washington TR 89-08-06. +-- +-- Beware: this benchmark is written in a grotesque style where +-- the constraint model is built by side-effects from constructors. +-- I've kept it this way to avoid deviating too much from the original +-- implementation. +-- + +function class(base) + local T = {} + T.__index = T + + if base then + T.super = base + setmetatable(T, base) + end + + function T.new(...) + local O = {} + setmetatable(O, T) + O:constructor(...) + return O + end + + return T +end + +local planner + +--- O b j e c t M o d e l --- + +local function alert (...) print(...) end + +local OrderedCollection = class() + +function OrderedCollection:constructor() + self.elms = {} +end + +function OrderedCollection:add(elm) + self.elms[#self.elms + 1] = elm +end + +function OrderedCollection:at (index) + return self.elms[index] +end + +function OrderedCollection:size () + return #self.elms +end + +function OrderedCollection:removeFirst () + local e = self.elms[#self.elms] + self.elms[#self.elms] = nil + return e +end + +function OrderedCollection:remove (elm) + local index = 0 + local skipped = 0 + + for i = 1, #self.elms do + local value = self.elms[i] + if value ~= elm then + self.elms[index] = value + index = index + 1 + else + skipped = skipped + 1 + end + end + + local l = #self.elms + for i = 1, skipped do self.elms[l - i + 1] = nil end +end + +-- +-- S t r e n g t h +-- + +-- +-- Strengths are used to measure the relative importance of constraints. +-- New strengths may be inserted in the strength hierarchy without +-- disrupting current constraints. Strengths cannot be created outside +-- this class, so pointer comparison can be used for value comparison. +-- + +local Strength = class() + +function Strength:constructor(strengthValue, name) + self.strengthValue = strengthValue + self.name = name +end + +function Strength.stronger (s1, s2) + return s1.strengthValue < s2.strengthValue +end + +function Strength.weaker (s1, s2) + return s1.strengthValue > s2.strengthValue +end + +function Strength.weakestOf (s1, s2) + return Strength.weaker(s1, s2) and s1 or s2 +end + +function Strength.strongest (s1, s2) + return Strength.stronger(s1, s2) and s1 or s2 +end + +function Strength:nextWeaker () + local v = self.strengthValue + if v == 0 then return Strength.WEAKEST + elseif v == 1 then return Strength.WEAK_DEFAULT + elseif v == 2 then return Strength.NORMAL + elseif v == 3 then return Strength.STRONG_DEFAULT + elseif v == 4 then return Strength.PREFERRED + elseif v == 5 then return Strength.REQUIRED + end +end + +-- Strength constants. +Strength.REQUIRED = Strength.new(0, "required"); +Strength.STONG_PREFERRED = Strength.new(1, "strongPreferred"); +Strength.PREFERRED = Strength.new(2, "preferred"); +Strength.STRONG_DEFAULT = Strength.new(3, "strongDefault"); +Strength.NORMAL = Strength.new(4, "normal"); +Strength.WEAK_DEFAULT = Strength.new(5, "weakDefault"); +Strength.WEAKEST = Strength.new(6, "weakest"); + +-- +-- C o n s t r a i n t +-- + +-- +-- An abstract class representing a system-maintainable relationship +-- (or "constraint") between a set of variables. A constraint supplies +-- a strength instance variable; concrete subclasses provide a means +-- of storing the constrained variables and other information required +-- to represent a constraint. +-- + +local Constraint = class () + +function Constraint:constructor(strength) + self.strength = strength +end + +-- +-- Activate this constraint and attempt to satisfy it. +-- +function Constraint:addConstraint () + self:addToGraph() + planner:incrementalAdd(self) +end + +-- +-- Attempt to find a way to enforce this constraint. If successful, +-- record the solution, perhaps modifying the current dataflow +-- graph. Answer the constraint that this constraint overrides, if +-- there is one, or nil, if there isn't. +-- Assume: I am not already satisfied. +-- +function Constraint:satisfy (mark) + self:chooseMethod(mark) + if not self:isSatisfied() then + if self.strength == Strength.REQUIRED then + alert("Could not satisfy a required constraint!") + end + return nil + end + self:markInputs(mark) + local out = self:output() + local overridden = out.determinedBy + if overridden ~= nil then overridden:markUnsatisfied() end + out.determinedBy = self + if not planner:addPropagate(self, mark) then alert("Cycle encountered") end + out.mark = mark + return overridden +end + +function Constraint:destroyConstraint () + if self:isSatisfied() + then planner:incrementalRemove(self) + else self:removeFromGraph() + end +end + +-- +-- Normal constraints are not input constraints. An input constraint +-- is one that depends on external state, such as the mouse, the +-- keybord, a clock, or some arbitraty piece of imperative code. +-- +function Constraint:isInput () + return false +end + + +-- +-- U n a r y C o n s t r a i n t +-- + +-- +-- Abstract superclass for constraints having a single possible output +-- variable. +-- + +local UnaryConstraint = class(Constraint) + +function UnaryConstraint:constructor (v, strength) + UnaryConstraint.super.constructor(self, strength) + self.myOutput = v + self.satisfied = false + self:addConstraint() +end + +-- +-- Adds this constraint to the constraint graph +-- +function UnaryConstraint:addToGraph () + self.myOutput:addConstraint(self) + self.satisfied = false +end + +-- +-- Decides if this constraint can be satisfied and records that +-- decision. +-- +function UnaryConstraint:chooseMethod (mark) + self.satisfied = (self.myOutput.mark ~= mark) + and Strength.stronger(self.strength, self.myOutput.walkStrength); +end + +-- +-- Returns true if this constraint is satisfied in the current solution. +-- +function UnaryConstraint:isSatisfied () + return self.satisfied; +end + +function UnaryConstraint:markInputs (mark) + -- has no inputs +end + +-- +-- Returns the current output variable. +-- +function UnaryConstraint:output () + return self.myOutput +end + +-- +-- Calculate the walkabout strength, the stay flag, and, if it is +-- 'stay', the value for the current output of this constraint. Assume +-- this constraint is satisfied. +-- +function UnaryConstraint:recalculate () + self.myOutput.walkStrength = self.strength + self.myOutput.stay = not self:isInput() + if self.myOutput.stay then + self:execute() -- Stay optimization + end +end + +-- +-- Records that this constraint is unsatisfied +-- +function UnaryConstraint:markUnsatisfied () + self.satisfied = false +end + +function UnaryConstraint:inputsKnown () + return true +end + +function UnaryConstraint:removeFromGraph () + if self.myOutput ~= nil then + self.myOutput:removeConstraint(self) + end + self.satisfied = false +end + +-- +-- S t a y C o n s t r a i n t +-- + +-- +-- Variables that should, with some level of preference, stay the same. +-- Planners may exploit the fact that instances, if satisfied, will not +-- change their output during plan execution. This is called "stay +-- optimization". +-- + +local StayConstraint = class(UnaryConstraint) + +function StayConstraint:constructor(v, str) + StayConstraint.super.constructor(self, v, str) +end + +function StayConstraint:execute () + -- Stay constraints do nothing +end + +-- +-- E d i t C o n s t r a i n t +-- + +-- +-- A unary input constraint used to mark a variable that the client +-- wishes to change. +-- + +local EditConstraint = class (UnaryConstraint) + +function EditConstraint:constructor(v, str) + EditConstraint.super.constructor(self, v, str) +end + +-- +-- Edits indicate that a variable is to be changed by imperative code. +-- +function EditConstraint:isInput () + return true +end + +function EditConstraint:execute () + -- Edit constraints do nothing +end + +-- +-- B i n a r y C o n s t r a i n t +-- + +local Direction = {} +Direction.NONE = 0 +Direction.FORWARD = 1 +Direction.BACKWARD = -1 + +-- +-- Abstract superclass for constraints having two possible output +-- variables. +-- + +local BinaryConstraint = class(Constraint) + +function BinaryConstraint:constructor(var1, var2, strength) + BinaryConstraint.super.constructor(self, strength); + self.v1 = var1 + self.v2 = var2 + self.direction = Direction.NONE + self:addConstraint() +end + + +-- +-- Decides if this constraint can be satisfied and which way it +-- should flow based on the relative strength of the variables related, +-- and record that decision. +-- +function BinaryConstraint:chooseMethod (mark) + if self.v1.mark == mark then + self.direction = (self.v2.mark ~= mark and Strength.stronger(self.strength, self.v2.walkStrength)) and Direction.FORWARD or Direction.NONE + end + if self.v2.mark == mark then + self.direction = (self.v1.mark ~= mark and Strength.stronger(self.strength, self.v1.walkStrength)) and Direction.BACKWARD or Direction.NONE + end + if Strength.weaker(self.v1.walkStrength, self.v2.walkStrength) then + self.direction = Strength.stronger(self.strength, self.v1.walkStrength) and Direction.BACKWARD or Direction.NONE + else + self.direction = Strength.stronger(self.strength, self.v2.walkStrength) and Direction.FORWARD or Direction.BACKWARD + end +end + +-- +-- Add this constraint to the constraint graph +-- +function BinaryConstraint:addToGraph () + self.v1:addConstraint(self) + self.v2:addConstraint(self) + self.direction = Direction.NONE +end + +-- +-- Answer true if this constraint is satisfied in the current solution. +-- +function BinaryConstraint:isSatisfied () + return self.direction ~= Direction.NONE +end + +-- +-- Mark the input variable with the given mark. +-- +function BinaryConstraint:markInputs (mark) + self:input().mark = mark +end + +-- +-- Returns the current input variable +-- +function BinaryConstraint:input () + return (self.direction == Direction.FORWARD) and self.v1 or self.v2 +end + +-- +-- Returns the current output variable +-- +function BinaryConstraint:output () + return (self.direction == Direction.FORWARD) and self.v2 or self.v1 +end + +-- +-- Calculate the walkabout strength, the stay flag, and, if it is +-- 'stay', the value for the current output of this +-- constraint. Assume this constraint is satisfied. +-- +function BinaryConstraint:recalculate () + local ihn = self:input() + local out = self:output() + out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength); + out.stay = ihn.stay + if out.stay then self:execute() end +end + +-- +-- Record the fact that self constraint is unsatisfied. +-- +function BinaryConstraint:markUnsatisfied () + self.direction = Direction.NONE +end + +function BinaryConstraint:inputsKnown (mark) + local i = self:input() + return i.mark == mark or i.stay or i.determinedBy == nil +end + +function BinaryConstraint:removeFromGraph () + if (self.v1 ~= nil) then self.v1:removeConstraint(self) end + if (self.v2 ~= nil) then self.v2:removeConstraint(self) end + self.direction = Direction.NONE +end + +-- +-- S c a l e C o n s t r a i n t +-- + +-- +-- Relates two variables by the linear scaling relationship: "v2 = +-- (v1 * scale) + offset". Either v1 or v2 may be changed to maintain +-- this relationship but the scale factor and offset are considered +-- read-only. +-- + +local ScaleConstraint = class (BinaryConstraint) + +function ScaleConstraint:constructor(src, scale, offset, dest, strength) + self.direction = Direction.NONE + self.scale = scale + self.offset = offset + ScaleConstraint.super.constructor(self, src, dest, strength) +end + + +-- +-- Adds this constraint to the constraint graph. +-- +function ScaleConstraint:addToGraph () + ScaleConstraint.super.addToGraph(self) + self.scale:addConstraint(self) + self.offset:addConstraint(self) +end + +function ScaleConstraint:removeFromGraph () + ScaleConstraint.super.removeFromGraph(self) + if (self.scale ~= nil) then self.scale:removeConstraint(self) end + if (self.offset ~= nil) then self.offset:removeConstraint(self) end +end + +function ScaleConstraint:markInputs (mark) + ScaleConstraint.super.markInputs(self, mark); + self.offset.mark = mark + self.scale.mark = mark +end + +-- +-- Enforce this constraint. Assume that it is satisfied. +-- +function ScaleConstraint:execute () + if self.direction == Direction.FORWARD then + self.v2.value = self.v1.value * self.scale.value + self.offset.value + else + self.v1.value = (self.v2.value - self.offset.value) / self.scale.value + end +end + +-- +-- Calculate the walkabout strength, the stay flag, and, if it is +-- 'stay', the value for the current output of this constraint. Assume +-- this constraint is satisfied. +-- +function ScaleConstraint:recalculate () + local ihn = self:input() + local out = self:output() + out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength) + out.stay = ihn.stay and self.scale.stay and self.offset.stay + if out.stay then self:execute() end +end + +-- +-- E q u a l i t y C o n s t r a i n t +-- + +-- +-- Constrains two variables to have the same value. +-- + +local EqualityConstraint = class (BinaryConstraint) + +function EqualityConstraint:constructor(var1, var2, strength) + EqualityConstraint.super.constructor(self, var1, var2, strength) +end + + +-- +-- Enforce this constraint. Assume that it is satisfied. +-- +function EqualityConstraint:execute () + self:output().value = self:input().value +end + +-- +-- V a r i a b l e +-- + +-- +-- A constrained variable. In addition to its value, it maintain the +-- structure of the constraint graph, the current dataflow graph, and +-- various parameters of interest to the DeltaBlue incremental +-- constraint solver. +-- +local Variable = class () + +function Variable:constructor(name, initialValue) + self.value = initialValue or 0 + self.constraints = OrderedCollection.new() + self.determinedBy = nil + self.mark = 0 + self.walkStrength = Strength.WEAKEST + self.stay = true + self.name = name +end + +-- +-- Add the given constraint to the set of all constraints that refer +-- this variable. +-- +function Variable:addConstraint (c) + self.constraints:add(c) +end + +-- +-- Removes all traces of c from this variable. +-- +function Variable:removeConstraint (c) + self.constraints:remove(c) + if self.determinedBy == c then + self.determinedBy = nil + end +end + +-- +-- P l a n n e r +-- + +-- +-- The DeltaBlue planner +-- +local Planner = class() +function Planner:constructor() + self.currentMark = 0 +end + +-- +-- Attempt to satisfy the given constraint and, if successful, +-- incrementally update the dataflow graph. Details: If satifying +-- the constraint is successful, it may override a weaker constraint +-- on its output. The algorithm attempts to resatisfy that +-- constraint using some other method. This process is repeated +-- until either a) it reaches a variable that was not previously +-- determined by any constraint or b) it reaches a constraint that +-- is too weak to be satisfied using any of its methods. The +-- variables of constraints that have been processed are marked with +-- a unique mark value so that we know where we've been. This allows +-- the algorithm to avoid getting into an infinite loop even if the +-- constraint graph has an inadvertent cycle. +-- +function Planner:incrementalAdd (c) + local mark = self:newMark() + local overridden = c:satisfy(mark) + while overridden ~= nil do + overridden = overridden:satisfy(mark) + end +end + +-- +-- Entry point for retracting a constraint. Remove the given +-- constraint and incrementally update the dataflow graph. +-- Details: Retracting the given constraint may allow some currently +-- unsatisfiable downstream constraint to be satisfied. We therefore collect +-- a list of unsatisfied downstream constraints and attempt to +-- satisfy each one in turn. This list is traversed by constraint +-- strength, strongest first, as a heuristic for avoiding +-- unnecessarily adding and then overriding weak constraints. +-- Assume: c is satisfied. +-- +function Planner:incrementalRemove (c) + local out = c:output() + c:markUnsatisfied() + c:removeFromGraph() + local unsatisfied = self:removePropagateFrom(out) + local strength = Strength.REQUIRED + repeat + for i = 1, unsatisfied:size() do + local u = unsatisfied:at(i) + if u.strength == strength then + self:incrementalAdd(u) + end + end + strength = strength:nextWeaker() + until strength == Strength.WEAKEST +end + +-- +-- Select a previously unused mark value. +-- +function Planner:newMark () + self.currentMark = self.currentMark + 1 + return self.currentMark +end + +-- +-- Extract a plan for resatisfaction starting from the given source +-- constraints, usually a set of input constraints. This method +-- assumes that stay optimization is desired; the plan will contain +-- only constraints whose output variables are not stay. Constraints +-- that do no computation, such as stay and edit constraints, are +-- not included in the plan. +-- Details: The outputs of a constraint are marked when it is added +-- to the plan under construction. A constraint may be appended to +-- the plan when all its input variables are known. A variable is +-- known if either a) the variable is marked (indicating that has +-- been computed by a constraint appearing earlier in the plan), b) +-- the variable is 'stay' (i.e. it is a constant at plan execution +-- time), or c) the variable is not determined by any +-- constraint. The last provision is for past states of history +-- variables, which are not stay but which are also not computed by +-- any constraint. +-- Assume: sources are all satisfied. +-- +local Plan -- FORWARD DECLARATION +function Planner:makePlan (sources) + local mark = self:newMark() + local plan = Plan.new() + local todo = sources + while todo:size() > 0 do + local c = todo:removeFirst() + if c:output().mark ~= mark and c:inputsKnown(mark) then + plan:addConstraint(c) + c:output().mark = mark + self:addConstraintsConsumingTo(c:output(), todo) + end + end + return plan +end + +-- +-- Extract a plan for resatisfying starting from the output of the +-- given constraints, usually a set of input constraints. +-- +function Planner:extractPlanFromConstraints (constraints) + local sources = OrderedCollection.new() + for i = 1, constraints:size() do + local c = constraints:at(i) + if c:isInput() and c:isSatisfied() then + -- not in plan already and eligible for inclusion + sources:add(c) + end + end + return self:makePlan(sources) +end + +-- +-- Recompute the walkabout strengths and stay flags of all variables +-- downstream of the given constraint and recompute the actual +-- values of all variables whose stay flag is true. If a cycle is +-- detected, remove the given constraint and answer +-- false. Otherwise, answer true. +-- Details: Cycles are detected when a marked variable is +-- encountered downstream of the given constraint. The sender is +-- assumed to have marked the inputs of the given constraint with +-- the given mark. Thus, encountering a marked node downstream of +-- the output constraint means that there is a path from the +-- constraint's output to one of its inputs. +-- +function Planner:addPropagate (c, mark) + local todo = OrderedCollection.new() + todo:add(c) + while todo:size() > 0 do + local d = todo:removeFirst() + if d:output().mark == mark then + self:incrementalRemove(c) + return false + end + d:recalculate() + self:addConstraintsConsumingTo(d:output(), todo) + end + return true +end + + +-- +-- Update the walkabout strengths and stay flags of all variables +-- downstream of the given constraint. Answer a collection of +-- unsatisfied constraints sorted in order of decreasing strength. +-- +function Planner:removePropagateFrom (out) + out.determinedBy = nil + out.walkStrength = Strength.WEAKEST + out.stay = true + local unsatisfied = OrderedCollection.new() + local todo = OrderedCollection.new() + todo:add(out) + while todo:size() > 0 do + local v = todo:removeFirst() + for i = 1, v.constraints:size() do + local c = v.constraints:at(i) + if not c:isSatisfied() then unsatisfied:add(c) end + end + local determining = v.determinedBy + for i = 1, v.constraints:size() do + local next = v.constraints:at(i); + if next ~= determining and next:isSatisfied() then + next:recalculate() + todo:add(next:output()) + end + end + end + return unsatisfied +end + +function Planner:addConstraintsConsumingTo (v, coll) + local determining = v.determinedBy + local cc = v.constraints + for i = 1, cc:size() do + local c = cc:at(i) + if c ~= determining and c:isSatisfied() then + coll:add(c) + end + end +end + +-- +-- P l a n +-- + +-- +-- A Plan is an ordered list of constraints to be executed in sequence +-- to resatisfy all currently satisfiable constraints in the face of +-- one or more changing inputs. +-- +Plan = class() +function Plan:constructor() + self.v = OrderedCollection.new() +end + +function Plan:addConstraint (c) + self.v:add(c) +end + +function Plan:size () + return self.v:size() +end + +function Plan:constraintAt (index) + return self.v:at(index) +end + +function Plan:execute () + for i = 1, self:size() do + local c = self:constraintAt(i) + c:execute() + end +end + +-- +-- M a i n +-- + +-- +-- This is the standard DeltaBlue benchmark. A long chain of equality +-- constraints is constructed with a stay constraint on one end. An +-- edit constraint is then added to the opposite end and the time is +-- measured for adding and removing this constraint, and extracting +-- and executing a constraint satisfaction plan. There are two cases. +-- In case 1, the added constraint is stronger than the stay +-- constraint and values must propagate down the entire length of the +-- chain. In case 2, the added constraint is weaker than the stay +-- constraint so it cannot be accomodated. The cost in this case is, +-- of course, very low. Typical situations lie somewhere between these +-- two extremes. +-- +local function chainTest(n) + planner = Planner.new() + local prev = nil + local first = nil + local last = nil + + -- Build chain of n equality constraints + for i = 0, n do + local name = "v" .. i; + local v = Variable.new(name) + if prev ~= nil then EqualityConstraint.new(prev, v, Strength.REQUIRED) end + if i == 0 then first = v end + if i == n then last = v end + prev = v + end + + StayConstraint.new(last, Strength.STRONG_DEFAULT) + local edit = EditConstraint.new(first, Strength.PREFERRED) + local edits = OrderedCollection.new() + edits:add(edit) + local plan = planner:extractPlanFromConstraints(edits) + for i = 0, 99 do + first.value = i + plan:execute() + if last.value ~= i then + alert("Chain test failed.") + end + end +end + +local function change(v, newValue) + local edit = EditConstraint.new(v, Strength.PREFERRED) + local edits = OrderedCollection.new() + edits:add(edit) + local plan = planner:extractPlanFromConstraints(edits) + for i = 1, 10 do + v.value = newValue + plan:execute() + end + edit:destroyConstraint() +end + +-- +-- This test constructs a two sets of variables related to each +-- other by a simple linear transformation (scale and offset). The +-- time is measured to change a variable on either side of the +-- mapping and to change the scale and offset factors. +-- +local function projectionTest(n) + planner = Planner.new(); + local scale = Variable.new("scale", 10); + local offset = Variable.new("offset", 1000); + local src = nil + local dst = nil; + + local dests = OrderedCollection.new(); + for i = 0, n - 1 do + src = Variable.new("src" .. i, i); + dst = Variable.new("dst" .. i, i); + dests:add(dst); + StayConstraint.new(src, Strength.NORMAL); + ScaleConstraint.new(src, scale, offset, dst, Strength.REQUIRED); + end + + change(src, 17) + if dst.value ~= 1170 then alert("Projection 1 failed") end + change(dst, 1050) + if src.value ~= 5 then alert("Projection 2 failed") end + change(scale, 5) + for i = 0, n - 2 do + if dests:at(i + 1).value ~= i * 5 + 1000 then + alert("Projection 3 failed") + end + end + change(offset, 2000) + for i = 0, n - 2 do + if dests:at(i + 1).value ~= i * 5 + 2000 then + alert("Projection 4 failed") + end + end +end + +function test() + local t0 = os.clock() + chainTest(1000); + projectionTest(1000); + local t1 = os.clock() + return t1-t0 +end + +bench.runCode(test, "deltablue") diff --git a/bench/tests/life.lua b/bench/tests/life.lua new file mode 100644 index 0000000..51586ad --- /dev/null +++ b/bench/tests/life.lua @@ -0,0 +1,122 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + -- life.lua + -- original by Dave Bollinger posted to lua-l + -- modified to use ANSI terminal escape sequences + -- modified to use for instead of while + + -- local write=io.write + + ALIVE="O" DEAD="-" + + --function delay() -- NOTE: SYSTEM-DEPENDENT, adjust as necessary + -- for i=1,10000 do end + -- local i=os.clock()+1 while(os.clock() 0 do + local xm1,x,xp1,xi=self.w-1,self.w,1,self.w + while xi > 0 do + local sum = self[ym1][xm1] + self[ym1][x] + self[ym1][xp1] + + self[y][xm1] + self[y][xp1] + + self[yp1][xm1] + self[yp1][x] + self[yp1][xp1] + next[y][x] = ((sum==2) and self[y][x]) or ((sum==3) and 1) or 0 + xm1,x,xp1,xi = x,xp1,xp1+1,xi-1 + end + ym1,y,yp1,yi = y,yp1,yp1+1,yi-1 + end + end + + -- output the array to screen + --function _CELLS:draw() + -- local out="" -- accumulate to reduce flicker + -- for y=1,self.h do + -- for x=1,self.w do + -- out=out..(((self[y][x]>0) and ALIVE) or DEAD) + -- end + -- out=out.."\n" + -- end + -- write(out) + --end + + -- constructor + function CELLS(w,h) + local c = ARRAY2D(w,h) + c.spawn = _CELLS.spawn + c.evolve = _CELLS.evolve + c.draw = _CELLS.draw + return c + end + + -- + -- shapes suitable for use with spawn() above + -- + HEART = { 1,0,1,1,0,1,1,1,1; w=3,h=3 } + GLIDER = { 0,0,1,1,0,1,0,1,1; w=3,h=3 } + EXPLODE = { 0,1,0,1,1,1,1,0,1,0,1,0; w=3,h=4 } + FISH = { 0,1,1,1,1,1,0,0,0,1,0,0,0,0,1,1,0,0,1,0; w=5,h=4 } + BUTTERFLY = { 1,0,0,0,1,0,1,1,1,0,1,0,0,0,1,1,0,1,0,1,1,0,0,0,1; w=5,h=5 } + + -- the main routine + function LIFE(w,h) + -- create two arrays + local thisgen = CELLS(w,h) + local nextgen = CELLS(w,h) + + -- create some life + -- about 1000 generations of fun, then a glider steady-state + thisgen:spawn(GLIDER,5,4) + thisgen:spawn(EXPLODE,25,10) + thisgen:spawn(FISH,4,12) + + -- run until break + local gen=1 + -- write("\027[2J") -- ANSI clear screen + while 1 do + thisgen:evolve(nextgen) + thisgen,nextgen = nextgen,thisgen + --write("\027[H") -- ANSI home cursor + --thisgen:draw() + --write("Life - generation ",gen,"\n") + gen=gen+1 + if gen>1000 then break end + --delay() -- no delay + end + end + + + local ts0 = os.clock() + LIFE(40,20) + local ts1 = os.clock() + + return ts1 - ts0 +end + +bench.runCode(test, "life") \ No newline at end of file diff --git a/bench/tests/qsort.lua b/bench/tests/qsort.lua new file mode 100644 index 0000000..9903717 --- /dev/null +++ b/bench/tests/qsort.lua @@ -0,0 +1,79 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + -- two implementations of a sort function + -- this is an example only. Lua has now a built-in function "sort" + + -- extracted from Programming Pearls, page 110 + function qsort(x,l,u,f) + if ly end) + --show("after reverse selection sort",x) + qsort(x,1,n,function (x,y) return x 0 then + local i = item + item + depth = depth - 1 + local left, right = BottomUpTree(i-1, depth), BottomUpTree(i, depth) + return { item, left, right } + else + return { item } + end +end + +local function ItemCheck(tree) + if tree[2] then + return tree[1] + ItemCheck(tree[2]) - ItemCheck(tree[3]) + else + return tree[1] + end +end + +local N = tonumber(arg and arg[1]) or 10 + +local mindepth = 4 +local maxdepth = mindepth + 2 +if maxdepth < N then maxdepth = N end + +do + local stretchdepth = maxdepth + 1 + local stretchtree = BottomUpTree(0, stretchdepth) + print(string.format("stretch tree of depth %d\t check: %d\n", + stretchdepth, ItemCheck(stretchtree))) +end + +local longlivedtree = BottomUpTree(0, maxdepth) + +for depth=mindepth,maxdepth,2 do + local iterations = 2 ^ (maxdepth - depth + mindepth) + local check = 0 + for i=1,iterations do + check = check + ItemCheck(BottomUpTree(1, depth)) + + ItemCheck(BottomUpTree(-1, depth)) + end + print(string.format("%d\t trees of depth %d\t check: %d\n", + iterations*2, depth, check)) +end + +print(string.format("long lived tree of depth %d\t check: %d\n", + maxdepth, ItemCheck(longlivedtree))) + +end + +bench.runCode(test, "binary-trees") diff --git a/bench/tests/shootout/fannkuch-redux.lua b/bench/tests/shootout/fannkuch-redux.lua new file mode 100644 index 0000000..e23f3a2 --- /dev/null +++ b/bench/tests/shootout/fannkuch-redux.lua @@ -0,0 +1,78 @@ +--[[ +MIT License + +Copyright (c) 2017 Gabriel de Quadros Ligneul + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +]] +-- The Computer Language Benchmarks Game +-- http://benchmarksgame.alioth.debian.org/ +-- contributed by Mike Pall + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +local function fannkuch(n) + local p, q, s, sign, maxflips, sum = {}, {}, {}, 1, 0, 0 + for i=1,n do p[i] = i; q[i] = i; s[i] = i end + repeat + -- Copy and flip. + local q1 = p[1] -- Cache 1st element. + if q1 ~= 1 then + for i=2,n do q[i] = p[i] end -- Work on a copy. + local flips = 1 + repeat + local qq = q[q1] + if qq == 1 then -- ... until 1st element is 1. + sum = sum + sign*flips + if flips > maxflips then maxflips = flips end -- New maximum? + break + end + q[q1] = q1 + if q1 >= 4 then + local i, j = 2, q1 - 1 + repeat q[i], q[j] = q[j], q[i]; i = i + 1; j = j - 1; until i >= j + end + q1 = qq; flips = flips + 1 + until false + end + if sign == 1 then + p[2], p[1] = p[1], p[2]; sign = -1 -- Rotate 1<-2. + else + p[2], p[3] = p[3], p[2]; sign = 1 -- Rotate 1<-2 and 1<-2<-3. + for i=3,n do + local sx = s[i] + if sx ~= 1 then s[i] = sx-1; break end + if i == n then return sum, maxflips end -- Out of permutations. + s[i] = i + -- Rotate 1<-...<-i+1. + local t = p[1]; for j=1,i do p[j] = p[j+1] end; p[i+1] = t + end + end + until false +end + +local n = tonumber(arg and arg[1]) or 8 +local sum, flips = fannkuch(n) +print(sum, "\nPfannkuchen(", n, ") = ", flips, "\n") + +end + +bench.runCode(test, "fannkuchen-redux") diff --git a/bench/tests/shootout/fixpoint-fact.lua b/bench/tests/shootout/fixpoint-fact.lua new file mode 100644 index 0000000..8e60e92 --- /dev/null +++ b/bench/tests/shootout/fixpoint-fact.lua @@ -0,0 +1,56 @@ +--[[ +MIT License + +Copyright (c) 2017 Gabriel de Quadros Ligneul + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +]] +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + +-- fixed-point operator +local Z = function (le) + local a = function (f) + return le(function (x) return f(f)(x) end) + end + return a(a) + end + + +-- non-recursive factorial + +local F = function (f) + return function (n) + if n == 0 then return 1 + else return n*f(n-1) end + end + end + +local fat = Z(F) + +local s = 0 +for i = 1, (arg and arg[1]) or 1000 do s = s + fat(i) end +--print(s) + + +end + +bench.runCode(test, "fixpoint-fact") diff --git a/bench/tests/shootout/heapsort.lua b/bench/tests/shootout/heapsort.lua new file mode 100644 index 0000000..fe85859 --- /dev/null +++ b/bench/tests/shootout/heapsort.lua @@ -0,0 +1,79 @@ +--[[ +MIT License + +Copyright (c) 2017 Gabriel de Quadros Ligneul + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +]] +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + +local random, floor = math.random, math.floor +floor = math.ifloor or floor + +function heapsort(n, ra) + local j, i, rra + local l = floor(n/2) + 1 + -- local l = (n//2) + 1 + local ir = n; + while 1 do + if l > 1 then + l = l - 1 + rra = ra[l] + else + rra = ra[ir] + ra[ir] = ra[1] + ir = ir - 1 + if (ir == 1) then + ra[1] = rra + return + end + end + i = l + j = l * 2 + while j <= ir do + if (j < ir) and (ra[j] < ra[j+1]) then + j = j + 1 + end + if rra < ra[j] then + ra[i] = ra[j] + i = j + j = j + i + else + j = ir + 1 + end + end + ra[i] = rra + end +end + +local Num = tonumber((arg and arg[1])) or 4 +for i=1,Num do + local N = tonumber((arg and arg[2])) or 10000 + local a = {} + for i=1,N do a[i] = random() end + heapsort(N, a) + for i=1,N-1 do assert(a[i] <= a[i+1]) end +end + +end + +bench.runCode(test, "heapsort") diff --git a/bench/tests/shootout/mandel.lua b/bench/tests/shootout/mandel.lua new file mode 100644 index 0000000..4be7850 --- /dev/null +++ b/bench/tests/shootout/mandel.lua @@ -0,0 +1,97 @@ +--[[ +MIT License + +Copyright (c) 2017 Gabriel de Quadros Ligneul + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +]] +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +local Complex={type="package"} + +local function complex(x,y) + return setmetatable({ re=x, im=y }, Complex.metatable) +end + +function Complex.conj(x,y) + return complex(x.re,-x.im) +end + +function Complex.norm2(x) + local n=Complex.mul(x,Complex.conj(x)) + return n.re +end + +function Complex.abs(x) + return sqrt(Complex.norm2(x)) +end + +function Complex.add(x,y) + return complex(x.re+y.re,x.im+y.im) +end + +function Complex.mul(x,y) + return complex(x.re*y.re-x.im*y.im,x.re*y.im+x.im*y.re) +end + +Complex.metatable={ + __add = Complex.add, + __mul = Complex.mul, +} + +local function abs(x) + return math.sqrt(Complex.norm2(x)) +end + +xmin=-2.0 xmax=2.0 ymin=-2.0 ymax=2.0 +N=(arg and arg[1]) or 64 + +function level(x,y) + local c=complex(x,y) + local l=0 + local z=c + repeat + z=z*z+c + l=l+1 + until abs(z)>2.0 or l>255 + return l-1 +end + +dx=(xmax-xmin)/N +dy=(ymax-ymin)/N + +print("P2") +print("# mandelbrot set",xmin,xmax,ymin,ymax,N) +print(N,N,255) +local S = 0 +for i=1,N do + local x=xmin+(i-1)*dx + for j=1,N do + local y=ymin+(j-1)*dy + S = S + level(x,y) + end + -- if i % 10 == 0 then print(collectgarbage"count") end +end +print(S) + +end + +bench.runCode(test, "mandel") diff --git a/bench/tests/shootout/n-body.lua b/bench/tests/shootout/n-body.lua new file mode 100644 index 0000000..4034158 --- /dev/null +++ b/bench/tests/shootout/n-body.lua @@ -0,0 +1,131 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + --The Computer Language Benchmarks Game + -- https://salsa.debian.org/benchmarksgame-team/benchmarksgame/ + --contributed by Mike Pall + + local PI = 3.141592653589793 + local SOLAR_MASS = 4 * PI * PI + local DAYS_PER_YEAR = 365.24 + local bodies = { + { --Sun + x = 0, + y = 0, + z = 0, + vx = 0, + vy = 0, + vz = 0, + mass = SOLAR_MASS + }, + { --Jupiter + x = 4.84143144246472090e+00, + y = -1.16032004402742839e+00, + z = -1.03622044471123109e-01, + vx = 1.66007664274403694e-03 * DAYS_PER_YEAR, + vy = 7.69901118419740425e-03 * DAYS_PER_YEAR, + vz = -6.90460016972063023e-05 * DAYS_PER_YEAR, + mass = 9.54791938424326609e-04 * SOLAR_MASS + }, + { --Saturn + x = 8.34336671824457987e+00, + y = 4.12479856412430479e+00, + z = -4.03523417114321381e-01, + vx = -2.76742510726862411e-03 * DAYS_PER_YEAR, + vy = 4.99852801234917238e-03 * DAYS_PER_YEAR, + vz = 2.30417297573763929e-05 * DAYS_PER_YEAR, + mass = 2.85885980666130812e-04 * SOLAR_MASS + }, + { --Uranus + x = 1.28943695621391310e+01, + y = -1.51111514016986312e+01, + z = -2.23307578892655734e-01, + vx = 2.96460137564761618e-03 * DAYS_PER_YEAR, + vy = 2.37847173959480950e-03 * DAYS_PER_YEAR, + vz = -2.96589568540237556e-05 * DAYS_PER_YEAR, + mass = 4.36624404335156298e-05 * SOLAR_MASS + }, + { --Neptune + x = 1.53796971148509165e+01, + y = -2.59193146099879641e+01, + z = 1.79258772950371181e-01, + vx = 2.68067772490389322e-03 * DAYS_PER_YEAR, + vy = 1.62824170038242295e-03 * DAYS_PER_YEAR, + vz = -9.51592254519715870e-05 * DAYS_PER_YEAR, + mass = 5.15138902046611451e-05 * SOLAR_MASS + } + } + + local function advance(bodies, nbody, dt) + for i = 1, nbody do + local bi = bodies[i] + local bix, biy, biz, bimass = bi.x, bi.y, bi.z, bi.mass + local bivx, bivy, bivz = bi.vx, bi.vy, bi.vz + for j = i + 1, nbody do + local bj = bodies[j] + local dx, dy, dz = bix - bj.x, biy - bj.y, biz - bj.z + local distance = math.sqrt(dx*dx + dy*dy + dz*dz) + local mag = dt / (distance * distance * distance) + local bim, bjm = bimass*mag, bj.mass*mag + bivx = bivx - (dx * bjm) + bivy = bivy - (dy * bjm) + bivz = bivz - (dz * bjm) + bj.vx = bj.vx + (dx * bim) + bj.vy = bj.vy + (dy * bim) + bj.vz = bj.vz + (dz * bim) + end + bi.vx = bivx + bi.vy = bivy + bi.vz = bivz + end + for i = 1, nbody do + local bi = bodies[i] + bi.x = bi.x + (dt * bi.vx) + bi.y = bi.y + (dt * bi.vy) + bi.z = bi.z + (dt * bi.vz) + end + end + + local function energy(bodies, nbody) + local e = 0 + for i = 1, nbody do + local bi = bodies[i] + local vx, vy, vz, bim = bi.vx, bi.vy, bi.vz, bi.mass + e = e + (0.5 * bim * (vx*vx + vy*vy + vz*vz)) + for j = i + 1, nbody do + local bj = bodies[j] + local dx, dy, dz = bi.x - bj.x, bi.y - bj.y, bi.z - bj.z + local distance = math.sqrt(dx*dx + dy*dy + dz*dz) + e = e - ((bim * bj.mass) / distance) + end + end + return e + end + + local function offsetMomentum(b, nbody) + local px, py, pz = 0, 0, 0 + for i = 1, nbody do + local bi = b[i] + local bim = bi.mass + px = px + (bi.vx * bim) + py = py + (bi.vy * bim) + pz = pz + (bi.vz * bim) + end + b[1].vx = -px / SOLAR_MASS + b[1].vy = -py / SOLAR_MASS + b[1].vz = -pz / SOLAR_MASS + end + + local N = 20000 + local nbody = #bodies + + local ts0 = os.clock() + offsetMomentum(bodies, nbody) + for i = 1, N do advance(bodies, nbody, 0.01) end + local ts1 = os.clock() + + return ts1 - ts0 +end + +bench.runCode(test, "n-body") \ No newline at end of file diff --git a/bench/tests/shootout/qt.lua b/bench/tests/shootout/qt.lua new file mode 100644 index 0000000..de962a7 --- /dev/null +++ b/bench/tests/shootout/qt.lua @@ -0,0 +1,334 @@ +--[[ +MIT License + +Copyright (c) 2017 Gabriel de Quadros Ligneul + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +]] +-- Julia sets via interval cell-mapping (quadtree version) + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +--require"julia" local f=f + +local io=io +local root,exterior +local cx,cy +local Rxmin,Rxmax,Rymin,Rymax=-2.0,2.0,-2.0,2.0 +local white=1.0 +local black=0.0 +local gray=0.5 +local N=0 +local nE=0 +local E={} +local write=print + +local function output(a1,a2,a3,a4,a5,a6) + --[[write( + a1 or ""," ", + a2 or ""," ", + a3 or ""," ", + a4 or ""," ", + a5 or ""," ", + a6 or ""," \n")]] +end + +local function imul(xmin,xmax,ymin,ymax) + local mm=xmin*ymin + local mM=xmin*ymax + local Mm=xmax*ymin + local MM=xmax*ymax + local m,M=mm,mm + if m>mM then m=mM elseif MMm then m=Mm elseif MMM then m=MM elseif M4.0 +end + +local function inside(xmin,xmax,ymin,ymax) + return xmin^2+ymin^2<=4.0 and xmin^2+ymax^2<=4.0 and + xmax^2+ymin^2<=4.0 and xmax^2+ymax^2<=4.0 +end + +local function newcell() + return {nil,nil,nil,nil,color=gray} +end + +local function addedge(a,b) + nE=nE+1 + E[nE]=b +end + +local function refine(q) + if q.color==gray then + if q[1]==nil then + q[1]=newcell() + q[2]=newcell() + q[3]=newcell() + q[4]=newcell() + else + refine(q[1]) + refine(q[2]) + refine(q[3]) + refine(q[4]) + end + end +end + +local function clip(q,xmin,xmax,ymin,ymax,o,oxmin,oxmax,oymin,oymax) + local ixmin,ixmax,iymin,iymax + if xmin>oxmin then ixmin=xmin else ixmin=oxmin end + if xmax=ixmax then return end + if ymin>oymin then iymin=ymin else iymin=oymin end + if ymax N then -- all queens have been placed? + printsolution(a) + else -- try to place n-th queen + for c = 1, N do + if isplaceok(a, n, c) then + a[n] = c -- place n-th queen at column 'c' + addqueen(a, n + 1) + end + end + end +end + + +-- run the program +addqueen({}, 1) + +end + +bench.runCode(test, "queen") diff --git a/bench/tests/shootout/scimark.lua b/bench/tests/shootout/scimark.lua new file mode 100644 index 0000000..41d97bb --- /dev/null +++ b/bench/tests/shootout/scimark.lua @@ -0,0 +1,442 @@ +------------------------------------------------------------------------------ +-- Lua SciMark (2010-12-20). +-- +-- A literal translation of SciMark 2.0a, written in Java and C. +-- Credits go to the original authors Roldan Pozo and Bruce Miller. +-- See: http://math.nist.gov/scimark2/ +------------------------------------------------------------------------------ +-- Copyright (C) 2006-2010 Mike Pall. All rights reserved. +-- +-- Permission is hereby granted, free of charge, to any person obtaining +-- a copy of this software and associated documentation files (the +-- "Software"), to deal in the Software without restriction, including +-- without limitation the rights to use, copy, modify, merge, publish, +-- distribute, sublicense, and/or sell copies of the Software, and to +-- permit persons to whom the Software is furnished to do so, subject to +-- the following conditions: +-- +-- The above copyright notice and this permission notice shall be +-- included in all copies or substantial portions of the Software. +-- +-- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +-- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +-- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +-- IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +-- CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +-- TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +-- SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +-- +-- [ MIT license: http://www.opensource.org/licenses/mit-license.php ] +------------------------------------------------------------------------------ + +------------------------------------------------------------------------------ +-- Modificatin to be compatible with Lua 5.3 +------------------------------------------------------------------------------ + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +if table and table.unpack then + unpack = table.unpack +end + +------------------------------------------------------------------------------ + +local SCIMARK_VERSION = "2010-12-10" +local SCIMARK_COPYRIGHT = "Copyright (C) 2006-2010 Mike Pall" + +local MIN_TIME = 0.2 +local RANDOM_SEED = 101009 -- Must be odd. +local SIZE_SELECT = "small" + +local benchmarks = { + "FFT", "SOR", "MC", "SPARSE", "LU", + small = { + FFT = { 1024 }, + SOR = { 100 }, + MC = { }, + SPARSE = { 1000, 5000 }, + LU = { 100 }, + }, + large = { + FFT = { 1048576 }, + SOR = { 1000 }, + MC = { }, + SPARSE = { 100000, 1000000 }, + LU = { 1000 }, + }, +} + +local abs, log, sin, floor = math.abs, math.log, math.sin, math.floor +local pi, clock = math.pi, os.clock +local format = string.format + +------------------------------------------------------------------------------ +-- Select array type: Lua tables or native (FFI) arrays +------------------------------------------------------------------------------ + +local darray, iarray + +local function array_init() + if jit and jit.status and jit.status() then + local ok, ffi = pcall(require, "ffi") + if ok then + darray = ffi.typeof("double[?]") + iarray = ffi.typeof("int[?]") + return + end + end + function darray(n) return {} end + iarray = darray +end + +------------------------------------------------------------------------------ +-- This is a Lagged Fibonacci Pseudo-random Number Generator with +-- j, k, M = 5, 17, 31. Pretty weak, but same as C/Java SciMark. +------------------------------------------------------------------------------ + +local rand, rand_init + +if jit and jit.status and jit.status() then + -- LJ2 has bit operations and zero-based arrays (internally). + local bit = require("bit") + local band, sar = bit.band, bit.arshift + function rand_init(seed) + local Rm, Rj, Ri = iarray(17), 16, 11 + for i=0,16 do Rm[i] = 0 end + for i=16,0,-1 do + seed = band(seed*9069, 0x7fffffff) + Rm[i] = seed + end + function rand() + local i = band(Ri+1, sar(Ri-16, 31)) + local j = band(Rj+1, sar(Rj-16, 31)) + Ri, Rj = i, j + local k = band(Rm[i] - Rm[j], 0x7fffffff) + Rm[j] = k + return k * (1.0/2147483647.0) + end + end +else + -- Better for standard Lua with one-based arrays and without bit operations. + function rand_init(seed) + local Rm, Rj = {}, 1 + for i=1,17 do Rm[i] = 0 end + for i=17,1,-1 do + seed = (seed*9069) % (2^31) + Rm[i] = seed + end + function rand() + local j, m = Rj, Rm + local h = j - 5 + if h < 1 then h = h + 17 end + local k = m[h] - m[j] + if k < 0 then k = k + 2147483647 end + m[j] = k + if j < 17 then Rj = j + 1 else Rj = 1 end + return k * (1.0/2147483647.0) + end + end +end + +local function random_vector(n) + local v = darray(n+1) + for x=1,n do v[x] = rand() end + return v +end + +local function random_matrix(m, n) + local a = {} + for y=1,m do + local v = darray(n+1) + a[y] = v + for x=1,n do v[x] = rand() end + end + return a +end + +------------------------------------------------------------------------------ +-- FFT: Fast Fourier Transform. +------------------------------------------------------------------------------ + +local function fft_bitreverse(v, n) + local j = 0 + for i=0,2*n-4,2 do + if i < j then + v[i+1], v[i+2], v[j+1], v[j+2] = v[j+1], v[j+2], v[i+1], v[i+2] + end + local k = n + while k <= j do j = j - k; k = k / 2 end + j = j + k + end +end + +local function fft_transform(v, n, dir) + if n <= 1 then return end + fft_bitreverse(v, n) + local dual = 1 + repeat + local dual2 = 2*dual + for i=1,2*n-1,2*dual2 do + local j = i+dual2 + local ir, ii = v[i], v[i+1] + local jr, ji = v[j], v[j+1] + v[j], v[j+1] = ir - jr, ii - ji + v[i], v[i+1] = ir + jr, ii + ji + end + local theta = dir * pi / dual + local s, s2 = sin(theta), 2.0 * sin(theta * 0.5)^2 + local wr, wi = 1.0, 0.0 + for a=3,dual2-1,2 do + wr, wi = wr - s*wi - s2*wr, wi + s*wr - s2*wi + for i=a,a+2*(n-dual2),2*dual2 do + local j = i+dual2 + local jr, ji = v[j], v[j+1] + local dr, di = wr*jr - wi*ji, wr*ji + wi*jr + local ir, ii = v[i], v[i+1] + v[j], v[j+1] = ir - dr, ii - di + v[i], v[i+1] = ir + dr, ii + di + end + end + dual = dual2 + until dual >= n +end + +function benchmarks.FFT(n) + local l2n = log(n)/log(2) + if l2n % 1 ~= 0 then + io.stderr:write("Error: FFT data length is not a power of 2\n") + os.exit(1) + end + local v = random_vector(n*2) + return function(cycles) + local norm = 1.0 / n + for p=1,cycles do + fft_transform(v, n, -1) + fft_transform(v, n, 1) + for i=1,n*2 do v[i] = v[i] * norm end + end + return ((5*n-2)*l2n + 2*(n+1)) * cycles + end +end + +------------------------------------------------------------------------------ +-- SOR: Jacobi Successive Over-Relaxation. +------------------------------------------------------------------------------ + +local function sor_run(mat, m, n, cycles, omega) + local om4, om1 = omega*0.25, 1.0-omega + m = m - 1 + n = n - 1 + for i=1,cycles do + for y=2,m do + local v, vp, vn = mat[y], mat[y-1], mat[y+1] + for x=2,n do + v[x] = om4*((vp[x]+vn[x])+(v[x-1]+v[x+1])) + om1*v[x] + end + end + end +end + +function benchmarks.SOR(n) + local mat = random_matrix(n, n) + return function(cycles) + sor_run(mat, n, n, cycles, 1.25) + return (n-1)*(n-1)*cycles*6 + end +end + +------------------------------------------------------------------------------ +-- MC: Monte Carlo Integration. +------------------------------------------------------------------------------ + +local function mc_integrate(cycles) + local under_curve = 0 + local rand = rand + for i=1,cycles do + local x = rand() + local y = rand() + if x*x + y*y <= 1.0 then under_curve = under_curve + 1 end + end + return (under_curve/cycles) * 4 +end + +function benchmarks.MC() + return function(cycles) + local res = mc_integrate(cycles) + assert(math.sqrt(cycles)*math.abs(res-math.pi) < 5.0, "bad MC result") + return cycles * 4 -- Way off, but same as SciMark in C/Java. + end +end + +------------------------------------------------------------------------------ +-- Sparse Matrix Multiplication. +------------------------------------------------------------------------------ + +local function sparse_mult(n, cycles, vy, val, row, col, vx) + for p=1,cycles do + for r=1,n do + local sum = 0 + for i=row[r],row[r+1]-1 do sum = sum + vx[col[i]] * val[i] end + vy[r] = sum + end + end +end + +function benchmarks.SPARSE(n, nz) + local nr = floor(nz/n) + local anz = nr*n + local vx = random_vector(n) + local val = random_vector(anz) + local vy, col, row = darray(n+1), iarray(nz+1), iarray(n+2) + row[1] = 1 + for r=1,n do + local step = floor(r/nr) + if step < 1 then step = 1 end + local rr = row[r] + row[r+1] = rr+nr + for i=0,nr-1 do col[rr+i] = 1+i*step end + end + return function(cycles) + sparse_mult(n, cycles, vy, val, row, col, vx) + return anz*cycles*2 + end +end + +------------------------------------------------------------------------------ +-- LU: Dense Matrix Factorization. +------------------------------------------------------------------------------ + +local function lu_factor(a, pivot, m, n) + local min_m_n = m < n and m or n + for j=1,min_m_n do + local jp, t = j, abs(a[j][j]) + for i=j+1,m do + local ab = abs(a[i][j]) + if ab > t then + jp = i + t = ab + end + end + pivot[j] = jp + if a[jp][j] == 0 then error("zero pivot") end + if jp ~= j then a[j], a[jp] = a[jp], a[j] end + if j < m then + local recp = 1.0 / a[j][j] + for k=j+1,m do + local v = a[k] + v[j] = v[j] * recp + end + end + if j < min_m_n then + for i=j+1,m do + local vi, vj = a[i], a[j] + local eij = vi[j] + for k=j+1,n do vi[k] = vi[k] - eij * vj[k] end + end + end + end +end + +local function matrix_alloc(m, n) + local a = {} + for y=1,m do a[y] = darray(n+1) end + return a +end + +local function matrix_copy(dst, src, m, n) + for y=1,m do + local vd, vs = dst[y], src[y] + for x=1,n do vd[x] = vs[x] end + end +end + +function benchmarks.LU(n) + local mat = random_matrix(n, n) + local tmp = matrix_alloc(n, n) + local pivot = iarray(n+1) + return function(cycles) + for i=1,cycles do + matrix_copy(tmp, mat, n, n) + lu_factor(tmp, pivot, n, n) + end + return 2.0/3.0*n*n*n*cycles + end +end + +------------------------------------------------------------------------------ +-- Main program. +------------------------------------------------------------------------------ + +local function printf(...) + print(format(...)) +end + +local function fmtparams(p1, p2) + if p2 then return format("[%d, %d]", p1, p2) + elseif p1 then return format("[%d]", p1) end + return "" +end + +local function measure(min_time, name, ...) + array_init() + rand_init(RANDOM_SEED) + local run = benchmarks[name](...) + --[[local cycles = 1 + repeat + local tm = clock() + local flops = run(cycles, ...) + tm = clock() - tm + if tm >= min_time then + local res = flops / tm * 1.0e-6 + local p1, p2 = ... + printf("%-7s %8.2f %s\n", name, res, fmtparams(...)) + return res + end + cycles = cycles * 2 + until false]] + + run(10, ...) + return 10 +end + +printf("Lua SciMark %s based on SciMark 2.0a. %s.\n\n", + SCIMARK_VERSION, SCIMARK_COPYRIGHT) + +while arg and arg[1] do + local a = table.remove(arg, 1) + if a == "-noffi" then + package.preload.ffi = nil + elseif a == "-small" then + SIZE_SELECT = "small" + elseif a == "-large" then + SIZE_SELECT = "large" + elseif benchmarks[a] then + local p = benchmarks[SIZE_SELECT][a] + measure(MIN_TIME, a, tonumber(arg[1]) or p[1], tonumber(arg[2]) or p[2]) + return + else + printf("Usage: scimark [-noffi] [-small|-large] [BENCH params...]\n\n") + printf("BENCH -small -large\n") + printf("---------------------------------------\n") + for _,name in ipairs(benchmarks) do + printf("%-7s %-13s %s\n", name, + fmtparams(unpack(benchmarks.small[name])), + fmtparams(unpack(benchmarks.large[name]))) + end + printf("\n") + os.exit(1) + end +end + +local params = benchmarks[SIZE_SELECT] +local sum = 0 +for _,name in ipairs(benchmarks) do + sum = sum + measure(MIN_TIME, name, unpack(params[name])) +end +--printf("\nSciMark %8.2f [%s problem sizes]\n", sum / #benchmarks, SIZE_SELECT) + +end + +bench.runCode(test, "scimark") diff --git a/bench/tests/shootout/spectral-norm.lua b/bench/tests/shootout/spectral-norm.lua new file mode 100644 index 0000000..6d217aa --- /dev/null +++ b/bench/tests/shootout/spectral-norm.lua @@ -0,0 +1,74 @@ +--[[ +MIT License + +Copyright (c) 2017 Gabriel de Quadros Ligneul + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +]] +-- The Computer Language Benchmarks Game +-- http://benchmarksgame.alioth.debian.org/ +-- contributed by Mike Pall + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +local function A(i, j) + local ij = i+j-1 + return 1.0 / (ij * (ij-1) * 0.5 + i) +end + +local function Av(x, y, N) + for i=1,N do + local a = 0 + for j=1,N do a = a + x[j] * A(i, j) end + y[i] = a + end +end + +local function Atv(x, y, N) + for i=1,N do + local a = 0 + for j=1,N do a = a + x[j] * A(j, i) end + y[i] = a + end +end + +local function AtAv(x, y, t, N) + Av(x, t, N) + Atv(t, y, N) +end + +local N = tonumber(arg and arg[1]) or 100 +local u, v, t = {}, {}, {} +for i=1,N do u[i] = 1 end + +for i=1,10 do AtAv(u, v, t, N) AtAv(v, u, t, N) end + +local vBv, vv = 0, 0 +for i=1,N do + local ui, vi = u[i], v[i] + vBv = vBv + ui*vi + vv = vv + vi*vi +end +print(string.format("%0.9f\n", math.sqrt(vBv / vv))) + +end + +bench.runCode(test, "spectral-norm") diff --git a/bench/tests/sieve.lua b/bench/tests/sieve.lua new file mode 100644 index 0000000..718ec48 --- /dev/null +++ b/bench/tests/sieve.lua @@ -0,0 +1,44 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + -- the sieve of of Eratosthenes programmed with coroutines + -- typical usage: lua -e N=1000 sieve.lua | column + + -- generate all the numbers from 2 to n + function gen (n) + return coroutine.wrap(function () + for i=2,n do coroutine.yield(i) end + end) + end + + -- filter the numbers generated by `g', removing multiples of `p' + function filter (p, g) + return coroutine.wrap(function () + while 1 do + local n = g() + if n == nil then return end + if n % p ~= 0 then coroutine.yield(n) end + end + end) + end + + local ts0 = os.clock() + + for loops=1,100 do + N = 1000 + x = gen(N) -- generate primes up to N + while 1 do + local n = x() -- pick a number until done + if n == nil then break end + -- print(n) -- must be a prime number + x = filter(n, x) -- now remove its multiples + end + end + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "sieve") \ No newline at end of file diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua new file mode 100644 index 0000000..6d40406 --- /dev/null +++ b/bench/tests/sunspider/3d-cube.lua @@ -0,0 +1,381 @@ +-- 3D Cube Rotation +-- http://www.speich.net/computer/moztesting/3d.htm +-- Created by Simon Speich + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +local Q = {} +local MTrans = {}; -- transformation matrix +local MQube = {} -- position information of qube +local I = {} -- entity matrix +local Origin = {} +local Testing = {} +local LoopTimer; + +local validation = { + [20] = 2889, + [40] = 2889, + [80] = 2889, + [160] = 2889 +}; + +local DisplArea = {} +DisplArea.Width = 300; +DisplArea.Height = 300; + +function DrawLine(From, To) + local x1 = From.V[1]; + local x2 = To.V[1]; + local y1 = From.V[2]; + local y2 = To.V[2]; + local dx = math.abs(x2 - x1); + local dy = math.abs(y2 - y1); + local x = x1; + local y = y1; + local IncX1, IncY1; + local IncX2, IncY2; + local Den; + local Num; + local NumAdd; + local NumPix; + + if (x2 >= x1) then IncX1 = 1; IncX2 = 1; + else IncX1 = -1; IncX2 = -1; end + + if (y2 >= y1) then IncY1 = 1; IncY2 = 1; + else IncY1 = -1; IncY2 = -1; end + + if (dx >= dy) then + IncX1 = 0; + IncY2 = 0; + Den = dx; + Num = dx / 2; + NumAdd = dy; + NumPix = dx; + else + IncX2 = 0; + IncY1 = 0; + Den = dy; + Num = dy / 2; + NumAdd = dx; + NumPix = dy; + end + + NumPix = math.floor(Q.LastPx + NumPix + 0.5); + + local i = Q.LastPx; + while i < NumPix do + Num = Num + NumAdd; + if (Num >= Den) then + Num = Num - Den; + x = x + IncX1; + y = y + IncY1; + end + x = x + IncX2; + y = y + IncY2; + + i = i + 1; + end + Q.LastPx = NumPix; +end + +function CalcCross(V0, V1) + local Cross = {}; + Cross[1] = V0[2]*V1[3] - V0[3]*V1[2]; + Cross[2] = V0[3]*V1[1] - V0[1]*V1[3]; + Cross[3] = V0[1]*V1[2] - V0[2]*V1[1]; + return Cross; +end + +function CalcNormal(V0, V1, V2) + local A = {}; local B = {}; + for i = 1,3 do + A[i] = V0[i] - V1[i]; + B[i] = V2[i] - V1[i]; + end + A = CalcCross(A, B); + local Length = math.sqrt(A[1]*A[1] + A[2]*A[2] + A[3]*A[3]); + for i = 1,3 do A[i] = A[i] / Length; end + A[4] = 1; + return A; +end + +function CreateP(X,Y,Z) + local result = {} + result.V = {X,Y,Z,1}; + return result +end + +-- multiplies two matrices +function MMulti(M1, M2) + local M = {{},{},{},{}}; + local i = 1; + local j = 1; + while i <= 4 do + j = 1; + while j <= 4 do + M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j]; j = j + 1 + end + + i = i + 1 + end + return M; +end + +-- multiplies matrix with vector +function VMulti(M, V) + local Vect = {}; + local i = 1; + while i <= 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; i = i + 1 end + return Vect; +end + +function VMulti2(M, V) + local Vect = {}; + local i = 1; + while i < 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; i = i + 1 end + return Vect; +end + +-- add to matrices +function MAdd(M1, M2) + local M = {{},{},{},{}}; + local i = 1; + local j = 1; + while i <= 4 do + j = 1; + while j <= 4 do M[i][j] = M1[i][j] + M2[i][j]; j = j + 1 end + + i = i + 1 + end + return M; +end + +function Translate(M, Dx, Dy, Dz) + local T = { + {1,0,0,Dx}, + {0,1,0,Dy}, + {0,0,1,Dz}, + {0,0,0,1} + }; + return MMulti(T, M); +end + +function RotateX(M, Phi) + local a = Phi; + a = a * math.pi / 180; + local Cos = math.cos(a); + local Sin = math.sin(a); + local R = { + {1,0,0,0}, + {0,Cos,-Sin,0}, + {0,Sin,Cos,0}, + {0,0,0,1} + }; + return MMulti(R, M); +end + +function RotateY(M, Phi) + local a = Phi; + a = a * math.pi / 180; + local Cos = math.cos(a); + local Sin = math.sin(a); + local R = { + {Cos,0,Sin,0}, + {0,1,0,0}, + {-Sin,0,Cos,0}, + {0,0,0,1} + }; + return MMulti(R, M); +end + +function RotateZ(M, Phi) + local a = Phi; + a = a * math.pi / 180; + local Cos = math.cos(a); + local Sin = math.sin(a); + local R = { + {Cos,-Sin,0,0}, + {Sin,Cos,0,0}, + {0,0,1,0}, + {0,0,0,1} + }; + return MMulti(R, M); +end + +function DrawQube() + -- calc current normals + local CurN = {}; + local i = 5; + Q.LastPx = 0; + while i > -1 do CurN[i+1] = VMulti2(MQube, Q.Normal[i+1]); i = i - 1 end + if (CurN[1][3] < 0) then + if (not Q.Line[1]) then DrawLine(Q[1], Q[2]); Q.Line[1] = true; end + if (not Q.Line[2]) then DrawLine(Q[2], Q[3]); Q.Line[2] = true; end + if (not Q.Line[3]) then DrawLine(Q[3], Q[4]); Q.Line[3] = true; end + if (not Q.Line[4]) then DrawLine(Q[4], Q[1]); Q.Line[4] = true; end + end + if (CurN[2][3] < 0) then + if (not Q.Line[3]) then DrawLine(Q[4], Q[3]); Q.Line[3] = true; end + if (not Q.Line[10]) then DrawLine(Q[3], Q[7]); Q.Line[10] = true; end + if (not Q.Line[7]) then DrawLine(Q[7], Q[8]); Q.Line[7] = true; end + if (not Q.Line[11]) then DrawLine(Q[8], Q[4]); Q.Line[11] = true; end + end + if (CurN[3][3] < 0) then + if (not Q.Line[5]) then DrawLine(Q[5], Q[6]); Q.Line[6] = true; end + if (not Q.Line[6]) then DrawLine(Q[6], Q[7]); Q.Line[6] = true; end + if (not Q.Line[7]) then DrawLine(Q[7], Q[8]); Q.Line[7] = true; end + if (not Q.Line[8]) then DrawLine(Q[8], Q[5]); Q.Line[8] = true; end + end + if (CurN[4][3] < 0) then + if (not Q.Line[5]) then DrawLine(Q[5], Q[6]); Q.Line[5] = true; end + if (not Q.Line[9]) then DrawLine(Q[6], Q[2]); Q.Line[9] = true; end + if (not Q.Line[1]) then DrawLine(Q[2], Q[1]); Q.Line[1] = true; end + if (not Q.Line[12]) then DrawLine(Q[1], Q[5]); Q.Line[12] = true; end + end + if (CurN[5][3] < 0) then + if (not Q.Line[12]) then DrawLine(Q[5], Q[1]); Q.Line[12] = true; end + if (not Q.Line[4]) then DrawLine(Q[1], Q[4]); Q.Line[4] = true; end + if (not Q.Line[11]) then DrawLine(Q[4], Q[8]); Q.Line[11] = true; end + if (not Q.Line[8]) then DrawLine(Q[8], Q[5]); Q.Line[8] = true; end + end + if (CurN[6][3] < 0) then + if (not Q.Line[9]) then DrawLine(Q[2], Q[6]); Q.Line[9] = true; end + if (not Q.Line[6]) then DrawLine(Q[6], Q[7]); Q.Line[6] = true; end + if (not Q.Line[10]) then DrawLine(Q[7], Q[3]); Q.Line[10] = true; end + if (not Q.Line[2]) then DrawLine(Q[3], Q[2]); Q.Line[2] = true; end + end + Q.Line = {false,false,false,false,false,false,false,false,false,false,false,false} + Q.LastPx = 0; +end + +function Loop() + if (Testing.LoopCount > Testing.LoopMax) then return; end + local TestingStr = tostring(Testing.LoopCount); + while (#TestingStr < 3) do TestingStr = "0" .. TestingStr; end + MTrans = Translate(I, -Q[9].V[1], -Q[9].V[2], -Q[9].V[3]); + MTrans = RotateX(MTrans, 1); + MTrans = RotateY(MTrans, 3); + MTrans = RotateZ(MTrans, 5); + MTrans = Translate(MTrans, Q[9].V[1], Q[9].V[2], Q[9].V[3]); + MQube = MMulti(MTrans, MQube); + local i = 8; + while i > -1 do + Q[i+1].V = VMulti(MTrans, Q[i+1].V); + i = i - 1 + end + DrawQube(); + Testing.LoopCount = Testing.LoopCount + 1; + Loop(); +end + +function Init(CubeSize) + -- init/reset vars + Origin.V = {150,150,20,1}; + Testing.LoopCount = 0; + Testing.LoopMax = 50; + Testing.TimeMax = 0; + Testing.TimeAvg = 0; + Testing.TimeMin = 0; + Testing.TimeTemp = 0; + Testing.TimeTotal = 0; + Testing.Init = false; + + -- transformation matrix + MTrans = { + {1,0,0,0}, + {0,1,0,0}, + {0,0,1,0}, + {0,0,0,1} + }; + + -- position information of qube + MQube = { + {1,0,0,0}, + {0,1,0,0}, + {0,0,1,0}, + {0,0,0,1} + }; + + -- entity matrix + I = { + {1,0,0,0}, + {0,1,0,0}, + {0,0,1,0}, + {0,0,0,1} + }; + + -- create qube + Q[1] = CreateP(-CubeSize,-CubeSize, CubeSize); + Q[2] = CreateP(-CubeSize, CubeSize, CubeSize); + Q[3] = CreateP( CubeSize, CubeSize, CubeSize); + Q[4] = CreateP( CubeSize,-CubeSize, CubeSize); + Q[5] = CreateP(-CubeSize,-CubeSize,-CubeSize); + Q[6] = CreateP(-CubeSize, CubeSize,-CubeSize); + Q[7] = CreateP( CubeSize, CubeSize,-CubeSize); + Q[8] = CreateP( CubeSize,-CubeSize,-CubeSize); + + -- center of gravity + Q[9] = CreateP(0, 0, 0); + + -- anti-clockwise edge check + Q.Edge = {{1,2,3},{4,5,7},{8,7,6},{5,6,2},{5,1,4},{2,6,7}}; + + -- calculate squad normals + Q.Normal = {}; + for i = 1,#Q.Edge do + Q.Normal[i] = CalcNormal(Q[Q.Edge[i][1]].V, Q[Q.Edge[i][2]].V, Q[Q.Edge[i][3]].V); + end + + -- line drawn ? + Q.Line = {false,false,false,false,false,false,false,false,false,false,false,false}; + + -- create line pixels + Q.NumPx = 9 * 2 * CubeSize; + for i = 1,Q.NumPx do CreateP(0,0,0); end + + MTrans = Translate(MTrans, Origin.V[1], Origin.V[2], Origin.V[3]); + MQube = MMulti(MTrans, MQube); + + local i = 0; + while i < 9 do + Q[i+1].V = VMulti(MTrans, Q[i+1].V); + i = i + 1 + end + DrawQube(); + Testing.Init = true; + Loop(); + + -- Perform a simple sum-based verification. + local sum = 0; + for i = 1,#Q do + local vector = Q[i].V; + for j = 1,#vector do + sum = sum + vector[j]; + end + end + if (math.floor(sum) ~= validation[CubeSize]) then + assert(false, "Error: bad vector sum for CubeSize = " .. CubeSize .. "; expected " .. validation[CubeSize] .. " but got " .. math.floor(sum)) + end +end + +local i = 20 +while i <= 160 do + Init(i); + i = i * 2 +end + +Q = nil; +MTrans = nil; +MQube = nil; +I = nil; +Origin = nil; +Testing = nil; +LoopTime = nil; +DisplArea = nil; + +end + +bench.runCode(test, "3d-cube") diff --git a/bench/tests/sunspider/3d-morph.lua b/bench/tests/sunspider/3d-morph.lua new file mode 100644 index 0000000..f73f173 --- /dev/null +++ b/bench/tests/sunspider/3d-morph.lua @@ -0,0 +1,75 @@ +--[[ + * Copyright (C) 2007 Apple Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +]] + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +local loops = 15 +local nx = 120 +local nz = 120 + +function morph(a, f) + local PI2nx = math.pi * 8/nx + local sin = math.sin + local f30 = -(50 * sin(f*math.pi*2)) + + for i = 0,nz-1 do + for j = 0,nx-1 do + a[3*(i*nx+j)+1] = sin((j-1) * PI2nx ) * -f30 + end + end +end + + +local a = {} +for i = 0,nx*nz*3-1 do + a[i] = 0 +end + +for i = 0,loops-1 do + morph(a, i/loops) +end + +testOutput = 0; +for i = 0,nx-1 do + testOutput = testOutput + a[3*(i*nx+i)+1]; +end + +a = nil; + +-- This has to be an approximate test since ECMAscript doesn't formally specify +-- what sin() returns. Even if it did specify something like for example what Java 7 +-- says - that sin() has to return a value within 1 ulp of exact - then we still +-- would not be able to do an exact test here since that would allow for just enough +-- low-bit slop to create possibly big errors due to testOutput being a sum. +local epsilon = 1e-13; +if (math.abs(testOutput) >= epsilon) then + assert(false, "Error: bad test output: expected magnitude below " .. epsilon .. " but got " .. testOutput); +end + +end + +bench.runCode(test, "3d-morph") diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua new file mode 100644 index 0000000..60e4f61 --- /dev/null +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -0,0 +1,502 @@ +--[[ + * Copyright (C) 2007 Apple Inc. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +]] +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +local size = 30 + +function createVector(x,y,z) + return { x,y,z }; +end + +function sqrLengthVector(self) + return self[1] * self[1] + self[2] * self[2] + self[3] * self[3]; +end + +function lengthVector(self) + return math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); +end + +function addVector(self, v) + self[1] = self[1] + v[1]; + self[2] = self[2] + v[2]; + self[3] = self[3] + v[3]; + return self; +end + +function subVector(self, v) + self[1] = self[1] - v[1]; + self[2] = self[2] - v[2]; + self[3] = self[3] - v[3]; + return self; +end + +function scaleVector(self, scale) + self[1] = self[1] * scale; + self[2] = self[2] * scale; + self[3] = self[3] * scale; + return self; +end + +function normaliseVector(self) + local len = math.sqrt(self[1] * self[1] + self[2] * self[2] + self[3] * self[3]); + self[1] = self[1] / len; + self[2] = self[2] / len; + self[3] = self[3] / len; + return self; +end + +function add(v1, v2) + return { v1[1] + v2[1], v1[2] + v2[2], v1[3] + v2[3] }; +end + +function sub(v1, v2) + return { v1[1] - v2[1], v1[2] - v2[2], v1[3] - v2[3] }; +end + +function scalev(v1, v2) + return { v1[1] * v2[1], v1[2] * v2[2], v1[3] * v2[3] }; +end + +function dot(v1, v2) + return v1[1] * v2[1] + v1[2] * v2[2] + v1[3] * v2[3]; +end + +function scale(v, scale) + return { v[1] * scale, v[2] * scale, v[3] * scale }; +end + +function cross(v1, v2) + return { v1[2] * v2[3] - v1[3] * v2[2], + v1[3] * v2[1] - v1[1] * v2[3], + v1[1] * v2[2] - v1[2] * v2[1] }; + +end + +function normalise(v) + local len = lengthVector(v); + return { v[1] / len, v[2] / len, v[3] / len }; +end + +function transformMatrix(self, v) + local vals = self; + local x = vals[1] * v[1] + vals[2] * v[2] + vals[3] * v[3] + vals[4]; + local y = vals[5] * v[1] + vals[6] * v[2] + vals[7] * v[3] + vals[8]; + local z = vals[9] * v[1] + vals[10] * v[2] + vals[11] * v[3] + vals[12]; + return { x, y, z }; +end + +function invertMatrix(self) + local temp = {} + local tx = -self[4]; + local ty = -self[8]; + local tz = -self[12]; + for h = 0,2 do + for v = 0,2 do + temp[h + v * 4 + 1] = self[v + h * 4 + 1]; + end + end + + for i = 0,10 do + self[i + 1] = temp[i + 1]; + end + + self[4] = tx * self[1] + ty * self[2] + tz * self[3]; + self[8] = tx * self[5] + ty * self[6] + tz * self[7]; + self[12] = tx * self[9] + ty * self[10] + tz * self[11]; + return self; +end + +-- Triangle intersection using barycentric coord method +function Triangle(p1, p2, p3) + local this = {} + + local edge1 = sub(p3, p1); + local edge2 = sub(p2, p1); + local normal = cross(edge1, edge2); + if (math.abs(normal[1]) > math.abs(normal[2])) then + if (math.abs(normal[1]) > math.abs(normal[3])) then + this.axis = 0; + else + this.axis = 2; + end + else + if (math.abs(normal[2]) > math.abs(normal[3])) then + this.axis = 1; + else + this.axis = 2; + end + end + + local u = (this.axis + 1) % 3; + local v = (this.axis + 2) % 3; + local u1 = edge1[u + 1]; + local v1 = edge1[v + 1]; + + local u2 = edge2[u + 1]; + local v2 = edge2[v + 1]; + this.normal = normalise(normal); + this.nu = normal[u + 1] / normal[this.axis + 1]; + this.nv = normal[v + 1] / normal[this.axis + 1]; + this.nd = dot(normal, p1) / normal[this.axis + 1]; + local det = u1 * v2 - v1 * u2; + this.eu = p1[u + 1]; + this.ev = p1[v + 1]; + this.nu1 = u1 / det; + this.nv1 = -v1 / det; + this.nu2 = v2 / det; + this.nv2 = -u2 / det; + this.material = { 0.7, 0.7, 0.7 }; + + + this.intersect = function(self, orig, dir, near, far) + local u = (self.axis + 1) % 3; + local v = (self.axis + 2) % 3; + local d = dir[self.axis + 1] + self.nu * dir[u + 1] + self.nv * dir[v + 1]; + local t = (self.nd - orig[self.axis + 1] - self.nu * orig[u + 1] - self.nv * orig[v + 1]) / d; + + if (t < near or t > far) then + return nil; + end + + local Pu = orig[u + 1] + t * dir[u + 1] - self.eu; + local Pv = orig[v + 1] + t * dir[v + 1] - self.ev; + local a2 = Pv * self.nu1 + Pu * self.nv1; + + if (a2 < 0) then + return nil; + end + + local a3 = Pu * self.nu2 + Pv * self.nv2; + if (a3 < 0) then + return nil; + end + + if ((a2 + a3) > 1) then + return nil; + end + + return t; + end + + return this +end + +function Scene(a_triangles) + local this = {} + this.triangles = a_triangles; + this.lights = {}; + this.ambient = {0,0,0}; + this.background = {0.8,0.8,1}; + + this.intersect = function(self, origin, dir, near, far) + local closest = nil; + for i = 0,#self.triangles-1 do + local triangle = self.triangles[i + 1]; + local d = triangle:intersect(origin, dir, near, far); + if (d == nil or d > far or d < near) then + -- continue; + else + far = d; + closest = triangle; + end + end + + if (not closest) then + return { self.background[1],self.background[2],self.background[3] }; + end + + local normal = closest.normal; + local hit = add(origin, scale(dir, far)); + if (dot(dir, normal) > 0) then + normal = { -normal[1], -normal[2], -normal[3] }; + end + + local colour = nil; + if (closest.shader) then + colour = closest.shader(closest, hit, dir); + else + colour = closest.material; + end + + -- do reflection + local reflected = nil; + if (colour.reflection or 0 > 0.001) then + local reflection = addVector(scale(normal, -2*dot(dir, normal)), dir); + reflected = self:intersect(hit, reflection, 0.0001, 1000000); + if (colour.reflection >= 0.999999) then + return reflected; + end + end + + local l = { self.ambient[1], self.ambient[2], self.ambient[3] }; + + for i = 0,#self.lights-1 do + local light = self.lights[i + 1]; + local toLight = sub(light, hit); + local distance = lengthVector(toLight); + scaleVector(toLight, 1.0/distance); + distance = distance - 0.0001; + + if (self:blocked(hit, toLight, distance)) then + -- continue; + else + local nl = dot(normal, toLight); + if (nl > 0) then + addVector(l, scale(light.colour, nl)); + end + end + end + + l = scalev(l, colour); + if (reflected) then + l = addVector(scaleVector(l, 1 - colour.reflection), scaleVector(reflected, colour.reflection)); + end + + return l; + end + + this.blocked = function(self, O, D, far) + local near = 0.0001; + local closest = nil; + for i = 0,#self.triangles-1 do + local triangle = self.triangles[i + 1]; + local d = triangle:intersect(O, D, near, far); + if (d == nil or d > far or d < near) then + --continue; + else + return true; + end + end + + return false; + end + + return this +end + +local zero = { 0,0,0 }; + +-- this camera code is from notes i made ages ago, it is from *somewhere* -- i cannot remember where +-- that somewhere is +function Camera(origin, lookat, up) + local this = {} + + local zaxis = normaliseVector(subVector(lookat, origin)); + local xaxis = normaliseVector(cross(up, zaxis)); + local yaxis = normaliseVector(cross(xaxis, subVector({ 0,0,0 }, zaxis))); + local m = {}; + m[1] = xaxis[1]; m[2] = xaxis[2]; m[3] = xaxis[3]; + m[5] = yaxis[1]; m[6] = yaxis[2]; m[7] = yaxis[3]; + m[9] = zaxis[1]; m[10] = zaxis[2]; m[11] = zaxis[3]; + m[4] = 0; m[8] = 0; m[12] = 0; + invertMatrix(m); + m[4] = 0; m[8] = 0; m[12] = 0; + this.origin = origin; + this.directions = {}; + this.directions[1] = normalise({ -0.7, 0.7, 1 }); + this.directions[2] = normalise({ 0.7, 0.7, 1 }); + this.directions[3] = normalise({ 0.7, -0.7, 1 }); + this.directions[4] = normalise({ -0.7, -0.7, 1 }); + this.directions[1] = transformMatrix(m, this.directions[1]); + this.directions[2] = transformMatrix(m, this.directions[2]); + this.directions[3] = transformMatrix(m, this.directions[3]); + this.directions[4] = transformMatrix(m, this.directions[4]); + + this.generateRayPair = function(self, y) + rays = { {}, {} } + rays[1].origin = self.origin; + rays[2].origin = self.origin; + rays[1].dir = addVector(scale(self.directions[1], y), scale(self.directions[4], 1 - y)); + rays[2].dir = addVector(scale(self.directions[2], y), scale(self.directions[3], 1 - y)); + return rays; + end + + function renderRows(camera, scene, pixels, width, height, starty, stopy) + for y = starty,stopy-1 do + local rays = camera:generateRayPair(y / height); + for x = 0,width-1 do + local xp = x / width; + local origin = addVector(scale(rays[1].origin, xp), scale(rays[2].origin, 1 - xp)); + local dir = normaliseVector(addVector(scale(rays[1].dir, xp), scale(rays[2].dir, 1 - xp))); + local l = scene:intersect(origin, dir, 0, math.huge); + pixels[y + 1][x + 1] = l; + end + end + end + + this.render = function(self, scene, pixels, width, height) + local cam = self; + local row = 0; + renderRows(cam, scene, pixels, width, height, 0, height); + end + + return this +end + +function raytraceScene() + local startDate = 13154863; + local numTriangles = 2 * 6; + local triangles = {}; -- numTriangles); + local tfl = createVector(-10, 10, -10); + local tfr = createVector( 10, 10, -10); + local tbl = createVector(-10, 10, 10); + local tbr = createVector( 10, 10, 10); + local bfl = createVector(-10, -10, -10); + local bfr = createVector( 10, -10, -10); + local bbl = createVector(-10, -10, 10); + local bbr = createVector( 10, -10, 10); + + -- cube!!! + -- front + local i = 0; + + triangles[i + 1] = Triangle(tfl, tfr, bfr); i = i + 1; + triangles[i + 1] = Triangle(tfl, bfr, bfl); i = i + 1; + -- back + triangles[i + 1] = Triangle(tbl, tbr, bbr); i = i + 1; + triangles[i + 1] = Triangle(tbl, bbr, bbl); i = i + 1; + -- triangles[i-1].material = [0.7,0.2,0.2]; + -- triangles[i-1].material.reflection = 0.8; + -- left + triangles[i + 1] = Triangle(tbl, tfl, bbl); i = i + 1; + -- triangles[i-1].reflection = 0.6; + triangles[i + 1] = Triangle(tfl, bfl, bbl); i = i + 1; + -- triangles[i-1].reflection = 0.6; + -- right + triangles[i + 1] = Triangle(tbr, tfr, bbr); i = i + 1; + triangles[i + 1] = Triangle(tfr, bfr, bbr); i = i + 1; + -- top + triangles[i + 1] = Triangle(tbl, tbr, tfr); i = i + 1; + triangles[i + 1] = Triangle(tbl, tfr, tfl); i = i + 1; + -- bottom + triangles[i + 1] = Triangle(bbl, bbr, bfr); i = i + 1; + triangles[i + 1] = Triangle(bbl, bfr, bfl); i = i + 1; + + -- Floor!!!! + local green = createVector(0.0, 0.4, 0.0); + green.reflection = 0; -- + local grey = createVector(0.4, 0.4, 0.4); + grey.reflection = 1.0; + local floorShader = function(tri, pos, view) + local x = ((pos[1]/32) % 2 + 2) % 2; + local z = ((pos[3]/32 + 0.3) % 2 + 2) % 2; + if ((x < 1) ~= (z < 1)) then + --in the real world we use the fresnel term... + -- local angle = 1-dot(view, tri.normal); + -- angle *= angle; + -- angle *= angle; + -- angle *= angle; + --grey.reflection = angle; + return grey; + else + return green; + end + end + + local ffl = createVector(-1000, -30, -1000); + local ffr = createVector( 1000, -30, -1000); + local fbl = createVector(-1000, -30, 1000); + local fbr = createVector( 1000, -30, 1000); + triangles[i + 1] = Triangle(fbl, fbr, ffr); i = i + 1; + triangles[i-1 + 1].shader = floorShader; + triangles[i + 1] = Triangle(fbl, ffr, ffl); i = i + 1; + triangles[i-1 + 1].shader = floorShader; + + local _scene = Scene(triangles); + _scene.lights[1] = createVector(20, 38, -22); + _scene.lights[1].colour = createVector(0.7, 0.3, 0.3); + _scene.lights[2] = createVector(-23, 40, 17); + _scene.lights[2].colour = createVector(0.7, 0.3, 0.3); + _scene.lights[3] = createVector(23, 20, 17); + _scene.lights[3].colour = createVector(0.7, 0.7, 0.7); + _scene.ambient = createVector(0.1, 0.1, 0.1); + -- _scene.background = createVector(0.7, 0.7, 1.0); + + local pixels = {}; + for y = 0,size-1 do + pixels[y + 1] = {}; + for x = 0,size-1 do + pixels[y + 1][x + 1] = 0; + end + end + + local _camera = Camera(createVector(-40, 40, 40), createVector(0, 0, 0), createVector(0, 1, 0)); + _camera:render(_scene, pixels, size, size); + + return pixels; +end + +function arrayToCanvasCommands(pixels) + local s = 'Test\nvar pixels = ['; + for y = 0,size-1 do + s = s .. "["; + for x = 0,size-1 do + s = s .. "[" .. math.floor(pixels[y + 1][x + 1][1] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][2] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][3] * 255) .. "],"; + end + s = s .. "],"; + end + s = s .. '];\n var canvas = document.getElementById("renderCanvas").getContext("2d");\n\ +\n\ +\n\ + var size = ' .. size .. ';\n\ + canvas.fillStyle = "red";\n\ + canvas.fillRect(0, 0, size, size);\n\ + canvas.scale(1, -1);\n\ + canvas.translate(0, -size);\n\ +\n\ + if (!canvas.setFillColor)\n\ + canvas.setFillColor = function(r, g, b, a) {\n\ + this.fillStyle = "rgb("+[Math.floor(r), Math.floor(g), Math.floor(b)]+")";\n\ + }\n\ +\n\ +for (var y = 0; y < size; y++) {\n\ + for (var x = 0; x < size; x++) {\n\ + var l = pixels[y][x];\n\ + canvas.setFillColor(l[0], l[1], l[2], 1);\n\ + canvas.fillRect(x, y, 1, 1);\n\ + }\n\ +}'; + + return s; +end + +testOutput = arrayToCanvasCommands(raytraceScene()); + +--local f = io.output("output.html") +--f:write(testOutput) +--f:close() + +local expectedLength = 11599; +local testLength = #testOutput + +if (testLength ~= expectedLength) then + assert(false, "Error: bad result: expected length " .. expectedLength .. " but got " .. testLength); +end + +end + +bench.runCode(test, "3d-raytrace") diff --git a/bench/tests/sunspider/access-binary-trees.lua b/bench/tests/sunspider/access-binary-trees.lua new file mode 100644 index 0000000..9eb9358 --- /dev/null +++ b/bench/tests/sunspider/access-binary-trees.lua @@ -0,0 +1,69 @@ +--[[ + The Great Computer Language Shootout + http://shootout.alioth.debian.org/ + contributed by Isaac Gouy +]] + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +function TreeNode(left,right,item) + local this = {} + this.left = left; + this.right = right; + this.item = item; + + this.itemCheck = function(self) + if (self.left==nil) then return self.item; + else return self.item + self.left:itemCheck() - self.right:itemCheck(); end + end + + return this +end + +function bottomUpTree(item,depth) + if (depth>0) then + return TreeNode( + bottomUpTree(2*item-1, depth-1) + ,bottomUpTree(2*item, depth-1) + ,item + ); + else + return TreeNode(nil,nil,item); + end +end + +local ret = 0; + +for n = 4,7,1 do + local minDepth = 4; + local maxDepth = math.max(minDepth + 2, n); + local stretchDepth = maxDepth + 1; + + local check = bottomUpTree(0,stretchDepth):itemCheck(); + + local longLivedTree = bottomUpTree(0,maxDepth); + + for depth = minDepth,maxDepth,2 do + local iterations = 2.0 ^ (maxDepth - depth + minDepth - 1) -- 1 << (maxDepth - depth + minDepth); + + check = 0; + for i = 1,iterations do + check = check + bottomUpTree(i,depth):itemCheck(); + check = check + bottomUpTree(-i,depth):itemCheck(); + end + end + + ret = ret + longLivedTree:itemCheck(); +end + +local expected = -4; + +if (ret ~= expected) then + assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. ret); +end + +end + +bench.runCode(test, "access-binary-trees") diff --git a/bench/tests/sunspider/controlflow-recursive.lua b/bench/tests/sunspider/controlflow-recursive.lua new file mode 100644 index 0000000..d079162 --- /dev/null +++ b/bench/tests/sunspider/controlflow-recursive.lua @@ -0,0 +1,42 @@ +--[[ + The Great Computer Language Shootout + http://shootout.alioth.debian.org/ + contributed by Isaac Gouy +]] +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +function ack(m,n) + if (m==0) then return n+1; end + if (n==0) then return ack(m-1,1); end + return ack(m-1, ack(m,n-1) ); +end + +function fib(n) + if (n < 2) then return 1; end + return fib(n-2) + fib(n-1); +end + +function tak(x,y,z) + if (y >= x) then return z; end + return tak(tak(x-1,y,z), tak(y-1,z,x), tak(z-1,x,y)); +end + +local result = 0; + +for i = 3,5 do + result = result + ack(3,i); + result = result + fib(17.0+i); + result = result + tak(3*i+3,2*i+2,i+1); +end + +local expected = 57775; + +if (result ~= expected) then + assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. result); +end + +end + +bench.runCode(test, "controlflow-recursive") diff --git a/bench/tests/sunspider/crypto-aes.lua b/bench/tests/sunspider/crypto-aes.lua new file mode 100644 index 0000000..3b28972 --- /dev/null +++ b/bench/tests/sunspider/crypto-aes.lua @@ -0,0 +1,362 @@ +--[[ + * AES Cipher function: encrypt 'input' with Rijndael algorithm + * + * takes byte-array 'input' (16 bytes) + * 2D byte-array key schedule 'w' (Nr+1 x Nb bytes) + * + * applies Nr rounds (10/12/14) using key schedule w for 'add round key' stage + * + * returns byte-array encrypted value (16 bytes) + */]] + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +-- Sbox is pre-computed multiplicative inverse in GF(2^8) used in SubBytes and KeyExpansion [§5.1.1] +local Sbox = { 0x63,0x7c,0x77,0x7b,0xf2,0x6b,0x6f,0xc5,0x30,0x01,0x67,0x2b,0xfe,0xd7,0xab,0x76, + 0xca,0x82,0xc9,0x7d,0xfa,0x59,0x47,0xf0,0xad,0xd4,0xa2,0xaf,0x9c,0xa4,0x72,0xc0, + 0xb7,0xfd,0x93,0x26,0x36,0x3f,0xf7,0xcc,0x34,0xa5,0xe5,0xf1,0x71,0xd8,0x31,0x15, + 0x04,0xc7,0x23,0xc3,0x18,0x96,0x05,0x9a,0x07,0x12,0x80,0xe2,0xeb,0x27,0xb2,0x75, + 0x09,0x83,0x2c,0x1a,0x1b,0x6e,0x5a,0xa0,0x52,0x3b,0xd6,0xb3,0x29,0xe3,0x2f,0x84, + 0x53,0xd1,0x00,0xed,0x20,0xfc,0xb1,0x5b,0x6a,0xcb,0xbe,0x39,0x4a,0x4c,0x58,0xcf, + 0xd0,0xef,0xaa,0xfb,0x43,0x4d,0x33,0x85,0x45,0xf9,0x02,0x7f,0x50,0x3c,0x9f,0xa8, + 0x51,0xa3,0x40,0x8f,0x92,0x9d,0x38,0xf5,0xbc,0xb6,0xda,0x21,0x10,0xff,0xf3,0xd2, + 0xcd,0x0c,0x13,0xec,0x5f,0x97,0x44,0x17,0xc4,0xa7,0x7e,0x3d,0x64,0x5d,0x19,0x73, + 0x60,0x81,0x4f,0xdc,0x22,0x2a,0x90,0x88,0x46,0xee,0xb8,0x14,0xde,0x5e,0x0b,0xdb, + 0xe0,0x32,0x3a,0x0a,0x49,0x06,0x24,0x5c,0xc2,0xd3,0xac,0x62,0x91,0x95,0xe4,0x79, + 0xe7,0xc8,0x37,0x6d,0x8d,0xd5,0x4e,0xa9,0x6c,0x56,0xf4,0xea,0x65,0x7a,0xae,0x08, + 0xba,0x78,0x25,0x2e,0x1c,0xa6,0xb4,0xc6,0xe8,0xdd,0x74,0x1f,0x4b,0xbd,0x8b,0x8a, + 0x70,0x3e,0xb5,0x66,0x48,0x03,0xf6,0x0e,0x61,0x35,0x57,0xb9,0x86,0xc1,0x1d,0x9e, + 0xe1,0xf8,0x98,0x11,0x69,0xd9,0x8e,0x94,0x9b,0x1e,0x87,0xe9,0xce,0x55,0x28,0xdf, + 0x8c,0xa1,0x89,0x0d,0xbf,0xe6,0x42,0x68,0x41,0x99,0x2d,0x0f,0xb0,0x54,0xbb,0x16 }; + +-- Rcon is Round Constant used for the Key Expansion [1st col is 2^(r-1) in GF(2^8)] [§5.2] +local Rcon = { { 0x00, 0x00, 0x00, 0x00 }, + {0x01, 0x00, 0x00, 0x00}, + {0x02, 0x00, 0x00, 0x00}, + {0x04, 0x00, 0x00, 0x00}, + {0x08, 0x00, 0x00, 0x00}, + {0x10, 0x00, 0x00, 0x00}, + {0x20, 0x00, 0x00, 0x00}, + {0x40, 0x00, 0x00, 0x00}, + {0x80, 0x00, 0x00, 0x00}, + {0x1b, 0x00, 0x00, 0x00}, + {0x36, 0x00, 0x00, 0x00} }; + +function Cipher(input, w) -- main Cipher function [§5.1] + local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) + local Nr = #w / Nb - 1; -- no of rounds: 10/12/14 for 128/192/256-bit keys + + local state = {{},{},{},{}}; -- initialise 4xNb byte-array 'state' with input [§3.4] + for i = 0,4*Nb-1 do state[(i % 4) + 1][math.floor(i/4) + 1] = input[i + 1]; end + + state = AddRoundKey(state, w, 0, Nb); + + for round = 1,Nr-1 do + state = SubBytes(state, Nb); + state = ShiftRows(state, Nb); + state = MixColumns(state, Nb); + state = AddRoundKey(state, w, round, Nb); + end + + state = SubBytes(state, Nb); + state = ShiftRows(state, Nb); + state = AddRoundKey(state, w, Nr, Nb); + + local output = {} -- convert state to 1-d array before returning [§3.4] + for i = 0,4*Nb-1 do output[i + 1] = state[(i % 4) + 1][math.floor(i / 4) + 1]; end + + return output; +end + + +function SubBytes(s, Nb) -- apply SBox to state S [§5.1.1] + for r = 0,3 do + for c = 0,Nb-1 do s[r + 1][c + 1] = Sbox[s[r + 1][c + 1] + 1]; end + end + return s; +end + + +function ShiftRows(s, Nb) -- shift row r of state S left by r bytes [§5.1.2] + local t = {}; + for r = 1,3 do + for c = 0,3 do t[c + 1] = s[r + 1][((c + r) % Nb) + 1] end; -- shift into temp copy + for c = 0,3 do s[r + 1][c + 1] = t[c + 1]; end -- and copy back + end -- note that this will work for Nb=4,5,6, but not 7,8 (always 4 for AES): + return s; -- see fp.gladman.plus.com/cryptography_technology/rijndael/aes.spec.311.pdf +end + + +function MixColumns(s, Nb) -- combine bytes of each col of state S [§5.1.3] + for c = 0,3 do + local a = {}; -- 'a' is a copy of the current column from 's' + local b = {}; -- 'b' is a•{02} in GF(2^8) + for i = 0,3 do + a[i + 1] = s[i + 1][c + 1]; + + if bit32.band(s[i + 1][c + 1], 0x80) ~= 0 then + b[i + 1] = bit32.bxor(bit32.lshift(s[i + 1][c + 1], 1), 0x011b); + else + b[i + 1] = bit32.lshift(s[i + 1][c + 1], 1); + end + end + -- a[n] ^ b[n] is a•{03} in GF(2^8) + s[1][c + 1] = bit32.bxor(b[1], a[2], b[2], a[3], a[4]); -- 2*a0 + 3*a1 + a2 + a3 + s[2][c + 1] = bit32.bxor(a[1], b[2], a[3], b[3], a[4]); -- a0 * 2*a1 + 3*a2 + a3 + s[3][c + 1] = bit32.bxor(a[1], a[2], b[3], a[4], b[4]); -- a0 + a1 + 2*a2 + 3*a3 + s[4][c + 1] = bit32.bxor(a[1], b[1], a[2], a[3], b[4]); -- 3*a0 + a1 + a2 + 2*a3 +end + return s; +end + + +function AddRoundKey(state, w, rnd, Nb) -- xor Round Key into state S [§5.1.4] + for r = 0,3 do + for c = 0,Nb-1 do state[r + 1][c + 1] = bit32.bxor(state[r + 1][c + 1], w[rnd*4+c + 1][r + 1]); end + end + return state; +end + + +function KeyExpansion(key) -- generate Key Schedule (byte-array Nr+1 x Nb) from Key [§5.2] + local Nb = 4; -- block size (in words): no of columns in state (fixed at 4 for AES) + local Nk = #key / 4 -- key length (in words): 4/6/8 for 128/192/256-bit keys + local Nr = Nk + 6; -- no of rounds: 10/12/14 for 128/192/256-bit keys + + local w = {}; + local temp = {}; + + for i = 0,Nk do + local r = { key[4*i + 1], key[4*i + 2], key[4*i + 3], key[4*i + 4] }; + w[i + 1] = r; + end + + for i = Nk,(Nb*(Nr+1)) - 1 do + w[i + 1] = {}; + for t = 0,3 do temp[t + 1] = w[i-1 + 1][t + 1]; end + if (i % Nk == 0) then + temp = SubWord(RotWord(temp)); + for t = 0,3 do temp[t + 1] = bit32.bxor(temp[t + 1], Rcon[i/Nk + 1][t + 1]); end + elseif (Nk > 6 and i % Nk == 4) then + temp = SubWord(temp); + end + for t = 0,3 do w[i + 1][t + 1] = bit32.bxor(w[i - Nk + 1][t + 1], temp[t + 1]); end + end + + return w; +end + +function SubWord(w) -- apply SBox to 4-byte word w + for i = 0,3 do w[i + 1] = Sbox[w[i + 1] + 1]; end + return w; +end + +function RotWord(w) -- rotate 4-byte word w left by one byte + w[5] = w[1]; + for i = 0,3 do w[i + 1] = w[i + 2]; end + return w; +end + + +--[[ + * Use AES to encrypt 'plaintext' with 'password' using 'nBits' key, in 'Counter' mode of operation + * - see http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf + * for each block + * - outputblock = cipher(counter, key) + * - cipherblock = plaintext xor outputblock + ]] + +function AESEncryptCtr(plaintext, password, nBits) + if (not (nBits==128 or nBits==192 or nBits==256)) then return ''; end -- standard allows 128/192/256 bit keys + + -- for this example script, generate the key by applying Cipher to 1st 16/24/32 chars of password; + -- for real-world applications, a higher security approach would be to hash the password e.g. with SHA-1 + local nBytes = nBits/8; -- no bytes in key + local pwBytes = {}; + for i = 0,nBytes-1 do pwBytes[i + 1] = bit32.band(string.byte(password, i + 1), 0xff); end + local key = Cipher(pwBytes, KeyExpansion(pwBytes)); + + -- key is now 16/24/32 bytes long + for i = 1,nBytes-16 do + table.insert(key, key[i]) + end + + -- initialise counter block (NIST SP800-38A §B.2): millisecond time-stamp for nonce in 1st 8 bytes, + -- block counter in 2nd 8 bytes + local blockSize = 16; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES + local counterBlock = {}; -- block size fixed at 16 bytes / 128 bits (Nb=4) for AES + local nonce = 12564231564 -- (new Date()).getTime(); -- milliseconds since 1-Jan-1970 + + -- encode nonce in two stages to cater for JavaScript 32-bit limit on bitwise ops + for i = 0,3 do counterBlock[i + 1] = bit32.band(bit32.rshift(nonce, i * 8), 0xff); end + for i = 0,3 do counterBlock[i + 4 + 1] = bit32.band(bit32.rshift(math.floor(nonce / 0x100000000), i*8), 0xff); end + + -- generate key schedule - an expansion of the key into distinct Key Rounds for each round + local keySchedule = KeyExpansion(key); + + local blockCount = math.ceil(#plaintext / blockSize); + local ciphertext = {}; -- ciphertext as array of strings + + for b = 0,blockCount-1 do + -- set counter (block #) in last 8 bytes of counter block (leaving nonce in 1st 8 bytes) + -- again done in two stages for 32-bit ops + for c = 0,3 do counterBlock[15-c + 1] = bit32.band(bit32.rshift(b, c*8), 0xff); end + for c = 0,3 do counterBlock[15-c-4 + 1] = bit32.rshift(math.floor(b/0x100000000), c*8) end + + local cipherCntr = Cipher(counterBlock, keySchedule); -- -- encrypt counter block -- + + -- calculate length of final block: + local blockLength = nil + + if b maxFlipsCount) then + maxFlipsCount = flipsCount; + for i = 1,n do maxPerm[i] = perm1[i]; end + end + end + + while (true) do + if (r == n) then return maxFlipsCount; end + + local perm0 = perm1[1]; + local i = 0; + while (i < r) do + local j = i + 1; + perm1[i + 1] = perm1[j + 1]; + i = j; + end + perm1[r + 1] = perm0; + + count[r + 1] = count[r + 1] - 1; + if (count[r + 1] > 0) then break; end + r = r + 1; + end + end + + return 0 +end + +local n = 8; +local ret = fannkuch(n); + +local expected = 22; +if (ret ~= expected) then + assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. ret); +end + +end + +bench.runCode(test, "fannkuch") diff --git a/bench/tests/sunspider/math-cordic.lua b/bench/tests/sunspider/math-cordic.lua new file mode 100644 index 0000000..94a64f4 --- /dev/null +++ b/bench/tests/sunspider/math-cordic.lua @@ -0,0 +1,104 @@ +--[[ + * Copyright (C) Rich Moore. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY CONTRIBUTORS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + ]] + + local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +--. Start CORDIC + +local AG_CONST = 0.6072529350; + +function FIXED(X) + return X * 65536.0; +end + +function FLOAT(X) + return X / 65536.0; +end + +function DEG2RAD(X) + return 0.017453 * (X); +end + +local Angles = { + FIXED(45.0), FIXED(26.565), FIXED(14.0362), FIXED(7.12502), + FIXED(3.57633), FIXED(1.78991), FIXED(0.895174), FIXED(0.447614), + FIXED(0.223811), FIXED(0.111906), FIXED(0.055953), + FIXED(0.027977) +}; + +local Target = 28.027; + +function cordicsincos(Target) + local X; + local Y; + local TargetAngle; + local CurrAngle; + + X = FIXED(AG_CONST); -- AG_CONST * cos(0) + Y = 0; -- AG_CONST * sin(0) + + TargetAngle = FIXED(Target); + CurrAngle = 0; + for Step = 0,11 do + local NewX; + if (TargetAngle > CurrAngle) then + NewX = X - bit32.rshift(math.floor(Y), Step) -- (Y >> Step); + Y = bit32.rshift(math.floor(X), Step) + Y; + X = NewX; + CurrAngle = CurrAngle + Angles[Step + 1]; + else + NewX = X + bit32.rshift(math.floor(Y), Step) + Y = -bit32.rshift(math.floor(X), Step) + Y; + X = NewX; + CurrAngle = CurrAngle - Angles[Step + 1]; + end + end + + return FLOAT(X) * FLOAT(Y); +end + +-- End CORDIC + +local total = 0; + +function cordic( runs ) + for i = 1,runs do + total = total + cordicsincos(Target); + end +end + +cordic(25000); + +local expected = 10362.570468755888; + +if (total ~= expected) then + assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. total); +end + +end + +bench.runCode(test, "math-cordic") diff --git a/bench/tests/sunspider/math-partial-sums.lua b/bench/tests/sunspider/math-partial-sums.lua new file mode 100644 index 0000000..3c22287 --- /dev/null +++ b/bench/tests/sunspider/math-partial-sums.lua @@ -0,0 +1,53 @@ +--[[ + The Great Computer Language Shootout + http://shootout.alioth.debian.org/ + contributed by Isaac Gouy +]] +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +function partial(n) + local a1, a2, a3, a4, a5, a6, a7, a8, a9 = 0, 0, 0, 0, 0, 0, 0, 0, 0; + local twothirds = 2.0/3.0; + local alt = -1.0; + local k2, k3, sk, ck = 0, 0, 0, 0; + + for k = 1,n do + k2 = k*k; + k3 = k2*k; + sk = math.sin(k); + ck = math.cos(k); + alt = -alt; + + a1 = a1 + math.pow(twothirds,k-1); + a2 = a2 + math.pow(k,-0.5); + a3 = a3 + 1.0/(k*(k+1.0)); + a4 = a4 + 1.0/(k3 * sk*sk); + a5 = a5 + 1.0/(k3 * ck*ck); + a6 = a6 + 1.0/k; + a7 = a7 + 1.0/k2; + a8 = a8 + alt/k; + a9 = a9 + alt/(2*k -1); + end + + return a6 + a7 + a8 + a9; +end + +local total = 0; +local i = 1024 + +while i <= 16384 do + total = total + partial(i); + i = i * 2 +end + +local expected = 60.08994194659945; + +if (total ~= expected) then + assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. total); +end + +end + +bench.runCode(test, "math-partial-sums") diff --git a/bench/tests/sunspider/math-spectral-norm.lua b/bench/tests/sunspider/math-spectral-norm.lua new file mode 100644 index 0000000..7d7ec16 --- /dev/null +++ b/bench/tests/sunspider/math-spectral-norm.lua @@ -0,0 +1,72 @@ +--[[ +The Great Computer Language Shootout +http://shootout.alioth.debian.org/ + +contributed by Ian Osgood +]] +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + +function A(i,j) + return 1/((i+j)*(i+j+1)/2+i+1); +end + +function Au(u,v) + for i = 0,#u-1 do + local t = 0; + for j = 0,#u-1 do + t = t + A(i,j) * u[j + 1]; + end + v[i + 1] = t; + end +end + +function Atu(u,v) + for i = 0,#u-1 do + local t = 0; + for j = 0,#u-1 do + t = t + A(j,i) * u[j + 1]; + end + v[i + 1] = t; + end +end + +function AtAu(u,v,w) + Au(u,w); + Atu(w,v); +end + +function spectralnorm(n) + local u, v, w, vv, vBv = {}, {}, {}, 0, 0; + for i = 1,n do + u[i] = 1; v[i] = 0; w[i] = 0; + end + for i = 0,9 do + AtAu(u,v,w); + AtAu(v,u,w); + end + for i = 1,n do + vBv = vBv + u[i]*v[i]; + vv = vv + v[i]*v[i]; + end + return math.sqrt(vBv/vv); +end + +local total = 0; +local i = 6 + +while i <= 48 do + total = total + spectralnorm(i); + i = i * 2 +end + +local expected = 5.086694231303284; + +if (total ~= expected) then + assert(false, "ERROR: bad result: expected " .. expected .. " but got " .. total) +end + +end + +bench.runCode(test, "math-spectral-norm") diff --git a/bench/tests/sunspider/n-body-oop.lua b/bench/tests/sunspider/n-body-oop.lua new file mode 100644 index 0000000..adcc15a --- /dev/null +++ b/bench/tests/sunspider/n-body-oop.lua @@ -0,0 +1,169 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +local PI = 3.141592653589793 +local SOLAR_MASS = 4 * PI * PI +local DAYS_PER_YEAR = 365.24 + +local Body = {} +Body.__index = Body + +function Body.new(x, y, z, vx, vy, vz, mass) + local self = {} + self.x = x + self.y = y + self.z = z + self.vx = vx + self.vy = vy + self.vz = vz + self.mass = mass + return setmetatable(self, Body) +end + +function Body:offsetMomentum(px, py, pz) + self.vx = -px / SOLAR_MASS + self.vy = -py / SOLAR_MASS + self.vz = -pz / SOLAR_MASS + + return self +end + +local function Jupiter() + return Body.new( + 4.841431442464721e0, + -1.1603200440274284e0, + -1.036220444711231e-1, + 1.660076642744037e-3 * DAYS_PER_YEAR, + 7.6990111841974045e-3 * DAYS_PER_YEAR, + -6.90460016972063e-5 * DAYS_PER_YEAR, + 9.547919384243267e-4 * SOLAR_MASS + ) +end +local function Saturn() + return Body.new(8.34336671824458e0, 4.124798564124305e0, -4.035234171143213e-1, -2.767425107268624e-3 * DAYS_PER_YEAR, 4.998528012349173e-3 * DAYS_PER_YEAR, 2.3041729757376395e-5 * DAYS_PER_YEAR, 2.8588598066613082e-4 * SOLAR_MASS) +end +local function Uranus() + return Body.new(1.2894369562139132e1, -1.511115140169863e1, -2.2330757889265573e-1, 2.964601375647616e-3 * DAYS_PER_YEAR, 2.3784717395948096e-3 * DAYS_PER_YEAR, -2.9658956854023755e-5 * DAYS_PER_YEAR, 4.366244043351563e-5 * SOLAR_MASS) +end +local function Neptune() + return Body.new(1.5379697114850917e1, -2.5919314609987962e1, 1.7925877295037118e-1, 2.680677724903893e-3 * DAYS_PER_YEAR, 1.628241700382423e-3 * DAYS_PER_YEAR, -9.515922545197158e-5 * DAYS_PER_YEAR, 5.151389020466114e-5 * SOLAR_MASS) +end +local function Sun() + return Body.new(0, 0, 0, 0, 0, 0, SOLAR_MASS) +end + +local NBodySystem = {} +NBodySystem.__index = NBodySystem + + +function NBodySystem.new(bodies) + local self = {} + self.bodies = bodies + + local px = 0 + local py = 0 + local pz = 0 + local size = #self.bodies + + for i=1, size do + local b = self.bodies[i] + local m = b.mass + + px = px + b.vx * m + py = py + b.vy * m + pz = pz + b.vz * m + end + + self.bodies[1]:offsetMomentum(px, py, pz) + + return setmetatable(self, NBodySystem) +end + +function NBodySystem:advance(dt) + local dx, dy, dz, distance, mag + local size = #self.bodies + + for i=1, size do + local bodyi = self.bodies[i] + for j=i+1, size do + local bodyj = self.bodies[j] + dx = bodyi.x - bodyj.x + dy = bodyi.y - bodyj.y + dz = bodyi.z - bodyj.z + + distance = math.sqrt(dx*dx + dy*dy + dz*dz) + mag = dt / (distance * distance * distance) + + bodyi.vx -= dx * bodyj.mass * mag + bodyi.vy -= dy * bodyj.mass * mag + bodyi.vz -= dz * bodyj.mass * mag + + bodyj.vx += dx * bodyi.mass * mag + bodyj.vy += dy * bodyi.mass * mag + bodyj.vz += dz * bodyi.mass * mag + end + end + for i=1, size do + local body = self.bodies[i] + + body.x = body.x + dt * body.vx + body.y = body.y + dt * body.vy + body.z = body.z + dt * body.vz + end +end + +function NBodySystem:energy() + local dx, dy, dz, distance + local e = 0.0 + local size = #self.bodies + + for i=1, size do + local bodyi = self.bodies[i] + + e = e + 0.5 * bodyi.mass * (bodyi.vx * bodyi.vx + bodyi.vy * bodyi.vy + bodyi.vz * bodyi.vz) + + for j=i+1, size do + local bodyj = self.bodies[j] + dx = bodyi.x - bodyj.x + dy = bodyi.y - bodyj.y + dz = bodyi.z - bodyj.z + + distance = math.sqrt(dx*dx + dy*dy + dz*dz) + e -= (bodyi.mass * bodyj.mass) / distance + end + end + + return e +end + +local function run() + local ret = 0 + local n = 3 + while n <= 24 do + (function() + local bodies = NBodySystem.new({ + Sun(),Jupiter(),Saturn(),Uranus(),Neptune() + }) + local max = n * 100 + + ret += bodies:energy() + for i=1, max do + bodies:advance(0.01) + end + ret += bodies:energy() + end)() + n *= 2 + end + local expected = -1.3524862408537381 + + if ret ~= expected then + error('ERROR: bad result: expected ' .. expected .. ' but got ' .. ret) + end +end + +function runIteration() + for i=1, 5 do + run() + end +end + +bench.runCode(runIteration, "n-body-oop") diff --git a/bench/tests/tictactoe.lua b/bench/tests/tictactoe.lua new file mode 100644 index 0000000..ae63f5f --- /dev/null +++ b/bench/tests/tictactoe.lua @@ -0,0 +1,228 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + -- https://github.com/stefandd/Tic4 + local negaMax = {maxdepth = 4, minsearchpos = 0, numsearchpos = 0} + negaMax.__index = negaMax + + function negaMax:evaluate(board, depth) + --[[ + What can be confusing is how the heuristic value of the current node is calculated. In this implementation, this value is always calculated from the point of view of player A, whose color value is one. In other words, higher heuristic values always represent situations more favorable for player A. This is the same behavior as the normal minimax algorithm. The heuristic value is not necessarily the same as a node's return value due to value negation by negamax and the color parameter. The negamax node's return value is a heuristic score from the point of view of the node's current player. + + Negamax scores match minimax scores for nodes where player A is about to play, and where player A is the maximizing player in the minimax equivalent. Negamax always searches for the maximum value for all its nodes. Hence for player B nodes, the minimax score is a negation of its negamax score. Player B is the minimizing player in the minimax equivalent. + + Variations in negamax implementations may omit the color parameter. In this case, the heuristic evaluation function must return values from the point of view of the node's current player. + --]] + print ("This function needs to be implemented!") + end + + function negaMax:move_candidates(board, side_to_move) + print ("This function needs to be implemented!") + end + + function negaMax:make_move(board, side_to_move, move) + print ("This function needs to be implemented!") + end + + function negaMax:negaMax(board, side_to_move, depth, alpha, beta) -- side_to_move: e.g. 1 is blue, -1 is read + -- + -- init vars for root call + -- + if not depth then -- root call + depth = 0 + alpha = -math.huge + beta = math.huge + self.numsearchpos = 0 -- reset call counter + end + -- + -- test if the node is terminal (i.e. full board or win) + -- + local best_move = -1 + local score, is_term_node = self:evaluate(board, depth) + -- we abort the recursion if this is a terminal node, or if one of the search abort conditions are met + -- + if is_term_node or depth == self.maxdepth then + return side_to_move*score, best_move, is_term_node + end + -- + -- if not terminal node, eval child nodes + -- + local moves = self:move_candidates(board, side_to_move) + score = -math.huge + + for _, analyzed_move in pairs(moves) do -- iterate over all boards + self.numsearchpos = self.numsearchpos + 1 + local b = self:make_move(board, side_to_move, analyzed_move) + local move_score, _, _ = -self:negaMax(b, -side_to_move, depth+1, -beta, -alpha) + if move_score > score then + score = move_score + best_move = analyzed_move + end + -- disable alpha-beta pruning + -- + alpha = math.max(alpha, score) + if alpha >= beta then + break + end + -- + end + if depth == 0 then + -- + -- exit for root call (depth == 0) + -- + -- debug stuff + --print(string.format("---- Negamax: node info, depth: %d, side: %d, score: %d, best move: %d", depth, side_to_move, score, best_move)) + --print_board(board) + --print(string.format("----")) + print("Analyzed positions: " .. self.numsearchpos) + end + return score, best_move, game_over + end + + local empty_board = {0,0,0,0, + 0,0,0,0, + 0,0,0,0, + 0,0,0,0} -- 16 empty positions + + ----------- helper methods + + function copy_board(board) + local copy = {} + for i = 1, #board do + copy[i] = board[i] + end + return copy + end + + function print_board(board) + gboard = {} + for i = 1, #board do + if board[i] == 0 then gboard[i] = '.' + elseif board[i] == 1 then gboard[i] = 'x' + else gboard[i] = 'o' + end + end + print(string.format("\n%s %s %s %s\n%s %s %s %s\n%s %s %s %s\n%s %s %s %s\n", unpack(gboard))) + end + + function is_board_full(board) + for i = 1, #board do + if board[i] == 0 then + return false + end + end + return true + end + + ----------- implement negaMax methods + + negaMax.index_quadruplets = { + {1,2,3,4}, {5,6,7,8}, {9,10,11,12}, -- rows + {13,14,15,16}, {1,5,9,13}, {2,6,10,14}, -- cols + {3,7,11,15}, {4,8,12,16}, {1,6,11,16}, {4,7,10,13}, -- diags + {1,2,5,6}, {2,3,6,7}, {3,4,7,8}, -- squares + {5,6,9,10}, {6,7,10,11}, {7,8,11,12}, + {9,10,13,14}, {10,11,14,15}, {11,12,15,16} + } + + function negaMax:evaluate(board, depth) -- return format is score, is_terminal_position + --[[ + What can be confusing is how the heuristic value of the current node is calculated. In this implementation, this value is always calculated from the point of view of player A, whose color value is one. In other words, higher heuristic values always represent situations more favorable for player A. This is the same behavior as the normal minimax algorithm. The heuristic value is not necessarily the same as a node's return value due to value negation by negamax and the color parameter. The negamax node's return value is a heuristic score from the point of view of the node's current player. + + Negamax scores match minimax scores for nodes where player A is about to play, and where player A is the maximizing player in the minimax equivalent. Negamax always searches for the maximum value for all its nodes. Hence for player B nodes, the minimax score is a negation of its negamax score. Player B is the minimizing player in the minimax equivalent. + + Variations in negamax implementations may omit the color parameter. In this case, the heuristic evaluation function must return values from the point of view of the node's current player. + --]] + local player_plus_score, player_minus_score = 0, 0 + local game_won = false + for _, curr_qdr in pairs(negaMax.index_quadruplets) do -- iterate over all index quadruplets + -- count the empty positions and positions occupied by the side whos move it is + local player_plus_fields, player_minus_fields, empties = 0, 0, 0 + for _, index in pairs(curr_qdr) do -- iterate over all indices + if board[index] == 0 then + empties = empties + 1 + elseif board[index] == 1 then + player_plus_fields = player_plus_fields + 1 + elseif board[index] == -1 then + player_minus_fields = player_minus_fields + 1 + end + end + -- evaluate the quadruplets score by looking at empty vs occupied positions + if empties == 3 then + if player_plus_fields == 1 then + player_plus_score = player_plus_score + 3 + elseif player_minus_fields == 1 then + player_minus_score = player_minus_score + 3 + end + elseif empties == 2 then + if player_plus_fields == 2 then + player_plus_score = player_plus_score + 13 + elseif player_minus_fields == 2 then + player_minus_score = player_minus_score + 13 + end + elseif empties == 1 then + if player_plus_fields == 3 then + player_plus_score = player_plus_score + 31 + elseif player_minus_fields == 3 then + player_minus_score = player_minus_score + 31 + end + elseif empties == 0 then + -- check for winning situations + if player_plus_fields == 4 then + player_plus_score = 999-depth + player_minus_score = 0 + game_won = true + break + elseif player_minus_fields == 4 then + -- this should not happen if there is a proper terminal node detection! + player_plus_score = 0 + player_minus_score = 999-depth + game_won = true + break + end + end + end + -- return format is score, is_terminal_position + if not game_won and is_board_full(board) then + return 0, true -- DRAW + else + return (player_plus_score - player_minus_score), game_won -- >0 is good for player 1 [+], <0 means good for the other player (player 2 [-])) + end + end + + function negaMax:move_candidates(board, side_to_move) + local moves = {} + for i = 1, #board do + if board[i] == 0 then -- empty? + moves[#moves + 1] = i -- save move that was made + end + end + return moves + end + + function negaMax:make_move(board, side_to_move, move) + local copy = copy_board(board) + copy[move] = side_to_move + return copy + end + + local human_player = 1 + local AI_player = -human_player + local game_board = copy_board(empty_board) + local curr_move = -1 + local curr_player = human_player -- human player goes first + local score = 0 + local stop_loop = false + local game_over = false + + negaMax.maxdepth = 5 + + local t0 = os.clock() + score, curr_move = negaMax:negaMax(game_board, curr_player) + local t1 = os.clock() + + return t1-t0 +end + +bench.runCode(test, "tictactoe") \ No newline at end of file diff --git a/bench/tests/trig.lua b/bench/tests/trig.lua new file mode 100644 index 0000000..702e699 --- /dev/null +++ b/bench/tests/trig.lua @@ -0,0 +1,71 @@ +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +function test() + + function updateTransforms(matrixArray, amount, offset, scale, time) + local i = 0 + + for x=0,amount-1 do + for y=0,amount-1 do + for z=0,amount-1 do + local tx = offset - x + local ty = offset - y + local tz = offset - z + + local rx = 0 + local ry = ( math.sin( x / 4 + time ) + math.sin( y / 4 + time ) + math.sin( z / 4 + time ) ) + local rz = ry * 2 + + local ch = math.cos(rx) + local sh = math.sin(rx) + local ca = math.cos(ry) + local sa = math.sin(ry) + local cb = math.cos(rz) + local sb = math.sin(rz) + + local m00 = ch * ca + local m01 = sh*sb - ch*sa*cb + local m02 = ch*sa*sb + sh*cb + local m10 = sa + local m11 = ca*cb + local m12 = -ca*sb + local m20 = -sh*ca + local m21 = sh*sa*cb + ch*sb + local m22 = -sh*sa*sb + ch*cb + + matrixArray[i * 16 + 1] = m00 * scale + matrixArray[i * 16 + 2] = m01 * scale + matrixArray[i * 16 + 3] = m02 * scale + matrixArray[i * 16 + 4] = 0 + matrixArray[i * 16 + 5] = m10 * scale + matrixArray[i * 16 + 6] = m11 * scale + matrixArray[i * 16 + 7] = m12 * scale + matrixArray[i * 16 + 8] = 0 + matrixArray[i * 16 + 9] = m20 * scale + matrixArray[i * 16 + 10] = m21 * scale + matrixArray[i * 16 + 11] = m22 * scale + matrixArray[i * 16 + 12] = 0 + matrixArray[i * 16 + 13] = tx + matrixArray[i * 16 + 14] = ty + matrixArray[i * 16 + 15] = tz + matrixArray[i * 16 + 16] = 1 + + i = i + 1 + end + end + end + end + + local N = 40 + local array = table.create(N*N*N*16) + + local ts0 = os.clock() + + updateTransforms(array, N, -N/2, 0.5, 1/60) + + local ts1 = os.clock() + + return ts1-ts0 +end + +bench.runCode(test, "trig") \ No newline at end of file diff --git a/extern/.clang-format b/extern/.clang-format new file mode 100644 index 0000000..9d15924 --- /dev/null +++ b/extern/.clang-format @@ -0,0 +1,2 @@ +DisableFormat: true +SortIncludes: false diff --git a/extern/doctest.h b/extern/doctest.h new file mode 100644 index 0000000..f9e9c5c --- /dev/null +++ b/extern/doctest.h @@ -0,0 +1,6580 @@ +// ====================================================================== lgtm [cpp/missing-header-guard] +// == DO NOT MODIFY THIS FILE BY HAND - IT IS AUTO GENERATED BY CMAKE! == +// ====================================================================== +// +// doctest.h - the lightest feature-rich C++ single-header testing framework for unit tests and TDD +// +// Copyright (c) 2016-2021 Viktor Kirilov +// +// Distributed under the MIT Software License +// See accompanying file LICENSE.txt or copy at +// https://opensource.org/licenses/MIT +// +// The documentation can be found at the library's page: +// https://github.com/onqtam/doctest/blob/master/doc/markdown/readme.md +// +// ================================================================================================= +// ================================================================================================= +// ================================================================================================= +// +// The library is heavily influenced by Catch - https://github.com/catchorg/Catch2 +// which uses the Boost Software License - Version 1.0 +// see here - https://github.com/catchorg/Catch2/blob/master/LICENSE.txt +// +// The concept of subcases (sections in Catch) and expression decomposition are from there. +// Some parts of the code are taken directly: +// - stringification - the detection of "ostream& operator<<(ostream&, const T&)" and StringMaker<> +// - the Approx() helper class for floating point comparison +// - colors in the console +// - breaking into a debugger +// - signal / SEH handling +// - timer +// - XmlWriter class - thanks to Phil Nash for allowing the direct reuse (AKA copy/paste) +// +// The expression decomposing templates are taken from lest - https://github.com/martinmoene/lest +// which uses the Boost Software License - Version 1.0 +// see here - https://github.com/martinmoene/lest/blob/master/LICENSE.txt +// +// ================================================================================================= +// ================================================================================================= +// ================================================================================================= + +#ifndef DOCTEST_LIBRARY_INCLUDED +#define DOCTEST_LIBRARY_INCLUDED + +// ================================================================================================= +// == VERSION ====================================================================================== +// ================================================================================================= + +#define DOCTEST_VERSION_MAJOR 2 +#define DOCTEST_VERSION_MINOR 4 +#define DOCTEST_VERSION_PATCH 6 +#define DOCTEST_VERSION_STR "2.4.6" + +#define DOCTEST_VERSION \ + (DOCTEST_VERSION_MAJOR * 10000 + DOCTEST_VERSION_MINOR * 100 + DOCTEST_VERSION_PATCH) + +// ================================================================================================= +// == COMPILER VERSION ============================================================================= +// ================================================================================================= + +// ideas for the version stuff are taken from here: https://github.com/cxxstuff/cxx_detect + +#define DOCTEST_COMPILER(MAJOR, MINOR, PATCH) ((MAJOR)*10000000 + (MINOR)*100000 + (PATCH)) + +// GCC/Clang and GCC/MSVC are mutually exclusive, but Clang/MSVC are not because of clang-cl... +#if defined(_MSC_VER) && defined(_MSC_FULL_VER) +#if _MSC_VER == _MSC_FULL_VER / 10000 +#define DOCTEST_MSVC DOCTEST_COMPILER(_MSC_VER / 100, _MSC_VER % 100, _MSC_FULL_VER % 10000) +#else // MSVC +#define DOCTEST_MSVC \ + DOCTEST_COMPILER(_MSC_VER / 100, (_MSC_FULL_VER / 100000) % 100, _MSC_FULL_VER % 100000) +#endif // MSVC +#endif // MSVC +#if defined(__clang__) && defined(__clang_minor__) +#define DOCTEST_CLANG DOCTEST_COMPILER(__clang_major__, __clang_minor__, __clang_patchlevel__) +#elif defined(__GNUC__) && defined(__GNUC_MINOR__) && defined(__GNUC_PATCHLEVEL__) && \ + !defined(__INTEL_COMPILER) +#define DOCTEST_GCC DOCTEST_COMPILER(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) +#endif // GCC + +#ifndef DOCTEST_MSVC +#define DOCTEST_MSVC 0 +#endif // DOCTEST_MSVC +#ifndef DOCTEST_CLANG +#define DOCTEST_CLANG 0 +#endif // DOCTEST_CLANG +#ifndef DOCTEST_GCC +#define DOCTEST_GCC 0 +#endif // DOCTEST_GCC + +// ================================================================================================= +// == COMPILER WARNINGS HELPERS ==================================================================== +// ================================================================================================= + +#if DOCTEST_CLANG +#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) +#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH _Pragma("clang diagnostic push") +#define DOCTEST_CLANG_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(clang diagnostic ignored w) +#define DOCTEST_CLANG_SUPPRESS_WARNING_POP _Pragma("clang diagnostic pop") +#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH DOCTEST_CLANG_SUPPRESS_WARNING(w) +#else // DOCTEST_CLANG +#define DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +#define DOCTEST_CLANG_SUPPRESS_WARNING(w) +#define DOCTEST_CLANG_SUPPRESS_WARNING_POP +#define DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_CLANG + +#if DOCTEST_GCC +#define DOCTEST_PRAGMA_TO_STR(x) _Pragma(#x) +#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH _Pragma("GCC diagnostic push") +#define DOCTEST_GCC_SUPPRESS_WARNING(w) DOCTEST_PRAGMA_TO_STR(GCC diagnostic ignored w) +#define DOCTEST_GCC_SUPPRESS_WARNING_POP _Pragma("GCC diagnostic pop") +#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_GCC_SUPPRESS_WARNING_PUSH DOCTEST_GCC_SUPPRESS_WARNING(w) +#else // DOCTEST_GCC +#define DOCTEST_GCC_SUPPRESS_WARNING_PUSH +#define DOCTEST_GCC_SUPPRESS_WARNING(w) +#define DOCTEST_GCC_SUPPRESS_WARNING_POP +#define DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_GCC + +#if DOCTEST_MSVC +#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH __pragma(warning(push)) +#define DOCTEST_MSVC_SUPPRESS_WARNING(w) __pragma(warning(disable : w)) +#define DOCTEST_MSVC_SUPPRESS_WARNING_POP __pragma(warning(pop)) +#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH DOCTEST_MSVC_SUPPRESS_WARNING(w) +#else // DOCTEST_MSVC +#define DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +#define DOCTEST_MSVC_SUPPRESS_WARNING(w) +#define DOCTEST_MSVC_SUPPRESS_WARNING_POP +#define DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(w) +#endif // DOCTEST_MSVC + +// ================================================================================================= +// == COMPILER WARNINGS ============================================================================ +// ================================================================================================= + +DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wnon-virtual-dtor") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wdeprecated") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-local-typedef") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") + +DOCTEST_GCC_SUPPRESS_WARNING_PUSH +DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") +DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") +DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") +DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") +DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") +DOCTEST_GCC_SUPPRESS_WARNING("-Wctor-dtor-privacy") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") +DOCTEST_GCC_SUPPRESS_WARNING("-Wnon-virtual-dtor") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-local-typedefs") +DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") +DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-promo") + +DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +DOCTEST_MSVC_SUPPRESS_WARNING(4616) // invalid compiler warning +DOCTEST_MSVC_SUPPRESS_WARNING(4619) // invalid compiler warning +DOCTEST_MSVC_SUPPRESS_WARNING(4996) // The compiler encountered a deprecated declaration +DOCTEST_MSVC_SUPPRESS_WARNING(4706) // assignment within conditional expression +DOCTEST_MSVC_SUPPRESS_WARNING(4512) // 'class' : assignment operator could not be generated +DOCTEST_MSVC_SUPPRESS_WARNING(4127) // conditional expression is constant +DOCTEST_MSVC_SUPPRESS_WARNING(4820) // padding +DOCTEST_MSVC_SUPPRESS_WARNING(4625) // copy constructor was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(4626) // assignment operator was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(5027) // move assignment operator was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(5026) // move constructor was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(4623) // default constructor was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(4640) // construction of local static object is not thread-safe +// static analysis +DOCTEST_MSVC_SUPPRESS_WARNING(26439) // This kind of function may not throw. Declare it 'noexcept' +DOCTEST_MSVC_SUPPRESS_WARNING(26495) // Always initialize a member variable +DOCTEST_MSVC_SUPPRESS_WARNING(26451) // Arithmetic overflow ... +DOCTEST_MSVC_SUPPRESS_WARNING(26444) // Avoid unnamed objects with custom construction and dtr... +DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' + +// 4548 - expression before comma has no effect; expected expression with side - effect +// 4265 - class has virtual functions, but destructor is not virtual +// 4986 - exception specification does not match previous declaration +// 4350 - behavior change: 'member1' called instead of 'member2' +// 4668 - 'x' is not defined as a preprocessor macro, replacing with '0' for '#if/#elif' +// 4365 - conversion from 'int' to 'unsigned long', signed/unsigned mismatch +// 4774 - format string expected in argument 'x' is not a string literal +// 4820 - padding in structs + +// only 4 should be disabled globally: +// - 4514 # unreferenced inline function has been removed +// - 4571 # SEH related +// - 4710 # function not inlined +// - 4711 # function 'x' selected for automatic inline expansion + +#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN \ + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH \ + DOCTEST_MSVC_SUPPRESS_WARNING(4548) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4265) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4986) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4350) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4668) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4365) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4774) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4820) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4625) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4626) \ + DOCTEST_MSVC_SUPPRESS_WARNING(5027) \ + DOCTEST_MSVC_SUPPRESS_WARNING(5026) \ + DOCTEST_MSVC_SUPPRESS_WARNING(4623) \ + DOCTEST_MSVC_SUPPRESS_WARNING(5039) \ + DOCTEST_MSVC_SUPPRESS_WARNING(5045) \ + DOCTEST_MSVC_SUPPRESS_WARNING(5105) + +#define DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END DOCTEST_MSVC_SUPPRESS_WARNING_POP + +// ================================================================================================= +// == FEATURE DETECTION ============================================================================ +// ================================================================================================= + +// general compiler feature support table: https://en.cppreference.com/w/cpp/compiler_support +// MSVC C++11 feature support table: https://msdn.microsoft.com/en-us/library/hh567368.aspx +// GCC C++11 feature support table: https://gcc.gnu.org/projects/cxx-status.html +// MSVC version table: +// https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B#Internal_version_numbering +// MSVC++ 14.2 (16) _MSC_VER == 1920 (Visual Studio 2019) +// MSVC++ 14.1 (15) _MSC_VER == 1910 (Visual Studio 2017) +// MSVC++ 14.0 _MSC_VER == 1900 (Visual Studio 2015) +// MSVC++ 12.0 _MSC_VER == 1800 (Visual Studio 2013) +// MSVC++ 11.0 _MSC_VER == 1700 (Visual Studio 2012) +// MSVC++ 10.0 _MSC_VER == 1600 (Visual Studio 2010) +// MSVC++ 9.0 _MSC_VER == 1500 (Visual Studio 2008) +// MSVC++ 8.0 _MSC_VER == 1400 (Visual Studio 2005) + +#if DOCTEST_MSVC && !defined(DOCTEST_CONFIG_WINDOWS_SEH) +#define DOCTEST_CONFIG_WINDOWS_SEH +#endif // MSVC +#if defined(DOCTEST_CONFIG_NO_WINDOWS_SEH) && defined(DOCTEST_CONFIG_WINDOWS_SEH) +#undef DOCTEST_CONFIG_WINDOWS_SEH +#endif // DOCTEST_CONFIG_NO_WINDOWS_SEH + +#if !defined(_WIN32) && !defined(__QNX__) && !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && \ + !defined(__EMSCRIPTEN__) +#define DOCTEST_CONFIG_POSIX_SIGNALS +#endif // _WIN32 +#if defined(DOCTEST_CONFIG_NO_POSIX_SIGNALS) && defined(DOCTEST_CONFIG_POSIX_SIGNALS) +#undef DOCTEST_CONFIG_POSIX_SIGNALS +#endif // DOCTEST_CONFIG_NO_POSIX_SIGNALS + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS +#if !defined(__cpp_exceptions) && !defined(__EXCEPTIONS) && !defined(_CPPUNWIND) +#define DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // no exceptions +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS +#define DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#if defined(DOCTEST_CONFIG_NO_EXCEPTIONS) && !defined(DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS) +#define DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS && !DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS + +#if defined(DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN) && !defined(DOCTEST_CONFIG_IMPLEMENT) +#define DOCTEST_CONFIG_IMPLEMENT +#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +#if defined(_WIN32) || defined(__CYGWIN__) +#if DOCTEST_MSVC +#define DOCTEST_SYMBOL_EXPORT __declspec(dllexport) +#define DOCTEST_SYMBOL_IMPORT __declspec(dllimport) +#else // MSVC +#define DOCTEST_SYMBOL_EXPORT __attribute__((dllexport)) +#define DOCTEST_SYMBOL_IMPORT __attribute__((dllimport)) +#endif // MSVC +#else // _WIN32 +#define DOCTEST_SYMBOL_EXPORT __attribute__((visibility("default"))) +#define DOCTEST_SYMBOL_IMPORT +#endif // _WIN32 + +#ifdef DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL +#ifdef DOCTEST_CONFIG_IMPLEMENT +#define DOCTEST_INTERFACE DOCTEST_SYMBOL_EXPORT +#else // DOCTEST_CONFIG_IMPLEMENT +#define DOCTEST_INTERFACE DOCTEST_SYMBOL_IMPORT +#endif // DOCTEST_CONFIG_IMPLEMENT +#else // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL +#define DOCTEST_INTERFACE +#endif // DOCTEST_CONFIG_IMPLEMENTATION_IN_DLL + +#define DOCTEST_EMPTY + +#if DOCTEST_MSVC +#define DOCTEST_NOINLINE __declspec(noinline) +#define DOCTEST_UNUSED +#define DOCTEST_ALIGNMENT(x) +#elif DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 5, 0) +#define DOCTEST_NOINLINE +#define DOCTEST_UNUSED +#define DOCTEST_ALIGNMENT(x) +#else +#define DOCTEST_NOINLINE __attribute__((noinline)) +#define DOCTEST_UNUSED __attribute__((unused)) +#define DOCTEST_ALIGNMENT(x) __attribute__((aligned(x))) +#endif + +#ifndef DOCTEST_NORETURN +#define DOCTEST_NORETURN [[noreturn]] +#endif // DOCTEST_NORETURN + +#ifndef DOCTEST_NOEXCEPT +#define DOCTEST_NOEXCEPT noexcept +#endif // DOCTEST_NOEXCEPT + +// ================================================================================================= +// == FEATURE DETECTION END ======================================================================== +// ================================================================================================= + +// internal macros for string concatenation and anonymous variable name generation +#define DOCTEST_CAT_IMPL(s1, s2) s1##s2 +#define DOCTEST_CAT(s1, s2) DOCTEST_CAT_IMPL(s1, s2) +#ifdef __COUNTER__ // not standard and may be missing for some compilers +#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __COUNTER__) +#else // __COUNTER__ +#define DOCTEST_ANONYMOUS(x) DOCTEST_CAT(x, __LINE__) +#endif // __COUNTER__ + +#define DOCTEST_TOSTR(x) #x + +#ifndef DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE +#define DOCTEST_REF_WRAP(x) x& +#else // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE +#define DOCTEST_REF_WRAP(x) x +#endif // DOCTEST_CONFIG_ASSERTION_PARAMETERS_BY_VALUE + +// not using __APPLE__ because... this is how Catch does it +#ifdef __MAC_OS_X_VERSION_MIN_REQUIRED +#define DOCTEST_PLATFORM_MAC +#elif defined(__IPHONE_OS_VERSION_MIN_REQUIRED) +#define DOCTEST_PLATFORM_IPHONE +#elif defined(_WIN32) +#define DOCTEST_PLATFORM_WINDOWS +#else // DOCTEST_PLATFORM +#define DOCTEST_PLATFORM_LINUX +#endif // DOCTEST_PLATFORM + +#define DOCTEST_GLOBAL_NO_WARNINGS(var) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wglobal-constructors") \ + DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-variable") \ + static const int var DOCTEST_UNUSED // NOLINT(fuchsia-statically-constructed-objects,cert-err58-cpp) +#define DOCTEST_GLOBAL_NO_WARNINGS_END() DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#ifndef DOCTEST_BREAK_INTO_DEBUGGER +// should probably take a look at https://github.com/scottt/debugbreak +#ifdef DOCTEST_PLATFORM_LINUX +#if defined(__GNUC__) && (defined(__i386) || defined(__x86_64)) +// Break at the location of the failing check if possible +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT (hicpp-no-assembler) +#else +#include +#define DOCTEST_BREAK_INTO_DEBUGGER() raise(SIGTRAP) +#endif +#elif defined(DOCTEST_PLATFORM_MAC) +#if defined(__x86_64) || defined(__x86_64__) || defined(__amd64__) || defined(__i386) +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("int $3\n" : :) // NOLINT (hicpp-no-assembler) +#else +#define DOCTEST_BREAK_INTO_DEBUGGER() __asm__("brk #0"); // NOLINT (hicpp-no-assembler) +#endif +#elif DOCTEST_MSVC +#define DOCTEST_BREAK_INTO_DEBUGGER() __debugbreak() +#elif defined(__MINGW32__) +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wredundant-decls") +extern "C" __declspec(dllimport) void __stdcall DebugBreak(); +DOCTEST_GCC_SUPPRESS_WARNING_POP +#define DOCTEST_BREAK_INTO_DEBUGGER() ::DebugBreak() +#else // linux +#define DOCTEST_BREAK_INTO_DEBUGGER() (static_cast(0)) +#endif // linux +#endif // DOCTEST_BREAK_INTO_DEBUGGER + +// this is kept here for backwards compatibility since the config option was changed +#ifdef DOCTEST_CONFIG_USE_IOSFWD +#define DOCTEST_CONFIG_USE_STD_HEADERS +#endif // DOCTEST_CONFIG_USE_IOSFWD + +#ifdef DOCTEST_CONFIG_USE_STD_HEADERS +#ifndef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#define DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#include +#include +#include +#else // DOCTEST_CONFIG_USE_STD_HEADERS + +#if DOCTEST_CLANG +// to detect if libc++ is being used with clang (the _LIBCPP_VERSION identifier) +#include +#endif // clang + +#ifdef _LIBCPP_VERSION +#define DOCTEST_STD_NAMESPACE_BEGIN _LIBCPP_BEGIN_NAMESPACE_STD +#define DOCTEST_STD_NAMESPACE_END _LIBCPP_END_NAMESPACE_STD +#else // _LIBCPP_VERSION +#define DOCTEST_STD_NAMESPACE_BEGIN namespace std { +#define DOCTEST_STD_NAMESPACE_END } +#endif // _LIBCPP_VERSION + +// Forward declaring 'X' in namespace std is not permitted by the C++ Standard. +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4643) + +DOCTEST_STD_NAMESPACE_BEGIN // NOLINT (cert-dcl58-cpp) +typedef decltype(nullptr) nullptr_t; +template +struct char_traits; +template <> +struct char_traits; +template +class basic_ostream; +typedef basic_ostream> ostream; +template +class tuple; +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/onqtam/doctest/issues/183 +template +class allocator; +template +class basic_string; +using string = basic_string, allocator>; +#endif // VS 2019 +DOCTEST_STD_NAMESPACE_END + +DOCTEST_MSVC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_CONFIG_USE_STD_HEADERS + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#include +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + +namespace doctest { + +DOCTEST_INTERFACE extern bool is_running_in_test; + +// A 24 byte string class (can be as small as 17 for x64 and 13 for x86) that can hold strings with length +// of up to 23 chars on the stack before going on the heap - the last byte of the buffer is used for: +// - "is small" bit - the highest bit - if "0" then it is small - otherwise its "1" (128) +// - if small - capacity left before going on the heap - using the lowest 5 bits +// - if small - 2 bits are left unused - the second and third highest ones +// - if small - acts as a null terminator if strlen() is 23 (24 including the null terminator) +// and the "is small" bit remains "0" ("as well as the capacity left") so its OK +// Idea taken from this lecture about the string implementation of facebook/folly - fbstring +// https://www.youtube.com/watch?v=kPR8h4-qZdk +// TODO: +// - optimizations - like not deleting memory unnecessarily in operator= and etc. +// - resize/reserve/clear +// - substr +// - replace +// - back/front +// - iterator stuff +// - find & friends +// - push_back/pop_back +// - assign/insert/erase +// - relational operators as free functions - taking const char* as one of the params +class DOCTEST_INTERFACE String +{ + static const unsigned len = 24; //!OCLINT avoid private static members + static const unsigned last = len - 1; //!OCLINT avoid private static members + + struct view // len should be more than sizeof(view) - because of the final byte for flags + { + char* ptr; + unsigned size; + unsigned capacity; + }; + + union + { + char buf[len]; + view data; + }; + + bool isOnStack() const { return (buf[last] & 128) == 0; } + void setOnHeap(); + void setLast(unsigned in = last); + + void copy(const String& other); + +public: + String(); + ~String(); + + // cppcheck-suppress noExplicitConstructor + String(const char* in); + String(const char* in, unsigned in_size); + + String(const String& other); + String& operator=(const String& other); + + String& operator+=(const String& other); + String operator+(const String& other) const; + + String(String&& other); + String& operator=(String&& other); + + char operator[](unsigned i) const; + char& operator[](unsigned i); + + // the only functions I'm willing to leave in the interface - available for inlining + const char* c_str() const { return const_cast(this)->c_str(); } // NOLINT + char* c_str() { + if(isOnStack()) + return reinterpret_cast(buf); + return data.ptr; + } + + unsigned size() const; + unsigned capacity() const; + + int compare(const char* other, bool no_case = false) const; + int compare(const String& other, bool no_case = false) const; +}; + +DOCTEST_INTERFACE bool operator==(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator!=(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator<(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator>(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator<=(const String& lhs, const String& rhs); +DOCTEST_INTERFACE bool operator>=(const String& lhs, const String& rhs); + +DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, const String& in); + +namespace Color { + enum Enum + { + None = 0, + White, + Red, + Green, + Blue, + Cyan, + Yellow, + Grey, + + Bright = 0x10, + + BrightRed = Bright | Red, + BrightGreen = Bright | Green, + LightGrey = Bright | Grey, + BrightWhite = Bright | White + }; + + DOCTEST_INTERFACE std::ostream& operator<<(std::ostream& s, Color::Enum code); +} // namespace Color + +namespace assertType { + enum Enum + { + // macro traits + + is_warn = 1, + is_check = 2 * is_warn, + is_require = 2 * is_check, + + is_normal = 2 * is_require, + is_throws = 2 * is_normal, + is_throws_as = 2 * is_throws, + is_throws_with = 2 * is_throws_as, + is_nothrow = 2 * is_throws_with, + + is_false = 2 * is_nothrow, + is_unary = 2 * is_false, // not checked anywhere - used just to distinguish the types + + is_eq = 2 * is_unary, + is_ne = 2 * is_eq, + + is_lt = 2 * is_ne, + is_gt = 2 * is_lt, + + is_ge = 2 * is_gt, + is_le = 2 * is_ge, + + // macro types + + DT_WARN = is_normal | is_warn, + DT_CHECK = is_normal | is_check, + DT_REQUIRE = is_normal | is_require, + + DT_WARN_FALSE = is_normal | is_false | is_warn, + DT_CHECK_FALSE = is_normal | is_false | is_check, + DT_REQUIRE_FALSE = is_normal | is_false | is_require, + + DT_WARN_THROWS = is_throws | is_warn, + DT_CHECK_THROWS = is_throws | is_check, + DT_REQUIRE_THROWS = is_throws | is_require, + + DT_WARN_THROWS_AS = is_throws_as | is_warn, + DT_CHECK_THROWS_AS = is_throws_as | is_check, + DT_REQUIRE_THROWS_AS = is_throws_as | is_require, + + DT_WARN_THROWS_WITH = is_throws_with | is_warn, + DT_CHECK_THROWS_WITH = is_throws_with | is_check, + DT_REQUIRE_THROWS_WITH = is_throws_with | is_require, + + DT_WARN_THROWS_WITH_AS = is_throws_with | is_throws_as | is_warn, + DT_CHECK_THROWS_WITH_AS = is_throws_with | is_throws_as | is_check, + DT_REQUIRE_THROWS_WITH_AS = is_throws_with | is_throws_as | is_require, + + DT_WARN_NOTHROW = is_nothrow | is_warn, + DT_CHECK_NOTHROW = is_nothrow | is_check, + DT_REQUIRE_NOTHROW = is_nothrow | is_require, + + DT_WARN_EQ = is_normal | is_eq | is_warn, + DT_CHECK_EQ = is_normal | is_eq | is_check, + DT_REQUIRE_EQ = is_normal | is_eq | is_require, + + DT_WARN_NE = is_normal | is_ne | is_warn, + DT_CHECK_NE = is_normal | is_ne | is_check, + DT_REQUIRE_NE = is_normal | is_ne | is_require, + + DT_WARN_GT = is_normal | is_gt | is_warn, + DT_CHECK_GT = is_normal | is_gt | is_check, + DT_REQUIRE_GT = is_normal | is_gt | is_require, + + DT_WARN_LT = is_normal | is_lt | is_warn, + DT_CHECK_LT = is_normal | is_lt | is_check, + DT_REQUIRE_LT = is_normal | is_lt | is_require, + + DT_WARN_GE = is_normal | is_ge | is_warn, + DT_CHECK_GE = is_normal | is_ge | is_check, + DT_REQUIRE_GE = is_normal | is_ge | is_require, + + DT_WARN_LE = is_normal | is_le | is_warn, + DT_CHECK_LE = is_normal | is_le | is_check, + DT_REQUIRE_LE = is_normal | is_le | is_require, + + DT_WARN_UNARY = is_normal | is_unary | is_warn, + DT_CHECK_UNARY = is_normal | is_unary | is_check, + DT_REQUIRE_UNARY = is_normal | is_unary | is_require, + + DT_WARN_UNARY_FALSE = is_normal | is_false | is_unary | is_warn, + DT_CHECK_UNARY_FALSE = is_normal | is_false | is_unary | is_check, + DT_REQUIRE_UNARY_FALSE = is_normal | is_false | is_unary | is_require, + }; +} // namespace assertType + +DOCTEST_INTERFACE const char* assertString(assertType::Enum at); +DOCTEST_INTERFACE const char* failureString(assertType::Enum at); +DOCTEST_INTERFACE const char* skipPathFromFilename(const char* file); + +struct DOCTEST_INTERFACE TestCaseData +{ + String m_file; // the file in which the test was registered (using String - see #350) + unsigned m_line; // the line where the test was registered + const char* m_name; // name of the test case + const char* m_test_suite; // the test suite in which the test was added + const char* m_description; + bool m_skip; + bool m_no_breaks; + bool m_no_output; + bool m_may_fail; + bool m_should_fail; + int m_expected_failures; + double m_timeout; +}; + +struct DOCTEST_INTERFACE AssertData +{ + // common - for all asserts + const TestCaseData* m_test_case; + assertType::Enum m_at; + const char* m_file; + int m_line; + const char* m_expr; + bool m_failed; + + // exception-related - for all asserts + bool m_threw; + String m_exception; + + // for normal asserts + String m_decomp; + + // for specific exception-related asserts + bool m_threw_as; + const char* m_exception_type; + const char* m_exception_string; +}; + +struct DOCTEST_INTERFACE MessageData +{ + String m_string; + const char* m_file; + int m_line; + assertType::Enum m_severity; +}; + +struct DOCTEST_INTERFACE SubcaseSignature +{ + String m_name; + const char* m_file; + int m_line; + + bool operator<(const SubcaseSignature& other) const; +}; + +struct DOCTEST_INTERFACE IContextScope +{ + IContextScope(); + virtual ~IContextScope(); + virtual void stringify(std::ostream*) const = 0; +}; + +namespace detail { + struct DOCTEST_INTERFACE TestCase; +} // namespace detail + +struct ContextOptions //!OCLINT too many fields +{ + std::ostream* cout; // stdout stream - std::cout by default + std::ostream* cerr; // stderr stream - std::cerr by default + String binary_name; // the test binary name + + const detail::TestCase* currentTest = nullptr; + + // == parameters from the command line + String out; // output filename + String order_by; // how tests should be ordered + unsigned rand_seed; // the seed for rand ordering + + unsigned first; // the first (matching) test to be executed + unsigned last; // the last (matching) test to be executed + + int abort_after; // stop tests after this many failed assertions + int subcase_filter_levels; // apply the subcase filters for the first N levels + + bool success; // include successful assertions in output + bool case_sensitive; // if filtering should be case sensitive + bool exit; // if the program should be exited after the tests are ran/whatever + bool duration; // print the time duration of each test case + bool no_throw; // to skip exceptions-related assertion macros + bool no_exitcode; // if the framework should return 0 as the exitcode + bool no_run; // to not run the tests at all (can be done with an "*" exclude) + bool no_version; // to not print the version of the framework + bool no_colors; // if output to the console should be colorized + bool force_colors; // forces the use of colors even when a tty cannot be detected + bool no_breaks; // to not break into the debugger + bool no_skip; // don't skip test cases which are marked to be skipped + bool gnu_file_line; // if line numbers should be surrounded with :x: and not (x): + bool no_path_in_filenames; // if the path to files should be removed from the output + bool no_line_numbers; // if source code line numbers should be omitted from the output + bool no_debug_output; // no output in the debug console when a debugger is attached + bool no_skipped_summary; // don't print "skipped" in the summary !!! UNDOCUMENTED !!! + bool no_time_in_output; // omit any time/timestamps from output !!! UNDOCUMENTED !!! + + bool help; // to print the help + bool version; // to print the version + bool count; // if only the count of matching tests is to be retrieved + bool list_test_cases; // to list all tests matching the filters + bool list_test_suites; // to list all suites matching the filters + bool list_reporters; // lists all registered reporters +}; + +namespace detail { + template + struct enable_if + {}; + + template + struct enable_if + { typedef TYPE type; }; + + // clang-format off + template struct remove_reference { typedef T type; }; + template struct remove_reference { typedef T type; }; + template struct remove_reference { typedef T type; }; + + template U declval(int); + + template T declval(long); + + template auto declval() DOCTEST_NOEXCEPT -> decltype(declval(0)) ; + + template struct is_lvalue_reference { const static bool value=false; }; + template struct is_lvalue_reference { const static bool value=true; }; + + template + inline T&& forward(typename remove_reference::type& t) DOCTEST_NOEXCEPT + { + return static_cast(t); + } + + template + inline T&& forward(typename remove_reference::type&& t) DOCTEST_NOEXCEPT + { + static_assert(!is_lvalue_reference::value, + "Can not forward an rvalue as an lvalue."); + return static_cast(t); + } + + template struct remove_const { typedef T type; }; + template struct remove_const { typedef T type; }; +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template struct is_enum : public std::is_enum {}; + template struct underlying_type : public std::underlying_type {}; +#else + // Use compiler intrinsics + template struct is_enum { constexpr static bool value = __is_enum(T); }; + template struct underlying_type { typedef __underlying_type(T) type; }; +#endif + // clang-format on + + template + struct deferred_false + // cppcheck-suppress unusedStructMember + { static const bool value = false; }; + + namespace has_insertion_operator_impl { + std::ostream &os(); + template + DOCTEST_REF_WRAP(T) val(); + + template + struct check { + static constexpr bool value = false; + }; + + template + struct check(), void())> { + static constexpr bool value = true; + }; + } // namespace has_insertion_operator_impl + + template + using has_insertion_operator = has_insertion_operator_impl::check; + + DOCTEST_INTERFACE void my_memcpy(void* dest, const void* src, unsigned num); + + DOCTEST_INTERFACE std::ostream* getTlsOss(); // returns a thread-local ostringstream + DOCTEST_INTERFACE String getTlsOssResult(); + + template + struct StringMakerBase + { + template + static String convert(const DOCTEST_REF_WRAP(T)) { + return "{?}"; + } + }; + + template <> + struct StringMakerBase + { + template + static String convert(const DOCTEST_REF_WRAP(T) in) { + *getTlsOss() << in; + return getTlsOssResult(); + } + }; + + DOCTEST_INTERFACE String rawMemoryToString(const void* object, unsigned size); + + template + String rawMemoryToString(const DOCTEST_REF_WRAP(T) object) { + return rawMemoryToString(&object, sizeof(object)); + } + + template + const char* type_to_string() { + return "<>"; + } +} // namespace detail + +template +struct StringMaker : public detail::StringMakerBase::value> +{}; + +template +struct StringMaker +{ + template + static String convert(U* p) { + if(p) + return detail::rawMemoryToString(p); + return "NULL"; + } +}; + +template +struct StringMaker +{ + static String convert(R C::*p) { + if(p) + return detail::rawMemoryToString(p); + return "NULL"; + } +}; + +template ::value, bool>::type = true> +String toString(const DOCTEST_REF_WRAP(T) value) { + return StringMaker::convert(value); +} + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +DOCTEST_INTERFACE String toString(char* in); +DOCTEST_INTERFACE String toString(const char* in); +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +DOCTEST_INTERFACE String toString(bool in); +DOCTEST_INTERFACE String toString(float in); +DOCTEST_INTERFACE String toString(double in); +DOCTEST_INTERFACE String toString(double long in); + +DOCTEST_INTERFACE String toString(char in); +DOCTEST_INTERFACE String toString(char signed in); +DOCTEST_INTERFACE String toString(char unsigned in); +DOCTEST_INTERFACE String toString(int short in); +DOCTEST_INTERFACE String toString(int short unsigned in); +DOCTEST_INTERFACE String toString(int in); +DOCTEST_INTERFACE String toString(int unsigned in); +DOCTEST_INTERFACE String toString(int long in); +DOCTEST_INTERFACE String toString(int long unsigned in); +DOCTEST_INTERFACE String toString(int long long in); +DOCTEST_INTERFACE String toString(int long long unsigned in); +DOCTEST_INTERFACE String toString(std::nullptr_t in); + +template ::value, bool>::type = true> +String toString(const DOCTEST_REF_WRAP(T) value) { + typedef typename detail::underlying_type::type UT; + return toString(static_cast(value)); +} + +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/onqtam/doctest/issues/183 +DOCTEST_INTERFACE String toString(const std::string& in); +#endif // VS 2019 + +class DOCTEST_INTERFACE Approx +{ +public: + explicit Approx(double value); + + Approx operator()(double value) const; + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + explicit Approx(const T& value, + typename detail::enable_if::value>::type* = + static_cast(nullptr)) { + *this = Approx(static_cast(value)); + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + Approx& epsilon(double newEpsilon); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + typename detail::enable_if::value, Approx&>::type epsilon( + const T& newEpsilon) { + m_epsilon = static_cast(newEpsilon); + return *this; + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + Approx& scale(double newScale); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + template + typename detail::enable_if::value, Approx&>::type scale( + const T& newScale) { + m_scale = static_cast(newScale); + return *this; + } +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + // clang-format off + DOCTEST_INTERFACE friend bool operator==(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator==(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator!=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator!=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator<=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator<=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator>=(double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator>=(const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator< (double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator< (const Approx & lhs, double rhs); + DOCTEST_INTERFACE friend bool operator> (double lhs, const Approx & rhs); + DOCTEST_INTERFACE friend bool operator> (const Approx & lhs, double rhs); + + DOCTEST_INTERFACE friend String toString(const Approx& in); + +#ifdef DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS +#define DOCTEST_APPROX_PREFIX \ + template friend typename detail::enable_if::value, bool>::type + + DOCTEST_APPROX_PREFIX operator==(const T& lhs, const Approx& rhs) { return operator==(double(lhs), rhs); } + DOCTEST_APPROX_PREFIX operator==(const Approx& lhs, const T& rhs) { return operator==(rhs, lhs); } + DOCTEST_APPROX_PREFIX operator!=(const T& lhs, const Approx& rhs) { return !operator==(lhs, rhs); } + DOCTEST_APPROX_PREFIX operator!=(const Approx& lhs, const T& rhs) { return !operator==(rhs, lhs); } + DOCTEST_APPROX_PREFIX operator<=(const T& lhs, const Approx& rhs) { return double(lhs) < rhs.m_value || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator<=(const Approx& lhs, const T& rhs) { return lhs.m_value < double(rhs) || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator>=(const T& lhs, const Approx& rhs) { return double(lhs) > rhs.m_value || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator>=(const Approx& lhs, const T& rhs) { return lhs.m_value > double(rhs) || lhs == rhs; } + DOCTEST_APPROX_PREFIX operator< (const T& lhs, const Approx& rhs) { return double(lhs) < rhs.m_value && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator< (const Approx& lhs, const T& rhs) { return lhs.m_value < double(rhs) && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator> (const T& lhs, const Approx& rhs) { return double(lhs) > rhs.m_value && lhs != rhs; } + DOCTEST_APPROX_PREFIX operator> (const Approx& lhs, const T& rhs) { return lhs.m_value > double(rhs) && lhs != rhs; } +#undef DOCTEST_APPROX_PREFIX +#endif // DOCTEST_CONFIG_INCLUDE_TYPE_TRAITS + + // clang-format on + +private: + double m_epsilon; + double m_scale; + double m_value; +}; + +DOCTEST_INTERFACE String toString(const Approx& in); + +DOCTEST_INTERFACE const ContextOptions* getContextOptions(); + +#if !defined(DOCTEST_CONFIG_DISABLE) + +namespace detail { + // clang-format off +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + template struct decay_array { typedef T type; }; + template struct decay_array { typedef T* type; }; + template struct decay_array { typedef T* type; }; + + template struct not_char_pointer { enum { value = 1 }; }; + template<> struct not_char_pointer { enum { value = 0 }; }; + template<> struct not_char_pointer { enum { value = 0 }; }; + + template struct can_use_op : public not_char_pointer::type> {}; +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + // clang-format on + + struct DOCTEST_INTERFACE TestFailureException + { + }; + + DOCTEST_INTERFACE bool checkIfShouldThrow(assertType::Enum at); + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_NORETURN +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_INTERFACE void throwException(); + + struct DOCTEST_INTERFACE Subcase + { + SubcaseSignature m_signature; + bool m_entered = false; + + Subcase(const String& name, const char* file, int line); + ~Subcase(); + + operator bool() const; + }; + + template + String stringifyBinaryExpr(const DOCTEST_REF_WRAP(L) lhs, const char* op, + const DOCTEST_REF_WRAP(R) rhs) { + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + return doctest::toString(lhs) + op + doctest::toString(rhs); + } + +#if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-comparison") +#endif + +// This will check if there is any way it could find a operator like member or friend and uses it. +// If not it doesn't find the operator or if the operator at global scope is defined after +// this template, the template won't be instantiated due to SFINAE. Once the template is not +// instantiated it can look for global operator using normal conversions. +#define SFINAE_OP(ret,op) decltype((void)(doctest::detail::declval() op doctest::detail::declval()),static_cast(0)) + +#define DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(op, op_str, op_macro) \ + template \ + DOCTEST_NOINLINE SFINAE_OP(Result,op) operator op(R&& rhs) { \ + bool res = op_macro(doctest::detail::forward(lhs), doctest::detail::forward(rhs)); \ + if(m_at & assertType::is_false) \ + res = !res; \ + if(!res || doctest::getContextOptions()->success) \ + return Result(res, stringifyBinaryExpr(lhs, op_str, rhs)); \ + return Result(res); \ + } + + // more checks could be added - like in Catch: + // https://github.com/catchorg/Catch2/pull/1480/files + // https://github.com/catchorg/Catch2/pull/1481/files +#define DOCTEST_FORBIT_EXPRESSION(rt, op) \ + template \ + rt& operator op(const R&) { \ + static_assert(deferred_false::value, \ + "Expression Too Complex Please Rewrite As Binary Comparison!"); \ + return *this; \ + } + + struct DOCTEST_INTERFACE Result + { + bool m_passed; + String m_decomp; + + Result(bool passed, const String& decomposition = String()); + + // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence + DOCTEST_FORBIT_EXPRESSION(Result, &) + DOCTEST_FORBIT_EXPRESSION(Result, ^) + DOCTEST_FORBIT_EXPRESSION(Result, |) + DOCTEST_FORBIT_EXPRESSION(Result, &&) + DOCTEST_FORBIT_EXPRESSION(Result, ||) + DOCTEST_FORBIT_EXPRESSION(Result, ==) + DOCTEST_FORBIT_EXPRESSION(Result, !=) + DOCTEST_FORBIT_EXPRESSION(Result, <) + DOCTEST_FORBIT_EXPRESSION(Result, >) + DOCTEST_FORBIT_EXPRESSION(Result, <=) + DOCTEST_FORBIT_EXPRESSION(Result, >=) + DOCTEST_FORBIT_EXPRESSION(Result, =) + DOCTEST_FORBIT_EXPRESSION(Result, +=) + DOCTEST_FORBIT_EXPRESSION(Result, -=) + DOCTEST_FORBIT_EXPRESSION(Result, *=) + DOCTEST_FORBIT_EXPRESSION(Result, /=) + DOCTEST_FORBIT_EXPRESSION(Result, %=) + DOCTEST_FORBIT_EXPRESSION(Result, <<=) + DOCTEST_FORBIT_EXPRESSION(Result, >>=) + DOCTEST_FORBIT_EXPRESSION(Result, &=) + DOCTEST_FORBIT_EXPRESSION(Result, ^=) + DOCTEST_FORBIT_EXPRESSION(Result, |=) + }; + +#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + DOCTEST_CLANG_SUPPRESS_WARNING_PUSH + DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") + DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-compare") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wdouble-promotion") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wconversion") + //DOCTEST_CLANG_SUPPRESS_WARNING("-Wfloat-equal") + + DOCTEST_GCC_SUPPRESS_WARNING_PUSH + DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") + DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-compare") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wdouble-promotion") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") + //DOCTEST_GCC_SUPPRESS_WARNING("-Wfloat-equal") + + DOCTEST_MSVC_SUPPRESS_WARNING_PUSH + // https://stackoverflow.com/questions/39479163 what's the difference between 4018 and 4389 + DOCTEST_MSVC_SUPPRESS_WARNING(4388) // signed/unsigned mismatch + DOCTEST_MSVC_SUPPRESS_WARNING(4389) // 'operator' : signed/unsigned mismatch + DOCTEST_MSVC_SUPPRESS_WARNING(4018) // 'expression' : signed/unsigned mismatch + //DOCTEST_MSVC_SUPPRESS_WARNING(4805) // 'operation' : unsafe mix of type 'type' and type 'type' in operation + +#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + // clang-format off +#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_COMPARISON_RETURN_TYPE bool +#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_COMPARISON_RETURN_TYPE typename enable_if::value || can_use_op::value, bool>::type + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + inline bool eq(const char* lhs, const char* rhs) { return String(lhs) == String(rhs); } + inline bool ne(const char* lhs, const char* rhs) { return String(lhs) != String(rhs); } + inline bool lt(const char* lhs, const char* rhs) { return String(lhs) < String(rhs); } + inline bool gt(const char* lhs, const char* rhs) { return String(lhs) > String(rhs); } + inline bool le(const char* lhs, const char* rhs) { return String(lhs) <= String(rhs); } + inline bool ge(const char* lhs, const char* rhs) { return String(lhs) >= String(rhs); } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + // clang-format on + +#define DOCTEST_RELATIONAL_OP(name, op) \ + template \ + DOCTEST_COMPARISON_RETURN_TYPE name(const DOCTEST_REF_WRAP(L) lhs, \ + const DOCTEST_REF_WRAP(R) rhs) { \ + return lhs op rhs; \ + } + + DOCTEST_RELATIONAL_OP(eq, ==) + DOCTEST_RELATIONAL_OP(ne, !=) + DOCTEST_RELATIONAL_OP(lt, <) + DOCTEST_RELATIONAL_OP(gt, >) + DOCTEST_RELATIONAL_OP(le, <=) + DOCTEST_RELATIONAL_OP(ge, >=) + +#ifndef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_CMP_EQ(l, r) l == r +#define DOCTEST_CMP_NE(l, r) l != r +#define DOCTEST_CMP_GT(l, r) l > r +#define DOCTEST_CMP_LT(l, r) l < r +#define DOCTEST_CMP_GE(l, r) l >= r +#define DOCTEST_CMP_LE(l, r) l <= r +#else // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +#define DOCTEST_CMP_EQ(l, r) eq(l, r) +#define DOCTEST_CMP_NE(l, r) ne(l, r) +#define DOCTEST_CMP_GT(l, r) gt(l, r) +#define DOCTEST_CMP_LT(l, r) lt(l, r) +#define DOCTEST_CMP_GE(l, r) ge(l, r) +#define DOCTEST_CMP_LE(l, r) le(l, r) +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + + template + // cppcheck-suppress copyCtorAndEqOperator + struct Expression_lhs + { + L lhs; + assertType::Enum m_at; + + explicit Expression_lhs(L&& in, assertType::Enum at) + : lhs(doctest::detail::forward(in)) + , m_at(at) {} + + DOCTEST_NOINLINE operator Result() { +// this is needed only foc MSVC 2015: +// https://ci.appveyor.com/project/onqtam/doctest/builds/38181202 +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4800) // 'int': forcing value to bool + bool res = static_cast(lhs); +DOCTEST_MSVC_SUPPRESS_WARNING_POP + if(m_at & assertType::is_false) //!OCLINT bitwise operator in conditional + res = !res; + + if(!res || getContextOptions()->success) + return Result(res, doctest::toString(lhs)); + return Result(res); + } + + /* This is required for user-defined conversions from Expression_lhs to L */ + //operator L() const { return lhs; } + operator L() const { return lhs; } + + // clang-format off + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(==, " == ", DOCTEST_CMP_EQ) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(!=, " != ", DOCTEST_CMP_NE) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>, " > ", DOCTEST_CMP_GT) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<, " < ", DOCTEST_CMP_LT) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(>=, " >= ", DOCTEST_CMP_GE) //!OCLINT bitwise operator in conditional + DOCTEST_DO_BINARY_EXPRESSION_COMPARISON(<=, " <= ", DOCTEST_CMP_LE) //!OCLINT bitwise operator in conditional + // clang-format on + + // forbidding some expressions based on this table: https://en.cppreference.com/w/cpp/language/operator_precedence + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &&) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ||) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, =) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, +=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, -=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, *=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, /=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, %=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, &=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, ^=) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, |=) + // these 2 are unfortunate because they should be allowed - they have higher precedence over the comparisons, but the + // ExpressionDecomposer class uses the left shift operator to capture the left operand of the binary expression... + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, <<) + DOCTEST_FORBIT_EXPRESSION(Expression_lhs, >>) + }; + +#ifndef DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_CONFIG_NO_COMPARISON_WARNING_SUPPRESSION + +#if DOCTEST_CLANG && DOCTEST_CLANG < DOCTEST_COMPILER(3, 6, 0) +DOCTEST_CLANG_SUPPRESS_WARNING_POP +#endif + + struct DOCTEST_INTERFACE ExpressionDecomposer + { + assertType::Enum m_at; + + ExpressionDecomposer(assertType::Enum at); + + // The right operator for capturing expressions is "<=" instead of "<<" (based on the operator precedence table) + // but then there will be warnings from GCC about "-Wparentheses" and since "_Pragma()" is problematic this will stay for now... + // https://github.com/catchorg/Catch2/issues/870 + // https://github.com/catchorg/Catch2/issues/565 + template + Expression_lhs operator<<(L &&operand) { + return Expression_lhs(doctest::detail::forward(operand), m_at); + } + }; + + struct DOCTEST_INTERFACE TestSuite + { + const char* m_test_suite; + const char* m_description; + bool m_skip; + bool m_no_breaks; + bool m_no_output; + bool m_may_fail; + bool m_should_fail; + int m_expected_failures; + double m_timeout; + + TestSuite& operator*(const char* in); + + template + TestSuite& operator*(const T& in) { + in.fill(*this); + return *this; + } + }; + + typedef void (*funcType)(); + + struct DOCTEST_INTERFACE TestCase : public TestCaseData + { + funcType m_test; // a function pointer to the test case + + const char* m_type; // for templated test cases - gets appended to the real name + int m_template_id; // an ID used to distinguish between the different versions of a templated test case + String m_full_name; // contains the name (only for templated test cases!) + the template type + + TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, + const char* type = "", int template_id = -1); + + TestCase(const TestCase& other); + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function + TestCase& operator=(const TestCase& other); + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + TestCase& operator*(const char* in); + + template + TestCase& operator*(const T& in) { + in.fill(*this); + return *this; + } + + bool operator<(const TestCase& other) const; + }; + + // forward declarations of functions used by the macros + DOCTEST_INTERFACE int regTest(const TestCase& tc); + DOCTEST_INTERFACE int setTestSuite(const TestSuite& ts); + DOCTEST_INTERFACE bool isDebuggerActive(); + + template + int instantiationHelper(const T&) { return 0; } + + namespace binaryAssertComparison { + enum Enum + { + eq = 0, + ne, + gt, + lt, + ge, + le + }; + } // namespace binaryAssertComparison + + // clang-format off + template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L), const DOCTEST_REF_WRAP(R) ) const { return false; } }; + +#define DOCTEST_BINARY_RELATIONAL_OP(n, op) \ + template struct RelationalComparator { bool operator()(const DOCTEST_REF_WRAP(L) lhs, const DOCTEST_REF_WRAP(R) rhs) const { return op(lhs, rhs); } }; + // clang-format on + + DOCTEST_BINARY_RELATIONAL_OP(0, doctest::detail::eq) + DOCTEST_BINARY_RELATIONAL_OP(1, doctest::detail::ne) + DOCTEST_BINARY_RELATIONAL_OP(2, doctest::detail::gt) + DOCTEST_BINARY_RELATIONAL_OP(3, doctest::detail::lt) + DOCTEST_BINARY_RELATIONAL_OP(4, doctest::detail::ge) + DOCTEST_BINARY_RELATIONAL_OP(5, doctest::detail::le) + + struct DOCTEST_INTERFACE ResultBuilder : public AssertData + { + ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type = "", const char* exception_string = ""); + + void setResult(const Result& res); + + template + DOCTEST_NOINLINE void binary_assert(const DOCTEST_REF_WRAP(L) lhs, + const DOCTEST_REF_WRAP(R) rhs) { + m_failed = !RelationalComparator()(lhs, rhs); + if(m_failed || getContextOptions()->success) + m_decomp = stringifyBinaryExpr(lhs, ", ", rhs); + } + + template + DOCTEST_NOINLINE void unary_assert(const DOCTEST_REF_WRAP(L) val) { + m_failed = !val; + + if(m_at & assertType::is_false) //!OCLINT bitwise operator in conditional + m_failed = !m_failed; + + if(m_failed || getContextOptions()->success) + m_decomp = toString(val); + } + + void translateException(); + + bool log(); + void react() const; + }; + + namespace assertAction { + enum Enum + { + nothing = 0, + dbgbreak = 1, + shouldthrow = 2 + }; + } // namespace assertAction + + DOCTEST_INTERFACE void failed_out_of_a_testing_context(const AssertData& ad); + + DOCTEST_INTERFACE void decomp_assert(assertType::Enum at, const char* file, int line, + const char* expr, Result result); + +#define DOCTEST_ASSERT_OUT_OF_TESTS(decomp) \ + do { \ + if(!is_running_in_test) { \ + if(failed) { \ + ResultBuilder rb(at, file, line, expr); \ + rb.m_failed = failed; \ + rb.m_decomp = decomp; \ + failed_out_of_a_testing_context(rb); \ + if(isDebuggerActive() && !getContextOptions()->no_breaks) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + if(checkIfShouldThrow(at)) \ + throwException(); \ + } \ + return; \ + } \ + } while(false) + +#define DOCTEST_ASSERT_IN_TESTS(decomp) \ + ResultBuilder rb(at, file, line, expr); \ + rb.m_failed = failed; \ + if(rb.m_failed || getContextOptions()->success) \ + rb.m_decomp = decomp; \ + if(rb.log()) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + if(rb.m_failed && checkIfShouldThrow(at)) \ + throwException() + + template + DOCTEST_NOINLINE void binary_assert(assertType::Enum at, const char* file, int line, + const char* expr, const DOCTEST_REF_WRAP(L) lhs, + const DOCTEST_REF_WRAP(R) rhs) { + bool failed = !RelationalComparator()(lhs, rhs); + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); + DOCTEST_ASSERT_IN_TESTS(stringifyBinaryExpr(lhs, ", ", rhs)); + } + + template + DOCTEST_NOINLINE void unary_assert(assertType::Enum at, const char* file, int line, + const char* expr, const DOCTEST_REF_WRAP(L) val) { + bool failed = !val; + + if(at & assertType::is_false) //!OCLINT bitwise operator in conditional + failed = !failed; + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS(toString(val)); + DOCTEST_ASSERT_IN_TESTS(toString(val)); + } + + struct DOCTEST_INTERFACE IExceptionTranslator + { + IExceptionTranslator(); + virtual ~IExceptionTranslator(); + virtual bool translate(String&) const = 0; + }; + + template + class ExceptionTranslator : public IExceptionTranslator //!OCLINT destructor of virtual class + { + public: + explicit ExceptionTranslator(String (*translateFunction)(T)) + : m_translateFunction(translateFunction) {} + + bool translate(String& res) const override { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + try { + throw; // lgtm [cpp/rethrow-no-exception] + // cppcheck-suppress catchExceptionByValue + } catch(T ex) { // NOLINT + res = m_translateFunction(ex); //!OCLINT parameter reassignment + return true; + } catch(...) {} //!OCLINT - empty catch statement +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + static_cast(res); // to silence -Wunused-parameter + return false; + } + + private: + String (*m_translateFunction)(T); + }; + + DOCTEST_INTERFACE void registerExceptionTranslatorImpl(const IExceptionTranslator* et); + + template + struct StringStreamBase + { + template + static void convert(std::ostream* s, const T& in) { + *s << toString(in); + } + + // always treat char* as a string in this context - no matter + // if DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING is defined + static void convert(std::ostream* s, const char* in) { *s << String(in); } + }; + + template <> + struct StringStreamBase + { + template + static void convert(std::ostream* s, const T& in) { + *s << in; + } + }; + + template + struct StringStream : public StringStreamBase::value> + {}; + + template + void toStream(std::ostream* s, const T& value) { + StringStream::convert(s, value); + } + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + DOCTEST_INTERFACE void toStream(std::ostream* s, char* in); + DOCTEST_INTERFACE void toStream(std::ostream* s, const char* in); +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + DOCTEST_INTERFACE void toStream(std::ostream* s, bool in); + DOCTEST_INTERFACE void toStream(std::ostream* s, float in); + DOCTEST_INTERFACE void toStream(std::ostream* s, double in); + DOCTEST_INTERFACE void toStream(std::ostream* s, double long in); + + DOCTEST_INTERFACE void toStream(std::ostream* s, char in); + DOCTEST_INTERFACE void toStream(std::ostream* s, char signed in); + DOCTEST_INTERFACE void toStream(std::ostream* s, char unsigned in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int short in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int short unsigned in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int unsigned in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int long in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int long unsigned in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int long long in); + DOCTEST_INTERFACE void toStream(std::ostream* s, int long long unsigned in); + + // ContextScope base class used to allow implementing methods of ContextScope + // that don't depend on the template parameter in doctest.cpp. + class DOCTEST_INTERFACE ContextScopeBase : public IContextScope { + protected: + ContextScopeBase(); + + void destroy(); + }; + + template class ContextScope : public ContextScopeBase + { + const L lambda_; + + public: + explicit ContextScope(const L &lambda) : lambda_(lambda) {} + + ContextScope(ContextScope &&other) : lambda_(other.lambda_) {} + + void stringify(std::ostream* s) const override { lambda_(s); } + + ~ContextScope() override { destroy(); } + }; + + struct DOCTEST_INTERFACE MessageBuilder : public MessageData + { + std::ostream* m_stream; + + MessageBuilder(const char* file, int line, assertType::Enum severity); + MessageBuilder() = delete; + ~MessageBuilder(); + + // the preferred way of chaining parameters for stringification + template + MessageBuilder& operator,(const T& in) { + toStream(m_stream, in); + return *this; + } + + // kept here just for backwards-compatibility - the comma operator should be preferred now + template + MessageBuilder& operator<<(const T& in) { return this->operator,(in); } + + // the `,` operator has the lowest operator precedence - if `<<` is used by the user then + // the `,` operator will be called last which is not what we want and thus the `*` operator + // is used first (has higher operator precedence compared to `<<`) so that we guarantee that + // an operator of the MessageBuilder class is called first before the rest of the parameters + template + MessageBuilder& operator*(const T& in) { return this->operator,(in); } + + bool log(); + void react(); + }; + + template + ContextScope MakeContextScope(const L &lambda) { + return ContextScope(lambda); + } +} // namespace detail + +#define DOCTEST_DEFINE_DECORATOR(name, type, def) \ + struct name \ + { \ + type data; \ + name(type in = def) \ + : data(in) {} \ + void fill(detail::TestCase& state) const { state.DOCTEST_CAT(m_, name) = data; } \ + void fill(detail::TestSuite& state) const { state.DOCTEST_CAT(m_, name) = data; } \ + } + +DOCTEST_DEFINE_DECORATOR(test_suite, const char*, ""); +DOCTEST_DEFINE_DECORATOR(description, const char*, ""); +DOCTEST_DEFINE_DECORATOR(skip, bool, true); +DOCTEST_DEFINE_DECORATOR(no_breaks, bool, true); +DOCTEST_DEFINE_DECORATOR(no_output, bool, true); +DOCTEST_DEFINE_DECORATOR(timeout, double, 0); +DOCTEST_DEFINE_DECORATOR(may_fail, bool, true); +DOCTEST_DEFINE_DECORATOR(should_fail, bool, true); +DOCTEST_DEFINE_DECORATOR(expected_failures, int, 0); + +template +int registerExceptionTranslator(String (*translateFunction)(T)) { + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") + static detail::ExceptionTranslator exceptionTranslator(translateFunction); + DOCTEST_CLANG_SUPPRESS_WARNING_POP + detail::registerExceptionTranslatorImpl(&exceptionTranslator); + return 0; +} + +} // namespace doctest + +// in a separate namespace outside of doctest because the DOCTEST_TEST_SUITE macro +// introduces an anonymous namespace in which getCurrentTestSuite gets overridden +namespace doctest_detail_test_suite_ns { +DOCTEST_INTERFACE doctest::detail::TestSuite& getCurrentTestSuite(); +} // namespace doctest_detail_test_suite_ns + +namespace doctest { +#else // DOCTEST_CONFIG_DISABLE +template +int registerExceptionTranslator(String (*)(T)) { + return 0; +} +#endif // DOCTEST_CONFIG_DISABLE + +namespace detail { + typedef void (*assert_handler)(const AssertData&); + struct ContextState; +} // namespace detail + +class DOCTEST_INTERFACE Context +{ + detail::ContextState* p; + + void parseArgs(int argc, const char* const* argv, bool withDefaults = false); + +public: + explicit Context(int argc = 0, const char* const* argv = nullptr); + + ~Context(); + + void applyCommandLine(int argc, const char* const* argv); + + void addFilter(const char* filter, const char* value); + void clearFilters(); + void setOption(const char* option, int value); + void setOption(const char* option, const char* value); + + bool shouldExit(); + + void setAsDefaultForAssertsOutOfTestCases(); + + void setAssertHandler(detail::assert_handler ah); + + int run(); +}; + +namespace TestCaseFailureReason { + enum Enum + { + None = 0, + AssertFailure = 1, // an assertion has failed in the test case + Exception = 2, // test case threw an exception + Crash = 4, // a crash... + TooManyFailedAsserts = 8, // the abort-after option + Timeout = 16, // see the timeout decorator + ShouldHaveFailedButDidnt = 32, // see the should_fail decorator + ShouldHaveFailedAndDid = 64, // see the should_fail decorator + DidntFailExactlyNumTimes = 128, // see the expected_failures decorator + FailedExactlyNumTimes = 256, // see the expected_failures decorator + CouldHaveFailedAndDid = 512 // see the may_fail decorator + }; +} // namespace TestCaseFailureReason + +struct DOCTEST_INTERFACE CurrentTestCaseStats +{ + int numAssertsCurrentTest; + int numAssertsFailedCurrentTest; + double seconds; + int failure_flags; // use TestCaseFailureReason::Enum +}; + +struct DOCTEST_INTERFACE TestCaseException +{ + String error_string; + bool is_crash; +}; + +struct DOCTEST_INTERFACE TestRunStats +{ + unsigned numTestCases; + unsigned numTestCasesPassingFilters; + unsigned numTestSuitesPassingFilters; + unsigned numTestCasesFailed; + int numAsserts; + int numAssertsFailed; +}; + +struct QueryData +{ + const TestRunStats* run_stats = nullptr; + const TestCaseData** data = nullptr; + unsigned num_data = 0; +}; + +struct DOCTEST_INTERFACE IReporter +{ + // The constructor has to accept "const ContextOptions&" as a single argument + // which has most of the options for the run + a pointer to the stdout stream + // Reporter(const ContextOptions& in) + + // called when a query should be reported (listing test cases, printing the version, etc.) + virtual void report_query(const QueryData&) = 0; + + // called when the whole test run starts + virtual void test_run_start() = 0; + // called when the whole test run ends (caching a pointer to the input doesn't make sense here) + virtual void test_run_end(const TestRunStats&) = 0; + + // called when a test case is started (safe to cache a pointer to the input) + virtual void test_case_start(const TestCaseData&) = 0; + // called when a test case is reentered because of unfinished subcases (safe to cache a pointer to the input) + virtual void test_case_reenter(const TestCaseData&) = 0; + // called when a test case has ended + virtual void test_case_end(const CurrentTestCaseStats&) = 0; + + // called when an exception is thrown from the test case (or it crashes) + virtual void test_case_exception(const TestCaseException&) = 0; + + // called whenever a subcase is entered (don't cache pointers to the input) + virtual void subcase_start(const SubcaseSignature&) = 0; + // called whenever a subcase is exited (don't cache pointers to the input) + virtual void subcase_end() = 0; + + // called for each assert (don't cache pointers to the input) + virtual void log_assert(const AssertData&) = 0; + // called for each message (don't cache pointers to the input) + virtual void log_message(const MessageData&) = 0; + + // called when a test case is skipped either because it doesn't pass the filters, has a skip decorator + // or isn't in the execution range (between first and last) (safe to cache a pointer to the input) + virtual void test_case_skipped(const TestCaseData&) = 0; + + // doctest will not be managing the lifetimes of reporters given to it but this would still be nice to have + virtual ~IReporter(); + + // can obtain all currently active contexts and stringify them if one wishes to do so + static int get_num_active_contexts(); + static const IContextScope* const* get_active_contexts(); + + // can iterate through contexts which have been stringified automatically in their destructors when an exception has been thrown + static int get_num_stringified_contexts(); + static const String* get_stringified_contexts(); +}; + +namespace detail { + typedef IReporter* (*reporterCreatorFunc)(const ContextOptions&); + + DOCTEST_INTERFACE void registerReporterImpl(const char* name, int prio, reporterCreatorFunc c, bool isReporter); + + template + IReporter* reporterCreator(const ContextOptions& o) { + return new Reporter(o); + } +} // namespace detail + +template +int registerReporter(const char* name, int priority, bool isReporter) { + detail::registerReporterImpl(name, priority, detail::reporterCreator, isReporter); + return 0; +} +} // namespace doctest + +// if registering is not disabled +#if !defined(DOCTEST_CONFIG_DISABLE) + +// common code in asserts - for convenience +#define DOCTEST_ASSERT_LOG_AND_REACT(b) \ + if(b.log()) \ + DOCTEST_BREAK_INTO_DEBUGGER(); \ + b.react() + +#ifdef DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#define DOCTEST_WRAP_IN_TRY(x) x; +#else // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS +#define DOCTEST_WRAP_IN_TRY(x) \ + try { \ + x; \ + } catch(...) { _DOCTEST_RB.translateException(); } +#endif // DOCTEST_CONFIG_NO_TRY_CATCH_IN_ASSERTS + +#ifdef DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS +#define DOCTEST_CAST_TO_VOID(...) \ + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wuseless-cast") \ + static_cast(__VA_ARGS__); \ + DOCTEST_GCC_SUPPRESS_WARNING_POP +#else // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS +#define DOCTEST_CAST_TO_VOID(...) __VA_ARGS__; +#endif // DOCTEST_CONFIG_VOID_CAST_EXPRESSIONS + +// registers the test by initializing a dummy var with a function +#define DOCTEST_REGISTER_FUNCTION(global_prefix, f, decorators) \ + global_prefix DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_VAR_)) = \ + doctest::detail::regTest( \ + doctest::detail::TestCase( \ + f, __FILE__, __LINE__, \ + doctest_detail_test_suite_ns::getCurrentTestSuite()) * \ + decorators); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() + +#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, decorators) \ + namespace { \ + struct der : public base \ + { \ + void f(); \ + }; \ + static void func() { \ + der v; \ + v.f(); \ + } \ + DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, func, decorators) \ + } \ + inline DOCTEST_NOINLINE void der::f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, decorators) \ + static void f(); \ + DOCTEST_REGISTER_FUNCTION(DOCTEST_EMPTY, f, decorators) \ + static void f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(f, proxy, decorators) \ + static doctest::detail::funcType proxy() { return f; } \ + DOCTEST_REGISTER_FUNCTION(inline const, proxy(), decorators) \ + static void f() + +// for registering tests +#define DOCTEST_TEST_CASE(decorators) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), decorators) + +// for registering tests in classes - requires C++17 for inline variables! +#if __cplusplus >= 201703L || (DOCTEST_MSVC >= DOCTEST_COMPILER(19, 12, 0) && _MSVC_LANG >= 201703L) +#define DOCTEST_TEST_CASE_CLASS(decorators) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION_IN_CLASS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), \ + DOCTEST_ANONYMOUS(_DOCTEST_ANON_PROXY_), \ + decorators) +#else // DOCTEST_TEST_CASE_CLASS +#define DOCTEST_TEST_CASE_CLASS(...) \ + TEST_CASES_CAN_BE_REGISTERED_IN_CLASSES_ONLY_IN_CPP17_MODE_OR_WITH_VS_2017_OR_NEWER +#endif // DOCTEST_TEST_CASE_CLASS + +// for registering tests with a fixture +#define DOCTEST_TEST_CASE_FIXTURE(c, decorators) \ + DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(_DOCTEST_ANON_CLASS_), c, \ + DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), decorators) + +// for converting types to strings without the header and demangling +#define DOCTEST_TYPE_TO_STRING_IMPL(...) \ + template <> \ + inline const char* type_to_string<__VA_ARGS__>() { \ + return "<" #__VA_ARGS__ ">"; \ + } +#define DOCTEST_TYPE_TO_STRING(...) \ + namespace doctest { namespace detail { \ + DOCTEST_TYPE_TO_STRING_IMPL(__VA_ARGS__) \ + } \ + } \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, iter, func) \ + template \ + static void func(); \ + namespace { \ + template \ + struct iter; \ + template \ + struct iter> \ + { \ + iter(const char* file, unsigned line, int index) { \ + doctest::detail::regTest(doctest::detail::TestCase(func, file, line, \ + doctest_detail_test_suite_ns::getCurrentTestSuite(), \ + doctest::detail::type_to_string(), \ + int(line) * 1000 + index) \ + * dec); \ + iter>(file, line, index + 1); \ + } \ + }; \ + template <> \ + struct iter> \ + { \ + iter(const char*, unsigned, int) {} \ + }; \ + } \ + template \ + static void func() + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(dec, T, id) \ + DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(id, ITERATOR), \ + DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_)) + +#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, anon, ...) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_CAT(anon, DUMMY)) = \ + doctest::detail::instantiationHelper(DOCTEST_CAT(id, ITERATOR)<__VA_ARGS__>(__FILE__, __LINE__, 0));\ + DOCTEST_GLOBAL_NO_WARNINGS_END() + +#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_), std::tuple<__VA_ARGS__>) \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(id, DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_), __VA_ARGS__) \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +#define DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, anon, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_DEFINE_IMPL(dec, T, DOCTEST_CAT(anon, ITERATOR), anon); \ + DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE_IMPL(anon, anon, std::tuple<__VA_ARGS__>) \ + template \ + static void anon() + +#define DOCTEST_TEST_CASE_TEMPLATE(dec, T, ...) \ + DOCTEST_TEST_CASE_TEMPLATE_IMPL(dec, T, DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_), __VA_ARGS__) + +// for subcases +#define DOCTEST_SUBCASE(name) \ + if(const doctest::detail::Subcase & DOCTEST_ANONYMOUS(_DOCTEST_ANON_SUBCASE_) DOCTEST_UNUSED = \ + doctest::detail::Subcase(name, __FILE__, __LINE__)) + +// for grouping tests in test suites by using code blocks +#define DOCTEST_TEST_SUITE_IMPL(decorators, ns_name) \ + namespace ns_name { namespace doctest_detail_test_suite_ns { \ + static DOCTEST_NOINLINE doctest::detail::TestSuite& getCurrentTestSuite() { \ + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4640) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wexit-time-destructors") \ + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmissing-field-initializers") \ + static doctest::detail::TestSuite data{}; \ + static bool inited = false; \ + DOCTEST_MSVC_SUPPRESS_WARNING_POP \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP \ + DOCTEST_GCC_SUPPRESS_WARNING_POP \ + if(!inited) { \ + data* decorators; \ + inited = true; \ + } \ + return data; \ + } \ + } \ + } \ + namespace ns_name + +#define DOCTEST_TEST_SUITE(decorators) \ + DOCTEST_TEST_SUITE_IMPL(decorators, DOCTEST_ANONYMOUS(_DOCTEST_ANON_SUITE_)) + +// for starting a testsuite block +#define DOCTEST_TEST_SUITE_BEGIN(decorators) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_VAR_)) = \ + doctest::detail::setTestSuite(doctest::detail::TestSuite() * decorators); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for ending a testsuite block +#define DOCTEST_TEST_SUITE_END \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_VAR_)) = \ + doctest::detail::setTestSuite(doctest::detail::TestSuite() * ""); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for registering exception translators +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(translatorName, signature) \ + inline doctest::String translatorName(signature); \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_TRANSLATOR_)) = \ + doctest::registerExceptionTranslator(translatorName); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() \ + doctest::String translatorName(signature) + +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ + DOCTEST_REGISTER_EXCEPTION_TRANSLATOR_IMPL(DOCTEST_ANONYMOUS(_DOCTEST_ANON_TRANSLATOR_), \ + signature) + +// for registering reporters +#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_REPORTER_)) = \ + doctest::registerReporter(name, priority, true); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for registering listeners +#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) \ + DOCTEST_GLOBAL_NO_WARNINGS(DOCTEST_ANONYMOUS(_DOCTEST_ANON_REPORTER_)) = \ + doctest::registerReporter(name, priority, false); \ + DOCTEST_GLOBAL_NO_WARNINGS_END() typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for logging +#define DOCTEST_INFO(...) \ + DOCTEST_INFO_IMPL(DOCTEST_ANONYMOUS(_DOCTEST_CAPTURE_), DOCTEST_ANONYMOUS(_DOCTEST_CAPTURE_), \ + __VA_ARGS__) + +#define DOCTEST_INFO_IMPL(mb_name, s_name, ...) \ + auto DOCTEST_ANONYMOUS(_DOCTEST_CAPTURE_) = doctest::detail::MakeContextScope( \ + [&](std::ostream* s_name) { \ + doctest::detail::MessageBuilder mb_name(__FILE__, __LINE__, doctest::assertType::is_warn); \ + mb_name.m_stream = s_name; \ + mb_name * __VA_ARGS__; \ + }) + +#define DOCTEST_CAPTURE(x) DOCTEST_INFO(#x " := ", x) + +#define DOCTEST_ADD_AT_IMPL(type, file, line, mb, ...) \ + do { \ + doctest::detail::MessageBuilder mb(file, line, doctest::assertType::type); \ + mb * __VA_ARGS__; \ + DOCTEST_ASSERT_LOG_AND_REACT(mb); \ + } while(false) + +// clang-format off +#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_warn, file, line, DOCTEST_ANONYMOUS(_DOCTEST_MESSAGE_), __VA_ARGS__) +#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_check, file, line, DOCTEST_ANONYMOUS(_DOCTEST_MESSAGE_), __VA_ARGS__) +#define DOCTEST_ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_AT_IMPL(is_require, file, line, DOCTEST_ANONYMOUS(_DOCTEST_MESSAGE_), __VA_ARGS__) +// clang-format on + +#define DOCTEST_MESSAGE(...) DOCTEST_ADD_MESSAGE_AT(__FILE__, __LINE__, __VA_ARGS__) +#define DOCTEST_FAIL_CHECK(...) DOCTEST_ADD_FAIL_CHECK_AT(__FILE__, __LINE__, __VA_ARGS__) +#define DOCTEST_FAIL(...) DOCTEST_ADD_FAIL_AT(__FILE__, __LINE__, __VA_ARGS__) + +#define DOCTEST_TO_LVALUE(...) __VA_ARGS__ // Not removed to keep backwards compatibility. + +#ifndef DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_ASSERT_IMPLEMENT_2(assert_type, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY(_DOCTEST_RB.setResult( \ + doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ + << __VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB) \ + DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ + do { \ + DOCTEST_ASSERT_IMPLEMENT_2(assert_type, __VA_ARGS__); \ + } while(false) + +#else // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +// necessary for _MESSAGE +#define DOCTEST_ASSERT_IMPLEMENT_2 DOCTEST_ASSERT_IMPLEMENT_1 + +#define DOCTEST_ASSERT_IMPLEMENT_1(assert_type, ...) \ + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Woverloaded-shift-op-parentheses") \ + doctest::detail::decomp_assert( \ + doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, \ + doctest::detail::ExpressionDecomposer(doctest::assertType::assert_type) \ + << __VA_ARGS__) DOCTEST_CLANG_SUPPRESS_WARNING_POP + +#endif // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_WARN(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN, __VA_ARGS__) +#define DOCTEST_CHECK(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK, __VA_ARGS__) +#define DOCTEST_REQUIRE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE, __VA_ARGS__) +#define DOCTEST_WARN_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_WARN_FALSE, __VA_ARGS__) +#define DOCTEST_CHECK_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_CHECK_FALSE, __VA_ARGS__) +#define DOCTEST_REQUIRE_FALSE(...) DOCTEST_ASSERT_IMPLEMENT_1(DT_REQUIRE_FALSE, __VA_ARGS__) + +// clang-format off +#define DOCTEST_WARN_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN, cond); } while(false) +#define DOCTEST_CHECK_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK, cond); } while(false) +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE, cond); } while(false) +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_WARN_FALSE, cond); } while(false) +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_CHECK_FALSE, cond); } while(false) +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_ASSERT_IMPLEMENT_2(DT_REQUIRE_FALSE, cond); } while(false) +// clang-format on + +#define DOCTEST_ASSERT_THROWS_AS(expr, assert_type, message, ...) \ + do { \ + if(!doctest::getContextOptions()->no_throw) { \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #expr, #__VA_ARGS__, message); \ + try { \ + DOCTEST_CAST_TO_VOID(expr) \ + } catch(const typename doctest::detail::remove_const< \ + typename doctest::detail::remove_reference<__VA_ARGS__>::type>::type&) { \ + _DOCTEST_RB.translateException(); \ + _DOCTEST_RB.m_threw_as = true; \ + } catch(...) { _DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ + } \ + } while(false) + +#define DOCTEST_ASSERT_THROWS_WITH(expr, expr_str, assert_type, ...) \ + do { \ + if(!doctest::getContextOptions()->no_throw) { \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, expr_str, "", __VA_ARGS__); \ + try { \ + DOCTEST_CAST_TO_VOID(expr) \ + } catch(...) { _DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ + } \ + } while(false) + +#define DOCTEST_ASSERT_NOTHROW(assert_type, ...) \ + do { \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + try { \ + DOCTEST_CAST_TO_VOID(__VA_ARGS__) \ + } catch(...) { _DOCTEST_RB.translateException(); } \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ + } while(false) + +// clang-format off +#define DOCTEST_WARN_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_WARN_THROWS, "") +#define DOCTEST_CHECK_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_CHECK_THROWS, "") +#define DOCTEST_REQUIRE_THROWS(...) DOCTEST_ASSERT_THROWS_WITH((__VA_ARGS__), #__VA_ARGS__, DT_REQUIRE_THROWS, "") + +#define DOCTEST_WARN_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_AS, "", __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_AS, "", __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_AS, "", __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_WARN_THROWS_WITH, __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_CHECK_THROWS_WITH, __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) DOCTEST_ASSERT_THROWS_WITH(expr, #expr, DT_REQUIRE_THROWS_WITH, __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_WARN_THROWS_WITH_AS, message, __VA_ARGS__) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_CHECK_THROWS_WITH_AS, message, __VA_ARGS__) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, message, ...) DOCTEST_ASSERT_THROWS_AS(expr, DT_REQUIRE_THROWS_WITH_AS, message, __VA_ARGS__) + +#define DOCTEST_WARN_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_WARN_NOTHROW, __VA_ARGS__) +#define DOCTEST_CHECK_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_CHECK_NOTHROW, __VA_ARGS__) +#define DOCTEST_REQUIRE_NOTHROW(...) DOCTEST_ASSERT_NOTHROW(DT_REQUIRE_NOTHROW, __VA_ARGS__) + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS(expr); } while(false) +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS(expr); } while(false) +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS(expr); } while(false) +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_AS(expr, ex); } while(false) +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_AS(expr, ex); } while(false) +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_AS(expr, ex); } while(false) +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH(expr, with); } while(false) +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH(expr, with); } while(false) +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH(expr, with); } while(false) +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_THROWS_WITH_AS(expr, with, ex); } while(false) +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ex); } while(false) +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ex); } while(false) +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_WARN_NOTHROW(expr); } while(false) +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_CHECK_NOTHROW(expr); } while(false) +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) do { DOCTEST_INFO(__VA_ARGS__); DOCTEST_REQUIRE_NOTHROW(expr); } while(false) +// clang-format on + +#ifndef DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_BINARY_ASSERT(assert_type, comp, ...) \ + do { \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY( \ + _DOCTEST_RB.binary_assert( \ + __VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ + } while(false) + +#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ + do { \ + doctest::detail::ResultBuilder _DOCTEST_RB(doctest::assertType::assert_type, __FILE__, \ + __LINE__, #__VA_ARGS__); \ + DOCTEST_WRAP_IN_TRY(_DOCTEST_RB.unary_assert(__VA_ARGS__)) \ + DOCTEST_ASSERT_LOG_AND_REACT(_DOCTEST_RB); \ + } while(false) + +#else // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_BINARY_ASSERT(assert_type, comparison, ...) \ + doctest::detail::binary_assert( \ + doctest::assertType::assert_type, __FILE__, __LINE__, #__VA_ARGS__, __VA_ARGS__) + +#define DOCTEST_UNARY_ASSERT(assert_type, ...) \ + doctest::detail::unary_assert(doctest::assertType::assert_type, __FILE__, __LINE__, \ + #__VA_ARGS__, __VA_ARGS__) + +#endif // DOCTEST_CONFIG_SUPER_FAST_ASSERTS + +#define DOCTEST_WARN_EQ(...) DOCTEST_BINARY_ASSERT(DT_WARN_EQ, eq, __VA_ARGS__) +#define DOCTEST_CHECK_EQ(...) DOCTEST_BINARY_ASSERT(DT_CHECK_EQ, eq, __VA_ARGS__) +#define DOCTEST_REQUIRE_EQ(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_EQ, eq, __VA_ARGS__) +#define DOCTEST_WARN_NE(...) DOCTEST_BINARY_ASSERT(DT_WARN_NE, ne, __VA_ARGS__) +#define DOCTEST_CHECK_NE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_NE, ne, __VA_ARGS__) +#define DOCTEST_REQUIRE_NE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_NE, ne, __VA_ARGS__) +#define DOCTEST_WARN_GT(...) DOCTEST_BINARY_ASSERT(DT_WARN_GT, gt, __VA_ARGS__) +#define DOCTEST_CHECK_GT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GT, gt, __VA_ARGS__) +#define DOCTEST_REQUIRE_GT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GT, gt, __VA_ARGS__) +#define DOCTEST_WARN_LT(...) DOCTEST_BINARY_ASSERT(DT_WARN_LT, lt, __VA_ARGS__) +#define DOCTEST_CHECK_LT(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LT, lt, __VA_ARGS__) +#define DOCTEST_REQUIRE_LT(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LT, lt, __VA_ARGS__) +#define DOCTEST_WARN_GE(...) DOCTEST_BINARY_ASSERT(DT_WARN_GE, ge, __VA_ARGS__) +#define DOCTEST_CHECK_GE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_GE, ge, __VA_ARGS__) +#define DOCTEST_REQUIRE_GE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_GE, ge, __VA_ARGS__) +#define DOCTEST_WARN_LE(...) DOCTEST_BINARY_ASSERT(DT_WARN_LE, le, __VA_ARGS__) +#define DOCTEST_CHECK_LE(...) DOCTEST_BINARY_ASSERT(DT_CHECK_LE, le, __VA_ARGS__) +#define DOCTEST_REQUIRE_LE(...) DOCTEST_BINARY_ASSERT(DT_REQUIRE_LE, le, __VA_ARGS__) + +#define DOCTEST_WARN_UNARY(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY, __VA_ARGS__) +#define DOCTEST_CHECK_UNARY(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY, __VA_ARGS__) +#define DOCTEST_REQUIRE_UNARY(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY, __VA_ARGS__) +#define DOCTEST_WARN_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_WARN_UNARY_FALSE, __VA_ARGS__) +#define DOCTEST_CHECK_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_CHECK_UNARY_FALSE, __VA_ARGS__) +#define DOCTEST_REQUIRE_UNARY_FALSE(...) DOCTEST_UNARY_ASSERT(DT_REQUIRE_UNARY_FALSE, __VA_ARGS__) + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS + +#undef DOCTEST_WARN_THROWS +#undef DOCTEST_CHECK_THROWS +#undef DOCTEST_REQUIRE_THROWS +#undef DOCTEST_WARN_THROWS_AS +#undef DOCTEST_CHECK_THROWS_AS +#undef DOCTEST_REQUIRE_THROWS_AS +#undef DOCTEST_WARN_THROWS_WITH +#undef DOCTEST_CHECK_THROWS_WITH +#undef DOCTEST_REQUIRE_THROWS_WITH +#undef DOCTEST_WARN_THROWS_WITH_AS +#undef DOCTEST_CHECK_THROWS_WITH_AS +#undef DOCTEST_REQUIRE_THROWS_WITH_AS +#undef DOCTEST_WARN_NOTHROW +#undef DOCTEST_CHECK_NOTHROW +#undef DOCTEST_REQUIRE_NOTHROW + +#undef DOCTEST_WARN_THROWS_MESSAGE +#undef DOCTEST_CHECK_THROWS_MESSAGE +#undef DOCTEST_REQUIRE_THROWS_MESSAGE +#undef DOCTEST_WARN_THROWS_AS_MESSAGE +#undef DOCTEST_CHECK_THROWS_AS_MESSAGE +#undef DOCTEST_REQUIRE_THROWS_AS_MESSAGE +#undef DOCTEST_WARN_THROWS_WITH_MESSAGE +#undef DOCTEST_CHECK_THROWS_WITH_MESSAGE +#undef DOCTEST_REQUIRE_THROWS_WITH_MESSAGE +#undef DOCTEST_WARN_THROWS_WITH_AS_MESSAGE +#undef DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE +#undef DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE +#undef DOCTEST_WARN_NOTHROW_MESSAGE +#undef DOCTEST_CHECK_NOTHROW_MESSAGE +#undef DOCTEST_REQUIRE_NOTHROW_MESSAGE + +#ifdef DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#define DOCTEST_WARN_THROWS(...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS(...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS(...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_AS(expr, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_AS(expr, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_WITH(expr, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) +#define DOCTEST_WARN_NOTHROW(...) (static_cast(0)) +#define DOCTEST_CHECK_NOTHROW(...) (static_cast(0)) +#define DOCTEST_REQUIRE_NOTHROW(...) (static_cast(0)) + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) + +#else // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#undef DOCTEST_REQUIRE +#undef DOCTEST_REQUIRE_FALSE +#undef DOCTEST_REQUIRE_MESSAGE +#undef DOCTEST_REQUIRE_FALSE_MESSAGE +#undef DOCTEST_REQUIRE_EQ +#undef DOCTEST_REQUIRE_NE +#undef DOCTEST_REQUIRE_GT +#undef DOCTEST_REQUIRE_LT +#undef DOCTEST_REQUIRE_GE +#undef DOCTEST_REQUIRE_LE +#undef DOCTEST_REQUIRE_UNARY +#undef DOCTEST_REQUIRE_UNARY_FALSE + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS_BUT_WITH_ALL_ASSERTS + +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + +// ================================================================================================= +// == WHAT FOLLOWS IS VERSIONS OF THE MACROS THAT DO NOT DO ANY REGISTERING! == +// == THIS CAN BE ENABLED BY DEFINING DOCTEST_CONFIG_DISABLE GLOBALLY! == +// ================================================================================================= +#else // DOCTEST_CONFIG_DISABLE + +#define DOCTEST_IMPLEMENT_FIXTURE(der, base, func, name) \ + namespace { \ + template \ + struct der : public base \ + { void f(); }; \ + } \ + template \ + inline void der::f() + +#define DOCTEST_CREATE_AND_REGISTER_FUNCTION(f, name) \ + template \ + static inline void f() + +// for registering tests +#define DOCTEST_TEST_CASE(name) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), name) + +// for registering tests in classes +#define DOCTEST_TEST_CASE_CLASS(name) \ + DOCTEST_CREATE_AND_REGISTER_FUNCTION(DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), name) + +// for registering tests with a fixture +#define DOCTEST_TEST_CASE_FIXTURE(x, name) \ + DOCTEST_IMPLEMENT_FIXTURE(DOCTEST_ANONYMOUS(_DOCTEST_ANON_CLASS_), x, \ + DOCTEST_ANONYMOUS(_DOCTEST_ANON_FUNC_), name) + +// for converting types to strings without the header and demangling +#define DOCTEST_TYPE_TO_STRING(...) typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) +#define DOCTEST_TYPE_TO_STRING_IMPL(...) + +// for typed tests +#define DOCTEST_TEST_CASE_TEMPLATE(name, type, ...) \ + template \ + inline void DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_)() + +#define DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, type, id) \ + template \ + inline void DOCTEST_ANONYMOUS(_DOCTEST_ANON_TMP_)() + +#define DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, ...) \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +#define DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, ...) \ + typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for subcases +#define DOCTEST_SUBCASE(name) + +// for a testsuite block +#define DOCTEST_TEST_SUITE(name) namespace + +// for starting a testsuite block +#define DOCTEST_TEST_SUITE_BEGIN(name) typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +// for ending a testsuite block +#define DOCTEST_TEST_SUITE_END typedef int DOCTEST_ANONYMOUS(_DOCTEST_ANON_FOR_SEMICOLON_) + +#define DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) \ + template \ + static inline doctest::String DOCTEST_ANONYMOUS(_DOCTEST_ANON_TRANSLATOR_)(signature) + +#define DOCTEST_REGISTER_REPORTER(name, priority, reporter) +#define DOCTEST_REGISTER_LISTENER(name, priority, reporter) + +#define DOCTEST_INFO(...) (static_cast(0)) +#define DOCTEST_CAPTURE(x) (static_cast(0)) +#define DOCTEST_ADD_MESSAGE_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_ADD_FAIL_CHECK_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_ADD_FAIL_AT(file, line, ...) (static_cast(0)) +#define DOCTEST_MESSAGE(...) (static_cast(0)) +#define DOCTEST_FAIL_CHECK(...) (static_cast(0)) +#define DOCTEST_FAIL(...) (static_cast(0)) + +#define DOCTEST_WARN(...) (static_cast(0)) +#define DOCTEST_CHECK(...) (static_cast(0)) +#define DOCTEST_REQUIRE(...) (static_cast(0)) +#define DOCTEST_WARN_FALSE(...) (static_cast(0)) +#define DOCTEST_CHECK_FALSE(...) (static_cast(0)) +#define DOCTEST_REQUIRE_FALSE(...) (static_cast(0)) + +#define DOCTEST_WARN_MESSAGE(cond, ...) (static_cast(0)) +#define DOCTEST_CHECK_MESSAGE(cond, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_MESSAGE(cond, ...) (static_cast(0)) +#define DOCTEST_WARN_FALSE_MESSAGE(cond, ...) (static_cast(0)) +#define DOCTEST_CHECK_FALSE_MESSAGE(cond, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_FALSE_MESSAGE(cond, ...) (static_cast(0)) + +#define DOCTEST_WARN_THROWS(...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS(...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS(...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_AS(expr, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_AS(expr, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_AS(expr, ...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_WITH(expr, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_WITH(expr, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_WITH(expr, ...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, ...) (static_cast(0)) +#define DOCTEST_WARN_NOTHROW(...) (static_cast(0)) +#define DOCTEST_CHECK_NOTHROW(...) (static_cast(0)) +#define DOCTEST_REQUIRE_NOTHROW(...) (static_cast(0)) + +#define DOCTEST_WARN_THROWS_MESSAGE(expr, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_MESSAGE(expr, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_MESSAGE(expr, ...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) (static_cast(0)) +#define DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) +#define DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) (static_cast(0)) +#define DOCTEST_WARN_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) +#define DOCTEST_CHECK_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) +#define DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, ...) (static_cast(0)) + +#define DOCTEST_WARN_EQ(...) (static_cast(0)) +#define DOCTEST_CHECK_EQ(...) (static_cast(0)) +#define DOCTEST_REQUIRE_EQ(...) (static_cast(0)) +#define DOCTEST_WARN_NE(...) (static_cast(0)) +#define DOCTEST_CHECK_NE(...) (static_cast(0)) +#define DOCTEST_REQUIRE_NE(...) (static_cast(0)) +#define DOCTEST_WARN_GT(...) (static_cast(0)) +#define DOCTEST_CHECK_GT(...) (static_cast(0)) +#define DOCTEST_REQUIRE_GT(...) (static_cast(0)) +#define DOCTEST_WARN_LT(...) (static_cast(0)) +#define DOCTEST_CHECK_LT(...) (static_cast(0)) +#define DOCTEST_REQUIRE_LT(...) (static_cast(0)) +#define DOCTEST_WARN_GE(...) (static_cast(0)) +#define DOCTEST_CHECK_GE(...) (static_cast(0)) +#define DOCTEST_REQUIRE_GE(...) (static_cast(0)) +#define DOCTEST_WARN_LE(...) (static_cast(0)) +#define DOCTEST_CHECK_LE(...) (static_cast(0)) +#define DOCTEST_REQUIRE_LE(...) (static_cast(0)) + +#define DOCTEST_WARN_UNARY(...) (static_cast(0)) +#define DOCTEST_CHECK_UNARY(...) (static_cast(0)) +#define DOCTEST_REQUIRE_UNARY(...) (static_cast(0)) +#define DOCTEST_WARN_UNARY_FALSE(...) (static_cast(0)) +#define DOCTEST_CHECK_UNARY_FALSE(...) (static_cast(0)) +#define DOCTEST_REQUIRE_UNARY_FALSE(...) (static_cast(0)) + +#endif // DOCTEST_CONFIG_DISABLE + +// clang-format off +// KEPT FOR BACKWARDS COMPATIBILITY - FORWARDING TO THE RIGHT MACROS +#define DOCTEST_FAST_WARN_EQ DOCTEST_WARN_EQ +#define DOCTEST_FAST_CHECK_EQ DOCTEST_CHECK_EQ +#define DOCTEST_FAST_REQUIRE_EQ DOCTEST_REQUIRE_EQ +#define DOCTEST_FAST_WARN_NE DOCTEST_WARN_NE +#define DOCTEST_FAST_CHECK_NE DOCTEST_CHECK_NE +#define DOCTEST_FAST_REQUIRE_NE DOCTEST_REQUIRE_NE +#define DOCTEST_FAST_WARN_GT DOCTEST_WARN_GT +#define DOCTEST_FAST_CHECK_GT DOCTEST_CHECK_GT +#define DOCTEST_FAST_REQUIRE_GT DOCTEST_REQUIRE_GT +#define DOCTEST_FAST_WARN_LT DOCTEST_WARN_LT +#define DOCTEST_FAST_CHECK_LT DOCTEST_CHECK_LT +#define DOCTEST_FAST_REQUIRE_LT DOCTEST_REQUIRE_LT +#define DOCTEST_FAST_WARN_GE DOCTEST_WARN_GE +#define DOCTEST_FAST_CHECK_GE DOCTEST_CHECK_GE +#define DOCTEST_FAST_REQUIRE_GE DOCTEST_REQUIRE_GE +#define DOCTEST_FAST_WARN_LE DOCTEST_WARN_LE +#define DOCTEST_FAST_CHECK_LE DOCTEST_CHECK_LE +#define DOCTEST_FAST_REQUIRE_LE DOCTEST_REQUIRE_LE + +#define DOCTEST_FAST_WARN_UNARY DOCTEST_WARN_UNARY +#define DOCTEST_FAST_CHECK_UNARY DOCTEST_CHECK_UNARY +#define DOCTEST_FAST_REQUIRE_UNARY DOCTEST_REQUIRE_UNARY +#define DOCTEST_FAST_WARN_UNARY_FALSE DOCTEST_WARN_UNARY_FALSE +#define DOCTEST_FAST_CHECK_UNARY_FALSE DOCTEST_CHECK_UNARY_FALSE +#define DOCTEST_FAST_REQUIRE_UNARY_FALSE DOCTEST_REQUIRE_UNARY_FALSE + +#define DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id,__VA_ARGS__) +// clang-format on + +// BDD style macros +// clang-format off +#define DOCTEST_SCENARIO(name) DOCTEST_TEST_CASE(" Scenario: " name) +#define DOCTEST_SCENARIO_CLASS(name) DOCTEST_TEST_CASE_CLASS(" Scenario: " name) +#define DOCTEST_SCENARIO_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(" Scenario: " name, T, __VA_ARGS__) +#define DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(" Scenario: " name, T, id) + +#define DOCTEST_GIVEN(name) DOCTEST_SUBCASE(" Given: " name) +#define DOCTEST_WHEN(name) DOCTEST_SUBCASE(" When: " name) +#define DOCTEST_AND_WHEN(name) DOCTEST_SUBCASE("And when: " name) +#define DOCTEST_THEN(name) DOCTEST_SUBCASE(" Then: " name) +#define DOCTEST_AND_THEN(name) DOCTEST_SUBCASE(" And: " name) +// clang-format on + +// == SHORT VERSIONS OF THE MACROS +#if !defined(DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES) + +#define TEST_CASE(name) DOCTEST_TEST_CASE(name) +#define TEST_CASE_CLASS(name) DOCTEST_TEST_CASE_CLASS(name) +#define TEST_CASE_FIXTURE(x, name) DOCTEST_TEST_CASE_FIXTURE(x, name) +#define TYPE_TO_STRING(...) DOCTEST_TYPE_TO_STRING(__VA_ARGS__) +#define TEST_CASE_TEMPLATE(name, T, ...) DOCTEST_TEST_CASE_TEMPLATE(name, T, __VA_ARGS__) +#define TEST_CASE_TEMPLATE_DEFINE(name, T, id) DOCTEST_TEST_CASE_TEMPLATE_DEFINE(name, T, id) +#define TEST_CASE_TEMPLATE_INVOKE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INVOKE(id, __VA_ARGS__) +#define TEST_CASE_TEMPLATE_APPLY(id, ...) DOCTEST_TEST_CASE_TEMPLATE_APPLY(id, __VA_ARGS__) +#define SUBCASE(name) DOCTEST_SUBCASE(name) +#define TEST_SUITE(decorators) DOCTEST_TEST_SUITE(decorators) +#define TEST_SUITE_BEGIN(name) DOCTEST_TEST_SUITE_BEGIN(name) +#define TEST_SUITE_END DOCTEST_TEST_SUITE_END +#define REGISTER_EXCEPTION_TRANSLATOR(signature) DOCTEST_REGISTER_EXCEPTION_TRANSLATOR(signature) +#define REGISTER_REPORTER(name, priority, reporter) DOCTEST_REGISTER_REPORTER(name, priority, reporter) +#define REGISTER_LISTENER(name, priority, reporter) DOCTEST_REGISTER_LISTENER(name, priority, reporter) +#define INFO(...) DOCTEST_INFO(__VA_ARGS__) +#define CAPTURE(x) DOCTEST_CAPTURE(x) +#define ADD_MESSAGE_AT(file, line, ...) DOCTEST_ADD_MESSAGE_AT(file, line, __VA_ARGS__) +#define ADD_FAIL_CHECK_AT(file, line, ...) DOCTEST_ADD_FAIL_CHECK_AT(file, line, __VA_ARGS__) +#define ADD_FAIL_AT(file, line, ...) DOCTEST_ADD_FAIL_AT(file, line, __VA_ARGS__) +#define MESSAGE(...) DOCTEST_MESSAGE(__VA_ARGS__) +#define FAIL_CHECK(...) DOCTEST_FAIL_CHECK(__VA_ARGS__) +#define FAIL(...) DOCTEST_FAIL(__VA_ARGS__) +#define TO_LVALUE(...) DOCTEST_TO_LVALUE(__VA_ARGS__) + +#define WARN(...) DOCTEST_WARN(__VA_ARGS__) +#define WARN_FALSE(...) DOCTEST_WARN_FALSE(__VA_ARGS__) +#define WARN_THROWS(...) DOCTEST_WARN_THROWS(__VA_ARGS__) +#define WARN_THROWS_AS(expr, ...) DOCTEST_WARN_THROWS_AS(expr, __VA_ARGS__) +#define WARN_THROWS_WITH(expr, ...) DOCTEST_WARN_THROWS_WITH(expr, __VA_ARGS__) +#define WARN_THROWS_WITH_AS(expr, with, ...) DOCTEST_WARN_THROWS_WITH_AS(expr, with, __VA_ARGS__) +#define WARN_NOTHROW(...) DOCTEST_WARN_NOTHROW(__VA_ARGS__) +#define CHECK(...) DOCTEST_CHECK(__VA_ARGS__) +#define CHECK_FALSE(...) DOCTEST_CHECK_FALSE(__VA_ARGS__) +#define CHECK_THROWS(...) DOCTEST_CHECK_THROWS(__VA_ARGS__) +#define CHECK_THROWS_AS(expr, ...) DOCTEST_CHECK_THROWS_AS(expr, __VA_ARGS__) +#define CHECK_THROWS_WITH(expr, ...) DOCTEST_CHECK_THROWS_WITH(expr, __VA_ARGS__) +#define CHECK_THROWS_WITH_AS(expr, with, ...) DOCTEST_CHECK_THROWS_WITH_AS(expr, with, __VA_ARGS__) +#define CHECK_NOTHROW(...) DOCTEST_CHECK_NOTHROW(__VA_ARGS__) +#define REQUIRE(...) DOCTEST_REQUIRE(__VA_ARGS__) +#define REQUIRE_FALSE(...) DOCTEST_REQUIRE_FALSE(__VA_ARGS__) +#define REQUIRE_THROWS(...) DOCTEST_REQUIRE_THROWS(__VA_ARGS__) +#define REQUIRE_THROWS_AS(expr, ...) DOCTEST_REQUIRE_THROWS_AS(expr, __VA_ARGS__) +#define REQUIRE_THROWS_WITH(expr, ...) DOCTEST_REQUIRE_THROWS_WITH(expr, __VA_ARGS__) +#define REQUIRE_THROWS_WITH_AS(expr, with, ...) DOCTEST_REQUIRE_THROWS_WITH_AS(expr, with, __VA_ARGS__) +#define REQUIRE_NOTHROW(...) DOCTEST_REQUIRE_NOTHROW(__VA_ARGS__) + +#define WARN_MESSAGE(cond, ...) DOCTEST_WARN_MESSAGE(cond, __VA_ARGS__) +#define WARN_FALSE_MESSAGE(cond, ...) DOCTEST_WARN_FALSE_MESSAGE(cond, __VA_ARGS__) +#define WARN_THROWS_MESSAGE(expr, ...) DOCTEST_WARN_THROWS_MESSAGE(expr, __VA_ARGS__) +#define WARN_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_WARN_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) +#define WARN_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_WARN_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) +#define WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_WARN_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) +#define WARN_NOTHROW_MESSAGE(expr, ...) DOCTEST_WARN_NOTHROW_MESSAGE(expr, __VA_ARGS__) +#define CHECK_MESSAGE(cond, ...) DOCTEST_CHECK_MESSAGE(cond, __VA_ARGS__) +#define CHECK_FALSE_MESSAGE(cond, ...) DOCTEST_CHECK_FALSE_MESSAGE(cond, __VA_ARGS__) +#define CHECK_THROWS_MESSAGE(expr, ...) DOCTEST_CHECK_THROWS_MESSAGE(expr, __VA_ARGS__) +#define CHECK_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_CHECK_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) +#define CHECK_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_CHECK_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) +#define CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_CHECK_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) +#define CHECK_NOTHROW_MESSAGE(expr, ...) DOCTEST_CHECK_NOTHROW_MESSAGE(expr, __VA_ARGS__) +#define REQUIRE_MESSAGE(cond, ...) DOCTEST_REQUIRE_MESSAGE(cond, __VA_ARGS__) +#define REQUIRE_FALSE_MESSAGE(cond, ...) DOCTEST_REQUIRE_FALSE_MESSAGE(cond, __VA_ARGS__) +#define REQUIRE_THROWS_MESSAGE(expr, ...) DOCTEST_REQUIRE_THROWS_MESSAGE(expr, __VA_ARGS__) +#define REQUIRE_THROWS_AS_MESSAGE(expr, ex, ...) DOCTEST_REQUIRE_THROWS_AS_MESSAGE(expr, ex, __VA_ARGS__) +#define REQUIRE_THROWS_WITH_MESSAGE(expr, with, ...) DOCTEST_REQUIRE_THROWS_WITH_MESSAGE(expr, with, __VA_ARGS__) +#define REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, ...) DOCTEST_REQUIRE_THROWS_WITH_AS_MESSAGE(expr, with, ex, __VA_ARGS__) +#define REQUIRE_NOTHROW_MESSAGE(expr, ...) DOCTEST_REQUIRE_NOTHROW_MESSAGE(expr, __VA_ARGS__) + +#define SCENARIO(name) DOCTEST_SCENARIO(name) +#define SCENARIO_CLASS(name) DOCTEST_SCENARIO_CLASS(name) +#define SCENARIO_TEMPLATE(name, T, ...) DOCTEST_SCENARIO_TEMPLATE(name, T, __VA_ARGS__) +#define SCENARIO_TEMPLATE_DEFINE(name, T, id) DOCTEST_SCENARIO_TEMPLATE_DEFINE(name, T, id) +#define GIVEN(name) DOCTEST_GIVEN(name) +#define WHEN(name) DOCTEST_WHEN(name) +#define AND_WHEN(name) DOCTEST_AND_WHEN(name) +#define THEN(name) DOCTEST_THEN(name) +#define AND_THEN(name) DOCTEST_AND_THEN(name) + +#define WARN_EQ(...) DOCTEST_WARN_EQ(__VA_ARGS__) +#define CHECK_EQ(...) DOCTEST_CHECK_EQ(__VA_ARGS__) +#define REQUIRE_EQ(...) DOCTEST_REQUIRE_EQ(__VA_ARGS__) +#define WARN_NE(...) DOCTEST_WARN_NE(__VA_ARGS__) +#define CHECK_NE(...) DOCTEST_CHECK_NE(__VA_ARGS__) +#define REQUIRE_NE(...) DOCTEST_REQUIRE_NE(__VA_ARGS__) +#define WARN_GT(...) DOCTEST_WARN_GT(__VA_ARGS__) +#define CHECK_GT(...) DOCTEST_CHECK_GT(__VA_ARGS__) +#define REQUIRE_GT(...) DOCTEST_REQUIRE_GT(__VA_ARGS__) +#define WARN_LT(...) DOCTEST_WARN_LT(__VA_ARGS__) +#define CHECK_LT(...) DOCTEST_CHECK_LT(__VA_ARGS__) +#define REQUIRE_LT(...) DOCTEST_REQUIRE_LT(__VA_ARGS__) +#define WARN_GE(...) DOCTEST_WARN_GE(__VA_ARGS__) +#define CHECK_GE(...) DOCTEST_CHECK_GE(__VA_ARGS__) +#define REQUIRE_GE(...) DOCTEST_REQUIRE_GE(__VA_ARGS__) +#define WARN_LE(...) DOCTEST_WARN_LE(__VA_ARGS__) +#define CHECK_LE(...) DOCTEST_CHECK_LE(__VA_ARGS__) +#define REQUIRE_LE(...) DOCTEST_REQUIRE_LE(__VA_ARGS__) +#define WARN_UNARY(...) DOCTEST_WARN_UNARY(__VA_ARGS__) +#define CHECK_UNARY(...) DOCTEST_CHECK_UNARY(__VA_ARGS__) +#define REQUIRE_UNARY(...) DOCTEST_REQUIRE_UNARY(__VA_ARGS__) +#define WARN_UNARY_FALSE(...) DOCTEST_WARN_UNARY_FALSE(__VA_ARGS__) +#define CHECK_UNARY_FALSE(...) DOCTEST_CHECK_UNARY_FALSE(__VA_ARGS__) +#define REQUIRE_UNARY_FALSE(...) DOCTEST_REQUIRE_UNARY_FALSE(__VA_ARGS__) + +// KEPT FOR BACKWARDS COMPATIBILITY +#define FAST_WARN_EQ(...) DOCTEST_FAST_WARN_EQ(__VA_ARGS__) +#define FAST_CHECK_EQ(...) DOCTEST_FAST_CHECK_EQ(__VA_ARGS__) +#define FAST_REQUIRE_EQ(...) DOCTEST_FAST_REQUIRE_EQ(__VA_ARGS__) +#define FAST_WARN_NE(...) DOCTEST_FAST_WARN_NE(__VA_ARGS__) +#define FAST_CHECK_NE(...) DOCTEST_FAST_CHECK_NE(__VA_ARGS__) +#define FAST_REQUIRE_NE(...) DOCTEST_FAST_REQUIRE_NE(__VA_ARGS__) +#define FAST_WARN_GT(...) DOCTEST_FAST_WARN_GT(__VA_ARGS__) +#define FAST_CHECK_GT(...) DOCTEST_FAST_CHECK_GT(__VA_ARGS__) +#define FAST_REQUIRE_GT(...) DOCTEST_FAST_REQUIRE_GT(__VA_ARGS__) +#define FAST_WARN_LT(...) DOCTEST_FAST_WARN_LT(__VA_ARGS__) +#define FAST_CHECK_LT(...) DOCTEST_FAST_CHECK_LT(__VA_ARGS__) +#define FAST_REQUIRE_LT(...) DOCTEST_FAST_REQUIRE_LT(__VA_ARGS__) +#define FAST_WARN_GE(...) DOCTEST_FAST_WARN_GE(__VA_ARGS__) +#define FAST_CHECK_GE(...) DOCTEST_FAST_CHECK_GE(__VA_ARGS__) +#define FAST_REQUIRE_GE(...) DOCTEST_FAST_REQUIRE_GE(__VA_ARGS__) +#define FAST_WARN_LE(...) DOCTEST_FAST_WARN_LE(__VA_ARGS__) +#define FAST_CHECK_LE(...) DOCTEST_FAST_CHECK_LE(__VA_ARGS__) +#define FAST_REQUIRE_LE(...) DOCTEST_FAST_REQUIRE_LE(__VA_ARGS__) + +#define FAST_WARN_UNARY(...) DOCTEST_FAST_WARN_UNARY(__VA_ARGS__) +#define FAST_CHECK_UNARY(...) DOCTEST_FAST_CHECK_UNARY(__VA_ARGS__) +#define FAST_REQUIRE_UNARY(...) DOCTEST_FAST_REQUIRE_UNARY(__VA_ARGS__) +#define FAST_WARN_UNARY_FALSE(...) DOCTEST_FAST_WARN_UNARY_FALSE(__VA_ARGS__) +#define FAST_CHECK_UNARY_FALSE(...) DOCTEST_FAST_CHECK_UNARY_FALSE(__VA_ARGS__) +#define FAST_REQUIRE_UNARY_FALSE(...) DOCTEST_FAST_REQUIRE_UNARY_FALSE(__VA_ARGS__) + +#define TEST_CASE_TEMPLATE_INSTANTIATE(id, ...) DOCTEST_TEST_CASE_TEMPLATE_INSTANTIATE(id, __VA_ARGS__) + +#endif // DOCTEST_CONFIG_NO_SHORT_MACRO_NAMES + +#if !defined(DOCTEST_CONFIG_DISABLE) + +// this is here to clear the 'current test suite' for the current translation unit - at the top +DOCTEST_TEST_SUITE_END(); + +// add stringification for primitive/fundamental types +namespace doctest { namespace detail { + DOCTEST_TYPE_TO_STRING_IMPL(bool) + DOCTEST_TYPE_TO_STRING_IMPL(float) + DOCTEST_TYPE_TO_STRING_IMPL(double) + DOCTEST_TYPE_TO_STRING_IMPL(long double) + DOCTEST_TYPE_TO_STRING_IMPL(char) + DOCTEST_TYPE_TO_STRING_IMPL(signed char) + DOCTEST_TYPE_TO_STRING_IMPL(unsigned char) +#if !DOCTEST_MSVC || defined(_NATIVE_WCHAR_T_DEFINED) + DOCTEST_TYPE_TO_STRING_IMPL(wchar_t) +#endif // not MSVC or wchar_t support enabled + DOCTEST_TYPE_TO_STRING_IMPL(short int) + DOCTEST_TYPE_TO_STRING_IMPL(unsigned short int) + DOCTEST_TYPE_TO_STRING_IMPL(int) + DOCTEST_TYPE_TO_STRING_IMPL(unsigned int) + DOCTEST_TYPE_TO_STRING_IMPL(long int) + DOCTEST_TYPE_TO_STRING_IMPL(unsigned long int) + DOCTEST_TYPE_TO_STRING_IMPL(long long int) + DOCTEST_TYPE_TO_STRING_IMPL(unsigned long long int) +}} // namespace doctest::detail + +#endif // DOCTEST_CONFIG_DISABLE + +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_MSVC_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_LIBRARY_INCLUDED + +#ifndef DOCTEST_SINGLE_HEADER +#define DOCTEST_SINGLE_HEADER +#endif // DOCTEST_SINGLE_HEADER + +#if defined(DOCTEST_CONFIG_IMPLEMENT) || !defined(DOCTEST_SINGLE_HEADER) + +#ifndef DOCTEST_SINGLE_HEADER +#include "doctest_fwd.h" +#endif // DOCTEST_SINGLE_HEADER + +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wunused-macros") + +#ifndef DOCTEST_LIBRARY_IMPLEMENTATION +#define DOCTEST_LIBRARY_IMPLEMENTATION + +DOCTEST_CLANG_SUPPRESS_WARNING_POP + +DOCTEST_CLANG_SUPPRESS_WARNING_PUSH +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunknown-pragmas") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wpadded") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wweak-vtables") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wglobal-constructors") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wexit-time-destructors") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-prototypes") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wsign-conversion") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wshorten-64-to-32") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-variable-declarations") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wswitch-enum") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wcovered-switch-default") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-noreturn") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-local-typedef") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wdisabled-macro-expansion") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-braces") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wmissing-field-initializers") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wc++98-compat-pedantic") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wunused-member-function") +DOCTEST_CLANG_SUPPRESS_WARNING("-Wnonportable-system-include-path") + +DOCTEST_GCC_SUPPRESS_WARNING_PUSH +DOCTEST_GCC_SUPPRESS_WARNING("-Wunknown-pragmas") +DOCTEST_GCC_SUPPRESS_WARNING("-Wpragmas") +DOCTEST_GCC_SUPPRESS_WARNING("-Wconversion") +DOCTEST_GCC_SUPPRESS_WARNING("-Weffc++") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsign-conversion") +DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-overflow") +DOCTEST_GCC_SUPPRESS_WARNING("-Wstrict-aliasing") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-field-initializers") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-braces") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmissing-declarations") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-enum") +DOCTEST_GCC_SUPPRESS_WARNING("-Wswitch-default") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunsafe-loop-optimizations") +DOCTEST_GCC_SUPPRESS_WARNING("-Wold-style-cast") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-local-typedefs") +DOCTEST_GCC_SUPPRESS_WARNING("-Wuseless-cast") +DOCTEST_GCC_SUPPRESS_WARNING("-Wunused-function") +DOCTEST_GCC_SUPPRESS_WARNING("-Wmultiple-inheritance") +DOCTEST_GCC_SUPPRESS_WARNING("-Wnoexcept") +DOCTEST_GCC_SUPPRESS_WARNING("-Wsuggest-attribute") + +DOCTEST_MSVC_SUPPRESS_WARNING_PUSH +DOCTEST_MSVC_SUPPRESS_WARNING(4616) // invalid compiler warning +DOCTEST_MSVC_SUPPRESS_WARNING(4619) // invalid compiler warning +DOCTEST_MSVC_SUPPRESS_WARNING(4996) // The compiler encountered a deprecated declaration +DOCTEST_MSVC_SUPPRESS_WARNING(4267) // 'var' : conversion from 'x' to 'y', possible loss of data +DOCTEST_MSVC_SUPPRESS_WARNING(4706) // assignment within conditional expression +DOCTEST_MSVC_SUPPRESS_WARNING(4512) // 'class' : assignment operator could not be generated +DOCTEST_MSVC_SUPPRESS_WARNING(4127) // conditional expression is constant +DOCTEST_MSVC_SUPPRESS_WARNING(4530) // C++ exception handler used, but unwind semantics not enabled +DOCTEST_MSVC_SUPPRESS_WARNING(4577) // 'noexcept' used with no exception handling mode specified +DOCTEST_MSVC_SUPPRESS_WARNING(4774) // format string expected in argument is not a string literal +DOCTEST_MSVC_SUPPRESS_WARNING(4365) // conversion from 'int' to 'unsigned', signed/unsigned mismatch +DOCTEST_MSVC_SUPPRESS_WARNING(4820) // padding in structs +DOCTEST_MSVC_SUPPRESS_WARNING(4640) // construction of local static object is not thread-safe +DOCTEST_MSVC_SUPPRESS_WARNING(5039) // pointer to potentially throwing function passed to extern C +DOCTEST_MSVC_SUPPRESS_WARNING(5045) // Spectre mitigation stuff +DOCTEST_MSVC_SUPPRESS_WARNING(4626) // assignment operator was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(5027) // move assignment operator was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(5026) // move constructor was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(4625) // copy constructor was implicitly defined as deleted +DOCTEST_MSVC_SUPPRESS_WARNING(4800) // forcing value to bool 'true' or 'false' (performance warning) +// static analysis +DOCTEST_MSVC_SUPPRESS_WARNING(26439) // This kind of function may not throw. Declare it 'noexcept' +DOCTEST_MSVC_SUPPRESS_WARNING(26495) // Always initialize a member variable +DOCTEST_MSVC_SUPPRESS_WARNING(26451) // Arithmetic overflow ... +DOCTEST_MSVC_SUPPRESS_WARNING(26444) // Avoid unnamed objects with custom construction and dtor... +DOCTEST_MSVC_SUPPRESS_WARNING(26812) // Prefer 'enum class' over 'enum' + +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_BEGIN + +// required includes - will go only in one translation unit! +#include +#include +#include +// borland (Embarcadero) compiler requires math.h and not cmath - https://github.com/onqtam/doctest/pull/37 +#ifdef __BORLANDC__ +#include +#endif // __BORLANDC__ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef DOCTEST_PLATFORM_MAC +#include +#include +#include +#endif // DOCTEST_PLATFORM_MAC + +#ifdef DOCTEST_PLATFORM_WINDOWS + +// defines for a leaner windows.h +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif // WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +// not sure what AfxWin.h is for - here I do what Catch does +#ifdef __AFXDLL +#include +#else +#include +#endif +#include + +#else // DOCTEST_PLATFORM_WINDOWS + +#include +#include + +#endif // DOCTEST_PLATFORM_WINDOWS + +// this is a fix for https://github.com/onqtam/doctest/issues/348 +// https://mail.gnome.org/archives/xml/2012-January/msg00000.html +#if !defined(HAVE_UNISTD_H) && !defined(STDOUT_FILENO) +#define STDOUT_FILENO fileno(stdout) +#endif // HAVE_UNISTD_H + +DOCTEST_MAKE_STD_HEADERS_CLEAN_FROM_WARNINGS_ON_WALL_END + +// counts the number of elements in a C array +#define DOCTEST_COUNTOF(x) (sizeof(x) / sizeof(x[0])) + +#ifdef DOCTEST_CONFIG_DISABLE +#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_disabled +#else // DOCTEST_CONFIG_DISABLE +#define DOCTEST_BRANCH_ON_DISABLED(if_disabled, if_not_disabled) if_not_disabled +#endif // DOCTEST_CONFIG_DISABLE + +#ifndef DOCTEST_CONFIG_OPTIONS_PREFIX +#define DOCTEST_CONFIG_OPTIONS_PREFIX "dt-" +#endif + +#ifndef DOCTEST_THREAD_LOCAL +#define DOCTEST_THREAD_LOCAL thread_local +#endif + +#ifndef DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES +#define DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES 32 +#endif + +#ifndef DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE +#define DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE 64 +#endif + +#ifdef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS +#define DOCTEST_OPTIONS_PREFIX_DISPLAY DOCTEST_CONFIG_OPTIONS_PREFIX +#else +#define DOCTEST_OPTIONS_PREFIX_DISPLAY "" +#endif + +#if defined(WINAPI_FAMILY) && (WINAPI_FAMILY == WINAPI_FAMILY_APP) +#define DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS +#endif + +namespace doctest { + +bool is_running_in_test = false; + +namespace { + using namespace detail; + // case insensitive strcmp + int stricmp(const char* a, const char* b) { + for(;; a++, b++) { + const int d = tolower(*a) - tolower(*b); + if(d != 0 || !*a) + return d; + } + } + + template + String fpToString(T value, int precision) { + std::ostringstream oss; + oss << std::setprecision(precision) << std::fixed << value; + std::string d = oss.str(); + size_t i = d.find_last_not_of('0'); + if(i != std::string::npos && i != d.size() - 1) { + if(d[i] == '.') + i++; + d = d.substr(0, i + 1); + } + return d.c_str(); + } + + struct Endianness + { + enum Arch + { + Big, + Little + }; + + static Arch which() { + int x = 1; + // casting any data pointer to char* is allowed + auto ptr = reinterpret_cast(&x); + if(*ptr) + return Little; + return Big; + } + }; +} // namespace + +namespace detail { + void my_memcpy(void* dest, const void* src, unsigned num) { memcpy(dest, src, num); } + + String rawMemoryToString(const void* object, unsigned size) { + // Reverse order for little endian architectures + int i = 0, end = static_cast(size), inc = 1; + if(Endianness::which() == Endianness::Little) { + i = end - 1; + end = inc = -1; + } + + unsigned const char* bytes = static_cast(object); + std::ostringstream oss; + oss << "0x" << std::setfill('0') << std::hex; + for(; i != end; i += inc) + oss << std::setw(2) << static_cast(bytes[i]); + return oss.str().c_str(); + } + + DOCTEST_THREAD_LOCAL std::ostringstream g_oss; // NOLINT(cert-err58-cpp) + + std::ostream* getTlsOss() { + g_oss.clear(); // there shouldn't be anything worth clearing in the flags + g_oss.str(""); // the slow way of resetting a string stream + //g_oss.seekp(0); // optimal reset - as seen here: https://stackoverflow.com/a/624291/3162383 + return &g_oss; + } + + String getTlsOssResult() { + //g_oss << std::ends; // needed - as shown here: https://stackoverflow.com/a/624291/3162383 + return g_oss.str().c_str(); + } + +#ifndef DOCTEST_CONFIG_DISABLE + +namespace timer_large_integer +{ + +#if defined(DOCTEST_PLATFORM_WINDOWS) + typedef ULONGLONG type; +#else // DOCTEST_PLATFORM_WINDOWS + using namespace std; + typedef uint64_t type; +#endif // DOCTEST_PLATFORM_WINDOWS +} + +typedef timer_large_integer::type ticks_t; + +#ifdef DOCTEST_CONFIG_GETCURRENTTICKS + ticks_t getCurrentTicks() { return DOCTEST_CONFIG_GETCURRENTTICKS(); } +#elif defined(DOCTEST_PLATFORM_WINDOWS) + ticks_t getCurrentTicks() { + static LARGE_INTEGER hz = {0}, hzo = {0}; + if(!hz.QuadPart) { + QueryPerformanceFrequency(&hz); + QueryPerformanceCounter(&hzo); + } + LARGE_INTEGER t; + QueryPerformanceCounter(&t); + return ((t.QuadPart - hzo.QuadPart) * LONGLONG(1000000)) / hz.QuadPart; + } +#else // DOCTEST_PLATFORM_WINDOWS + ticks_t getCurrentTicks() { + timeval t; + gettimeofday(&t, nullptr); + return static_cast(t.tv_sec) * 1000000 + static_cast(t.tv_usec); + } +#endif // DOCTEST_PLATFORM_WINDOWS + + struct Timer + { + void start() { m_ticks = getCurrentTicks(); } + unsigned int getElapsedMicroseconds() const { + return static_cast(getCurrentTicks() - m_ticks); + } + //unsigned int getElapsedMilliseconds() const { + // return static_cast(getElapsedMicroseconds() / 1000); + //} + double getElapsedSeconds() const { return static_cast(getCurrentTicks() - m_ticks) / 1000000.0; } + + private: + ticks_t m_ticks = 0; + }; + +#ifdef DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS + template + using AtomicOrMultiLaneAtomic = std::atomic; +#else // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS + // Provides a multilane implementation of an atomic variable that supports add, sub, load, + // store. Instead of using a single atomic variable, this splits up into multiple ones, + // each sitting on a separate cache line. The goal is to provide a speedup when most + // operations are modifying. It achieves this with two properties: + // + // * Multiple atomics are used, so chance of congestion from the same atomic is reduced. + // * Each atomic sits on a separate cache line, so false sharing is reduced. + // + // The disadvantage is that there is a small overhead due to the use of TLS, and load/store + // is slower because all atomics have to be accessed. + template + class MultiLaneAtomic + { + struct CacheLineAlignedAtomic + { + std::atomic atomic{}; + char padding[DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE - sizeof(std::atomic)]; + }; + CacheLineAlignedAtomic m_atomics[DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES]; + + static_assert(sizeof(CacheLineAlignedAtomic) == DOCTEST_MULTI_LANE_ATOMICS_CACHE_LINE_SIZE, + "guarantee one atomic takes exactly one cache line"); + + public: + T operator++() DOCTEST_NOEXCEPT { return fetch_add(1) + 1; } + + T operator++(int) DOCTEST_NOEXCEPT { return fetch_add(1); } + + T fetch_add(T arg, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { + return myAtomic().fetch_add(arg, order); + } + + T fetch_sub(T arg, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { + return myAtomic().fetch_sub(arg, order); + } + + operator T() const DOCTEST_NOEXCEPT { return load(); } + + T load(std::memory_order order = std::memory_order_seq_cst) const DOCTEST_NOEXCEPT { + auto result = T(); + for(auto const& c : m_atomics) { + result += c.atomic.load(order); + } + return result; + } + + T operator=(T desired) DOCTEST_NOEXCEPT { + store(desired); + return desired; + } + + void store(T desired, std::memory_order order = std::memory_order_seq_cst) DOCTEST_NOEXCEPT { + // first value becomes desired", all others become 0. + for(auto& c : m_atomics) { + c.atomic.store(desired, order); + desired = {}; + } + } + + private: + // Each thread has a different atomic that it operates on. If more than NumLanes threads + // use this, some will use the same atomic. So performance will degrate a bit, but still + // everything will work. + // + // The logic here is a bit tricky. The call should be as fast as possible, so that there + // is minimal to no overhead in determining the correct atomic for the current thread. + // + // 1. A global static counter laneCounter counts continuously up. + // 2. Each successive thread will use modulo operation of that counter so it gets an atomic + // assigned in a round-robin fashion. + // 3. This tlsLaneIdx is stored in the thread local data, so it is directly available with + // little overhead. + std::atomic& myAtomic() DOCTEST_NOEXCEPT { + static std::atomic laneCounter; + DOCTEST_THREAD_LOCAL size_t tlsLaneIdx = + laneCounter++ % DOCTEST_MULTI_LANE_ATOMICS_THREAD_LANES; + + return m_atomics[tlsLaneIdx].atomic; + } + }; + + template + using AtomicOrMultiLaneAtomic = MultiLaneAtomic; +#endif // DOCTEST_CONFIG_NO_MULTI_LANE_ATOMICS + + // this holds both parameters from the command line and runtime data for tests + struct ContextState : ContextOptions, TestRunStats, CurrentTestCaseStats + { + AtomicOrMultiLaneAtomic numAssertsCurrentTest_atomic; + AtomicOrMultiLaneAtomic numAssertsFailedCurrentTest_atomic; + + std::vector> filters = decltype(filters)(9); // 9 different filters + + std::vector reporters_currently_used; + + assert_handler ah = nullptr; + + Timer timer; + + std::vector stringifiedContexts; // logging from INFO() due to an exception + + // stuff for subcases + std::vector subcasesStack; + std::set subcasesPassed; + int subcasesCurrentMaxLevel; + bool should_reenter; + std::atomic shouldLogCurrentException; + + void resetRunData() { + numTestCases = 0; + numTestCasesPassingFilters = 0; + numTestSuitesPassingFilters = 0; + numTestCasesFailed = 0; + numAsserts = 0; + numAssertsFailed = 0; + numAssertsCurrentTest = 0; + numAssertsFailedCurrentTest = 0; + } + + void finalizeTestCaseData() { + seconds = timer.getElapsedSeconds(); + + // update the non-atomic counters + numAsserts += numAssertsCurrentTest_atomic; + numAssertsFailed += numAssertsFailedCurrentTest_atomic; + numAssertsCurrentTest = numAssertsCurrentTest_atomic; + numAssertsFailedCurrentTest = numAssertsFailedCurrentTest_atomic; + + if(numAssertsFailedCurrentTest) + failure_flags |= TestCaseFailureReason::AssertFailure; + + if(Approx(currentTest->m_timeout).epsilon(DBL_EPSILON) != 0 && + Approx(seconds).epsilon(DBL_EPSILON) > currentTest->m_timeout) + failure_flags |= TestCaseFailureReason::Timeout; + + if(currentTest->m_should_fail) { + if(failure_flags) { + failure_flags |= TestCaseFailureReason::ShouldHaveFailedAndDid; + } else { + failure_flags |= TestCaseFailureReason::ShouldHaveFailedButDidnt; + } + } else if(failure_flags && currentTest->m_may_fail) { + failure_flags |= TestCaseFailureReason::CouldHaveFailedAndDid; + } else if(currentTest->m_expected_failures > 0) { + if(numAssertsFailedCurrentTest == currentTest->m_expected_failures) { + failure_flags |= TestCaseFailureReason::FailedExactlyNumTimes; + } else { + failure_flags |= TestCaseFailureReason::DidntFailExactlyNumTimes; + } + } + + bool ok_to_fail = (TestCaseFailureReason::ShouldHaveFailedAndDid & failure_flags) || + (TestCaseFailureReason::CouldHaveFailedAndDid & failure_flags) || + (TestCaseFailureReason::FailedExactlyNumTimes & failure_flags); + + // if any subcase has failed - the whole test case has failed + if(failure_flags && !ok_to_fail) + numTestCasesFailed++; + } + }; + + ContextState* g_cs = nullptr; + + // used to avoid locks for the debug output + // TODO: figure out if this is indeed necessary/correct - seems like either there still + // could be a race or that there wouldn't be a race even if using the context directly + DOCTEST_THREAD_LOCAL bool g_no_colors; + +#endif // DOCTEST_CONFIG_DISABLE +} // namespace detail + +void String::setOnHeap() { *reinterpret_cast(&buf[last]) = 128; } +void String::setLast(unsigned in) { buf[last] = char(in); } + +void String::copy(const String& other) { + using namespace std; + if(other.isOnStack()) { + memcpy(buf, other.buf, len); + } else { + setOnHeap(); + data.size = other.data.size; + data.capacity = data.size + 1; + data.ptr = new char[data.capacity]; + memcpy(data.ptr, other.data.ptr, data.size + 1); + } +} + +String::String() { + buf[0] = '\0'; + setLast(); +} + +String::~String() { + if(!isOnStack()) + delete[] data.ptr; + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) +} + +String::String(const char* in) + : String(in, strlen(in)) {} + +String::String(const char* in, unsigned in_size) { + using namespace std; + if(in_size <= last) { + memcpy(buf, in, in_size); + buf[in_size] = '\0'; + setLast(last - in_size); + } else { + setOnHeap(); + data.size = in_size; + data.capacity = data.size + 1; + data.ptr = new char[data.capacity]; + memcpy(data.ptr, in, in_size); + data.ptr[in_size] = '\0'; + } +} + +String::String(const String& other) { copy(other); } + +String& String::operator=(const String& other) { + if(this != &other) { + if(!isOnStack()) + delete[] data.ptr; + + copy(other); + } + + return *this; +} + +String& String::operator+=(const String& other) { + const unsigned my_old_size = size(); + const unsigned other_size = other.size(); + const unsigned total_size = my_old_size + other_size; + using namespace std; + if(isOnStack()) { + if(total_size < len) { + // append to the current stack space + memcpy(buf + my_old_size, other.c_str(), other_size + 1); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + setLast(last - total_size); + } else { + // alloc new chunk + char* temp = new char[total_size + 1]; + // copy current data to new location before writing in the union + memcpy(temp, buf, my_old_size); // skip the +1 ('\0') for speed + // update data in union + setOnHeap(); + data.size = total_size; + data.capacity = data.size + 1; + data.ptr = temp; + // transfer the rest of the data + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } + } else { + if(data.capacity > total_size) { + // append to the current heap block + data.size = total_size; + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } else { + // resize + data.capacity *= 2; + if(data.capacity <= total_size) + data.capacity = total_size + 1; + // alloc new chunk + char* temp = new char[data.capacity]; + // copy current data to new location before releasing it + memcpy(temp, data.ptr, my_old_size); // skip the +1 ('\0') for speed + // release old chunk + delete[] data.ptr; + // update the rest of the union members + data.size = total_size; + data.ptr = temp; + // transfer the rest of the data + memcpy(data.ptr + my_old_size, other.c_str(), other_size + 1); + } + } + + return *this; +} + +// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) +String String::operator+(const String& other) const { return String(*this) += other; } + +String::String(String&& other) { + using namespace std; + memcpy(buf, other.buf, len); + other.buf[0] = '\0'; + other.setLast(); +} + +String& String::operator=(String&& other) { + using namespace std; + if(this != &other) { + if(!isOnStack()) + delete[] data.ptr; + memcpy(buf, other.buf, len); + other.buf[0] = '\0'; + other.setLast(); + } + return *this; +} + +char String::operator[](unsigned i) const { + return const_cast(this)->operator[](i); // NOLINT +} + +char& String::operator[](unsigned i) { + if(isOnStack()) + return reinterpret_cast(buf)[i]; + return data.ptr[i]; +} + +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wmaybe-uninitialized") +unsigned String::size() const { + if(isOnStack()) + return last - (unsigned(buf[last]) & 31); // using "last" would work only if "len" is 32 + return data.size; +} +DOCTEST_GCC_SUPPRESS_WARNING_POP + +unsigned String::capacity() const { + if(isOnStack()) + return len; + return data.capacity; +} + +int String::compare(const char* other, bool no_case) const { + if(no_case) + return doctest::stricmp(c_str(), other); + return std::strcmp(c_str(), other); +} + +int String::compare(const String& other, bool no_case) const { + return compare(other.c_str(), no_case); +} + +// clang-format off +bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; } +bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; } +bool operator< (const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; } +bool operator> (const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; } +bool operator<=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) < 0 : true; } +bool operator>=(const String& lhs, const String& rhs) { return (lhs != rhs) ? lhs.compare(rhs) > 0 : true; } +// clang-format on + +std::ostream& operator<<(std::ostream& s, const String& in) { return s << in.c_str(); } + +namespace { + void color_to_stream(std::ostream&, Color::Enum) DOCTEST_BRANCH_ON_DISABLED({}, ;) +} // namespace + +namespace Color { + std::ostream& operator<<(std::ostream& s, Color::Enum code) { + color_to_stream(s, code); + return s; + } +} // namespace Color + +// clang-format off +const char* assertString(assertType::Enum at) { + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4062) // enum 'x' in switch of enum 'y' is not handled + switch(at) { //!OCLINT missing default in switch statements + case assertType::DT_WARN : return "WARN"; + case assertType::DT_CHECK : return "CHECK"; + case assertType::DT_REQUIRE : return "REQUIRE"; + + case assertType::DT_WARN_FALSE : return "WARN_FALSE"; + case assertType::DT_CHECK_FALSE : return "CHECK_FALSE"; + case assertType::DT_REQUIRE_FALSE : return "REQUIRE_FALSE"; + + case assertType::DT_WARN_THROWS : return "WARN_THROWS"; + case assertType::DT_CHECK_THROWS : return "CHECK_THROWS"; + case assertType::DT_REQUIRE_THROWS : return "REQUIRE_THROWS"; + + case assertType::DT_WARN_THROWS_AS : return "WARN_THROWS_AS"; + case assertType::DT_CHECK_THROWS_AS : return "CHECK_THROWS_AS"; + case assertType::DT_REQUIRE_THROWS_AS : return "REQUIRE_THROWS_AS"; + + case assertType::DT_WARN_THROWS_WITH : return "WARN_THROWS_WITH"; + case assertType::DT_CHECK_THROWS_WITH : return "CHECK_THROWS_WITH"; + case assertType::DT_REQUIRE_THROWS_WITH : return "REQUIRE_THROWS_WITH"; + + case assertType::DT_WARN_THROWS_WITH_AS : return "WARN_THROWS_WITH_AS"; + case assertType::DT_CHECK_THROWS_WITH_AS : return "CHECK_THROWS_WITH_AS"; + case assertType::DT_REQUIRE_THROWS_WITH_AS : return "REQUIRE_THROWS_WITH_AS"; + + case assertType::DT_WARN_NOTHROW : return "WARN_NOTHROW"; + case assertType::DT_CHECK_NOTHROW : return "CHECK_NOTHROW"; + case assertType::DT_REQUIRE_NOTHROW : return "REQUIRE_NOTHROW"; + + case assertType::DT_WARN_EQ : return "WARN_EQ"; + case assertType::DT_CHECK_EQ : return "CHECK_EQ"; + case assertType::DT_REQUIRE_EQ : return "REQUIRE_EQ"; + case assertType::DT_WARN_NE : return "WARN_NE"; + case assertType::DT_CHECK_NE : return "CHECK_NE"; + case assertType::DT_REQUIRE_NE : return "REQUIRE_NE"; + case assertType::DT_WARN_GT : return "WARN_GT"; + case assertType::DT_CHECK_GT : return "CHECK_GT"; + case assertType::DT_REQUIRE_GT : return "REQUIRE_GT"; + case assertType::DT_WARN_LT : return "WARN_LT"; + case assertType::DT_CHECK_LT : return "CHECK_LT"; + case assertType::DT_REQUIRE_LT : return "REQUIRE_LT"; + case assertType::DT_WARN_GE : return "WARN_GE"; + case assertType::DT_CHECK_GE : return "CHECK_GE"; + case assertType::DT_REQUIRE_GE : return "REQUIRE_GE"; + case assertType::DT_WARN_LE : return "WARN_LE"; + case assertType::DT_CHECK_LE : return "CHECK_LE"; + case assertType::DT_REQUIRE_LE : return "REQUIRE_LE"; + + case assertType::DT_WARN_UNARY : return "WARN_UNARY"; + case assertType::DT_CHECK_UNARY : return "CHECK_UNARY"; + case assertType::DT_REQUIRE_UNARY : return "REQUIRE_UNARY"; + case assertType::DT_WARN_UNARY_FALSE : return "WARN_UNARY_FALSE"; + case assertType::DT_CHECK_UNARY_FALSE : return "CHECK_UNARY_FALSE"; + case assertType::DT_REQUIRE_UNARY_FALSE : return "REQUIRE_UNARY_FALSE"; + } + DOCTEST_MSVC_SUPPRESS_WARNING_POP + return ""; +} +// clang-format on + +const char* failureString(assertType::Enum at) { + if(at & assertType::is_warn) //!OCLINT bitwise operator in conditional + return "WARNING"; + if(at & assertType::is_check) //!OCLINT bitwise operator in conditional + return "ERROR"; + if(at & assertType::is_require) //!OCLINT bitwise operator in conditional + return "FATAL ERROR"; + return ""; +} + +DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") +DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wnull-dereference") +// depending on the current options this will remove the path of filenames +const char* skipPathFromFilename(const char* file) { +#ifndef DOCTEST_CONFIG_DISABLE + if(getContextOptions()->no_path_in_filenames) { + auto back = std::strrchr(file, '\\'); + auto forward = std::strrchr(file, '/'); + if(back || forward) { + if(back > forward) + forward = back; + return forward + 1; + } + } +#endif // DOCTEST_CONFIG_DISABLE + return file; +} +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +bool SubcaseSignature::operator<(const SubcaseSignature& other) const { + if(m_line != other.m_line) + return m_line < other.m_line; + if(std::strcmp(m_file, other.m_file) != 0) + return std::strcmp(m_file, other.m_file) < 0; + return m_name.compare(other.m_name) < 0; +} + +IContextScope::IContextScope() = default; +IContextScope::~IContextScope() = default; + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +String toString(char* in) { return toString(static_cast(in)); } +// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) +String toString(const char* in) { return String("\"") + (in ? in : "{null string}") + "\""; } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING +String toString(bool in) { return in ? "true" : "false"; } +String toString(float in) { return fpToString(in, 5) + "f"; } +String toString(double in) { return fpToString(in, 10); } +String toString(double long in) { return fpToString(in, 15); } + +#define DOCTEST_TO_STRING_OVERLOAD(type, fmt) \ + String toString(type in) { \ + char buf[64]; \ + std::sprintf(buf, fmt, in); \ + return buf; \ + } + +DOCTEST_TO_STRING_OVERLOAD(char, "%d") +DOCTEST_TO_STRING_OVERLOAD(char signed, "%d") +DOCTEST_TO_STRING_OVERLOAD(char unsigned, "%u") +DOCTEST_TO_STRING_OVERLOAD(int short, "%d") +DOCTEST_TO_STRING_OVERLOAD(int short unsigned, "%u") +DOCTEST_TO_STRING_OVERLOAD(int, "%d") +DOCTEST_TO_STRING_OVERLOAD(unsigned, "%u") +DOCTEST_TO_STRING_OVERLOAD(int long, "%ld") +DOCTEST_TO_STRING_OVERLOAD(int long unsigned, "%lu") +DOCTEST_TO_STRING_OVERLOAD(int long long, "%lld") +DOCTEST_TO_STRING_OVERLOAD(int long long unsigned, "%llu") + +String toString(std::nullptr_t) { return "NULL"; } + +#if DOCTEST_MSVC >= DOCTEST_COMPILER(19, 20, 0) +// see this issue on why this is needed: https://github.com/onqtam/doctest/issues/183 +String toString(const std::string& in) { return in.c_str(); } +#endif // VS 2019 + +Approx::Approx(double value) + : m_epsilon(static_cast(std::numeric_limits::epsilon()) * 100) + , m_scale(1.0) + , m_value(value) {} + +Approx Approx::operator()(double value) const { + Approx approx(value); + approx.epsilon(m_epsilon); + approx.scale(m_scale); + return approx; +} + +Approx& Approx::epsilon(double newEpsilon) { + m_epsilon = newEpsilon; + return *this; +} +Approx& Approx::scale(double newScale) { + m_scale = newScale; + return *this; +} + +bool operator==(double lhs, const Approx& rhs) { + // Thanks to Richard Harris for his help refining this formula + return std::fabs(lhs - rhs.m_value) < + rhs.m_epsilon * (rhs.m_scale + std::max(std::fabs(lhs), std::fabs(rhs.m_value))); +} +bool operator==(const Approx& lhs, double rhs) { return operator==(rhs, lhs); } +bool operator!=(double lhs, const Approx& rhs) { return !operator==(lhs, rhs); } +bool operator!=(const Approx& lhs, double rhs) { return !operator==(rhs, lhs); } +bool operator<=(double lhs, const Approx& rhs) { return lhs < rhs.m_value || lhs == rhs; } +bool operator<=(const Approx& lhs, double rhs) { return lhs.m_value < rhs || lhs == rhs; } +bool operator>=(double lhs, const Approx& rhs) { return lhs > rhs.m_value || lhs == rhs; } +bool operator>=(const Approx& lhs, double rhs) { return lhs.m_value > rhs || lhs == rhs; } +bool operator<(double lhs, const Approx& rhs) { return lhs < rhs.m_value && lhs != rhs; } +bool operator<(const Approx& lhs, double rhs) { return lhs.m_value < rhs && lhs != rhs; } +bool operator>(double lhs, const Approx& rhs) { return lhs > rhs.m_value && lhs != rhs; } +bool operator>(const Approx& lhs, double rhs) { return lhs.m_value > rhs && lhs != rhs; } + +String toString(const Approx& in) { + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + return String("Approx( ") + doctest::toString(in.m_value) + " )"; +} +const ContextOptions* getContextOptions() { return DOCTEST_BRANCH_ON_DISABLED(nullptr, g_cs); } + +} // namespace doctest + +#ifdef DOCTEST_CONFIG_DISABLE +namespace doctest { +Context::Context(int, const char* const*) {} +Context::~Context() = default; +void Context::applyCommandLine(int, const char* const*) {} +void Context::addFilter(const char*, const char*) {} +void Context::clearFilters() {} +void Context::setOption(const char*, int) {} +void Context::setOption(const char*, const char*) {} +bool Context::shouldExit() { return false; } +void Context::setAsDefaultForAssertsOutOfTestCases() {} +void Context::setAssertHandler(detail::assert_handler) {} +int Context::run() { return 0; } + +IReporter::~IReporter() = default; + +int IReporter::get_num_active_contexts() { return 0; } +const IContextScope* const* IReporter::get_active_contexts() { return nullptr; } +int IReporter::get_num_stringified_contexts() { return 0; } +const String* IReporter::get_stringified_contexts() { return nullptr; } + +int registerReporter(const char*, int, IReporter*) { return 0; } + +} // namespace doctest +#else // DOCTEST_CONFIG_DISABLE + +#if !defined(DOCTEST_CONFIG_COLORS_NONE) +#if !defined(DOCTEST_CONFIG_COLORS_WINDOWS) && !defined(DOCTEST_CONFIG_COLORS_ANSI) +#ifdef DOCTEST_PLATFORM_WINDOWS +#define DOCTEST_CONFIG_COLORS_WINDOWS +#else // linux +#define DOCTEST_CONFIG_COLORS_ANSI +#endif // platform +#endif // DOCTEST_CONFIG_COLORS_WINDOWS && DOCTEST_CONFIG_COLORS_ANSI +#endif // DOCTEST_CONFIG_COLORS_NONE + +namespace doctest_detail_test_suite_ns { +// holds the current test suite +doctest::detail::TestSuite& getCurrentTestSuite() { + static doctest::detail::TestSuite data{}; + return data; +} +} // namespace doctest_detail_test_suite_ns + +namespace doctest { +namespace { + // the int (priority) is part of the key for automatic sorting - sadly one can register a + // reporter with a duplicate name and a different priority but hopefully that won't happen often :| + typedef std::map, reporterCreatorFunc> reporterMap; + + reporterMap& getReporters() { + static reporterMap data; + return data; + } + reporterMap& getListeners() { + static reporterMap data; + return data; + } +} // namespace +namespace detail { +#define DOCTEST_ITERATE_THROUGH_REPORTERS(function, ...) \ + for(auto& curr_rep : g_cs->reporters_currently_used) \ + curr_rep->function(__VA_ARGS__) + + bool checkIfShouldThrow(assertType::Enum at) { + if(at & assertType::is_require) //!OCLINT bitwise operator in conditional + return true; + + if((at & assertType::is_check) //!OCLINT bitwise operator in conditional + && getContextOptions()->abort_after > 0 && + (g_cs->numAssertsFailed + g_cs->numAssertsFailedCurrentTest_atomic) >= + getContextOptions()->abort_after) + return true; + + return false; + } + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + DOCTEST_NORETURN void throwException() { + g_cs->shouldLogCurrentException = false; + throw TestFailureException(); + } // NOLINT(cert-err60-cpp) +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + void throwException() {} +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +} // namespace detail + +namespace { + using namespace detail; + // matching of a string against a wildcard mask (case sensitivity configurable) taken from + // https://www.codeproject.com/Articles/1088/Wildcard-string-compare-globbing + int wildcmp(const char* str, const char* wild, bool caseSensitive) { + const char* cp = str; + const char* mp = wild; + + while((*str) && (*wild != '*')) { + if((caseSensitive ? (*wild != *str) : (tolower(*wild) != tolower(*str))) && + (*wild != '?')) { + return 0; + } + wild++; + str++; + } + + while(*str) { + if(*wild == '*') { + if(!*++wild) { + return 1; + } + mp = wild; + cp = str + 1; + } else if((caseSensitive ? (*wild == *str) : (tolower(*wild) == tolower(*str))) || + (*wild == '?')) { + wild++; + str++; + } else { + wild = mp; //!OCLINT parameter reassignment + str = cp++; //!OCLINT parameter reassignment + } + } + + while(*wild == '*') { + wild++; + } + return !*wild; + } + + //// C string hash function (djb2) - taken from http://www.cse.yorku.ca/~oz/hash.html + //unsigned hashStr(unsigned const char* str) { + // unsigned long hash = 5381; + // char c; + // while((c = *str++)) + // hash = ((hash << 5) + hash) + c; // hash * 33 + c + // return hash; + //} + + // checks if the name matches any of the filters (and can be configured what to do when empty) + bool matchesAny(const char* name, const std::vector& filters, bool matchEmpty, + bool caseSensitive) { + if(filters.empty() && matchEmpty) + return true; + for(auto& curr : filters) + if(wildcmp(name, curr.c_str(), caseSensitive)) + return true; + return false; + } +} // namespace +namespace detail { + + Subcase::Subcase(const String& name, const char* file, int line) + : m_signature({name, file, line}) { + auto* s = g_cs; + + // check subcase filters + if(s->subcasesStack.size() < size_t(s->subcase_filter_levels)) { + if(!matchesAny(m_signature.m_name.c_str(), s->filters[6], true, s->case_sensitive)) + return; + if(matchesAny(m_signature.m_name.c_str(), s->filters[7], false, s->case_sensitive)) + return; + } + + // if a Subcase on the same level has already been entered + if(s->subcasesStack.size() < size_t(s->subcasesCurrentMaxLevel)) { + s->should_reenter = true; + return; + } + + // push the current signature to the stack so we can check if the + // current stack + the current new subcase have been traversed + s->subcasesStack.push_back(m_signature); + if(s->subcasesPassed.count(s->subcasesStack) != 0) { + // pop - revert to previous stack since we've already passed this + s->subcasesStack.pop_back(); + return; + } + + s->subcasesCurrentMaxLevel = s->subcasesStack.size(); + m_entered = true; + + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_start, m_signature); + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + + Subcase::~Subcase() { + if(m_entered) { + // only mark the subcase stack as passed if no subcases have been skipped + if(g_cs->should_reenter == false) + g_cs->subcasesPassed.insert(g_cs->subcasesStack); + g_cs->subcasesStack.pop_back(); + +#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) + if(std::uncaught_exceptions() > 0 +#else + if(std::uncaught_exception() +#endif + && g_cs->shouldLogCurrentException) { + DOCTEST_ITERATE_THROUGH_REPORTERS( + test_case_exception, {"exception thrown in subcase - will translate later " + "when the whole test case has been exited (cannot " + "translate while there is an active exception)", + false}); + g_cs->shouldLogCurrentException = false; + } + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); + } + } + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + Subcase::operator bool() const { return m_entered; } + + Result::Result(bool passed, const String& decomposition) + : m_passed(passed) + , m_decomp(decomposition) {} + + ExpressionDecomposer::ExpressionDecomposer(assertType::Enum at) + : m_at(at) {} + + TestSuite& TestSuite::operator*(const char* in) { + m_test_suite = in; + // clear state + m_description = nullptr; + m_skip = false; + m_no_breaks = false; + m_no_output = false; + m_may_fail = false; + m_should_fail = false; + m_expected_failures = 0; + m_timeout = 0; + return *this; + } + + TestCase::TestCase(funcType test, const char* file, unsigned line, const TestSuite& test_suite, + const char* type, int template_id) { + m_file = file; + m_line = line; + m_name = nullptr; // will be later overridden in operator* + m_test_suite = test_suite.m_test_suite; + m_description = test_suite.m_description; + m_skip = test_suite.m_skip; + m_no_breaks = test_suite.m_no_breaks; + m_no_output = test_suite.m_no_output; + m_may_fail = test_suite.m_may_fail; + m_should_fail = test_suite.m_should_fail; + m_expected_failures = test_suite.m_expected_failures; + m_timeout = test_suite.m_timeout; + + m_test = test; + m_type = type; + m_template_id = template_id; + } + + TestCase::TestCase(const TestCase& other) + : TestCaseData() { + *this = other; + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(26434) // hides a non-virtual function + DOCTEST_MSVC_SUPPRESS_WARNING(26437) // Do not slice + TestCase& TestCase::operator=(const TestCase& other) { + static_cast(*this) = static_cast(other); + + m_test = other.m_test; + m_type = other.m_type; + m_template_id = other.m_template_id; + m_full_name = other.m_full_name; + + if(m_template_id != -1) + m_name = m_full_name.c_str(); + return *this; + } + DOCTEST_MSVC_SUPPRESS_WARNING_POP + + TestCase& TestCase::operator*(const char* in) { + m_name = in; + // make a new name with an appended type for templated test case + if(m_template_id != -1) { + m_full_name = String(m_name) + m_type; + // redirect the name to point to the newly constructed full name + m_name = m_full_name.c_str(); + } + return *this; + } + + bool TestCase::operator<(const TestCase& other) const { + // this will be used only to differentiate between test cases - not relevant for sorting + if(m_line != other.m_line) + return m_line < other.m_line; + const int name_cmp = strcmp(m_name, other.m_name); + if(name_cmp != 0) + return name_cmp < 0; + const int file_cmp = m_file.compare(other.m_file); + if(file_cmp != 0) + return file_cmp < 0; + return m_template_id < other.m_template_id; + } + + // all the registered tests + std::set& getRegisteredTests() { + static std::set data; + return data; + } +} // namespace detail +namespace { + using namespace detail; + // for sorting tests by file/line + bool fileOrderComparator(const TestCase* lhs, const TestCase* rhs) { + // this is needed because MSVC gives different case for drive letters + // for __FILE__ when evaluated in a header and a source file + const int res = lhs->m_file.compare(rhs->m_file, bool(DOCTEST_MSVC)); + if(res != 0) + return res < 0; + if(lhs->m_line != rhs->m_line) + return lhs->m_line < rhs->m_line; + return lhs->m_template_id < rhs->m_template_id; + } + + // for sorting tests by suite/file/line + bool suiteOrderComparator(const TestCase* lhs, const TestCase* rhs) { + const int res = std::strcmp(lhs->m_test_suite, rhs->m_test_suite); + if(res != 0) + return res < 0; + return fileOrderComparator(lhs, rhs); + } + + // for sorting tests by name/suite/file/line + bool nameOrderComparator(const TestCase* lhs, const TestCase* rhs) { + const int res = std::strcmp(lhs->m_name, rhs->m_name); + if(res != 0) + return res < 0; + return suiteOrderComparator(lhs, rhs); + } + +#ifdef DOCTEST_CONFIG_COLORS_WINDOWS + HANDLE g_stdoutHandle; + WORD g_origFgAttrs; + WORD g_origBgAttrs; + bool g_attrsInitted = false; + + int colors_init() { + if(!g_attrsInitted) { + g_stdoutHandle = GetStdHandle(STD_OUTPUT_HANDLE); + g_attrsInitted = true; + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo(g_stdoutHandle, &csbiInfo); + g_origFgAttrs = csbiInfo.wAttributes & ~(BACKGROUND_GREEN | BACKGROUND_RED | + BACKGROUND_BLUE | BACKGROUND_INTENSITY); + g_origBgAttrs = csbiInfo.wAttributes & ~(FOREGROUND_GREEN | FOREGROUND_RED | + FOREGROUND_BLUE | FOREGROUND_INTENSITY); + } + return 0; + } + + int dumy_init_console_colors = colors_init(); +#endif // DOCTEST_CONFIG_COLORS_WINDOWS + + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + void color_to_stream(std::ostream& s, Color::Enum code) { + static_cast(s); // for DOCTEST_CONFIG_COLORS_NONE or DOCTEST_CONFIG_COLORS_WINDOWS + static_cast(code); // for DOCTEST_CONFIG_COLORS_NONE +#ifdef DOCTEST_CONFIG_COLORS_ANSI + if(g_no_colors || + (isatty(STDOUT_FILENO) == false && getContextOptions()->force_colors == false)) + return; + + auto col = ""; + // clang-format off + switch(code) { //!OCLINT missing break in switch statement / unnecessary default statement in covered switch statement + case Color::Red: col = "[0;31m"; break; + case Color::Green: col = "[0;32m"; break; + case Color::Blue: col = "[0;34m"; break; + case Color::Cyan: col = "[0;36m"; break; + case Color::Yellow: col = "[0;33m"; break; + case Color::Grey: col = "[1;30m"; break; + case Color::LightGrey: col = "[0;37m"; break; + case Color::BrightRed: col = "[1;31m"; break; + case Color::BrightGreen: col = "[1;32m"; break; + case Color::BrightWhite: col = "[1;37m"; break; + case Color::Bright: // invalid + case Color::None: + case Color::White: + default: col = "[0m"; + } + // clang-format on + s << "\033" << col; +#endif // DOCTEST_CONFIG_COLORS_ANSI + +#ifdef DOCTEST_CONFIG_COLORS_WINDOWS + if(g_no_colors || + (isatty(fileno(stdout)) == false && getContextOptions()->force_colors == false)) + return; + +#define DOCTEST_SET_ATTR(x) SetConsoleTextAttribute(g_stdoutHandle, x | g_origBgAttrs) + + // clang-format off + switch (code) { + case Color::White: DOCTEST_SET_ATTR(FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; + case Color::Red: DOCTEST_SET_ATTR(FOREGROUND_RED); break; + case Color::Green: DOCTEST_SET_ATTR(FOREGROUND_GREEN); break; + case Color::Blue: DOCTEST_SET_ATTR(FOREGROUND_BLUE); break; + case Color::Cyan: DOCTEST_SET_ATTR(FOREGROUND_BLUE | FOREGROUND_GREEN); break; + case Color::Yellow: DOCTEST_SET_ATTR(FOREGROUND_RED | FOREGROUND_GREEN); break; + case Color::Grey: DOCTEST_SET_ATTR(0); break; + case Color::LightGrey: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY); break; + case Color::BrightRed: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_RED); break; + case Color::BrightGreen: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN); break; + case Color::BrightWhite: DOCTEST_SET_ATTR(FOREGROUND_INTENSITY | FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_BLUE); break; + case Color::None: + case Color::Bright: // invalid + default: DOCTEST_SET_ATTR(g_origFgAttrs); + } + // clang-format on +#endif // DOCTEST_CONFIG_COLORS_WINDOWS + } + DOCTEST_CLANG_SUPPRESS_WARNING_POP + + std::vector& getExceptionTranslators() { + static std::vector data; + return data; + } + + String translateActiveException() { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + String res; + auto& translators = getExceptionTranslators(); + for(auto& curr : translators) + if(curr->translate(res)) + return res; + // clang-format off + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wcatch-value") + try { + throw; + } catch(std::exception& ex) { + return ex.what(); + } catch(std::string& msg) { + return msg.c_str(); + } catch(const char* msg) { + return msg; + } catch(...) { + return "unknown exception"; + } + DOCTEST_GCC_SUPPRESS_WARNING_POP +// clang-format on +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + return ""; +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + } +} // namespace + +namespace detail { + // used by the macros for registering tests + int regTest(const TestCase& tc) { + getRegisteredTests().insert(tc); + return 0; + } + + // sets the current test suite + int setTestSuite(const TestSuite& ts) { + doctest_detail_test_suite_ns::getCurrentTestSuite() = ts; + return 0; + } + +#ifdef DOCTEST_IS_DEBUGGER_ACTIVE + bool isDebuggerActive() { return DOCTEST_IS_DEBUGGER_ACTIVE(); } +#else // DOCTEST_IS_DEBUGGER_ACTIVE +#ifdef DOCTEST_PLATFORM_LINUX + class ErrnoGuard { + public: + ErrnoGuard() : m_oldErrno(errno) {} + ~ErrnoGuard() { errno = m_oldErrno; } + private: + int m_oldErrno; + }; + // See the comments in Catch2 for the reasoning behind this implementation: + // https://github.com/catchorg/Catch2/blob/v2.13.1/include/internal/catch_debugger.cpp#L79-L102 + bool isDebuggerActive() { + ErrnoGuard guard; + std::ifstream in("/proc/self/status"); + for(std::string line; std::getline(in, line);) { + static const int PREFIX_LEN = 11; + if(line.compare(0, PREFIX_LEN, "TracerPid:\t") == 0) { + return line.length() > PREFIX_LEN && line[PREFIX_LEN] != '0'; + } + } + return false; + } +#elif defined(DOCTEST_PLATFORM_MAC) + // The following function is taken directly from the following technical note: + // https://developer.apple.com/library/archive/qa/qa1361/_index.html + // Returns true if the current process is being debugged (either + // running under the debugger or has a debugger attached post facto). + bool isDebuggerActive() { + int mib[4]; + kinfo_proc info; + size_t size; + // Initialize the flags so that, if sysctl fails for some bizarre + // reason, we get a predictable result. + info.kp_proc.p_flag = 0; + // Initialize mib, which tells sysctl the info we want, in this case + // we're looking for information about a specific process ID. + mib[0] = CTL_KERN; + mib[1] = KERN_PROC; + mib[2] = KERN_PROC_PID; + mib[3] = getpid(); + // Call sysctl. + size = sizeof(info); + if(sysctl(mib, DOCTEST_COUNTOF(mib), &info, &size, 0, 0) != 0) { + std::cerr << "\nCall to sysctl failed - unable to determine if debugger is active **\n"; + return false; + } + // We're being debugged if the P_TRACED flag is set. + return ((info.kp_proc.p_flag & P_TRACED) != 0); + } +#elif DOCTEST_MSVC || defined(__MINGW32__) || defined(__MINGW64__) + bool isDebuggerActive() { return ::IsDebuggerPresent() != 0; } +#else + bool isDebuggerActive() { return false; } +#endif // Platform +#endif // DOCTEST_IS_DEBUGGER_ACTIVE + + void registerExceptionTranslatorImpl(const IExceptionTranslator* et) { + if(std::find(getExceptionTranslators().begin(), getExceptionTranslators().end(), et) == + getExceptionTranslators().end()) + getExceptionTranslators().push_back(et); + } + +#ifdef DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + void toStream(std::ostream* s, char* in) { *s << in; } + void toStream(std::ostream* s, const char* in) { *s << in; } +#endif // DOCTEST_CONFIG_TREAT_CHAR_STAR_AS_STRING + void toStream(std::ostream* s, bool in) { *s << std::boolalpha << in << std::noboolalpha; } + void toStream(std::ostream* s, float in) { *s << in; } + void toStream(std::ostream* s, double in) { *s << in; } + void toStream(std::ostream* s, double long in) { *s << in; } + + void toStream(std::ostream* s, char in) { *s << in; } + void toStream(std::ostream* s, char signed in) { *s << in; } + void toStream(std::ostream* s, char unsigned in) { *s << in; } + void toStream(std::ostream* s, int short in) { *s << in; } + void toStream(std::ostream* s, int short unsigned in) { *s << in; } + void toStream(std::ostream* s, int in) { *s << in; } + void toStream(std::ostream* s, int unsigned in) { *s << in; } + void toStream(std::ostream* s, int long in) { *s << in; } + void toStream(std::ostream* s, int long unsigned in) { *s << in; } + void toStream(std::ostream* s, int long long in) { *s << in; } + void toStream(std::ostream* s, int long long unsigned in) { *s << in; } + + DOCTEST_THREAD_LOCAL std::vector g_infoContexts; // for logging with INFO() + + ContextScopeBase::ContextScopeBase() { + g_infoContexts.push_back(this); + } + + DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4996) // std::uncaught_exception is deprecated in C++17 + DOCTEST_GCC_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + + // destroy cannot be inlined into the destructor because that would mean calling stringify after + // ContextScope has been destroyed (base class destructors run after derived class destructors). + // Instead, ContextScope calls this method directly from its destructor. + void ContextScopeBase::destroy() { +#if defined(__cpp_lib_uncaught_exceptions) && __cpp_lib_uncaught_exceptions >= 201411L && (!defined(__MAC_OS_X_VERSION_MIN_REQUIRED) || __MAC_OS_X_VERSION_MIN_REQUIRED >= 101200) + if(std::uncaught_exceptions() > 0) { +#else + if(std::uncaught_exception()) { +#endif + std::ostringstream s; + this->stringify(&s); + g_cs->stringifiedContexts.push_back(s.str().c_str()); + } + g_infoContexts.pop_back(); + } + + DOCTEST_CLANG_SUPPRESS_WARNING_POP + DOCTEST_GCC_SUPPRESS_WARNING_POP + DOCTEST_MSVC_SUPPRESS_WARNING_POP +} // namespace detail +namespace { + using namespace detail; + +#if !defined(DOCTEST_CONFIG_POSIX_SIGNALS) && !defined(DOCTEST_CONFIG_WINDOWS_SEH) + struct FatalConditionHandler + { + static void reset() {} + static void allocateAltStackMem() {} + static void freeAltStackMem() {} + }; +#else // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH + + void reportFatal(const std::string&); + +#ifdef DOCTEST_PLATFORM_WINDOWS + + struct SignalDefs + { + DWORD id; + const char* name; + }; + // There is no 1-1 mapping between signals and windows exceptions. + // Windows can easily distinguish between SO and SigSegV, + // but SigInt, SigTerm, etc are handled differently. + SignalDefs signalDefs[] = { + {static_cast(EXCEPTION_ILLEGAL_INSTRUCTION), + "SIGILL - Illegal instruction signal"}, + {static_cast(EXCEPTION_STACK_OVERFLOW), "SIGSEGV - Stack overflow"}, + {static_cast(EXCEPTION_ACCESS_VIOLATION), + "SIGSEGV - Segmentation violation signal"}, + {static_cast(EXCEPTION_INT_DIVIDE_BY_ZERO), "Divide by zero error"}, + }; + + struct FatalConditionHandler + { + static LONG CALLBACK handleException(PEXCEPTION_POINTERS ExceptionInfo) { + // Multiple threads may enter this filter/handler at once. We want the error message to be printed on the + // console just once no matter how many threads have crashed. + static std::mutex mutex; + static bool execute = true; + { + std::lock_guard lock(mutex); + if(execute) { + bool reported = false; + for(size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + if(ExceptionInfo->ExceptionRecord->ExceptionCode == signalDefs[i].id) { + reportFatal(signalDefs[i].name); + reported = true; + break; + } + } + if(reported == false) + reportFatal("Unhandled SEH exception caught"); + if(isDebuggerActive() && !g_cs->no_breaks) + DOCTEST_BREAK_INTO_DEBUGGER(); + } + execute = false; + } + std::exit(EXIT_FAILURE); + } + + static void allocateAltStackMem() {} + static void freeAltStackMem() {} + + FatalConditionHandler() { + isSet = true; + // 32k seems enough for doctest to handle stack overflow, + // but the value was found experimentally, so there is no strong guarantee + guaranteeSize = 32 * 1024; + // Register an unhandled exception filter + previousTop = SetUnhandledExceptionFilter(handleException); + // Pass in guarantee size to be filled + SetThreadStackGuarantee(&guaranteeSize); + + // On Windows uncaught exceptions from another thread, exceptions from + // destructors, or calls to std::terminate are not a SEH exception + + // The terminal handler gets called when: + // - std::terminate is called FROM THE TEST RUNNER THREAD + // - an exception is thrown from a destructor FROM THE TEST RUNNER THREAD + original_terminate_handler = std::get_terminate(); + std::set_terminate([]() DOCTEST_NOEXCEPT { + reportFatal("Terminate handler called"); + if(isDebuggerActive() && !g_cs->no_breaks) + DOCTEST_BREAK_INTO_DEBUGGER(); + std::exit(EXIT_FAILURE); // explicitly exit - otherwise the SIGABRT handler may be called as well + }); + + // SIGABRT is raised when: + // - std::terminate is called FROM A DIFFERENT THREAD + // - an exception is thrown from a destructor FROM A DIFFERENT THREAD + // - an uncaught exception is thrown FROM A DIFFERENT THREAD + prev_sigabrt_handler = std::signal(SIGABRT, [](int signal) DOCTEST_NOEXCEPT { + if(signal == SIGABRT) { + reportFatal("SIGABRT - Abort (abnormal termination) signal"); + if(isDebuggerActive() && !g_cs->no_breaks) + DOCTEST_BREAK_INTO_DEBUGGER(); + std::exit(EXIT_FAILURE); + } + }); + + // The following settings are taken from google test, and more + // specifically from UnitTest::Run() inside of gtest.cc + + // the user does not want to see pop-up dialogs about crashes + prev_error_mode_1 = SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOALIGNMENTFAULTEXCEPT | + SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX); + // This forces the abort message to go to stderr in all circumstances. + prev_error_mode_2 = _set_error_mode(_OUT_TO_STDERR); + // In the debug version, Visual Studio pops up a separate dialog + // offering a choice to debug the aborted program - we want to disable that. + prev_abort_behavior = _set_abort_behavior(0x0, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); + // In debug mode, the Windows CRT can crash with an assertion over invalid + // input (e.g. passing an invalid file descriptor). The default handling + // for these assertions is to pop up a dialog and wait for user input. + // Instead ask the CRT to dump such assertions to stderr non-interactively. + prev_report_mode = _CrtSetReportMode(_CRT_ASSERT, _CRTDBG_MODE_FILE | _CRTDBG_MODE_DEBUG); + prev_report_file = _CrtSetReportFile(_CRT_ASSERT, _CRTDBG_FILE_STDERR); + } + + static void reset() { + if(isSet) { + // Unregister handler and restore the old guarantee + SetUnhandledExceptionFilter(previousTop); + SetThreadStackGuarantee(&guaranteeSize); + std::set_terminate(original_terminate_handler); + std::signal(SIGABRT, prev_sigabrt_handler); + SetErrorMode(prev_error_mode_1); + _set_error_mode(prev_error_mode_2); + _set_abort_behavior(prev_abort_behavior, _WRITE_ABORT_MSG | _CALL_REPORTFAULT); + static_cast(_CrtSetReportMode(_CRT_ASSERT, prev_report_mode)); + static_cast(_CrtSetReportFile(_CRT_ASSERT, prev_report_file)); + isSet = false; + } + } + + ~FatalConditionHandler() { reset(); } + + private: + static UINT prev_error_mode_1; + static int prev_error_mode_2; + static unsigned int prev_abort_behavior; + static int prev_report_mode; + static _HFILE prev_report_file; + static void (*prev_sigabrt_handler)(int); + static std::terminate_handler original_terminate_handler; + static bool isSet; + static ULONG guaranteeSize; + static LPTOP_LEVEL_EXCEPTION_FILTER previousTop; + }; + + UINT FatalConditionHandler::prev_error_mode_1; + int FatalConditionHandler::prev_error_mode_2; + unsigned int FatalConditionHandler::prev_abort_behavior; + int FatalConditionHandler::prev_report_mode; + _HFILE FatalConditionHandler::prev_report_file; + void (*FatalConditionHandler::prev_sigabrt_handler)(int); + std::terminate_handler FatalConditionHandler::original_terminate_handler; + bool FatalConditionHandler::isSet = false; + ULONG FatalConditionHandler::guaranteeSize = 0; + LPTOP_LEVEL_EXCEPTION_FILTER FatalConditionHandler::previousTop = nullptr; + +#else // DOCTEST_PLATFORM_WINDOWS + + struct SignalDefs + { + int id; + const char* name; + }; + SignalDefs signalDefs[] = {{SIGINT, "SIGINT - Terminal interrupt signal"}, + {SIGILL, "SIGILL - Illegal instruction signal"}, + {SIGFPE, "SIGFPE - Floating point error signal"}, + {SIGSEGV, "SIGSEGV - Segmentation violation signal"}, + {SIGTERM, "SIGTERM - Termination request signal"}, + {SIGABRT, "SIGABRT - Abort (abnormal termination) signal"}}; + + struct FatalConditionHandler + { + static bool isSet; + static struct sigaction oldSigActions[DOCTEST_COUNTOF(signalDefs)]; + static stack_t oldSigStack; + static size_t altStackSize; + static char* altStackMem; + + static void handleSignal(int sig) { + const char* name = ""; + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + SignalDefs& def = signalDefs[i]; + if(sig == def.id) { + name = def.name; + break; + } + } + reset(); + reportFatal(name); + raise(sig); + } + + static void allocateAltStackMem() { + altStackMem = new char[altStackSize]; + } + + static void freeAltStackMem() { + delete[] altStackMem; + } + + FatalConditionHandler() { + isSet = true; + stack_t sigStack; + sigStack.ss_sp = altStackMem; + sigStack.ss_size = altStackSize; + sigStack.ss_flags = 0; + sigaltstack(&sigStack, &oldSigStack); + struct sigaction sa = {}; + sa.sa_handler = handleSignal; // NOLINT + sa.sa_flags = SA_ONSTACK; + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + sigaction(signalDefs[i].id, &sa, &oldSigActions[i]); + } + } + + ~FatalConditionHandler() { reset(); } + static void reset() { + if(isSet) { + // Set signals back to previous values -- hopefully nobody overwrote them in the meantime + for(std::size_t i = 0; i < DOCTEST_COUNTOF(signalDefs); ++i) { + sigaction(signalDefs[i].id, &oldSigActions[i], nullptr); + } + // Return the old stack + sigaltstack(&oldSigStack, nullptr); + isSet = false; + } + } + }; + + bool FatalConditionHandler::isSet = false; + struct sigaction FatalConditionHandler::oldSigActions[DOCTEST_COUNTOF(signalDefs)] = {}; + stack_t FatalConditionHandler::oldSigStack = {}; + size_t FatalConditionHandler::altStackSize = 4 * SIGSTKSZ; + char* FatalConditionHandler::altStackMem = nullptr; + +#endif // DOCTEST_PLATFORM_WINDOWS +#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH + +} // namespace + +namespace { + using namespace detail; + +#ifdef DOCTEST_PLATFORM_WINDOWS +#define DOCTEST_OUTPUT_DEBUG_STRING(text) ::OutputDebugStringA(text) +#else + // TODO: integration with XCode and other IDEs +#define DOCTEST_OUTPUT_DEBUG_STRING(text) // NOLINT(clang-diagnostic-unused-macros) +#endif // Platform + + void addAssert(assertType::Enum at) { + if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional + g_cs->numAssertsCurrentTest_atomic++; + } + + void addFailedAssert(assertType::Enum at) { + if((at & assertType::is_warn) == 0) //!OCLINT bitwise operator in conditional + g_cs->numAssertsFailedCurrentTest_atomic++; + } + +#if defined(DOCTEST_CONFIG_POSIX_SIGNALS) || defined(DOCTEST_CONFIG_WINDOWS_SEH) + void reportFatal(const std::string& message) { + g_cs->failure_flags |= TestCaseFailureReason::Crash; + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, {message.c_str(), true}); + + while(g_cs->subcasesStack.size()) { + g_cs->subcasesStack.pop_back(); + DOCTEST_ITERATE_THROUGH_REPORTERS(subcase_end, DOCTEST_EMPTY); + } + + g_cs->finalizeTestCaseData(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); + } +#endif // DOCTEST_CONFIG_POSIX_SIGNALS || DOCTEST_CONFIG_WINDOWS_SEH +} // namespace +namespace detail { + + ResultBuilder::ResultBuilder(assertType::Enum at, const char* file, int line, const char* expr, + const char* exception_type, const char* exception_string) { + m_test_case = g_cs->currentTest; + m_at = at; + m_file = file; + m_line = line; + m_expr = expr; + m_failed = true; + m_threw = false; + m_threw_as = false; + m_exception_type = exception_type; + m_exception_string = exception_string; +#if DOCTEST_MSVC + if(m_expr[0] == ' ') // this happens when variadic macros are disabled under MSVC + ++m_expr; +#endif // MSVC + } + + void ResultBuilder::setResult(const Result& res) { + m_decomp = res.m_decomp; + m_failed = !res.m_passed; + } + + void ResultBuilder::translateException() { + m_threw = true; + m_exception = translateActiveException(); + } + + bool ResultBuilder::log() { + if(m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional + m_failed = !m_threw; + } else if((m_at & assertType::is_throws_as) && (m_at & assertType::is_throws_with)) { //!OCLINT + m_failed = !m_threw_as || (m_exception != m_exception_string); + } else if(m_at & assertType::is_throws_as) { //!OCLINT bitwise operator in conditional + m_failed = !m_threw_as; + } else if(m_at & assertType::is_throws_with) { //!OCLINT bitwise operator in conditional + m_failed = m_exception != m_exception_string; + } else if(m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional + m_failed = m_threw; + } + + if(m_exception.size()) + m_exception = String("\"") + m_exception + "\""; + + if(is_running_in_test) { + addAssert(m_at); + DOCTEST_ITERATE_THROUGH_REPORTERS(log_assert, *this); + + if(m_failed) + addFailedAssert(m_at); + } else if(m_failed) { + failed_out_of_a_testing_context(*this); + } + + return m_failed && isDebuggerActive() && !getContextOptions()->no_breaks && + (g_cs->currentTest == nullptr || !g_cs->currentTest->m_no_breaks); // break into debugger + } + + void ResultBuilder::react() const { + if(m_failed && checkIfShouldThrow(m_at)) + throwException(); + } + + void failed_out_of_a_testing_context(const AssertData& ad) { + if(g_cs->ah) + g_cs->ah(ad); + else + std::abort(); + } + + void decomp_assert(assertType::Enum at, const char* file, int line, const char* expr, + Result result) { + bool failed = !result.m_passed; + + // ################################################################################### + // IF THE DEBUGGER BREAKS HERE - GO 1 LEVEL UP IN THE CALLSTACK FOR THE FAILING ASSERT + // THIS IS THE EFFECT OF HAVING 'DOCTEST_CONFIG_SUPER_FAST_ASSERTS' DEFINED + // ################################################################################### + DOCTEST_ASSERT_OUT_OF_TESTS(result.m_decomp); + DOCTEST_ASSERT_IN_TESTS(result.m_decomp); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) + } + + MessageBuilder::MessageBuilder(const char* file, int line, assertType::Enum severity) { + m_stream = getTlsOss(); + m_file = file; + m_line = line; + m_severity = severity; + } + + IExceptionTranslator::IExceptionTranslator() = default; + IExceptionTranslator::~IExceptionTranslator() = default; + + bool MessageBuilder::log() { + m_string = getTlsOssResult(); + DOCTEST_ITERATE_THROUGH_REPORTERS(log_message, *this); + + const bool isWarn = m_severity & assertType::is_warn; + + // warn is just a message in this context so we don't treat it as an assert + if(!isWarn) { + addAssert(m_severity); + addFailedAssert(m_severity); + } + + return isDebuggerActive() && !getContextOptions()->no_breaks && !isWarn && + (g_cs->currentTest == nullptr || !g_cs->currentTest->m_no_breaks); // break into debugger + } + + void MessageBuilder::react() { + if(m_severity & assertType::is_require) //!OCLINT bitwise operator in conditional + throwException(); + } + + MessageBuilder::~MessageBuilder() = default; +} // namespace detail +namespace { + using namespace detail; + + template + DOCTEST_NORETURN void throw_exception(Ex const& e) { +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + throw e; +#else // DOCTEST_CONFIG_NO_EXCEPTIONS + std::cerr << "doctest will terminate because it needed to throw an exception.\n" + << "The message was: " << e.what() << '\n'; + std::terminate(); +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + } + +#ifndef DOCTEST_INTERNAL_ERROR +#define DOCTEST_INTERNAL_ERROR(msg) \ + throw_exception(std::logic_error( \ + __FILE__ ":" DOCTEST_TOSTR(__LINE__) ": Internal doctest error: " msg)) +#endif // DOCTEST_INTERNAL_ERROR + + // clang-format off + +// ================================================================================================= +// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp +// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. +// ================================================================================================= + + class XmlEncode { + public: + enum ForWhat { ForTextNodes, ForAttributes }; + + XmlEncode( std::string const& str, ForWhat forWhat = ForTextNodes ); + + void encodeTo( std::ostream& os ) const; + + friend std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ); + + private: + std::string m_str; + ForWhat m_forWhat; + }; + + class XmlWriter { + public: + + class ScopedElement { + public: + ScopedElement( XmlWriter* writer ); + + ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT; + ScopedElement& operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT; + + ~ScopedElement(); + + ScopedElement& writeText( std::string const& text, bool indent = true ); + + template + ScopedElement& writeAttribute( std::string const& name, T const& attribute ) { + m_writer->writeAttribute( name, attribute ); + return *this; + } + + private: + mutable XmlWriter* m_writer = nullptr; + }; + + XmlWriter( std::ostream& os = std::cout ); + ~XmlWriter(); + + XmlWriter( XmlWriter const& ) = delete; + XmlWriter& operator=( XmlWriter const& ) = delete; + + XmlWriter& startElement( std::string const& name ); + + ScopedElement scopedElement( std::string const& name ); + + XmlWriter& endElement(); + + XmlWriter& writeAttribute( std::string const& name, std::string const& attribute ); + + XmlWriter& writeAttribute( std::string const& name, const char* attribute ); + + XmlWriter& writeAttribute( std::string const& name, bool attribute ); + + template + XmlWriter& writeAttribute( std::string const& name, T const& attribute ) { + std::stringstream rss; + rss << attribute; + return writeAttribute( name, rss.str() ); + } + + XmlWriter& writeText( std::string const& text, bool indent = true ); + + //XmlWriter& writeComment( std::string const& text ); + + //void writeStylesheetRef( std::string const& url ); + + //XmlWriter& writeBlankLine(); + + void ensureTagClosed(); + + private: + + void writeDeclaration(); + + void newlineIfNecessary(); + + bool m_tagIsOpen = false; + bool m_needsNewline = false; + std::vector m_tags; + std::string m_indent; + std::ostream& m_os; + }; + +// ================================================================================================= +// The following code has been taken verbatim from Catch2/include/internal/catch_xmlwriter.h/cpp +// This is done so cherry-picking bug fixes is trivial - even the style/formatting is untouched. +// ================================================================================================= + +using uchar = unsigned char; + +namespace { + + size_t trailingBytes(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return 2; + } + if ((c & 0xF0) == 0xE0) { + return 3; + } + if ((c & 0xF8) == 0xF0) { + return 4; + } + DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + uint32_t headerValue(unsigned char c) { + if ((c & 0xE0) == 0xC0) { + return c & 0x1F; + } + if ((c & 0xF0) == 0xE0) { + return c & 0x0F; + } + if ((c & 0xF8) == 0xF0) { + return c & 0x07; + } + DOCTEST_INTERNAL_ERROR("Invalid multibyte utf-8 start byte encountered"); + } + + void hexEscapeChar(std::ostream& os, unsigned char c) { + std::ios_base::fmtflags f(os.flags()); + os << "\\x" + << std::uppercase << std::hex << std::setfill('0') << std::setw(2) + << static_cast(c); + os.flags(f); + } + +} // anonymous namespace + + XmlEncode::XmlEncode( std::string const& str, ForWhat forWhat ) + : m_str( str ), + m_forWhat( forWhat ) + {} + + void XmlEncode::encodeTo( std::ostream& os ) const { + // Apostrophe escaping not necessary if we always use " to write attributes + // (see: https://www.w3.org/TR/xml/#syntax) + + for( std::size_t idx = 0; idx < m_str.size(); ++ idx ) { + uchar c = m_str[idx]; + switch (c) { + case '<': os << "<"; break; + case '&': os << "&"; break; + + case '>': + // See: https://www.w3.org/TR/xml/#syntax + if (idx > 2 && m_str[idx - 1] == ']' && m_str[idx - 2] == ']') + os << ">"; + else + os << c; + break; + + case '\"': + if (m_forWhat == ForAttributes) + os << """; + else + os << c; + break; + + default: + // Check for control characters and invalid utf-8 + + // Escape control characters in standard ascii + // see https://stackoverflow.com/questions/404107/why-are-control-characters-illegal-in-xml-1-0 + if (c < 0x09 || (c > 0x0D && c < 0x20) || c == 0x7F) { + hexEscapeChar(os, c); + break; + } + + // Plain ASCII: Write it to stream + if (c < 0x7F) { + os << c; + break; + } + + // UTF-8 territory + // Check if the encoding is valid and if it is not, hex escape bytes. + // Important: We do not check the exact decoded values for validity, only the encoding format + // First check that this bytes is a valid lead byte: + // This means that it is not encoded as 1111 1XXX + // Or as 10XX XXXX + if (c < 0xC0 || + c >= 0xF8) { + hexEscapeChar(os, c); + break; + } + + auto encBytes = trailingBytes(c); + // Are there enough bytes left to avoid accessing out-of-bounds memory? + if (idx + encBytes - 1 >= m_str.size()) { + hexEscapeChar(os, c); + break; + } + // The header is valid, check data + // The next encBytes bytes must together be a valid utf-8 + // This means: bitpattern 10XX XXXX and the extracted value is sane (ish) + bool valid = true; + uint32_t value = headerValue(c); + for (std::size_t n = 1; n < encBytes; ++n) { + uchar nc = m_str[idx + n]; + valid &= ((nc & 0xC0) == 0x80); + value = (value << 6) | (nc & 0x3F); + } + + if ( + // Wrong bit pattern of following bytes + (!valid) || + // Overlong encodings + (value < 0x80) || + ( value < 0x800 && encBytes > 2) || // removed "0x80 <= value &&" because redundant + (0x800 < value && value < 0x10000 && encBytes > 3) || + // Encoded value out of range + (value >= 0x110000) + ) { + hexEscapeChar(os, c); + break; + } + + // If we got here, this is in fact a valid(ish) utf-8 sequence + for (std::size_t n = 0; n < encBytes; ++n) { + os << m_str[idx + n]; + } + idx += encBytes - 1; + break; + } + } + } + + std::ostream& operator << ( std::ostream& os, XmlEncode const& xmlEncode ) { + xmlEncode.encodeTo( os ); + return os; + } + + XmlWriter::ScopedElement::ScopedElement( XmlWriter* writer ) + : m_writer( writer ) + {} + + XmlWriter::ScopedElement::ScopedElement( ScopedElement&& other ) DOCTEST_NOEXCEPT + : m_writer( other.m_writer ){ + other.m_writer = nullptr; + } + XmlWriter::ScopedElement& XmlWriter::ScopedElement::operator=( ScopedElement&& other ) DOCTEST_NOEXCEPT { + if ( m_writer ) { + m_writer->endElement(); + } + m_writer = other.m_writer; + other.m_writer = nullptr; + return *this; + } + + + XmlWriter::ScopedElement::~ScopedElement() { + if( m_writer ) + m_writer->endElement(); + } + + XmlWriter::ScopedElement& XmlWriter::ScopedElement::writeText( std::string const& text, bool indent ) { + m_writer->writeText( text, indent ); + return *this; + } + + XmlWriter::XmlWriter( std::ostream& os ) : m_os( os ) + { + writeDeclaration(); + } + + XmlWriter::~XmlWriter() { + while( !m_tags.empty() ) + endElement(); + } + + XmlWriter& XmlWriter::startElement( std::string const& name ) { + ensureTagClosed(); + newlineIfNecessary(); + m_os << m_indent << '<' << name; + m_tags.push_back( name ); + m_indent += " "; + m_tagIsOpen = true; + return *this; + } + + XmlWriter::ScopedElement XmlWriter::scopedElement( std::string const& name ) { + ScopedElement scoped( this ); + startElement( name ); + return scoped; + } + + XmlWriter& XmlWriter::endElement() { + newlineIfNecessary(); + m_indent = m_indent.substr( 0, m_indent.size()-2 ); + if( m_tagIsOpen ) { + m_os << "/>"; + m_tagIsOpen = false; + } + else { + m_os << m_indent << ""; + } + m_os << std::endl; + m_tags.pop_back(); + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, std::string const& attribute ) { + if( !name.empty() && !attribute.empty() ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, const char* attribute ) { + if( !name.empty() && attribute && attribute[0] != '\0' ) + m_os << ' ' << name << "=\"" << XmlEncode( attribute, XmlEncode::ForAttributes ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeAttribute( std::string const& name, bool attribute ) { + m_os << ' ' << name << "=\"" << ( attribute ? "true" : "false" ) << '"'; + return *this; + } + + XmlWriter& XmlWriter::writeText( std::string const& text, bool indent ) { + if( !text.empty() ){ + bool tagWasOpen = m_tagIsOpen; + ensureTagClosed(); + if( tagWasOpen && indent ) + m_os << m_indent; + m_os << XmlEncode( text ); + m_needsNewline = true; + } + return *this; + } + + //XmlWriter& XmlWriter::writeComment( std::string const& text ) { + // ensureTagClosed(); + // m_os << m_indent << ""; + // m_needsNewline = true; + // return *this; + //} + + //void XmlWriter::writeStylesheetRef( std::string const& url ) { + // m_os << "\n"; + //} + + //XmlWriter& XmlWriter::writeBlankLine() { + // ensureTagClosed(); + // m_os << '\n'; + // return *this; + //} + + void XmlWriter::ensureTagClosed() { + if( m_tagIsOpen ) { + m_os << ">" << std::endl; + m_tagIsOpen = false; + } + } + + void XmlWriter::writeDeclaration() { + m_os << "\n"; + } + + void XmlWriter::newlineIfNecessary() { + if( m_needsNewline ) { + m_os << std::endl; + m_needsNewline = false; + } + } + +// ================================================================================================= +// End of copy-pasted code from Catch +// ================================================================================================= + + // clang-format on + + struct XmlReporter : public IReporter + { + XmlWriter xml; + std::mutex mutex; + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc = nullptr; + + XmlReporter(const ContextOptions& co) + : xml(*co.cout) + , opt(co) {} + + void log_contexts() { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + std::stringstream ss; + for(int i = 0; i < num_contexts; ++i) { + contexts[i]->stringify(&ss); + xml.scopedElement("Info").writeText(ss.str()); + ss.str(""); + } + } + } + + unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } + + void test_case_start_impl(const TestCaseData& in) { + bool open_ts_tag = false; + if(tc != nullptr) { // we have already opened a test suite + if(std::strcmp(tc->m_test_suite, in.m_test_suite) != 0) { + xml.endElement(); + open_ts_tag = true; + } + } + else { + open_ts_tag = true; // first test case ==> first test suite + } + + if(open_ts_tag) { + xml.startElement("TestSuite"); + xml.writeAttribute("name", in.m_test_suite); + } + + tc = ∈ + xml.startElement("TestCase") + .writeAttribute("name", in.m_name) + .writeAttribute("filename", skipPathFromFilename(in.m_file.c_str())) + .writeAttribute("line", line(in.m_line)) + .writeAttribute("description", in.m_description); + + if(Approx(in.m_timeout) != 0) + xml.writeAttribute("timeout", in.m_timeout); + if(in.m_may_fail) + xml.writeAttribute("may_fail", true); + if(in.m_should_fail) + xml.writeAttribute("should_fail", true); + } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData& in) override { + test_run_start(); + if(opt.list_reporters) { + for(auto& curr : getListeners()) + xml.scopedElement("Listener") + .writeAttribute("priority", curr.first.first) + .writeAttribute("name", curr.first.second); + for(auto& curr : getReporters()) + xml.scopedElement("Reporter") + .writeAttribute("priority", curr.first.first) + .writeAttribute("name", curr.first.second); + } else if(opt.count || opt.list_test_cases) { + for(unsigned i = 0; i < in.num_data; ++i) { + xml.scopedElement("TestCase").writeAttribute("name", in.data[i]->m_name) + .writeAttribute("testsuite", in.data[i]->m_test_suite) + .writeAttribute("filename", skipPathFromFilename(in.data[i]->m_file.c_str())) + .writeAttribute("line", line(in.data[i]->m_line)); + } + xml.scopedElement("OverallResultsTestCases") + .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); + } else if(opt.list_test_suites) { + for(unsigned i = 0; i < in.num_data; ++i) + xml.scopedElement("TestSuite").writeAttribute("name", in.data[i]->m_test_suite); + xml.scopedElement("OverallResultsTestCases") + .writeAttribute("unskipped", in.run_stats->numTestCasesPassingFilters); + xml.scopedElement("OverallResultsTestSuites") + .writeAttribute("unskipped", in.run_stats->numTestSuitesPassingFilters); + } + xml.endElement(); + } + + void test_run_start() override { + // remove .exe extension - mainly to have the same output on UNIX and Windows + std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); +#ifdef DOCTEST_PLATFORM_WINDOWS + if(binary_name.rfind(".exe") != std::string::npos) + binary_name = binary_name.substr(0, binary_name.length() - 4); +#endif // DOCTEST_PLATFORM_WINDOWS + + xml.startElement("doctest").writeAttribute("binary", binary_name); + if(opt.no_version == false) + xml.writeAttribute("version", DOCTEST_VERSION_STR); + + // only the consequential ones (TODO: filters) + xml.scopedElement("Options") + .writeAttribute("order_by", opt.order_by.c_str()) + .writeAttribute("rand_seed", opt.rand_seed) + .writeAttribute("first", opt.first) + .writeAttribute("last", opt.last) + .writeAttribute("abort_after", opt.abort_after) + .writeAttribute("subcase_filter_levels", opt.subcase_filter_levels) + .writeAttribute("case_sensitive", opt.case_sensitive) + .writeAttribute("no_throw", opt.no_throw) + .writeAttribute("no_skip", opt.no_skip); + } + + void test_run_end(const TestRunStats& p) override { + if(tc) // the TestSuite tag - only if there has been at least 1 test case + xml.endElement(); + + xml.scopedElement("OverallResultsAsserts") + .writeAttribute("successes", p.numAsserts - p.numAssertsFailed) + .writeAttribute("failures", p.numAssertsFailed); + + xml.startElement("OverallResultsTestCases") + .writeAttribute("successes", + p.numTestCasesPassingFilters - p.numTestCasesFailed) + .writeAttribute("failures", p.numTestCasesFailed); + if(opt.no_skipped_summary == false) + xml.writeAttribute("skipped", p.numTestCases - p.numTestCasesPassingFilters); + xml.endElement(); + + xml.endElement(); + } + + void test_case_start(const TestCaseData& in) override { + test_case_start_impl(in); + xml.ensureTagClosed(); + } + + void test_case_reenter(const TestCaseData&) override {} + + void test_case_end(const CurrentTestCaseStats& st) override { + xml.startElement("OverallResultsAsserts") + .writeAttribute("successes", + st.numAssertsCurrentTest - st.numAssertsFailedCurrentTest) + .writeAttribute("failures", st.numAssertsFailedCurrentTest); + if(opt.duration) + xml.writeAttribute("duration", st.seconds); + if(tc->m_expected_failures) + xml.writeAttribute("expected_failures", tc->m_expected_failures); + xml.endElement(); + + xml.endElement(); + } + + void test_case_exception(const TestCaseException& e) override { + std::lock_guard lock(mutex); + + xml.scopedElement("Exception") + .writeAttribute("crash", e.is_crash) + .writeText(e.error_string.c_str()); + } + + void subcase_start(const SubcaseSignature& in) override { + std::lock_guard lock(mutex); + + xml.startElement("SubCase") + .writeAttribute("name", in.m_name) + .writeAttribute("filename", skipPathFromFilename(in.m_file)) + .writeAttribute("line", line(in.m_line)); + xml.ensureTagClosed(); + } + + void subcase_end() override { xml.endElement(); } + + void log_assert(const AssertData& rb) override { + if(!rb.m_failed && !opt.success) + return; + + std::lock_guard lock(mutex); + + xml.startElement("Expression") + .writeAttribute("success", !rb.m_failed) + .writeAttribute("type", assertString(rb.m_at)) + .writeAttribute("filename", skipPathFromFilename(rb.m_file)) + .writeAttribute("line", line(rb.m_line)); + + xml.scopedElement("Original").writeText(rb.m_expr); + + if(rb.m_threw) + xml.scopedElement("Exception").writeText(rb.m_exception.c_str()); + + if(rb.m_at & assertType::is_throws_as) + xml.scopedElement("ExpectedException").writeText(rb.m_exception_type); + if(rb.m_at & assertType::is_throws_with) + xml.scopedElement("ExpectedExceptionString").writeText(rb.m_exception_string); + if((rb.m_at & assertType::is_normal) && !rb.m_threw) + xml.scopedElement("Expanded").writeText(rb.m_decomp.c_str()); + + log_contexts(); + + xml.endElement(); + } + + void log_message(const MessageData& mb) override { + std::lock_guard lock(mutex); + + xml.startElement("Message") + .writeAttribute("type", failureString(mb.m_severity)) + .writeAttribute("filename", skipPathFromFilename(mb.m_file)) + .writeAttribute("line", line(mb.m_line)); + + xml.scopedElement("Text").writeText(mb.m_string.c_str()); + + log_contexts(); + + xml.endElement(); + } + + void test_case_skipped(const TestCaseData& in) override { + if(opt.no_skipped_summary == false) { + test_case_start_impl(in); + xml.writeAttribute("skipped", "true"); + xml.endElement(); + } + } + }; + + DOCTEST_REGISTER_REPORTER("xml", 0, XmlReporter); + + void fulltext_log_assert_to_stream(std::ostream& s, const AssertData& rb) { + if((rb.m_at & (assertType::is_throws_as | assertType::is_throws_with)) == + 0) //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << " ) " + << Color::None; + + if(rb.m_at & assertType::is_throws) { //!OCLINT bitwise operator in conditional + s << (rb.m_threw ? "threw as expected!" : "did NOT throw at all!") << "\n"; + } else if((rb.m_at & assertType::is_throws_as) && + (rb.m_at & assertType::is_throws_with)) { //!OCLINT + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" + << rb.m_exception_string << "\", " << rb.m_exception_type << " ) " << Color::None; + if(rb.m_threw) { + if(!rb.m_failed) { + s << "threw as expected!\n"; + } else { + s << "threw a DIFFERENT exception! (contents: " << rb.m_exception << ")\n"; + } + } else { + s << "did NOT throw at all!\n"; + } + } else if(rb.m_at & + assertType::is_throws_as) { //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", " + << rb.m_exception_type << " ) " << Color::None + << (rb.m_threw ? (rb.m_threw_as ? "threw as expected!" : + "threw a DIFFERENT exception: ") : + "did NOT throw at all!") + << Color::Cyan << rb.m_exception << "\n"; + } else if(rb.m_at & + assertType::is_throws_with) { //!OCLINT bitwise operator in conditional + s << Color::Cyan << assertString(rb.m_at) << "( " << rb.m_expr << ", \"" + << rb.m_exception_string << "\" ) " << Color::None + << (rb.m_threw ? (!rb.m_failed ? "threw as expected!" : + "threw a DIFFERENT exception: ") : + "did NOT throw at all!") + << Color::Cyan << rb.m_exception << "\n"; + } else if(rb.m_at & assertType::is_nothrow) { //!OCLINT bitwise operator in conditional + s << (rb.m_threw ? "THREW exception: " : "didn't throw!") << Color::Cyan + << rb.m_exception << "\n"; + } else { + s << (rb.m_threw ? "THREW exception: " : + (!rb.m_failed ? "is correct!\n" : "is NOT correct!\n")); + if(rb.m_threw) + s << rb.m_exception << "\n"; + else + s << " values: " << assertString(rb.m_at) << "( " << rb.m_decomp << " )\n"; + } + } + + // TODO: + // - log_message() + // - respond to queries + // - honor remaining options + // - more attributes in tags + struct JUnitReporter : public IReporter + { + XmlWriter xml; + std::mutex mutex; + Timer timer; + std::vector deepestSubcaseStackNames; + + struct JUnitTestCaseData + { + static std::string getCurrentTimestamp() { + // Beware, this is not reentrant because of backward compatibility issues + // Also, UTC only, again because of backward compatibility (%z is C++11) + time_t rawtime; + std::time(&rawtime); + auto const timeStampSize = sizeof("2017-01-16T17:06:45Z"); + + std::tm timeInfo; +#ifdef DOCTEST_PLATFORM_WINDOWS + gmtime_s(&timeInfo, &rawtime); +#else // DOCTEST_PLATFORM_WINDOWS + gmtime_r(&rawtime, &timeInfo); +#endif // DOCTEST_PLATFORM_WINDOWS + + char timeStamp[timeStampSize]; + const char* const fmt = "%Y-%m-%dT%H:%M:%SZ"; + + std::strftime(timeStamp, timeStampSize, fmt, &timeInfo); + return std::string(timeStamp); + } + + struct JUnitTestMessage + { + JUnitTestMessage(const std::string& _message, const std::string& _type, const std::string& _details) + : message(_message), type(_type), details(_details) {} + + JUnitTestMessage(const std::string& _message, const std::string& _details) + : message(_message), type(), details(_details) {} + + std::string message, type, details; + }; + + struct JUnitTestCase + { + JUnitTestCase(const std::string& _classname, const std::string& _name) + : classname(_classname), name(_name), time(0), failures() {} + + std::string classname, name; + double time; + std::vector failures, errors; + }; + + void add(const std::string& classname, const std::string& name) { + testcases.emplace_back(classname, name); + } + + void appendSubcaseNamesToLastTestcase(std::vector nameStack) { + for(auto& curr: nameStack) + if(curr.size()) + testcases.back().name += std::string("/") + curr.c_str(); + } + + void addTime(double time) { + if(time < 1e-4) + time = 0; + testcases.back().time = time; + totalSeconds += time; + } + + void addFailure(const std::string& message, const std::string& type, const std::string& details) { + testcases.back().failures.emplace_back(message, type, details); + ++totalFailures; + } + + void addError(const std::string& message, const std::string& details) { + testcases.back().errors.emplace_back(message, details); + ++totalErrors; + } + + std::vector testcases; + double totalSeconds = 0; + int totalErrors = 0, totalFailures = 0; + }; + + JUnitTestCaseData testCaseData; + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc = nullptr; + + JUnitReporter(const ContextOptions& co) + : xml(*co.cout) + , opt(co) {} + + unsigned line(unsigned l) const { return opt.no_line_numbers ? 0 : l; } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData&) override {} + + void test_run_start() override {} + + void test_run_end(const TestRunStats& p) override { + // remove .exe extension - mainly to have the same output on UNIX and Windows + std::string binary_name = skipPathFromFilename(opt.binary_name.c_str()); +#ifdef DOCTEST_PLATFORM_WINDOWS + if(binary_name.rfind(".exe") != std::string::npos) + binary_name = binary_name.substr(0, binary_name.length() - 4); +#endif // DOCTEST_PLATFORM_WINDOWS + xml.startElement("testsuites"); + xml.startElement("testsuite").writeAttribute("name", binary_name) + .writeAttribute("errors", testCaseData.totalErrors) + .writeAttribute("failures", testCaseData.totalFailures) + .writeAttribute("tests", p.numAsserts); + if(opt.no_time_in_output == false) { + xml.writeAttribute("time", testCaseData.totalSeconds); + xml.writeAttribute("timestamp", JUnitTestCaseData::getCurrentTimestamp()); + } + if(opt.no_version == false) + xml.writeAttribute("doctest_version", DOCTEST_VERSION_STR); + + for(const auto& testCase : testCaseData.testcases) { + xml.startElement("testcase") + .writeAttribute("classname", testCase.classname) + .writeAttribute("name", testCase.name); + if(opt.no_time_in_output == false) + xml.writeAttribute("time", testCase.time); + // This is not ideal, but it should be enough to mimic gtest's junit output. + xml.writeAttribute("status", "run"); + + for(const auto& failure : testCase.failures) { + xml.scopedElement("failure") + .writeAttribute("message", failure.message) + .writeAttribute("type", failure.type) + .writeText(failure.details, false); + } + + for(const auto& error : testCase.errors) { + xml.scopedElement("error") + .writeAttribute("message", error.message) + .writeText(error.details); + } + + xml.endElement(); + } + xml.endElement(); + xml.endElement(); + } + + void test_case_start(const TestCaseData& in) override { + testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); + timer.start(); + } + + void test_case_reenter(const TestCaseData& in) override { + testCaseData.addTime(timer.getElapsedSeconds()); + testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); + deepestSubcaseStackNames.clear(); + + timer.start(); + testCaseData.add(skipPathFromFilename(in.m_file.c_str()), in.m_name); + } + + void test_case_end(const CurrentTestCaseStats&) override { + testCaseData.addTime(timer.getElapsedSeconds()); + testCaseData.appendSubcaseNamesToLastTestcase(deepestSubcaseStackNames); + deepestSubcaseStackNames.clear(); + } + + void test_case_exception(const TestCaseException& e) override { + std::lock_guard lock(mutex); + testCaseData.addError("exception", e.error_string.c_str()); + } + + void subcase_start(const SubcaseSignature& in) override { + std::lock_guard lock(mutex); + deepestSubcaseStackNames.push_back(in.m_name); + } + + void subcase_end() override {} + + void log_assert(const AssertData& rb) override { + if(!rb.m_failed) // report only failures & ignore the `success` option + return; + + std::lock_guard lock(mutex); + + std::ostringstream os; + os << skipPathFromFilename(rb.m_file) << (opt.gnu_file_line ? ":" : "(") + << line(rb.m_line) << (opt.gnu_file_line ? ":" : "):") << std::endl; + + fulltext_log_assert_to_stream(os, rb); + log_contexts(os); + testCaseData.addFailure(rb.m_decomp.c_str(), assertString(rb.m_at), os.str()); + } + + void log_message(const MessageData&) override {} + + void test_case_skipped(const TestCaseData&) override {} + + void log_contexts(std::ostringstream& s) { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + + s << " logged: "; + for(int i = 0; i < num_contexts; ++i) { + s << (i == 0 ? "" : " "); + contexts[i]->stringify(&s); + s << std::endl; + } + } + } + }; + + DOCTEST_REGISTER_REPORTER("junit", 0, JUnitReporter); + + struct Whitespace + { + int nrSpaces; + explicit Whitespace(int nr) + : nrSpaces(nr) {} + }; + + std::ostream& operator<<(std::ostream& out, const Whitespace& ws) { + if(ws.nrSpaces != 0) + out << std::setw(ws.nrSpaces) << ' '; + return out; + } + + struct ConsoleReporter : public IReporter + { + std::ostream& s; + bool hasLoggedCurrentTestStart; + std::vector subcasesStack; + size_t currentSubcaseLevel; + std::mutex mutex; + + // caching pointers/references to objects of these types - safe to do + const ContextOptions& opt; + const TestCaseData* tc; + + ConsoleReporter(const ContextOptions& co) + : s(*co.cout) + , opt(co) {} + + ConsoleReporter(const ContextOptions& co, std::ostream& ostr) + : s(ostr) + , opt(co) {} + + // ========================================================================================= + // WHAT FOLLOWS ARE HELPERS USED BY THE OVERRIDES OF THE VIRTUAL METHODS OF THE INTERFACE + // ========================================================================================= + + void separator_to_stream() { + s << Color::Yellow + << "===============================================================================" + "\n"; + } + + const char* getSuccessOrFailString(bool success, assertType::Enum at, + const char* success_str) { + if(success) + return success_str; + return failureString(at); + } + + Color::Enum getSuccessOrFailColor(bool success, assertType::Enum at) { + return success ? Color::BrightGreen : + (at & assertType::is_warn) ? Color::Yellow : Color::Red; + } + + void successOrFailColoredStringToStream(bool success, assertType::Enum at, + const char* success_str = "SUCCESS") { + s << getSuccessOrFailColor(success, at) + << getSuccessOrFailString(success, at, success_str) << ": "; + } + + void log_contexts() { + int num_contexts = get_num_active_contexts(); + if(num_contexts) { + auto contexts = get_active_contexts(); + + s << Color::None << " logged: "; + for(int i = 0; i < num_contexts; ++i) { + s << (i == 0 ? "" : " "); + contexts[i]->stringify(&s); + s << "\n"; + } + } + + s << "\n"; + } + + // this was requested to be made virtual so users could override it + virtual void file_line_to_stream(const char* file, int line, + const char* tail = "") { + s << Color::LightGrey << skipPathFromFilename(file) << (opt.gnu_file_line ? ":" : "(") + << (opt.no_line_numbers ? 0 : line) // 0 or the real num depending on the option + << (opt.gnu_file_line ? ":" : "):") << tail; + } + + void logTestStart() { + if(hasLoggedCurrentTestStart) + return; + + separator_to_stream(); + file_line_to_stream(tc->m_file.c_str(), tc->m_line, "\n"); + if(tc->m_description) + s << Color::Yellow << "DESCRIPTION: " << Color::None << tc->m_description << "\n"; + if(tc->m_test_suite && tc->m_test_suite[0] != '\0') + s << Color::Yellow << "TEST SUITE: " << Color::None << tc->m_test_suite << "\n"; + if(strncmp(tc->m_name, " Scenario:", 11) != 0) + s << Color::Yellow << "TEST CASE: "; + s << Color::None << tc->m_name << "\n"; + + for(size_t i = 0; i < currentSubcaseLevel; ++i) { + if(subcasesStack[i].m_name[0] != '\0') + s << " " << subcasesStack[i].m_name << "\n"; + } + + if(currentSubcaseLevel != subcasesStack.size()) { + s << Color::Yellow << "\nDEEPEST SUBCASE STACK REACHED (DIFFERENT FROM THE CURRENT ONE):\n" << Color::None; + for(size_t i = 0; i < subcasesStack.size(); ++i) { + if(subcasesStack[i].m_name[0] != '\0') + s << " " << subcasesStack[i].m_name << "\n"; + } + } + + s << "\n"; + + hasLoggedCurrentTestStart = true; + } + + void printVersion() { + if(opt.no_version == false) + s << Color::Cyan << "[doctest] " << Color::None << "doctest version is \"" + << DOCTEST_VERSION_STR << "\"\n"; + } + + void printIntro() { + printVersion(); + s << Color::Cyan << "[doctest] " << Color::None + << "run with \"--" DOCTEST_OPTIONS_PREFIX_DISPLAY "help\" for options\n"; + } + + void printHelp() { + int sizePrefixDisplay = static_cast(strlen(DOCTEST_OPTIONS_PREFIX_DISPLAY)); + printVersion(); + // clang-format off + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "boolean values: \"1/on/yes/true\" or \"0/off/no/false\"\n"; + s << Color::Cyan << "[doctest] " << Color::None; + s << "filter values: \"str1,str2,str3\" (comma separated strings)\n"; + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "filters use wildcards for matching strings\n"; + s << Color::Cyan << "[doctest] " << Color::None; + s << "something passes a filter if any of the strings in a filter matches\n"; +#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "ALL FLAGS, OPTIONS AND FILTERS ALSO AVAILABLE WITH A \"" DOCTEST_CONFIG_OPTIONS_PREFIX "\" PREFIX!!!\n"; +#endif + s << Color::Cyan << "[doctest]\n" << Color::None; + s << Color::Cyan << "[doctest] " << Color::None; + s << "Query flags - the program quits after them. Available:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "?, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "help, -" DOCTEST_OPTIONS_PREFIX_DISPLAY "h " + << Whitespace(sizePrefixDisplay*0) << "prints this message\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "v, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "version " + << Whitespace(sizePrefixDisplay*1) << "prints the version\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "c, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "count " + << Whitespace(sizePrefixDisplay*1) << "prints the number of matching tests\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ltc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-cases " + << Whitespace(sizePrefixDisplay*1) << "lists all matching tests by name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-test-suites " + << Whitespace(sizePrefixDisplay*1) << "lists all matching test suites\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "lr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "list-reporters " + << Whitespace(sizePrefixDisplay*1) << "lists all registered reporters\n\n"; + // ================================================================================== << 79 + s << Color::Cyan << "[doctest] " << Color::None; + s << "The available / options/filters are:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-case-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their file\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sfe, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "source-file-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their file\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ts, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite= " + << Whitespace(sizePrefixDisplay*1) << "filters tests by their test suite\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "tse, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "test-suite-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT tests by their test suite\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase= " + << Whitespace(sizePrefixDisplay*1) << "filters subcases by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "sce, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-exclude= " + << Whitespace(sizePrefixDisplay*1) << "filters OUT subcases by their name\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "r, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "reporters= " + << Whitespace(sizePrefixDisplay*1) << "reporters to use (console is default)\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "o, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "out= " + << Whitespace(sizePrefixDisplay*1) << "output filename\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ob, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "order-by= " + << Whitespace(sizePrefixDisplay*1) << "how the tests should be ordered\n"; + s << Whitespace(sizePrefixDisplay*3) << " - [file/suite/name/rand/none]\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "rs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "rand-seed= " + << Whitespace(sizePrefixDisplay*1) << "seed for random ordering\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "f, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "first= " + << Whitespace(sizePrefixDisplay*1) << "the first test passing the filters to\n"; + s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "l, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "last= " + << Whitespace(sizePrefixDisplay*1) << "the last test passing the filters to\n"; + s << Whitespace(sizePrefixDisplay*3) << " execute - for range-based execution\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "aa, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "abort-after= " + << Whitespace(sizePrefixDisplay*1) << "stop after failed assertions\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "scfl,--" DOCTEST_OPTIONS_PREFIX_DISPLAY "subcase-filter-levels= " + << Whitespace(sizePrefixDisplay*1) << "apply filters for the first levels\n"; + s << Color::Cyan << "\n[doctest] " << Color::None; + s << "Bool options - can be used like flags and true is assumed. Available:\n\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "s, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "success= " + << Whitespace(sizePrefixDisplay*1) << "include successful assertions in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "cs, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "case-sensitive= " + << Whitespace(sizePrefixDisplay*1) << "filters being treated as case sensitive\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "e, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "exit= " + << Whitespace(sizePrefixDisplay*1) << "exits after the tests finish\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "d, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "duration= " + << Whitespace(sizePrefixDisplay*1) << "prints the time duration of each test\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nt, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-throw= " + << Whitespace(sizePrefixDisplay*1) << "skips exceptions-related assert checks\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ne, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-exitcode= " + << Whitespace(sizePrefixDisplay*1) << "returns (or exits) always with success\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nr, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-run= " + << Whitespace(sizePrefixDisplay*1) << "skips all runtime doctest operations\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nv, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-version= " + << Whitespace(sizePrefixDisplay*1) << "omit the framework version in the output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-colors= " + << Whitespace(sizePrefixDisplay*1) << "disables colors in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "fc, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "force-colors= " + << Whitespace(sizePrefixDisplay*1) << "use colors even when not in a tty\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nb, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-breaks= " + << Whitespace(sizePrefixDisplay*1) << "disables breakpoints in debuggers\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "ns, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-skip= " + << Whitespace(sizePrefixDisplay*1) << "don't skip test cases marked as skip\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "gfl, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "gnu-file-line= " + << Whitespace(sizePrefixDisplay*1) << ":n: vs (n): for line numbers in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "npf, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-path-filenames= " + << Whitespace(sizePrefixDisplay*1) << "only filenames and no paths in output\n"; + s << " -" DOCTEST_OPTIONS_PREFIX_DISPLAY "nln, --" DOCTEST_OPTIONS_PREFIX_DISPLAY "no-line-numbers= " + << Whitespace(sizePrefixDisplay*1) << "0 instead of real line numbers in output\n"; + // ================================================================================== << 79 + // clang-format on + + s << Color::Cyan << "\n[doctest] " << Color::None; + s << "for more information visit the project documentation\n\n"; + } + + void printRegisteredReporters() { + printVersion(); + auto printReporters = [this] (const reporterMap& reporters, const char* type) { + if(reporters.size()) { + s << Color::Cyan << "[doctest] " << Color::None << "listing all registered " << type << "\n"; + for(auto& curr : reporters) + s << "priority: " << std::setw(5) << curr.first.first + << " name: " << curr.first.second << "\n"; + } + }; + printReporters(getListeners(), "listeners"); + printReporters(getReporters(), "reporters"); + } + + void list_query_results() { + separator_to_stream(); + if(opt.count || opt.list_test_cases) { + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + } else if(opt.list_test_suites) { + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + s << Color::Cyan << "[doctest] " << Color::None + << "test suites with unskipped test cases passing the current filters: " + << g_cs->numTestSuitesPassingFilters << "\n"; + } + } + + // ========================================================================================= + // WHAT FOLLOWS ARE OVERRIDES OF THE VIRTUAL METHODS OF THE REPORTER INTERFACE + // ========================================================================================= + + void report_query(const QueryData& in) override { + if(opt.version) { + printVersion(); + } else if(opt.help) { + printHelp(); + } else if(opt.list_reporters) { + printRegisteredReporters(); + } else if(opt.count || opt.list_test_cases) { + if(opt.list_test_cases) { + s << Color::Cyan << "[doctest] " << Color::None + << "listing all test case names\n"; + separator_to_stream(); + } + + for(unsigned i = 0; i < in.num_data; ++i) + s << Color::None << in.data[i]->m_name << "\n"; + + separator_to_stream(); + + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + + } else if(opt.list_test_suites) { + s << Color::Cyan << "[doctest] " << Color::None << "listing all test suites\n"; + separator_to_stream(); + + for(unsigned i = 0; i < in.num_data; ++i) + s << Color::None << in.data[i]->m_test_suite << "\n"; + + separator_to_stream(); + + s << Color::Cyan << "[doctest] " << Color::None + << "unskipped test cases passing the current filters: " + << g_cs->numTestCasesPassingFilters << "\n"; + s << Color::Cyan << "[doctest] " << Color::None + << "test suites with unskipped test cases passing the current filters: " + << g_cs->numTestSuitesPassingFilters << "\n"; + } + } + + void test_run_start() override { printIntro(); } + + void test_run_end(const TestRunStats& p) override { + separator_to_stream(); + s << std::dec; + + auto totwidth = int(std::ceil(log10((std::max(p.numTestCasesPassingFilters, static_cast(p.numAsserts))) + 1))); + auto passwidth = int(std::ceil(log10((std::max(p.numTestCasesPassingFilters - p.numTestCasesFailed, static_cast(p.numAsserts - p.numAssertsFailed))) + 1))); + auto failwidth = int(std::ceil(log10((std::max(p.numTestCasesFailed, static_cast(p.numAssertsFailed))) + 1))); + const bool anythingFailed = p.numTestCasesFailed > 0 || p.numAssertsFailed > 0; + s << Color::Cyan << "[doctest] " << Color::None << "test cases: " << std::setw(totwidth) + << p.numTestCasesPassingFilters << " | " + << ((p.numTestCasesPassingFilters == 0 || anythingFailed) ? Color::None : + Color::Green) + << std::setw(passwidth) << p.numTestCasesPassingFilters - p.numTestCasesFailed << " passed" + << Color::None << " | " << (p.numTestCasesFailed > 0 ? Color::Red : Color::None) + << std::setw(failwidth) << p.numTestCasesFailed << " failed" << Color::None << " |"; + if(opt.no_skipped_summary == false) { + const int numSkipped = p.numTestCases - p.numTestCasesPassingFilters; + s << " " << (numSkipped == 0 ? Color::None : Color::Yellow) << numSkipped + << " skipped" << Color::None; + } + s << "\n"; + s << Color::Cyan << "[doctest] " << Color::None << "assertions: " << std::setw(totwidth) + << p.numAsserts << " | " + << ((p.numAsserts == 0 || anythingFailed) ? Color::None : Color::Green) + << std::setw(passwidth) << (p.numAsserts - p.numAssertsFailed) << " passed" << Color::None + << " | " << (p.numAssertsFailed > 0 ? Color::Red : Color::None) << std::setw(failwidth) + << p.numAssertsFailed << " failed" << Color::None << " |\n"; + s << Color::Cyan << "[doctest] " << Color::None + << "Status: " << (p.numTestCasesFailed > 0 ? Color::Red : Color::Green) + << ((p.numTestCasesFailed > 0) ? "FAILURE!" : "SUCCESS!") << Color::None << std::endl; + } + + void test_case_start(const TestCaseData& in) override { + hasLoggedCurrentTestStart = false; + tc = ∈ + subcasesStack.clear(); + currentSubcaseLevel = 0; + } + + void test_case_reenter(const TestCaseData&) override { + subcasesStack.clear(); + } + + void test_case_end(const CurrentTestCaseStats& st) override { + if(tc->m_no_output) + return; + + // log the preamble of the test case only if there is something + // else to print - something other than that an assert has failed + if(opt.duration || + (st.failure_flags && st.failure_flags != TestCaseFailureReason::AssertFailure)) + logTestStart(); + + if(opt.duration) + s << Color::None << std::setprecision(6) << std::fixed << st.seconds + << " s: " << tc->m_name << "\n"; + + if(st.failure_flags & TestCaseFailureReason::Timeout) + s << Color::Red << "Test case exceeded time limit of " << std::setprecision(6) + << std::fixed << tc->m_timeout << "!\n"; + + if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedButDidnt) { + s << Color::Red << "Should have failed but didn't! Marking it as failed!\n"; + } else if(st.failure_flags & TestCaseFailureReason::ShouldHaveFailedAndDid) { + s << Color::Yellow << "Failed as expected so marking it as not failed\n"; + } else if(st.failure_flags & TestCaseFailureReason::CouldHaveFailedAndDid) { + s << Color::Yellow << "Allowed to fail so marking it as not failed\n"; + } else if(st.failure_flags & TestCaseFailureReason::DidntFailExactlyNumTimes) { + s << Color::Red << "Didn't fail exactly " << tc->m_expected_failures + << " times so marking it as failed!\n"; + } else if(st.failure_flags & TestCaseFailureReason::FailedExactlyNumTimes) { + s << Color::Yellow << "Failed exactly " << tc->m_expected_failures + << " times as expected so marking it as not failed!\n"; + } + if(st.failure_flags & TestCaseFailureReason::TooManyFailedAsserts) { + s << Color::Red << "Aborting - too many failed asserts!\n"; + } + s << Color::None; // lgtm [cpp/useless-expression] + } + + void test_case_exception(const TestCaseException& e) override { + if(tc->m_no_output) + return; + + logTestStart(); + + file_line_to_stream(tc->m_file.c_str(), tc->m_line, " "); + successOrFailColoredStringToStream(false, e.is_crash ? assertType::is_require : + assertType::is_check); + s << Color::Red << (e.is_crash ? "test case CRASHED: " : "test case THREW exception: ") + << Color::Cyan << e.error_string << "\n"; + + int num_stringified_contexts = get_num_stringified_contexts(); + if(num_stringified_contexts) { + auto stringified_contexts = get_stringified_contexts(); + s << Color::None << " logged: "; + for(int i = num_stringified_contexts; i > 0; --i) { + s << (i == num_stringified_contexts ? "" : " ") + << stringified_contexts[i - 1] << "\n"; + } + } + s << "\n" << Color::None; + } + + void subcase_start(const SubcaseSignature& subc) override { + std::lock_guard lock(mutex); + subcasesStack.push_back(subc); + ++currentSubcaseLevel; + hasLoggedCurrentTestStart = false; + } + + void subcase_end() override { + std::lock_guard lock(mutex); + --currentSubcaseLevel; + hasLoggedCurrentTestStart = false; + } + + void log_assert(const AssertData& rb) override { + if((!rb.m_failed && !opt.success) || tc->m_no_output) + return; + + std::lock_guard lock(mutex); + + logTestStart(); + + file_line_to_stream(rb.m_file, rb.m_line, " "); + successOrFailColoredStringToStream(!rb.m_failed, rb.m_at); + + fulltext_log_assert_to_stream(s, rb); + + log_contexts(); + } + + void log_message(const MessageData& mb) override { + if(tc->m_no_output) + return; + + std::lock_guard lock(mutex); + + logTestStart(); + + file_line_to_stream(mb.m_file, mb.m_line, " "); + s << getSuccessOrFailColor(false, mb.m_severity) + << getSuccessOrFailString(mb.m_severity & assertType::is_warn, mb.m_severity, + "MESSAGE") << ": "; + s << Color::None << mb.m_string << "\n"; + log_contexts(); + } + + void test_case_skipped(const TestCaseData&) override {} + }; + + DOCTEST_REGISTER_REPORTER("console", 0, ConsoleReporter); + +#ifdef DOCTEST_PLATFORM_WINDOWS + struct DebugOutputWindowReporter : public ConsoleReporter + { + DOCTEST_THREAD_LOCAL static std::ostringstream oss; + + DebugOutputWindowReporter(const ContextOptions& co) + : ConsoleReporter(co, oss) {} + +#define DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(func, type, arg) \ + void func(type arg) override { \ + bool with_col = g_no_colors; \ + g_no_colors = false; \ + ConsoleReporter::func(arg); \ + if(oss.tellp() != std::streampos{}) { \ + DOCTEST_OUTPUT_DEBUG_STRING(oss.str().c_str()); \ + oss.str(""); \ + } \ + g_no_colors = with_col; \ + } + + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_start, DOCTEST_EMPTY, DOCTEST_EMPTY) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_run_end, const TestRunStats&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_start, const TestCaseData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_reenter, const TestCaseData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_end, const CurrentTestCaseStats&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_exception, const TestCaseException&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_start, const SubcaseSignature&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(subcase_end, DOCTEST_EMPTY, DOCTEST_EMPTY) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_assert, const AssertData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(log_message, const MessageData&, in) + DOCTEST_DEBUG_OUTPUT_REPORTER_OVERRIDE(test_case_skipped, const TestCaseData&, in) + }; + + DOCTEST_THREAD_LOCAL std::ostringstream DebugOutputWindowReporter::oss; +#endif // DOCTEST_PLATFORM_WINDOWS + + // the implementation of parseOption() + bool parseOptionImpl(int argc, const char* const* argv, const char* pattern, String* value) { + // going from the end to the beginning and stopping on the first occurrence from the end + for(int i = argc; i > 0; --i) { + auto index = i - 1; + auto temp = std::strstr(argv[index], pattern); + if(temp && (value || strlen(temp) == strlen(pattern))) { //!OCLINT prefer early exits and continue + // eliminate matches in which the chars before the option are not '-' + bool noBadCharsFound = true; + auto curr = argv[index]; + while(curr != temp) { + if(*curr++ != '-') { + noBadCharsFound = false; + break; + } + } + if(noBadCharsFound && argv[index][0] == '-') { + if(value) { + // parsing the value of an option + temp += strlen(pattern); + const unsigned len = strlen(temp); + if(len) { + *value = temp; + return true; + } + } else { + // just a flag - no value + return true; + } + } + } + } + return false; + } + + // parses an option and returns the string after the '=' character + bool parseOption(int argc, const char* const* argv, const char* pattern, String* value = nullptr, + const String& defaultVal = String()) { + if(value) + *value = defaultVal; +#ifndef DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + // offset (normally 3 for "dt-") to skip prefix + if(parseOptionImpl(argc, argv, pattern + strlen(DOCTEST_CONFIG_OPTIONS_PREFIX), value)) + return true; +#endif // DOCTEST_CONFIG_NO_UNPREFIXED_OPTIONS + return parseOptionImpl(argc, argv, pattern, value); + } + + // locates a flag on the command line + bool parseFlag(int argc, const char* const* argv, const char* pattern) { + return parseOption(argc, argv, pattern); + } + + // parses a comma separated list of words after a pattern in one of the arguments in argv + bool parseCommaSepArgs(int argc, const char* const* argv, const char* pattern, + std::vector& res) { + String filtersString; + if(parseOption(argc, argv, pattern, &filtersString)) { + // tokenize with "," as a separator + // cppcheck-suppress strtokCalled + DOCTEST_CLANG_SUPPRESS_WARNING_WITH_PUSH("-Wdeprecated-declarations") + auto pch = std::strtok(filtersString.c_str(), ","); // modifies the string + while(pch != nullptr) { + if(strlen(pch)) + res.push_back(pch); + // uses the strtok() internal state to go to the next token + // cppcheck-suppress strtokCalled + pch = std::strtok(nullptr, ","); + } + DOCTEST_CLANG_SUPPRESS_WARNING_POP + return true; + } + return false; + } + + enum optionType + { + option_bool, + option_int + }; + + // parses an int/bool option from the command line + bool parseIntOption(int argc, const char* const* argv, const char* pattern, optionType type, + int& res) { + String parsedValue; + if(!parseOption(argc, argv, pattern, &parsedValue)) + return false; + + if(type == 0) { + // boolean + const char positive[][5] = {"1", "true", "on", "yes"}; // 5 - strlen("true") + 1 + const char negative[][6] = {"0", "false", "off", "no"}; // 6 - strlen("false") + 1 + + // if the value matches any of the positive/negative possibilities + for(unsigned i = 0; i < 4; i++) { + if(parsedValue.compare(positive[i], true) == 0) { + res = 1; //!OCLINT parameter reassignment + return true; + } + if(parsedValue.compare(negative[i], true) == 0) { + res = 0; //!OCLINT parameter reassignment + return true; + } + } + } else { + // integer + // TODO: change this to use std::stoi or something else! currently it uses undefined behavior - assumes '0' on failed parse... + int theInt = std::atoi(parsedValue.c_str()); // NOLINT + if(theInt != 0) { + res = theInt; //!OCLINT parameter reassignment + return true; + } + } + return false; + } +} // namespace + +Context::Context(int argc, const char* const* argv) + : p(new detail::ContextState) { + parseArgs(argc, argv, true); + if(argc) + p->binary_name = argv[0]; +} + +Context::~Context() { + if(g_cs == p) + g_cs = nullptr; + delete p; +} + +void Context::applyCommandLine(int argc, const char* const* argv) { + parseArgs(argc, argv); + if(argc) + p->binary_name = argv[0]; +} + +// parses args +void Context::parseArgs(int argc, const char* const* argv, bool withDefaults) { + using namespace detail; + + // clang-format off + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file=", p->filters[0]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sf=", p->filters[0]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "source-file-exclude=",p->filters[1]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sfe=", p->filters[1]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite=", p->filters[2]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ts=", p->filters[2]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-suite-exclude=", p->filters[3]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tse=", p->filters[3]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case=", p->filters[4]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tc=", p->filters[4]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "test-case-exclude=", p->filters[5]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "tce=", p->filters[5]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase=", p->filters[6]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sc=", p->filters[6]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "subcase-exclude=", p->filters[7]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "sce=", p->filters[7]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "reporters=", p->filters[8]); + parseCommaSepArgs(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "r=", p->filters[8]); + // clang-format on + + int intRes = 0; + String strRes; + +#define DOCTEST_PARSE_AS_BOOL_OR_FLAG(name, sname, var, default) \ + if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_bool, intRes) || \ + parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_bool, intRes)) \ + p->var = static_cast(intRes); \ + else if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name) || \ + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname)) \ + p->var = true; \ + else if(withDefaults) \ + p->var = default + +#define DOCTEST_PARSE_INT_OPTION(name, sname, var, default) \ + if(parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", option_int, intRes) || \ + parseIntOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", option_int, intRes)) \ + p->var = intRes; \ + else if(withDefaults) \ + p->var = default + +#define DOCTEST_PARSE_STR_OPTION(name, sname, var, default) \ + if(parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX name "=", &strRes, default) || \ + parseOption(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX sname "=", &strRes, default) || \ + withDefaults) \ + p->var = strRes + + // clang-format off + DOCTEST_PARSE_STR_OPTION("out", "o", out, ""); + DOCTEST_PARSE_STR_OPTION("order-by", "ob", order_by, "file"); + DOCTEST_PARSE_INT_OPTION("rand-seed", "rs", rand_seed, 0); + + DOCTEST_PARSE_INT_OPTION("first", "f", first, 0); + DOCTEST_PARSE_INT_OPTION("last", "l", last, UINT_MAX); + + DOCTEST_PARSE_INT_OPTION("abort-after", "aa", abort_after, 0); + DOCTEST_PARSE_INT_OPTION("subcase-filter-levels", "scfl", subcase_filter_levels, INT_MAX); + + DOCTEST_PARSE_AS_BOOL_OR_FLAG("success", "s", success, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("case-sensitive", "cs", case_sensitive, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("exit", "e", exit, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("duration", "d", duration, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-throw", "nt", no_throw, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-exitcode", "ne", no_exitcode, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-run", "nr", no_run, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-version", "nv", no_version, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-colors", "nc", no_colors, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("force-colors", "fc", force_colors, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-breaks", "nb", no_breaks, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skip", "ns", no_skip, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("gnu-file-line", "gfl", gnu_file_line, !bool(DOCTEST_MSVC)); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-path-filenames", "npf", no_path_in_filenames, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-line-numbers", "nln", no_line_numbers, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-debug-output", "ndo", no_debug_output, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-skipped-summary", "nss", no_skipped_summary, false); + DOCTEST_PARSE_AS_BOOL_OR_FLAG("no-time-in-output", "ntio", no_time_in_output, false); + // clang-format on + + if(withDefaults) { + p->help = false; + p->version = false; + p->count = false; + p->list_test_cases = false; + p->list_test_suites = false; + p->list_reporters = false; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "help") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "h") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "?")) { + p->help = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "version") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "v")) { + p->version = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "count") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "c")) { + p->count = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-cases") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "ltc")) { + p->list_test_cases = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-test-suites") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lts")) { + p->list_test_suites = true; + p->exit = true; + } + if(parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "list-reporters") || + parseFlag(argc, argv, DOCTEST_CONFIG_OPTIONS_PREFIX "lr")) { + p->list_reporters = true; + p->exit = true; + } +} + +// allows the user to add procedurally to the filters from the command line +void Context::addFilter(const char* filter, const char* value) { setOption(filter, value); } + +// allows the user to clear all filters from the command line +void Context::clearFilters() { + for(auto& curr : p->filters) + curr.clear(); +} + +// allows the user to override procedurally the int/bool options from the command line +void Context::setOption(const char* option, int value) { + setOption(option, toString(value).c_str()); + // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks) +} + +// allows the user to override procedurally the string options from the command line +void Context::setOption(const char* option, const char* value) { + auto argv = String("-") + option + "=" + value; + auto lvalue = argv.c_str(); + parseArgs(1, &lvalue); +} + +// users should query this in their main() and exit the program if true +bool Context::shouldExit() { return p->exit; } + +void Context::setAsDefaultForAssertsOutOfTestCases() { g_cs = p; } + +void Context::setAssertHandler(detail::assert_handler ah) { p->ah = ah; } + +// the main function that does all the filtering and test running +int Context::run() { + using namespace detail; + + // save the old context state in case such was setup - for using asserts out of a testing context + auto old_cs = g_cs; + // this is the current contest + g_cs = p; + is_running_in_test = true; + + g_no_colors = p->no_colors; + p->resetRunData(); + + // stdout by default + p->cout = &std::cout; + p->cerr = &std::cerr; + + // or to a file if specified + std::fstream fstr; + if(p->out.size()) { + fstr.open(p->out.c_str(), std::fstream::out); + p->cout = &fstr; + } + + FatalConditionHandler::allocateAltStackMem(); + + auto cleanup_and_return = [&]() { + FatalConditionHandler::freeAltStackMem(); + + if(fstr.is_open()) + fstr.close(); + + // restore context + g_cs = old_cs; + is_running_in_test = false; + + // we have to free the reporters which were allocated when the run started + for(auto& curr : p->reporters_currently_used) + delete curr; + p->reporters_currently_used.clear(); + + if(p->numTestCasesFailed && !p->no_exitcode) + return EXIT_FAILURE; + return EXIT_SUCCESS; + }; + + // setup default reporter if none is given through the command line + if(p->filters[8].empty()) + p->filters[8].push_back("console"); + + // check to see if any of the registered reporters has been selected + for(auto& curr : getReporters()) { + if(matchesAny(curr.first.second.c_str(), p->filters[8], false, p->case_sensitive)) + p->reporters_currently_used.push_back(curr.second(*g_cs)); + } + + // TODO: check if there is nothing in reporters_currently_used + + // prepend all listeners + for(auto& curr : getListeners()) + p->reporters_currently_used.insert(p->reporters_currently_used.begin(), curr.second(*g_cs)); + +#ifdef DOCTEST_PLATFORM_WINDOWS + if(isDebuggerActive() && p->no_debug_output == false) + p->reporters_currently_used.push_back(new DebugOutputWindowReporter(*g_cs)); +#endif // DOCTEST_PLATFORM_WINDOWS + + // handle version, help and no_run + if(p->no_run || p->version || p->help || p->list_reporters) { + DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, QueryData()); + + return cleanup_and_return(); + } + + std::vector testArray; + for(auto& curr : getRegisteredTests()) + testArray.push_back(&curr); + p->numTestCases = testArray.size(); + + // sort the collected records + if(!testArray.empty()) { + if(p->order_by.compare("file", true) == 0) { + std::sort(testArray.begin(), testArray.end(), fileOrderComparator); + } else if(p->order_by.compare("suite", true) == 0) { + std::sort(testArray.begin(), testArray.end(), suiteOrderComparator); + } else if(p->order_by.compare("name", true) == 0) { + std::sort(testArray.begin(), testArray.end(), nameOrderComparator); + } else if(p->order_by.compare("rand", true) == 0) { + std::srand(p->rand_seed); + + // random_shuffle implementation + const auto first = &testArray[0]; + for(size_t i = testArray.size() - 1; i > 0; --i) { + int idxToSwap = std::rand() % (i + 1); // NOLINT + + const auto temp = first[i]; + + first[i] = first[idxToSwap]; + first[idxToSwap] = temp; + } + } else if(p->order_by.compare("none", true) == 0) { + // means no sorting - beneficial for death tests which call into the executable + // with a specific test case in mind - we don't want to slow down the startup times + } + } + + std::set testSuitesPassingFilt; + + bool query_mode = p->count || p->list_test_cases || p->list_test_suites; + std::vector queryResults; + + if(!query_mode) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_start, DOCTEST_EMPTY); + + // invoke the registered functions if they match the filter criteria (or just count them) + for(auto& curr : testArray) { + const auto& tc = *curr; + + bool skip_me = false; + if(tc.m_skip && !p->no_skip) + skip_me = true; + + if(!matchesAny(tc.m_file.c_str(), p->filters[0], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_file.c_str(), p->filters[1], false, p->case_sensitive)) + skip_me = true; + if(!matchesAny(tc.m_test_suite, p->filters[2], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_test_suite, p->filters[3], false, p->case_sensitive)) + skip_me = true; + if(!matchesAny(tc.m_name, p->filters[4], true, p->case_sensitive)) + skip_me = true; + if(matchesAny(tc.m_name, p->filters[5], false, p->case_sensitive)) + skip_me = true; + + if(!skip_me) + p->numTestCasesPassingFilters++; + + // skip the test if it is not in the execution range + if((p->last < p->numTestCasesPassingFilters && p->first <= p->last) || + (p->first > p->numTestCasesPassingFilters)) + skip_me = true; + + if(skip_me) { + if(!query_mode) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_skipped, tc); + continue; + } + + // do not execute the test if we are to only count the number of filter passing tests + if(p->count) + continue; + + // print the name of the test and don't execute it + if(p->list_test_cases) { + queryResults.push_back(&tc); + continue; + } + + // print the name of the test suite if not done already and don't execute it + if(p->list_test_suites) { + if((testSuitesPassingFilt.count(tc.m_test_suite) == 0) && tc.m_test_suite[0] != '\0') { + queryResults.push_back(&tc); + testSuitesPassingFilt.insert(tc.m_test_suite); + p->numTestSuitesPassingFilters++; + } + continue; + } + + // execute the test if it passes all the filtering + { + p->currentTest = &tc; + + p->failure_flags = TestCaseFailureReason::None; + p->seconds = 0; + + // reset atomic counters + p->numAssertsFailedCurrentTest_atomic = 0; + p->numAssertsCurrentTest_atomic = 0; + + p->subcasesPassed.clear(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_start, tc); + + p->timer.start(); + + bool run_test = true; + + do { + // reset some of the fields for subcases (except for the set of fully passed ones) + p->should_reenter = false; + p->subcasesCurrentMaxLevel = 0; + p->subcasesStack.clear(); + + p->shouldLogCurrentException = true; + + // reset stuff for logging with INFO() + p->stringifiedContexts.clear(); + +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + try { +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS +// MSVC 2015 diagnoses fatalConditionHandler as unused (because reset() is a static method) +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4101) // unreferenced local variable + FatalConditionHandler fatalConditionHandler; // Handle signals + // execute the test + tc.m_test(); + fatalConditionHandler.reset(); +DOCTEST_MSVC_SUPPRESS_WARNING_POP +#ifndef DOCTEST_CONFIG_NO_EXCEPTIONS + } catch(const TestFailureException&) { + p->failure_flags |= TestCaseFailureReason::AssertFailure; + } catch(...) { + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_exception, + {translateActiveException(), false}); + p->failure_flags |= TestCaseFailureReason::Exception; + } +#endif // DOCTEST_CONFIG_NO_EXCEPTIONS + + // exit this loop if enough assertions have failed - even if there are more subcases + if(p->abort_after > 0 && + p->numAssertsFailed + p->numAssertsFailedCurrentTest_atomic >= p->abort_after) { + run_test = false; + p->failure_flags |= TestCaseFailureReason::TooManyFailedAsserts; + } + + if(p->should_reenter && run_test) + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_reenter, tc); + if(!p->should_reenter) + run_test = false; + } while(run_test); + + p->finalizeTestCaseData(); + + DOCTEST_ITERATE_THROUGH_REPORTERS(test_case_end, *g_cs); + + p->currentTest = nullptr; + + // stop executing tests if enough assertions have failed + if(p->abort_after > 0 && p->numAssertsFailed >= p->abort_after) + break; + } + } + + if(!query_mode) { + DOCTEST_ITERATE_THROUGH_REPORTERS(test_run_end, *g_cs); + } else { + QueryData qdata; + qdata.run_stats = g_cs; + qdata.data = queryResults.data(); + qdata.num_data = unsigned(queryResults.size()); + DOCTEST_ITERATE_THROUGH_REPORTERS(report_query, qdata); + } + + // see these issues on the reasoning for this: + // - https://github.com/onqtam/doctest/issues/143#issuecomment-414418903 + // - https://github.com/onqtam/doctest/issues/126 + auto DOCTEST_FIX_FOR_MACOS_LIBCPP_IOSFWD_STRING_LINK_ERRORS = []() DOCTEST_NOINLINE + { std::cout << std::string(); }; + DOCTEST_FIX_FOR_MACOS_LIBCPP_IOSFWD_STRING_LINK_ERRORS(); + + return cleanup_and_return(); +} + +IReporter::~IReporter() = default; + +int IReporter::get_num_active_contexts() { return detail::g_infoContexts.size(); } +const IContextScope* const* IReporter::get_active_contexts() { + return get_num_active_contexts() ? &detail::g_infoContexts[0] : nullptr; +} + +int IReporter::get_num_stringified_contexts() { return detail::g_cs->stringifiedContexts.size(); } +const String* IReporter::get_stringified_contexts() { + return get_num_stringified_contexts() ? &detail::g_cs->stringifiedContexts[0] : nullptr; +} + +namespace detail { + void registerReporterImpl(const char* name, int priority, reporterCreatorFunc c, bool isReporter) { + if(isReporter) + getReporters().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); + else + getListeners().insert(reporterMap::value_type(reporterMap::key_type(priority, name), c)); + } +} // namespace detail + +} // namespace doctest + +#endif // DOCTEST_CONFIG_DISABLE + +#ifdef DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +DOCTEST_MSVC_SUPPRESS_WARNING_WITH_PUSH(4007) // 'function' : must be 'attribute' - see issue #182 +int main(int argc, char** argv) { return doctest::Context(argc, argv).run(); } +DOCTEST_MSVC_SUPPRESS_WARNING_POP +#endif // DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN + +DOCTEST_CLANG_SUPPRESS_WARNING_POP +DOCTEST_MSVC_SUPPRESS_WARNING_POP +DOCTEST_GCC_SUPPRESS_WARNING_POP + +#endif // DOCTEST_LIBRARY_IMPLEMENTATION +#endif // DOCTEST_CONFIG_IMPLEMENT diff --git a/extern/linenoise.hpp b/extern/linenoise.hpp new file mode 100644 index 0000000..ae36eb0 --- /dev/null +++ b/extern/linenoise.hpp @@ -0,0 +1,2415 @@ +/* + * linenoise.hpp -- Multi-platfrom C++ header-only linenoise library. + * + * All credits and commendations have to go to the authors of the + * following excellent libraries. + * + * - linenoise.h and linenose.c (https://github.com/antirez/linenoise) + * - ANSI.c (https://github.com/adoxa/ansicon) + * - Win32_ANSI.h and Win32_ANSI.c (https://github.com/MSOpenTech/redis) + * + * ------------------------------------------------------------------------ + * + * Copyright (c) 2015 yhirose + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* linenoise.h -- guerrilla line editing library against the idea that a + * line editing lib needs to be 20,000 lines of C code. + * + * See linenoise.c for more information. + * + * ------------------------------------------------------------------------ + * + * Copyright (c) 2010, Salvatore Sanfilippo + * Copyright (c) 2010, Pieter Noordhuis + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * ANSI.c - ANSI escape sequence console driver. + * + * Copyright (C) 2005-2014 Jason Hood + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the author be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + * misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + * + * Jason Hood + * jadoxa@yahoo.com.au + */ + +/* + * Win32_ANSI.h and Win32_ANSI.c + * + * Derived from ANSI.c by Jason Hood, from his ansicon project (https://github.com/adoxa/ansicon), with modifications. + * + * Copyright (c), Microsoft Open Technologies, Inc. + * All rights reserved. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * - Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * - Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef LINENOISE_HPP +#define LINENOISE_HPP + +#ifndef _WIN32 +#include +#include +#include +#else +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#ifndef STDIN_FILENO +#define STDIN_FILENO (_fileno(stdin)) +#endif +#ifndef STDOUT_FILENO +#define STDOUT_FILENO 1 +#endif +#define isatty _isatty +#define write win32_write +#define read _read +#pragma warning(push) +#pragma warning(disable : 4996) +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace linenoise { + +typedef std::function&)> CompletionCallback; + +#ifdef _WIN32 + +namespace ansi { + +#define lenof(array) (sizeof(array)/sizeof(*(array))) + +typedef struct +{ + BYTE foreground; // ANSI base color (0 to 7; add 30) + BYTE background; // ANSI base color (0 to 7; add 40) + BYTE bold; // console FOREGROUND_INTENSITY bit + BYTE underline; // console BACKGROUND_INTENSITY bit + BYTE rvideo; // swap foreground/bold & background/underline + BYTE concealed; // set foreground/bold to background/underline + BYTE reverse; // swap console foreground & background attributes +} GRM, *PGRM; // Graphic Rendition Mode + + +inline bool is_digit(char c) { return '0' <= c && c <= '9'; } + +// ========== Global variables and constants + +HANDLE hConOut; // handle to CONOUT$ + +const char ESC = '\x1B'; // ESCape character +const char BEL = '\x07'; +const char SO = '\x0E'; // Shift Out +const char SI = '\x0F'; // Shift In + +const int MAX_ARG = 16; // max number of args in an escape sequence +int state; // automata state +WCHAR prefix; // escape sequence prefix ( '[', ']' or '(' ); +WCHAR prefix2; // secondary prefix ( '?' or '>' ); +WCHAR suffix; // escape sequence suffix +int es_argc; // escape sequence args count +int es_argv[MAX_ARG]; // escape sequence args +WCHAR Pt_arg[MAX_PATH * 2]; // text parameter for Operating System Command +int Pt_len; +BOOL shifted; + + +// DEC Special Graphics Character Set from +// http://vt100.net/docs/vt220-rm/table2-4.html +// Some of these may not look right, depending on the font and code page (in +// particular, the Control Pictures probably won't work at all). +const WCHAR G1[] = +{ + ' ', // _ - blank + L'\x2666', // ` - Black Diamond Suit + L'\x2592', // a - Medium Shade + L'\x2409', // b - HT + L'\x240c', // c - FF + L'\x240d', // d - CR + L'\x240a', // e - LF + L'\x00b0', // f - Degree Sign + L'\x00b1', // g - Plus-Minus Sign + L'\x2424', // h - NL + L'\x240b', // i - VT + L'\x2518', // j - Box Drawings Light Up And Left + L'\x2510', // k - Box Drawings Light Down And Left + L'\x250c', // l - Box Drawings Light Down And Right + L'\x2514', // m - Box Drawings Light Up And Right + L'\x253c', // n - Box Drawings Light Vertical And Horizontal + L'\x00af', // o - SCAN 1 - Macron + L'\x25ac', // p - SCAN 3 - Black Rectangle + L'\x2500', // q - SCAN 5 - Box Drawings Light Horizontal + L'_', // r - SCAN 7 - Low Line + L'_', // s - SCAN 9 - Low Line + L'\x251c', // t - Box Drawings Light Vertical And Right + L'\x2524', // u - Box Drawings Light Vertical And Left + L'\x2534', // v - Box Drawings Light Up And Horizontal + L'\x252c', // w - Box Drawings Light Down And Horizontal + L'\x2502', // x - Box Drawings Light Vertical + L'\x2264', // y - Less-Than Or Equal To + L'\x2265', // z - Greater-Than Or Equal To + L'\x03c0', // { - Greek Small Letter Pi + L'\x2260', // | - Not Equal To + L'\x00a3', // } - Pound Sign + L'\x00b7', // ~ - Middle Dot +}; + +#define FIRST_G1 '_' +#define LAST_G1 '~' + + +// color constants + +#define FOREGROUND_BLACK 0 +#define FOREGROUND_WHITE FOREGROUND_RED|FOREGROUND_GREEN|FOREGROUND_BLUE + +#define BACKGROUND_BLACK 0 +#define BACKGROUND_WHITE BACKGROUND_RED|BACKGROUND_GREEN|BACKGROUND_BLUE + +const BYTE foregroundcolor[8] = + { + FOREGROUND_BLACK, // black foreground + FOREGROUND_RED, // red foreground + FOREGROUND_GREEN, // green foreground + FOREGROUND_RED | FOREGROUND_GREEN, // yellow foreground + FOREGROUND_BLUE, // blue foreground + FOREGROUND_BLUE | FOREGROUND_RED, // magenta foreground + FOREGROUND_BLUE | FOREGROUND_GREEN, // cyan foreground + FOREGROUND_WHITE // white foreground + }; + +const BYTE backgroundcolor[8] = + { + BACKGROUND_BLACK, // black background + BACKGROUND_RED, // red background + BACKGROUND_GREEN, // green background + BACKGROUND_RED | BACKGROUND_GREEN, // yellow background + BACKGROUND_BLUE, // blue background + BACKGROUND_BLUE | BACKGROUND_RED, // magenta background + BACKGROUND_BLUE | BACKGROUND_GREEN, // cyan background + BACKGROUND_WHITE, // white background + }; + +const BYTE attr2ansi[8] = // map console attribute to ANSI number +{ + 0, // black + 4, // blue + 2, // green + 6, // cyan + 1, // red + 5, // magenta + 3, // yellow + 7 // white +}; + +GRM grm; + +// saved cursor position +COORD SavePos; + +// ========== Print Buffer functions + +#define BUFFER_SIZE 2048 + +int nCharInBuffer; +WCHAR ChBuffer[BUFFER_SIZE]; + +//----------------------------------------------------------------------------- +// FlushBuffer() +// Writes the buffer to the console and empties it. +//----------------------------------------------------------------------------- + +inline void FlushBuffer(void) +{ + DWORD nWritten; + if (nCharInBuffer <= 0) return; + WriteConsoleW(hConOut, ChBuffer, nCharInBuffer, &nWritten, NULL); + nCharInBuffer = 0; +} + +//----------------------------------------------------------------------------- +// PushBuffer( WCHAR c ) +// Adds a character in the buffer. +//----------------------------------------------------------------------------- + +inline void PushBuffer(WCHAR c) +{ + if (shifted && c >= FIRST_G1 && c <= LAST_G1) + c = G1[c - FIRST_G1]; + ChBuffer[nCharInBuffer] = c; + if (++nCharInBuffer == BUFFER_SIZE) + FlushBuffer(); +} + +//----------------------------------------------------------------------------- +// SendSequence( LPCWSTR seq ) +// Send the string to the input buffer. +//----------------------------------------------------------------------------- + +inline void SendSequence(LPCWSTR seq) +{ + DWORD out; + INPUT_RECORD in; + HANDLE hStdIn = GetStdHandle(STD_INPUT_HANDLE); + + in.EventType = KEY_EVENT; + in.Event.KeyEvent.bKeyDown = TRUE; + in.Event.KeyEvent.wRepeatCount = 1; + in.Event.KeyEvent.wVirtualKeyCode = 0; + in.Event.KeyEvent.wVirtualScanCode = 0; + in.Event.KeyEvent.dwControlKeyState = 0; + for (; *seq; ++seq) + { + in.Event.KeyEvent.uChar.UnicodeChar = *seq; + WriteConsoleInput(hStdIn, &in, 1, &out); + } +} + +// ========== Print functions + +//----------------------------------------------------------------------------- +// InterpretEscSeq() +// Interprets the last escape sequence scanned by ParseAndPrintANSIString +// prefix escape sequence prefix +// es_argc escape sequence args count +// es_argv[] escape sequence args array +// suffix escape sequence suffix +// +// for instance, with \e[33;45;1m we have +// prefix = '[', +// es_argc = 3, es_argv[0] = 33, es_argv[1] = 45, es_argv[2] = 1 +// suffix = 'm' +//----------------------------------------------------------------------------- + +inline void InterpretEscSeq(void) +{ + int i; + WORD attribut; + CONSOLE_SCREEN_BUFFER_INFO Info; + CONSOLE_CURSOR_INFO CursInfo; + DWORD len, NumberOfCharsWritten; + COORD Pos; + SMALL_RECT Rect; + CHAR_INFO CharInfo; + + if (prefix == '[') + { + if (prefix2 == '?' && (suffix == 'h' || suffix == 'l')) + { + if (es_argc == 1 && es_argv[0] == 25) + { + GetConsoleCursorInfo(hConOut, &CursInfo); + CursInfo.bVisible = (suffix == 'h'); + SetConsoleCursorInfo(hConOut, &CursInfo); + return; + } + } + // Ignore any other \e[? or \e[> sequences. + if (prefix2 != 0) + return; + + GetConsoleScreenBufferInfo(hConOut, &Info); + switch (suffix) + { + case 'm': + if (es_argc == 0) es_argv[es_argc++] = 0; + for (i = 0; i < es_argc; i++) + { + if (30 <= es_argv[i] && es_argv[i] <= 37) + grm.foreground = es_argv[i] - 30; + else if (40 <= es_argv[i] && es_argv[i] <= 47) + grm.background = es_argv[i] - 40; + else switch (es_argv[i]) + { + case 0: + case 39: + case 49: + { + WCHAR def[4]; + int a; + *def = '7'; def[1] = '\0'; + GetEnvironmentVariableW(L"ANSICON_DEF", def, lenof(def)); + a = wcstol(def, NULL, 16); + grm.reverse = FALSE; + if (a < 0) + { + grm.reverse = TRUE; + a = -a; + } + if (es_argv[i] != 49) + grm.foreground = attr2ansi[a & 7]; + if (es_argv[i] != 39) + grm.background = attr2ansi[(a >> 4) & 7]; + if (es_argv[i] == 0) + { + if (es_argc == 1) + { + grm.bold = a & FOREGROUND_INTENSITY; + grm.underline = a & BACKGROUND_INTENSITY; + } + else + { + grm.bold = 0; + grm.underline = 0; + } + grm.rvideo = 0; + grm.concealed = 0; + } + } + break; + + case 1: grm.bold = FOREGROUND_INTENSITY; break; + case 5: // blink + case 4: grm.underline = BACKGROUND_INTENSITY; break; + case 7: grm.rvideo = 1; break; + case 8: grm.concealed = 1; break; + case 21: // oops, this actually turns on double underline + case 22: grm.bold = 0; break; + case 25: + case 24: grm.underline = 0; break; + case 27: grm.rvideo = 0; break; + case 28: grm.concealed = 0; break; + } + } + if (grm.concealed) + { + if (grm.rvideo) + { + attribut = foregroundcolor[grm.foreground] + | backgroundcolor[grm.foreground]; + if (grm.bold) + attribut |= FOREGROUND_INTENSITY | BACKGROUND_INTENSITY; + } + else + { + attribut = foregroundcolor[grm.background] + | backgroundcolor[grm.background]; + if (grm.underline) + attribut |= FOREGROUND_INTENSITY | BACKGROUND_INTENSITY; + } + } + else if (grm.rvideo) + { + attribut = foregroundcolor[grm.background] + | backgroundcolor[grm.foreground]; + if (grm.bold) + attribut |= BACKGROUND_INTENSITY; + if (grm.underline) + attribut |= FOREGROUND_INTENSITY; + } + else + attribut = foregroundcolor[grm.foreground] | grm.bold + | backgroundcolor[grm.background] | grm.underline; + if (grm.reverse) + attribut = ((attribut >> 4) & 15) | ((attribut & 15) << 4); + SetConsoleTextAttribute(hConOut, attribut); + return; + + case 'J': + if (es_argc == 0) es_argv[es_argc++] = 0; // ESC[J == ESC[0J + if (es_argc != 1) return; + switch (es_argv[0]) + { + case 0: // ESC[0J erase from cursor to end of display + len = (Info.dwSize.Y - Info.dwCursorPosition.Y - 1) * Info.dwSize.X + + Info.dwSize.X - Info.dwCursorPosition.X - 1; + FillConsoleOutputCharacter(hConOut, ' ', len, + Info.dwCursorPosition, + &NumberOfCharsWritten); + FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, + Info.dwCursorPosition, + &NumberOfCharsWritten); + return; + + case 1: // ESC[1J erase from start to cursor. + Pos.X = 0; + Pos.Y = 0; + len = Info.dwCursorPosition.Y * Info.dwSize.X + + Info.dwCursorPosition.X + 1; + FillConsoleOutputCharacter(hConOut, ' ', len, Pos, + &NumberOfCharsWritten); + FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, Pos, + &NumberOfCharsWritten); + return; + + case 2: // ESC[2J Clear screen and home cursor + Pos.X = 0; + Pos.Y = 0; + len = Info.dwSize.X * Info.dwSize.Y; + FillConsoleOutputCharacter(hConOut, ' ', len, Pos, + &NumberOfCharsWritten); + FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, Pos, + &NumberOfCharsWritten); + SetConsoleCursorPosition(hConOut, Pos); + return; + + default: + return; + } + + case 'K': + if (es_argc == 0) es_argv[es_argc++] = 0; // ESC[K == ESC[0K + if (es_argc != 1) return; + switch (es_argv[0]) + { + case 0: // ESC[0K Clear to end of line + len = Info.dwSize.X - Info.dwCursorPosition.X + 1; + FillConsoleOutputCharacter(hConOut, ' ', len, + Info.dwCursorPosition, + &NumberOfCharsWritten); + FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, + Info.dwCursorPosition, + &NumberOfCharsWritten); + return; + + case 1: // ESC[1K Clear from start of line to cursor + Pos.X = 0; + Pos.Y = Info.dwCursorPosition.Y; + FillConsoleOutputCharacter(hConOut, ' ', + Info.dwCursorPosition.X + 1, Pos, + &NumberOfCharsWritten); + FillConsoleOutputAttribute(hConOut, Info.wAttributes, + Info.dwCursorPosition.X + 1, Pos, + &NumberOfCharsWritten); + return; + + case 2: // ESC[2K Clear whole line. + Pos.X = 0; + Pos.Y = Info.dwCursorPosition.Y; + FillConsoleOutputCharacter(hConOut, ' ', Info.dwSize.X, Pos, + &NumberOfCharsWritten); + FillConsoleOutputAttribute(hConOut, Info.wAttributes, + Info.dwSize.X, Pos, + &NumberOfCharsWritten); + return; + + default: + return; + } + + case 'X': // ESC[#X Erase # characters. + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[X == ESC[1X + if (es_argc != 1) return; + FillConsoleOutputCharacter(hConOut, ' ', es_argv[0], + Info.dwCursorPosition, + &NumberOfCharsWritten); + FillConsoleOutputAttribute(hConOut, Info.wAttributes, es_argv[0], + Info.dwCursorPosition, + &NumberOfCharsWritten); + return; + + case 'L': // ESC[#L Insert # blank lines. + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[L == ESC[1L + if (es_argc != 1) return; + Rect.Left = 0; + Rect.Top = Info.dwCursorPosition.Y; + Rect.Right = Info.dwSize.X - 1; + Rect.Bottom = Info.dwSize.Y - 1; + Pos.X = 0; + Pos.Y = Info.dwCursorPosition.Y + es_argv[0]; + CharInfo.Char.UnicodeChar = ' '; + CharInfo.Attributes = Info.wAttributes; + ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Pos, &CharInfo); + return; + + case 'M': // ESC[#M Delete # lines. + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[M == ESC[1M + if (es_argc != 1) return; + if (es_argv[0] > Info.dwSize.Y - Info.dwCursorPosition.Y) + es_argv[0] = Info.dwSize.Y - Info.dwCursorPosition.Y; + Rect.Left = 0; + Rect.Top = Info.dwCursorPosition.Y + es_argv[0]; + Rect.Right = Info.dwSize.X - 1; + Rect.Bottom = Info.dwSize.Y - 1; + Pos.X = 0; + Pos.Y = Info.dwCursorPosition.Y; + CharInfo.Char.UnicodeChar = ' '; + CharInfo.Attributes = Info.wAttributes; + ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Pos, &CharInfo); + return; + + case 'P': // ESC[#P Delete # characters. + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[P == ESC[1P + if (es_argc != 1) return; + if (Info.dwCursorPosition.X + es_argv[0] > Info.dwSize.X - 1) + es_argv[0] = Info.dwSize.X - Info.dwCursorPosition.X; + Rect.Left = Info.dwCursorPosition.X + es_argv[0]; + Rect.Top = Info.dwCursorPosition.Y; + Rect.Right = Info.dwSize.X - 1; + Rect.Bottom = Info.dwCursorPosition.Y; + CharInfo.Char.UnicodeChar = ' '; + CharInfo.Attributes = Info.wAttributes; + ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Info.dwCursorPosition, + &CharInfo); + return; + + case '@': // ESC[#@ Insert # blank characters. + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[@ == ESC[1@ + if (es_argc != 1) return; + if (Info.dwCursorPosition.X + es_argv[0] > Info.dwSize.X - 1) + es_argv[0] = Info.dwSize.X - Info.dwCursorPosition.X; + Rect.Left = Info.dwCursorPosition.X; + Rect.Top = Info.dwCursorPosition.Y; + Rect.Right = Info.dwSize.X - 1 - es_argv[0]; + Rect.Bottom = Info.dwCursorPosition.Y; + Pos.X = Info.dwCursorPosition.X + es_argv[0]; + Pos.Y = Info.dwCursorPosition.Y; + CharInfo.Char.UnicodeChar = ' '; + CharInfo.Attributes = Info.wAttributes; + ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Pos, &CharInfo); + return; + + case 'k': // ESC[#k + case 'A': // ESC[#A Moves cursor up # lines + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[A == ESC[1A + if (es_argc != 1) return; + Pos.Y = Info.dwCursorPosition.Y - es_argv[0]; + if (Pos.Y < 0) Pos.Y = 0; + Pos.X = Info.dwCursorPosition.X; + SetConsoleCursorPosition(hConOut, Pos); + return; + + case 'e': // ESC[#e + case 'B': // ESC[#B Moves cursor down # lines + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[B == ESC[1B + if (es_argc != 1) return; + Pos.Y = Info.dwCursorPosition.Y + es_argv[0]; + if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; + Pos.X = Info.dwCursorPosition.X; + SetConsoleCursorPosition(hConOut, Pos); + return; + + case 'a': // ESC[#a + case 'C': // ESC[#C Moves cursor forward # spaces + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[C == ESC[1C + if (es_argc != 1) return; + Pos.X = Info.dwCursorPosition.X + es_argv[0]; + if (Pos.X >= Info.dwSize.X) Pos.X = Info.dwSize.X - 1; + Pos.Y = Info.dwCursorPosition.Y; + SetConsoleCursorPosition(hConOut, Pos); + return; + + case 'j': // ESC[#j + case 'D': // ESC[#D Moves cursor back # spaces + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[D == ESC[1D + if (es_argc != 1) return; + Pos.X = Info.dwCursorPosition.X - es_argv[0]; + if (Pos.X < 0) Pos.X = 0; + Pos.Y = Info.dwCursorPosition.Y; + SetConsoleCursorPosition(hConOut, Pos); + return; + + case 'E': // ESC[#E Moves cursor down # lines, column 1. + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[E == ESC[1E + if (es_argc != 1) return; + Pos.Y = Info.dwCursorPosition.Y + es_argv[0]; + if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; + Pos.X = 0; + SetConsoleCursorPosition(hConOut, Pos); + return; + + case 'F': // ESC[#F Moves cursor up # lines, column 1. + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[F == ESC[1F + if (es_argc != 1) return; + Pos.Y = Info.dwCursorPosition.Y - es_argv[0]; + if (Pos.Y < 0) Pos.Y = 0; + Pos.X = 0; + SetConsoleCursorPosition(hConOut, Pos); + return; + + case '`': // ESC[#` + case 'G': // ESC[#G Moves cursor column # in current row. + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[G == ESC[1G + if (es_argc != 1) return; + Pos.X = es_argv[0] - 1; + if (Pos.X >= Info.dwSize.X) Pos.X = Info.dwSize.X - 1; + if (Pos.X < 0) Pos.X = 0; + Pos.Y = Info.dwCursorPosition.Y; + SetConsoleCursorPosition(hConOut, Pos); + return; + + case 'd': // ESC[#d Moves cursor row #, current column. + if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[d == ESC[1d + if (es_argc != 1) return; + Pos.Y = es_argv[0] - 1; + if (Pos.Y < 0) Pos.Y = 0; + if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; + SetConsoleCursorPosition(hConOut, Pos); + return; + + case 'f': // ESC[#;#f + case 'H': // ESC[#;#H Moves cursor to line #, column # + if (es_argc == 0) + es_argv[es_argc++] = 1; // ESC[H == ESC[1;1H + if (es_argc == 1) + es_argv[es_argc++] = 1; // ESC[#H == ESC[#;1H + if (es_argc > 2) return; + Pos.X = es_argv[1] - 1; + if (Pos.X < 0) Pos.X = 0; + if (Pos.X >= Info.dwSize.X) Pos.X = Info.dwSize.X - 1; + Pos.Y = es_argv[0] - 1; + if (Pos.Y < 0) Pos.Y = 0; + if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; + SetConsoleCursorPosition(hConOut, Pos); + return; + + case 's': // ESC[s Saves cursor position for recall later + if (es_argc != 0) return; + SavePos = Info.dwCursorPosition; + return; + + case 'u': // ESC[u Return to saved cursor position + if (es_argc != 0) return; + SetConsoleCursorPosition(hConOut, SavePos); + return; + + case 'n': // ESC[#n Device status report + if (es_argc != 1) return; // ESC[n == ESC[0n -> ignored + switch (es_argv[0]) + { + case 5: // ESC[5n Report status + SendSequence(L"\33[0n"); // "OK" + return; + + case 6: // ESC[6n Report cursor position + { + WCHAR buf[32]; + swprintf(buf, 32, L"\33[%d;%dR", Info.dwCursorPosition.Y + 1, + Info.dwCursorPosition.X + 1); + SendSequence(buf); + } + return; + + default: + return; + } + + case 't': // ESC[#t Window manipulation + if (es_argc != 1) return; + if (es_argv[0] == 21) // ESC[21t Report xterm window's title + { + WCHAR buf[MAX_PATH * 2]; + DWORD len = GetConsoleTitleW(buf + 3, lenof(buf) - 3 - 2); + // Too bad if it's too big or fails. + buf[0] = ESC; + buf[1] = ']'; + buf[2] = 'l'; + buf[3 + len] = ESC; + buf[3 + len + 1] = '\\'; + buf[3 + len + 2] = '\0'; + SendSequence(buf); + } + return; + + default: + return; + } + } + else // (prefix == ']') + { + // Ignore any \e]? or \e]> sequences. + if (prefix2 != 0) + return; + + if (es_argc == 1 && es_argv[0] == 0) // ESC]0;titleST + { + SetConsoleTitleW(Pt_arg); + } + } +} + +//----------------------------------------------------------------------------- +// ParseAndPrintANSIString(hDev, lpBuffer, nNumberOfBytesToWrite) +// Parses the string lpBuffer, interprets the escapes sequences and prints the +// characters in the device hDev (console). +// The lexer is a three states automata. +// If the number of arguments es_argc > MAX_ARG, only the MAX_ARG-1 firsts and +// the last arguments are processed (no es_argv[] overflow). +//----------------------------------------------------------------------------- + +inline BOOL ParseAndPrintANSIString(HANDLE hDev, LPCVOID lpBuffer, DWORD nNumberOfBytesToWrite, LPDWORD lpNumberOfBytesWritten) +{ + DWORD i; + LPCSTR s; + + if (hDev != hConOut) // reinit if device has changed + { + hConOut = hDev; + state = 1; + shifted = FALSE; + } + for (i = nNumberOfBytesToWrite, s = (LPCSTR)lpBuffer; i > 0; i--, s++) + { + if (state == 1) + { + if (*s == ESC) state = 2; + else if (*s == SO) shifted = TRUE; + else if (*s == SI) shifted = FALSE; + else PushBuffer(*s); + } + else if (state == 2) + { + if (*s == ESC); // \e\e...\e == \e + else if ((*s == '[') || (*s == ']')) + { + FlushBuffer(); + prefix = *s; + prefix2 = 0; + state = 3; + Pt_len = 0; + *Pt_arg = '\0'; + } + else if (*s == ')' || *s == '(') state = 6; + else state = 1; + } + else if (state == 3) + { + if (is_digit(*s)) + { + es_argc = 0; + es_argv[0] = *s - '0'; + state = 4; + } + else if (*s == ';') + { + es_argc = 1; + es_argv[0] = 0; + es_argv[1] = 0; + state = 4; + } + else if (*s == '?' || *s == '>') + { + prefix2 = *s; + } + else + { + es_argc = 0; + suffix = *s; + InterpretEscSeq(); + state = 1; + } + } + else if (state == 4) + { + if (is_digit(*s)) + { + es_argv[es_argc] = 10 * es_argv[es_argc] + (*s - '0'); + } + else if (*s == ';') + { + if (es_argc < MAX_ARG - 1) es_argc++; + es_argv[es_argc] = 0; + if (prefix == ']') + state = 5; + } + else + { + es_argc++; + suffix = *s; + InterpretEscSeq(); + state = 1; + } + } + else if (state == 5) + { + if (*s == BEL) + { + Pt_arg[Pt_len] = '\0'; + InterpretEscSeq(); + state = 1; + } + else if (*s == '\\' && Pt_len > 0 && Pt_arg[Pt_len - 1] == ESC) + { + Pt_arg[--Pt_len] = '\0'; + InterpretEscSeq(); + state = 1; + } + else if (Pt_len < lenof(Pt_arg) - 1) + Pt_arg[Pt_len++] = *s; + } + else if (state == 6) + { + // Ignore it (ESC ) 0 is implicit; nothing else is supported). + state = 1; + } + } + FlushBuffer(); + if (lpNumberOfBytesWritten != NULL) + *lpNumberOfBytesWritten = nNumberOfBytesToWrite - i; + return (i == 0); +} + +} // namespace ansi + +HANDLE hOut; +HANDLE hIn; +DWORD consolemodeIn = 0; + +inline int win32read(int *c) { + DWORD foo; + INPUT_RECORD b; + KEY_EVENT_RECORD e; + BOOL altgr; + + while (1) { + if (!ReadConsoleInput(hIn, &b, 1, &foo)) return 0; + if (!foo) return 0; + + if (b.EventType == KEY_EVENT && b.Event.KeyEvent.bKeyDown) { + + e = b.Event.KeyEvent; + *c = b.Event.KeyEvent.uChar.AsciiChar; + + altgr = e.dwControlKeyState & (LEFT_CTRL_PRESSED | RIGHT_ALT_PRESSED); + + if (e.dwControlKeyState & (LEFT_CTRL_PRESSED | RIGHT_CTRL_PRESSED) && !altgr) { + + /* Ctrl+Key */ + switch (*c) { + case 'D': + *c = 4; + return 1; + case 'C': + *c = 3; + return 1; + case 'H': + *c = 8; + return 1; + case 'T': + *c = 20; + return 1; + case 'B': /* ctrl-b, left_arrow */ + *c = 2; + return 1; + case 'F': /* ctrl-f right_arrow*/ + *c = 6; + return 1; + case 'P': /* ctrl-p up_arrow*/ + *c = 16; + return 1; + case 'N': /* ctrl-n down_arrow*/ + *c = 14; + return 1; + case 'U': /* Ctrl+u, delete the whole line. */ + *c = 21; + return 1; + case 'K': /* Ctrl+k, delete from current to end of line. */ + *c = 11; + return 1; + case 'A': /* Ctrl+a, go to the start of the line */ + *c = 1; + return 1; + case 'E': /* ctrl+e, go to the end of the line */ + *c = 5; + return 1; + } + + /* Other Ctrl+KEYs ignored */ + } else { + + switch (e.wVirtualKeyCode) { + + case VK_ESCAPE: /* ignore - send ctrl-c, will return -1 */ + *c = 3; + return 1; + case VK_RETURN: /* enter */ + *c = 13; + return 1; + case VK_LEFT: /* left */ + *c = 2; + return 1; + case VK_RIGHT: /* right */ + *c = 6; + return 1; + case VK_UP: /* up */ + *c = 16; + return 1; + case VK_DOWN: /* down */ + *c = 14; + return 1; + case VK_HOME: + *c = 1; + return 1; + case VK_END: + *c = 5; + return 1; + case VK_BACK: + *c = 8; + return 1; + case VK_DELETE: + *c = 4; /* same as Ctrl+D above */ + return 1; + default: + if (*c) return 1; + } + } + } + } + + return -1; /* Makes compiler happy */ +} + +inline int win32_write(int fd, const void *buffer, unsigned int count) { + if (fd == _fileno(stdout)) { + DWORD bytesWritten = 0; + if (FALSE != ansi::ParseAndPrintANSIString(GetStdHandle(STD_OUTPUT_HANDLE), buffer, (DWORD)count, &bytesWritten)) { + return (int)bytesWritten; + } else { + errno = GetLastError(); + return 0; + } + } else if (fd == _fileno(stderr)) { + DWORD bytesWritten = 0; + if (FALSE != ansi::ParseAndPrintANSIString(GetStdHandle(STD_ERROR_HANDLE), buffer, (DWORD)count, &bytesWritten)) { + return (int)bytesWritten; + } else { + errno = GetLastError(); + return 0; + } + } else { + return _write(fd, buffer, count); + } +} +#endif // _WIN32 + +#define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 +#define LINENOISE_MAX_LINE 4096 +static const char *unsupported_term[] = {"dumb","cons25","emacs",NULL}; +static CompletionCallback completionCallback; + +#ifndef _WIN32 +static struct termios orig_termios; /* In order to restore at exit.*/ +#endif +static bool rawmode = false; /* For atexit() function to check if restore is needed*/ +static bool mlmode = false; /* Multi line mode. Default is single line. */ +static bool atexit_registered = false; /* Register atexit just 1 time. */ +static size_t history_max_len = LINENOISE_DEFAULT_HISTORY_MAX_LEN; +static std::vector history; + +/* The linenoiseState structure represents the state during line editing. + * We pass this state to functions implementing specific editing + * functionalities. */ +struct linenoiseState { + int ifd; /* Terminal stdin file descriptor. */ + int ofd; /* Terminal stdout file descriptor. */ + char *buf; /* Edited line buffer. */ + int buflen; /* Edited line buffer size. */ + std::string prompt; /* Prompt to display. */ + int pos; /* Current cursor position. */ + int oldcolpos; /* Previous refresh cursor column position. */ + int len; /* Current edited line length. */ + int cols; /* Number of columns in terminal. */ + int maxrows; /* Maximum num of rows used so far (multiline mode) */ + int history_index; /* The history index we are currently editing. */ +}; + +enum KEY_ACTION { + KEY_NULL = 0, /* NULL */ + CTRL_A = 1, /* Ctrl+a */ + CTRL_B = 2, /* Ctrl-b */ + CTRL_C = 3, /* Ctrl-c */ + CTRL_D = 4, /* Ctrl-d */ + CTRL_E = 5, /* Ctrl-e */ + CTRL_F = 6, /* Ctrl-f */ + CTRL_H = 8, /* Ctrl-h */ + TAB = 9, /* Tab */ + CTRL_K = 11, /* Ctrl+k */ + CTRL_L = 12, /* Ctrl+l */ + ENTER = 13, /* Enter */ + CTRL_N = 14, /* Ctrl-n */ + CTRL_P = 16, /* Ctrl-p */ + CTRL_T = 20, /* Ctrl-t */ + CTRL_U = 21, /* Ctrl+u */ + CTRL_W = 23, /* Ctrl+w */ + ESC = 27, /* Escape */ + BACKSPACE = 127 /* Backspace */ +}; + +void linenoiseAtExit(void); +bool AddHistory(const char *line); +void refreshLine(struct linenoiseState *l); + +/* ============================ UTF8 utilities ============================== */ + +static unsigned long unicodeWideCharTable[][2] = { + { 0x1100, 0x115F }, { 0x2329, 0x232A }, { 0x2E80, 0x2E99, }, { 0x2E9B, 0x2EF3, }, + { 0x2F00, 0x2FD5, }, { 0x2FF0, 0x2FFB, }, { 0x3000, 0x303E, }, { 0x3041, 0x3096, }, + { 0x3099, 0x30FF, }, { 0x3105, 0x312D, }, { 0x3131, 0x318E, }, { 0x3190, 0x31BA, }, + { 0x31C0, 0x31E3, }, { 0x31F0, 0x321E, }, { 0x3220, 0x3247, }, { 0x3250, 0x4DBF, }, + { 0x4E00, 0xA48C, }, { 0xA490, 0xA4C6, }, { 0xA960, 0xA97C, }, { 0xAC00, 0xD7A3, }, + { 0xF900, 0xFAFF, }, { 0xFE10, 0xFE19, }, { 0xFE30, 0xFE52, }, { 0xFE54, 0xFE66, }, + { 0xFE68, 0xFE6B, }, { 0xFF01, 0xFFE6, }, + { 0x1B000, 0x1B001, }, { 0x1F200, 0x1F202, }, { 0x1F210, 0x1F23A, }, + { 0x1F240, 0x1F248, }, { 0x1F250, 0x1F251, }, { 0x20000, 0x3FFFD, }, +}; + +static int unicodeWideCharTableSize = sizeof(unicodeWideCharTable) / sizeof(unicodeWideCharTable[0]); + +static int unicodeIsWideChar(unsigned long cp) +{ + int i; + for (i = 0; i < unicodeWideCharTableSize; i++) { + if (unicodeWideCharTable[i][0] <= cp && cp <= unicodeWideCharTable[i][1]) { + return 1; + } + } + return 0; +} + +static unsigned long unicodeCombiningCharTable[] = { + 0x0300,0x0301,0x0302,0x0303,0x0304,0x0305,0x0306,0x0307, + 0x0308,0x0309,0x030A,0x030B,0x030C,0x030D,0x030E,0x030F, + 0x0310,0x0311,0x0312,0x0313,0x0314,0x0315,0x0316,0x0317, + 0x0318,0x0319,0x031A,0x031B,0x031C,0x031D,0x031E,0x031F, + 0x0320,0x0321,0x0322,0x0323,0x0324,0x0325,0x0326,0x0327, + 0x0328,0x0329,0x032A,0x032B,0x032C,0x032D,0x032E,0x032F, + 0x0330,0x0331,0x0332,0x0333,0x0334,0x0335,0x0336,0x0337, + 0x0338,0x0339,0x033A,0x033B,0x033C,0x033D,0x033E,0x033F, + 0x0340,0x0341,0x0342,0x0343,0x0344,0x0345,0x0346,0x0347, + 0x0348,0x0349,0x034A,0x034B,0x034C,0x034D,0x034E,0x034F, + 0x0350,0x0351,0x0352,0x0353,0x0354,0x0355,0x0356,0x0357, + 0x0358,0x0359,0x035A,0x035B,0x035C,0x035D,0x035E,0x035F, + 0x0360,0x0361,0x0362,0x0363,0x0364,0x0365,0x0366,0x0367, + 0x0368,0x0369,0x036A,0x036B,0x036C,0x036D,0x036E,0x036F, + 0x0483,0x0484,0x0485,0x0486,0x0487,0x0591,0x0592,0x0593, + 0x0594,0x0595,0x0596,0x0597,0x0598,0x0599,0x059A,0x059B, + 0x059C,0x059D,0x059E,0x059F,0x05A0,0x05A1,0x05A2,0x05A3, + 0x05A4,0x05A5,0x05A6,0x05A7,0x05A8,0x05A9,0x05AA,0x05AB, + 0x05AC,0x05AD,0x05AE,0x05AF,0x05B0,0x05B1,0x05B2,0x05B3, + 0x05B4,0x05B5,0x05B6,0x05B7,0x05B8,0x05B9,0x05BA,0x05BB, + 0x05BC,0x05BD,0x05BF,0x05C1,0x05C2,0x05C4,0x05C5,0x05C7, + 0x0610,0x0611,0x0612,0x0613,0x0614,0x0615,0x0616,0x0617, + 0x0618,0x0619,0x061A,0x064B,0x064C,0x064D,0x064E,0x064F, + 0x0650,0x0651,0x0652,0x0653,0x0654,0x0655,0x0656,0x0657, + 0x0658,0x0659,0x065A,0x065B,0x065C,0x065D,0x065E,0x065F, + 0x0670,0x06D6,0x06D7,0x06D8,0x06D9,0x06DA,0x06DB,0x06DC, + 0x06DF,0x06E0,0x06E1,0x06E2,0x06E3,0x06E4,0x06E7,0x06E8, + 0x06EA,0x06EB,0x06EC,0x06ED,0x0711,0x0730,0x0731,0x0732, + 0x0733,0x0734,0x0735,0x0736,0x0737,0x0738,0x0739,0x073A, + 0x073B,0x073C,0x073D,0x073E,0x073F,0x0740,0x0741,0x0742, + 0x0743,0x0744,0x0745,0x0746,0x0747,0x0748,0x0749,0x074A, + 0x07A6,0x07A7,0x07A8,0x07A9,0x07AA,0x07AB,0x07AC,0x07AD, + 0x07AE,0x07AF,0x07B0,0x07EB,0x07EC,0x07ED,0x07EE,0x07EF, + 0x07F0,0x07F1,0x07F2,0x07F3,0x0816,0x0817,0x0818,0x0819, + 0x081B,0x081C,0x081D,0x081E,0x081F,0x0820,0x0821,0x0822, + 0x0823,0x0825,0x0826,0x0827,0x0829,0x082A,0x082B,0x082C, + 0x082D,0x0859,0x085A,0x085B,0x08E3,0x08E4,0x08E5,0x08E6, + 0x08E7,0x08E8,0x08E9,0x08EA,0x08EB,0x08EC,0x08ED,0x08EE, + 0x08EF,0x08F0,0x08F1,0x08F2,0x08F3,0x08F4,0x08F5,0x08F6, + 0x08F7,0x08F8,0x08F9,0x08FA,0x08FB,0x08FC,0x08FD,0x08FE, + 0x08FF,0x0900,0x0901,0x0902,0x093A,0x093C,0x0941,0x0942, + 0x0943,0x0944,0x0945,0x0946,0x0947,0x0948,0x094D,0x0951, + 0x0952,0x0953,0x0954,0x0955,0x0956,0x0957,0x0962,0x0963, + 0x0981,0x09BC,0x09C1,0x09C2,0x09C3,0x09C4,0x09CD,0x09E2, + 0x09E3,0x0A01,0x0A02,0x0A3C,0x0A41,0x0A42,0x0A47,0x0A48, + 0x0A4B,0x0A4C,0x0A4D,0x0A51,0x0A70,0x0A71,0x0A75,0x0A81, + 0x0A82,0x0ABC,0x0AC1,0x0AC2,0x0AC3,0x0AC4,0x0AC5,0x0AC7, + 0x0AC8,0x0ACD,0x0AE2,0x0AE3,0x0B01,0x0B3C,0x0B3F,0x0B41, + 0x0B42,0x0B43,0x0B44,0x0B4D,0x0B56,0x0B62,0x0B63,0x0B82, + 0x0BC0,0x0BCD,0x0C00,0x0C3E,0x0C3F,0x0C40,0x0C46,0x0C47, + 0x0C48,0x0C4A,0x0C4B,0x0C4C,0x0C4D,0x0C55,0x0C56,0x0C62, + 0x0C63,0x0C81,0x0CBC,0x0CBF,0x0CC6,0x0CCC,0x0CCD,0x0CE2, + 0x0CE3,0x0D01,0x0D41,0x0D42,0x0D43,0x0D44,0x0D4D,0x0D62, + 0x0D63,0x0DCA,0x0DD2,0x0DD3,0x0DD4,0x0DD6,0x0E31,0x0E34, + 0x0E35,0x0E36,0x0E37,0x0E38,0x0E39,0x0E3A,0x0E47,0x0E48, + 0x0E49,0x0E4A,0x0E4B,0x0E4C,0x0E4D,0x0E4E,0x0EB1,0x0EB4, + 0x0EB5,0x0EB6,0x0EB7,0x0EB8,0x0EB9,0x0EBB,0x0EBC,0x0EC8, + 0x0EC9,0x0ECA,0x0ECB,0x0ECC,0x0ECD,0x0F18,0x0F19,0x0F35, + 0x0F37,0x0F39,0x0F71,0x0F72,0x0F73,0x0F74,0x0F75,0x0F76, + 0x0F77,0x0F78,0x0F79,0x0F7A,0x0F7B,0x0F7C,0x0F7D,0x0F7E, + 0x0F80,0x0F81,0x0F82,0x0F83,0x0F84,0x0F86,0x0F87,0x0F8D, + 0x0F8E,0x0F8F,0x0F90,0x0F91,0x0F92,0x0F93,0x0F94,0x0F95, + 0x0F96,0x0F97,0x0F99,0x0F9A,0x0F9B,0x0F9C,0x0F9D,0x0F9E, + 0x0F9F,0x0FA0,0x0FA1,0x0FA2,0x0FA3,0x0FA4,0x0FA5,0x0FA6, + 0x0FA7,0x0FA8,0x0FA9,0x0FAA,0x0FAB,0x0FAC,0x0FAD,0x0FAE, + 0x0FAF,0x0FB0,0x0FB1,0x0FB2,0x0FB3,0x0FB4,0x0FB5,0x0FB6, + 0x0FB7,0x0FB8,0x0FB9,0x0FBA,0x0FBB,0x0FBC,0x0FC6,0x102D, + 0x102E,0x102F,0x1030,0x1032,0x1033,0x1034,0x1035,0x1036, + 0x1037,0x1039,0x103A,0x103D,0x103E,0x1058,0x1059,0x105E, + 0x105F,0x1060,0x1071,0x1072,0x1073,0x1074,0x1082,0x1085, + 0x1086,0x108D,0x109D,0x135D,0x135E,0x135F,0x1712,0x1713, + 0x1714,0x1732,0x1733,0x1734,0x1752,0x1753,0x1772,0x1773, + 0x17B4,0x17B5,0x17B7,0x17B8,0x17B9,0x17BA,0x17BB,0x17BC, + 0x17BD,0x17C6,0x17C9,0x17CA,0x17CB,0x17CC,0x17CD,0x17CE, + 0x17CF,0x17D0,0x17D1,0x17D2,0x17D3,0x17DD,0x180B,0x180C, + 0x180D,0x18A9,0x1920,0x1921,0x1922,0x1927,0x1928,0x1932, + 0x1939,0x193A,0x193B,0x1A17,0x1A18,0x1A1B,0x1A56,0x1A58, + 0x1A59,0x1A5A,0x1A5B,0x1A5C,0x1A5D,0x1A5E,0x1A60,0x1A62, + 0x1A65,0x1A66,0x1A67,0x1A68,0x1A69,0x1A6A,0x1A6B,0x1A6C, + 0x1A73,0x1A74,0x1A75,0x1A76,0x1A77,0x1A78,0x1A79,0x1A7A, + 0x1A7B,0x1A7C,0x1A7F,0x1AB0,0x1AB1,0x1AB2,0x1AB3,0x1AB4, + 0x1AB5,0x1AB6,0x1AB7,0x1AB8,0x1AB9,0x1ABA,0x1ABB,0x1ABC, + 0x1ABD,0x1B00,0x1B01,0x1B02,0x1B03,0x1B34,0x1B36,0x1B37, + 0x1B38,0x1B39,0x1B3A,0x1B3C,0x1B42,0x1B6B,0x1B6C,0x1B6D, + 0x1B6E,0x1B6F,0x1B70,0x1B71,0x1B72,0x1B73,0x1B80,0x1B81, + 0x1BA2,0x1BA3,0x1BA4,0x1BA5,0x1BA8,0x1BA9,0x1BAB,0x1BAC, + 0x1BAD,0x1BE6,0x1BE8,0x1BE9,0x1BED,0x1BEF,0x1BF0,0x1BF1, + 0x1C2C,0x1C2D,0x1C2E,0x1C2F,0x1C30,0x1C31,0x1C32,0x1C33, + 0x1C36,0x1C37,0x1CD0,0x1CD1,0x1CD2,0x1CD4,0x1CD5,0x1CD6, + 0x1CD7,0x1CD8,0x1CD9,0x1CDA,0x1CDB,0x1CDC,0x1CDD,0x1CDE, + 0x1CDF,0x1CE0,0x1CE2,0x1CE3,0x1CE4,0x1CE5,0x1CE6,0x1CE7, + 0x1CE8,0x1CED,0x1CF4,0x1CF8,0x1CF9,0x1DC0,0x1DC1,0x1DC2, + 0x1DC3,0x1DC4,0x1DC5,0x1DC6,0x1DC7,0x1DC8,0x1DC9,0x1DCA, + 0x1DCB,0x1DCC,0x1DCD,0x1DCE,0x1DCF,0x1DD0,0x1DD1,0x1DD2, + 0x1DD3,0x1DD4,0x1DD5,0x1DD6,0x1DD7,0x1DD8,0x1DD9,0x1DDA, + 0x1DDB,0x1DDC,0x1DDD,0x1DDE,0x1DDF,0x1DE0,0x1DE1,0x1DE2, + 0x1DE3,0x1DE4,0x1DE5,0x1DE6,0x1DE7,0x1DE8,0x1DE9,0x1DEA, + 0x1DEB,0x1DEC,0x1DED,0x1DEE,0x1DEF,0x1DF0,0x1DF1,0x1DF2, + 0x1DF3,0x1DF4,0x1DF5,0x1DFC,0x1DFD,0x1DFE,0x1DFF,0x20D0, + 0x20D1,0x20D2,0x20D3,0x20D4,0x20D5,0x20D6,0x20D7,0x20D8, + 0x20D9,0x20DA,0x20DB,0x20DC,0x20E1,0x20E5,0x20E6,0x20E7, + 0x20E8,0x20E9,0x20EA,0x20EB,0x20EC,0x20ED,0x20EE,0x20EF, + 0x20F0,0x2CEF,0x2CF0,0x2CF1,0x2D7F,0x2DE0,0x2DE1,0x2DE2, + 0x2DE3,0x2DE4,0x2DE5,0x2DE6,0x2DE7,0x2DE8,0x2DE9,0x2DEA, + 0x2DEB,0x2DEC,0x2DED,0x2DEE,0x2DEF,0x2DF0,0x2DF1,0x2DF2, + 0x2DF3,0x2DF4,0x2DF5,0x2DF6,0x2DF7,0x2DF8,0x2DF9,0x2DFA, + 0x2DFB,0x2DFC,0x2DFD,0x2DFE,0x2DFF,0x302A,0x302B,0x302C, + 0x302D,0x3099,0x309A,0xA66F,0xA674,0xA675,0xA676,0xA677, + 0xA678,0xA679,0xA67A,0xA67B,0xA67C,0xA67D,0xA69E,0xA69F, + 0xA6F0,0xA6F1,0xA802,0xA806,0xA80B,0xA825,0xA826,0xA8C4, + 0xA8E0,0xA8E1,0xA8E2,0xA8E3,0xA8E4,0xA8E5,0xA8E6,0xA8E7, + 0xA8E8,0xA8E9,0xA8EA,0xA8EB,0xA8EC,0xA8ED,0xA8EE,0xA8EF, + 0xA8F0,0xA8F1,0xA926,0xA927,0xA928,0xA929,0xA92A,0xA92B, + 0xA92C,0xA92D,0xA947,0xA948,0xA949,0xA94A,0xA94B,0xA94C, + 0xA94D,0xA94E,0xA94F,0xA950,0xA951,0xA980,0xA981,0xA982, + 0xA9B3,0xA9B6,0xA9B7,0xA9B8,0xA9B9,0xA9BC,0xA9E5,0xAA29, + 0xAA2A,0xAA2B,0xAA2C,0xAA2D,0xAA2E,0xAA31,0xAA32,0xAA35, + 0xAA36,0xAA43,0xAA4C,0xAA7C,0xAAB0,0xAAB2,0xAAB3,0xAAB4, + 0xAAB7,0xAAB8,0xAABE,0xAABF,0xAAC1,0xAAEC,0xAAED,0xAAF6, + 0xABE5,0xABE8,0xABED,0xFB1E,0xFE00,0xFE01,0xFE02,0xFE03, + 0xFE04,0xFE05,0xFE06,0xFE07,0xFE08,0xFE09,0xFE0A,0xFE0B, + 0xFE0C,0xFE0D,0xFE0E,0xFE0F,0xFE20,0xFE21,0xFE22,0xFE23, + 0xFE24,0xFE25,0xFE26,0xFE27,0xFE28,0xFE29,0xFE2A,0xFE2B, + 0xFE2C,0xFE2D,0xFE2E,0xFE2F, + 0x101FD,0x102E0,0x10376,0x10377,0x10378,0x10379,0x1037A,0x10A01, + 0x10A02,0x10A03,0x10A05,0x10A06,0x10A0C,0x10A0D,0x10A0E,0x10A0F, + 0x10A38,0x10A39,0x10A3A,0x10A3F,0x10AE5,0x10AE6,0x11001,0x11038, + 0x11039,0x1103A,0x1103B,0x1103C,0x1103D,0x1103E,0x1103F,0x11040, + 0x11041,0x11042,0x11043,0x11044,0x11045,0x11046,0x1107F,0x11080, + 0x11081,0x110B3,0x110B4,0x110B5,0x110B6,0x110B9,0x110BA,0x11100, + 0x11101,0x11102,0x11127,0x11128,0x11129,0x1112A,0x1112B,0x1112D, + 0x1112E,0x1112F,0x11130,0x11131,0x11132,0x11133,0x11134,0x11173, + 0x11180,0x11181,0x111B6,0x111B7,0x111B8,0x111B9,0x111BA,0x111BB, + 0x111BC,0x111BD,0x111BE,0x111CA,0x111CB,0x111CC,0x1122F,0x11230, + 0x11231,0x11234,0x11236,0x11237,0x112DF,0x112E3,0x112E4,0x112E5, + 0x112E6,0x112E7,0x112E8,0x112E9,0x112EA,0x11300,0x11301,0x1133C, + 0x11340,0x11366,0x11367,0x11368,0x11369,0x1136A,0x1136B,0x1136C, + 0x11370,0x11371,0x11372,0x11373,0x11374,0x114B3,0x114B4,0x114B5, + 0x114B6,0x114B7,0x114B8,0x114BA,0x114BF,0x114C0,0x114C2,0x114C3, + 0x115B2,0x115B3,0x115B4,0x115B5,0x115BC,0x115BD,0x115BF,0x115C0, + 0x115DC,0x115DD,0x11633,0x11634,0x11635,0x11636,0x11637,0x11638, + 0x11639,0x1163A,0x1163D,0x1163F,0x11640,0x116AB,0x116AD,0x116B0, + 0x116B1,0x116B2,0x116B3,0x116B4,0x116B5,0x116B7,0x1171D,0x1171E, + 0x1171F,0x11722,0x11723,0x11724,0x11725,0x11727,0x11728,0x11729, + 0x1172A,0x1172B,0x16AF0,0x16AF1,0x16AF2,0x16AF3,0x16AF4,0x16B30, + 0x16B31,0x16B32,0x16B33,0x16B34,0x16B35,0x16B36,0x16F8F,0x16F90, + 0x16F91,0x16F92,0x1BC9D,0x1BC9E,0x1D167,0x1D168,0x1D169,0x1D17B, + 0x1D17C,0x1D17D,0x1D17E,0x1D17F,0x1D180,0x1D181,0x1D182,0x1D185, + 0x1D186,0x1D187,0x1D188,0x1D189,0x1D18A,0x1D18B,0x1D1AA,0x1D1AB, + 0x1D1AC,0x1D1AD,0x1D242,0x1D243,0x1D244,0x1DA00,0x1DA01,0x1DA02, + 0x1DA03,0x1DA04,0x1DA05,0x1DA06,0x1DA07,0x1DA08,0x1DA09,0x1DA0A, + 0x1DA0B,0x1DA0C,0x1DA0D,0x1DA0E,0x1DA0F,0x1DA10,0x1DA11,0x1DA12, + 0x1DA13,0x1DA14,0x1DA15,0x1DA16,0x1DA17,0x1DA18,0x1DA19,0x1DA1A, + 0x1DA1B,0x1DA1C,0x1DA1D,0x1DA1E,0x1DA1F,0x1DA20,0x1DA21,0x1DA22, + 0x1DA23,0x1DA24,0x1DA25,0x1DA26,0x1DA27,0x1DA28,0x1DA29,0x1DA2A, + 0x1DA2B,0x1DA2C,0x1DA2D,0x1DA2E,0x1DA2F,0x1DA30,0x1DA31,0x1DA32, + 0x1DA33,0x1DA34,0x1DA35,0x1DA36,0x1DA3B,0x1DA3C,0x1DA3D,0x1DA3E, + 0x1DA3F,0x1DA40,0x1DA41,0x1DA42,0x1DA43,0x1DA44,0x1DA45,0x1DA46, + 0x1DA47,0x1DA48,0x1DA49,0x1DA4A,0x1DA4B,0x1DA4C,0x1DA4D,0x1DA4E, + 0x1DA4F,0x1DA50,0x1DA51,0x1DA52,0x1DA53,0x1DA54,0x1DA55,0x1DA56, + 0x1DA57,0x1DA58,0x1DA59,0x1DA5A,0x1DA5B,0x1DA5C,0x1DA5D,0x1DA5E, + 0x1DA5F,0x1DA60,0x1DA61,0x1DA62,0x1DA63,0x1DA64,0x1DA65,0x1DA66, + 0x1DA67,0x1DA68,0x1DA69,0x1DA6A,0x1DA6B,0x1DA6C,0x1DA75,0x1DA84, + 0x1DA9B,0x1DA9C,0x1DA9D,0x1DA9E,0x1DA9F,0x1DAA1,0x1DAA2,0x1DAA3, + 0x1DAA4,0x1DAA5,0x1DAA6,0x1DAA7,0x1DAA8,0x1DAA9,0x1DAAA,0x1DAAB, + 0x1DAAC,0x1DAAD,0x1DAAE,0x1DAAF,0x1E8D0,0x1E8D1,0x1E8D2,0x1E8D3, + 0x1E8D4,0x1E8D5,0x1E8D6,0xE0100,0xE0101,0xE0102,0xE0103,0xE0104, + 0xE0105,0xE0106,0xE0107,0xE0108,0xE0109,0xE010A,0xE010B,0xE010C, + 0xE010D,0xE010E,0xE010F,0xE0110,0xE0111,0xE0112,0xE0113,0xE0114, + 0xE0115,0xE0116,0xE0117,0xE0118,0xE0119,0xE011A,0xE011B,0xE011C, + 0xE011D,0xE011E,0xE011F,0xE0120,0xE0121,0xE0122,0xE0123,0xE0124, + 0xE0125,0xE0126,0xE0127,0xE0128,0xE0129,0xE012A,0xE012B,0xE012C, + 0xE012D,0xE012E,0xE012F,0xE0130,0xE0131,0xE0132,0xE0133,0xE0134, + 0xE0135,0xE0136,0xE0137,0xE0138,0xE0139,0xE013A,0xE013B,0xE013C, + 0xE013D,0xE013E,0xE013F,0xE0140,0xE0141,0xE0142,0xE0143,0xE0144, + 0xE0145,0xE0146,0xE0147,0xE0148,0xE0149,0xE014A,0xE014B,0xE014C, + 0xE014D,0xE014E,0xE014F,0xE0150,0xE0151,0xE0152,0xE0153,0xE0154, + 0xE0155,0xE0156,0xE0157,0xE0158,0xE0159,0xE015A,0xE015B,0xE015C, + 0xE015D,0xE015E,0xE015F,0xE0160,0xE0161,0xE0162,0xE0163,0xE0164, + 0xE0165,0xE0166,0xE0167,0xE0168,0xE0169,0xE016A,0xE016B,0xE016C, + 0xE016D,0xE016E,0xE016F,0xE0170,0xE0171,0xE0172,0xE0173,0xE0174, + 0xE0175,0xE0176,0xE0177,0xE0178,0xE0179,0xE017A,0xE017B,0xE017C, + 0xE017D,0xE017E,0xE017F,0xE0180,0xE0181,0xE0182,0xE0183,0xE0184, + 0xE0185,0xE0186,0xE0187,0xE0188,0xE0189,0xE018A,0xE018B,0xE018C, + 0xE018D,0xE018E,0xE018F,0xE0190,0xE0191,0xE0192,0xE0193,0xE0194, + 0xE0195,0xE0196,0xE0197,0xE0198,0xE0199,0xE019A,0xE019B,0xE019C, + 0xE019D,0xE019E,0xE019F,0xE01A0,0xE01A1,0xE01A2,0xE01A3,0xE01A4, + 0xE01A5,0xE01A6,0xE01A7,0xE01A8,0xE01A9,0xE01AA,0xE01AB,0xE01AC, + 0xE01AD,0xE01AE,0xE01AF,0xE01B0,0xE01B1,0xE01B2,0xE01B3,0xE01B4, + 0xE01B5,0xE01B6,0xE01B7,0xE01B8,0xE01B9,0xE01BA,0xE01BB,0xE01BC, + 0xE01BD,0xE01BE,0xE01BF,0xE01C0,0xE01C1,0xE01C2,0xE01C3,0xE01C4, + 0xE01C5,0xE01C6,0xE01C7,0xE01C8,0xE01C9,0xE01CA,0xE01CB,0xE01CC, + 0xE01CD,0xE01CE,0xE01CF,0xE01D0,0xE01D1,0xE01D2,0xE01D3,0xE01D4, + 0xE01D5,0xE01D6,0xE01D7,0xE01D8,0xE01D9,0xE01DA,0xE01DB,0xE01DC, + 0xE01DD,0xE01DE,0xE01DF,0xE01E0,0xE01E1,0xE01E2,0xE01E3,0xE01E4, + 0xE01E5,0xE01E6,0xE01E7,0xE01E8,0xE01E9,0xE01EA,0xE01EB,0xE01EC, + 0xE01ED,0xE01EE,0xE01EF, +}; + +static int unicodeCombiningCharTableSize = sizeof(unicodeCombiningCharTable) / sizeof(unicodeCombiningCharTable[0]); + +inline int unicodeIsCombiningChar(unsigned long cp) +{ + int i; + for (i = 0; i < unicodeCombiningCharTableSize; i++) { + if (unicodeCombiningCharTable[i] == cp) { + return 1; + } + } + return 0; +} + +/* Get length of previous UTF8 character + */ +inline int unicodePrevUTF8CharLen(char* buf, int pos) +{ + int end = pos--; + while (pos >= 0 && ((unsigned char)buf[pos] & 0xC0) == 0x80) { + pos--; + } + return end - pos; +} + +/* Get length of previous UTF8 character + */ +inline int unicodeUTF8CharLen(char* buf, int buf_len, int pos) +{ + if (pos == buf_len) { return 0; } + unsigned char ch = buf[pos]; + if (ch < 0x80) { return 1; } + else if (ch < 0xE0) { return 2; } + else if (ch < 0xF0) { return 3; } + else { return 4; } +} + +/* Convert UTF8 to Unicode code point + */ +inline int unicodeUTF8CharToCodePoint( + const char* buf, + int len, + int* cp) +{ + if (len) { + unsigned char byte = buf[0]; + if ((byte & 0x80) == 0) { + *cp = byte; + return 1; + } else if ((byte & 0xE0) == 0xC0) { + if (len >= 2) { + *cp = (((unsigned long)(buf[0] & 0x1F)) << 6) | + ((unsigned long)(buf[1] & 0x3F)); + return 2; + } + } else if ((byte & 0xF0) == 0xE0) { + if (len >= 3) { + *cp = (((unsigned long)(buf[0] & 0x0F)) << 12) | + (((unsigned long)(buf[1] & 0x3F)) << 6) | + ((unsigned long)(buf[2] & 0x3F)); + return 3; + } + } else if ((byte & 0xF8) == 0xF0) { + if (len >= 4) { + *cp = (((unsigned long)(buf[0] & 0x07)) << 18) | + (((unsigned long)(buf[1] & 0x3F)) << 12) | + (((unsigned long)(buf[2] & 0x3F)) << 6) | + ((unsigned long)(buf[3] & 0x3F)); + return 4; + } + } + } + return 0; +} + +/* Get length of grapheme + */ +inline int unicodeGraphemeLen(char* buf, int buf_len, int pos) +{ + if (pos == buf_len) { + return 0; + } + int beg = pos; + pos += unicodeUTF8CharLen(buf, buf_len, pos); + while (pos < buf_len) { + int len = unicodeUTF8CharLen(buf, buf_len, pos); + int cp = 0; + unicodeUTF8CharToCodePoint(buf + pos, len, &cp); + if (!unicodeIsCombiningChar(cp)) { + return pos - beg; + } + pos += len; + } + return pos - beg; +} + +/* Get length of previous grapheme + */ +inline int unicodePrevGraphemeLen(char* buf, int pos) +{ + if (pos == 0) { + return 0; + } + int end = pos; + while (pos > 0) { + int len = unicodePrevUTF8CharLen(buf, pos); + pos -= len; + int cp = 0; + unicodeUTF8CharToCodePoint(buf + pos, len, &cp); + if (!unicodeIsCombiningChar(cp)) { + return end - pos; + } + } + return 0; +} + +inline int isAnsiEscape(const char* buf, int buf_len, int* len) +{ + if (buf_len > 2 && !memcmp("\033[", buf, 2)) { + int off = 2; + while (off < buf_len) { + switch (buf[off++]) { + case 'A': case 'B': case 'C': case 'D': + case 'E': case 'F': case 'G': case 'H': + case 'J': case 'K': case 'S': case 'T': + case 'f': case 'm': + *len = off; + return 1; + } + } + } + return 0; +} + +/* Get column position for the single line mode. + */ +inline int unicodeColumnPos(const char* buf, int buf_len) +{ + int ret = 0; + + int off = 0; + while (off < buf_len) { + int len; + if (isAnsiEscape(buf + off, buf_len - off, &len)) { + off += len; + continue; + } + + int cp = 0; + len = unicodeUTF8CharToCodePoint(buf + off, buf_len - off, &cp); + + if (!unicodeIsCombiningChar(cp)) { + ret += unicodeIsWideChar(cp) ? 2 : 1; + } + + off += len; + } + + return ret; +} + +/* Get column position for the multi line mode. + */ +inline int unicodeColumnPosForMultiLine(char* buf, int buf_len, int pos, int cols, int ini_pos) +{ + int ret = 0; + int colwid = ini_pos; + + int off = 0; + while (off < buf_len) { + int cp = 0; + int len = unicodeUTF8CharToCodePoint(buf + off, buf_len - off, &cp); + + int wid = 0; + if (!unicodeIsCombiningChar(cp)) { + wid = unicodeIsWideChar(cp) ? 2 : 1; + } + + int dif = (int)(colwid + wid) - (int)cols; + if (dif > 0) { + ret += dif; + colwid = wid; + } else if (dif == 0) { + colwid = 0; + } else { + colwid += wid; + } + + if (off >= pos) { + break; + } + + off += len; + ret += wid; + } + + return ret; +} + +/* Read UTF8 character from file. + */ +inline int unicodeReadUTF8Char(int fd, char* buf, int* cp) +{ + int nread = read(fd,&buf[0],1); + + if (nread <= 0) { return nread; } + + unsigned char byte = buf[0]; + + if ((byte & 0x80) == 0) { + ; + } else if ((byte & 0xE0) == 0xC0) { + nread = read(fd,&buf[1],1); + if (nread <= 0) { return nread; } + } else if ((byte & 0xF0) == 0xE0) { + nread = read(fd,&buf[1],2); + if (nread <= 0) { return nread; } + } else if ((byte & 0xF8) == 0xF0) { + nread = read(fd,&buf[1],3); + if (nread <= 0) { return nread; } + } else { + return -1; + } + + return unicodeUTF8CharToCodePoint(buf, 4, cp); +} + +/* ======================= Low level terminal handling ====================== */ + +/* Set if to use or not the multi line mode. */ +inline void SetMultiLine(bool ml) { + mlmode = ml; +} + +/* Return true if the terminal name is in the list of terminals we know are + * not able to understand basic escape sequences. */ +inline bool isUnsupportedTerm(void) { +#ifndef _WIN32 + char *term = getenv("TERM"); + int j; + + if (term == NULL) return false; + for (j = 0; unsupported_term[j]; j++) + if (!strcasecmp(term,unsupported_term[j])) return true; +#endif + return false; +} + +/* Raw mode: 1960 magic shit. */ +inline bool enableRawMode(int fd) { +#ifndef _WIN32 + struct termios raw; + + if (!isatty(STDIN_FILENO)) goto fatal; + if (!atexit_registered) { + atexit(linenoiseAtExit); + atexit_registered = true; + } + if (tcgetattr(fd,&orig_termios) == -1) goto fatal; + + raw = orig_termios; /* modify the original mode */ + /* input modes: no break, no CR to NL, no parity check, no strip char, + * no start/stop output control. */ + raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + /* output modes - disable post processing */ + // NOTE: Multithreaded issue #20 (https://github.com/yhirose/cpp-linenoise/issues/20) + // raw.c_oflag &= ~(OPOST); + /* control modes - set 8 bit chars */ + raw.c_cflag |= (CS8); + /* local modes - echoing off, canonical off, no extended functions, + * no signal chars (^Z,^C) */ + raw.c_lflag &= ~(ECHO | ICANON | IEXTEN | ISIG); + /* control chars - set return condition: min number of bytes and timer. + * We want read to return every single byte, without timeout. */ + raw.c_cc[VMIN] = 1; raw.c_cc[VTIME] = 0; /* 1 byte, no timer */ + + /* put terminal in raw mode after flushing */ + if (tcsetattr(fd,TCSAFLUSH,&raw) < 0) goto fatal; + rawmode = true; +#else + if (!atexit_registered) { + /* Cleanup them at exit */ + atexit(linenoiseAtExit); + atexit_registered = true; + + /* Init windows console handles only once */ + hOut = GetStdHandle(STD_OUTPUT_HANDLE); + if (hOut==INVALID_HANDLE_VALUE) goto fatal; + } + + DWORD consolemodeOut; + if (!GetConsoleMode(hOut, &consolemodeOut)) { + CloseHandle(hOut); + errno = ENOTTY; + return false; + }; + + hIn = GetStdHandle(STD_INPUT_HANDLE); + if (hIn == INVALID_HANDLE_VALUE) { + CloseHandle(hOut); + errno = ENOTTY; + return false; + } + + GetConsoleMode(hIn, &consolemodeIn); + /* Enable raw mode */ + SetConsoleMode(hIn, consolemodeIn & ~ENABLE_PROCESSED_INPUT); + + rawmode = true; +#endif + return true; + +fatal: + errno = ENOTTY; + return false; +} + +inline void disableRawMode(int fd) { +#ifdef _WIN32 + if (consolemodeIn) { + SetConsoleMode(hIn, consolemodeIn); + consolemodeIn = 0; + } + rawmode = false; +#else + /* Don't even check the return value as it's too late. */ + if (rawmode && tcsetattr(fd,TCSAFLUSH,&orig_termios) != -1) + rawmode = false; +#endif +} + +/* Use the ESC [6n escape sequence to query the horizontal cursor position + * and return it. On error -1 is returned, on success the position of the + * cursor. */ +inline int getCursorPosition(int ifd, int ofd) { + char buf[32]; + int cols, rows; + unsigned int i = 0; + + /* Report cursor location */ + if (write(ofd, "\x1b[6n", 4) != 4) return -1; + + /* Read the response: ESC [ rows ; cols R */ + while (i < sizeof(buf)-1) { + if (read(ifd,buf+i,1) != 1) break; + if (buf[i] == 'R') break; + i++; + } + buf[i] = '\0'; + + /* Parse it. */ + if (buf[0] != ESC || buf[1] != '[') return -1; + if (sscanf(buf+2,"%d;%d",&rows,&cols) != 2) return -1; + return cols; +} + +/* Try to get the number of columns in the current terminal, or assume 80 + * if it fails. */ +inline int getColumns(int ifd, int ofd) { +#ifdef _WIN32 + CONSOLE_SCREEN_BUFFER_INFO b; + + if (!GetConsoleScreenBufferInfo(hOut, &b)) return 80; + return b.srWindow.Right - b.srWindow.Left; +#else + struct winsize ws; + + if (ioctl(1, TIOCGWINSZ, &ws) == -1 || ws.ws_col == 0) { + /* ioctl() failed. Try to query the terminal itself. */ + int start, cols; + + /* Get the initial position so we can restore it later. */ + start = getCursorPosition(ifd,ofd); + if (start == -1) goto failed; + + /* Go to right margin and get position. */ + if (write(ofd,"\x1b[999C",6) != 6) goto failed; + cols = getCursorPosition(ifd,ofd); + if (cols == -1) goto failed; + + /* Restore position. */ + if (cols > start) { + char seq[32]; + snprintf(seq,32,"\x1b[%dD",cols-start); + if (write(ofd,seq,strlen(seq)) == -1) { + /* Can't recover... */ + } + } + return cols; + } else { + return ws.ws_col; + } + +failed: + return 80; +#endif +} + +/* Clear the screen. Used to handle ctrl+l */ +inline void linenoiseClearScreen(void) { + if (write(STDOUT_FILENO,"\x1b[H\x1b[2J",7) <= 0) { + /* nothing to do, just to avoid warning. */ + } +} + +/* Beep, used for completion when there is nothing to complete or when all + * the choices were already shown. */ +inline void linenoiseBeep(void) { + fprintf(stderr, "\x7"); + fflush(stderr); +} + +/* ============================== Completion ================================ */ + +/* This is an helper function for linenoiseEdit() and is called when the + * user types the key in order to complete the string currently in the + * input. + * + * The state of the editing is encapsulated into the pointed linenoiseState + * structure as described in the structure definition. */ +inline int completeLine(struct linenoiseState *ls, char *cbuf, int *c) { + std::vector lc; + int nread = 0, nwritten; + *c = 0; + + completionCallback(ls->buf,lc); + if (lc.empty()) { + linenoiseBeep(); + } else { + int stop = 0, i = 0; + + while(!stop) { + /* Show completion or original buffer */ + if (i < static_cast(lc.size())) { + struct linenoiseState saved = *ls; + + ls->len = ls->pos = static_cast(lc[i].size()); + ls->buf = &lc[i][0]; + refreshLine(ls); + ls->len = saved.len; + ls->pos = saved.pos; + ls->buf = saved.buf; + } else { + refreshLine(ls); + } + + //nread = read(ls->ifd,&c,1); +#ifdef _WIN32 + nread = win32read(c); + if (nread == 1) { + cbuf[0] = *c; + } +#else + nread = unicodeReadUTF8Char(ls->ifd,cbuf,c); +#endif + if (nread <= 0) { + *c = -1; + return nread; + } + + switch(*c) { + case 9: /* tab */ + i = (i+1) % (lc.size()+1); + if (i == static_cast(lc.size())) linenoiseBeep(); + break; + case 27: /* escape */ + /* Re-show original buffer */ + if (i < static_cast(lc.size())) refreshLine(ls); + stop = 1; + break; + default: + /* Update buffer and return */ + if (i < static_cast(lc.size())) { + nwritten = snprintf(ls->buf,ls->buflen,"%s",&lc[i][0]); + ls->len = ls->pos = nwritten; + } + stop = 1; + break; + } + } + } + + return nread; +} + +/* Register a callback function to be called for tab-completion. */ +inline void SetCompletionCallback(CompletionCallback fn) { + completionCallback = fn; +} + +/* =========================== Line editing ================================= */ + +/* Single line low level line refresh. + * + * Rewrite the currently edited line accordingly to the buffer content, + * cursor position, and number of columns of the terminal. */ +inline void refreshSingleLine(struct linenoiseState *l) { + char seq[64]; + int pcolwid = unicodeColumnPos(l->prompt.c_str(), static_cast(l->prompt.length())); + int fd = l->ofd; + char *buf = l->buf; + int len = l->len; + int pos = l->pos; + std::string ab; + + while((pcolwid+unicodeColumnPos(buf, pos)) >= l->cols) { + int glen = unicodeGraphemeLen(buf, len, 0); + buf += glen; + len -= glen; + pos -= glen; + } + while (pcolwid+unicodeColumnPos(buf, len) > l->cols) { + len -= unicodePrevGraphemeLen(buf, len); + } + + /* Cursor to left edge */ + snprintf(seq,64,"\r"); + ab += seq; + /* Write the prompt and the current buffer content */ + ab += l->prompt; + ab.append(buf, len); + /* Erase to right */ + snprintf(seq,64,"\x1b[0K"); + ab += seq; + /* Move cursor to original position. */ + snprintf(seq,64,"\r\x1b[%dC", (int)(unicodeColumnPos(buf, pos)+pcolwid)); + ab += seq; + if (write(fd,ab.c_str(), static_cast(ab.length())) == -1) {} /* Can't recover from write error. */ +} + +/* Multi line low level line refresh. + * + * Rewrite the currently edited line accordingly to the buffer content, + * cursor position, and number of columns of the terminal. */ +inline void refreshMultiLine(struct linenoiseState *l) { + char seq[64]; + int pcolwid = unicodeColumnPos(l->prompt.c_str(), static_cast(l->prompt.length())); + int colpos = unicodeColumnPosForMultiLine(l->buf, l->len, l->len, l->cols, pcolwid); + int colpos2; /* cursor column position. */ + int rows = (pcolwid+colpos+l->cols-1)/l->cols; /* rows used by current buf. */ + int rpos = (pcolwid+l->oldcolpos+l->cols)/l->cols; /* cursor relative row. */ + int rpos2; /* rpos after refresh. */ + int col; /* colum position, zero-based. */ + int old_rows = (int)l->maxrows; + int fd = l->ofd, j; + std::string ab; + + /* Update maxrows if needed. */ + if (rows > (int)l->maxrows) l->maxrows = rows; + + /* First step: clear all the lines used before. To do so start by + * going to the last row. */ + if (old_rows-rpos > 0) { + snprintf(seq,64,"\x1b[%dB", old_rows-rpos); + ab += seq; + } + + /* Now for every row clear it, go up. */ + for (j = 0; j < old_rows-1; j++) { + snprintf(seq,64,"\r\x1b[0K\x1b[1A"); + ab += seq; + } + + /* Clean the top line. */ + snprintf(seq,64,"\r\x1b[0K"); + ab += seq; + + /* Write the prompt and the current buffer content */ + ab += l->prompt; + ab.append(l->buf, l->len); + + /* Get text width to cursor position */ + colpos2 = unicodeColumnPosForMultiLine(l->buf, l->len, l->pos, l->cols, pcolwid); + + /* If we are at the very end of the screen with our prompt, we need to + * emit a newline and move the prompt to the first column. */ + if (l->pos && + l->pos == l->len && + (colpos2+pcolwid) % l->cols == 0) + { + ab += "\n"; + snprintf(seq,64,"\r"); + ab += seq; + rows++; + if (rows > (int)l->maxrows) l->maxrows = rows; + } + + /* Move cursor to right position. */ + rpos2 = (pcolwid+colpos2+l->cols)/l->cols; /* current cursor relative row. */ + + /* Go up till we reach the expected positon. */ + if (rows-rpos2 > 0) { + snprintf(seq,64,"\x1b[%dA", rows-rpos2); + ab += seq; + } + + /* Set column. */ + col = (pcolwid + colpos2) % l->cols; + if (col) + snprintf(seq,64,"\r\x1b[%dC", col); + else + snprintf(seq,64,"\r"); + ab += seq; + + l->oldcolpos = colpos2; + + if (write(fd,ab.c_str(), static_cast(ab.length())) == -1) {} /* Can't recover from write error. */ +} + +/* Calls the two low level functions refreshSingleLine() or + * refreshMultiLine() according to the selected mode. */ +inline void refreshLine(struct linenoiseState *l) { + if (mlmode) + refreshMultiLine(l); + else + refreshSingleLine(l); +} + +/* Insert the character 'c' at cursor current position. + * + * On error writing to the terminal -1 is returned, otherwise 0. */ +inline int linenoiseEditInsert(struct linenoiseState *l, const char* cbuf, int clen) { + if (l->len < l->buflen) { + if (l->len == l->pos) { + memcpy(&l->buf[l->pos],cbuf,clen); + l->pos+=clen; + l->len+=clen;; + l->buf[l->len] = '\0'; + if ((!mlmode && unicodeColumnPos(l->prompt.c_str(), static_cast(l->prompt.length()))+unicodeColumnPos(l->buf,l->len) < l->cols) /* || mlmode */) { + /* Avoid a full update of the line in the + * trivial case. */ + if (write(l->ofd,cbuf,clen) == -1) return -1; + } else { + refreshLine(l); + } + } else { + memmove(l->buf+l->pos+clen,l->buf+l->pos,l->len-l->pos); + memcpy(&l->buf[l->pos],cbuf,clen); + l->pos+=clen; + l->len+=clen; + l->buf[l->len] = '\0'; + refreshLine(l); + } + } + return 0; +} + +/* Move cursor on the left. */ +inline void linenoiseEditMoveLeft(struct linenoiseState *l) { + if (l->pos > 0) { + l->pos -= unicodePrevGraphemeLen(l->buf, l->pos); + refreshLine(l); + } +} + +/* Move cursor on the right. */ +inline void linenoiseEditMoveRight(struct linenoiseState *l) { + if (l->pos != l->len) { + l->pos += unicodeGraphemeLen(l->buf, l->len, l->pos); + refreshLine(l); + } +} + +/* Move cursor to the start of the line. */ +inline void linenoiseEditMoveHome(struct linenoiseState *l) { + if (l->pos != 0) { + l->pos = 0; + refreshLine(l); + } +} + +/* Move cursor to the end of the line. */ +inline void linenoiseEditMoveEnd(struct linenoiseState *l) { + if (l->pos != l->len) { + l->pos = l->len; + refreshLine(l); + } +} + +/* Substitute the currently edited line with the next or previous history + * entry as specified by 'dir'. */ +#define LINENOISE_HISTORY_NEXT 0 +#define LINENOISE_HISTORY_PREV 1 +inline void linenoiseEditHistoryNext(struct linenoiseState *l, int dir) { + if (history.size() > 1) { + /* Update the current history entry before to + * overwrite it with the next one. */ + history[history.size() - 1 - l->history_index] = l->buf; + /* Show the new entry */ + l->history_index += (dir == LINENOISE_HISTORY_PREV) ? 1 : -1; + if (l->history_index < 0) { + l->history_index = 0; + return; + } else if (l->history_index >= (int)history.size()) { + l->history_index = static_cast(history.size())-1; + return; + } + memset(l->buf, 0, l->buflen); + strcpy(l->buf,history[history.size() - 1 - l->history_index].c_str()); + l->len = l->pos = static_cast(strlen(l->buf)); + refreshLine(l); + } +} + +/* Delete the character at the right of the cursor without altering the cursor + * position. Basically this is what happens with the "Delete" keyboard key. */ +inline void linenoiseEditDelete(struct linenoiseState *l) { + if (l->len > 0 && l->pos < l->len) { + int glen = unicodeGraphemeLen(l->buf,l->len,l->pos); + memmove(l->buf+l->pos,l->buf+l->pos+glen,l->len-l->pos-glen); + l->len-=glen; + l->buf[l->len] = '\0'; + refreshLine(l); + } +} + +/* Backspace implementation. */ +inline void linenoiseEditBackspace(struct linenoiseState *l) { + if (l->pos > 0 && l->len > 0) { + int glen = unicodePrevGraphemeLen(l->buf,l->pos); + memmove(l->buf+l->pos-glen,l->buf+l->pos,l->len-l->pos); + l->pos-=glen; + l->len-=glen; + l->buf[l->len] = '\0'; + refreshLine(l); + } +} + +/* Delete the previosu word, maintaining the cursor at the start of the + * current word. */ +inline void linenoiseEditDeletePrevWord(struct linenoiseState *l) { + int old_pos = l->pos; + int diff; + + while (l->pos > 0 && l->buf[l->pos-1] == ' ') + l->pos--; + while (l->pos > 0 && l->buf[l->pos-1] != ' ') + l->pos--; + diff = old_pos - l->pos; + memmove(l->buf+l->pos,l->buf+old_pos,l->len-old_pos+1); + l->len -= diff; + refreshLine(l); +} + +/* This function is the core of the line editing capability of linenoise. + * It expects 'fd' to be already in "raw mode" so that every key pressed + * will be returned ASAP to read(). + * + * The resulting string is put into 'buf' when the user type enter, or + * when ctrl+d is typed. + * + * The function returns the length of the current buffer. */ +inline int linenoiseEdit(int stdin_fd, int stdout_fd, char *buf, int buflen, const char *prompt) +{ + struct linenoiseState l; + + /* Populate the linenoise state that we pass to functions implementing + * specific editing functionalities. */ + l.ifd = stdin_fd; + l.ofd = stdout_fd; + l.buf = buf; + l.buflen = buflen; + l.prompt = prompt; + l.oldcolpos = l.pos = 0; + l.len = 0; + l.cols = getColumns(stdin_fd, stdout_fd); + l.maxrows = 0; + l.history_index = 0; + + /* Buffer starts empty. */ + l.buf[0] = '\0'; + l.buflen--; /* Make sure there is always space for the nulterm */ + + /* The latest history entry is always our current buffer, that + * initially is just an empty string. */ + AddHistory(""); + + if (write(l.ofd,prompt, static_cast(l.prompt.length())) == -1) return -1; + while(1) { + int c; + char cbuf[4]; + int nread; + char seq[3]; + +#ifdef _WIN32 + nread = win32read(&c); + if (nread == 1) { + cbuf[0] = c; + } +#else + nread = unicodeReadUTF8Char(l.ifd,cbuf,&c); +#endif + if (nread <= 0) return (int)l.len; + + /* Only autocomplete when the callback is set. It returns < 0 when + * there was an error reading from fd. Otherwise it will return the + * character that should be handled next. */ + if (c == 9 && completionCallback != NULL) { + nread = completeLine(&l,cbuf,&c); + /* Return on errors */ + if (c < 0) return l.len; + /* Read next character when 0 */ + if (c == 0) continue; + } + + switch(c) { + case ENTER: /* enter */ + if (!history.empty()) history.pop_back(); + if (mlmode) linenoiseEditMoveEnd(&l); + return (int)l.len; + case CTRL_C: /* ctrl-c */ + errno = EAGAIN; + return -1; + case BACKSPACE: /* backspace */ + case 8: /* ctrl-h */ + linenoiseEditBackspace(&l); + break; + case CTRL_D: /* ctrl-d, remove char at right of cursor, or if the + line is empty, act as end-of-file. */ + if (l.len > 0) { + linenoiseEditDelete(&l); + } else { + history.pop_back(); + return -1; + } + break; + case CTRL_T: /* ctrl-t, swaps current character with previous. */ + if (l.pos > 0 && l.pos < l.len) { + char aux = buf[l.pos-1]; + buf[l.pos-1] = buf[l.pos]; + buf[l.pos] = aux; + if (l.pos != l.len-1) l.pos++; + refreshLine(&l); + } + break; + case CTRL_B: /* ctrl-b */ + linenoiseEditMoveLeft(&l); + break; + case CTRL_F: /* ctrl-f */ + linenoiseEditMoveRight(&l); + break; + case CTRL_P: /* ctrl-p */ + linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_PREV); + break; + case CTRL_N: /* ctrl-n */ + linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_NEXT); + break; + case ESC: /* escape sequence */ + /* Read the next two bytes representing the escape sequence. + * Use two calls to handle slow terminals returning the two + * chars at different times. */ + if (read(l.ifd,seq,1) == -1) break; + if (read(l.ifd,seq+1,1) == -1) break; + + /* ESC [ sequences. */ + if (seq[0] == '[') { + if (seq[1] >= '0' && seq[1] <= '9') { + /* Extended escape, read additional byte. */ + if (read(l.ifd,seq+2,1) == -1) break; + if (seq[2] == '~') { + switch(seq[1]) { + case '3': /* Delete key. */ + linenoiseEditDelete(&l); + break; + } + } + } else { + switch(seq[1]) { + case 'A': /* Up */ + linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_PREV); + break; + case 'B': /* Down */ + linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_NEXT); + break; + case 'C': /* Right */ + linenoiseEditMoveRight(&l); + break; + case 'D': /* Left */ + linenoiseEditMoveLeft(&l); + break; + case 'H': /* Home */ + linenoiseEditMoveHome(&l); + break; + case 'F': /* End*/ + linenoiseEditMoveEnd(&l); + break; + } + } + } + + /* ESC O sequences. */ + else if (seq[0] == 'O') { + switch(seq[1]) { + case 'H': /* Home */ + linenoiseEditMoveHome(&l); + break; + case 'F': /* End*/ + linenoiseEditMoveEnd(&l); + break; + } + } + break; + default: + if (linenoiseEditInsert(&l,cbuf,nread)) return -1; + break; + case CTRL_U: /* Ctrl+u, delete the whole line. */ + buf[0] = '\0'; + l.pos = l.len = 0; + refreshLine(&l); + break; + case CTRL_K: /* Ctrl+k, delete from current to end of line. */ + buf[l.pos] = '\0'; + l.len = l.pos; + refreshLine(&l); + break; + case CTRL_A: /* Ctrl+a, go to the start of the line */ + linenoiseEditMoveHome(&l); + break; + case CTRL_E: /* ctrl+e, go to the end of the line */ + linenoiseEditMoveEnd(&l); + break; + case CTRL_L: /* ctrl+l, clear screen */ + linenoiseClearScreen(); + refreshLine(&l); + break; + case CTRL_W: /* ctrl+w, delete previous word */ + linenoiseEditDeletePrevWord(&l); + break; + } + } + return l.len; +} + +/* This function calls the line editing function linenoiseEdit() using + * the STDIN file descriptor set in raw mode. */ +inline bool linenoiseRaw(const char *prompt, std::string& line) { + bool quit = false; + + if (!isatty(STDIN_FILENO)) { + /* Not a tty: read from file / pipe. */ + std::getline(std::cin, line); + } else { + /* Interactive editing. */ + if (enableRawMode(STDIN_FILENO) == false) { + return quit; + } + + char buf[LINENOISE_MAX_LINE]; + auto count = linenoiseEdit(STDIN_FILENO, STDOUT_FILENO, buf, LINENOISE_MAX_LINE, prompt); + if (count == -1) { + quit = true; + } else { + line.assign(buf, count); + } + + disableRawMode(STDIN_FILENO); + printf("\n"); + } + return quit; +} + +/* The high level function that is the main API of the linenoise library. + * This function checks if the terminal has basic capabilities, just checking + * for a blacklist of stupid terminals, and later either calls the line + * editing function or uses dummy fgets() so that you will be able to type + * something even in the most desperate of the conditions. */ +inline bool Readline(const char *prompt, std::string& line) { + if (isUnsupportedTerm()) { + printf("%s",prompt); + fflush(stdout); + std::getline(std::cin, line); + return false; + } else { + return linenoiseRaw(prompt, line); + } +} + +inline std::string Readline(const char *prompt, bool& quit) { + std::string line; + quit = Readline(prompt, line); + return line; +} + +inline std::string Readline(const char *prompt) { + bool quit; // dummy + return Readline(prompt, quit); +} + +/* ================================ History ================================= */ + +/* At exit we'll try to fix the terminal to the initial conditions. */ +inline void linenoiseAtExit(void) { + disableRawMode(STDIN_FILENO); +} + +/* This is the API call to add a new entry in the linenoise history. + * It uses a fixed array of char pointers that are shifted (memmoved) + * when the history max length is reached in order to remove the older + * entry and make room for the new one, so it is not exactly suitable for huge + * histories, but will work well for a few hundred of entries. + * + * Using a circular buffer is smarter, but a bit more complex to handle. */ +inline bool AddHistory(const char* line) { + if (history_max_len == 0) return false; + + /* Don't add duplicated lines. */ + if (!history.empty() && history.back() == line) return false; + + /* If we reached the max length, remove the older line. */ + if (history.size() == history_max_len) { + history.erase(history.begin()); + } + history.push_back(line); + + return true; +} + +/* Set the maximum length for the history. This function can be called even + * if there is already some history, the function will make sure to retain + * just the latest 'len' elements if the new history length value is smaller + * than the amount of items already inside the history. */ +inline bool SetHistoryMaxLen(size_t len) { + if (len < 1) return false; + history_max_len = len; + if (len < history.size()) { + history.resize(len); + } + return true; +} + +/* Save the history in the specified file. On success *true* is returned + * otherwise *false* is returned. */ +inline bool SaveHistory(const char* path) { + std::ofstream f(path); // TODO: need 'std::ios::binary'? + if (!f) return false; + for (const auto& h: history) { + f << h << std::endl; + } + return true; +} + +/* Load the history from the specified file. If the file does not exist + * zero is returned and no operation is performed. + * + * If the file exists and the operation succeeded *true* is returned, otherwise + * on error *false* is returned. */ +inline bool LoadHistory(const char* path) { + std::ifstream f(path); + if (!f) return false; + std::string line; + while (std::getline(f, line)) { + AddHistory(line.c_str()); + } + return true; +} + +inline const std::vector& GetHistory() { + return history; +} + +} // namespace linenoise + +#ifdef _WIN32 +#undef isatty +#undef write +#undef read +#pragma warning(pop) +#endif + +#endif /* __LINENOISE_HPP */ diff --git a/fuzz/basic.lua b/fuzz/basic.lua new file mode 100644 index 0000000..8b51a4e --- /dev/null +++ b/fuzz/basic.lua @@ -0,0 +1,7 @@ +local function test(t) + for k,v in pairs(t) do + print(k,v) + end +end + +test({a = 1}) diff --git a/fuzz/compiler.cpp b/fuzz/compiler.cpp new file mode 100644 index 0000000..ae915b1 --- /dev/null +++ b/fuzz/compiler.cpp @@ -0,0 +1,10 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include +#include "Luau/Compiler.h" +#include "Luau/Common.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) +{ + Luau::compile(std::string(reinterpret_cast(Data), Size)); + return 0; +} diff --git a/fuzz/format.cpp b/fuzz/format.cpp new file mode 100644 index 0000000..3ad3912 --- /dev/null +++ b/fuzz/format.cpp @@ -0,0 +1,20 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" + +#include +#include + +namespace Luau +{ +void fuzzFormatString(const char* data, size_t size); +} // namespace Luau + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) +{ + // copy data to heap to make sure ASAN can catch out of bounds access + std::vector str(Data, Data + Size); + + Luau::fuzzFormatString(str.data(), str.size()); + + return 0; +} diff --git a/fuzz/linter.cpp b/fuzz/linter.cpp new file mode 100644 index 0000000..55e0888 --- /dev/null +++ b/fuzz/linter.cpp @@ -0,0 +1,39 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include +#include "Luau/TypeInfer.h" +#include "Luau/Linter.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Common.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) +{ + Luau::ParseOptions options; + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + + Luau::ParseResult parseResult = Luau::Parser::parse(reinterpret_cast(Data), Size, names, allocator, options); + + // "static" here is to accelerate fuzzing process by only creating and populating the type environment once + static Luau::NullModuleResolver moduleResolver; + static Luau::InternalErrorReporter iceHandler; + static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); + static int once = (Luau::registerBuiltinTypes(sharedEnv), 1); + (void)once; + static int once2 = (Luau::freeze(sharedEnv.globalTypes), 1); + (void)once2; + + if (parseResult.errors.empty()) + { + Luau::TypeChecker typeck(&moduleResolver, &iceHandler); + typeck.globalScope = sharedEnv.globalScope; + + Luau::LintOptions lintOptions; + lintOptions.warningMask = ~0ull; + + Luau::lint(parseResult.root, names, typeck.globalScope, nullptr, lintOptions); + } + + return 0; +} diff --git a/fuzz/luau.proto b/fuzz/luau.proto new file mode 100644 index 0000000..41a1d07 --- /dev/null +++ b/fuzz/luau.proto @@ -0,0 +1,342 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +syntax = "proto2"; +package luau; + +message Expr { + oneof expr_oneof { + ExprGroup group = 1; + ExprConstantNil nil = 2; + ExprConstantBool bool = 3; + ExprConstantNumber number = 4; + ExprConstantString string = 5; + ExprLocal local = 6; + ExprGlobal global = 7; + ExprVarargs varargs = 8; + ExprCall call = 9; + ExprIndexName index_name = 10; + ExprIndexExpr index_expr = 11; + ExprFunction function = 12; + ExprTable table = 13; + ExprUnary unary = 14; + ExprBinary binary = 15; + } +} + +message ExprPrefix { + oneof expr_oneof { + ExprGroup group = 1; + ExprLocal local = 2; + ExprGlobal global = 3; + ExprCall call = 4; + ExprIndexName index_name = 5; + ExprIndexExpr index_expr = 6; + } +} + +message Local { + required int32 name = 1; +} + +message Typename { + required int32 index = 1; +} + +message Name { + oneof name_oneof { + int32 builtin = 1; + int32 custom = 2; + } +} + +message ExprGroup { + required Expr expr = 1; +} + +message ExprConstantNil { +} + +message ExprConstantBool { + required bool val = 1; +} + +message ExprConstantNumber { + required int32 val = 1; +} + +message ExprConstantString { + required string val = 1; +} + +message ExprLocal { + required Local var = 1; +} + +message ExprGlobal { + required Name name = 1; +} + +message ExprVarargs { +} + +message ExprCall { + required ExprPrefix func = 1; + required bool self = 2; + repeated Expr args = 3; +} + +message ExprIndexName { + required ExprPrefix expr = 1; + required Name index = 2; +} + +message ExprIndexExpr { + required ExprPrefix expr = 1; + required Expr index = 2; +} + +message ExprFunction { + repeated Local args = 1; + required bool vararg = 2; + required StatBlock body = 3; + repeated Type types = 4; + repeated Type rettypes = 5; +} + +message TableItem { + oneof item_oneof { + Name key_name = 1; + Expr key_expr = 2; + } + required Expr value = 3; +} + +message ExprTable { + repeated TableItem items = 1; +} + +message ExprUnary { + enum Op { + Not = 0; + Minus = 1; + Len = 2; + } + + required Op op = 1; + required Expr expr = 2; +} + +message ExprBinary { + enum Op { + Add = 0; + Sub = 1; + Mul = 2; + Div = 3; + Mod = 4; + Pow = 5; + Concat = 6; + CompareNe = 7; + CompareEq = 8; + CompareLt = 9; + CompareLe = 10; + CompareGt = 11; + CompareGe = 12; + And = 13; + Or = 14; + } + + required Op op = 1; + required Expr left = 2; + required Expr right = 3; +} + +message LValue { + oneof lvalue_oneof { + ExprLocal local = 1; + ExprGlobal global = 2; + ExprIndexName index_name = 3; + ExprIndexExpr index_expr = 4; + } +} + +message Stat { + oneof stat_oneof { + StatBlock block = 1; + StatIf if = 2; + StatWhile while = 3; + StatRepeat repeat = 4; + StatBreak break = 5; + StatContinue continue = 6; + StatReturn return = 7; + StatCall call = 8; + StatLocal local = 9; + StatFor for = 10; + StatForIn for_in = 11; + StatAssign assign = 12; + StatCompoundAssign compound_assign = 13; + StatFunction function = 14; + StatLocalFunction local_function = 15; + StatTypeAlias type_alias = 16; + } +} + +message StatBlock { + repeated Stat body = 1; +} + +message StatIf { + required Expr cond = 1; + required StatBlock then = 2; + oneof else_oneof { + StatBlock else = 3; + StatIf elseif = 4; + } +} + +message StatWhile { + required Expr cond = 1; + required StatBlock body = 2; +} + +message StatRepeat { + required StatBlock body = 1; + required Expr cond = 2; +} + +message StatBreak { +} + +message StatContinue { +} + +message StatReturn { + repeated Expr list = 1; +} + +message StatCall { + required ExprCall expr = 1; +} + +message StatLocal { + repeated Local vars = 1; + repeated Expr values = 2; + repeated Type types = 3; +} + +message StatFor { + required Local var = 1; + required Expr from = 2; + required Expr to = 3; + optional Expr step = 4; + required StatBlock body = 5; +} + +message StatForIn { + repeated Local vars = 1; + repeated Expr values = 2; + required StatBlock body = 5; +} + +message StatAssign { + repeated LValue vars = 1; + repeated Expr values = 2; +} + +message StatCompoundAssign { + enum Op { + Add = 0; + Sub = 1; + Mul = 2; + Div = 3; + Mod = 4; + Pow = 5; + Concat = 6; + }; + + required Op op = 1; + required LValue var = 2; + required Expr value = 3; +} + +message StatFunction { + required LValue var = 1; + required ExprFunction func = 2; + required bool self = 3; +} + +message StatLocalFunction { + required Local var = 1; + required ExprFunction func = 2; +} + +message StatTypeAlias { + required Typename name = 1; + required Type type = 2; + repeated Typename generics = 3; +} + +message Type { + oneof type_oneof { + TypePrimitive primitive = 1; + TypeLiteral literal = 2; + TypeTable table = 3; + TypeFunction function = 4; + TypeTypeof typeof = 5; + TypeUnion union = 6; + TypeIntersection intersection = 7; + TypeClass class = 8; + TypeRef ref = 9; + } +} + +message TypePrimitive { + required int32 kind = 1; +} + +message TypeLiteral { + required Typename name = 1; + repeated Typename generics = 2; +} + +message TypeTableItem { + required Name key = 1; + required Type type = 2; +} + +message TypeTableIndexer { + required Type key = 1; + required Type value = 2; +} + +message TypeTable { + repeated TypeTableItem items = 1; + optional TypeTableIndexer indexer = 2; +} + +message TypeFunction { + repeated Type args = 1; + repeated Type rets = 2; + // TODO: vararg? +} + +message TypeTypeof { + required Expr expr = 1; +} + +message TypeUnion { + required Type left = 1; + required Type right = 2; +} + +message TypeIntersection { + required Type left = 1; + required Type right = 2; +} + +message TypeClass { + required int32 kind = 1; +} + +message TypeRef { + required Local prefix = 1; + required Typename index = 2; +} diff --git a/fuzz/parser.cpp b/fuzz/parser.cpp new file mode 100644 index 0000000..ef0dc0c --- /dev/null +++ b/fuzz/parser.cpp @@ -0,0 +1,15 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include +#include "Luau/Parser.h" +#include "Luau/Common.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) +{ + Luau::ParseOptions options; + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + + Luau::Parser::parse(reinterpret_cast(Data), Size, names, allocator, options); + return 0; +} diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp new file mode 100644 index 0000000..6c230b6 --- /dev/null +++ b/fuzz/proto.cpp @@ -0,0 +1,266 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "src/libfuzzer/libfuzzer_macro.h" +#include "luau.pb.h" + +#include "Luau/TypeInfer.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/ModuleResolver.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Compiler.h" +#include "Luau/Linter.h" +#include "Luau/BytecodeBuilder.h" +#include "Luau/Common.h" +#include "Luau/ToString.h" + +#include "lua.h" +#include "lualib.h" + +#include + +// Select components to fuzz +const bool kFuzzCompiler = true; +const bool kFuzzLinter = true; +const bool kFuzzTypeck = true; +const bool kFuzzVM = true; +const bool kFuzzTypes = true; + +static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); + +std::string protoprint(const luau::StatBlock& stat, bool types); + +LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) +LUAU_FASTINT(LuauCheckRecursionLimit) +LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) +LUAU_FASTINT(LuauTypeInferIterationLimit) +LUAU_FASTINT(LuauTarjanChildLimit) + +std::chrono::milliseconds kInterruptTimeout(10); +std::chrono::time_point interruptDeadline; + +size_t kHeapLimit = 512 * 1024 * 1024; +size_t heapSize = 0; + +void interrupt(lua_State* L, int gc) +{ + if (gc >= 0) + return; + + if (std::chrono::system_clock::now() > interruptDeadline) + { + lua_checkstack(L, 1); + luaL_error(L, "execution timed out"); + } +} + +void* allocate(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsize) +{ + if (nsize == 0) + { + heapSize -= osize; + free(ptr); + return NULL; + } + else + { + if (heapSize - osize + nsize > kHeapLimit) + return NULL; + + heapSize -= osize; + heapSize += nsize; + + return realloc(ptr, nsize); + } +} + +lua_State* createGlobalState() +{ + lua_State* L = lua_newstate(allocate, NULL); + + lua_callbacks(L)->interrupt = interrupt; + + luaL_openlibs(L); + luaL_sandbox(L); + + return L; +} + +int registerTypes(Luau::TypeChecker& env) +{ + using namespace Luau; + using std::nullopt; + + Luau::registerBuiltinTypes(env); + + TypeArena& arena = env.globalTypes; + + // Vector3 stub + TypeId vector3MetaType = arena.addType(TableTypeVar{}); + + TypeId vector3InstanceType = arena.addType(ClassTypeVar{"Vector3", {}, nullopt, vector3MetaType, {}, {}}); + getMutable(vector3InstanceType)->props = { + {"X", {env.numberType}}, + {"Y", {env.numberType}}, + {"Z", {env.numberType}}, + }; + + getMutable(vector3MetaType)->props = { + {"__add", {makeFunction(arena, nullopt, {vector3InstanceType, vector3InstanceType}, {vector3InstanceType})}}, + }; + + env.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; + + // Instance stub + TypeId instanceType = arena.addType(ClassTypeVar{"Instance", {}, nullopt, nullopt, {}, {}}); + getMutable(instanceType)->props = { + {"Name", {env.stringType}}, + }; + + env.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; + + // Part stub + TypeId partType = arena.addType(ClassTypeVar{"Part", {}, instanceType, nullopt, {}, {}}); + getMutable(partType)->props = { + {"Position", {vector3InstanceType}}, + }; + + env.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, partType}; + + for (const auto& [_, fun] : env.globalScope->exportedTypeBindings) + persist(fun.type); + + return 0; +} + +static std::string debugsource; + +DEFINE_PROTO_FUZZER(const luau::StatBlock& message) +{ + FInt::LuauTypeInferRecursionLimit.value = 100; + FInt::LuauTypeInferTypePackLoopLimit.value = 100; + FInt::LuauCheckRecursionLimit.value = 100; + FInt::LuauTypeInferIterationLimit.value = 1000; + FInt::LuauTarjanChildLimit.value = 1000; + FInt::LuauTableTypeMaximumStringifierLength.value = 100; + + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + + std::string source = protoprint(message, kFuzzTypes); + + // stash source in a global for easier crash dump debugging + debugsource = source; + + Luau::ParseResult parseResult = Luau::Parser::parse(source.c_str(), source.size(), names, allocator); + + // "static" here is to accelerate fuzzing process by only creating and populating the type environment once + static Luau::NullModuleResolver moduleResolver; + static Luau::InternalErrorReporter iceHandler; + static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); + static int once = registerTypes(sharedEnv); + (void)once; + static int once2 = (Luau::freeze(sharedEnv.globalTypes), 0); + (void)once2; + + iceHandler.onInternalError = [](const char* error) { + printf("ICE: %s\n", error); + LUAU_ASSERT(!"ICE"); + }; + + static bool debug = getenv("LUAU_DEBUG") != 0; + + if (debug) + { + fprintf(stdout, "--\n%s\n", source.c_str()); + fflush(stdout); + } + + std::string bytecode; + + // compile + if (kFuzzCompiler && parseResult.errors.empty()) + { + Luau::CompileOptions compileOptions; + + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, parseResult.root, names, compileOptions); + bytecode = bcb.getBytecode(); + } + catch (const Luau::CompileError&) + { + // not all valid ASTs can be compiled due to limits on number of registers + } + } + + // typecheck + if (kFuzzTypeck && parseResult.root) + { + Luau::SourceModule sourceModule; + sourceModule.root = parseResult.root; + sourceModule.mode = Luau::Mode::Nonstrict; + + Luau::TypeChecker typeck(&moduleResolver, &iceHandler); + typeck.globalScope = sharedEnv.globalScope; + + Luau::ModulePtr module = nullptr; + + try + { + module = typeck.check(sourceModule, Luau::Mode::Nonstrict); + } + catch (std::exception&) + { + // This catches internal errors that the type checker currently (unfortunately) throws in some cases + } + + // lint (note that we need access to types so we need to do this with typeck in scope) + if (kFuzzLinter) + { + Luau::LintOptions lintOptions = {~0u}; + Luau::lint(parseResult.root, names, sharedEnv.globalScope, module.get(), lintOptions); + } + } + + // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down + // note: it's important for typeck to be destroyed at this point! + if (kFuzzTypeck) + { + for (auto& p : sharedEnv.globalScope->bindings) + { + Luau::ToStringOptions opts; + opts.exhaustive = true; + opts.maxTableLength = 0; + opts.maxTypeLength = 0; + + toString(p.second.typeId, opts); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas + } + } + + // run resulting bytecode + if (kFuzzVM && bytecode.size()) + { + static lua_State* globalState = createGlobalState(); + + lua_State* L = lua_newthread(globalState); + luaL_sandboxthread(L); + + if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size()) == 0) + { + interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout; + + lua_resume(L, NULL, 0); + } + + lua_pop(globalState, 1); + + // we'd expect full GC to reclaim all memory allocated by the script + lua_gc(globalState, LUA_GCCOLLECT, 0); + LUAU_ASSERT(heapSize < 256 * 1024); + } +} diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp new file mode 100644 index 0000000..2c861a5 --- /dev/null +++ b/fuzz/protoprint.cpp @@ -0,0 +1,951 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "luau.pb.h" + +static const std::string kNames[] = { + "_G", + "_LOADED", + "_VERSION", + "__add", + "__call", + "__concat", + "__div", + "__eq", + "__index", + "__le", + "__len", + "__lt", + "__mod", + "__mode", + "__mul", + "__namecall", + "__newindex", + "__pow", + "__sub", + "__type", + "__unm", + "abs", + "acos", + "arshift", + "asin", + "assert", + "atan", + "atan2", + "band", + "bit32", + "bnot", + "boolean", + "bor", + "btest", + "bxor", + "byte", + "ceil", + "char", + "charpattern", + "clock", + "codepoint", + "codes", + "concat", + "coroutine", + "cos", + "cosh", + "create", + "date", + "debug", + "deg", + "difftime", + "error", + "exp", + "extract", + "find", + "floor", + "fmod", + "foreach", + "foreachi", + "format", + "frexp", + "function", + "gcinfo", + "getfenv", + "getinfo", + "getmetatable", + "getn", + "gmatch", + "gsub", + "huge", + "insert", + "ipairs", + "isyieldable", + "ldexp", + "len", + "loadstring", + "log", + "log10", + "lower", + "lrotate", + "lshift", + "match", + "math", + "max", + "maxn", + "min", + "modf", + "move", + "newproxy", + "next", + "nil", + "number", + "offset", + "os", + "pack", + "packsize", + "pairs", + "pcall", + "pi", + "pow", + "print", + "rad", + "random", + "randomseed", + "rawequal", + "rawget", + "rawset", + "remove", + "rep", + "replace", + "require", + "resume", + "reverse", + "rrotate", + "rshift", + "running", + "select", + "setfenv", + "setmetatable", + "sin", + "sinh", + "sort", + "split", + "sqrt", + "status", + "stdin", + "string", + "sub", + "table", + "tan", + "tanh", + "thread", + "time", + "tonumber", + "tostring", + "traceback", + "type", + "typeof", + "unpack", + "upper", + "userdata", + "utf8", + "vector", + "wrap", + "xpcall", + "yield", +}; + +static const std::string kTypes[] = { + "any", + "nil", + "number", + "string", + "boolean", + "thread", +}; + +static const std::string kClasses[] = { + "Vector3", + "Instance", + "Part", +}; + +struct ProtoToLuau +{ + struct Function + { + int loops = 0; + bool vararg = false; + }; + + std::string source; + std::vector functions; + bool types = false; + + ProtoToLuau() + { + Function top = {}; + top.vararg = true; + functions.push_back(top); + } + + void ident(const luau::Name& name) + { + if (name.has_builtin()) + { + size_t index = size_t(name.builtin()) % std::size(kNames); + source += kNames[index]; + } + else if (name.has_custom()) + { + source += 'n'; + source += std::to_string(name.custom() & 0xff); + } + else + { + source += '_'; + } + } + + void ident(const luau::Typename& name) + { + source += 't'; + source += std::to_string(name.index() & 0xff); + } + + void print(const luau::Expr& expr) + { + if (expr.has_group()) + print(expr.group()); + else if (expr.has_nil()) + print(expr.nil()); + else if (expr.has_bool_()) + print(expr.bool_()); + else if (expr.has_number()) + print(expr.number()); + else if (expr.has_string()) + print(expr.string()); + else if (expr.has_local()) + print(expr.local()); + else if (expr.has_global()) + print(expr.global()); + else if (expr.has_varargs()) + print(expr.varargs()); + else if (expr.has_call()) + print(expr.call()); + else if (expr.has_index_name()) + print(expr.index_name()); + else if (expr.has_index_expr()) + print(expr.index_expr()); + else if (expr.has_function()) + print(expr.function()); + else if (expr.has_table()) + print(expr.table()); + else if (expr.has_unary()) + print(expr.unary()); + else if (expr.has_binary()) + print(expr.binary()); + else + source += "_"; + } + + void print(const luau::ExprPrefix& expr) + { + if (expr.has_group()) + print(expr.group()); + else if (expr.has_local()) + print(expr.local()); + else if (expr.has_global()) + print(expr.global()); + else if (expr.has_call()) + print(expr.call()); + else if (expr.has_index_name()) + print(expr.index_name()); + else if (expr.has_index_expr()) + print(expr.index_expr()); + else + source += "_"; + } + + void print(const luau::ExprGroup& expr) + { + source += '('; + print(expr.expr()); + source += ')'; + } + + void print(const luau::ExprConstantNil& expr) + { + source += "nil"; + } + + void print(const luau::ExprConstantBool& expr) + { + source += expr.val() ? "true" : "false"; + } + + void print(const luau::ExprConstantNumber& expr) + { + source += std::to_string(expr.val()); + } + + void print(const luau::ExprConstantString& expr) + { + source += '"'; + for (char ch : expr.val()) + if (isalpha(ch)) + source += ch; + source += '"'; + } + + void print(const luau::Local& var) + { + source += 'l'; + source += std::to_string(var.name() & 0xff); + } + + void print(const luau::ExprLocal& expr) + { + print(expr.var()); + } + + void print(const luau::ExprGlobal& expr) + { + ident(expr.name()); + } + + void print(const luau::ExprVarargs& expr) + { + if (functions.back().vararg) + source += "..."; + else + source += "_"; + } + + void print(const luau::ExprCall& expr) + { + if (expr.func().has_index_name()) + print(expr.func().index_name(), expr.self()); + else + print(expr.func()); + source += '('; + for (int i = 0; i < expr.args_size(); ++i) + { + if (i != 0) + source += ','; + print(expr.args(i)); + } + source += ')'; + } + + void print(const luau::ExprIndexName& expr, bool self = false) + { + print(expr.expr()); + source += self ? ':' : '.'; + ident(expr.index()); + } + + void print(const luau::ExprIndexExpr& expr) + { + print(expr.expr()); + source += '['; + print(expr.index()); + source += ']'; + } + + void function(const luau::ExprFunction& expr) + { + source += "("; + for (int i = 0; i < expr.args_size(); ++i) + { + if (i != 0) + source += ','; + print(expr.args(i)); + + if (types && i < expr.types_size()) + { + source += ':'; + print(expr.types(i)); + } + } + if (expr.vararg()) + { + if (expr.args_size()) + source += ','; + source += "..."; + } + source += ')'; + if (types && expr.rettypes_size()) + { + source += ':'; + if (expr.rettypes_size() > 1) + source += '('; + for (size_t i = 0; i < expr.rettypes_size(); ++i) + { + if (i != 0) + source += ','; + print(expr.rettypes(i)); + } + if (expr.rettypes_size() > 1) + source += ')'; + } + source += '\n'; + + Function func = {}; + func.vararg = expr.vararg(); + functions.push_back(func); + + print(expr.body()); + + functions.pop_back(); + + source += "end"; + } + + void print(const luau::ExprFunction& expr) + { + source += "function"; + function(expr); + } + + void print(const luau::ExprTable& expr) + { + source += '{'; + for (int i = 0; i < expr.items_size(); ++i) + { + if (expr.items(i).has_key_name()) + { + ident(expr.items(i).key_name()); + source += '='; + } + else if (expr.items(i).has_key_expr()) + { + source += "["; + print(expr.items(i).key_expr()); + source += "]="; + } + + print(expr.items(i).value()); + source += ','; + } + source += '}'; + } + + void print(const luau::ExprUnary& expr) + { + if (expr.op() == luau::ExprUnary::Not) + source += "not "; + else if (expr.op() == luau::ExprUnary::Minus) + source += "- "; + else if (expr.op() == luau::ExprUnary::Len) + source += "# "; + + print(expr.expr()); + } + + void print(const luau::ExprBinary& expr) + { + print(expr.left()); + + if (expr.op() == luau::ExprBinary::Add) + source += " + "; + else if (expr.op() == luau::ExprBinary::Sub) + source += " - "; + else if (expr.op() == luau::ExprBinary::Mul) + source += " * "; + else if (expr.op() == luau::ExprBinary::Div) + source += " / "; + else if (expr.op() == luau::ExprBinary::Mod) + source += " % "; + else if (expr.op() == luau::ExprBinary::Pow) + source += " ^ "; + else if (expr.op() == luau::ExprBinary::Concat) + source += " .. "; + else if (expr.op() == luau::ExprBinary::CompareNe) + source += " ~= "; + else if (expr.op() == luau::ExprBinary::CompareEq) + source += " == "; + else if (expr.op() == luau::ExprBinary::CompareLt) + source += " < "; + else if (expr.op() == luau::ExprBinary::CompareLe) + source += " <= "; + else if (expr.op() == luau::ExprBinary::CompareGt) + source += " > "; + else if (expr.op() == luau::ExprBinary::CompareGe) + source += " >= "; + else if (expr.op() == luau::ExprBinary::And) + source += " and "; + else if (expr.op() == luau::ExprBinary::Or) + source += " or "; + + print(expr.right()); + } + + void print(const luau::LValue& expr) + { + if (expr.has_local()) + print(expr.local()); + else if (expr.has_global()) + print(expr.global()); + else if (expr.has_index_name()) + print(expr.index_name()); + else if (expr.has_index_expr()) + print(expr.index_expr()); + else + source += "_"; + } + + void print(const luau::Stat& stat) + { + if (stat.has_block()) + print(stat.block()); + else if (stat.has_if_()) + print(stat.if_()); + else if (stat.has_while_()) + print(stat.while_()); + else if (stat.has_repeat()) + print(stat.repeat()); + else if (stat.has_break_()) + print(stat.break_()); + else if (stat.has_continue_()) + print(stat.continue_()); + else if (stat.has_return_()) + print(stat.return_()); + else if (stat.has_call()) + print(stat.call()); + else if (stat.has_local()) + print(stat.local()); + else if (stat.has_for_()) + print(stat.for_()); + else if (stat.has_for_in()) + print(stat.for_in()); + else if (stat.has_assign()) + print(stat.assign()); + else if (stat.has_compound_assign()) + print(stat.compound_assign()); + else if (stat.has_function()) + print(stat.function()); + else if (stat.has_local_function()) + print(stat.local_function()); + else if (stat.has_type_alias()) + print(stat.type_alias()); + else + source += "do end\n"; + } + + void print(const luau::StatBlock& stat) + { + for (int i = 0; i < stat.body_size(); ++i) + { + if (stat.body(i).has_block()) + { + source += "do\n"; + print(stat.body(i)); + source += "end\n"; + } + else + { + print(stat.body(i)); + + // parser will reject code with break/continue/return being non-trailing statements in a block + if (stat.body(i).has_break_() || stat.body(i).has_continue_() || stat.body(i).has_return_()) + break; + } + } + } + + void print(const luau::StatIf& stat) + { + source += "if "; + print(stat.cond()); + source += " then\n"; + print(stat.then()); + + if (stat.has_else_()) + { + source += "else\n"; + print(stat.else_()); + source += "end\n"; + } + else if (stat.has_elseif()) + { + source += "else"; + print(stat.elseif()); + } + else + { + source += "end\n"; + } + } + + void print(const luau::StatWhile& stat) + { + source += "while "; + print(stat.cond()); + source += " do\n"; + + functions.back().loops++; + print(stat.body()); + functions.back().loops--; + + source += "end\n"; + } + + void print(const luau::StatRepeat& stat) + { + source += "repeat\n"; + + functions.back().loops++; + print(stat.body()); + functions.back().loops--; + + source += "until "; + print(stat.cond()); + source += "\n"; + } + + void print(const luau::StatBreak& stat) + { + if (functions.back().loops) + source += "break\n"; + else + source += "do end\n"; + } + + void print(const luau::StatContinue& stat) + { + if (functions.back().loops) + source += "continue\n"; + else + source += "do end\n"; + } + + void print(const luau::StatReturn& stat) + { + source += "return "; + for (int i = 0; i < stat.list_size(); ++i) + { + if (i != 0) + source += ','; + print(stat.list(i)); + } + source += "\n"; + } + + void print(const luau::StatCall& stat) + { + print(stat.expr()); + source += '\n'; + } + + void print(const luau::StatLocal& stat) + { + source += "local "; + + if (stat.vars_size() == 0) + source += '_'; + + for (int i = 0; i < stat.vars_size(); ++i) + { + if (i != 0) + source += ','; + print(stat.vars(i)); + + if (types && i < stat.types_size()) + { + source += ':'; + print(stat.types(i)); + } + } + + if (stat.values_size() != 0) + source += " = "; + + for (int i = 0; i < stat.values_size(); ++i) + { + if (i != 0) + source += ','; + print(stat.values(i)); + } + source += '\n'; + } + + void print(const luau::StatFor& stat) + { + source += "for "; + print(stat.var()); + source += '='; + print(stat.from()); + source += ','; + print(stat.to()); + if (stat.has_step()) + { + source += ','; + print(stat.step()); + } + source += " do\n"; + + functions.back().loops++; + print(stat.body()); + functions.back().loops--; + + source += "end\n"; + } + + void print(const luau::StatForIn& stat) + { + source += "for "; + + if (stat.vars_size() == 0) + source += '_'; + + for (int i = 0; i < stat.vars_size(); ++i) + { + if (i != 0) + source += ','; + print(stat.vars(i)); + } + + source += " in "; + + if (stat.values_size() == 0) + source += "..."; + + for (int i = 0; i < stat.values_size(); ++i) + { + if (i != 0) + source += ','; + print(stat.values(i)); + } + + source += " do\n"; + + functions.back().loops++; + print(stat.body()); + functions.back().loops--; + + source += "end\n"; + } + + void print(const luau::StatAssign& stat) + { + if (stat.vars_size() == 0) + source += '_'; + + for (int i = 0; i < stat.vars_size(); ++i) + { + if (i != 0) + source += ','; + print(stat.vars(i)); + } + + source += " = "; + + if (stat.values_size() == 0) + source += "nil"; + + for (int i = 0; i < stat.values_size(); ++i) + { + if (i != 0) + source += ','; + print(stat.values(i)); + } + source += '\n'; + } + + void print(const luau::StatCompoundAssign& stat) + { + print(stat.var()); + + if (stat.op() == luau::StatCompoundAssign::Add) + source += " += "; + else if (stat.op() == luau::StatCompoundAssign::Sub) + source += " -= "; + else if (stat.op() == luau::StatCompoundAssign::Mul) + source += " *= "; + else if (stat.op() == luau::StatCompoundAssign::Div) + source += " /= "; + else if (stat.op() == luau::StatCompoundAssign::Mod) + source += " %= "; + else if (stat.op() == luau::StatCompoundAssign::Pow) + source += " ^= "; + else if (stat.op() == luau::StatCompoundAssign::Concat) + source += " ..= "; + + print(stat.value()); + source += '\n'; + } + + void print(const luau::StatFunction& stat) + { + source += "function "; + if (stat.var().has_index_name()) + print(stat.var().index_name(), stat.self()); + else if (stat.var().has_index_expr()) + source += '_'; // function foo[bar]() is invalid syntax + else + print(stat.var()); + function(stat.func()); + source += '\n'; + } + + void print(const luau::StatLocalFunction& stat) + { + source += "local function "; + print(stat.var()); + function(stat.func()); + source += '\n'; + } + + void print(const luau::StatTypeAlias& stat) + { + source += "type "; + ident(stat.name()); + + if (stat.generics_size()) + { + source += '<'; + for (size_t i = 0; i < stat.generics_size(); ++i) + { + if (i != 0) + source += ','; + ident(stat.generics(i)); + } + source += '>'; + } + + source += " = "; + print(stat.type()); + source += '\n'; + } + + void print(const luau::Type& type) + { + if (type.has_primitive()) + print(type.primitive()); + else if (type.has_literal()) + print(type.literal()); + else if (type.has_table()) + print(type.table()); + else if (type.has_function()) + print(type.function()); + else if (type.has_typeof()) + print(type.typeof()); + else if (type.has_union_()) + print(type.union_()); + else if (type.has_intersection()) + print(type.intersection()); + else if (type.has_class_()) + print(type.class_()); + else if (type.has_ref()) + print(type.ref()); + else + source += "any"; + } + + void print(const luau::TypePrimitive& type) + { + size_t index = size_t(type.kind()) % std::size(kTypes); + source += kTypes[index]; + } + + void print(const luau::TypeLiteral& type) + { + ident(type.name()); + + if (type.generics_size()) + { + source += '<'; + for (size_t i = 0; i < type.generics_size(); ++i) + { + if (i != 0) + source += ','; + ident(type.generics(i)); + } + source += '>'; + } + } + + void print(const luau::TypeTable& type) + { + source += '{'; + for (size_t i = 0; i < type.items_size(); ++i) + { + ident(type.items(i).key()); + source += ':'; + print(type.items(i).type()); + source += ','; + } + if (type.has_indexer()) + { + source += '['; + print(type.indexer().key()); + source += "]:"; + print(type.indexer().value()); + } + source += '}'; + } + + void print(const luau::TypeFunction& type) + { + source += '('; + for (size_t i = 0; i < type.args_size(); ++i) + { + if (i != 0) + source += ','; + print(type.args(i)); + } + source += ")->"; + if (type.rets_size() != 1) + source += '('; + for (size_t i = 0; i < type.rets_size(); ++i) + { + if (i != 0) + source += ','; + print(type.rets(i)); + } + if (type.rets_size() != 1) + source += ')'; + } + + void print(const luau::TypeTypeof& type) + { + source += "typeof("; + print(type.expr()); + source += ')'; + } + + void print(const luau::TypeUnion& type) + { + source += '('; + print(type.left()); + source += ")|("; + print(type.right()); + source += ')'; + } + + void print(const luau::TypeIntersection& type) + { + source += '('; + print(type.left()); + source += ")&("; + print(type.right()); + source += ')'; + } + + void print(const luau::TypeClass& type) + { + size_t index = size_t(type.kind()) % std::size(kClasses); + source += kClasses[index]; + } + + void print(const luau::TypeRef& type) + { + print(type.prefix()); + source += '.'; + ident(type.index()); + } +}; + +std::string protoprint(const luau::StatBlock& stat, bool types) +{ + ProtoToLuau printer; + printer.types = types; + printer.print(stat); + return printer.source; +} diff --git a/fuzz/prototest.cpp b/fuzz/prototest.cpp new file mode 100644 index 0000000..804e708 --- /dev/null +++ b/fuzz/prototest.cpp @@ -0,0 +1,12 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "src/libfuzzer/libfuzzer_macro.h" +#include "luau.pb.h" + +std::string protoprint(const luau::StatBlock& stat, bool types); + +DEFINE_PROTO_FUZZER(const luau::StatBlock& message) +{ + std::string source = protoprint(message, true); + + printf("%s\n", source.c_str()); +} diff --git a/fuzz/syntax.dict b/fuzz/syntax.dict new file mode 100644 index 0000000..3e87680 --- /dev/null +++ b/fuzz/syntax.dict @@ -0,0 +1,21 @@ +"and" +"break" +"do" +"else" +"elseif" +"end" +"false" +"for" +"function" +"if" +"in" +"local" +"nil" +"not" +"or" +"repeat" +"return" +"then" +"true" +"until" +"while" diff --git a/fuzz/transpiler.cpp b/fuzz/transpiler.cpp new file mode 100644 index 0000000..ccc1a4f --- /dev/null +++ b/fuzz/transpiler.cpp @@ -0,0 +1,10 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include +#include "Luau/Transpiler.h" +#include "Luau/Common.h" + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) +{ + Luau::transpile(std::string_view(reinterpret_cast(Data), Size)); + return 0; +} diff --git a/fuzz/typeck.cpp b/fuzz/typeck.cpp new file mode 100644 index 0000000..5020c77 --- /dev/null +++ b/fuzz/typeck.cpp @@ -0,0 +1,52 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include +#include "Luau/TypeInfer.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Common.h" + +LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) +{ + FInt::LuauTypeInferRecursionLimit.value = 100; + FInt::LuauTypeInferTypePackLoopLimit.value = 100; + + Luau::ParseOptions options; + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + + Luau::ParseResult parseResult = Luau::Parser::parse(reinterpret_cast(Data), Size, names, allocator, options); + + // "static" here is to accelerate fuzzing process by only creating and populating the type environment once + static Luau::NullModuleResolver moduleResolver; + static Luau::InternalErrorReporter iceHandler; + static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); + static int once = (Luau::registerBuiltinTypes(sharedEnv), 1); + (void)once; + static int once2 = (Luau::freeze(sharedEnv.globalTypes), 1); + (void)once2; + + if (parseResult.errors.empty()) + { + Luau::SourceModule module; + module.root = parseResult.root; + module.mode = Luau::Mode::Nonstrict; + + Luau::TypeChecker typeck(&moduleResolver, &iceHandler); + typeck.globalScope = sharedEnv.globalScope; + + try + { + typeck.check(module, Luau::Mode::Nonstrict); + } + catch (std::exception&) + { + // This catches internal errors that the type checker currently (unfortunately) throws in some cases + } + } + + return 0; +} diff --git a/lua_LICENSE.txt b/lua_LICENSE.txt new file mode 100644 index 0000000..0754dfd --- /dev/null +++ b/lua_LICENSE.txt @@ -0,0 +1,19 @@ +Copyright © 1994–2019 Lua.org, PUC-Rio. + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp new file mode 100644 index 0000000..dd49e67 --- /dev/null +++ b/tests/AstQuery.test.cpp @@ -0,0 +1,81 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Fixture.h" + +#include "Luau/AstQuery.h" + +#include "doctest.h" + +using namespace Luau; + +struct DocumentationSymbolFixture : Fixture +{ + std::optional getDocSymbol(const std::string& source, Position position) + { + check(source); + + SourceModule* sourceModule = getMainSourceModule(); + ModulePtr module = getMainModule(); + + return getDocumentationSymbolAtPosition(*sourceModule, *module, position); + } +}; + +TEST_SUITE_BEGIN("AstQuery::getDocumentationSymbolAtPosition"); + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "binding") +{ + std::optional global = getDocSymbol(R"( + local a = string.sub() + )", + Position(1, 21)); + + CHECK_EQ(global, "@luau/global/string"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "prop") +{ + std::optional substring = getDocSymbol(R"( + local a = string.sub() + )", + Position(1, 27)); + + CHECK_EQ(substring, "@luau/global/string.sub"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") +{ + ScopedFastFlag sffs[] = { + {"LuauDontMutatePersistentFunctions", true}, + {"LuauPersistDefinitionFileTypes", true}, + }; + + loadDefinition(R"( + declare function Connect(fn: (string) -> ()) + )"); + + std::optional substring = getDocSymbol(R"( + Connect(function(abc) + end) + )", + Position(1, 27)); + + CHECK_EQ(substring, "@test/global/Connect/param/0/param/0"); +} + +TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") +{ + ScopedFastFlag sffs{"LuauStoreMatchingOverloadFnType", true}; + + loadDefinition(R"( + declare foo: ((string) -> number) & ((number) -> string) + )"); + + std::optional symbol = getDocSymbol(R"( + foo("asdf") + )", + Position(1, 10)); + + CHECK_EQ(symbol, "@test/global/foo/overload/(string) -> number"); +} + +TEST_SUITE_END(); diff --git a/tests/AstVisitor.test.cpp b/tests/AstVisitor.test.cpp new file mode 100644 index 0000000..35bad57 --- /dev/null +++ b/tests/AstVisitor.test.cpp @@ -0,0 +1,117 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Fixture.h" + +#include "Luau/Ast.h" + +#include "doctest.h" + +using namespace Luau; + +namespace +{ + +class AstVisitorTracking : public AstVisitor +{ +private: + std::vector visitedNodes; + std::set seen; + +public: + bool visit(AstNode* n) override + { + visitedNodes.push_back(n); + return true; + } + + AstNode* operator[](size_t index) + { + REQUIRE(index < visitedNodes.size()); + + seen.insert(index); + return visitedNodes[index]; + } + + ~AstVisitorTracking() + { + std::string s = "Seen " + std::to_string(seen.size()) + " nodes but got " + std::to_string(visitedNodes.size()); + CHECK_MESSAGE(seen.size() == visitedNodes.size(), s); + } +}; + +class AstTypeVisitorTrackingWiths : public AstVisitorTracking +{ +public: + using AstVisitorTracking::visit; + bool visit(AstType* n) override + { + return visit((AstNode*)n); + } +}; + +} // namespace + +TEST_SUITE_BEGIN("AstVisitorTest"); + +TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsAreNotVisited") +{ + AstStatBlock* block = parse(R"( + local a: A + )"); + + AstVisitorTracking v; + block->visit(&v); + + CHECK(v[0]->is()); + CHECK(v[1]->is()); + // We should not have v[2] that points to the annotation + // We should not have v[3] that points to the type argument 'number' in A. +} + +TEST_CASE_FIXTURE(Fixture, "LocalTwoBindings") +{ + AstStatBlock* block = parse(R"( + local a, b + )"); + + AstVisitorTracking v; + block->visit(&v); + + CHECK(v[0]->is()); + CHECK(v[1]->is()); +} + +TEST_CASE_FIXTURE(Fixture, "LocalTwoAnnotatedBindings") +{ + AstStatBlock* block = parse(R"( + local a: A, b: B + )"); + + AstTypeVisitorTrackingWiths v; + block->visit(&v); + + CHECK(v[0]->is()); + CHECK(v[1]->is()); + CHECK(v[2]->is()); + CHECK(v[3]->is()); + CHECK(v[4]->is()); +} + +TEST_CASE_FIXTURE(Fixture, "LocalTwoAnnotatedBindingsWithTwoValues") +{ + AstStatBlock* block = parse(R"( + local a: A, b: B = 1, 2 + )"); + + AstTypeVisitorTrackingWiths v; + block->visit(&v); + + CHECK(v[0]->is()); + CHECK(v[1]->is()); + CHECK(v[2]->is()); + CHECK(v[3]->is()); + CHECK(v[4]->is()); + CHECK(v[5]->is()); + CHECK(v[6]->is()); +} + +TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp new file mode 100644 index 0000000..9cd642c --- /dev/null +++ b/tests/Autocomplete.test.cpp @@ -0,0 +1,2576 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Autocomplete.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" +#include "Luau/StringUtils.h" + +#include "Fixture.h" + +#include "doctest.h" + +#include + +LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) +LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) + +using namespace Luau; + +static std::optional nullCallback(std::string tag, std::optional ptr) +{ + return std::nullopt; +} + +struct ACFixture : Fixture +{ + AutocompleteResult autocomplete(unsigned row, unsigned column) + { + return Luau::autocomplete(frontend, "MainModule", Position{row, column}, nullCallback); + } + + AutocompleteResult autocomplete(char marker) + { + auto i = markerPosition.find(marker); + LUAU_ASSERT(i != markerPosition.end()); + const Position& pos = i->second; + return Luau::autocomplete(frontend, "MainModule", pos, nullCallback); + } + + CheckResult check(const std::string& source) + { + markerPosition.clear(); + std::string filteredSource; + filteredSource.reserve(source.size()); + + Position curPos(0, 0); + for (char c : source) + { + if (c == '@' && !filteredSource.empty()) + { + char prevChar = filteredSource.back(); + filteredSource.pop_back(); + curPos.column--; // Adjust column position since we removed a character from the output + LUAU_ASSERT("Illegal marker character" && prevChar >= '0' && prevChar <= '9'); + LUAU_ASSERT("Duplicate marker found" && markerPosition.count(prevChar) == 0); + markerPosition.insert(std::pair{prevChar, curPos}); + } + else + { + filteredSource.push_back(c); + if (c == '\n') + { + curPos.line++; + curPos.column = 0; + } + else + { + curPos.column++; + } + } + } + + return Fixture::check(filteredSource); + } + + // Maps a marker character (0-9 inclusive) to a position in the source code. + std::map markerPosition; +}; + +TEST_SUITE_BEGIN("AutocompleteTest"); + +TEST_CASE_FIXTURE(ACFixture, "empty_program") +{ + check(" "); + + auto ac = autocomplete(0, 1); + + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); +} + +TEST_CASE_FIXTURE(ACFixture, "local_initializer") +{ + check("local a = "); + + auto ac = autocomplete(0, 10); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); +} + +TEST_CASE_FIXTURE(ACFixture, "leave_numbers_alone") +{ + check("local a = 3.1"); + + auto ac = autocomplete(0, 12); + CHECK(ac.entryMap.empty()); +} + +TEST_CASE_FIXTURE(ACFixture, "user_defined_globals") +{ + check("local myLocal = 4; "); + + auto ac = autocomplete(0, 19); + + CHECK(ac.entryMap.count("myLocal")); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); +} + +TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") +{ + check(R"( + local myLocal = 4 + function abc() + local myInnerLocal = 1 + + end + )"); + + auto ac = autocomplete(3, 0); + CHECK(ac.entryMap.count("myLocal")); + CHECK(!ac.entryMap.count("myInnerLocal")); + + ac = autocomplete(4, 0); + CHECK(ac.entryMap.count("myLocal")); + CHECK(ac.entryMap.count("myInnerLocal")); + + ac = autocomplete(6, 0); + CHECK(ac.entryMap.count("myLocal")); + CHECK(!ac.entryMap.count("myInnerLocal")); +} + +TEST_CASE_FIXTURE(ACFixture, "recursive_function") +{ + check(R"( + function foo() + end + )"); + + auto ac = autocomplete(2, 0); + CHECK(ac.entryMap.count("foo")); +} + +TEST_CASE_FIXTURE(ACFixture, "nested_recursive_function") +{ + check(R"( + local function outer() + local function inner() + end + end + )"); + + auto ac = autocomplete(3, 0); + CHECK(ac.entryMap.count("inner")); + CHECK(ac.entryMap.count("outer")); +} + +TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") +{ + check(R"( + local function abc() + + end + )"); + + auto ac = autocomplete(2, 0); + + CHECK(ac.entryMap.count("abc")); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); + + check(R"( + local abc = function() + + end + )"); + + ac = autocomplete(2, 0); + + CHECK(ac.entryMap.count("abc")); // FIXME: This is actually incorrect! + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); +} + +TEST_CASE_FIXTURE(ACFixture, "global_functions_are_not_scoped_lexically") +{ + check(R"( + if true then + function abc() + + end + end + )"); + + auto ac = autocomplete(6, 0); + + CHECK(!ac.entryMap.empty()); + CHECK(ac.entryMap.count("abc")); + CHECK(ac.entryMap.count("table")); + CHECK(ac.entryMap.count("math")); +} + +TEST_CASE_FIXTURE(ACFixture, "local_functions_fall_out_of_scope") +{ + check(R"( + if true then + local function abc() + + end + end + )"); + + auto ac = autocomplete(6, 0); + + CHECK_NE(0, ac.entryMap.size()); + CHECK(!ac.entryMap.count("abc")); +} + +TEST_CASE_FIXTURE(ACFixture, "function_parameters") +{ + check(R"( + function abc(test) + + end + )"); + + auto ac = autocomplete(3, 0); + + CHECK(ac.entryMap.count("test")); +} + +TEST_CASE_FIXTURE(ACFixture, "get_member_completions") +{ + check(R"( + local a = table. -- Line 1 + -- | Column 23 + )"); + + auto ac = autocomplete(1, 24); + + CHECK_EQ(16, ac.entryMap.size()); + CHECK(ac.entryMap.count("find")); + CHECK(ac.entryMap.count("pack")); + CHECK(!ac.entryMap.count("math")); +} + +TEST_CASE_FIXTURE(ACFixture, "nested_member_completions") +{ + check(R"( + local tbl = { abc = { def = 1234, egh = false } } + tbl.abc. + )"); + + auto ac = autocomplete(2, 17); + CHECK_EQ(2, ac.entryMap.size()); + CHECK(ac.entryMap.count("def")); + CHECK(ac.entryMap.count("egh")); +} + +TEST_CASE_FIXTURE(ACFixture, "unsealed_table") +{ + check(R"( + local tbl = {} + tbl.prop = 5 + tbl. + )"); + + auto ac = autocomplete(3, 12); + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("prop")); +} + +TEST_CASE_FIXTURE(ACFixture, "unsealed_table_2") +{ + check(R"( + local tbl = {} + local inner = { prop = 5 } + tbl.inner = inner + tbl.inner. + )"); + + auto ac = autocomplete(4, 19); + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("prop")); +} + +TEST_CASE_FIXTURE(ACFixture, "cyclic_table") +{ + check(R"( + local abc = {} + local def = { abc = abc } + abc.def = def + abc.def. + )"); + + auto ac = autocomplete(4, 17); + CHECK(ac.entryMap.count("abc")); +} + +TEST_CASE_FIXTURE(ACFixture, "table_union") +{ + check(R"( + type t1 = { a1 : string, b2 : number } + type t2 = { b2 : string, c3 : string } + function func(abc : t1 | t2) + abc. + end + )"); + + auto ac = autocomplete(4, 18); + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("b2")); +} + +TEST_CASE_FIXTURE(ACFixture, "table_intersection") +{ + check(R"( + type t1 = { a1 : string, b2 : number } + type t2 = { b2 : string, c3 : string } + function func(abc : t1 & t2) + abc. + end + )"); + + auto ac = autocomplete(4, 18); + CHECK_EQ(3, ac.entryMap.size()); + CHECK(ac.entryMap.count("a1")); + CHECK(ac.entryMap.count("b2")); + CHECK(ac.entryMap.count("c3")); +} + +TEST_CASE_FIXTURE(ACFixture, "get_string_completions") +{ + check(R"( + local a = ("foo"): -- Line 1 + -- | Column 26 + )"); + + auto ac = autocomplete(1, 26); + + CHECK_EQ(17, ac.entryMap.size()); +} + +TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_new_statement") +{ + check(""); + + auto ac = autocomplete(0, 0); + + CHECK_NE(0, ac.entryMap.size()); + + CHECK(ac.entryMap.count("table")); +} + +TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_the_very_start_of_the_script") +{ + check(R"( + + function aaa() end + )"); + + auto ac = autocomplete(0, 0); + + CHECK(ac.entryMap.count("table")); +} + +TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") +{ + check(R"( + local game = { GetService=function(s) return 'hello' end } + + function a() + game: + end + )"); + + auto ac = autocomplete(4, 19); + + CHECK_NE(0, ac.entryMap.size()); + + CHECK(!ac.entryMap.count("math")); +} + +TEST_CASE_FIXTURE(ACFixture, "method_call_inside_if_conditional") +{ + check(R"( + if table: + )"); + + auto ac = autocomplete(1, 19); + + CHECK_NE(0, ac.entryMap.size()); + CHECK(ac.entryMap.count("concat")); + CHECK(!ac.entryMap.count("math")); +} + +TEST_CASE_FIXTURE(ACFixture, "statement_between_two_statements") +{ + check(R"( + function getmyscripts() end + + g + + getmyscripts() + )"); + + auto ac = autocomplete(3, 9); + + CHECK_NE(0, ac.entryMap.size()); + + CHECK(ac.entryMap.count("getmyscripts")); +} + +TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") +{ + check(R"( + local A = {one=1} + + function B() + local A = {two=2} + + A + end + )"); + + auto ac = autocomplete(6, 15); + + CHECK(ac.entryMap.count("A")); + + TypeId t = follow(*ac.entryMap["A"].type); + const TableTypeVar* tt = get(t); + REQUIRE(tt); + + CHECK(tt->props.count("two")); +} + +TEST_CASE_FIXTURE(ACFixture, "recommend_statement_starting_keywords") +{ + check(""); + auto ac = autocomplete(0, 0); + CHECK(ac.entryMap.count("local")); + + check("local i = "); + auto ac2 = autocomplete(0, 10); + CHECK(!ac2.entryMap.count("local")); +} + +TEST_CASE_FIXTURE(ACFixture, "do_not_overwrite_context_sensitive_kws") +{ + check(R"( + local function continue() + end + + + )"); + + auto ac = autocomplete(5, 0); + + AutocompleteEntry entry = ac.entryMap["continue"]; + CHECK(entry.kind == AutocompleteEntryKind::Binding); +} + +TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") +{ + check(R"( + --!strict + local foo = {} + function foo:bar() end + + --[[ + foo: + ]] + )"); + + auto ac = autocomplete(6, 16); + + CHECK_EQ(0, ac.entryMap.size()); +} + +TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_the_end_of_a_comment") +{ + check(R"( + --!strict + )"); + + auto ac = autocomplete(1, 17); + + CHECK_EQ(0, ac.entryMap.size()); +} + +TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment") +{ + ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; + + check(R"( + --[[ + )"); + + auto ac = autocomplete(1, 13); + + CHECK_EQ(0, ac.entryMap.size()); +} + +TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment_at_the_very_end_of_the_file") +{ + ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; + + check("--[["); + + auto ac = autocomplete(0, 4); + CHECK_EQ(0, ac.entryMap.size()); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") +{ + check(R"( + for x = + )"); + + auto ac1 = autocomplete(1, 14); + CHECK_EQ(ac1.entryMap.count("do"), 0); + CHECK_EQ(ac1.entryMap.count("end"), 0); + + check(R"( + for x = 1 + )"); + + auto ac2 = autocomplete(1, 15); + CHECK_EQ(ac2.entryMap.count("do"), 0); + CHECK_EQ(ac2.entryMap.count("end"), 0); + + check(R"( + for x = 1, 2 + )"); + + auto ac3 = autocomplete(1, 18); + CHECK_EQ(1, ac3.entryMap.size()); + CHECK_EQ(ac3.entryMap.count("do"), 1); + + check(R"( + for x = 1, 2, + )"); + + auto ac4 = autocomplete(1, 19); + CHECK_EQ(ac4.entryMap.count("do"), 0); + CHECK_EQ(ac4.entryMap.count("end"), 0); + + check(R"( + for x = 1, 2, 5 + )"); + + auto ac5 = autocomplete(1, 22); + CHECK_EQ(ac5.entryMap.count("do"), 1); + CHECK_EQ(ac5.entryMap.count("end"), 0); + + check(R"( + for x = 1, 2, 5 f + )"); + + auto ac6 = autocomplete(1, 25); + CHECK_EQ(ac6.entryMap.size(), 1); + CHECK_EQ(ac6.entryMap.count("do"), 1); + + check(R"( + for x = 1, 2, 5 do + )"); + + auto ac7 = autocomplete(1, 32); + CHECK_EQ(ac7.entryMap.count("end"), 1); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") +{ + check(R"( + for + )"); + + auto ac1 = autocomplete(1, 12); + CHECK_EQ(0, ac1.entryMap.size()); + + check(R"( + for x + )"); + + auto ac2 = autocomplete(1, 13); + CHECK_EQ(0, ac2.entryMap.size()); + + auto ac2a = autocomplete(1, 14); + CHECK_EQ(1, ac2a.entryMap.size()); + CHECK_EQ(1, ac2a.entryMap.count("in")); + + check(R"( + for x in y + )"); + + auto ac3 = autocomplete(1, 18); + CHECK_EQ(ac3.entryMap.count("table"), 1); + CHECK_EQ(ac3.entryMap.count("do"), 0); + + check(R"( + for x in y + )"); + + auto ac4 = autocomplete(1, 19); + CHECK_EQ(ac4.entryMap.size(), 1); + CHECK_EQ(ac4.entryMap.count("do"), 1); + + check(R"( + for x in f f + )"); + + auto ac5 = autocomplete(1, 20); + CHECK_EQ(ac5.entryMap.size(), 1); + CHECK_EQ(ac5.entryMap.count("do"), 1); + + check(R"( + for x in y do + )"); + + auto ac6 = autocomplete(1, 23); + CHECK_EQ(ac6.entryMap.count("in"), 0); + CHECK_EQ(ac6.entryMap.count("table"), 1); + CHECK_EQ(ac6.entryMap.count("end"), 1); + CHECK_EQ(ac6.entryMap.count("function"), 1); + + check(R"( + for x in y do e + )"); + + auto ac7 = autocomplete(1, 23); + CHECK_EQ(ac7.entryMap.count("in"), 0); + CHECK_EQ(ac7.entryMap.count("table"), 1); + CHECK_EQ(ac7.entryMap.count("end"), 1); + CHECK_EQ(ac7.entryMap.count("function"), 1); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") +{ + check(R"( + while + )"); + + auto ac1 = autocomplete(1, 13); + CHECK_EQ(ac1.entryMap.count("do"), 0); + CHECK_EQ(ac1.entryMap.count("end"), 0); + + check(R"( + while true + )"); + + auto ac2 = autocomplete(1, 19); + CHECK_EQ(1, ac2.entryMap.size()); + CHECK_EQ(ac2.entryMap.count("do"), 1); + + check(R"( + while true do + )"); + + auto ac3 = autocomplete(1, 23); + CHECK_EQ(ac3.entryMap.count("end"), 1); + + check(R"( + while true d + )"); + + auto ac4 = autocomplete(1, 20); + CHECK_EQ(1, ac4.entryMap.size()); + CHECK_EQ(ac4.entryMap.count("do"), 1); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") +{ + check(R"( + if + )"); + + auto ac1 = autocomplete(1, 13); + CHECK_EQ(ac1.entryMap.count("then"), 0); + CHECK_EQ(ac1.entryMap.count("function"), + 1); // FIXME: This is kind of dumb. It is technically syntactically valid but you can never do anything interesting with this. + CHECK_EQ(ac1.entryMap.count("table"), 1); + CHECK_EQ(ac1.entryMap.count("else"), 0); + CHECK_EQ(ac1.entryMap.count("elseif"), 0); + CHECK_EQ(ac1.entryMap.count("end"), 0); + + check(R"( + if x + )"); + + auto ac2 = autocomplete(1, 14); + CHECK_EQ(ac2.entryMap.count("then"), 1); + CHECK_EQ(ac2.entryMap.count("function"), 0); + CHECK_EQ(ac2.entryMap.count("else"), 0); + CHECK_EQ(ac2.entryMap.count("elseif"), 0); + CHECK_EQ(ac2.entryMap.count("end"), 0); + + check(R"( + if x t + )"); + + auto ac3 = autocomplete(1, 14); + CHECK_EQ(1, ac3.entryMap.size()); + CHECK_EQ(ac3.entryMap.count("then"), 1); + + check(R"( + if x then + + end + )"); + + auto ac4 = autocomplete(2, 0); + CHECK_EQ(ac4.entryMap.count("then"), 0); + CHECK_EQ(ac4.entryMap.count("else"), 1); + CHECK_EQ(ac4.entryMap.count("function"), 1); + CHECK_EQ(ac4.entryMap.count("elseif"), 1); + CHECK_EQ(ac4.entryMap.count("end"), 0); + + check(R"( + if x then + t + end + )"); + + auto ac4a = autocomplete(2, 13); + CHECK_EQ(ac4a.entryMap.count("then"), 0); + CHECK_EQ(ac4a.entryMap.count("table"), 1); + CHECK_EQ(ac4a.entryMap.count("else"), 1); + CHECK_EQ(ac4a.entryMap.count("elseif"), 1); + + check(R"( + if x then + + elseif x then + end + )"); + + auto ac5 = autocomplete(2, 0); + CHECK_EQ(ac5.entryMap.count("then"), 0); + CHECK_EQ(ac5.entryMap.count("function"), 1); + CHECK_EQ(ac5.entryMap.count("else"), 0); + CHECK_EQ(ac5.entryMap.count("elseif"), 0); + CHECK_EQ(ac5.entryMap.count("end"), 0); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_in_repeat") +{ + check(R"( + repeat + )"); + + auto ac = autocomplete(1, 16); + CHECK_EQ(ac.entryMap.count("table"), 1); + CHECK_EQ(ac.entryMap.count("until"), 1); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_expression") +{ + check(R"( + repeat + until + )"); + + auto ac = autocomplete(2, 16); + CHECK_EQ(ac.entryMap.count("table"), 1); +} + +TEST_CASE_FIXTURE(ACFixture, "local_names") +{ + check(R"( + local ab + )"); + + auto ac1 = autocomplete(1, 16); + CHECK_EQ(ac1.entryMap.size(), 1); + CHECK_EQ(ac1.entryMap.count("function"), 1); + + check(R"( + local ab, cd + )"); + + auto ac2 = autocomplete(1, 20); + CHECK(ac2.entryMap.empty()); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_fn_exprs") +{ + check(R"( + local function f() + )"); + + auto ac = autocomplete(1, 28); + CHECK_EQ(ac.entryMap.count("end"), 1); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_lambda") +{ + check(R"( + local a = function() local bar = foo en + )"); + + auto ac = autocomplete(1, 47); + CHECK_EQ(ac.entryMap.count("end"), 1); +} + +TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") +{ + check(R"( + repeat + for x + )"); + + auto ac1 = autocomplete(2, 18); + CHECK_EQ(ac1.entryMap.count("in"), 1); + CHECK_EQ(ac1.entryMap.count("until"), 0); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_repeat_middle_keyword") +{ + check(R"( + repeat + )"); + + auto ac1 = autocomplete(1, 15); + CHECK_EQ(ac1.entryMap.count("do"), 1); + CHECK_EQ(ac1.entryMap.count("function"), 1); + CHECK_EQ(ac1.entryMap.count("until"), 1); + + check(R"( + repeat f f + )"); + + auto ac2 = autocomplete(1, 18); + CHECK_EQ(ac2.entryMap.count("function"), 1); + CHECK_EQ(ac2.entryMap.count("until"), 1); + + check(R"( + repeat + u + until + )"); + + auto ac3 = autocomplete(2, 13); + CHECK_EQ(ac3.entryMap.count("until"), 0); +} + +TEST_CASE_FIXTURE(ACFixture, "local_function") +{ + check(R"( + local f + )"); + + auto ac1 = autocomplete(1, 15); + CHECK_EQ(ac1.entryMap.size(), 1); + CHECK_EQ(ac1.entryMap.count("function"), 1); + + check(R"( + local f, cd + )"); + + auto ac2 = autocomplete(1, 15); + CHECK(ac2.entryMap.empty()); +} + +TEST_CASE_FIXTURE(ACFixture, "local_function") +{ + check(R"( + local function + )"); + + auto ac = autocomplete(1, 23); + CHECK(ac.entryMap.empty()); + + check(R"( + local function s + )"); + + ac = autocomplete(1, 23); + CHECK(ac.entryMap.empty()); + + ac = autocomplete(1, 24); + CHECK(ac.entryMap.empty()); + + check(R"( + local function () + )"); + + ac = autocomplete(1, 23); + CHECK(ac.entryMap.empty()); + + ac = autocomplete(1, 25); + CHECK(ac.entryMap.count("end")); + + check(R"( + local function something + )"); + + ac = autocomplete(1, 32); + CHECK(ac.entryMap.empty()); + + check(R"( + local tbl = {} + function tbl.something() end + )"); + + ac = autocomplete(2, 30); + CHECK(ac.entryMap.empty()); +} + +TEST_CASE_FIXTURE(ACFixture, "local_function_params") +{ + check(R"( + local function abc(def) + )"); + + CHECK(autocomplete(1, 23).entryMap.empty()); + CHECK(autocomplete(1, 24).entryMap.empty()); + CHECK(autocomplete(1, 27).entryMap.empty()); + CHECK(autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete(1, 31).entryMap.empty()); + + CHECK(!autocomplete(1, 32).entryMap.empty()); + + check(R"( + local function abc(def) + end + )"); + + for (unsigned int i = 23; i < 31; ++i) + { + CHECK(autocomplete(1, i).entryMap.empty()); + } + CHECK(!autocomplete(1, 32).entryMap.empty()); + + auto ac2 = autocomplete(2, 0); + CHECK_EQ(ac2.entryMap.count("abc"), 1); + CHECK_EQ(ac2.entryMap.count("def"), 1); + + check(R"( + local function abc(def, ghi) + end + )"); + + auto ac3 = autocomplete(1, 35); + CHECK(ac3.entryMap.empty()); +} + +TEST_CASE_FIXTURE(ACFixture, "global_function_params") +{ + check(R"( + function abc(def) + )"); + + for (unsigned int i = 17; i < 25; ++i) + { + CHECK(autocomplete(1, i).entryMap.empty()); + } + CHECK(!autocomplete(1, 26).entryMap.empty()); + + check(R"( + function abc(def) + end + )"); + + for (unsigned int i = 17; i < 25; ++i) + { + CHECK(autocomplete(1, i).entryMap.empty()); + } + CHECK(!autocomplete(1, 26).entryMap.empty()); + + check(R"( + function abc(def) + + end + )"); + + auto ac2 = autocomplete(2, 0); + CHECK_EQ(ac2.entryMap.count("abc"), 1); + CHECK_EQ(ac2.entryMap.count("def"), 1); + + check(R"( + function abc(def, ghi) + end + )"); + + auto ac3 = autocomplete(1, 29); + CHECK(ac3.entryMap.empty()); +} + +TEST_CASE_FIXTURE(ACFixture, "arguments_to_global_lambda") +{ + check(R"( + abc = function(def, ghi) + end + )"); + + auto ac = autocomplete(1, 31); + CHECK(ac.entryMap.empty()); +} + +TEST_CASE_FIXTURE(ACFixture, "function_expr_params") +{ + check(R"( + abc = function(def) + )"); + + for (unsigned int i = 20; i < 27; ++i) + { + CHECK(autocomplete(1, i).entryMap.empty()); + } + CHECK(!autocomplete(1, 28).entryMap.empty()); + + check(R"( + abc = function(def) + end + )"); + + for (unsigned int i = 20; i < 27; ++i) + { + CHECK(autocomplete(1, i).entryMap.empty()); + } + CHECK(!autocomplete(1, 28).entryMap.empty()); + + check(R"( + abc = function(def) + + end + )"); + + auto ac2 = autocomplete(2, 0); + CHECK_EQ(ac2.entryMap.count("def"), 1); +} + +TEST_CASE_FIXTURE(ACFixture, "local_initializer") +{ + check(R"( + local a = t + )"); + + auto ac = autocomplete(1, 19); + CHECK_EQ(ac.entryMap.count("table"), 1); + CHECK_EQ(ac.entryMap.count("true"), 1); +} + +TEST_CASE_FIXTURE(ACFixture, "local_initializer_2") +{ + check(R"( + local a= + )"); + + auto ac = autocomplete(1, 16); + CHECK(ac.entryMap.count("table")); +} + +TEST_CASE_FIXTURE(ACFixture, "get_member_completions") +{ + check(R"( + local a = 12.3 + )"); + + auto ac = autocomplete(1, 21); + CHECK(ac.entryMap.empty()); +} + +TEST_CASE_FIXTURE(ACFixture, "sometimes_the_metatable_is_an_error") +{ + check(R"( + local T = {} + T.__index = T + + function T.new() + return setmetatable({x=6}, X) -- oops! + end + local t = T.new() + t. + )"); + + autocomplete(8, 12); + // Don't crash! +} + +TEST_CASE_FIXTURE(ACFixture, "local_types_builtin") +{ + check(R"( +local a: n +local b: string = "don't trip" + )"); + + auto ac = autocomplete(1, 10); + + CHECK(ac.entryMap.count("nil")); + CHECK(ac.entryMap.count("number")); +} + +TEST_CASE_FIXTURE(ACFixture, "private_types") +{ + check(R"( +do + type num = number + local a: nu + local b: num +end +local a: nu + )"); + + auto ac = autocomplete(3, 14); + + CHECK(ac.entryMap.count("num")); + CHECK(ac.entryMap.count("number")); + + ac = autocomplete(4, 15); + + CHECK(ac.entryMap.count("num")); + CHECK(ac.entryMap.count("number")); + + ac = autocomplete(6, 11); + + CHECK(!ac.entryMap.count("num")); + CHECK(ac.entryMap.count("number")); +} + +TEST_CASE_FIXTURE(ACFixture, "type_scoping_easy") +{ + check(R"( +type Table = { a: number, b: number } +do + type Table = { x: string, y: string } + local a: T +end + )"); + + auto ac = autocomplete(4, 14); + + REQUIRE(ac.entryMap.count("Table")); + REQUIRE(ac.entryMap["Table"].type); + const TableTypeVar* tv = get(follow(*ac.entryMap["Table"].type)); + REQUIRE(tv); + CHECK(tv->props.count("x")); +} + +TEST_CASE_FIXTURE(ACFixture, "modules_with_types") +{ + fileResolver.source["Module/A"] = R"( +export type A = { x: number, y: number } +export type B = { z: number, w: number } +return {} + )"; + + LUAU_REQUIRE_NO_ERRORS(frontend.check("Module/A")); + + fileResolver.source["Module/B"] = R"( +local aaa = require(script.Parent.A) +local a: aa + )"; + + frontend.check("Module/B"); + + auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 11}, nullCallback); + + CHECK(ac.entryMap.count("aaa")); +} + +TEST_CASE_FIXTURE(ACFixture, "module_type_members") +{ + fileResolver.source["Module/A"] = R"( +export type A = { x: number, y: number } +export type B = { z: number, w: number } +return {} + )"; + + LUAU_REQUIRE_NO_ERRORS(frontend.check("Module/A")); + + fileResolver.source["Module/B"] = R"( +local aaa = require(script.Parent.A) +local a: aaa. + )"; + + frontend.check("Module/B"); + + auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 13}, nullCallback); + + CHECK_EQ(2, ac.entryMap.size()); + CHECK(ac.entryMap.count("A")); + CHECK(ac.entryMap.count("B")); +} + +TEST_CASE_FIXTURE(ACFixture, "argument_types") +{ + check(R"( +local function f(a: n +local b: string = "don't trip" + )"); + + auto ac = autocomplete(1, 21); + + CHECK(ac.entryMap.count("nil")); + CHECK(ac.entryMap.count("number")); +} + +TEST_CASE_FIXTURE(ACFixture, "return_types") +{ + check(R"( +local function f(a: number): n +local b: string = "don't trip" + )"); + + auto ac = autocomplete(1, 30); + + CHECK(ac.entryMap.count("nil")); + CHECK(ac.entryMap.count("number")); +} + +TEST_CASE_FIXTURE(ACFixture, "as_types") +{ + check(R"( +local a: any = 5 +local b: number = (a :: n + )"); + + auto ac = autocomplete(2, 25); + + CHECK(ac.entryMap.count("nil")); + CHECK(ac.entryMap.count("number")); +} + +TEST_CASE_FIXTURE(ACFixture, "function_type_types") +{ + check(R"( +local a: (n +local b: (number, (n +local c: (number, (number) -> n +local d: (number, (number) -> (number, n +local e: (n: n + )"); + + auto ac = autocomplete(1, 11); + + CHECK(ac.entryMap.count("nil")); + CHECK(ac.entryMap.count("number")); + + ac = autocomplete(2, 20); + + CHECK(ac.entryMap.count("nil")); + CHECK(ac.entryMap.count("number")); + + ac = autocomplete(3, 31); + + CHECK(ac.entryMap.count("nil")); + CHECK(ac.entryMap.count("number")); + + ac = autocomplete(4, 40); + + CHECK(ac.entryMap.count("nil")); + CHECK(ac.entryMap.count("number")); + + ac = autocomplete(5, 14); + + CHECK(ac.entryMap.count("nil")); + CHECK(ac.entryMap.count("number")); +} + +TEST_CASE_FIXTURE(ACFixture, "generic_types") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); + + check(R"( +function f(a: T +local b: string = "don't trip" + )"); + + auto ac = autocomplete(1, 25); + + CHECK(ac.entryMap.count("Tee")); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_in_argument") +{ + // local + check(R"( +local function target(a: number, b: string) return a + #b end + +local one = 4 +local two = "hello" +return target(o + )"); + + auto ac = autocomplete(5, 15); + + CHECK(ac.entryMap.count("one")); + CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::None); + + check(R"( +local function target(a: number, b: string) return a + #b end + +local one = 4 +local two = "hello" +return target(one, t + )"); + + ac = autocomplete(5, 20); + + CHECK(ac.entryMap.count("two")); + CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::None); + + // member + check(R"( +local function target(a: number, b: string) return a + #b end + +local a = { one = 4, two = "hello" } +return target(a. + )"); + + ac = autocomplete(4, 16); + + CHECK(ac.entryMap.count("one")); + CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::None); + + check(R"( +local function target(a: number, b: string) return a + #b end + +local a = { one = 4, two = "hello" } +return target(a.one, a. + )"); + + ac = autocomplete(4, 23); + + CHECK(ac.entryMap.count("two")); + CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::None); + + // union match + check(R"( +local function target(a: string?) return #b end + +local a = { one = 4, two = "hello" } +return target(a. + )"); + + ac = autocomplete(4, 16); + + CHECK(ac.entryMap.count("two")); + CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::None); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_in_table") +{ + check(R"( +type Foo = { a: number, b: string } +local a = { one = 4, two = "hello" } +local b: Foo = { a = a. + )"); + + auto ac = autocomplete(3, 23); + + CHECK(ac.entryMap.count("one")); + CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::None); + + check(R"( +type Foo = { a: number, b: string } +local a = { one = 4, two = "hello" } +local b: Foo = { b = a. + )"); + + ac = autocomplete(3, 23); + + CHECK(ac.entryMap.count("two")); + CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::None); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_function_return_types") +{ + check(R"( +local function target(a: number, b: string) return a + #b end +local function bar1(a: number) return -a end +local function bar2(a: string) reutrn a .. 'x' end + +return target(b + )"); + + auto ac = autocomplete(5, 15); + + CHECK(ac.entryMap.count("bar1")); + CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::None); + + check(R"( +local function target(a: number, b: string) return a + #b end +local function bar1(a: number) return -a end +local function bar2(a: string) return a .. 'x' end + +return target(bar1, b + )"); + + ac = autocomplete(5, 21); + + CHECK(ac.entryMap.count("bar2")); + CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::None); + + check(R"( +local function target(a: number, b: string) return a + #b end +local function bar1(a: number): (...number) return -a, a end +local function bar2(a: string) reutrn a .. 'x' end + +return target(b + )"); + + ac = autocomplete(5, 15); + + CHECK(ac.entryMap.count("bar1")); + CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::None); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_local_type_suggestion") +{ + check(R"( +local b: s = "str" + )"); + + auto ac = autocomplete(1, 10); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function f() return "str" end +local b: s = f() + )"); + + ac = autocomplete(2, 10); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local b: s, c: n = "str", 2 + )"); + + ac = autocomplete(1, 10); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(1, 16); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function f() return 1, "str", 3 end +local a: b, b: n, c: s, d: n = false, f() + )"); + + ac = autocomplete(2, 10); + + CHECK(ac.entryMap.count("boolean")); + CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(2, 16); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(2, 22); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(2, 28); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function f(): ...number return 1, 2, 3 end +local a: boolean, b: n = false, f() + )"); + + ac = autocomplete(2, 22); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_function_type_suggestion") +{ + check(R"( +local b: (n) -> number = function(a: number, b: string) return a + #b end + )"); + + auto ac = autocomplete(1, 11); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local b: (number, s = function(a: number, b: string) return a + #b end + )"); + + ac = autocomplete(1, 19); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local b: (number, string) -> b = function(a: number, b: string): boolean return a + #b == 0 end + )"); + + ac = autocomplete(1, 30); + + CHECK(ac.entryMap.count("boolean")); + CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local b: (number, ...s) = function(a: number, ...: string) return a end + )"); + + ac = autocomplete(1, 22); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" end + )"); + + ac = autocomplete(1, 25); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_full_type_suggestion") +{ + check(R"( +local b: = "str" + )"); + + auto ac = autocomplete(1, 8); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(1, 9); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local b: = function(a: number) return -a end + )"); + + ac = autocomplete(1, 9); + + CHECK(ac.entryMap.count("(number) -> number")); + CHECK(ac.entryMap["(number) -> number"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion") +{ + check(R"( +local function target(a: number, b: string) return a + #b end + +local function d(a: n, b) + return target(a, b) +end + )"); + + auto ac = autocomplete(3, 21); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function target(a: number, b: string) return a + #b end + +local function d(a, b: s) + return target(a, b) +end + )"); + + ac = autocomplete(3, 24); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function target(a: number, b: string) return a + #b end + +local function d(a: , b) + return target(a, b) +end + )"); + + ac = autocomplete(3, 19); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(3, 20); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function target(a: number, b: string) return a + #b end + +local function d(a, b: ): number + return target(a, b) +end + )"); + + ac = autocomplete(3, 23); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(3, 24); + + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::None); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion") +{ + check(R"( +local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end + +local x = target(function(a: + )"); + + auto ac = autocomplete(3, 29); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end + +local x = target(function(a: n + )"); + + ac = autocomplete(3, 30); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end + +local x = target(function(a: n, b: ) + return a + #b +end) + )"); + + ac = autocomplete(3, 30); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(3, 35); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function target(callback: (...number) -> number) return callback(1, 2, 3) end + +local x = target(function(a: n) + return a +end + )"); + + ac = autocomplete(3, 30); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestion") +{ + check(R"( +local function target(callback: (...number) -> number) return callback(1, 2, 3) end + +local x = target(function(...:n) + return a +end + )"); + + auto ac = autocomplete(3, 31); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function target(callback: (...number) -> number) return callback(1, 2, 3) end + +local x = target(function(a:number, b:number, ...:) + return a + b +end + )"); + + ac = autocomplete(3, 50); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion") +{ + check(R"( +local function target(callback: () -> number) return callback() end + +local x = target(function(): n + return 1 +end + )"); + + auto ac = autocomplete(3, 30); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function target(callback: () -> (number, number)) return callback() end + +local x = target(function(): (number, n + return 1, 2 +end + )"); + + ac = autocomplete(3, 39); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion") +{ + check(R"( +local function target(callback: () -> ...number) return callback() end + +local x = target(function(): ...n + return 1, 2, 3 +end + )"); + + auto ac = autocomplete(3, 33); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local function target(callback: () -> ...number) return callback() end + +local x = target(function(): (number, number, ...n + return 1, 2, 3 +end + )"); + + ac = autocomplete(3, 50); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion_optional") +{ + check(R"( +local function target(callback: nil | (a: number, b: string) -> number) return callback(4, "hello") end + +local x = target(function(a: + )"); + + auto ac = autocomplete(3, 29); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion_self") +{ + check(R"( +local t = {} +t.x = 5 +function t:target(callback: (a: number, b: string) -> number) return callback(self.x, "hello") end + +local x = t:target(function(a: , b: ) end) +local y = t.target(t, function(a: number, b: ) end) + )"); + + auto ac = autocomplete(5, 31); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(5, 35); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(6, 45); + + CHECK(ac.entryMap.count("string")); + CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "do_not_suggest_internal_module_type") +{ + fileResolver.source["Module/A"] = R"( +type done = { x: number, y: number } +local function a(a: (done) -> number) return a({x=1, y=2}) end +local function b(a: ((done) -> number) -> number) return a(function(done) return 1 end) end +return {a = a, b = b} + )"; + + LUAU_REQUIRE_NO_ERRORS(frontend.check("Module/A")); + + fileResolver.source["Module/B"] = R"( +local ex = require(script.Parent.A) +ex.a(function(x: + )"; + + frontend.check("Module/B"); + + auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 16}, nullCallback); + + CHECK(!ac.entryMap.count("done")); + + fileResolver.source["Module/C"] = R"( +local ex = require(script.Parent.A) +ex.b(function(x: + )"; + + frontend.check("Module/C"); + + ac = Luau::autocomplete(frontend, "Module/C", Position{2, 16}, nullCallback); + + CHECK(!ac.entryMap.count("(done) -> number")); +} + +TEST_CASE_FIXTURE(ACFixture, "suggest_external_module_type") +{ + fileResolver.source["Module/A"] = R"( +export type done = { x: number, y: number } +local function a(a: (done) -> number) return a({x=1, y=2}) end +local function b(a: ((done) -> number) -> number) return a(function(done) return 1 end) end +return {a = a, b = b} + )"; + + LUAU_REQUIRE_NO_ERRORS(frontend.check("Module/A")); + + fileResolver.source["Module/B"] = R"( +local ex = require(script.Parent.A) +ex.a(function(x: + )"; + + frontend.check("Module/B"); + + auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 16}, nullCallback); + + CHECK(!ac.entryMap.count("done")); + CHECK(ac.entryMap.count("ex.done")); + CHECK(ac.entryMap["ex.done"].typeCorrect == TypeCorrectKind::Correct); + + fileResolver.source["Module/C"] = R"( +local ex = require(script.Parent.A) +ex.b(function(x: + )"; + + frontend.check("Module/C"); + + ac = Luau::autocomplete(frontend, "Module/C", Position{2, 16}, nullCallback); + + CHECK(!ac.entryMap.count("(done) -> number")); + CHECK(ac.entryMap.count("(ex.done) -> number")); + CHECK(ac.entryMap["(ex.done) -> number"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "do_not_suggest_synthetic_table_name") +{ + check(R"( +local foo = { a = 1, b = 2 } +local bar: = foo + )"); + + auto ac = autocomplete(2, 11); + + CHECK(!ac.entryMap.count("foo")); +} + +// CLI-45692: Remove UnfrozenFixture here +TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_function_no_parenthesis") +{ + check(R"( +local function target(a: (number) -> number) return a(4) end +local function bar1(a: number) return -a end +local function bar2(a: string) reutrn a .. 'x' end + +return target(b + )"); + + auto ac = autocomplete(frontend, "MainModule", Position{5, 15}, nullCallback); + + CHECK(ac.entryMap.count("bar1")); + CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["bar1"].parens == ParenthesesRecommendation::None); + CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::None); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_sealed_table") +{ + check(R"( +local function f(a: { x: number, y: number }) return a.x + a.y end +local fp: = f + )"); + + auto ac = autocomplete(2, 10); + + CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); +} + +// CLI-45692: Remove UnfrozenFixture here +TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_keywords") +{ + check(R"( +local function a(x: boolean) end +local function b(x: number?) end +local function c(x: (number) -> string) end +local function d(x: ((number) -> string)?) end +local function e(x: ((number) -> string) & ((boolean) -> number)) end + +local tru = {} +local ni = false + +local ac = a(t) +local bc = b(n) +local cc = c(f) +local dc = d(f) +local ec = e(f) + )"); + + auto ac = autocomplete(frontend, "MainModule", Position{10, 14}, nullCallback); + CHECK(ac.entryMap.count("tru")); + CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); + CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(frontend, "MainModule", Position{11, 14}, nullCallback); + CHECK(ac.entryMap.count("ni")); + CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); + CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(frontend, "MainModule", Position{12, 14}, nullCallback); + CHECK(ac.entryMap.count("false")); + CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); + CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(frontend, "MainModule", Position{13, 14}, nullCallback); + CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); + + ac = autocomplete(frontend, "MainModule", Position{14, 14}, nullCallback); + CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); +} + +TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_for_overloads") +{ + check(R"( +local target: ((number) -> string) & ((string) -> number)) + +local one = 4 +local two = "hello" +return target(o) + )"); + + auto ac = autocomplete(5, 15); + + CHECK(ac.entryMap.count("one")); + CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); + + check(R"( +local target: ((number) -> string) & ((number) -> number)) + +local one = 4 +local two = "hello" +return target(o) + )"); + + ac = autocomplete(5, 15); + + CHECK(ac.entryMap.count("one")); + CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::None); + + check(R"( +local target: ((number, number) -> string) & ((string) -> number)) + +local one = 4 +local two = "hello" +return target(1, o) + )"); + + ac = autocomplete(5, 18); + + CHECK(ac.entryMap.count("one")); + CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); + CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::None); +} + +TEST_CASE_FIXTURE(ACFixture, "optional_members") +{ + check(R"( +local a = { x = 2, y = 3 } +type A = typeof(a) +local b: A? = a +return b. + )"); + + auto ac = autocomplete(4, 9); + + CHECK_EQ(2, ac.entryMap.size()); + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + check(R"( +local a = { x = 2, y = 3 } +type A = typeof(a) +local b: nil | A = a +return b. + )"); + + ac = autocomplete(4, 9); + + CHECK_EQ(2, ac.entryMap.size()); + CHECK(ac.entryMap.count("x")); + CHECK(ac.entryMap.count("y")); + + check(R"( +local b: nil | nil +return b. + )"); + + ac = autocomplete(2, 9); + + CHECK_EQ(0, ac.entryMap.size()); +} + +TEST_CASE_FIXTURE(ACFixture, "no_function_name_suggestions") +{ + check(R"( +function na + )"); + + auto ac = autocomplete(1, 11); + + CHECK(ac.entryMap.empty()); + + check(R"( +local function + )"); + + ac = autocomplete(1, 15); + + CHECK(ac.entryMap.empty()); + + check(R"( +local function na + )"); + + ac = autocomplete(1, 17); + + CHECK(ac.entryMap.empty()); +} + +TEST_CASE_FIXTURE(ACFixture, "skip_current_local") +{ + check(R"( +local other = 1 +local name = na + )"); + + auto ac = autocomplete(2, 15); + + CHECK(!ac.entryMap.count("name")); + CHECK(ac.entryMap.count("other")); + + check(R"( +local other = 1 +local name, test = na + )"); + + ac = autocomplete(2, 21); + + CHECK(!ac.entryMap.count("name")); + CHECK(!ac.entryMap.count("test")); + CHECK(ac.entryMap.count("other")); +} + +TEST_CASE_FIXTURE(ACFixture, "keyword_members") +{ + check(R"( +local a = { done = 1, forever = 2 } +local b = a.do +local c = a.for +local d = a. +do +end + )"); + + auto ac = autocomplete(2, 14); + + CHECK_EQ(2, ac.entryMap.size()); + CHECK(ac.entryMap.count("done")); + CHECK(ac.entryMap.count("forever")); + + ac = autocomplete(3, 15); + + CHECK_EQ(2, ac.entryMap.size()); + CHECK(ac.entryMap.count("done")); + CHECK(ac.entryMap.count("forever")); + + ac = autocomplete(4, 12); + + CHECK_EQ(2, ac.entryMap.size()); + CHECK(ac.entryMap.count("done")); + CHECK(ac.entryMap.count("forever")); +} + +TEST_CASE_FIXTURE(ACFixture, "keyword_methods") +{ + check(R"( +local a = {} +function a:done() end +local b = a:do + )"); + + auto ac = autocomplete(3, 14); + + CHECK_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("done")); +} + +TEST_CASE_FIXTURE(ACFixture, "keyword_types") +{ + fileResolver.source["Module/A"] = R"( +export type done = { x: number, y: number } +export type other = { z: number, w: number } +return {} + )"; + + LUAU_REQUIRE_NO_ERRORS(frontend.check("Module/A")); + + fileResolver.source["Module/B"] = R"( +local aaa = require(script.Parent.A) +local a: aaa.do + )"; + + frontend.check("Module/B"); + + auto ac = Luau::autocomplete(frontend, "Module/B", Position{2, 15}, nullCallback); + + CHECK_EQ(2, ac.entryMap.size()); + CHECK(ac.entryMap.count("done")); + CHECK(ac.entryMap.count("other")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") +{ + std::string_view source = R"( + local a = table. -- Line 1 + -- | Column 23 + )"; + + auto ac = autocompleteSource(frontend, source, Position{1, 24}, nullCallback).result; + + CHECK_EQ(16, ac.entryMap.size()); + CHECK(ac.entryMap.count("find")); + CHECK(ac.entryMap.count("pack")); + CHECK(!ac.entryMap.count("math")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_require") +{ + ScopedFastFlag luauResolveModuleNameWithoutACurrentModule("LuauResolveModuleNameWithoutACurrentModule", true); + + std::string_view source = R"( + local a = require(w -- Line 1 + -- | Column 27 + )"; + + // CLI-43699 require shouldn't crash inside autocompleteSource + auto ac = autocompleteSource(frontend, source, Position{1, 27}, nullCallback).result; +} + +TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_comments") +{ + std::string_view source = "--!str"; + + auto ac = autocompleteSource(frontend, source, Position{0, 6}, nullCallback).result; + CHECK_EQ(0, ac.entryMap.size()); +} + +TEST_CASE_FIXTURE(ACFixture, "autocompleteProp_index_function_metamethod_is_variadic") +{ + std::string_view source = R"( + type Foo = {x: number} + local t = {} + setmetatable(t, { + __index = function(index: string): ...Foo + return {x = 1}, {x = 2} + end + }) + + local a = t. -- Line 9 + -- | Column 20 + )"; + + auto ac = autocompleteSource(frontend, source, Position{9, 20}, nullCallback).result; + REQUIRE_EQ(1, ac.entryMap.size()); + CHECK(ac.entryMap.count("x")); +} + +TEST_CASE_FIXTURE(ACFixture, "if_then_else_full_keywords") +{ + check(R"( +local thenceforth = false +local elsewhere = false +local doover = false +local endurance = true + +if 1 then +else +end + +while false do +end + +repeat +until + )"); + + auto ac = autocomplete(6, 9); + CHECK(ac.entryMap.size() == 1); + CHECK(ac.entryMap.count("then")); + + ac = autocomplete(7, 4); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); + + ac = autocomplete(10, 14); + CHECK(ac.entryMap.count("do")); + + ac = autocomplete(13, 6); + CHECK(ac.entryMap.count("do")); + + // FIXME: ideally we want to handle start and end of all statements as well +} + +TEST_CASE_FIXTURE(ACFixture, "if_then_else_elseif_completions") +{ + ScopedFastFlag sff{"ElseElseIfCompletionImprovements", true}; + + check(R"( +local elsewhere = false + +if true then + return 1 +el +end + )"); + + auto ac = autocomplete(5, 2); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); + CHECK(ac.entryMap.count("elsewhere") == 0); + + check(R"( +local elsewhere = false + +if true then + return 1 +else + return 2 +el +end + )"); + + ac = autocomplete(7, 2); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + CHECK(ac.entryMap.count("elsewhere")); + + check(R"( +local elsewhere = false + +if true then + print("1") +elif true then + print("2") +el +end + )"); + ac = autocomplete(7, 2); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); + CHECK(ac.entryMap.count("elsewhere")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_not_the_var_we_are_defining") +{ + std::string_view source = "abc,de"; + + auto ac = autocompleteSource(frontend, source, Position{0, 6}, nullCallback).result; + CHECK(!ac.entryMap.count("de")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_recursive_function") +{ + { + std::string_view global = R"(function abc() + +end +)"; + + auto ac = autocompleteSource(frontend, global, Position{1, 0}, nullCallback).result; + CHECK(ac.entryMap.count("abc")); + } + + { + std::string_view local = R"(local function abc() + +end +)"; + + auto ac = autocompleteSource(frontend, local, Position{1, 0}, nullCallback).result; + CHECK(ac.entryMap.count("abc")); + } +} + +TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") +{ + check(R"( +type Test = { first: number, second: number } +local t: Test = { f } + )"); + + auto ac = autocomplete(2, 19); + CHECK(ac.entryMap.count("first")); + CHECK(ac.entryMap.count("second")); + + // Intersection + check(R"( +type Test = { first: number } & { second: number } +local t: Test = { f } + )"); + + ac = autocomplete(2, 19); + CHECK(ac.entryMap.count("first")); + CHECK(ac.entryMap.count("second")); + + // Union + check(R"( +type Test = { first: number, second: number } | { second: number, third: number } +local t: Test = { s } + )"); + + ac = autocomplete(2, 19); + CHECK(ac.entryMap.count("second")); + CHECK(!ac.entryMap.count("first")); + CHECK(!ac.entryMap.count("third")); + + // No parenthesis suggestion + check(R"( +type Test = { first: (number) -> number, second: number } +local t: Test = { f } + )"); + + ac = autocomplete(2, 19); + CHECK(ac.entryMap.count("first")); + CHECK(ac.entryMap["first"].parens == ParenthesesRecommendation::None); + + // When key is changed + check(R"( +type Test = { first: number, second: number } +local t: Test = { f = 2 } + )"); + + ac = autocomplete(2, 19); + CHECK(ac.entryMap.count("first")); + CHECK(ac.entryMap.count("second")); + + // Alternative key syntax + check(R"( +type Test = { first: number, second: number } +local t: Test = { ["f"] } + )"); + + ac = autocomplete(2, 21); + CHECK(ac.entryMap.count("first")); + CHECK(ac.entryMap.count("second")); + + // Not an alternative key syntax + check(R"( +type Test = { first: number, second: number } +local t: Test = { "f" } + )"); + + ac = autocomplete(2, 20); + CHECK(!ac.entryMap.count("first")); + CHECK(!ac.entryMap.count("second")); + + // Skip keys that are already defined + check(R"( +type Test = { first: number, second: number } +local t: Test = { first = 2, s } + )"); + + ac = autocomplete(2, 30); + CHECK(!ac.entryMap.count("first")); + CHECK(ac.entryMap.count("second")); + + // Don't skip active key + check(R"( +type Test = { first: number, second: number } +local t: Test = { first } + )"); + + ac = autocomplete(2, 23); + CHECK(ac.entryMap.count("first")); + CHECK(ac.entryMap.count("second")); + + // Inference after first key + check(R"( +local t = { + { first = 5, second = 10 }, + { f } +} + )"); + + ac = autocomplete(3, 7); + CHECK(ac.entryMap.count("first")); + CHECK(ac.entryMap.count("second")); + + check(R"( +local t = { + [2] = { first = 5, second = 10 }, + [5] = { f } +} + )"); + + ac = autocomplete(3, 13); + CHECK(ac.entryMap.count("first")); + CHECK(ac.entryMap.count("second")); +} + +TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") +{ + loadDefinition(R"( + declare y: { + x: number, + } + )"); + + fileResolver.source["Module/A"] = R"( + local a = y. + )"; + + frontend.check("Module/A"); + + auto ac = autocomplete(frontend, "Module/A", Position{1, 21}, nullCallback); + + REQUIRE(ac.entryMap.count("x")); + CHECK_EQ(ac.entryMap["x"].documentationSymbol, "@test/global/y.x"); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") +{ + ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + + { + check(R"( +local temp = false +local even = true; +local a = true +a = if t1@emp then t +a = if temp t2@ +a = if temp then e3@ +a = if temp then even e4@ +a = if temp then even elseif t5@ +a = if temp then even elseif true t6@ +a = if temp then even elseif true then t7@ +a = if temp then even elseif true then temp e8@ +a = if temp then even elseif true then temp else e9@ + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('2'); + CHECK(ac.entryMap.count("temp") == 0); + CHECK(ac.entryMap.count("true") == 0); + CHECK(ac.entryMap.count("then")); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('3'); + CHECK(ac.entryMap.count("even")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('4'); + CHECK(ac.entryMap.count("even") == 0); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); + + ac = autocomplete('5'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('6'); + CHECK(ac.entryMap.count("temp") == 0); + CHECK(ac.entryMap.count("true") == 0); + CHECK(ac.entryMap.count("then")); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('7'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('8'); + CHECK(ac.entryMap.count("even") == 0); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); + + ac = autocomplete('9'); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + } +} + +TEST_SUITE_END(); diff --git a/tests/BuiltinDefinitions.test.cpp b/tests/BuiltinDefinitions.test.cpp new file mode 100644 index 0000000..dbe80f2 --- /dev/null +++ b/tests/BuiltinDefinitions.test.cpp @@ -0,0 +1,45 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/BuiltinDefinitions.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("BuiltinDefinitionsTest"); + +TEST_CASE_FIXTURE(Fixture, "lib_documentation_symbols") +{ + for (const auto& [name, binding] : typeChecker.globalScope->bindings) + { + std::string nameString(name.c_str()); + std::string expectedRootSymbol = "@luau/global/" + nameString; + std::optional actualRootSymbol = binding.documentationSymbol; + CHECK_MESSAGE( + actualRootSymbol == expectedRootSymbol, "expected symbol ", expectedRootSymbol, " for global ", nameString, ", got ", actualRootSymbol); + + const TableTypeVar::Props* props = nullptr; + if (const TableTypeVar* ttv = get(binding.typeId)) + { + props = &ttv->props; + } + else if (const ClassTypeVar* ctv = get(binding.typeId)) + { + props = &ctv->props; + } + + if (props) + { + for (const auto& [propName, prop] : *props) + { + std::string fullPropName = nameString + "." + propName; + std::string expectedPropSymbol = expectedRootSymbol + "." + propName; + std::optional actualPropSymbol = prop.documentationSymbol; + CHECK_MESSAGE(actualPropSymbol == expectedPropSymbol, "expected symbol ", expectedPropSymbol, " for ", fullPropName, ", got ", + actualPropSymbol); + } + } + } +} diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp new file mode 100644 index 0000000..54a31a6 --- /dev/null +++ b/tests/Compiler.test.cpp @@ -0,0 +1,3662 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Compiler.h" +#include "Luau/BytecodeBuilder.h" +#include "Luau/StringUtils.h" + +#include "ScopedFlags.h" + +#include "doctest.h" + +#include +#include + +LUAU_FASTFLAG(LuauPreloadClosures) +LUAU_FASTFLAG(LuauPreloadClosuresFenv) +LUAU_FASTFLAG(LuauPreloadClosuresUpval) + +using namespace Luau; + +static std::string compileFunction(const char* source, uint32_t id) +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::compileOrThrow(bcb, source); + + return bcb.dumpFunction(id); +} + +static std::string compileFunction0(const char* source) +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::compileOrThrow(bcb, source); + + return bcb.dumpFunction(0); +} + +static std::string compileFunction0Coverage(const char* source, int level) +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + + Luau::CompileOptions opts; + opts.coverageLevel = level; + Luau::compileOrThrow(bcb, source, opts); + + return bcb.dumpFunction(0); +} + +TEST_SUITE_BEGIN("Compiler"); + +TEST_CASE("CompileToBytecode") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::compileOrThrow(bcb, "return 5, 6.5"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +LOADN R0 5 +LOADK R1 K0 +RETURN R0 2 +)"); +} + +TEST_CASE("LocalsDirectReference") +{ + CHECK_EQ("\n" + compileFunction0("local a return a"), R"( +LOADNIL R0 +RETURN R0 1 +)"); +} + +TEST_CASE("BasicFunction") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::compileOrThrow(bcb, "local function foo(a, b) return b end"); + + if (FFlag::LuauPreloadClosures) + { + CHECK_EQ("\n" + bcb.dumpFunction(1), R"( +DUPCLOSURE R0 K0 +RETURN R0 0 +)"); + } + else + { + CHECK_EQ("\n" + bcb.dumpFunction(1), R"( +NEWCLOSURE R0 P0 +RETURN R0 0 +)"); + } + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +RETURN R1 1 +)"); +} + +TEST_CASE("BasicFunctionCall") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::compileOrThrow(bcb, "local function foo(a, b) return b end function test() return foo(2) end"); + + CHECK_EQ("\n" + bcb.dumpFunction(1), R"( +GETUPVAL R0 0 +LOADN R1 2 +CALL R0 1 -1 +RETURN R0 -1 +)"); +} + +TEST_CASE("FunctionCallOptimization") +{ + // direct call into local + CHECK_EQ("\n" + compileFunction0("local foo = math.foo()"), R"( +GETIMPORT R0 2 +CALL R0 0 1 +RETURN R0 0 +)"); + + // direct call into temp + CHECK_EQ("\n" + compileFunction0("local foo = math.foo(math.bar())"), R"( +GETIMPORT R0 2 +GETIMPORT R1 4 +CALL R1 0 -1 +CALL R0 -1 1 +RETURN R0 0 +)"); + + // can't directly call into local since foo might be used as arguments of caller + CHECK_EQ("\n" + compileFunction0("local foo foo = math.foo(foo)"), R"( +LOADNIL R0 +GETIMPORT R1 2 +MOVE R2 R0 +CALL R1 1 1 +MOVE R0 R1 +RETURN R0 0 +)"); +} + +TEST_CASE("ReflectionBytecode") +{ + CHECK_EQ("\n" + compileFunction0(R"( +local part = Instance.new('Part', workspace) +part.Size = Vector3.new(1, 2, 3) +return part.Size.Z * part:GetMass() +)"), + R"( +GETIMPORT R0 2 +LOADK R1 K3 +GETIMPORT R2 5 +CALL R0 2 1 +GETIMPORT R1 7 +LOADN R2 1 +LOADN R3 2 +LOADN R4 3 +CALL R1 3 1 +SETTABLEKS R1 R0 K8 +GETTABLEKS R3 R0 K8 +GETTABLEKS R2 R3 K9 +NAMECALL R3 R0 K10 +CALL R3 1 1 +MUL R1 R2 R3 +RETURN R1 1 +)"); +} + +TEST_CASE("ImportCall") +{ + CHECK_EQ("\n" + compileFunction0("return math.max(1, 2)"), R"( +LOADN R1 1 +FASTCALL2K 18 R1 K0 +4 +LOADK R2 K0 +GETIMPORT R0 3 +CALL R0 2 -1 +RETURN R0 -1 +)"); +} + +TEST_CASE("FakeImportCall") +{ + const char* source = "math = {} function math.max() return 0 end function test() return math.max(1, 2) end"; + + CHECK_EQ("\n" + compileFunction(source, 1), R"( +GETGLOBAL R1 K0 +GETTABLEKS R0 R1 K1 +LOADN R1 1 +LOADN R2 2 +CALL R0 2 -1 +RETURN R0 -1 +)"); +} + +TEST_CASE("AssignmentLocal") +{ + CHECK_EQ("\n" + compileFunction0("local a a = 2"), R"( +LOADNIL R0 +LOADN R0 2 +RETURN R0 0 +)"); +} + +TEST_CASE("AssignmentGlobal") +{ + CHECK_EQ("\n" + compileFunction0("a = 2"), R"( +LOADN R0 2 +SETGLOBAL R0 K0 +RETURN R0 0 +)"); +} + +TEST_CASE("AssignmentTable") +{ + const char* source = "local c = ... local a = {} a.b = 2 a.b = c"; + + CHECK_EQ("\n" + compileFunction0(source), R"( +GETVARARGS R0 1 +NEWTABLE R1 1 0 +LOADN R2 2 +SETTABLEKS R2 R1 K0 +SETTABLEKS R0 R1 K0 +RETURN R0 0 +)"); +} + +TEST_CASE("ConcatChainOptimization") +{ + CHECK_EQ("\n" + compileFunction0("return '1' .. '2'"), R"( +LOADK R1 K0 +LOADK R2 K1 +CONCAT R0 R1 R2 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return '1' .. '2' .. '3'"), R"( +LOADK R1 K0 +LOADK R2 K1 +LOADK R3 K2 +CONCAT R0 R1 R3 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return ('1' .. '2') .. '3'"), R"( +LOADK R3 K0 +LOADK R4 K1 +CONCAT R1 R3 R4 +LOADK R2 K2 +CONCAT R0 R1 R2 +RETURN R0 1 +)"); +} + +TEST_CASE("RepeatLocals") +{ + CHECK_EQ("\n" + compileFunction0("repeat local a a = 5 until a - 4 < 0 or a - 4 >= 0"), R"( +LOADNIL R0 +LOADN R0 5 +SUBK R1 R0 K0 +LOADN R2 0 +JUMPIFLT R1 R2 +6 +SUBK R1 R0 K0 +LOADN R2 0 +JUMPIFLE R2 R1 +2 +JUMPBACK -11 +RETURN R0 0 +)"); +} + +TEST_CASE("ForBytecode") +{ + // basic for loop: variable directly refers to internal iteration index (R2) + CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( +LOADN R2 1 +LOADN R0 5 +LOADN R1 1 +FORNPREP R0 +5 +GETIMPORT R3 1 +MOVE R4 R2 +CALL R3 1 0 +FORNLOOP R0 -5 +RETURN R0 0 +)"); + + // when you assign the variable internally, we freak out and copy the variable so that you aren't changing the loop behavior + CHECK_EQ("\n" + compileFunction0("for i=1,5 do i = 7 print(i) end"), R"( +LOADN R2 1 +LOADN R0 5 +LOADN R1 1 +FORNPREP R0 +7 +MOVE R3 R2 +LOADN R3 7 +GETIMPORT R4 1 +MOVE R5 R3 +CALL R4 1 0 +FORNLOOP R0 -7 +RETURN R0 0 +)"); + + // basic for-in loop, generic version + CHECK_EQ("\n" + compileFunction0("for word in string.gmatch(\"Hello Lua user\", \"%a+\") do print(word) end"), R"( +GETIMPORT R0 2 +LOADK R1 K3 +LOADK R2 K4 +CALL R0 2 3 +JUMP +4 +GETIMPORT R5 6 +MOVE R6 R3 +CALL R5 1 0 +FORGLOOP R0 -5 1 +RETURN R0 0 +)"); + + // basic for-in loop, using inext specialization + CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do print(k,v) end"), R"( +GETIMPORT R0 1 +NEWTABLE R1 0 0 +CALL R0 1 3 +FORGPREP_INEXT R0 +5 +GETIMPORT R5 3 +MOVE R6 R3 +MOVE R7 R4 +CALL R5 2 0 +FORGLOOP_INEXT R0 -6 +RETURN R0 0 +)"); + + // basic for-in loop, using next specialization + CHECK_EQ("\n" + compileFunction0("for k,v in pairs({}) do print(k,v) end"), R"( +GETIMPORT R0 1 +NEWTABLE R1 0 0 +CALL R0 1 3 +FORGPREP_NEXT R0 +5 +GETIMPORT R5 3 +MOVE R6 R3 +MOVE R7 R4 +CALL R5 2 0 +FORGLOOP_NEXT R0 -6 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("for k,v in next,{} do print(k,v) end"), R"( +GETIMPORT R0 1 +NEWTABLE R1 0 0 +LOADNIL R2 +FORGPREP_NEXT R0 +5 +GETIMPORT R5 3 +MOVE R6 R3 +MOVE R7 R4 +CALL R5 2 0 +FORGLOOP_NEXT R0 -6 +RETURN R0 0 +)"); +} + +TEST_CASE("ForBytecodeBuiltin") +{ + // we generally recognize builtins like pairs/ipairs and emit special opcodes + CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( +GETIMPORT R0 1 +NEWTABLE R1 0 0 +CALL R0 1 3 +FORGPREP_INEXT R0 +0 +FORGLOOP_INEXT R0 -1 +RETURN R0 0 +)"); + + // ... even if they are using a local variable + CHECK_EQ("\n" + compileFunction0("local ip = ipairs for k,v in ip({}) do end"), R"( +GETIMPORT R0 1 +MOVE R1 R0 +NEWTABLE R2 0 0 +CALL R1 1 3 +FORGPREP_INEXT R1 +0 +FORGLOOP_INEXT R1 -1 +RETURN R0 0 +)"); + + // ... even when it's an upvalue + CHECK_EQ("\n" + compileFunction0("local ip = ipairs function foo() for k,v in ip({}) do end end"), R"( +GETUPVAL R0 0 +NEWTABLE R1 0 0 +CALL R0 1 3 +FORGPREP_INEXT R0 +0 +FORGLOOP_INEXT R0 -1 +RETURN R0 0 +)"); + + // but if it's reassigned then all bets are off + CHECK_EQ("\n" + compileFunction0("local ip = ipairs ip = pairs for k,v in ip({}) do end"), R"( +GETIMPORT R0 1 +GETIMPORT R0 3 +MOVE R1 R0 +NEWTABLE R2 0 0 +CALL R1 1 3 +JUMP +0 +FORGLOOP R1 -1 2 +RETURN R0 0 +)"); + + // or if the global is hijacked + CHECK_EQ("\n" + compileFunction0("ipairs = pairs for k,v in ipairs({}) do end"), R"( +GETIMPORT R0 1 +SETGLOBAL R0 K2 +GETGLOBAL R0 K2 +NEWTABLE R1 0 0 +CALL R0 1 3 +JUMP +0 +FORGLOOP R0 -1 2 +RETURN R0 0 +)"); + + // or if we don't even know the global to begin with + CHECK_EQ("\n" + compileFunction0("for k,v in unknown({}) do end"), R"( +GETIMPORT R0 1 +NEWTABLE R1 0 0 +CALL R0 1 3 +JUMP +0 +FORGLOOP R0 -1 2 +RETURN R0 0 +)"); +} + +TEST_CASE("TableLiterals") +{ + // empty table, note it's computed directly to target + CHECK_EQ("\n" + compileFunction0("return {}"), R"( +NEWTABLE R0 0 0 +RETURN R0 1 +)"); + + // we can't compute directly to target since that'd overwrite the local + CHECK_EQ("\n" + compileFunction0("local a a = {a} return a"), R"( +LOADNIL R0 +NEWTABLE R1 0 1 +MOVE R2 R0 +SETLIST R1 R2 1 [1] +MOVE R0 R1 +RETURN R0 1 +)"); + + // short list + CHECK_EQ("\n" + compileFunction0("return {1,2,3}"), R"( +NEWTABLE R0 0 3 +LOADN R1 1 +LOADN R2 2 +LOADN R3 3 +SETLIST R0 R1 3 [1] +RETURN R0 1 +)"); + + // long list, split into two chunks + CHECK_EQ("\n" + compileFunction0("return {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17}"), R"( +NEWTABLE R0 0 17 +LOADN R1 1 +LOADN R2 2 +LOADN R3 3 +LOADN R4 4 +LOADN R5 5 +LOADN R6 6 +LOADN R7 7 +LOADN R8 8 +LOADN R9 9 +LOADN R10 10 +LOADN R11 11 +LOADN R12 12 +LOADN R13 13 +LOADN R14 14 +LOADN R15 15 +LOADN R16 16 +SETLIST R0 R1 16 [1] +LOADN R1 17 +SETLIST R0 R1 1 [17] +RETURN R0 1 +)"); + + // varargs; -1 indicates multret treatment; note that we don't allocate space for the ... + CHECK_EQ("\n" + compileFunction0("return {...}"), R"( +NEWTABLE R0 0 0 +GETVARARGS R1 -1 +SETLIST R0 R1 -1 [1] +RETURN R0 1 +)"); + + // varargs with other elements; -1 indicates multret treatment; note that we don't allocate space for the ... + CHECK_EQ("\n" + compileFunction0("return {1,2,3,...}"), R"( +NEWTABLE R0 0 3 +LOADN R1 1 +LOADN R2 2 +LOADN R3 3 +GETVARARGS R4 -1 +SETLIST R0 R1 -1 [1] +RETURN R0 1 +)"); + + // basic literals; note that we use DUPTABLE instead of NEWTABLE + CHECK_EQ("\n" + compileFunction0("return {a=1,b=2,c=3}"), R"( +DUPTABLE R0 3 +LOADN R1 1 +SETTABLEKS R1 R0 K0 +LOADN R1 2 +SETTABLEKS R1 R0 K1 +LOADN R1 3 +SETTABLEKS R1 R0 K2 +RETURN R0 1 +)"); + + // literals+array + CHECK_EQ("\n" + compileFunction0("return {a=1,b=2,3,4}"), R"( +NEWTABLE R0 2 2 +LOADN R3 1 +SETTABLEKS R3 R0 K0 +LOADN R3 2 +SETTABLEKS R3 R0 K1 +LOADN R1 3 +LOADN R2 4 +SETLIST R0 R1 2 [1] +RETURN R0 1 +)"); + + // expression assignment + CHECK_EQ("\n" + compileFunction0("a = 7 return {[a]=42}"), R"( +LOADN R0 7 +SETGLOBAL R0 K0 +NEWTABLE R0 1 0 +GETGLOBAL R1 K0 +LOADN R2 42 +SETTABLE R2 R0 R1 +RETURN R0 1 +)"); + + // table template caching; two DUPTABLES out of three use the same slot. Note that caching is order dependent + CHECK_EQ("\n" + compileFunction0("return {a=1,b=2},{b=3,a=4},{a=5,b=6}"), R"( +DUPTABLE R0 2 +LOADN R1 1 +SETTABLEKS R1 R0 K0 +LOADN R1 2 +SETTABLEKS R1 R0 K1 +DUPTABLE R1 3 +LOADN R2 3 +SETTABLEKS R2 R1 K1 +LOADN R2 4 +SETTABLEKS R2 R1 K0 +DUPTABLE R2 2 +LOADN R3 5 +SETTABLEKS R3 R2 K0 +LOADN R3 6 +SETTABLEKS R3 R2 K1 +RETURN R0 3 +)"); +} + +TEST_CASE("TableLiteralsNumberIndex") +{ + // tables with [x] compile to SETTABLEN if the index is short + CHECK_EQ("\n" + compileFunction0("return {[2] = 2, [256] = 256, [0] = 0, [257] = 257}"), R"( +NEWTABLE R0 4 0 +LOADN R1 2 +SETTABLEN R1 R0 2 +LOADN R1 256 +SETTABLEN R1 R0 256 +LOADN R1 0 +LOADN R2 0 +SETTABLE R2 R0 R1 +LOADN R1 257 +LOADN R2 257 +SETTABLE R2 R0 R1 +RETURN R0 1 +)"); + + // tables with [x] where x is sequential compile to correctly sized array + SETTABLEN + CHECK_EQ("\n" + compileFunction0("return {[1] = 1, [2] = 2}"), R"( +NEWTABLE R0 0 2 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 2 +SETTABLEN R1 R0 2 +RETURN R0 1 +)"); + + // when index chain starts with 0, or isn't sequential, we disable the optimization + CHECK_EQ("\n" + compileFunction0("return {[0] = 0, [1] = 1, [2] = 2, [42] = 42}"), R"( +NEWTABLE R0 4 0 +LOADN R1 0 +LOADN R2 0 +SETTABLE R2 R0 R1 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 2 +SETTABLEN R1 R0 2 +LOADN R1 42 +SETTABLEN R1 R0 42 +RETURN R0 1 +)"); + + // we disable this optimization when the table has list elements for simplicity + CHECK_EQ("\n" + compileFunction0("return {[1] = 1, [2] = 2, 3}"), R"( +NEWTABLE R0 2 1 +LOADN R2 1 +SETTABLEN R2 R0 1 +LOADN R2 2 +SETTABLEN R2 R0 2 +LOADN R1 3 +SETLIST R0 R1 1 [1] +RETURN R0 1 +)"); + + // we can also correctly predict the array length for mixed tables + CHECK_EQ("\n" + compileFunction0("return {key = 1, value = 2, [1] = 42}"), R"( +NEWTABLE R0 2 1 +LOADN R1 1 +SETTABLEKS R1 R0 K0 +LOADN R1 2 +SETTABLEKS R1 R0 K1 +LOADN R1 42 +SETTABLEN R1 R0 1 +RETURN R0 1 +)"); +} + +TEST_CASE("EmptyTableHashSizePredictionOptimization") +{ + const char* hashSizeSource = R"( +local t = {} +t.a = 1 +t.b = 1 +t.c = 1 +t.d = 1 +t.e = 1 +t.f = 1 +t.g = 1 +t.h = 1 +t.i = 1 +)"; + + const char* hashSizeSource2 = R"( +local t = {} +t.x = 1 +t.x = 2 +t.x = 3 +t.x = 4 +t.x = 5 +t.x = 6 +t.x = 7 +t.x = 8 +t.x = 9 +)"; + + const char* arraySizeSource = R"( +local t = {} +t[1] = 1 +t[2] = 1 +t[3] = 1 +t[4] = 1 +t[5] = 1 +t[6] = 1 +t[7] = 1 +t[8] = 1 +t[9] = 1 +t[10] = 1 +)"; + + CHECK_EQ("\n" + compileFunction0(hashSizeSource), R"( +NEWTABLE R0 16 0 +LOADN R1 1 +SETTABLEKS R1 R0 K0 +LOADN R1 1 +SETTABLEKS R1 R0 K1 +LOADN R1 1 +SETTABLEKS R1 R0 K2 +LOADN R1 1 +SETTABLEKS R1 R0 K3 +LOADN R1 1 +SETTABLEKS R1 R0 K4 +LOADN R1 1 +SETTABLEKS R1 R0 K5 +LOADN R1 1 +SETTABLEKS R1 R0 K6 +LOADN R1 1 +SETTABLEKS R1 R0 K7 +LOADN R1 1 +SETTABLEKS R1 R0 K8 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0(hashSizeSource2), R"( +NEWTABLE R0 1 0 +LOADN R1 1 +SETTABLEKS R1 R0 K0 +LOADN R1 2 +SETTABLEKS R1 R0 K0 +LOADN R1 3 +SETTABLEKS R1 R0 K0 +LOADN R1 4 +SETTABLEKS R1 R0 K0 +LOADN R1 5 +SETTABLEKS R1 R0 K0 +LOADN R1 6 +SETTABLEKS R1 R0 K0 +LOADN R1 7 +SETTABLEKS R1 R0 K0 +LOADN R1 8 +SETTABLEKS R1 R0 K0 +LOADN R1 9 +SETTABLEKS R1 R0 K0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0(arraySizeSource), R"( +NEWTABLE R0 0 10 +LOADN R1 1 +SETTABLEN R1 R0 1 +LOADN R1 1 +SETTABLEN R1 R0 2 +LOADN R1 1 +SETTABLEN R1 R0 3 +LOADN R1 1 +SETTABLEN R1 R0 4 +LOADN R1 1 +SETTABLEN R1 R0 5 +LOADN R1 1 +SETTABLEN R1 R0 6 +LOADN R1 1 +SETTABLEN R1 R0 7 +LOADN R1 1 +SETTABLEN R1 R0 8 +LOADN R1 1 +SETTABLEN R1 R0 9 +LOADN R1 1 +SETTABLEN R1 R0 10 +RETURN R0 0 +)"); +} + +TEST_CASE("TableSizePredictionSetMetatable") +{ + CHECK_EQ("\n" + compileFunction0(R"( +local t = setmetatable({}, nil) +t.field1 = 1 +t.field2 = 2 +return t +)"), + R"( +GETIMPORT R0 1 +NEWTABLE R1 2 0 +LOADNIL R2 +CALL R0 2 1 +LOADN R1 1 +SETTABLEKS R1 R0 K2 +LOADN R1 2 +SETTABLEKS R1 R0 K3 +RETURN R0 1 +)"); +} + +TEST_CASE("ReflectionEnums") +{ + CHECK_EQ("\n" + compileFunction0("return Enum.EasingStyle.Linear"), R"( +GETIMPORT R0 3 +RETURN R0 1 +)"); +} + +TEST_CASE("CaptureSelf") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::compileOrThrow(bcb, R"( +local MaterialsListClass = {} + +function MaterialsListClass:_MakeToolTip(guiElement, text) + local function updateTooltipPosition() + self._tweakingTooltipFrame = 5 + end + + updateTooltipPosition() +end + +return MaterialsListClass +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(1), R"( +NEWCLOSURE R3 P0 +CAPTURE VAL R0 +MOVE R4 R3 +CALL R4 0 0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +GETUPVAL R0 0 +LOADN R1 5 +SETTABLEKS R1 R0 K0 +RETURN R0 0 +)"); +} + +TEST_CASE("ConditionalBasic") +{ + CHECK_EQ("\n" + compileFunction0("local a = ... if a then return 5 end"), R"( +GETVARARGS R0 1 +JUMPIFNOT R0 +2 +LOADN R1 5 +RETURN R1 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = ... if not a then return 5 end"), R"( +GETVARARGS R0 1 +JUMPIF R0 +2 +LOADN R1 5 +RETURN R1 1 +RETURN R0 0 +)"); +} + +TEST_CASE("ConditionalCompare") +{ + CHECK_EQ("\n" + compileFunction0("local a, b = ... if a < b then return 5 end"), R"( +GETVARARGS R0 2 +JUMPIFNOTLT R0 R1 +3 +LOADN R2 5 +RETURN R2 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a, b = ... if a <= b then return 5 end"), R"( +GETVARARGS R0 2 +JUMPIFNOTLE R0 R1 +3 +LOADN R2 5 +RETURN R2 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a, b = ... if a > b then return 5 end"), R"( +GETVARARGS R0 2 +JUMPIFNOTLT R1 R0 +3 +LOADN R2 5 +RETURN R2 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a, b = ... if a >= b then return 5 end"), R"( +GETVARARGS R0 2 +JUMPIFNOTLE R1 R0 +3 +LOADN R2 5 +RETURN R2 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a, b = ... if a == b then return 5 end"), R"( +GETVARARGS R0 2 +JUMPIFNOTEQ R0 R1 +3 +LOADN R2 5 +RETURN R2 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a, b = ... if a ~= b then return 5 end"), R"( +GETVARARGS R0 2 +JUMPIFEQ R0 R1 +3 +LOADN R2 5 +RETURN R2 1 +RETURN R0 0 +)"); +} + +TEST_CASE("ConditionalNot") +{ + CHECK_EQ("\n" + compileFunction0("local a, b = ... if not (not (a < b)) then return 5 end"), R"( +GETVARARGS R0 2 +JUMPIFNOTLT R0 R1 +3 +LOADN R2 5 +RETURN R2 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a, b = ... if not (not (not (a < b))) then return 5 end"), R"( +GETVARARGS R0 2 +JUMPIFLT R0 R1 +3 +LOADN R2 5 +RETURN R2 1 +RETURN R0 0 +)"); +} + +TEST_CASE("ConditionalAndOr") +{ + CHECK_EQ("\n" + compileFunction0("local a, b, c = ... if a < b and b < c then return 5 end"), R"( +GETVARARGS R0 3 +JUMPIFNOTLT R0 R1 +5 +JUMPIFNOTLT R1 R2 +3 +LOADN R3 5 +RETURN R3 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a, b, c = ... if a < b or b < c then return 5 end"), R"( +GETVARARGS R0 3 +JUMPIFLT R0 R1 +3 +JUMPIFNOTLT R1 R2 +3 +LOADN R3 5 +RETURN R3 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a,b,c,d = ... if (a or b) and not (c and d) then return 5 end"), R"( +GETVARARGS R0 4 +JUMPIF R0 +1 +JUMPIFNOT R1 +4 +JUMPIFNOT R2 +1 +JUMPIF R3 +2 +LOADN R4 5 +RETURN R4 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a,b,c = ... if a or not b or c then return 5 end"), R"( +GETVARARGS R0 3 +JUMPIF R0 +2 +JUMPIFNOT R1 +1 +JUMPIFNOT R2 +2 +LOADN R3 5 +RETURN R3 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a,b,c = ... if a and not b and c then return 5 end"), R"( +GETVARARGS R0 3 +JUMPIFNOT R0 +4 +JUMPIF R1 +3 +JUMPIFNOT R2 +2 +LOADN R3 5 +RETURN R3 1 +RETURN R0 0 +)"); +} + +TEST_CASE("AndOr") +{ + // codegen for constant, local, global for and + CHECK_EQ("\n" + compileFunction0("local a = 1 a = a and 2 return a"), R"( +LOADN R0 1 +ANDK R0 R0 K0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = 1 local b = ... a = a and b return a"), R"( +LOADN R0 1 +GETVARARGS R1 1 +AND R0 R0 R1 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = 1 b = 2 a = a and b return a"), R"( +LOADN R0 1 +LOADN R1 2 +SETGLOBAL R1 K0 +MOVE R1 R0 +JUMPIFNOT R1 +2 +GETGLOBAL R1 K0 +MOVE R0 R1 +RETURN R0 1 +)"); + + // codegen for constant, local, global for or + CHECK_EQ("\n" + compileFunction0("local a = 1 a = a or 2 return a"), R"( +LOADN R0 1 +ORK R0 R0 K0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = 1 local b = ... a = a or b return a"), R"( +LOADN R0 1 +GETVARARGS R1 1 +OR R0 R0 R1 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = 1 b = 2 a = a or b return a"), R"( +LOADN R0 1 +LOADN R1 2 +SETGLOBAL R1 K0 +MOVE R1 R0 +JUMPIF R1 +2 +GETGLOBAL R1 K0 +MOVE R0 R1 +RETURN R0 1 +)"); + + // codegen without a temp variable for and/or when we know we can assign directly into the target register + // note: `a = a` assignment is to disable constant folding for testing purposes + CHECK_EQ("\n" + compileFunction0("local a = 1 a = a b = 2 local c = a and b return c"), R"( +LOADN R0 1 +MOVE R0 R0 +LOADN R1 2 +SETGLOBAL R1 K0 +MOVE R1 R0 +JUMPIFNOT R1 +2 +GETGLOBAL R1 K0 +RETURN R1 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = 1 a = a b = 2 local c = a or b return c"), R"( +LOADN R0 1 +MOVE R0 R0 +LOADN R1 2 +SETGLOBAL R1 K0 +MOVE R1 R0 +JUMPIF R1 +2 +GETGLOBAL R1 K0 +RETURN R1 1 +)"); +} + +TEST_CASE("AndOrChainCodegen") +{ + const char* source = R"( + return + (1 - verticalGradientTurbulence < waterLevel + .015 and Enum.Material.Sand) + or (sandbank>0 and sandbank<1 and Enum.Material.Sand)--this for canyonbase sandbanks + or Enum.Material.Sandstone + )"; + + CHECK_EQ("\n" + compileFunction0(source), R"( +LOADN R2 1 +GETIMPORT R3 1 +SUB R1 R2 R3 +GETIMPORT R3 4 +ADDK R2 R3 K2 +JUMPIFNOTLT R1 R2 +4 +GETIMPORT R0 8 +JUMPIF R0 +15 +GETIMPORT R1 10 +LOADN R2 0 +JUMPIFNOTLT R2 R1 +9 +GETIMPORT R1 10 +LOADN R2 1 +JUMPIFNOTLT R1 R2 +4 +GETIMPORT R0 8 +JUMPIF R0 +2 +GETIMPORT R0 12 +RETURN R0 1 +)"); +} + +TEST_CASE("IfElseExpression") +{ + ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + + // codegen for a true constant condition + CHECK_EQ("\n" + compileFunction0("return if true then 10 else 20"), R"( +LOADN R0 10 +RETURN R0 1 +)"); + + // codegen for a false constant condition + CHECK_EQ("\n" + compileFunction0("return if false then 10 else 20"), R"( +LOADN R0 20 +RETURN R0 1 +)"); + + // codegen for a false (in this case 'nil') constant condition + CHECK_EQ("\n" + compileFunction0("return if nil then 10 else 20"), R"( +LOADN R0 20 +RETURN R0 1 +)"); + + // codegen constant if-else expression used with a binary operation involving another constant + // The test verifies that everything constant folds down to a single constant + CHECK_EQ("\n" + compileFunction0("return 7 + if true then 10 else 20"), R"( +LOADN R0 17 +RETURN R0 1 +)"); + + // codegen for a non-constant condition + CHECK_EQ("\n" + compileFunction0("return if condition then 10 else 20"), R"( +GETIMPORT R1 1 +JUMPIFNOT R1 +2 +LOADN R0 10 +RETURN R0 1 +LOADN R0 20 +RETURN R0 1 +)"); + + // codegen for a non-constant condition using an assignment + CHECK_EQ("\n" + compileFunction0("result = if condition then 10 else 20"), R"( +GETIMPORT R1 1 +JUMPIFNOT R1 +2 +LOADN R0 10 +JUMP +1 +LOADN R0 20 +SETGLOBAL R0 K2 +RETURN R0 0 +)"); + + // codegen for a non-constant condition using an assignment to a local variable + CHECK_EQ("\n" + compileFunction0("local result = if condition then 10 else 20"), R"( +GETIMPORT R1 1 +JUMPIFNOT R1 +2 +LOADN R0 10 +RETURN R0 0 +LOADN R0 20 +RETURN R0 0 +)"); + + // codegen for an if-else expression with multiple elseif's + CHECK_EQ("\n" + compileFunction0("result = if condition1 then 10 elseif condition2 then 20 elseif condition3 then 30 else 40"), R"( +GETIMPORT R1 1 +JUMPIFNOT R1 +2 +LOADN R0 10 +JUMP +11 +GETIMPORT R1 3 +JUMPIFNOT R1 +2 +LOADN R0 20 +JUMP +6 +GETIMPORT R1 5 +JUMPIFNOT R1 +2 +LOADN R0 30 +JUMP +1 +LOADN R0 40 +SETGLOBAL R0 K6 +RETURN R0 0 +)"); +} + +TEST_CASE("ConstantFoldArith") +{ + CHECK_EQ("\n" + compileFunction0("return 10 + 2"), R"( +LOADN R0 12 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return 10 - 2"), R"( +LOADN R0 8 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return 10 * 2"), R"( +LOADN R0 20 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return 10 / 2"), R"( +LOADN R0 5 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return 10 % 2"), R"( +LOADN R0 0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return 10 ^ 2"), R"( +LOADN R0 100 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return -(2 - 5)"), R"( +LOADN R0 3 +RETURN R0 1 +)"); + + // nested arith expression with groups + CHECK_EQ("\n" + compileFunction0("return (2 + 2) * 2"), R"( +LOADN R0 8 +RETURN R0 1 +)"); +} + +TEST_CASE("ConstantFoldCompare") +{ + // ordered comparisons + CHECK_EQ("\n" + compileFunction0("return 1 < 1, 1 < 2"), R"( +LOADB R0 0 +LOADB R1 1 +RETURN R0 2 +)"); + CHECK_EQ("\n" + compileFunction0("return 1 <= 1, 1 <= 2"), R"( +LOADB R0 1 +LOADB R1 1 +RETURN R0 2 +)"); + + CHECK_EQ("\n" + compileFunction0("return 1 > 1, 1 > 2"), R"( +LOADB R0 0 +LOADB R1 0 +RETURN R0 2 +)"); + + CHECK_EQ("\n" + compileFunction0("return 1 >= 1, 1 >= 2"), R"( +LOADB R0 1 +LOADB R1 0 +RETURN R0 2 +)"); + + // equality comparisons + CHECK_EQ("\n" + compileFunction0("return nil == 1, nil ~= 1, nil == nil, nil ~= nil"), R"( +LOADB R0 0 +LOADB R1 1 +LOADB R2 1 +LOADB R3 0 +RETURN R0 4 +)"); + + CHECK_EQ("\n" + compileFunction0("return 2 == 1, 2 ~= 1, 1 == 1, 1 ~= 1"), R"( +LOADB R0 0 +LOADB R1 1 +LOADB R2 1 +LOADB R3 0 +RETURN R0 4 +)"); + + CHECK_EQ("\n" + compileFunction0("return true == false, true ~= false, true == true, true ~= true"), R"( +LOADB R0 0 +LOADB R1 1 +LOADB R2 1 +LOADB R3 0 +RETURN R0 4 +)"); + + CHECK_EQ("\n" + compileFunction0("return 'a' == 'b', 'a' ~= 'b', 'a' == 'a', 'a' ~= 'a'"), R"( +LOADB R0 0 +LOADB R1 1 +LOADB R2 1 +LOADB R3 0 +RETURN R0 4 +)"); +} + +TEST_CASE("ConstantFoldLocal") +{ + // local constant propagation, including upvalues, and no propagation for mutated locals + CHECK_EQ("\n" + compileFunction0("local a = 1 return a + a"), R"( +LOADN R0 2 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = 1 a = a + a return a"), R"( +LOADN R0 1 +ADD R0 R0 R0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("local a = 1 function foo() return a + a end", 0), R"( +LOADN R0 2 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction("local a = 1 function foo() return a + a end function bar() a = 5 end", 0), R"( +GETUPVAL R1 0 +GETUPVAL R2 0 +ADD R0 R1 R2 +RETURN R0 1 +)"); + + // local values for multiple assignments + CHECK_EQ("\n" + compileFunction0("local a return a"), R"( +LOADNIL R0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("local a, b = 1, 3 return a + 1, b"), R"( +LOADN R0 2 +LOADN R1 3 +RETURN R0 2 +)"); + + CHECK_EQ("\n" + compileFunction0("local a, b = 1 return a + 1, b"), R"( +LOADN R0 2 +LOADNIL R1 +RETURN R0 2 +)"); + + // local values for multiple assignments w/multret + CHECK_EQ("\n" + compileFunction0("local a, b = ... return a + 1, b"), R"( +GETVARARGS R0 2 +ADDK R2 R0 K0 +MOVE R3 R1 +RETURN R2 2 +)"); + + CHECK_EQ("\n" + compileFunction0("local a, b = 1, ... return a + 1, b"), R"( +LOADN R0 1 +GETVARARGS R1 1 +LOADN R2 2 +MOVE R3 R1 +RETURN R2 2 +)"); +} + +TEST_CASE("ConstantFoldAndOr") +{ + // and/or constant folding when both sides are constant + CHECK_EQ("\n" + compileFunction0("return true and 2"), R"( +LOADN R0 2 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return false and 2"), R"( +LOADB R0 0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return nil and 2"), R"( +LOADNIL R0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return true or 2"), R"( +LOADB R0 1 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return false or 2"), R"( +LOADN R0 2 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return nil or 2"), R"( +LOADN R0 2 +RETURN R0 1 +)"); + + // and/or constant folding when left hand side is constant + CHECK_EQ("\n" + compileFunction0("return true and a"), R"( +GETIMPORT R0 1 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return false and a"), R"( +LOADB R0 0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return true or a"), R"( +LOADB R0 1 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return false or a"), R"( +GETIMPORT R0 1 +RETURN R0 1 +)"); + + // constant fold parts in chains of and/or statements + CHECK_EQ("\n" + compileFunction0("return a and true and b"), R"( +GETIMPORT R0 1 +JUMPIFNOT R0 +2 +GETIMPORT R0 3 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return a or false or b"), R"( +GETIMPORT R0 1 +JUMPIF R0 +2 +GETIMPORT R0 3 +RETURN R0 1 +)"); +} + +TEST_CASE("ConstantFoldConditionalAndOr") +{ + CHECK_EQ("\n" + compileFunction0("local a = ... if false or a then print(1) end"), R"( +GETVARARGS R0 1 +JUMPIFNOT R0 +4 +GETIMPORT R1 1 +LOADN R2 1 +CALL R1 1 0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = ... if not (false or a) then print(1) end"), R"( +GETVARARGS R0 1 +JUMPIF R0 +4 +GETIMPORT R1 1 +LOADN R2 1 +CALL R1 1 0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = ... if true and a then print(1) end"), R"( +GETVARARGS R0 1 +JUMPIFNOT R0 +4 +GETIMPORT R1 1 +LOADN R2 1 +CALL R1 1 0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = ... if not (true and a) then print(1) end"), R"( +GETVARARGS R0 1 +JUMPIF R0 +4 +GETIMPORT R1 1 +LOADN R2 1 +CALL R1 1 0 +RETURN R0 0 +)"); +} + +TEST_CASE("ConstantFoldFlowControl") +{ + // if + CHECK_EQ("\n" + compileFunction0("if true then print(1) end"), R"( +GETIMPORT R0 1 +LOADN R1 1 +CALL R0 1 0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("if false then print(1) end"), R"( +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("if true then print(1) else print(2) end"), R"( +GETIMPORT R0 1 +LOADN R1 1 +CALL R0 1 0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("if false then print(1) else print(2) end"), R"( +GETIMPORT R0 1 +LOADN R1 2 +CALL R0 1 0 +RETURN R0 0 +)"); + + // while + CHECK_EQ("\n" + compileFunction0("while true do print(1) end"), R"( +GETIMPORT R0 1 +LOADN R1 1 +CALL R0 1 0 +JUMPBACK -5 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("while false do print(1) end"), R"( +RETURN R0 0 +)"); + + // repeat + CHECK_EQ("\n" + compileFunction0("repeat print(1) until true"), R"( +GETIMPORT R0 1 +LOADN R1 1 +CALL R0 1 0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("repeat print(1) until false"), R"( +GETIMPORT R0 1 +LOADN R1 1 +CALL R0 1 0 +JUMPBACK -5 +RETURN R0 0 +)"); + + // there's an odd case in repeat..until compilation where we evaluate the expression that is always false for side-effects of the left hand side + CHECK_EQ("\n" + compileFunction0("repeat print(1) until five and false"), R"( +GETIMPORT R0 1 +LOADN R1 1 +CALL R0 1 0 +GETIMPORT R0 3 +JUMPIFNOT R0 +0 +JUMPBACK -8 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopBreak") +{ + // default codegen: compile breaks as unconditional jumps + CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break else end end"), R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFNOTLT R0 R1 +3 +RETURN R0 0 +JUMP +0 +JUMPBACK -9 +RETURN R0 0 +)"); + + // optimization: if then body is a break statement, flip the branches + CHECK_EQ("\n" + compileFunction0("while true do if math.random() < 0.5 then break end end"), R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 +2 +JUMPBACK -7 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopContinue") +{ + // default codegen: compile continue as unconditional jumps + CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue else end break until false error()"), R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFNOTLT R0 R1 +5 +JUMP +2 +JUMP +2 +JUMP +1 +JUMPBACK -10 +GETIMPORT R0 5 +CALL R0 0 0 +RETURN R0 0 +)"); + + // optimization: if then body is a continue statement, flip the branches + CHECK_EQ("\n" + compileFunction0("repeat if math.random() < 0.5 then continue end break until false error()"), R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R0 R1 +2 +JUMP +1 +JUMPBACK -8 +GETIMPORT R0 5 +CALL R0 0 0 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopContinueUntil") +{ + // it's valid to use locals defined inside the loop in until expression if they're defined before continue + CHECK_EQ("\n" + compileFunction0("repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until r < 0.5"), R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R1 R0 +2 +ADDK R0 R0 K4 +LOADK R1 K3 +JUMPIFLT R0 R1 +2 +JUMPBACK -11 +RETURN R0 0 +)"); + + // it's however invalid to use locals if they are defined after continue + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, R"( +repeat + local r = math.random() + if r > 0.5 then + continue + end + local rr = r + 0.3 +until rr < 0.5 +)"); + + CHECK(!"Expected CompileError"); + } + catch (Luau::CompileError& e) + { + CHECK_EQ(e.getLocation().begin.line + 1, 8); + CHECK_EQ( + std::string(e.what()), "Local rr used in the repeat..until condition is undefined because continue statement on line 5 jumps over it"); + } + + // but it's okay if continue is inside a non-repeat..until loop, or inside a loop that doesn't use the local (here `continue` just terminates + // inner loop) + CHECK_EQ("\n" + compileFunction0( + "repeat local r = math.random() repeat if r > 0.5 then continue end r = r - 0.1 until true r = r + 0.3 until r < 0.5"), + R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R1 R0 +2 +SUBK R0 R0 K4 +ADDK R0 R0 K5 +LOADK R1 K3 +JUMPIFLT R0 R1 +2 +JUMPBACK -12 +RETURN R0 0 +)"); + + // and it's also okay to use a local defined in the until expression as long as it's inside a function! + CHECK_EQ( + "\n" + compileFunction( + "repeat local r = math.random() if r > 0.5 then continue end r = r + 0.3 until (function() local a = r return a < 0.5 end)()", 1), + R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFNOTLT R1 R0 +3 +CLOSEUPVALS R0 +JUMP +1 +ADDK R0 R0 K4 +NEWCLOSURE R1 P0 +CAPTURE REF R0 +CALL R1 0 1 +JUMPIF R1 +2 +CLOSEUPVALS R0 +JUMPBACK -15 +CLOSEUPVALS R0 +RETURN R0 0 +)"); + + // but not if the function just refers to an upvalue + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, R"( +repeat + local r = math.random() + if r > 0.5 then + continue + end + local rr = r + 0.3 +until (function() return rr end)() < 0.5 +)"); + + CHECK(!"Expected CompileError"); + } + catch (Luau::CompileError& e) + { + CHECK_EQ(e.getLocation().begin.line + 1, 8); + CHECK_EQ( + std::string(e.what()), "Local rr used in the repeat..until condition is undefined because continue statement on line 5 jumps over it"); + } + + // unless that upvalue is from an outer scope + CHECK_EQ("\n" + compileFunction0("local stop = false stop = true function test() repeat local r = math.random() if r > 0.5 then " + "continue end r = r + 0.3 until stop or r < 0.5 end"), + R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFLT R1 R0 +2 +ADDK R0 R0 K4 +GETUPVAL R1 0 +JUMPIF R1 +4 +LOADK R1 K3 +JUMPIFLT R0 R1 +2 +JUMPBACK -13 +RETURN R0 0 +)"); + + // including upvalue references from a function expression + CHECK_EQ("\n" + compileFunction("local stop = false stop = true function test() repeat local r = math.random() if r > 0.5 then continue " + "end r = r + 0.3 until (function() return stop or r < 0.5 end)() end", + 1), + R"( +GETIMPORT R0 2 +CALL R0 0 1 +LOADK R1 K3 +JUMPIFNOTLT R1 R0 +3 +CLOSEUPVALS R0 +JUMP +1 +ADDK R0 R0 K4 +NEWCLOSURE R1 P0 +CAPTURE UPVAL U0 +CAPTURE REF R0 +CALL R1 0 1 +JUMPIF R1 +2 +CLOSEUPVALS R0 +JUMPBACK -16 +CLOSEUPVALS R0 +RETURN R0 0 +)"); +} + +TEST_CASE("LoopContinueUntilOops") +{ + // this used to crash the compiler :( + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, R"( +local _ +repeat +continue +until not _ +)"); + } + catch (Luau::CompileError& e) + { + CHECK_EQ( + std::string(e.what()), "Local _ used in the repeat..until condition is undefined because continue statement on line 4 jumps over it"); + } +} + +TEST_CASE("AndOrOptimizations") +{ + // the OR/ORK optimization triggers for cutoff since lhs is simple + CHECK_EQ("\n" + compileFunction(R"( +local function advancedRidgedFilter(value, cutoff) + local cutoff = cutoff or .5 + value = value - cutoff + return 1 - (value < 0 and -value or value) * 1 / (1 - cutoff) +end +)", + 0), + R"( +ORK R2 R1 K0 +SUB R0 R0 R2 +LOADN R4 1 +LOADN R8 0 +JUMPIFNOTLT R0 R8 +3 +MINUS R7 R0 +JUMPIF R7 +1 +MOVE R7 R0 +MULK R6 R7 K1 +LOADN R8 1 +SUB R7 R8 R2 +DIV R5 R6 R7 +SUB R3 R4 R5 +RETURN R3 1 +)"); + + // sometimes we need to compute a boolean; this uses LOADB with an offset + CHECK_EQ("\n" + compileFunction(R"( +function thinSurface(surfaceGradient, surfaceThickness) + return surfaceGradient > .5 - surfaceThickness*.4 and surfaceGradient < .5 + surfaceThickness*.4 +end +)", + 0), + R"( +LOADB R2 0 +LOADK R4 K0 +MULK R5 R1 K1 +SUB R3 R4 R5 +JUMPIFNOTLT R3 R0 +8 +LOADK R4 K0 +MULK R5 R1 K1 +ADD R3 R4 R5 +JUMPIFLT R0 R3 +2 +LOADB R2 0 +1 +LOADB R2 1 +RETURN R2 1 +)"); + + // sometimes we need to compute a boolean; this uses LOADB with an offset for the last op, note that first op is compiled better + CHECK_EQ("\n" + compileFunction(R"( +function thickSurface(surfaceGradient, surfaceThickness) + return surfaceGradient < .5 - surfaceThickness*.4 or surfaceGradient > .5 + surfaceThickness*.4 +end +)", + 0), + R"( +LOADB R2 1 +LOADK R4 K0 +MULK R5 R1 K1 +SUB R3 R4 R5 +JUMPIFLT R0 R3 +8 +LOADK R4 K0 +MULK R5 R1 K1 +ADD R3 R4 R5 +JUMPIFLT R3 R0 +2 +LOADB R2 0 +1 +LOADB R2 1 +RETURN R2 1 +)"); + + // trivial ternary if with constants + CHECK_EQ("\n" + compileFunction(R"( +function testSurface(surface) + return surface and 1 or 0 +end +)", + 0), + R"( +JUMPIFNOT R0 +2 +LOADN R1 1 +RETURN R1 1 +LOADN R1 0 +RETURN R1 1 +)"); + + // canonical saturate + CHECK_EQ("\n" + compileFunction(R"( +function saturate(x) + return x < 0 and 0 or x > 1 and 1 or x +end +)", + 0), + R"( +LOADN R2 0 +JUMPIFNOTLT R0 R2 +3 +LOADN R1 0 +RETURN R1 1 +LOADN R2 1 +JUMPIFNOTLT R2 R0 +3 +LOADN R1 1 +RETURN R1 1 +MOVE R1 R0 +RETURN R1 1 +)"); +} + +TEST_CASE("JumpFold") +{ + // jump-to-return folding to return + CHECK_EQ("\n" + compileFunction0("return a and 1 or 0"), R"( +GETIMPORT R1 1 +JUMPIFNOT R1 +2 +LOADN R0 1 +RETURN R0 1 +LOADN R0 0 +RETURN R0 1 +)"); + + // conditional jump in the inner if() folding to jump out of the expression (JUMPIFNOT+5 skips over all jumps, JUMP+1 skips over JUMP+0) + CHECK_EQ("\n" + compileFunction0("if a then if b then b() else end else end d()"), R"( +GETIMPORT R0 1 +JUMPIFNOT R0 +8 +GETIMPORT R0 3 +JUMPIFNOT R0 +5 +GETIMPORT R0 3 +CALL R0 0 0 +JUMP +1 +JUMP +0 +GETIMPORT R0 5 +CALL R0 0 0 +RETURN R0 0 +)"); + + // same as example before but the unconditional jumps are folded with RETURN + CHECK_EQ("\n" + compileFunction0("if a then if b then b() else end else end"), R"( +GETIMPORT R0 1 +JUMPIFNOT R0 +8 +GETIMPORT R0 3 +JUMPIFNOT R0 +5 +GETIMPORT R0 3 +CALL R0 0 0 +RETURN R0 0 +RETURN R0 0 +RETURN R0 0 +)"); + + // in this example, we do *not* have a JUMP after RETURN in the if branch + // this is important since, even though this jump is never reached, jump folding needs to be able to analyze it + CHECK_EQ("\n" + compileFunction(R"( +local function getPerlin(x, y, z, seed, scale, raw) +local seed = seed or 0 +local scale = scale or 1 +if not raw then +return math.noise(x / scale + (seed * 17) + masterSeed, y / scale - masterSeed, z / scale - seed*seed)*.5 + .5 --accounts for bleeding from interpolated line +else +return math.noise(x / scale + (seed * 17) + masterSeed, y / scale - masterSeed, z / scale - seed*seed) +end +end +)", + 0), + R"( +ORK R6 R3 K0 +ORK R7 R4 K1 +JUMPIF R5 +19 +GETIMPORT R10 5 +DIV R13 R0 R7 +MULK R14 R6 K6 +ADD R12 R13 R14 +GETIMPORT R13 8 +ADD R11 R12 R13 +DIV R13 R1 R7 +GETIMPORT R14 8 +SUB R12 R13 R14 +DIV R14 R2 R7 +MUL R15 R6 R6 +SUB R13 R14 R15 +CALL R10 3 1 +MULK R9 R10 K2 +ADDK R8 R9 K2 +RETURN R8 1 +GETIMPORT R8 5 +DIV R11 R0 R7 +MULK R12 R6 K6 +ADD R10 R11 R12 +GETIMPORT R11 8 +ADD R9 R10 R11 +DIV R11 R1 R7 +GETIMPORT R12 8 +SUB R10 R11 R12 +DIV R12 R2 R7 +MUL R13 R6 R6 +SUB R11 R12 R13 +CALL R8 3 -1 +RETURN R8 -1 +)"); +} + +static std::string rep(const std::string& s, size_t n) +{ + std::string r; + r.reserve(s.length() * n); + for (size_t i = 0; i < n; ++i) + r += s; + return r; +} + +TEST_CASE("RecursionParse") +{ + // The test forcibly pushes the stack limit during compilation; in NoOpt, the stack consumption is much larger so we need to reduce the limit to + // not overflow the C stack. When ASAN is enabled, stack consumption increases even more. +#if defined(LUAU_ENABLE_ASAN) + ScopedFastInt flag("LuauRecursionLimit", 200); +#elif defined(_NOOPT) || defined(_DEBUG) + ScopedFastInt flag("LuauRecursionLimit", 300); +#endif + + Luau::BytecodeBuilder bcb; + + try + { + Luau::compileOrThrow(bcb, "a=" + rep("{", 1500) + rep("}", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "function a" + rep(".a", 1500) + "() end"); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your function name to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "a=1" + rep("+1", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, "a=" + rep("(", 1500) + "1" + rep(")", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); + } + + try + { + Luau::compileOrThrow(bcb, rep("do ", 1500) + "print()" + rep(" end", 1500)); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); + } +} + +TEST_CASE("ArrayIndexLiteral") +{ + CHECK_EQ("\n" + compileFunction0("local arr = {} return arr[0], arr[1], arr[256], arr[257]"), R"( +NEWTABLE R0 0 0 +LOADN R2 0 +GETTABLE R1 R0 R2 +GETTABLEN R2 R0 1 +GETTABLEN R3 R0 256 +LOADN R5 257 +GETTABLE R4 R0 R5 +RETURN R1 4 +)"); + + CHECK_EQ("\n" + compileFunction0("local arr = {} local b = ... arr[0] = b arr[1] = b arr[256] = b arr[257] = b"), R"( +NEWTABLE R0 0 1 +GETVARARGS R1 1 +LOADN R2 0 +SETTABLE R1 R0 R2 +SETTABLEN R1 R0 1 +SETTABLEN R1 R0 256 +LOADN R2 257 +SETTABLE R1 R0 R2 +RETURN R0 0 +)"); +} + +TEST_CASE("NestedFunctionCalls") +{ + CHECK_EQ("\n" + compileFunction0("function clamp(t,a,b) return math.min(math.max(t,a),b) end"), R"( +FASTCALL2 18 R0 R1 +5 +MOVE R5 R0 +MOVE R6 R1 +GETIMPORT R4 2 +CALL R4 2 1 +FASTCALL2 19 R4 R2 +4 +MOVE R5 R2 +GETIMPORT R3 4 +CALL R3 2 -1 +RETURN R3 -1 +)"); +} + +TEST_CASE("UpvaluesLoopsBytecode") +{ + CHECK_EQ("\n" + compileFunction(R"( +function test() + for i=1,10 do + i = i + foo(function() return i end) + if bar then + break + end + end + return 0 +end +)", + 1), + R"( +LOADN R2 1 +LOADN R0 10 +LOADN R1 1 +FORNPREP R0 +14 +MOVE R3 R2 +MOVE R3 R3 +GETIMPORT R4 1 +NEWCLOSURE R5 P0 +CAPTURE REF R3 +CALL R4 1 0 +GETIMPORT R4 3 +JUMPIFNOT R4 +2 +CLOSEUPVALS R3 +JUMP +2 +CLOSEUPVALS R3 +FORNLOOP R0 -14 +LOADN R0 0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction(R"( +function test() + for i in ipairs(data) do + i = i + foo(function() return i end) + if bar then + break + end + end + return 0 +end +)", + 1), + R"( +GETIMPORT R0 1 +GETIMPORT R1 3 +CALL R0 1 3 +FORGPREP_INEXT R0 +12 +MOVE R3 R3 +GETIMPORT R5 5 +NEWCLOSURE R6 P0 +CAPTURE REF R3 +CALL R5 1 0 +GETIMPORT R5 7 +JUMPIFNOT R5 +2 +CLOSEUPVALS R3 +JUMP +2 +CLOSEUPVALS R3 +FORGLOOP_INEXT R0 -13 +LOADN R0 0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction(R"( +function test() + local i = 0 + while i < 5 do + local j + j = i + foo(function() return j end) + i = i + 1 + if bar then + break + end + end + return 0 +end +)", + 1), + R"( +LOADN R0 0 +LOADN R1 5 +JUMPIFNOTLT R0 R1 +16 +LOADNIL R1 +MOVE R1 R0 +GETIMPORT R2 1 +NEWCLOSURE R3 P0 +CAPTURE REF R1 +CALL R2 1 0 +ADDK R0 R0 K2 +GETIMPORT R2 4 +JUMPIFNOT R2 +2 +CLOSEUPVALS R1 +JUMP +2 +CLOSEUPVALS R1 +JUMPBACK -18 +LOADN R1 0 +RETURN R1 1 +)"); + + CHECK_EQ("\n" + compileFunction(R"( +function test() + local i = 0 + repeat + local j + j = i + foo(function() return j end) + i = i + 1 + if bar then + break + end + until i < 5 + return 0 +end +)", + 1), + R"( +LOADN R0 0 +LOADNIL R1 +MOVE R1 R0 +GETIMPORT R2 1 +NEWCLOSURE R3 P0 +CAPTURE REF R1 +CALL R2 1 0 +ADDK R0 R0 K2 +GETIMPORT R2 4 +JUMPIFNOT R2 +2 +CLOSEUPVALS R1 +JUMP +6 +LOADN R2 5 +JUMPIFLT R0 R2 +3 +CLOSEUPVALS R1 +JUMPBACK -18 +CLOSEUPVALS R1 +LOADN R1 0 +RETURN R1 1 +)"); +} + +TEST_CASE("TypeAliasing") +{ + Luau::BytecodeBuilder bcb; + Luau::CompileOptions options; + Luau::ParseOptions parseOptions; + CHECK_NOTHROW(Luau::compileOrThrow(bcb, "type A = number local a: A = 1", options, parseOptions)); +} + +TEST_CASE("DebugLineInfo") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local kSelectedBiomes = { + ['Mountains'] = true, + ['Canyons'] = true, + ['Dunes'] = true, + ['Arctic'] = true, + ['Lavaflow'] = true, + ['Hills'] = true, + ['Plains'] = true, + ['Marsh'] = true, + ['Water'] = true, +} +local result = "" +for k in pairs(kSelectedBiomes) do + result = result .. k +end +return result +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: NEWTABLE R0 16 0 +3: LOADB R1 1 +3: SETTABLEKS R1 R0 K0 +4: LOADB R1 1 +4: SETTABLEKS R1 R0 K1 +5: LOADB R1 1 +5: SETTABLEKS R1 R0 K2 +6: LOADB R1 1 +6: SETTABLEKS R1 R0 K3 +7: LOADB R1 1 +7: SETTABLEKS R1 R0 K4 +8: LOADB R1 1 +8: SETTABLEKS R1 R0 K5 +9: LOADB R1 1 +9: SETTABLEKS R1 R0 K6 +10: LOADB R1 1 +10: SETTABLEKS R1 R0 K7 +11: LOADB R1 1 +11: SETTABLEKS R1 R0 K8 +13: LOADK R1 K9 +14: GETIMPORT R2 11 +14: MOVE R3 R0 +14: CALL R2 1 3 +14: FORGPREP_NEXT R2 +3 +15: MOVE R7 R1 +15: MOVE R8 R5 +15: CONCAT R1 R7 R8 +14: FORGLOOP_NEXT R2 -4 +17: RETURN R1 1 +)"); +} + +TEST_CASE("DebugLineInfoFor") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +for +i +in +1 +, +2 +, +3 +do +print(i) +end +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +5: LOADN R0 1 +7: LOADN R1 2 +9: LOADN R2 3 +9: JUMP +4 +11: GETIMPORT R5 1 +11: MOVE R6 R3 +11: CALL R5 1 0 +2: FORGLOOP R0 -5 1 +13: RETURN R0 0 +)"); +} + +TEST_CASE("DebugLineInfoWhile") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local count = 0 +while true do + count += 1 + if count > 1 then + print("done!") + break + end +end +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: LOADN R0 0 +4: ADDK R0 R0 K0 +5: LOADN R1 1 +5: JUMPIFNOTLT R1 R0 +6 +6: GETIMPORT R1 2 +6: LOADK R2 K3 +6: CALL R1 1 0 +10: RETURN R0 0 +3: JUMPBACK -10 +10: RETURN R0 0 +)"); +} + +TEST_CASE("DebugLineInfoRepeatUntil") +{ + CHECK_EQ("\n" + compileFunction0Coverage(R"( +local f = 0 +repeat + f += 1 + if f == 1 then + print(f) + else + f = 0 + end +until f == 0 +)", + 0), + R"( +2: LOADN R0 0 +4: ADDK R0 R0 K0 +5: JUMPIFNOTEQK R0 K0 +6 +6: GETIMPORT R1 2 +6: MOVE R2 R0 +6: CALL R1 1 0 +6: JUMP +1 +8: LOADN R0 0 +10: JUMPIFEQK R0 K3 +2 +10: JUMPBACK -12 +11: RETURN R0 0 +)"); +} + +TEST_CASE("DebugLineInfoSubTable") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local Value1, Value2, Value3 = ... +local Table = {} + +Table.SubTable["Key"] = { + Key1 = Value1, + Key2 = Value2, + Key3 = Value3, + Key4 = true, +} +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: GETVARARGS R0 3 +3: NEWTABLE R3 0 0 +5: GETTABLEKS R4 R3 K0 +5: DUPTABLE R5 5 +6: SETTABLEKS R0 R5 K1 +7: SETTABLEKS R1 R5 K2 +8: SETTABLEKS R2 R5 K3 +9: LOADB R6 1 +9: SETTABLEKS R6 R5 K4 +5: SETTABLEKS R5 R4 K6 +11: RETURN R0 0 +)"); +} + +TEST_CASE("DebugLineInfoCall") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local Foo = ... + +Foo:Bar( + 1, + 2, + 3) +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: GETVARARGS R0 1 +5: LOADN R3 1 +6: LOADN R4 2 +7: LOADN R5 3 +4: NAMECALL R1 R0 K0 +4: CALL R1 4 0 +8: RETURN R0 0 +)"); +} + +TEST_CASE("DebugSource") +{ + const char* source = R"( +local kSelectedBiomes = { + ['Mountains'] = true, + ['Canyons'] = true, + ['Dunes'] = true, + ['Arctic'] = true, + ['Lavaflow'] = true, + ['Hills'] = true, + ['Plains'] = true, + ['Marsh'] = true, + ['Water'] = true, +} +local result = "" +for k in pairs(kSelectedBiomes) do + result = result .. k +end +return result +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); + bcb.setDumpSource(source); + + Luau::compileOrThrow(bcb, source); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( + 2: local kSelectedBiomes = { +NEWTABLE R0 16 0 + 3: ['Mountains'] = true, +LOADB R1 1 +SETTABLEKS R1 R0 K0 + 4: ['Canyons'] = true, +LOADB R1 1 +SETTABLEKS R1 R0 K1 + 5: ['Dunes'] = true, +LOADB R1 1 +SETTABLEKS R1 R0 K2 + 6: ['Arctic'] = true, +LOADB R1 1 +SETTABLEKS R1 R0 K3 + 7: ['Lavaflow'] = true, +LOADB R1 1 +SETTABLEKS R1 R0 K4 + 8: ['Hills'] = true, +LOADB R1 1 +SETTABLEKS R1 R0 K5 + 9: ['Plains'] = true, +LOADB R1 1 +SETTABLEKS R1 R0 K6 + 10: ['Marsh'] = true, +LOADB R1 1 +SETTABLEKS R1 R0 K7 + 11: ['Water'] = true, +LOADB R1 1 +SETTABLEKS R1 R0 K8 + 13: local result = "" +LOADK R1 K9 + 14: for k in pairs(kSelectedBiomes) do +GETIMPORT R2 11 +MOVE R3 R0 +CALL R2 1 3 +FORGPREP_NEXT R2 +3 + 15: result = result .. k +MOVE R7 R1 +MOVE R8 R5 +CONCAT R1 R7 R8 + 14: for k in pairs(kSelectedBiomes) do +FORGLOOP_NEXT R2 -4 + 17: return result +RETURN R1 1 +)"); +} + +TEST_CASE("DebugLocals") +{ + const char* source = R"( +function foo(e, f) + local a = 1 + for i=1,3 do + print(i) + end + for k,v in pairs() do + print(k, v) + end + do + local b = 2 + print(b) + end + do + local c = 2 + print(b) + end + local function inner() + return inner, a + end + return a +end +)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines | Luau::BytecodeBuilder::Dump_Locals); + bcb.setDumpSource(source); + + Luau::CompileOptions options; + options.debugLevel = 2; + + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(1), R"( +local 0: reg 5, start pc 5 line 5, end pc 8 line 5 +local 1: reg 6, start pc 14 line 8, end pc 18 line 8 +local 2: reg 7, start pc 14 line 8, end pc 18 line 8 +local 3: reg 3, start pc 21 line 12, end pc 24 line 12 +local 4: reg 3, start pc 26 line 16, end pc 30 line 16 +local 5: reg 0, start pc 0 line 3, end pc 34 line 21 +local 6: reg 1, start pc 0 line 3, end pc 34 line 21 +local 7: reg 2, start pc 1 line 4, end pc 34 line 21 +local 8: reg 3, start pc 34 line 21, end pc 34 line 21 +3: LOADN R2 1 +4: LOADN R5 1 +4: LOADN R3 3 +4: LOADN R4 1 +4: FORNPREP R3 +5 +5: GETIMPORT R6 1 +5: MOVE R7 R5 +5: CALL R6 1 0 +4: FORNLOOP R3 -5 +7: GETIMPORT R3 3 +7: CALL R3 0 3 +7: FORGPREP_NEXT R3 +5 +8: GETIMPORT R8 1 +8: MOVE R9 R6 +8: MOVE R10 R7 +8: CALL R8 2 0 +7: FORGLOOP_NEXT R3 -6 +11: LOADN R3 2 +12: GETIMPORT R4 1 +12: LOADN R5 2 +12: CALL R4 1 0 +15: LOADN R3 2 +16: GETIMPORT R4 1 +16: GETIMPORT R5 5 +16: CALL R4 1 0 +18: NEWCLOSURE R3 P0 +18: CAPTURE VAL R3 +18: CAPTURE VAL R2 +21: RETURN R2 1 +)"); +} + +TEST_CASE("AssignmentConflict") +{ + // assignments are left to right + CHECK_EQ("\n" + compileFunction0("local a, b a, b = 1, 2"), R"( +LOADNIL R0 +LOADNIL R1 +LOADN R2 1 +LOADN R3 2 +MOVE R0 R2 +MOVE R1 R3 +RETURN R0 0 +)"); + + // if assignment of a local invalidates a direct register reference in later assignments, the value is evacuated to a temp register + CHECK_EQ("\n" + compileFunction0("local a a, a[1] = 1, 2"), R"( +LOADNIL R0 +MOVE R1 R0 +LOADN R2 1 +LOADN R3 2 +MOVE R0 R2 +SETTABLEN R3 R1 1 +RETURN R0 0 +)"); + + // note that this doesn't happen if the local assignment happens last naturally + CHECK_EQ("\n" + compileFunction0("local a a[1], a = 1, 2"), R"( +LOADNIL R0 +LOADN R1 1 +LOADN R2 2 +SETTABLEN R1 R0 1 +MOVE R0 R2 +RETURN R0 0 +)"); + + // this will happen if assigned register is used in any table expression, including as an object... + CHECK_EQ("\n" + compileFunction0("local a a, a.foo = 1, 2"), R"( +LOADNIL R0 +MOVE R1 R0 +LOADN R2 1 +LOADN R3 2 +MOVE R0 R2 +SETTABLEKS R3 R1 K0 +RETURN R0 0 +)"); + + // ... or a table index ... + CHECK_EQ("\n" + compileFunction0("local a a, foo[a] = 1, 2"), R"( +LOADNIL R0 +GETIMPORT R1 1 +MOVE R2 R0 +LOADN R3 1 +LOADN R4 2 +MOVE R0 R3 +SETTABLE R4 R1 R2 +RETURN R0 0 +)"); + + // ... or both ... + CHECK_EQ("\n" + compileFunction0("local a a, a[a] = 1, 2"), R"( +LOADNIL R0 +MOVE R1 R0 +LOADN R2 1 +LOADN R3 2 +MOVE R0 R2 +SETTABLE R3 R1 R1 +RETURN R0 0 +)"); + + // ... or both with two different locals ... + CHECK_EQ("\n" + compileFunction0("local a, b a, b, a[b] = 1, 2, 3"), R"( +LOADNIL R0 +LOADNIL R1 +MOVE R2 R0 +MOVE R3 R1 +LOADN R4 1 +LOADN R5 2 +LOADN R6 3 +MOVE R0 R4 +MOVE R1 R5 +SETTABLE R6 R2 R3 +RETURN R0 0 +)"); + + // however note that if it participates in an expression on the left hand side, there's no point reassigning it since we'd compute the expr value + // into a temp register + CHECK_EQ("\n" + compileFunction0("local a a, foo[a + 1] = 1, 2"), R"( +LOADNIL R0 +GETIMPORT R1 1 +ADDK R2 R0 K2 +LOADN R3 1 +LOADN R4 2 +MOVE R0 R3 +SETTABLE R4 R1 R2 +RETURN R0 0 +)"); +} + +TEST_CASE("FastcallBytecode") +{ + // direct global call + CHECK_EQ("\n" + compileFunction0("return math.abs(-5)"), R"( +LOADN R1 -5 +FASTCALL1 2 R1 +2 +GETIMPORT R0 2 +CALL R0 1 -1 +RETURN R0 -1 +)"); + + // call through a local variable + CHECK_EQ("\n" + compileFunction0("local abs = math.abs return abs(-5)"), R"( +GETIMPORT R0 2 +LOADN R2 -5 +FASTCALL1 2 R2 +1 +MOVE R1 R0 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // call through an upvalue + CHECK_EQ("\n" + compileFunction0("local abs = math.abs function foo() return abs(-5) end return foo()"), R"( +LOADN R1 -5 +FASTCALL1 2 R1 +1 +GETUPVAL R0 0 +CALL R0 1 -1 +RETURN R0 -1 +)"); + + // mutating the global in the script breaks the optimization + CHECK_EQ("\n" + compileFunction0("math = {} return math.abs(-5)"), R"( +NEWTABLE R0 0 0 +SETGLOBAL R0 K0 +GETGLOBAL R1 K0 +GETTABLEKS R0 R1 K1 +LOADN R1 -5 +CALL R0 1 -1 +RETURN R0 -1 +)"); + + // mutating the local in the script breaks the optimization + CHECK_EQ("\n" + compileFunction0("local abs = math.abs abs = nil return abs(-5)"), R"( +GETIMPORT R0 2 +LOADNIL R0 +MOVE R1 R0 +LOADN R2 -5 +CALL R1 1 -1 +RETURN R1 -1 +)"); + + // mutating the global in the script breaks the optimization, even if you do this after computing the local (for simplicity) + CHECK_EQ("\n" + compileFunction0("local abs = math.abs math = {} return abs(-5)"), R"( +GETGLOBAL R1 K0 +GETTABLEKS R0 R1 K1 +NEWTABLE R1 0 0 +SETGLOBAL R1 K0 +MOVE R1 R0 +LOADN R2 -5 +CALL R1 1 -1 +RETURN R1 -1 +)"); +} + +TEST_CASE("LotsOfParameters") +{ + const char* source = R"( +select("#",1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1) +)"; + + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, source); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Out of registers when trying to allocate 265 registers: exceeded limit 255"); + } +} + +TEST_CASE("LotsOfIndexers") +{ + const char* source = R"( +function u(t)for t in s(t.l.l.l.l.l.l.l.l.l.l.l.l.l.l.n.l.l.l.l.l.l.l.l.l.l.l.l.n.l.l.l.l.l.l.n.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.n.l.l.l.g.l.l.l.l.l.l.l.l.l.l.l.l.l.n.l.l.l.l.l.t.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.r.l.l.l.l.l.l.n.l.l.l.l.l.l.l.l.l.l.l.l.n.l.l.l.l.l.l.n.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.g.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.n.l.l.l.l.l.l.l.l.n.l.l.l.l.l.l.l.l.l.l.l.l.n.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.r.n.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.l.n.l.l.l.n.l.l.l.l.l.l.l.n.l.l.l.l.l.l.l.l.l.l..l,l)do end +end +)"; + + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, source); + CHECK(!"Expected exception"); + } + catch (std::exception& e) + { + CHECK_EQ(std::string(e.what()), "Out of registers when trying to allocate 1 registers: exceeded limit 255"); + } +} + +TEST_CASE("AsConstant") +{ + const char* source = R"( +--!strict +return (1 + 2) :: number +)"; + + Luau::CompileOptions options; + Luau::ParseOptions parseOptions; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::compileOrThrow(bcb, source, options, parseOptions); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +LOADN R0 3 +RETURN R0 1 +)"); +} + +TEST_CASE("PreserveNegZero") +{ + CHECK_EQ("\n" + compileFunction0("return 0"), R"( +LOADN R0 0 +RETURN R0 1 +)"); + + CHECK_EQ("\n" + compileFunction0("return -0"), R"( +LOADK R0 K0 +RETURN R0 1 +)"); +} + +TEST_CASE("CaptureImmutable") +{ + // capture argument: note capture by value + CHECK_EQ("\n" + compileFunction("function foo(a, b) return function() return a end end", 1), R"( +NEWCLOSURE R2 P0 +CAPTURE VAL R0 +RETURN R2 1 +)"); + + // capture mutable argument: note capture by reference + close + CHECK_EQ("\n" + compileFunction("function foo(a, b) a = 1 return function() return a end end", 1), R"( +LOADN R0 1 +NEWCLOSURE R2 P0 +CAPTURE REF R0 +CLOSEUPVALS R0 +RETURN R2 1 +)"); + + // capture two arguments, one mutable, one immutable + CHECK_EQ("\n" + compileFunction("function foo(a, b) a = 1 return function() return a + b end end", 1), R"( +LOADN R0 1 +NEWCLOSURE R2 P0 +CAPTURE REF R0 +CAPTURE VAL R1 +CLOSEUPVALS R0 +RETURN R2 1 +)"); + + // capture self + CHECK_EQ("\n" + compileFunction("function bar:foo(a, b) return function() return self end end", 1), R"( +NEWCLOSURE R3 P0 +CAPTURE VAL R0 +RETURN R3 1 +)"); + + // capture mutable self (who mutates self?!?) + CHECK_EQ("\n" + compileFunction("function bar:foo(a, b) self = 42 return function() return self end end", 1), R"( +LOADN R0 42 +NEWCLOSURE R3 P0 +CAPTURE REF R0 +CLOSEUPVALS R0 +RETURN R3 1 +)"); + + // capture upvalue: one mutable, one immutable + CHECK_EQ("\n" + compileFunction("local a, b = math.rand() a = 42 function foo() return function() return a + b end end", 1), R"( +NEWCLOSURE R0 P0 +CAPTURE UPVAL U0 +CAPTURE UPVAL U1 +RETURN R0 1 +)"); + + if (FFlag::LuauPreloadClosuresUpval) + { + // recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( +DUPCLOSURE R0 K0 +CAPTURE VAL R0 +RETURN R0 0 +)"); + + // multi-level recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( +DUPCLOSURE R0 K0 +CAPTURE UPVAL U0 +RETURN R0 1 +)"); + + // multi-level recursive capture where function isn't top-level + // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + local function bar() + return function() return bar() end + end +end +)", + 1), + R"( +NEWCLOSURE R0 P0 +CAPTURE UPVAL U0 +RETURN R0 1 +)"); + } + else + { + // recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( +NEWCLOSURE R0 P0 +CAPTURE VAL R0 +RETURN R0 0 +)"); + } +} + +TEST_CASE("OutOfLocals") +{ + std::string source; + + for (int i = 0; i < 200; ++i) + { + formatAppend(source, "local foo%d\n", i); + } + + source += "local bar\n"; + + Luau::CompileOptions options; + options.debugLevel = 2; // make sure locals aren't elided by requesting their debug info + + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, source, options); + + CHECK(!"Expected CompileError"); + } + catch (Luau::CompileError& e) + { + CHECK_EQ(e.getLocation().begin.line + 1, 201); + CHECK_EQ(std::string(e.what()), "Out of local registers when trying to allocate bar: exceeded limit 200"); + } +} + +TEST_CASE("OutOfUpvalues") +{ + std::string source; + + for (int i = 0; i < 150; ++i) + { + formatAppend(source, "local foo%d\n", i); + formatAppend(source, "foo%d = 42\n", i); + } + + source += "function foo()\n"; + + for (int i = 0; i < 150; ++i) + { + formatAppend(source, "local bar%d\n", i); + formatAppend(source, "bar%d = 42\n", i); + } + + source += "function bar()\n"; + + for (int i = 0; i < 150; ++i) + { + formatAppend(source, "print(foo%d, bar%d)\n", i, i); + } + + source += "end\nend\n"; + + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, source); + + CHECK(!"Expected CompileError"); + } + catch (Luau::CompileError& e) + { + CHECK_EQ(e.getLocation().begin.line + 1, 201); + CHECK_EQ(std::string(e.what()), "Out of upvalue registers when trying to allocate foo100: exceeded limit 200"); + } +} + +TEST_CASE("OutOfRegisters") +{ + std::string source; + + source += "print(\n"; + + for (int i = 0; i < 150; ++i) + { + formatAppend(source, "%d,\n", i); + } + + source += "table.pack(\n"; + + for (int i = 0; i < 150; ++i) + { + formatAppend(source, "%d,\n", i); + } + + source += "42))\n"; + + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, source); + + CHECK(!"Expected CompileError"); + } + catch (Luau::CompileError& e) + { + CHECK_EQ(e.getLocation().begin.line + 1, 152); + CHECK_EQ(std::string(e.what()), "Out of registers when trying to allocate 152 registers: exceeded limit 255"); + } +} + +TEST_CASE("FastCallImportFallback") +{ + std::string source = "local t = {}\n"; + + // we need to exhaust the 10-bit constant space to block GETIMPORT from being emitted + for (int i = 1; i <= 1024; ++i) + { + formatAppend(source, "t[%d] = \"%d\"\n", i, i); + } + + source += "return math.abs(-1)\n"; + + std::string code = compileFunction0(source.c_str()); + + std::vector insns = Luau::split(code, '\n'); + + CHECK_EQ(insns[insns.size() - 9], "LOADN R1 1024"); + CHECK_EQ(insns[insns.size() - 8], "LOADK R2 K1023"); + CHECK_EQ(insns[insns.size() - 7], "SETTABLE R2 R0 R1"); + CHECK_EQ(insns[insns.size() - 6], "LOADN R2 -1"); + CHECK_EQ(insns[insns.size() - 5], "FASTCALL1 2 R2 +4"); + CHECK_EQ(insns[insns.size() - 4], "GETGLOBAL R3 K1024"); // note: it's important that this doesn't overwrite R2 + CHECK_EQ(insns[insns.size() - 3], "GETTABLEKS R1 R3 K1025"); + CHECK_EQ(insns[insns.size() - 2], "CALL R1 1 -1"); + CHECK_EQ(insns[insns.size() - 1], "RETURN R1 -1"); +} + +TEST_CASE("CompoundAssignment") +{ + // globals vs constants + CHECK_EQ("\n" + compileFunction0("a += 1"), R"( +GETGLOBAL R0 K0 +ADDK R0 R0 K1 +SETGLOBAL R0 K0 +RETURN R0 0 +)"); + + // globals vs expressions + CHECK_EQ("\n" + compileFunction0("a -= a"), R"( +GETGLOBAL R0 K0 +GETGLOBAL R1 K0 +SUB R0 R0 R1 +SETGLOBAL R0 K0 +RETURN R0 0 +)"); + + // locals vs constants + CHECK_EQ("\n" + compileFunction0("local a = 1 a *= 2"), R"( +LOADN R0 1 +MULK R0 R0 K0 +RETURN R0 0 +)"); + + // locals vs locals + CHECK_EQ("\n" + compileFunction0("local a = 1 a /= a"), R"( +LOADN R0 1 +DIV R0 R0 R0 +RETURN R0 0 +)"); + + // locals vs expressions + CHECK_EQ("\n" + compileFunction0("local a = 1 a /= a + 1"), R"( +LOADN R0 1 +ADDK R1 R0 K0 +DIV R0 R0 R1 +RETURN R0 0 +)"); + + // upvalues + CHECK_EQ("\n" + compileFunction0("local a = 1 function foo() a += 4 end"), R"( +GETUPVAL R0 0 +ADDK R0 R0 K0 +SETUPVAL R0 0 +RETURN R0 0 +)"); + + // table variants (indexed by string, number, variable) + CHECK_EQ("\n" + compileFunction0("local a = {} a.foo += 5"), R"( +NEWTABLE R0 1 0 +GETTABLEKS R1 R0 K0 +ADDK R1 R1 K1 +SETTABLEKS R1 R0 K0 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = {} a[1] += 5"), R"( +NEWTABLE R0 0 1 +GETTABLEN R1 R0 1 +ADDK R1 R1 K0 +SETTABLEN R1 R0 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = {} a[a] += 5"), R"( +NEWTABLE R0 0 0 +GETTABLE R1 R0 R0 +ADDK R1 R1 K0 +SETTABLE R1 R0 R0 +RETURN R0 0 +)"); + + // left hand side is evaluated once + CHECK_EQ("\n" + compileFunction0("foo()[bar()] += 5"), R"( +GETIMPORT R0 1 +CALL R0 0 1 +GETIMPORT R1 3 +CALL R1 0 1 +GETTABLE R2 R0 R1 +ADDK R2 R2 K4 +SETTABLE R2 R0 R1 +RETURN R0 0 +)"); +} + +TEST_CASE("CompoundAssignmentConcat") +{ + // basic concat + CHECK_EQ("\n" + compileFunction0("local a = '' a ..= 'a'"), R"( +LOADK R0 K0 +MOVE R1 R0 +LOADK R2 K1 +CONCAT R0 R1 R2 +RETURN R0 0 +)"); + + // concat chains + CHECK_EQ("\n" + compileFunction0("local a = '' a ..= 'a' .. 'b'"), R"( +LOADK R0 K0 +MOVE R1 R0 +LOADK R2 K1 +LOADK R3 K2 +CONCAT R0 R1 R3 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0("local a = '' a ..= 'a' .. 'b' .. 'c'"), R"( +LOADK R0 K0 +MOVE R1 R0 +LOADK R2 K1 +LOADK R3 K2 +LOADK R4 K3 +CONCAT R0 R1 R4 +RETURN R0 0 +)"); + + // concat on non-local + CHECK_EQ("\n" + compileFunction0("_VERSION ..= 'a' .. 'b'"), R"( +GETGLOBAL R1 K0 +LOADK R2 K1 +LOADK R3 K2 +CONCAT R0 R1 R3 +SETGLOBAL R0 K0 +RETURN R0 0 +)"); +} + +TEST_CASE("JumpTrampoline") +{ + std::string source; + source += "local sum = 0\n"; + source += "for i=1,3 do\n"; + for (int i = 0; i < 10000; ++i) + { + source += "sum = sum + i\n"; + source += "if sum > 150000 then break end\n"; + } + source += "end\n"; + source += "return sum\n"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::compileOrThrow(bcb, source.c_str()); + + std::stringstream bcs(bcb.dumpFunction(0)); + + std::vector insns; + std::string insn; + while ((std::getline)(bcs, insn)) + insns.push_back(insn); + + // FORNPREP and early JUMPs (break) need to go through a trampoline + CHECK_EQ(insns[0], "LOADN R0 0"); + CHECK_EQ(insns[1], "LOADN R3 1"); + CHECK_EQ(insns[2], "LOADN R1 3"); + CHECK_EQ(insns[3], "LOADN R2 1"); + CHECK_EQ(insns[4], "JUMP +1"); + CHECK_EQ(insns[5], "JUMPX +54542"); + CHECK_EQ(insns[6], "FORNPREP R1 -2"); + CHECK_EQ(insns[7], "ADD R0 R0 R3"); + CHECK_EQ(insns[8], "LOADK R4 K0"); + CHECK_EQ(insns[9], "JUMP +1"); + CHECK_EQ(insns[10], "JUMPX +54537"); + CHECK_EQ(insns[11], "JUMPIFLT R4 R0 -2"); + CHECK_EQ(insns[12], "ADD R0 R0 R3"); + CHECK_EQ(insns[13], "LOADK R4 K0"); + CHECK_EQ(insns[14], "JUMP +1"); + CHECK_EQ(insns[15], "JUMPX +54531"); + CHECK_EQ(insns[16], "JUMPIFLT R4 R0 -2"); + + // FORNLOOP has to go through a trampoline since the jump is back to the beginning of the function + // however, late JUMPs (break) don't need a trampoline since the loop end is really close by + CHECK_EQ(insns[44539], "ADD R0 R0 R3"); + CHECK_EQ(insns[44540], "LOADK R4 K0"); + CHECK_EQ(insns[44541], "JUMPIFLT R4 R0 +8"); + CHECK_EQ(insns[44542], "ADD R0 R0 R3"); + CHECK_EQ(insns[44543], "LOADK R4 K0"); + CHECK_EQ(insns[44544], "JUMPIFLT R4 R0 +4"); + CHECK_EQ(insns[44545], "JUMP +1"); + CHECK_EQ(insns[44546], "JUMPX -54540"); + CHECK_EQ(insns[44547], "FORNLOOP R1 -2"); + CHECK_EQ(insns[44548], "RETURN R0 1"); +} + +TEST_CASE("CompileBytecode") +{ + // This is a coverage test, it just exercises bytecode dumping for correct and malformed code + Luau::compile("return 5"); + Luau::compile("this is not valid lua, right?"); +} + +TEST_CASE("NestedNamecall") +{ + CHECK_EQ("\n" + compileFunction0(R"( +local obj = ... +return obj:Method(1):Method(2):Method(3) +)"), + R"( +GETVARARGS R0 1 +LOADN R3 1 +NAMECALL R1 R0 K0 +CALL R1 2 1 +LOADN R3 2 +NAMECALL R1 R1 K0 +CALL R1 2 1 +LOADN R3 3 +NAMECALL R1 R1 K0 +CALL R1 2 -1 +RETURN R1 -1 +)"); +} + +TEST_CASE("ElideLocals") +{ + // simple local elision: all locals are constant + CHECK_EQ("\n" + compileFunction0(R"( +local a, b = 1, 2 +return a + b +)"), + R"( +LOADN R0 3 +RETURN R0 1 +)"); + + // side effecting expressions block local elision + CHECK_EQ("\n" + compileFunction0(R"( +local a = g() +return a +)"), + R"( +GETIMPORT R0 1 +CALL R0 0 1 +RETURN R0 1 +)"); + + // ... even if they are not used + CHECK_EQ("\n" + compileFunction0(R"( +local a = 1, g() +return a +)"), + R"( +LOADN R0 1 +GETIMPORT R1 1 +CALL R1 0 1 +RETURN R0 1 +)"); +} + +TEST_CASE("ConstantJumpCompare") +{ + CHECK_EQ("\n" + compileFunction0(R"( +local obj = ... +local b = obj == 1 +)"), + R"( +GETVARARGS R0 1 +JUMPIFEQK R0 K0 +2 +LOADB R1 0 +1 +LOADB R1 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0(R"( +local obj = ... +local b = 1 == obj +)"), + R"( +GETVARARGS R0 1 +JUMPIFEQK R0 K0 +2 +LOADB R1 0 +1 +LOADB R1 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0(R"( +local obj = ... +local b = "Hello, Sailor!" == obj +)"), + R"( +GETVARARGS R0 1 +JUMPIFEQK R0 K0 +2 +LOADB R1 0 +1 +LOADB R1 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0(R"( +local obj = ... +local b = nil == obj +)"), + R"( +GETVARARGS R0 1 +JUMPIFEQK R0 K0 +2 +LOADB R1 0 +1 +LOADB R1 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0(R"( +local obj = ... +local b = true == obj +)"), + R"( +GETVARARGS R0 1 +JUMPIFEQK R0 K0 +2 +LOADB R1 0 +1 +LOADB R1 1 +RETURN R0 0 +)"); + + CHECK_EQ("\n" + compileFunction0(R"( +local obj = ... +local b = nil ~= obj +)"), + R"( +GETVARARGS R0 1 +JUMPIFNOTEQK R0 K0 +2 +LOADB R1 0 +1 +LOADB R1 1 +RETURN R0 0 +)"); + + // table literals should not generate IFEQK variants + CHECK_EQ("\n" + compileFunction0(R"( +local obj = ... +local b = obj == {} +)"), + R"( +GETVARARGS R0 1 +NEWTABLE R2 0 0 +JUMPIFEQ R0 R2 +2 +LOADB R1 0 +1 +LOADB R1 1 +RETURN R0 0 +)"); +} + +TEST_CASE("TableConstantStringIndex") +{ + CHECK_EQ("\n" + compileFunction0(R"( +local t = { a = 2 } +return t['a'] +)"), + R"( +DUPTABLE R0 1 +LOADN R1 2 +SETTABLEKS R1 R0 K0 +GETTABLEKS R1 R0 K0 +RETURN R1 1 +)"); + + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t['a'] = 2 +)"), + R"( +NEWTABLE R0 0 0 +LOADN R1 2 +SETTABLEKS R1 R0 K0 +RETURN R0 0 +)"); +} + +TEST_CASE("Coverage") +{ + // basic statement coverage + CHECK_EQ("\n" + compileFunction0Coverage(R"( +print(1) +print(2) +)", + 1), + R"( +2: COVERAGE +2: GETIMPORT R0 1 +2: LOADN R1 1 +2: CALL R0 1 0 +3: COVERAGE +3: GETIMPORT R0 1 +3: LOADN R1 2 +3: CALL R0 1 0 +4: RETURN R0 0 +)"); + + // branching + CHECK_EQ("\n" + compileFunction0Coverage(R"( +if x then + print(1) +else + print(2) +end +)", + 1), + R"( +2: COVERAGE +2: GETIMPORT R0 1 +2: JUMPIFNOT R0 +6 +3: COVERAGE +3: GETIMPORT R0 3 +3: LOADN R1 1 +3: CALL R0 1 0 +7: RETURN R0 0 +5: COVERAGE +5: GETIMPORT R0 3 +5: LOADN R1 2 +5: CALL R0 1 0 +7: RETURN R0 0 +)"); + + // branching with comments + // note that commented lines don't have COVERAGE insns! + CHECK_EQ("\n" + compileFunction0Coverage(R"( +if x then + -- first + print(1) +else + -- second + print(2) +end +)", + 1), + R"( +2: COVERAGE +2: GETIMPORT R0 1 +2: JUMPIFNOT R0 +6 +4: COVERAGE +4: GETIMPORT R0 3 +4: LOADN R1 1 +4: CALL R0 1 0 +9: RETURN R0 0 +7: COVERAGE +7: GETIMPORT R0 3 +7: LOADN R1 2 +7: CALL R0 1 0 +9: RETURN R0 0 +)"); + + // expression coverage for table literals + // note: duplicate COVERAGE instructions are there since we don't deduplicate expr/stat + CHECK_EQ("\n" + compileFunction0Coverage(R"( +local c = ... +local t = { + a = 1, + b = 2, + c = c +} +)", + 2), + R"( +2: COVERAGE +2: COVERAGE +2: GETVARARGS R0 1 +3: COVERAGE +3: COVERAGE +3: DUPTABLE R1 3 +4: COVERAGE +4: COVERAGE +4: LOADN R2 1 +4: SETTABLEKS R2 R1 K0 +5: COVERAGE +5: COVERAGE +5: LOADN R2 2 +5: SETTABLEKS R2 R1 K1 +6: COVERAGE +6: SETTABLEKS R0 R1 K2 +8: RETURN R0 0 +)"); +} + +TEST_CASE("ConstantClosure") +{ + ScopedFastFlag sff("LuauPreloadClosures", true); + + // closures without upvalues are created when bytecode is loaded + CHECK_EQ("\n" + compileFunction(R"( +return function() end +)", + 1), + R"( +DUPCLOSURE R0 K0 +RETURN R0 1 +)"); + + // they can access globals just fine + CHECK_EQ("\n" + compileFunction(R"( +return function() print("hi") end +)", + 1), + R"( +DUPCLOSURE R0 K0 +RETURN R0 1 +)"); + + // if they need upvalues, we can't create them before running the code (but see SharedClosure test) + CHECK_EQ("\n" + compileFunction(R"( +function test() + local print = print + return function() print("hi") end +end +)", + 1), + R"( +GETIMPORT R0 1 +NEWCLOSURE R1 P0 +CAPTURE VAL R0 +RETURN R1 1 +)"); + + if (FFlag::LuauPreloadClosuresFenv) + { + // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion + CHECK_EQ("\n" + compileFunction(R"( +setfenv(1, {}) +return function() print("hi") end +)", + 1), + R"( +GETIMPORT R0 1 +LOADN R1 1 +NEWTABLE R2 0 0 +CALL R0 2 0 +NEWCLOSURE R0 P0 +RETURN R0 1 +)"); + + // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature + CHECK_EQ("\n" + compileFunction(R"( +if false then setfenv(1, {}) end +return function() print("hi") end +)", + 1), + R"( +NEWCLOSURE R0 P0 +RETURN R0 1 +)"); + } +} + +TEST_CASE("SharedClosure") +{ + ScopedFastFlag sff1("LuauPreloadClosures", true); + ScopedFastFlag sff2("LuauPreloadClosuresUpval", true); + + // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level + CHECK_EQ("\n" + compileFunction(R"( +local val = ... + +local function foo() + return function() return val end +end +)", + 1), + R"( +DUPCLOSURE R0 K0 +CAPTURE UPVAL U0 +RETURN R0 1 +)"); + + // ... as long as the values aren't mutated. + CHECK_EQ("\n" + compileFunction(R"( +local val = ... + +local function foo() + return function() return val end +end + +val = 5 +)", + 1), + R"( +NEWCLOSURE R0 P0 +CAPTURE UPVAL U0 +RETURN R0 1 +)"); + + // making the upvalue non-toplevel disables the optimization since it's likely that it will change + CHECK_EQ("\n" + compileFunction(R"( +local function foo(val) + return function() return val end +end +)", + 1), + R"( +NEWCLOSURE R1 P0 +CAPTURE VAL R0 +RETURN R1 1 +)"); + + // the upvalue analysis is transitive through local functions, which allows for code reuse to not defeat the optimization + CHECK_EQ("\n" + compileFunction(R"( +local val = ... + +local function foo() + local function bar() + return val + end + + return function() return bar() end +end +)", + 2), + R"( +DUPCLOSURE R0 K0 +CAPTURE UPVAL U0 +DUPCLOSURE R1 K1 +CAPTURE VAL R0 +RETURN R1 1 +)"); + + // as such, if the upvalue that we reach transitively isn't top-level we fall back to newclosure + CHECK_EQ("\n" + compileFunction(R"( +local function foo(val) + local function bar() + return val + end + + return function() return bar() end +end +)", + 2), + R"( +NEWCLOSURE R1 P0 +CAPTURE VAL R0 +NEWCLOSURE R2 P1 +CAPTURE VAL R1 +RETURN R2 1 +)"); + + // we also allow recursive function captures to share the object, even when it's not top-level + CHECK_EQ("\n" + compileFunction("function test() local function foo() return foo() end end", 1), R"( +DUPCLOSURE R0 K0 +CAPTURE VAL R0 +RETURN R0 0 +)"); + + // multi-level recursive capture where function isn't top-level fails however. + // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler + CHECK_EQ("\n" + compileFunction(R"( +local function foo() + local function bar() + return function() return bar() end + end +end +)", + 1), + R"( +NEWCLOSURE R0 P0 +CAPTURE UPVAL U0 +RETURN R0 1 +)"); + + // top level upvalues inside loops should not be shared -- note that the bytecode below only uses NEWCLOSURE + CHECK_EQ("\n" + compileFunction(R"( +for i=1,10 do + print(function() return i end) +end + +for k,v in pairs(...) do + print(function() return k end) +end + +for i=1,10 do + local j = i + print(function() return j end) +end +)", + 3), + R"( +LOADN R2 1 +LOADN R0 10 +LOADN R1 1 +FORNPREP R0 +6 +GETIMPORT R3 1 +NEWCLOSURE R4 P0 +CAPTURE VAL R2 +CALL R3 1 0 +FORNLOOP R0 -6 +GETIMPORT R0 3 +GETVARARGS R1 -1 +CALL R0 -1 3 +FORGPREP_NEXT R0 +5 +GETIMPORT R5 1 +NEWCLOSURE R6 P1 +CAPTURE VAL R3 +CALL R5 1 0 +FORGLOOP_NEXT R0 -6 +LOADN R2 1 +LOADN R0 10 +LOADN R1 1 +FORNPREP R0 +7 +MOVE R3 R2 +GETIMPORT R4 1 +NEWCLOSURE R5 P2 +CAPTURE VAL R3 +CALL R4 1 0 +FORNLOOP R0 -7 +RETURN R0 0 +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Config.test.cpp b/tests/Config.test.cpp new file mode 100644 index 0000000..e6a7267 --- /dev/null +++ b/tests/Config.test.cpp @@ -0,0 +1,156 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Config.h" +#include "Luau/Frontend.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +#include + +using namespace Luau; + +TEST_SUITE_BEGIN("ConfigTest"); + +TEST_CASE("language_mode") +{ + Config config; + auto err = parseConfig(R"({"languageMode":"strict"})", config); + REQUIRE(!err); + + CHECK_EQ(int(Luau::Mode::Strict), int(config.mode)); +} + +TEST_CASE("disable_a_lint_rule") +{ + Config config; + auto err = parseConfig(R"( + {"lint": { + "UnknownGlobal": false, + }} + )", + config); + REQUIRE(!err); + + CHECK(!config.enabledLint.isEnabled(LintWarning::Code_UnknownGlobal)); + CHECK(config.enabledLint.isEnabled(LintWarning::Code_DeprecatedGlobal)); +} + +TEST_CASE("report_a_syntax_error") +{ + Config config; + auto err = parseConfig(R"( + {"lint": { + "UnknownGlobal": "oops" + }} + )", + config); + + REQUIRE(err); + CHECK_EQ("In key UnknownGlobal: Bad setting 'oops'. Valid options are true and false", *err); +} + +TEST_CASE("noinfer_is_still_allowed") +{ + Config config; + auto err = parseConfig(R"( {"language": {"mode": "noinfer"}} )", config, true); + REQUIRE(!err); + + CHECK_EQ(int(Luau::Mode::NoCheck), int(config.mode)); +} + +TEST_CASE("lint_warnings_are_ordered") +{ + Config root; + auto err = parseConfig(R"({"lint": {"*": true, "LocalShadow": false}})", root); + REQUIRE(!err); + + Config foo = root; + err = parseConfig(R"({"lint": {"LocalShadow": true, "*": false}})", foo); + REQUIRE(!err); + + CHECK(!root.enabledLint.isEnabled(LintWarning::Code_LocalShadow)); + CHECK(root.enabledLint.isEnabled(LintWarning::Code_LocalUnused)); + + CHECK(!foo.enabledLint.isEnabled(LintWarning::Code_LocalShadow)); +} + +TEST_CASE("comments") +{ + Config config; + auto err = parseConfig(R"( +{ + "lint": { + "*": false, + "SameLineStatement": true, + "FunctionUnused": true, + //"LocalShadow": true, + //"LocalUnused": true, + "ImportUnused": true, + "ImplicitReturn": true + } +} +)", + config); + REQUIRE(!err); + + CHECK(!config.enabledLint.isEnabled(LintWarning::Code_LocalShadow)); + CHECK(config.enabledLint.isEnabled(LintWarning::Code_ImportUnused)); +} + +TEST_CASE("issue_severity") +{ + Config config; + CHECK(!config.lintErrors); + CHECK(config.typeErrors); + + auto err = parseConfig(R"( +{ + "lintErrors": true, + "typeErrors": false, +} +)", + config); + REQUIRE(!err); + + CHECK(config.lintErrors); + CHECK(!config.typeErrors); +} + +TEST_CASE("extra_globals") +{ + Config config; + auto err = parseConfig(R"( +{ + "globals": ["it", "__DEV__"], +} +)", + config); + REQUIRE(!err); + + REQUIRE(config.globals.size() == 2); + CHECK(config.globals[0] == "it"); + CHECK(config.globals[1] == "__DEV__"); +} + +TEST_CASE("lint_rules_compat") +{ + Config config; + auto err = parseConfig(R"( + {"lint": { + "SameLineStatement": "enabled", + "FunctionUnused": "disabled", + "ImportUnused": "fatal", + }} + )", + config, true); + REQUIRE(!err); + + CHECK(config.enabledLint.isEnabled(LintWarning::Code_SameLineStatement)); + CHECK(!config.enabledLint.isEnabled(LintWarning::Code_FunctionUnused)); + CHECK(config.enabledLint.isEnabled(LintWarning::Code_ImportUnused)); + CHECK(config.fatalLint.isEnabled(LintWarning::Code_ImportUnused)); +} + +TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp new file mode 100644 index 0000000..5a697a4 --- /dev/null +++ b/tests/Conformance.test.cpp @@ -0,0 +1,804 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Compiler.h" + +#include "Luau/BuiltinDefinitions.h" +#include "Luau/ModuleResolver.h" +#include "Luau/TypeInfer.h" +#include "Luau/StringUtils.h" +#include "Luau/BytecodeBuilder.h" + +#include "doctest.h" +#include "ScopedFlags.h" + +#include "lua.h" +#include "lualib.h" + +#include +#include + +static int lua_collectgarbage(lua_State* L) +{ + static const char* const opts[] = {"stop", "restart", "collect", "count", "isrunning", "step", "setgoal", "setstepmul", "setstepsize", nullptr}; + static const int optsnum[] = { + LUA_GCSTOP, LUA_GCRESTART, LUA_GCCOLLECT, LUA_GCCOUNT, LUA_GCISRUNNING, LUA_GCSTEP, LUA_GCSETGOAL, LUA_GCSETSTEPMUL, LUA_GCSETSTEPSIZE}; + + int o = luaL_checkoption(L, 1, "collect", opts); + int ex = luaL_optinteger(L, 2, 0); + int res = lua_gc(L, optsnum[o], ex); + switch (optsnum[o]) + { + case LUA_GCSTEP: + case LUA_GCISRUNNING: + { + lua_pushboolean(L, res); + return 1; + } + default: + { + lua_pushnumber(L, res); + return 1; + } + } +} + +static int lua_loadstring(lua_State* L) +{ + size_t l = 0; + const char* s = luaL_checklstring(L, 1, &l); + const char* chunkname = luaL_optstring(L, 2, s); + + lua_setsafeenv(L, LUA_ENVIRONINDEX, false); + + std::string bytecode = Luau::compile(std::string(s, l)); + if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + return 1; + + lua_pushnil(L); + lua_insert(L, -2); /* put before error message */ + return 2; /* return nil plus error message */ +} + +static int lua_vector(lua_State* L) +{ + double x = luaL_checknumber(L, 1); + double y = luaL_checknumber(L, 2); + double z = luaL_checknumber(L, 3); + + lua_pushvector(L, float(x), float(y), float(z)); + return 1; +} + +static int lua_vector_dot(lua_State* L) +{ + const float* a = lua_tovector(L, 1); + const float* b = lua_tovector(L, 2); + + if (a && b) + { + lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); + return 1; + } + + throw std::runtime_error("invalid arguments to vector:Dot"); +} + +static int lua_vector_index(lua_State* L) +{ + const char* name = luaL_checkstring(L, 2); + + if (const float* v = lua_tovector(L, 1)) + { + if (strcmp(name, "Magnitude") == 0) + { + lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); + return 1; + } + + if (strcmp(name, "Dot") == 0) + { + lua_pushcfunction(L, lua_vector_dot, "Dot"); + return 1; + } + } + + throw std::runtime_error(Luau::format("%s is not a valid member of vector", name)); +} + +static int lua_vector_namecall(lua_State* L) +{ + if (const char* str = lua_namecallatom(L, nullptr)) + { + if (strcmp(str, "Dot") == 0) + return lua_vector_dot(L); + } + + throw std::runtime_error(Luau::format("%s is not a valid method of vector", luaL_checkstring(L, 1))); +} + +int lua_silence(lua_State* L) +{ + return 0; +} + +using StateRef = std::unique_ptr; + +static StateRef runConformance( + const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, lua_State* initialLuaState = nullptr) +{ + std::string path = __FILE__; + path.erase(path.find_last_of("\\/")); + path += "/conformance/"; + path += name; + + std::fstream stream(path, std::ios::in | std::ios::binary); + REQUIRE(stream); + + std::string source(std::istreambuf_iterator(stream), {}); + + stream.close(); + + if (!initialLuaState) + initialLuaState = luaL_newstate(); + StateRef globalState(initialLuaState, lua_close); + lua_State* L = globalState.get(); + + luaL_openlibs(L); + + // Register a few global functions for conformance tests + static const luaL_Reg funcs[] = { + {"collectgarbage", lua_collectgarbage}, + {"loadstring", lua_loadstring}, + {"print", lua_silence}, // Disable print() by default; comment this out to enable debug prints in tests + {nullptr, nullptr}, + }; + + lua_pushvalue(L, LUA_GLOBALSINDEX); + luaL_register(L, nullptr, funcs); + lua_pop(L, 1); + + // In some configurations we have a larger C stack consumption which trips some conformance tests +#if defined(LUAU_ENABLE_ASAN) || defined(_NOOPT) || defined(_DEBUG) + lua_pushboolean(L, true); + lua_setglobal(L, "limitedstack"); +#endif + + // Extra test-specific setup + if (setup) + setup(L); + + // Protect core libraries and metatables from modification + luaL_sandbox(L); + + // Create a new writable global table for current thread + luaL_sandboxthread(L); + + // Lua conformance tests treat _G synonymously with getfenv(); for now cater to them + lua_pushvalue(L, LUA_GLOBALSINDEX); + lua_pushvalue(L, LUA_GLOBALSINDEX); + lua_setfield(L, -1, "_G"); + + std::string chunkname = "=" + std::string(name); + + Luau::CompileOptions copts; + copts.debugLevel = 2; // for debugger tests + copts.vectorCtor = "vector"; // for vector tests + + std::string bytecode = Luau::compile(source, copts); + int status = 0; + + if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + { + status = lua_resume(L, nullptr, 0); + } + else + { + status = LUA_ERRSYNTAX; + } + + while (yield && (status == LUA_YIELD || status == LUA_BREAK)) + { + yield(L); + status = lua_resume(L, nullptr, 0); + } + + extern void luaC_validate(lua_State * L); // internal function, declared in lgc.h - not exposed via lua.h + luaC_validate(L); + + if (status == 0) + { + REQUIRE(lua_isstring(L, -1)); + CHECK(std::string(lua_tostring(L, -1)) == "OK"); + } + else + { + std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(L, -1); + error += "\nstacktrace:\n"; + error += lua_debugtrace(L); + + FAIL(error); + } + + return globalState; +} + +TEST_SUITE_BEGIN("Conformance"); + +TEST_CASE("Assert") +{ + runConformance("assert.lua"); +} + +TEST_CASE("Basic") +{ + runConformance("basic.lua"); +} + +TEST_CASE("Math") +{ + runConformance("math.lua"); +} + +TEST_CASE("Table") +{ + ScopedFastFlag sff("LuauTableFreeze", true); + + runConformance("nextvar.lua"); +} + +TEST_CASE("PatternMatch") +{ + runConformance("pm.lua"); +} + +TEST_CASE("Sort") +{ + runConformance("sort.lua"); +} + +TEST_CASE("Move") +{ + runConformance("move.lua"); +} + +TEST_CASE("Clear") +{ + runConformance("clear.lua"); +} + +TEST_CASE("Strings") +{ + runConformance("strings.lua"); +} + +TEST_CASE("VarArg") +{ + runConformance("vararg.lua"); +} + +TEST_CASE("Locals") +{ + runConformance("locals.lua"); +} + +TEST_CASE("Literals") +{ + runConformance("literals.lua"); +} + +TEST_CASE("Errors") +{ + runConformance("errors.lua"); +} + +TEST_CASE("Events") +{ + runConformance("events.lua"); +} + +TEST_CASE("Constructs") +{ + runConformance("constructs.lua"); +} + +TEST_CASE("Closure") +{ + runConformance("closure.lua"); +} + +TEST_CASE("Calls") +{ + runConformance("calls.lua"); +} + +TEST_CASE("Attrib") +{ + runConformance("attrib.lua"); +} + +TEST_CASE("GC") +{ + runConformance("gc.lua"); +} + +TEST_CASE("Bitwise") +{ + runConformance("bitwise.lua"); +} + +TEST_CASE("UTF8") +{ + runConformance("utf8.lua"); +} + +TEST_CASE("Coroutine") +{ + runConformance("coroutine.lua"); +} + +TEST_CASE("PCall") +{ + runConformance("pcall.lua", [](lua_State* L) { + lua_pushcfunction(L, [](lua_State* L) -> int { +#if LUA_USE_LONGJMP + luaL_error(L, "oops"); +#else + throw std::runtime_error("oops"); +#endif + }); + lua_setglobal(L, "cxxthrow"); + + lua_pushcfunction(L, [](lua_State* L) -> int { + lua_State* co = lua_tothread(L, 1); + lua_xmove(L, co, 1); + lua_resumeerror(co, L); + return 0; + }); + lua_setglobal(L, "resumeerror"); + }); +} + +TEST_CASE("Pack") +{ + runConformance("tpack.lua"); +} + +TEST_CASE("Vector") +{ + runConformance("vector.lua", [](lua_State* L) { + lua_pushcfunction(L, lua_vector); + lua_setglobal(L, "vector"); + + lua_pushvector(L, 0.0f, 0.0f, 0.0f); + luaL_newmetatable(L, "vector"); + + lua_pushstring(L, "__index"); + lua_pushcfunction(L, lua_vector_index); + lua_settable(L, -3); + + lua_pushstring(L, "__namecall"); + lua_pushcfunction(L, lua_vector_namecall); + lua_settable(L, -3); + + lua_setreadonly(L, -1, true); + lua_setmetatable(L, -2); + lua_pop(L, 1); + }); +} + +static void populateRTTI(lua_State* L, Luau::TypeId type) +{ + if (auto p = Luau::get(type)) + { + switch (p->type) + { + case Luau::PrimitiveTypeVar::Boolean: + lua_pushstring(L, "boolean"); + break; + + case Luau::PrimitiveTypeVar::NilType: + lua_pushstring(L, "nil"); + break; + + case Luau::PrimitiveTypeVar::Number: + lua_pushstring(L, "number"); + break; + + case Luau::PrimitiveTypeVar::String: + lua_pushstring(L, "string"); + break; + + case Luau::PrimitiveTypeVar::Thread: + lua_pushstring(L, "thread"); + break; + + default: + LUAU_ASSERT(!"Unknown primitive type"); + } + } + else if (auto t = Luau::get(type)) + { + lua_newtable(L); + + for (const auto& [name, prop] : t->props) + { + populateRTTI(L, prop.type); + lua_setfield(L, -2, name.c_str()); + } + } + else if (Luau::get(type)) + { + lua_pushstring(L, "function"); + } + else if (Luau::get(type)) + { + lua_pushstring(L, "any"); + } + else if (auto i = Luau::get(type)) + { + for (const auto& part : i->parts) + LUAU_ASSERT(Luau::get(part)); + + lua_pushstring(L, "function"); + } + else + { + LUAU_ASSERT(!"Unknown type"); + } +} + +TEST_CASE("Types") +{ + runConformance("types.lua", [](lua_State* L) { + Luau::NullModuleResolver moduleResolver; + Luau::InternalErrorReporter iceHandler; + Luau::TypeChecker env(&moduleResolver, &iceHandler); + + Luau::registerBuiltinTypes(env); + Luau::freeze(env.globalTypes); + + lua_newtable(L); + + for (const auto& [name, binding] : env.globalScope->bindings) + { + populateRTTI(L, binding.typeId); + lua_setfield(L, -2, toString(name).c_str()); + } + + lua_setglobal(L, "RTTI"); + }); +} + +TEST_CASE("DateTime") +{ + runConformance("datetime.lua"); +} + +TEST_CASE("Debug") +{ + runConformance("debug.lua"); +} + +TEST_CASE("Debugger") +{ + static int breakhits = 0; + static lua_State* interruptedthread = nullptr; + + breakhits = 0; + interruptedthread = nullptr; + + runConformance( + "debugger.lua", + [](lua_State* L) { + lua_Callbacks* cb = lua_callbacks(L); + + // for breakpoints to work we should make sure debugbreak is installed + cb->debugbreak = [](lua_State* L, lua_Debug* ar) { + breakhits++; + + // for every breakpoint, we break on the first invocation and continue on second + // this allows us to easily step off breakpoints + // (real implementaiton may require singlestepping) + if (breakhits % 2 == 1) + lua_break(L); + }; + + // for resuming off a breakpoint inside a coroutine we need to resume the interrupted coroutine + cb->debuginterrupt = [](lua_State* L, lua_Debug* ar) { + CHECK(interruptedthread == nullptr); + CHECK(ar->userdata); // userdata contains the interrupted thread + + interruptedthread = static_cast(ar->userdata); + }; + + // add breakpoint() function + lua_pushcfunction(L, [](lua_State* L) -> int { + int line = luaL_checkinteger(L, 1); + + lua_Debug ar = {}; + lua_getinfo(L, 1, "f", &ar); + + lua_breakpoint(L, -1, line, true); + return 0; + }); + lua_setglobal(L, "breakpoint"); + }, + [](lua_State* L) { + CHECK(breakhits % 2 == 1); + + lua_checkstack(L, LUA_MINSTACK); + + if (breakhits == 1) + { + // test lua_getargument + int a = lua_getargument(L, 0, 1); + REQUIRE(a); + CHECK(lua_tointeger(L, -1) == 50); + lua_pop(L, 1); + + // test lua_getlocal + const char* l = lua_getlocal(L, 0, 1); + REQUIRE(l); + CHECK(strcmp(l, "b") == 0); + CHECK(lua_tointeger(L, -1) == 50); + lua_pop(L, 1); + + // test lua_getupvalue + lua_Debug ar = {}; + lua_getinfo(L, 0, "f", &ar); + + const char* u = lua_getupvalue(L, -1, 1); + REQUIRE(u); + CHECK(strcmp(u, "a") == 0); + CHECK(lua_tointeger(L, -1) == 5); + lua_pop(L, 2); + } + else if (breakhits == 3) + { + // validate assignment via lua_getlocal + const char* l = lua_getlocal(L, 0, 1); + REQUIRE(l); + CHECK(strcmp(l, "a") == 0); + CHECK(lua_tointeger(L, -1) == 6); + lua_pop(L, 1); + } + else if (breakhits == 5) + { + // validate assignment via lua_getlocal + const char* l = lua_getlocal(L, 1, 1); + REQUIRE(l); + CHECK(strcmp(l, "a") == 0); + CHECK(lua_tointeger(L, -1) == 7); + lua_pop(L, 1); + } + else if (breakhits == 7) + { + // validate assignment via lua_getlocal + const char* l = lua_getlocal(L, 1, 1); + REQUIRE(l); + CHECK(strcmp(l, "a") == 0); + CHECK(lua_tointeger(L, -1) == 8); + lua_pop(L, 1); + } + else if (breakhits == 9) + { + // validate assignment via lua_getlocal + const char* l = lua_getlocal(L, 1, 1); + REQUIRE(l); + CHECK(strcmp(l, "a") == 0); + CHECK(lua_tointeger(L, -1) == 9); + lua_pop(L, 1); + } + + if (interruptedthread) + { + lua_resume(interruptedthread, nullptr, 0); + interruptedthread = nullptr; + } + }); + + CHECK(breakhits == 10); // 2 hits per breakpoint +} + +TEST_CASE("SameHash") +{ + extern unsigned int luaS_hash(const char* str, size_t len); // internal function, declared in lstring.h - not exposed via lua.h + + // To keep VM and compiler separate, we duplicate the hash function definition + // This test validates that the hash function in question returns the same results on basic inputs + // If this is violated, some code may regress in performance due to hash slot misprediction in inline caches + CHECK(luaS_hash("", 0) == Luau::BytecodeBuilder::getStringHash({"", 0})); + CHECK(luaS_hash("lua", 3) == Luau::BytecodeBuilder::getStringHash({"lua", 3})); + CHECK(luaS_hash("luau", 4) == Luau::BytecodeBuilder::getStringHash({"luau", 4})); + CHECK(luaS_hash("luaubytecode", 12) == Luau::BytecodeBuilder::getStringHash({"luaubytecode", 12})); + CHECK(luaS_hash("luaubytecodehash", 16) == Luau::BytecodeBuilder::getStringHash({"luaubytecodehash", 16})); + + // Also hash should work on unaligned source data even when hashing long strings + char buf[128] = {}; + CHECK(luaS_hash(buf + 1, 120) == luaS_hash(buf + 2, 120)); +} + +TEST_CASE("InlineDtor") +{ + static int dtorhits = 0; + + dtorhits = 0; + + { + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + void* u1 = lua_newuserdatadtor(L, 4, [](void* data) { + dtorhits += *(int*)data; + }); + + void* u2 = lua_newuserdatadtor(L, 1, [](void* data) { + dtorhits += *(char*)data; + }); + + *(int*)u1 = 39; + *(char*)u2 = 3; + } + + CHECK(dtorhits == 42); +} + +TEST_CASE("Reference") +{ + static int dtorhits = 0; + + dtorhits = 0; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // note, we push two userdata objects but only pin one of them (the first one) + lua_newuserdatadtor(L, 0, [](void*) { + dtorhits++; + }); + lua_newuserdatadtor(L, 0, [](void*) { + dtorhits++; + }); + + lua_gc(L, LUA_GCCOLLECT, 0); + CHECK(dtorhits == 0); + + int ref = lua_ref(L, -2); + lua_pop(L, 2); + + lua_gc(L, LUA_GCCOLLECT, 0); + CHECK(dtorhits == 1); + + lua_getref(L, ref); + CHECK(lua_isuserdata(L, -1)); + lua_pop(L, 1); + + lua_gc(L, LUA_GCCOLLECT, 0); + CHECK(dtorhits == 1); + + lua_unref(L, ref); + + lua_gc(L, LUA_GCCOLLECT, 0); + CHECK(dtorhits == 2); +} + +TEST_CASE("ApiFunctionCalls") +{ + StateRef globalState = runConformance("apicalls.lua"); + lua_State* L = globalState.get(); + + lua_getfield(L, LUA_GLOBALSINDEX, "add"); + lua_pushnumber(L, 40); + lua_pushnumber(L, 2); + lua_call(L, 2, 1); + CHECK(lua_isnumber(L, -1)); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + + lua_getfield(L, LUA_GLOBALSINDEX, "add"); + lua_pushnumber(L, 40); + lua_pushnumber(L, 2); + lua_pcall(L, 2, 1, 0); + CHECK(lua_isnumber(L, -1)); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); +} + +static bool endsWith(const std::string& str, const std::string& suffix) +{ + if (suffix.length() > str.length()) + return false; + + return suffix == std::string_view(str.c_str() + str.length() - suffix.length(), suffix.length()); +} + +#if !LUA_USE_LONGJMP +TEST_CASE("ExceptionObject") +{ + ScopedFastFlag sff("LuauExceptionMessageFix", true); + + struct ExceptionResult + { + bool exceptionGenerated; + std::string description; + }; + + auto captureException = [](lua_State* L, const char* functionToRun) { + try + { + lua_State* threadState = lua_newthread(L); + lua_getfield(threadState, LUA_GLOBALSINDEX, functionToRun); + CHECK(lua_isLfunction(threadState, -1)); + lua_call(threadState, 0, 0); + } + catch (std::exception& e) + { + CHECK(e.what() != nullptr); + return ExceptionResult{true, e.what()}; + } + return ExceptionResult{false, ""}; + }; + + auto reallocFunc = [](lua_State* L, void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { + if (nsize == 0) + { + free(ptr); + return NULL; + } + else if (nsize > 512 * 1024) + { + // For testing purposes return null for large allocations + // so we can generate exceptions related to memory allocation + // failures. + return nullptr; + } + else + { + return realloc(ptr, nsize); + } + }; + + StateRef globalState = runConformance("exceptions.lua", nullptr, nullptr, lua_newstate(reallocFunc, nullptr)); + lua_State* L = globalState.get(); + + { + ExceptionResult result = captureException(L, "infinite_recursion_error"); + CHECK(result.exceptionGenerated); + } + + { + ExceptionResult result = captureException(L, "empty_function"); + CHECK_FALSE(result.exceptionGenerated); + } + + { + ExceptionResult result = captureException(L, "pass_number_to_error"); + CHECK(result.exceptionGenerated); + CHECK(endsWith(result.description, "42")); + } + + { + ExceptionResult result = captureException(L, "pass_string_to_error"); + CHECK(result.exceptionGenerated); + CHECK(endsWith(result.description, "string argument")); + } + + { + ExceptionResult result = captureException(L, "pass_table_to_error"); + CHECK(result.exceptionGenerated); + } + + { + ExceptionResult result = captureException(L, "large_allocation_error"); + CHECK(result.exceptionGenerated); + } +} +#endif + +TEST_CASE("IfElseExpression") +{ + ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + + runConformance("ifelseexpr.lua"); +} + +TEST_SUITE_END(); diff --git a/tests/Error.test.cpp b/tests/Error.test.cpp new file mode 100644 index 0000000..5ba5c11 --- /dev/null +++ b/tests/Error.test.cpp @@ -0,0 +1,16 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Error.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("ErrorTests"); + +TEST_CASE("TypeError_code_should_return_nonzero_code") +{ + auto e = TypeError{{{0, 0}, {0, 1}}, UnknownSymbol{"Foo"}}; + CHECK_GE(e.code(), 1000); +} + +TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp new file mode 100644 index 0000000..26bc77f --- /dev/null +++ b/tests/Fixture.cpp @@ -0,0 +1,431 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Fixture.h" + +#include "Luau/AstQuery.h" +#include "Luau/TypeVar.h" +#include "Luau/TypeAttach.h" +#include "Luau/Transpiler.h" + +#include "Luau/BuiltinDefinitions.h" + +#include "doctest.h" + +#include +#include +#include + +static const char* mainModuleName = "MainModule"; + +namespace Luau +{ + +std::optional TestFileResolver::fromAstFragment(AstExpr* expr) const +{ + auto g = expr->as(); + if (!g) + return std::nullopt; + + std::string_view value = g->name.value; + if (value == "game" || value == "Game" || value == "workspace" || value == "Workspace" || value == "script" || value == "Script") + return ModuleName(value); + + return std::nullopt; +} + +ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const +{ + return lhs + "/" + ModuleName(rhs); +} + +std::optional TestFileResolver::getParentModuleName(const ModuleName& name) const +{ + std::string_view view = name; + const size_t lastSeparatorIndex = view.find_last_of('/'); + + if (lastSeparatorIndex != std::string_view::npos) + { + return ModuleName(view.substr(0, lastSeparatorIndex)); + } + + return std::nullopt; +} + +std::string TestFileResolver::getHumanReadableModuleName(const ModuleName& name) const +{ + return name; +} + +std::optional TestFileResolver::getEnvironmentForModule(const ModuleName& name) const +{ + auto it = environments.find(name); + if (it != environments.end()) + return it->second; + + return std::nullopt; +} + +Fixture::Fixture(bool freeze) + : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) + , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true}) + , typeChecker(frontend.typeChecker) +{ + configResolver.defaultConfig.mode = Mode::Strict; + configResolver.defaultConfig.enabledLint.warningMask = ~0ull; + configResolver.defaultConfig.parseOptions.captureComments = true; + + registerBuiltinTypes(frontend.typeChecker); + registerTestTypes(); + Luau::freeze(frontend.typeChecker.globalTypes); + + Luau::setPrintLine([](auto s) {}); +} + +Fixture::~Fixture() +{ + Luau::resetPrintLine(); +} + +UnfrozenFixture::UnfrozenFixture() + : Fixture(false) +{ +} + +AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& parseOptions) +{ + sourceModule.reset(new SourceModule); + + ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, parseOptions); + + sourceModule->name = fromString(mainModuleName); + sourceModule->root = result.root; + sourceModule->mode = parseMode(result.hotcomments); + sourceModule->ignoreLints = LintWarning::parseMask(result.hotcomments); + + if (!result.errors.empty()) + { + // if AST is available, check how lint and typecheck handle error nodes + if (result.root) + { + frontend.lint(*sourceModule); + + typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + } + + throw ParseErrors(result.errors); + } + + return result.root; +} + +CheckResult Fixture::check(Mode mode, std::string source) +{ + configResolver.defaultConfig.mode = mode; + fileResolver.source[mainModuleName] = std::move(source); + + CheckResult result = frontend.check(fromString(mainModuleName)); + + configResolver.defaultConfig.mode = Mode::Strict; + + return result; +} + +CheckResult Fixture::check(const std::string& source) +{ + ModuleName mm = fromString(mainModuleName); + configResolver.defaultConfig.mode = Mode::Strict; + fileResolver.source[mm] = std::move(source); + frontend.markDirty(mm); + + CheckResult result = frontend.check(mm); + + return result; +} + +LintResult Fixture::lint(const std::string& source, const std::optional& lintOptions) +{ + ParseOptions parseOptions; + configResolver.defaultConfig.mode = Mode::Nonstrict; + parse(source, parseOptions); + + return frontend.lint(*sourceModule, lintOptions); +} + +LintResult Fixture::lintTyped(const std::string& source, const std::optional& lintOptions) +{ + check(source); + ModuleName mm = fromString(mainModuleName); + + return frontend.lint(mm, lintOptions); +} + +ParseResult Fixture::parseEx(const std::string& source, const ParseOptions& options) +{ + ParseResult result = tryParse(source, options); + if (!result.errors.empty()) + throw ParseErrors(result.errors); + + return result; +} + +ParseResult Fixture::tryParse(const std::string& source, const ParseOptions& parseOptions) +{ + ParseOptions options = parseOptions; + options.allowDeclarationSyntax = true; + + sourceModule.reset(new SourceModule); + ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); + sourceModule->root = result.root; + return result; +} + +ParseResult Fixture::matchParseError(const std::string& source, const std::string& message) +{ + ParseOptions options; + options.allowDeclarationSyntax = true; + + sourceModule.reset(new SourceModule); + ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); + + REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); + + CHECK_EQ(result.errors.front().getMessage(), message); + + return result; +} + +ParseResult Fixture::matchParseErrorPrefix(const std::string& source, const std::string& prefix) +{ + ParseOptions options; + options.allowDeclarationSyntax = true; + + sourceModule.reset(new SourceModule); + ParseResult result = Parser::parse(source.c_str(), source.length(), *sourceModule->names, *sourceModule->allocator, options); + + REQUIRE_MESSAGE(!result.errors.empty(), "Expected a parse error in '" << source << "'"); + + const std::string& message = result.errors.front().getMessage(); + CHECK_GE(message.length(), prefix.length()); + CHECK_EQ(prefix, message.substr(0, prefix.size())); + + return result; +} + +ModulePtr Fixture::getMainModule() +{ + return frontend.moduleResolver.getModule(fromString(mainModuleName)); +} + +SourceModule* Fixture::getMainSourceModule() +{ + return frontend.getSourceModule(fromString("MainModule")); +} + +std::optional Fixture::getPrimitiveType(TypeId ty) +{ + REQUIRE(ty != nullptr); + + TypeId aType = follow(ty); + REQUIRE(aType != nullptr); + + const PrimitiveTypeVar* pt = get(aType); + if (pt != nullptr) + return pt->type; + else + return std::nullopt; +} + +std::optional Fixture::getType(const std::string& name) +{ + ModulePtr module = getMainModule(); + REQUIRE(module); + + return lookupName(module->getModuleScope(), name); +} + +TypeId Fixture::requireType(const std::string& name) +{ + std::optional ty = getType(name); + REQUIRE(bool(ty)); + return follow(*ty); +} + +TypeId Fixture::requireType(const ModuleName& moduleName, const std::string& name) +{ + ModulePtr module = frontend.moduleResolver.getModule(moduleName); + REQUIRE(module); + return requireType(module->getModuleScope(), name); +} + +TypeId Fixture::requireType(const ModulePtr& module, const std::string& name) +{ + return requireType(module->getModuleScope(), name); +} + +TypeId Fixture::requireType(const ScopePtr& scope, const std::string& name) +{ + std::optional ty = lookupName(scope, name); + REQUIRE_MESSAGE(ty, "requireType: No type \"" << name << "\""); + return *ty; +} + +std::optional Fixture::findTypeAtPosition(Position position) +{ + ModulePtr module = getMainModule(); + SourceModule* sourceModule = getMainSourceModule(); + return Luau::findTypeAtPosition(*module, *sourceModule, position); +} + +TypeId Fixture::requireTypeAtPosition(Position position) +{ + auto ty = findTypeAtPosition(position); + REQUIRE_MESSAGE(ty, "requireTypeAtPosition: No type at position " << position); + return *ty; +} + +std::optional Fixture::lookupType(const std::string& name) +{ + if (auto typeFun = getMainModule()->getModuleScope()->lookupType(name)) + return typeFun->type; + + return std::nullopt; +} + +std::optional Fixture::lookupImportedType(const std::string& moduleAlias, const std::string& name) +{ + if (auto typeFun = getMainModule()->getModuleScope()->lookupImportedType(moduleAlias, name)) + return typeFun->type; + + return std::nullopt; +} + +std::string Fixture::decorateWithTypes(const std::string& code) +{ + fileResolver.source[mainModuleName] = code; + + Luau::CheckResult typeInfo = frontend.check(mainModuleName); + + SourceModule* sourceModule = frontend.getSourceModule(mainModuleName); + attachTypeData(*sourceModule, *frontend.moduleResolver.getModule(mainModuleName)); + + return transpileWithTypes(*sourceModule->root); +} + +void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) +{ + for (const auto& error : errors) + { + os << std::endl; + os << "Error: " << error << std::endl; + + std::string_view source = fileResolver.source[error.moduleName]; + std::vector lines = Luau::split(source, '\n'); + + if (error.location.begin.line >= lines.size()) + { + os << "\tSource not available?" << std::endl; + return; + } + + std::string_view theLine = lines[error.location.begin.line]; + os << "Line:\t" << theLine << std::endl; + int startCol = error.location.begin.column; + int endCol = error.location.end.line == error.location.begin.line ? error.location.end.column : int(theLine.size()); + + os << '\t' << std::string(startCol, ' ') << std::string(std::max(1, endCol - startCol), '-') << std::endl; + } +} + +void Fixture::registerTestTypes() +{ + addGlobalBinding(typeChecker, "game", typeChecker.anyType, "@luau"); + addGlobalBinding(typeChecker, "workspace", typeChecker.anyType, "@luau"); + addGlobalBinding(typeChecker, "script", typeChecker.anyType, "@luau"); +} + +void Fixture::dumpErrors(const CheckResult& cr) +{ + dumpErrors(std::cout, cr.errors); +} + +void Fixture::dumpErrors(const ModulePtr& module) +{ + dumpErrors(std::cout, module->errors); +} + +void Fixture::dumpErrors(const Module& module) +{ + dumpErrors(std::cout, module.errors); +} + +std::string Fixture::getErrors(const CheckResult& cr) +{ + std::stringstream ss; + dumpErrors(ss, cr.errors); + return ss.str(); +} + +void Fixture::validateErrors(const std::vector& errors) +{ + std::ostringstream oss; + + // This helps us validate that error stringification doesn't crash, using both user-facing and internal test-only representation + // Also we exercise error comparison to make sure it's at least able to compare the error equal to itself + for (const Luau::TypeError& e : errors) + { + oss.clear(); + oss << e; + toString(e); + // CHECK(e == e); TODO: this doesn't work due to union/intersection type vars + } +} + +LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) +{ + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, source, "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); + return result; +} + +ModuleName fromString(std::string_view name) +{ + return ModuleName(name); +} + +std::string rep(const std::string& s, size_t n) +{ + std::string r; + r.reserve(s.length() * n); + for (size_t i = 0; i < n; ++i) + r += s; + return r; +} + +bool isInArena(TypeId t, const TypeArena& arena) +{ + return arena.typeVars.contains(t); +} + +void dumpErrors(const ModulePtr& module) +{ + for (const auto& error : module->errors) + std::cout << "Error: " << error << std::endl; +} + +void dump(const std::string& name, TypeId ty) +{ + std::cout << name << '\t' << toString(ty, {true}) << std::endl; +} + +std::optional lookupName(ScopePtr scope, const std::string& name) +{ + auto binding = scope->linearSearchForBinding(name); + if (binding) + return binding->typeId; + else + return std::nullopt; +} + +} // namespace Luau diff --git a/tests/Fixture.h b/tests/Fixture.h new file mode 100644 index 0000000..c6294b0 --- /dev/null +++ b/tests/Fixture.h @@ -0,0 +1,205 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Config.h" +#include "Luau/FileResolver.h" +#include "Luau/Frontend.h" +#include "Luau/IostreamHelpers.h" +#include "Luau/Linter.h" +#include "Luau/Location.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Parser.h" +#include "Luau/ToString.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "IostreamOptional.h" +#include "ScopedFlags.h" + +#include +#include +#include + +#include + +namespace Luau +{ + +struct TestFileResolver + : FileResolver + , ModuleResolver +{ + std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override + { + if (auto name = pathExprToModuleName(currentModuleName, pathExpr)) + return {{*name, false}}; + + return std::nullopt; + } + + const ModulePtr getModule(const ModuleName& moduleName) const override + { + LUAU_ASSERT(false); + return nullptr; + } + + bool moduleExists(const ModuleName& moduleName) const override + { + auto it = source.find(moduleName); + return (it != source.end()); + } + + std::optional readSource(const ModuleName& name) override + { + auto it = source.find(name); + if (it == source.end()) + return std::nullopt; + + SourceCode::Type sourceType = SourceCode::Module; + + auto it2 = sourceTypes.find(name); + if (it2 != sourceTypes.end()) + sourceType = it2->second; + + return SourceCode{it->second, sourceType}; + } + + std::optional fromAstFragment(AstExpr* expr) const override; + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override; + std::optional getParentModuleName(const ModuleName& name) const override; + + std::string getHumanReadableModuleName(const ModuleName& name) const override; + + std::optional getEnvironmentForModule(const ModuleName& name) const override; + + std::unordered_map source; + std::unordered_map sourceTypes; + std::unordered_map environments; +}; + +struct TestConfigResolver : ConfigResolver +{ + Config defaultConfig; + std::unordered_map configFiles; + + const Config& getConfig(const ModuleName& name) const override + { + auto it = configFiles.find(name); + if (it != configFiles.end()) + return it->second; + + return defaultConfig; + } +}; + +struct Fixture +{ + explicit Fixture(bool freeze = true); + ~Fixture(); + + // Throws Luau::ParseErrors if the parse fails. + AstStatBlock* parse(const std::string& source, const ParseOptions& parseOptions = {}); + CheckResult check(Mode mode, std::string source); + CheckResult check(const std::string& source); + + LintResult lint(const std::string& source, const std::optional& lintOptions = {}); + LintResult lintTyped(const std::string& source, const std::optional& lintOptions = {}); + + /// Parse with all language extensions enabled + ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); + ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {}); + ParseResult matchParseError(const std::string& source, const std::string& message); + // Verify a parse error occurs and the parse error message has the specified prefix + ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); + + ModulePtr getMainModule(); + SourceModule* getMainSourceModule(); + + std::optional getPrimitiveType(TypeId ty); + std::optional getType(const std::string& name); + TypeId requireType(const std::string& name); + TypeId requireType(const ModuleName& moduleName, const std::string& name); + TypeId requireType(const ModulePtr& module, const std::string& name); + TypeId requireType(const ScopePtr& scope, const std::string& name); + + std::optional findTypeAtPosition(Position position); + TypeId requireTypeAtPosition(Position position); + + std::optional lookupType(const std::string& name); + std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); + + ScopedFastFlag sff_DebugLuauFreezeArena; + + TestFileResolver fileResolver; + TestConfigResolver configResolver; + std::unique_ptr sourceModule; + Frontend frontend; + TypeChecker& typeChecker; + + std::string decorateWithTypes(const std::string& code); + + void dumpErrors(std::ostream& os, const std::vector& errors); + + void dumpErrors(const CheckResult& cr); + void dumpErrors(const ModulePtr& module); + void dumpErrors(const Module& module); + + void validateErrors(const std::vector& errors); + + std::string getErrors(const CheckResult& cr); + + void registerTestTypes(); + + LoadDefinitionFileResult loadDefinition(const std::string& source); +}; + +// Disables arena freezing for a given test case. +// Do not use this in new tests. If you are running into access violations, you +// are violating Luau's memory model - the fix is not to use UnfrozenFixture. +// Related: CLI-45692 +struct UnfrozenFixture : Fixture +{ + UnfrozenFixture(); +}; + +ModuleName fromString(std::string_view name); + +template +std::optional get(const std::map& map, const Name& name) +{ + auto it = map.find(name); + if (it != map.end()) + return std::optional(it->second); + else + return std::nullopt; +} + +std::string rep(const std::string& s, size_t n); + +bool isInArena(TypeId t, const TypeArena& arena); + +void dumpErrors(const ModulePtr& module); +void dumpErrors(const Module& module); +void dump(const std::string& name, TypeId ty); + +std::optional lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2) + +} // namespace Luau + +#define LUAU_REQUIRE_ERRORS(result) \ + do \ + { \ + auto&& r = (result); \ + validateErrors(r.errors); \ + REQUIRE(!r.errors.empty()); \ + } while (false) + +#define LUAU_REQUIRE_ERROR_COUNT(count, result) \ + do \ + { \ + auto&& r = (result); \ + validateErrors(r.errors); \ + REQUIRE_MESSAGE(count == r.errors.size(), getErrors(r)); \ + } while (false) + +#define LUAU_REQUIRE_NO_ERRORS(result) LUAU_REQUIRE_ERROR_COUNT(0, result) diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp new file mode 100644 index 0000000..3f33a5d --- /dev/null +++ b/tests/Frontend.test.cpp @@ -0,0 +1,965 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Frontend.h" +#include "Luau/Parser.h" +#include "Luau/RequireTracer.h" + +#include "Fixture.h" + +#include "doctest.h" + +#include + +using namespace Luau; + +namespace +{ + +struct NaiveModuleResolver : ModuleResolver +{ + std::optional resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override + { + if (auto name = pathExprToModuleName(currentModuleName, pathExpr)) + return {{*name, false}}; + + return std::nullopt; + } + + const ModulePtr getModule(const ModuleName& moduleName) const override + { + return nullptr; + } + + bool moduleExists(const ModuleName& moduleName) const override + { + return false; + } + + std::string getHumanReadableModuleName(const ModuleName& moduleName) const override + { + return moduleName; + } +}; + +NaiveModuleResolver naiveModuleResolver; + +struct NaiveFileResolver : NullFileResolver +{ + std::optional fromAstFragment(AstExpr* expr) const override + { + AstExprGlobal* g = expr->as(); + if (g && g->name == "Modules") + return "Modules"; + + if (g && g->name == "game") + return "game"; + + return std::nullopt; + } + + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override + { + return lhs + "/" + ModuleName(rhs); + } +}; + +} // namespace + +struct FrontendFixture : Fixture +{ + FrontendFixture() + { + addGlobalBinding(typeChecker, "game", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(typeChecker, "script", frontend.typeChecker.anyType, "@test"); + } +}; + +TEST_SUITE_BEGIN("FrontendTest"); + +TEST_CASE_FIXTURE(FrontendFixture, "find_a_require") +{ + AstStatBlock* program = parse(R"( + local M = require(Modules.Foo.Bar) + )"); + + NaiveFileResolver naiveFileResolver; + + auto res = traceRequires(&naiveFileResolver, program, ""); + CHECK_EQ(1, res.requires.size()); + CHECK_EQ(res.requires[0].first, "Modules/Foo/Bar"); +} + +// It could be argued that this should not work. +TEST_CASE_FIXTURE(FrontendFixture, "find_a_require_inside_a_function") +{ + AstStatBlock* program = parse(R"( + function foo() + local M = require(Modules.Foo.Bar) + end + )"); + + NaiveFileResolver naiveFileResolver; + + auto res = traceRequires(&naiveFileResolver, program, ""); + CHECK_EQ(1, res.requires.size()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "real_source") +{ + AstStatBlock* program = parse(R"( + return function() + local Modules = game:GetService("CoreGui").Gui.Modules + + local Roact = require(Modules.Common.Roact) + local Rodux = require(Modules.Common.Rodux) + + local AppReducer = require(Modules.LuaApp.AppReducer) + local AEAppReducer = require(Modules.LuaApp.Reducers.AEReducers.AEAppReducer) + local AETabList = require(Modules.LuaApp.Components.Avatar.UI.Views.Portrait.AETabList) + local mockServices = require(Modules.LuaApp.TestHelpers.mockServices) + local DeviceOrientationMode = require(Modules.LuaApp.DeviceOrientationMode) + local MockAvatarEditorTheme = require(Modules.LuaApp.TestHelpers.MockAvatarEditorTheming) + local FFlagAvatarEditorEnableThemes = settings():GetFFlag("AvatarEditorEnableThemes2") + end + )"); + + NaiveFileResolver naiveFileResolver; + + auto res = traceRequires(&naiveFileResolver, program, ""); + CHECK_EQ(8, res.requires.size()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_dependent_scripts") +{ + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + return {b_value = A.hello} + )"; + + frontend.check("game/Gui/Modules/B"); + + ModulePtr bModule = frontend.moduleResolver.modules["game/Gui/Modules/B"]; + REQUIRE(bModule != nullptr); + CHECK(bModule->errors.empty()); + Luau::dumpErrors(bModule); + + auto bExports = first(bModule->getModuleScope()->returnType); + REQUIRE(!!bExports); + + CHECK_EQ("{| b_value: number |}", toString(*bExports)); +} + +TEST_CASE_FIXTURE(FrontendFixture, "automatically_check_cyclically_dependent_scripts") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {} + )"; + fileResolver.source["game/Gui/Modules/B"] = R"( + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + require(Modules.C) + return {} + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + do local A = require(Modules.A) end + return {} + )"; + + fileResolver.source["game/Gui/Modules/D"] = R"( + local Modules = game:GetService('Gui').Modules + do local A = require(Modules.A) end + return {} + )"; + + CheckResult result1 = frontend.check("game/Gui/Modules/B"); + LUAU_REQUIRE_ERROR_COUNT(4, result1); + + CHECK_MESSAGE(get(result1.errors[0]), "Should have been a ModuleHasCyclicDependency: " << toString(result1.errors[0])); + + CHECK_MESSAGE(get(result1.errors[1]), "Should have been a ModuleHasCyclicDependency: " << toString(result1.errors[1])); + + CHECK_MESSAGE(get(result1.errors[2]), "Should have been a ModuleHasCyclicDependency: " << toString(result1.errors[2])); + + CheckResult result2 = frontend.check("game/Gui/Modules/D"); + LUAU_REQUIRE_ERROR_COUNT(0, result2); +} + +TEST_CASE_FIXTURE(FrontendFixture, "any_annotation_breaks_cycle") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {hello = B.hello} + )"; + fileResolver.source["game/Gui/Modules/B"] = R"( + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) :: any + return {hello = A.hello} + )"; + + CheckResult result = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(FrontendFixture, "nocheck_modules_are_typed") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + --!nocheck + export type Foo = number + return {hello = "hi"} + )"; + fileResolver.source["game/Gui/Modules/B"] = R"( + --!nonstrict + export type Foo = number + return {hello = "hi"} + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + local B = require(Modules.B) + local five : A.Foo = 5 + )"; + + CheckResult result = frontend.check("game/Gui/Modules/C"); + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr aModule = frontend.moduleResolver.modules["game/Gui/Modules/A"]; + REQUIRE(bool(aModule)); + + std::optional aExports = first(aModule->getModuleScope()->returnType); + REQUIRE(bool(aExports)); + + ModulePtr bModule = frontend.moduleResolver.modules["game/Gui/Modules/B"]; + REQUIRE(bool(bModule)); + + std::optional bExports = first(bModule->getModuleScope()->returnType); + REQUIRE(bool(bExports)); + + CHECK_EQ(toString(*aExports), toString(*bExports)); +} + +TEST_CASE_FIXTURE(FrontendFixture, "cycle_detection_between_check_and_nocheck") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {hello = B.hello} + )"; + fileResolver.source["game/Gui/Modules/B"] = R"( + --!nocheck + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + return {hello = A.hello} + )"; + + CheckResult result = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(FrontendFixture, "nocheck_cycle_used_by_checked") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + --!nocheck + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {hello = B.hello} + )"; + fileResolver.source["game/Gui/Modules/B"] = R"( + --!nocheck + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + return {hello = A.hello} + )"; + fileResolver.source["game/Gui/Modules/C"] = R"( + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + local B = require(Modules.B) + return {a=A, b=B} + )"; + + CheckResult result = frontend.check("game/Gui/Modules/C"); + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr cModule = frontend.moduleResolver.modules["game/Gui/Modules/C"]; + REQUIRE(bool(cModule)); + + std::optional cExports = first(cModule->getModuleScope()->returnType); + REQUIRE(bool(cExports)); + CHECK_EQ("{| a: any, b: any |}", toString(*cExports)); +} + +TEST_CASE_FIXTURE(FrontendFixture, "cycle_detection_disabled_in_nocheck") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + --!nocheck + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {hello = B.hello} + )"; + fileResolver.source["game/Gui/Modules/B"] = R"( + --!nocheck + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + return {hello = A.hello} + )"; + + CheckResult result = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(FrontendFixture, "cycle_errors_can_be_fixed") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {hello = B.hello} + )"; + fileResolver.source["game/Gui/Modules/B"] = R"( + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + return {hello = A.hello} + )"; + + CheckResult result1 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(2, result1); + + CHECK_MESSAGE(get(result1.errors[0]), "Should have been a ModuleHasCyclicDependency: " << toString(result1.errors[0])); + + CHECK_MESSAGE(get(result1.errors[1]), "Should have been a ModuleHasCyclicDependency: " << toString(result1.errors[1])); + + fileResolver.source["game/Gui/Modules/B"] = R"( + return {hello = 42} + )"; + frontend.markDirty("game/Gui/Modules/B"); + + CheckResult result2 = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result2); +} + +TEST_CASE_FIXTURE(FrontendFixture, "cycle_error_paths") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + local Modules = game:GetService('Gui').Modules + local B = require(Modules.B) + return {hello = B.hello} + )"; + fileResolver.source["game/Gui/Modules/B"] = R"( + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + return {hello = A.hello} + )"; + + CheckResult result = frontend.check("game/Gui/Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + + auto ce1 = get(result.errors[0]); + REQUIRE(ce1); + CHECK_EQ(result.errors[0].moduleName, "game/Gui/Modules/B"); + REQUIRE_EQ(ce1->cycle.size(), 2); + CHECK_EQ(ce1->cycle[0], "game/Gui/Modules/A"); + CHECK_EQ(ce1->cycle[1], "game/Gui/Modules/B"); + + auto ce2 = get(result.errors[1]); + REQUIRE(ce2); + CHECK_EQ(result.errors[1].moduleName, "game/Gui/Modules/A"); + REQUIRE_EQ(ce2->cycle.size(), 2); + CHECK_EQ(ce2->cycle[0], "game/Gui/Modules/B"); + CHECK_EQ(ce2->cycle[1], "game/Gui/Modules/A"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "dont_reparse_clean_file_when_linting") +{ + fileResolver.source["Modules/A"] = R"( + local t = {} + + for i=#t,1 do + end + + for i=#t,1,-1 do + end + )"; + + frontend.check("Modules/A"); + + fileResolver.source["Modules/A"] = R"( + -- We have fixed the lint error, but we did not tell the Frontend that the file is changed! + -- Therefore, we expect Frontend to reuse the parse tree. + )"; + + configResolver.defaultConfig.enabledLint.enableWarning(LintWarning::Code_ForRange); + + LintResult lintResult = frontend.lint("Modules/A"); + + CHECK_EQ(1, lintResult.warnings.size()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "dont_recheck_script_that_hasnt_been_marked_dirty") +{ + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + return {b_value = A.hello} + )"; + + frontend.check("game/Gui/Modules/B"); + + fileResolver.source["game/Gui/Modules/A"] = + "Massively incorrect syntax haha oops! However! The frontend doesn't know that this file needs reparsing!"; + + frontend.check("game/Gui/Modules/B"); + + ModulePtr bModule = frontend.moduleResolver.modules["game/Gui/Modules/B"]; + CHECK(bModule->errors.empty()); + Luau::dumpErrors(bModule); +} + +TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_is_dirty") +{ + fileResolver.source["game/Gui/Modules/A"] = "return {hello=5, world=true}"; + fileResolver.source["game/Gui/Modules/B"] = R"( + local Modules = game:GetService('Gui').Modules + local A = require(Modules.A) + return {b_value = A.hello} + )"; + + frontend.check("game/Gui/Modules/B"); + + fileResolver.source["game/Gui/Modules/A"] = "return {hello='hi!'}"; + frontend.markDirty("game/Gui/Modules/A"); + + frontend.check("game/Gui/Modules/B"); + + ModulePtr bModule = frontend.moduleResolver.modules["game/Gui/Modules/B"]; + CHECK(bModule->errors.empty()); + Luau::dumpErrors(bModule); + + auto bExports = first(bModule->getModuleScope()->returnType); + REQUIRE(!!bExports); + + CHECK_EQ("{| b_value: string |}", toString(*bExports)); +} + +#if 0 +// Does not work yet. :( +TEST_CASE_FIXTURE(FrontendFixture, "recheck_if_dependent_script_has_a_parse_error") +{ + fileResolver.source["Modules/A"] = "oh no a syntax error"; + fileResolver.source["Modules/B"] = R"( + local Modules = {} + local A = require(Modules.A) + return {} + )"; + + CheckResult result = frontend.check("Modules/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Modules/A", result.errors[0].moduleName); + + CheckResult result2 = frontend.check("Modules/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result2); + CHECK_EQ(result2.errors[0], result.errors[0]); +} +#endif + +TEST_CASE_FIXTURE(FrontendFixture, "produce_errors_for_unchanged_file_with_a_syntax_error") +{ + fileResolver.source["Modules/A"] = "oh no a blatant syntax error!!"; + + CheckResult one = frontend.check("Modules/A"); + CheckResult two = frontend.check("Modules/A"); + + CHECK(!one.errors.empty()); + CHECK(!two.errors.empty()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "produce_errors_for_unchanged_file_with_errors") +{ + fileResolver.source["Modules/A"] = "local p: number = 'oh no a type error'"; + + frontend.check("Modules/A"); + + fileResolver.source["Modules/A"] = "local p = 4 -- We have fixed the problem, but we didn't tell the frontend, so it will not recheck this file!"; + CheckResult secondResult = frontend.check("Modules/A"); + + CHECK_EQ(1, secondResult.errors.size()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "reports_errors_from_multiple_sources") +{ + fileResolver.source["game/Gui/Modules/A"] = R"( + local a: number = 'oh no a type error' + return {a=a} + )"; + + fileResolver.source["game/Gui/Modules/B"] = R"( + local Modules = script.Parent + local A = require(Modules.A) + local b: number = 'another one! This is quite distressing!' + )"; + + CheckResult result = frontend.check("game/Gui/Modules/B"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ("game/Gui/Modules/A", result.errors[0].moduleName); + CHECK_EQ("game/Gui/Modules/B", result.errors[1].moduleName); +} + +TEST_CASE_FIXTURE(FrontendFixture, "report_require_to_nonexistent_file") +{ + fileResolver.source["Modules/A"] = R"( + local Modules = script + local B = require(Modules.B) + )"; + + CheckResult result = frontend.check("Modules/A"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + std::string s = toString(result.errors[0]); + CHECK_MESSAGE(get(result.errors[0]), "Should have been an UnknownRequire: " << toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(FrontendFixture, "ignore_require_to_nonexistent_file") +{ + fileResolver.source["Modules/A"] = R"( + local Modules = script + local B = require(Modules.B :: any) + )"; + + CheckResult result = frontend.check("Modules/A"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(FrontendFixture, "report_syntax_error_in_required_file") +{ + fileResolver.source["Modules/A"] = "oh no a gross breach of syntax"; + fileResolver.source["Modules/B"] = R"( + local Modules = script.Parent + local A = require(Modules.A) + )"; + + CheckResult result = frontend.check("Modules/B"); + LUAU_REQUIRE_ERRORS(result); + + CHECK_EQ("Modules/A", result.errors[0].moduleName); + + bool b = std::any_of(begin(result.errors), end(result.errors), [](auto&& e) -> bool { + return get(e); + }); + if (!b) + { + CHECK_MESSAGE(false, "Expected a syntax error!"); + dumpErrors(result); + } +} + +TEST_CASE_FIXTURE(FrontendFixture, "re_report_type_error_in_required_file") +{ + fileResolver.source["Modules/A"] = R"( + local n: number = 'five' + return {n=n} + )"; + + fileResolver.source["Modules/B"] = R"( + local Modules = script.Parent + local A = require(Modules.A) + print(A.n) + )"; + + CheckResult result = frontend.check("Modules/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CheckResult result2 = frontend.check("Modules/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result2); + + CHECK_EQ("Modules/A", result.errors[0].moduleName); +} + +TEST_CASE_FIXTURE(FrontendFixture, "accumulate_cached_errors") +{ + fileResolver.source["Modules/A"] = R"( + local n: number = 'five' + return {n=n} + )"; + + fileResolver.source["Modules/B"] = R"( + local Modules = script.Parent + local A = require(Modules.A) + local b: number = 'seven' + print(A, b) + )"; + + CheckResult result1 = frontend.check("Modules/B"); + + LUAU_REQUIRE_ERROR_COUNT(2, result1); + + CHECK_EQ("Modules/A", result1.errors[0].moduleName); + CHECK_EQ("Modules/B", result1.errors[1].moduleName); + + CheckResult result2 = frontend.check("Modules/B"); + + LUAU_REQUIRE_ERROR_COUNT(2, result2); + + CHECK_EQ("Modules/A", result2.errors[0].moduleName); + CHECK_EQ("Modules/B", result2.errors[1].moduleName); +} + +TEST_CASE_FIXTURE(FrontendFixture, "accumulate_cached_errors_in_consistent_order") +{ + fileResolver.source["Modules/A"] = R"( + a = 1 + b = 2 + local Modules = script.Parent + local A = require(Modules.B) + )"; + + fileResolver.source["Modules/B"] = R"( + d = 3 + e = 4 + return {} + )"; + + CheckResult result1 = frontend.check("Modules/A"); + + LUAU_REQUIRE_ERROR_COUNT(4, result1); + + CHECK_EQ("Modules/A", result1.errors[2].moduleName); + CHECK_EQ("Modules/A", result1.errors[3].moduleName); + + CHECK_EQ("Modules/B", result1.errors[0].moduleName); + CHECK_EQ("Modules/B", result1.errors[1].moduleName); + + CheckResult result2 = frontend.check("Modules/A"); + CHECK_EQ(4, result2.errors.size()); + + for (size_t i = 0; i < result1.errors.size(); ++i) + CHECK_EQ(result1.errors[i], result2.errors[i]); +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_pruneParentSegments") +{ + CHECK_EQ(std::optional{"Modules/Enum/ButtonState"}, + pathExprToModuleName("", {"Modules", "LuaApp", "DeprecatedDarkTheme", "Parent", "Parent", "Enum", "ButtonState"})); + CHECK_EQ(std::optional{"workspace/Foo/Bar/Baz"}, pathExprToModuleName("workspace/Foo/Quux", {"script", "Parent", "Bar", "Baz"})); + CHECK_EQ(std::nullopt, pathExprToModuleName("", {})); + CHECK_EQ(std::optional{"script"}, pathExprToModuleName("", {"script"})); + CHECK_EQ(std::optional{"script/Parent"}, pathExprToModuleName("", {"script", "Parent"})); + CHECK_EQ(std::optional{"script"}, pathExprToModuleName("", {"script", "Parent", "Parent"})); + CHECK_EQ(std::optional{"script"}, pathExprToModuleName("", {"script", "Test", "Parent"})); + CHECK_EQ(std::optional{"script/Parent"}, pathExprToModuleName("", {"script", "Test", "Parent", "Parent"})); + CHECK_EQ(std::optional{"script/Parent"}, pathExprToModuleName("", {"script", "Test", "Parent", "Test", "Parent", "Parent"})); +} + +TEST_CASE_FIXTURE(FrontendFixture, "test_lint_uses_correct_config") +{ + fileResolver.source["Module/A"] = R"( + local t = {} + + for i=#t,1 do + end + )"; + + configResolver.configFiles["Module/A"].enabledLint.enableWarning(LintWarning::Code_ForRange); + + auto result = frontend.lint("Module/A"); + CHECK_EQ(1, result.warnings.size()); + + configResolver.configFiles["Module/A"].enabledLint.disableWarning(LintWarning::Code_ForRange); + + auto result2 = frontend.lint("Module/A"); + CHECK_EQ(0, result2.warnings.size()); + + LintOptions overrideOptions; + + overrideOptions.enableWarning(LintWarning::Code_ForRange); + auto result3 = frontend.lint("Module/A", overrideOptions); + CHECK_EQ(1, result3.warnings.size()); + + overrideOptions.disableWarning(LintWarning::Code_ForRange); + auto result4 = frontend.lint("Module/A", overrideOptions); + CHECK_EQ(0, result4.warnings.size()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "lintFragment") +{ + LintOptions lintOptions; + lintOptions.enableWarning(LintWarning::Code_ForRange); + + auto [_sourceModule, result] = frontend.lintFragment(R"( + local t = {} + + for i=#t,1 do + end + + for i=#t,1,-1 do + end + )", + lintOptions); + + CHECK_EQ(1, result.warnings.size()); + CHECK_EQ(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "discard_type_graphs") +{ + Frontend fe{&fileResolver, &configResolver, {false}}; + + fileResolver.source["Module/A"] = R"( + local a = {1,2,3,4,5} + )"; + + CheckResult result = fe.check("Module/A"); + + ModulePtr module = fe.moduleResolver.getModule("Module/A"); + + CHECK_EQ(0, module->internalTypes.typeVars.size()); + CHECK_EQ(0, module->internalTypes.typePacks.size()); + CHECK_EQ(0, module->astTypes.size()); +} + +TEST_CASE_FIXTURE(FrontendFixture, "it_should_be_safe_to_stringify_errors_when_full_type_graph_is_discarded") +{ + Frontend fe{&fileResolver, &configResolver, {false}}; + + fileResolver.source["Module/A"] = R"( + --!strict + local a: {Count: number} = {count='five'} + )"; + + CheckResult result = fe.check("Module/A"); + + REQUIRE_EQ(1, result.errors.size()); + + // When this test fails, it is because the TypeIds needed by the error have been deallocated. + // It is thus basically impossible to predict what will happen when this assert is evaluated. + // It could segfault, or you could see weird type names like the empty string or + REQUIRE_EQ( + "Table type 'a' not compatible with type '{| Count: number |}' because the former is missing field 'Count'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(FrontendFixture, "trace_requires_in_nonstrict_mode") +{ + fileResolver.source["Module/A"] = R"( + --!nonstrict + local module = {} + + function module.f(arg: number) + print('f', arg) + end + + return module + )"; + + fileResolver.source["Module/B"] = R"( + --!nonstrict + local A = require(script.Parent.A) + + print(A.g(5)) -- Key 'g' not found + print(A.f('five')) -- Type mismatch number and string + print(A.f(5)) -- OK + )"; + + CheckResult result = frontend.check("Module/B"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(4, result.errors[0].location.begin.line); + CHECK_EQ(5, result.errors[1].location.begin.line); +} + +TEST_CASE_FIXTURE(FrontendFixture, "environments") +{ + ScopePtr testScope = frontend.addEnvironment("test"); + + unfreeze(typeChecker.globalTypes); + loadDefinitionFile(typeChecker, testScope, R"( + export type Foo = number | string + )", + "@test"); + freeze(typeChecker.globalTypes); + + fileResolver.source["A"] = R"( + --!nonstrict + local foo: Foo = 1 + )"; + + fileResolver.source["B"] = R"( + --!nonstrict + local foo: Foo = 1 + )"; + + fileResolver.environments["A"] = "test"; + + CheckResult resultA = frontend.check("A"); + LUAU_REQUIRE_NO_ERRORS(resultA); + + CheckResult resultB = frontend.check("B"); + LUAU_REQUIRE_ERROR_COUNT(1, resultB); +} + +TEST_CASE_FIXTURE(FrontendFixture, "ast_node_at_position") +{ + check(R"( + local t = {} + + function t:aa() end + + t: + )"); + + SourceModule* module = getMainSourceModule(); + Position pos = module->root->location.end; + AstNode* node = findNodeAtPosition(*module, pos); + + REQUIRE(node); + REQUIRE(bool(node->asExpr())); + + ++pos.column; + AstNode* node2 = findNodeAtPosition(*module, pos); + CHECK_EQ(node, node2); +} + +TEST_CASE_FIXTURE(FrontendFixture, "stats_are_not_reset_between_checks") +{ + fileResolver.source["Module/A"] = R"( + --!strict + local B = require(script.Parent.B) + local foo = B.foo + 1 + )"; + + fileResolver.source["Module/B"] = R"( + --!strict + return {foo = 1} + )"; + + CheckResult r1 = frontend.check("Module/A"); + LUAU_REQUIRE_NO_ERRORS(r1); + + Frontend::Stats stats1 = frontend.stats; + CHECK_EQ(2, stats1.files); + + frontend.markDirty("Module/A"); + frontend.markDirty("Module/B"); + + CheckResult r2 = frontend.check("Module/A"); + LUAU_REQUIRE_NO_ERRORS(r2); + Frontend::Stats stats2 = frontend.stats; + + CHECK_EQ(4, stats2.files); +} + +TEST_CASE_FIXTURE(FrontendFixture, "clearStats") +{ + fileResolver.source["Module/A"] = R"( + --!strict + local B = require(script.Parent.B) + local foo = B.foo + 1 + )"; + + fileResolver.source["Module/B"] = R"( + --!strict + return {foo = 1} + )"; + + CheckResult r1 = frontend.check("Module/A"); + LUAU_REQUIRE_NO_ERRORS(r1); + + Frontend::Stats stats1 = frontend.stats; + CHECK_EQ(2, stats1.files); + + frontend.markDirty("Module/A"); + frontend.markDirty("Module/B"); + + frontend.clearStats(); + CheckResult r2 = frontend.check("Module/A"); + LUAU_REQUIRE_NO_ERRORS(r2); + Frontend::Stats stats2 = frontend.stats; + + CHECK_EQ(2, stats2.files); +} + +TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") +{ + ScopedFastFlag sffs("LuauTypeCheckTwice", true); + + fileResolver.source["Module/A"] = R"( + local a = 1 + )"; + + CheckResult result = frontend.check("Module/A"); + + ModulePtr module = frontend.moduleResolver.getModule("Module/A"); + + REQUIRE_EQ(module->astTypes.size(), 1); + auto it = module->astTypes.begin(); + CHECK_EQ(toString(it->second), "number"); +} + +TEST_CASE_FIXTURE(FrontendFixture, "imported_table_modification_2") +{ + frontend.options.retainFullTypeGraphs = false; + + fileResolver.source["Module/A"] = R"( +--!nonstrict +local a = {} +a.x = 1 +return a; + )"; + + fileResolver.source["Module/B"] = R"( +--!nonstrict +local a = require(script.Parent.A) +local b = {} +function a:b() end -- this should error, but doesn't +return b + )"; + + fileResolver.source["Module/C"] = R"( +--!nonstrict +local a = require(script.Parent.A) +local b = require(script.Parent.B) +a:b() -- this should error, since A doesn't define a:b() + )"; + + CheckResult resultA = frontend.check("Module/A"); + LUAU_REQUIRE_NO_ERRORS(resultA); + + CheckResult resultB = frontend.check("Module/B"); + // TODO (CLI-45592): this should error, since we shouldn't be adding properties to objects from other modules + LUAU_REQUIRE_NO_ERRORS(resultB); + + CheckResult resultC = frontend.check("Module/C"); + LUAU_REQUIRE_ERRORS(resultC); +} + +// This test does not use TEST_CASE_FIXTURE because we need to set a flag before +// the fixture is constructed. +TEST_CASE("no_use_after_free_with_type_fun_instantiation") +{ + // This flag forces this test to crash if there's a UAF in this code. + ScopedFastFlag sff_DebugLuauFreezeArena("DebugLuauFreezeArena", true); + ScopedFastFlag sff_LuauCloneCorrectlyBeforeMutatingTableType("LuauCloneCorrectlyBeforeMutatingTableType", true); + + FrontendFixture fix; + + fix.fileResolver.source["Module/A"] = R"( +export type Foo = typeof(setmetatable({}, {})) +return false; +)"; + + fix.fileResolver.source["Module/B"] = R"( +local A = require(script.Parent.A) +export type Foo = A.Foo +return false; +)"; + + // We don't care about the result. That we haven't crashed is enough. + fix.frontend.check("Module/B"); +} + +TEST_SUITE_END(); diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h new file mode 100644 index 0000000..9f87489 --- /dev/null +++ b/tests/IostreamOptional.h @@ -0,0 +1,18 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) +{ + return lhs << "none"; +} + +template +std::ostream& operator<<(std::ostream& lhs, const std::optional& t) +{ + if (t) + return lhs << *t; + else + return lhs << "none"; +} diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp new file mode 100644 index 0000000..4a71727 --- /dev/null +++ b/tests/JsonEncoder.test.cpp @@ -0,0 +1,53 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Ast.h" +#include "Luau/JsonEncoder.h" + +#include "doctest.h" + +#include + +using namespace Luau; + +TEST_SUITE_BEGIN("JsonEncoderTests"); + +TEST_CASE("encode_constants") +{ + AstExprConstantNil nil{Location()}; + AstExprConstantBool b{Location(), true}; + AstExprConstantNumber n{Location(), 8.2}; + + CHECK_EQ(R"({"type":"AstExprConstantNil","location":"0,0 - 0,0"})", toJson(&nil)); + CHECK_EQ(R"({"type":"AstExprConstantBool","location":"0,0 - 0,0","value":true})", toJson(&b)); + CHECK_EQ(R"({"type":"AstExprConstantNumber","location":"0,0 - 0,0","value":8.2})", toJson(&n)); +} + +TEST_CASE("basic_escaping") +{ + std::string s = "hello \"world\""; + AstArray theString{s.data(), s.size()}; + AstExprConstantString str{Location(), theString}; + + std::string expected = R"({"type":"AstExprConstantString","location":"0,0 - 0,0","value":"hello \"world\""})"; + CHECK_EQ(expected, toJson(&str)); +} + +TEST_CASE("encode_AstStatBlock") +{ + AstLocal astlocal{AstName{"a_local"}, Location(), nullptr, 0, 0, nullptr}; + AstLocal* astlocalarray[] = {&astlocal}; + + AstArray vars{astlocalarray, 1}; + AstArray values{nullptr, 0}; + AstStatLocal local{Location(), vars, values, std::nullopt}; + AstStat* statArray[] = {&local}; + + AstArray bodyArray{statArray, 1}; + + AstStatBlock block{Location(), bodyArray}; + + CHECK_EQ( + (R"({"type":"AstStatBlock","location":"0,0 - 0,0","body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":["a_local"],"values":[]}]})"), + toJson(&block)); +} + +TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp new file mode 100644 index 0000000..c8eff39 --- /dev/null +++ b/tests/Linter.test.cpp @@ -0,0 +1,1494 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Linter.h" +#include "Luau/BuiltinDefinitions.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("Linter"); + +TEST_CASE_FIXTURE(Fixture, "CleanCode") +{ + LintResult result = lint(R"( +function fib(n) + return n < 2 and 1 or fib(n-1) + fib(n-2) +end + +return math.max(fib(5), 1) +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "UnknownGlobal") +{ + LintResult result = lint("--!nocheck\nreturn foo"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Unknown global 'foo'"); +} + +TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") +{ + // Normally this would be defined externally, so hack it in for testing + addGlobalBinding(typeChecker, "Wait", Binding{typeChecker.anyType, {}, true, "wait", "@test/global/Wait"}); + + LintResult result = lintTyped("Wait(5)"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Global 'Wait' is deprecated, use 'wait' instead"); +} + +TEST_CASE_FIXTURE(Fixture, "PlaceholderRead") +{ + LintResult result = lint(R"( +local _ = 5 +return _ +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); +} + +TEST_CASE_FIXTURE(Fixture, "PlaceholderWrite") +{ + LintResult result = lint(R"( +local _ = 5 +_ = 6 +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "BuiltinGlobalWrite") +{ + LintResult result = lint(R"( +math = {} + +function assert(x) +end + +assert(5) +)"); + + CHECK_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, "Built-in global 'math' is overwritten here; consider using a local or changing the name"); + CHECK_EQ(result.warnings[1].text, "Built-in global 'assert' is overwritten here; consider using a local or changing the name"); +} + +TEST_CASE_FIXTURE(Fixture, "MultilineBlock") +{ + LintResult result = lint(R"( +if true then print(1) print(2) print(3) end +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "A new statement is on the same line; add semi-colon on previous statement to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "MultilineBlockSemicolonsWhitelisted") +{ + LintResult result = lint(R"( +print(1); print(2); print(3) +)"); + + CHECK(result.warnings.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "MultilineBlockMissedSemicolon") +{ + LintResult result = lint(R"( +print(1); print(2) print(3) +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "A new statement is on the same line; add semi-colon on previous statement to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "MultilineBlockLocalDo") +{ + LintResult result = lint(R"( +local _x do + _x = 5 +end +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "ConfusingIndentation") +{ + LintResult result = lint(R"( +print(math.max(1, +2)) +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Statement spans multiple lines; use indentation to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocal") +{ + LintResult result = lint(R"( +function bar() + foo = 6 + return foo +end + +return bar() +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Global 'foo' is only used in the enclosing function 'bar'; consider changing it to local"); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMulti") +{ + LintResult result = lint(R"( +local createFunction = function(configValue) + -- Create an internal convenience function + local function internalLogic() + print(configValue) -- prints passed-in value + end + -- Here, we thought we were creating another internal convenience function + -- that closed over the passed-in configValue, but this is actually being + -- declared at module scope! + function moreInternalLogic() + print(configValue) -- nil!!! + end + return function() + internalLogic() + moreInternalLogic() + return nil + end +end +fnA = createFunction(true) +fnB = createFunction(false) +fnA() -- prints "true", "nil" +fnB() -- prints "false", "nil" +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, + "Global 'moreInternalLogic' is only used in the enclosing function defined at line 2; consider changing it to local"); +} + +TEST_CASE_FIXTURE(Fixture, "LocalShadowLocal") +{ + LintResult result = lint(R"( +local arg = 6 +print(arg) + +local arg = 5 +print(arg) +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Variable 'arg' shadows previous declaration at line 2"); +} + +TEST_CASE_FIXTURE(Fixture, "LocalShadowGlobal") +{ + LintResult result = lint(R"( +local math = math +global = math + +function bar() + local global = math.max(5, 1) + return global +end + +return bar() +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Variable 'global' shadows a global variable used at line 3"); +} + +TEST_CASE_FIXTURE(Fixture, "LocalShadowArgument") +{ + LintResult result = lint(R"( +function bar(a, b) + local a = b + 1 + return a +end + +return bar() +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Variable 'a' shadows previous declaration at line 2"); +} + +TEST_CASE_FIXTURE(Fixture, "LocalUnused") +{ + LintResult result = lint(R"( +local arg = 6 + +local function bar() + local arg = 5 + local blarg = 6 + if arg then + blarg = 42 + end +end + +return bar() +)"); + + CHECK_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, "Variable 'arg' is never used; prefix with '_' to silence"); + CHECK_EQ(result.warnings[1].text, "Variable 'blarg' is never used; prefix with '_' to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "ImportUnused") +{ + // Normally this would be defined externally, so hack it in for testing + addGlobalBinding(typeChecker, "game", typeChecker.anyType, "@test"); + + LintResult result = lint(R"( +local Roact = require(game.Packages.Roact) +local _Roact = require(game.Packages.Roact) +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Import 'Roact' is never used; prefix with '_' to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "FunctionUnused") +{ + LintResult result = lint(R"( +function bar() +end + +local function qux() +end + +function foo() +end + +local function _unusedl() +end + +function _unusedg() +end + +return foo() +)"); + + CHECK_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, "Function 'bar' is never used; prefix with '_' to silence"); + CHECK_EQ(result.warnings[1].text, "Function 'qux' is never used; prefix with '_' to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "UnreachableCodeBasic") +{ + LintResult result = lint(R"( +do +return 'ok' +end + +print("hi!") +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].location.begin.line, 5); + CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always returns)"); +} + +TEST_CASE_FIXTURE(Fixture, "UnreachableCodeLoopBreak") +{ + LintResult result = lint(R"( +while true do + do break end + print("nope") +end + +print("hi!") +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].location.begin.line, 3); + CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always breaks)"); +} + +TEST_CASE_FIXTURE(Fixture, "UnreachableCodeLoopContinue") +{ + LintResult result = lint(R"( +while true do + do continue end + print("nope") +end + +print("hi!") +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].location.begin.line, 3); + CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always continues)"); +} + +TEST_CASE_FIXTURE(Fixture, "UnreachableCodeIfMerge") +{ + LintResult result = lint(R"( +function foo1(a) + if a then + return 'x' + else + return 'y' + end + return 'z' +end + +function foo2(a) + if a then + return 'x' + end + return 'z' +end + +function foo3(a) + if a then + return 'x' + else + print('y') + end + return 'z' +end + +return { foo1, foo2, foo3 } +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].location.begin.line, 7); + CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always returns)"); +} + +TEST_CASE_FIXTURE(Fixture, "UnreachableCodeErrorReturnSilent") +{ + LintResult result = lint(R"( +function foo1(a) + if a then + error('x') + return 'z' + else + error('y') + end +end + +return foo1 +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "UnreachableCodeAssertFalseReturnSilent") +{ + LintResult result = lint(R"( +function foo1(a) + if a then + return 'z' + end + + assert(false) +end + +return foo1 +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "UnreachableCodeErrorReturnNonSilentBranchy") +{ + LintResult result = lint(R"( +function foo1(a) + if a then + error('x') + else + error('y') + end + return 'z' +end + +return foo1 +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].location.begin.line, 7); + CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always errors)"); +} + +TEST_CASE_FIXTURE(Fixture, "UnreachableCodeErrorReturnPropagate") +{ + LintResult result = lint(R"( +function foo1(a) + if a then + error('x') + return 'z' + else + error('y') + end + return 'x' +end + +return foo1 +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].location.begin.line, 8); + CHECK_EQ(result.warnings[0].text, "Unreachable code (previous statement always errors)"); +} + +TEST_CASE_FIXTURE(Fixture, "UnreachableCodeLoopWhile") +{ + LintResult result = lint(R"( +function foo1(a) + while a do + return 'z' + end + return 'x' +end + +return foo1 +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "UnreachableCodeLoopRepeat") +{ + LintResult result = lint(R"( +function foo1(a) + repeat + return 'z' + until a + return 'x' +end + +return foo1 +)"); + + CHECK_EQ(result.warnings.size(), + 0); // this is technically a bug, since the repeat body always returns; fixing this bug is a bit more involved than I'd like +} + +TEST_CASE_FIXTURE(Fixture, "UnknownType") +{ + ScopedFastFlag sff{"LuauLinterUnknownTypeVectorAware", true}; + + SourceModule sm; + + unfreeze(typeChecker.globalTypes); + TableTypeVar::Props instanceProps{ + {"ClassName", {typeChecker.anyType}}, + }; + + TableTypeVar instanceTable{instanceProps, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}; + TypeId instanceType = typeChecker.globalTypes.addType(instanceTable); + TypeFun instanceTypeFun{{}, instanceType}; + + ClassTypeVar::Props enumItemProps{ + {"EnumType", {typeChecker.anyType}}, + }; + + ClassTypeVar enumItemClass{"EnumItem", enumItemProps, std::nullopt, std::nullopt, {}, {}}; + TypeId enumItemType = typeChecker.globalTypes.addType(enumItemClass); + TypeFun enumItemTypeFun{{}, enumItemType}; + + ClassTypeVar normalIdClass{"NormalId", {}, enumItemType, std::nullopt, {}, {}}; + TypeId normalIdType = typeChecker.globalTypes.addType(normalIdClass); + TypeFun normalIdTypeFun{{}, normalIdType}; + + // Normally this would be defined externally, so hack it in for testing + addGlobalBinding(typeChecker, "game", typeChecker.anyType, "@test"); + addGlobalBinding(typeChecker, "typeof", typeChecker.anyType, "@test"); + typeChecker.globalScope->exportedTypeBindings["Part"] = instanceTypeFun; + typeChecker.globalScope->exportedTypeBindings["Workspace"] = instanceTypeFun; + typeChecker.globalScope->exportedTypeBindings["RunService"] = instanceTypeFun; + typeChecker.globalScope->exportedTypeBindings["Instance"] = instanceTypeFun; + typeChecker.globalScope->exportedTypeBindings["ColorSequence"] = TypeFun{{}, typeChecker.anyType}; + typeChecker.globalScope->exportedTypeBindings["EnumItem"] = enumItemTypeFun; + typeChecker.globalScope->importedTypeBindings["Enum"] = {{"NormalId", normalIdTypeFun}}; + freeze(typeChecker.globalTypes); + + LintResult result = lint(R"( +local _e01 = game:GetService("Foo") +local _e02 = game:GetService("NormalId") +local _e03 = game:FindService("table") +local _e04 = type(game) == "Part" +local _e05 = type(game) == "NormalId" +local _e06 = typeof(game) == "Bar" +local _e07 = typeof(game) == "Part" +local _e08 = typeof(game) == "vector" +local _e09 = typeof(game) == "NormalId" +local _e10 = game:IsA("ColorSequence") +local _e11 = game:IsA("Enum.NormalId") +local _e12 = game:FindFirstChildWhichIsA("function") + +local _o01 = game:GetService("Workspace") +local _o02 = game:FindService("RunService") +local _o03 = type(game) == "number" +local _o04 = type(game) == "vector" +local _o05 = typeof(game) == "string" +local _o06 = typeof(game) == "Instance" +local _o07 = typeof(game) == "EnumItem" +local _o08 = game:IsA("Part") +local _o09 = game:IsA("NormalId") +local _o10 = game:FindFirstChildWhichIsA("Part") +)"); + + REQUIRE_EQ(result.warnings.size(), 12); + CHECK_EQ(result.warnings[0].location.begin.line, 1); + CHECK_EQ(result.warnings[0].text, "Unknown type 'Foo'"); + CHECK_EQ(result.warnings[1].location.begin.line, 2); + CHECK_EQ(result.warnings[1].text, "Unknown type 'NormalId' (expected class type)"); + CHECK_EQ(result.warnings[2].location.begin.line, 3); + CHECK_EQ(result.warnings[2].text, "Unknown type 'table' (expected class type)"); + CHECK_EQ(result.warnings[3].location.begin.line, 4); + CHECK_EQ(result.warnings[3].text, "Unknown type 'Part' (expected primitive type)"); + CHECK_EQ(result.warnings[4].location.begin.line, 5); + CHECK_EQ(result.warnings[4].text, "Unknown type 'NormalId' (expected primitive type)"); + CHECK_EQ(result.warnings[5].location.begin.line, 6); + CHECK_EQ(result.warnings[5].text, "Unknown type 'Bar'"); + CHECK_EQ(result.warnings[6].location.begin.line, 7); + CHECK_EQ(result.warnings[6].text, "Unknown type 'Part' (expected primitive or userdata type)"); + CHECK_EQ(result.warnings[7].location.begin.line, 8); + CHECK_EQ(result.warnings[7].text, "Unknown type 'vector' (expected primitive or userdata type)"); + CHECK_EQ(result.warnings[8].location.begin.line, 9); + CHECK_EQ(result.warnings[8].text, "Unknown type 'NormalId' (expected primitive or userdata type)"); + CHECK_EQ(result.warnings[9].location.begin.line, 10); + CHECK_EQ(result.warnings[9].text, "Unknown type 'ColorSequence' (expected class or enum type)"); + CHECK_EQ(result.warnings[10].location.begin.line, 11); + CHECK_EQ(result.warnings[10].text, "Unknown type 'Enum.NormalId'"); + CHECK_EQ(result.warnings[11].location.begin.line, 12); + CHECK_EQ(result.warnings[11].text, "Unknown type 'function' (expected class type)"); +} + +TEST_CASE_FIXTURE(Fixture, "ForRangeTable") +{ + LintResult result = lint(R"( +local t = {} + +for i=#t,1 do +end + +for i=#t,1,-1 do +end +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].location.begin.line, 3); + CHECK_EQ(result.warnings[0].text, "For loop should iterate backwards; did you forget to specify -1 as step?"); +} + +TEST_CASE_FIXTURE(Fixture, "ForRangeBackwards") +{ + LintResult result = lint(R"( +for i=8,1 do +end + +for i=8,1,-1 do +end +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].location.begin.line, 1); + CHECK_EQ(result.warnings[0].text, "For loop should iterate backwards; did you forget to specify -1 as step?"); +} + +TEST_CASE_FIXTURE(Fixture, "ForRangeImprecise") +{ + LintResult result = lint(R"( +for i=1.3,7.5 do +end + +for i=1.3,7.5,1 do +end +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].location.begin.line, 1); + CHECK_EQ(result.warnings[0].text, "For loop ends at 7.3 instead of 7.5; did you forget to specify step?"); +} + +TEST_CASE_FIXTURE(Fixture, "ForRangeZero") +{ + LintResult result = lint(R"( +for i=0,#t do +end + +for i=(0),#t do -- to silence +end + +for i=#t,0 do +end +)"); + + CHECK_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].location.begin.line, 1); + CHECK_EQ(result.warnings[0].text, "For loop starts at 0, but arrays start at 1"); + CHECK_EQ(result.warnings[1].location.begin.line, 7); + CHECK_EQ(result.warnings[1].text, + "For loop should iterate backwards; did you forget to specify -1 as step? Also consider changing 0 to 1 since arrays start at 1"); +} + +TEST_CASE_FIXTURE(Fixture, "UnbalancedAssignment") +{ + LintResult result = lint(R"( +do +local _a,_b,_c = pcall() +end +do +local _a,_b,_c = pcall(), 5 +end +do +local _a,_b,_c = pcall(), 5, 6 +end +do +local _a,_b,_c = pcall(), 5, 6, 7 +end +do +local _a,_b,_c = pcall(), nil +end +)"); + + CHECK_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].location.begin.line, 5); + CHECK_EQ(result.warnings[0].text, "Assigning 2 values to 3 variables initializes extra variables with nil; add 'nil' to value list to silence"); + CHECK_EQ(result.warnings[1].location.begin.line, 11); + CHECK_EQ(result.warnings[1].text, "Assigning 4 values to 3 variables leaves some values unused"); +} + +TEST_CASE_FIXTURE(Fixture, "ImplicitReturn") +{ + LintResult result = lint(R"( +function f1(a) + if not a then + return 5 + end +end + +function f2(a) + if not a then + return + end +end + +function f3(a) + if not a then + return 5 + else + return + end +end + +function f4(a) + for i in pairs(a) do + if i > 5 then + return i + end + end + + print("element not found") +end + +function f5(a) + for i in pairs(a) do + if i > 5 then + return i + end + end + + error("element not found") +end + +f6 = function(a) + if a == 0 then + return 42 + end +end + +function f7(a) + repeat + return 10 + until a ~= nil +end + +return f1,f2,f3,f4,f5,f6,f7 +)"); + + CHECK_EQ(result.warnings.size(), 3); + CHECK_EQ(result.warnings[0].location.begin.line, 4); + CHECK_EQ(result.warnings[0].text, + "Function 'f1' can implicitly return no values even though there's an explicit return at line 4; add explicit return to silence"); + CHECK_EQ(result.warnings[1].location.begin.line, 28); + CHECK_EQ(result.warnings[1].text, + "Function 'f4' can implicitly return no values even though there's an explicit return at line 25; add explicit return to silence"); + CHECK_EQ(result.warnings[2].location.begin.line, 44); + CHECK_EQ(result.warnings[2].text, + "Function can implicitly return no values even though there's an explicit return at line 44; add explicit return to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "ImplicitReturnInfiniteLoop") +{ + LintResult result = lint(R"( +function f1(a) + while true do + if math.random() > 0.5 then + return 5 + end + end +end + +function f2(a) + repeat + if math.random() > 0.5 then + return 5 + end + until false +end + +function f3(a) + while true do + if math.random() > 0.5 then + return 5 + end + if math.random() < 0.1 then + break + end + end +end + +function f4(a) + repeat + if math.random() > 0.5 then + return 5 + end + if math.random() < 0.1 then + break + end + until false +end + +return f1,f2,f3,f4 +)"); + + CHECK_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].location.begin.line, 25); + CHECK_EQ(result.warnings[0].text, + "Function 'f3' can implicitly return no values even though there's an explicit return at line 21; add explicit return to silence"); + CHECK_EQ(result.warnings[1].location.begin.line, 36); + CHECK_EQ(result.warnings[1].text, + "Function 'f4' can implicitly return no values even though there's an explicit return at line 32; add explicit return to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings") +{ + LintResult result = lint(R"(--!strict +type InputData = { + id: number, + inputType: EnumItem, + inputState: EnumItem, + updated: number, + position: Vector3, + keyCode: EnumItem, + name: string +} +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "BreakFromInfiniteLoopMakesStatementReachable") +{ + LintResult result = lint(R"( +local bar = ... + +repeat + if bar then + break + end + + return 2 +until true + +return 1 +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "IgnoreLintAll") +{ + LintResult result = lint(R"( +--!nolint +return foo +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "IgnoreLintSpecific") +{ + LintResult result = lint(R"( +--!nolint UnknownGlobal +local x = 1 +return foo +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Variable 'x' is never used; prefix with '_' to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "FormatStringFormat") +{ + LintResult result = lint(R"( +-- incorrect format strings +string.format("%") +string.format("%??d") +string.format("%Y") + +-- incorrect format strings, self call +local _ = ("%"):format() + +-- correct format strings, just to uh make sure +string.format("hello %d %f", 4, 5) +)"); + + CHECK_EQ(result.warnings.size(), 4); + CHECK_EQ(result.warnings[0].text, "Invalid format string: unfinished format specifier"); + CHECK_EQ(result.warnings[1].text, "Invalid format string: invalid format specifier: must be a string format specifier or %"); + CHECK_EQ(result.warnings[2].text, "Invalid format string: invalid format specifier: must be a string format specifier or %"); + CHECK_EQ(result.warnings[3].text, "Invalid format string: unfinished format specifier"); +} + +TEST_CASE_FIXTURE(Fixture, "FormatStringPack") +{ + LintResult result = lint(R"( +-- incorrect pack specifiers +string.pack("?") +string.packsize("?") +string.unpack("?") + +-- missing size +string.packsize("bc") + +-- incorrect X alignment +string.packsize("X") +string.packsize("X i") + +-- correct X alignment +string.packsize("Xi") + +-- packsize can't be used with variable sized formats +string.packsize("s") + +-- out of range size specifiers +string.packsize("i0") +string.packsize("i17") + +-- a very very very out of range size specifier +string.packsize("i99999999999999999999") +string.packsize("c99999999999999999999") + +-- correct format specifiers +string.packsize("=!1bbbI3c42") +)"); + + CHECK_EQ(result.warnings.size(), 11); + CHECK_EQ(result.warnings[0].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); + CHECK_EQ(result.warnings[1].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); + CHECK_EQ(result.warnings[2].text, "Invalid pack format: unexpected character; must be a pack specifier or space"); + CHECK_EQ(result.warnings[3].text, "Invalid pack format: fixed-sized string format must specify the size"); + CHECK_EQ(result.warnings[4].text, "Invalid pack format: X must be followed by a size specifier"); + CHECK_EQ(result.warnings[5].text, "Invalid pack format: X must be followed by a size specifier"); + CHECK_EQ(result.warnings[6].text, "Invalid pack format: pack specifier must be fixed-size"); + CHECK_EQ(result.warnings[7].text, "Invalid pack format: integer size must be in range [1,16]"); + CHECK_EQ(result.warnings[8].text, "Invalid pack format: integer size must be in range [1,16]"); + CHECK_EQ(result.warnings[9].text, "Invalid pack format: size specifier is too large"); + CHECK_EQ(result.warnings[10].text, "Invalid pack format: size specifier is too large"); +} + +TEST_CASE_FIXTURE(Fixture, "FormatStringMatch") +{ + LintResult result = lint(R"( +local s = ... + +-- incorrect character class specifiers +string.match(s, "%q") +string.gmatch(s, "%q") +string.find(s, "%q") +string.gsub(s, "%q", "") + +-- various errors +string.match(s, "%") +string.match(s, "[%1]") +string.match(s, "%0") +string.match(s, "(%d)%2") +string.match(s, "%bx") +string.match(s, "%foo") +string.match(s, '(%d))') +string.match(s, '(%d') +string.match(s, '[%d') +string.match(s, '%,') + +-- self call - not detected because we don't know the type! +local _ = s:match("%q") + +-- correct patterns +string.match(s, "[A-Z]+(%d)%1") +)"); + + CHECK_EQ(result.warnings.size(), 14); + CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); + CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); + CHECK_EQ(result.warnings[2].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); + CHECK_EQ(result.warnings[3].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); + CHECK_EQ(result.warnings[4].text, "Invalid match pattern: unfinished character class"); + CHECK_EQ(result.warnings[5].text, "Invalid match pattern: sets can not contain capture references"); + CHECK_EQ(result.warnings[6].text, "Invalid match pattern: invalid capture reference, must be 1-9"); + CHECK_EQ(result.warnings[7].text, "Invalid match pattern: invalid capture reference, must refer to a valid capture"); + CHECK_EQ(result.warnings[8].text, "Invalid match pattern: missing brace characters for balanced match"); + CHECK_EQ(result.warnings[9].text, "Invalid match pattern: missing set after a frontier pattern"); + CHECK_EQ(result.warnings[10].text, "Invalid match pattern: unexpected ) without a matching ("); + CHECK_EQ(result.warnings[11].text, "Invalid match pattern: expected ) at the end of the string to close a capture"); + CHECK_EQ(result.warnings[12].text, "Invalid match pattern: expected ] at the end of the string to close a set"); + CHECK_EQ(result.warnings[13].text, "Invalid match pattern: expected a magic character after %"); +} + +TEST_CASE_FIXTURE(Fixture, "FormatStringMatchNested") +{ + LintResult result = lint(R"~( +local s = ... + +-- correct reference to nested pattern +string.match(s, "((a)%2)") + +-- incorrect reference to nested pattern (not closed yet) +string.match(s, "((a)%1)") + +-- incorrect reference to nested pattern (index out of range) +string.match(s, "((a)%3)") +)~"); + + CHECK_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid capture reference, must refer to a closed capture"); + CHECK_EQ(result.warnings[0].location.begin.line, 7); + CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid capture reference, must refer to a valid capture"); + CHECK_EQ(result.warnings[1].location.begin.line, 10); +} + +TEST_CASE_FIXTURE(Fixture, "FormatStringMatchSets") +{ + LintResult result = lint(R"~( +local s = ... + +-- fake empty sets (but actually sets that aren't closed) +string.match(s, "[]") +string.match(s, "[^]") + +-- character ranges in sets +string.match(s, "[%a-b]") +string.match(s, "[a-%b]") + +-- invalid escapes +string.match(s, "[%q]") +string.match(s, "[%;]") + +-- capture refs in sets +string.match(s, "[%1]") + +-- valid escapes and - at the end +string.match(s, "[%]x-]") + +-- % escapes itself +string.match(s, "[%%]") + +-- this abomination is a valid pattern due to rules wrt handling empty sets +string.match(s, "[]|'[]") +string.match(s, "[^]|'[]") +)~"); + + CHECK_EQ(result.warnings.size(), 7); + CHECK_EQ(result.warnings[0].text, "Invalid match pattern: expected ] at the end of the string to close a set"); + CHECK_EQ(result.warnings[1].text, "Invalid match pattern: expected ] at the end of the string to close a set"); + CHECK_EQ(result.warnings[2].text, "Invalid match pattern: character range can't include character sets"); + CHECK_EQ(result.warnings[3].text, "Invalid match pattern: character range can't include character sets"); + CHECK_EQ(result.warnings[4].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); + CHECK_EQ(result.warnings[5].text, "Invalid match pattern: expected a magic character after %"); + CHECK_EQ(result.warnings[6].text, "Invalid match pattern: sets can not contain capture references"); +} + +TEST_CASE_FIXTURE(Fixture, "FormatStringFindArgs") +{ + LintResult result = lint(R"( +local s = ... + +-- incorrect character class specifier +string.find(s, "%q") + +-- raw string find +string.find(s, "%q", 1, true) +string.find(s, "%q", 1, math.random() < 0.5) + +-- incorrect character class specifier +string.find(s, "%q", 1, false) + +-- missing arguments +string.find() +string.find("foo"); +("foo"):find() +)"); + + CHECK_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); + CHECK_EQ(result.warnings[0].location.begin.line, 4); + CHECK_EQ(result.warnings[1].text, "Invalid match pattern: invalid character class, must refer to a defined class or its inverse"); + CHECK_EQ(result.warnings[1].location.begin.line, 11); +} + +TEST_CASE_FIXTURE(Fixture, "FormatStringReplace") +{ + LintResult result = lint(R"( +local s = ... + +-- incorrect replacements +string.gsub(s, '(%d+)', "%") +string.gsub(s, '(%d+)', "%x") +string.gsub(s, '(%d+)', "%2") +string.gsub(s, '', "%1") + +-- correct replacements +string.gsub(s, '[A-Z]+(%d)', "%0%1") +string.gsub(s, 'foo', "%0") +)"); + + CHECK_EQ(result.warnings.size(), 4); + CHECK_EQ(result.warnings[0].text, "Invalid match replacement: unfinished replacement"); + CHECK_EQ(result.warnings[1].text, "Invalid match replacement: unexpected replacement character; must be a digit or %"); + CHECK_EQ(result.warnings[2].text, "Invalid match replacement: invalid capture index, must refer to pattern capture"); + CHECK_EQ(result.warnings[3].text, "Invalid match replacement: invalid capture index, must refer to pattern capture"); +} + +TEST_CASE_FIXTURE(Fixture, "FormatStringDate") +{ + LintResult result = lint(R"( +-- incorrect formats +os.date("%") +os.date("%L") +os.date("%?") + +-- correct formats +os.date("it's %c now") +os.date("!*t") +)"); + + CHECK_EQ(result.warnings.size(), 3); + CHECK_EQ(result.warnings[0].text, "Invalid date format: unfinished replacement"); + CHECK_EQ(result.warnings[1].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); + CHECK_EQ(result.warnings[2].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); +} + +TEST_CASE_FIXTURE(Fixture, "FormatStringTyped") +{ + LintResult result = lintTyped(R"~( +local s: string, nons = ... + +string.match(s, "[]") +s:match("[]") + +-- no warning here since we don't know that it's a string +nons:match("[]") +)~"); + + CHECK_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, "Invalid match pattern: expected ] at the end of the string to close a set"); + CHECK_EQ(result.warnings[0].location.begin.line, 3); + CHECK_EQ(result.warnings[1].text, "Invalid match pattern: expected ] at the end of the string to close a set"); + CHECK_EQ(result.warnings[1].location.begin.line, 4); +} + +TEST_CASE_FIXTURE(Fixture, "TableLiteral") +{ + LintResult result = lint(R"(-- line 1 +_ = { + first = 1, + second = 2, + first = 3, +} + +_ = { + first = 1, + ["first"] = 2, +} + +_ = { + 1, 2, 3, + [1] = 42 +} + +_ = { + [3] = 42, + 1, 2, 3, +} + +local _: { + first: number, + second: string, + first: boolean +} + +_ = { + 1, 2, 3, + [0] = 42, + [4] = 42, +} + +_ = { + [1] = 1, + [2] = 2, + [1] = 3, +} +)"); + + CHECK_EQ(result.warnings.size(), 6); + CHECK_EQ(result.warnings[0].text, "Table field 'first' is a duplicate; previously defined at line 3"); + CHECK_EQ(result.warnings[1].text, "Table field 'first' is a duplicate; previously defined at line 9"); + CHECK_EQ(result.warnings[2].text, "Table index 1 is a duplicate; previously defined as a list entry"); + CHECK_EQ(result.warnings[3].text, "Table index 3 is a duplicate; previously defined as a list entry"); + CHECK_EQ(result.warnings[4].text, "Table type field 'first' is a duplicate; previously defined at line 24"); + CHECK_EQ(result.warnings[5].text, "Table index 1 is a duplicate; previously defined at line 36"); +} + +TEST_CASE_FIXTURE(Fixture, "ImportOnlyUsedInTypeAnnotation") +{ + LintResult result = lint(R"( + local Foo = require(script.Parent.Foo) + + local x: Foo.Y = 1 + )"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Variable 'x' is never used; prefix with '_' to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "DisableUnknownGlobalWithTypeChecking") +{ + LintResult result = lint(R"( + --!strict + unknownGlobal() + )"); + + REQUIRE_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "no_spurious_warning_after_a_function_type_alias") +{ + LintResult result = lint(R"( + local exports = {} + export type PathFunction

= (P?) -> string + exports.tokensToFunction = function() end + return exports + )"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") +{ + ScopePtr testScope = frontend.addEnvironment("Test"); + unfreeze(typeChecker.globalTypes); + loadDefinitionFile(frontend.typeChecker, testScope, R"( + declare Foo: number + )", + "@test"); + freeze(typeChecker.globalTypes); + + fileResolver.environments["A"] = "Test"; + + fileResolver.source["A"] = R"( + local _foo: Foo = 123 + -- os.clock comes from the global scope, the parent of this module's environment + local _bar: typeof(os.clock) = os.clock + )"; + + LintResult result = frontend.lint("A"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "DeadLocalsUsed") +{ + LintResult result = lint(R"( +--!nolint LocalShadow +do + local x + for x in pairs({}) do + print(x) + end + print(x) -- x is not initialized +end + +do + local a, b, c = 1, 2 + print(a, b, c) -- c is not initialized +end + +do + local a, b, c = table.unpack({}) + print(a, b, c) -- no warning as we don't know anything about c +end + )"); + + CHECK_EQ(result.warnings.size(), 3); + CHECK_EQ(result.warnings[0].text, "Variable 'x' defined at line 4 is never initialized or assigned; initialize with 'nil' to silence"); + CHECK_EQ(result.warnings[1].text, "Assigning 2 values to 3 variables initializes extra variables with nil; add 'nil' to value list to silence"); + CHECK_EQ(result.warnings[2].text, "Variable 'c' defined at line 12 is never initialized or assigned; initialize with 'nil' to silence"); +} + +TEST_CASE_FIXTURE(Fixture, "LocalFunctionNotDead") +{ + LintResult result = lint(R"( +local foo +function foo() end + )"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "DuplicateGlobalFunction") +{ + LintResult result = lint(R"( + function x() end + + function x() end + + return x + )"); + + REQUIRE_EQ(1, result.warnings.size()); + + const auto& w = result.warnings[0]; + + CHECK_EQ(LintWarning::Code_DuplicateFunction, w.code); + CHECK_EQ("Duplicate function definition: 'x' also defined on line 2", w.text); +} + +TEST_CASE_FIXTURE(Fixture, "DuplicateLocalFunction") +{ + LintOptions options; + options.setDefaults(); + options.enableWarning(LintWarning::Code_DuplicateFunction); + options.enableWarning(LintWarning::Code_LocalShadow); + + LintResult result = lint(R"( + local function x() end + + print(x) + + local function x() end + + return x + )", + options); + + REQUIRE_EQ(1, result.warnings.size()); + + CHECK_EQ(LintWarning::Code_DuplicateFunction, result.warnings[0].code); +} + +TEST_CASE_FIXTURE(Fixture, "DuplicateMethod") +{ + LintResult result = lint(R"( + local T = {} + function T:x() end + + function T:x() end + + return x + )"); + + REQUIRE_EQ(1, result.warnings.size()); + + const auto& w = result.warnings[0]; + + CHECK_EQ(LintWarning::Code_DuplicateFunction, w.code); + CHECK_EQ("Duplicate function definition: 'T.x' also defined on line 3", w.text); +} + +TEST_CASE_FIXTURE(Fixture, "DontTriggerTheWarningIfTheFunctionsAreInDifferentScopes") +{ + LintResult result = lint(R"( + if true then + function c() end + else + function c() end + end + + return c + )"); + + CHECK(result.warnings.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "LintHygieneUAF") +{ + LintResult result = lint(R"( + local Hooty = require(workspace.A) + + local HoHooty = require(workspace.A) + + local h: Hooty.Pointy = ruire(workspace.A) + + local h: H + local h: Hooty.Pointy = ruire(workspace.A) + + local hh: Hooty.Pointy = ruire(workspace.A) + + local h: Hooty.Pointy = ruire(workspace.A) + + linooty.Pointy = ruire(workspace.A) + + local hh: Hooty.Pointy = ruire(workspace.A) + + local h: Hooty.Pointy = ruire(workspace.A) + + linty = ruire(workspace.A) + + local h: Hooty.Pointy = ruire(workspace.A) + + local hh: Hooty.Pointy = ruire(workspace.A) + + local h: Hooty.Pointy = ruire(workspace.A) + + local h: Hooty.Pt + )"); + + CHECK_EQ(result.warnings.size(), 12); +} + +TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") +{ + unfreeze(typeChecker.globalTypes); + TypeId instanceType = typeChecker.globalTypes.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, {}}); + typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; + + getMutable(instanceType)->props = { + {"Name", {typeChecker.stringType}}, + {"DataCost", {typeChecker.numberType, /* deprecated= */ true}}, + {"Wait", {typeChecker.anyType, /* deprecated= */ true}}, + }; + freeze(typeChecker.globalTypes); + + LintResult result = lintTyped(R"( +return function (i: Instance) + i:Wait(1.0) + print(i.Name) + return i.DataCost +end +)"); + + REQUIRE_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, "Member 'Instance.Wait' is deprecated"); + CHECK_EQ(result.warnings[1].text, "Member 'Instance.DataCost' is deprecated"); +} + +TEST_CASE_FIXTURE(Fixture, "TableOperations") +{ + LintResult result = lintTyped(R"( +local t = {} +local tt = {} + +table.insert(t, #t, 42) +table.insert(t, (#t), 42) -- silenced + +table.insert(t, #t + 1, 42) +table.insert(t, #tt + 1, 42) -- different table, ok + +table.insert(t, 0, 42) + +table.remove(t, 0) + +table.remove(t, #t-1) + +table.insert(t, string.find("hello", "h")) +)"); + + REQUIRE_EQ(result.warnings.size(), 6); + CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " + "second argument or wrap it in parentheses to silence"); + CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); + CHECK_EQ(result.warnings[2].text, "table.insert uses index 0 but arrays are 1-based; did you mean 1 instead?"); + CHECK_EQ(result.warnings[3].text, "table.remove uses index 0 but arrays are 1-based; did you mean 1 instead?"); + CHECK_EQ(result.warnings[4].text, "table.remove will remove the value before the last element, which is likely a bug; consider removing the " + "second argument or wrap it in parentheses to silence"); + CHECK_EQ(result.warnings[5].text, + "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); +} + +TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") +{ + LintResult result = lint(R"( +if true then +elseif false then +elseif true then -- duplicate +end + +if true then +elseif false then +else + if true then -- duplicate + end +end + +_ = true and true +_ = true or true +_ = (true and false) and true +_ = (true and true) and true +_ = (true and true) or true +_ = (true and false) and (42 and false) + +_ = true and true or false -- no warning since this is is a common pattern used as a ternary replacement +)"); + + REQUIRE_EQ(result.warnings.size(), 7); + CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); + CHECK_EQ(result.warnings[0].location.begin.line + 1, 4); + CHECK_EQ(result.warnings[1].text, "Condition has already been checked on column 5"); + CHECK_EQ(result.warnings[2].text, "Condition has already been checked on column 5"); + CHECK_EQ(result.warnings[3].text, "Condition has already been checked on column 6"); + CHECK_EQ(result.warnings[4].text, "Condition has already been checked on column 6"); + CHECK_EQ(result.warnings[5].text, "Condition has already been checked on column 6"); + CHECK_EQ(result.warnings[6].text, "Condition has already been checked on column 15"); + CHECK_EQ(result.warnings[6].location.begin.line + 1, 19); +} + +TEST_CASE_FIXTURE(Fixture, "DuplicateLocal") +{ + LintResult result = lint(R"( +function foo(a1, a2, a3, a1) +end + +local _, _, _ = ... -- ok! +local a1, a2, a1 = ... -- not ok + +local moo = {} +function moo:bar(self) +end + +return foo, moo, a1, a2 +)"); + + REQUIRE_EQ(result.warnings.size(), 4); + CHECK_EQ(result.warnings[0].text, "Function parameter 'a1' already defined on column 14"); + CHECK_EQ(result.warnings[1].text, "Variable 'a1' is never used; prefix with '_' to silence"); + CHECK_EQ(result.warnings[2].text, "Variable 'a1' already defined on column 7"); + CHECK_EQ(result.warnings[3].text, "Function parameter 'self' already defined implicitly"); +} + +TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp new file mode 100644 index 0000000..1b146ed --- /dev/null +++ b/tests/Module.test.cpp @@ -0,0 +1,267 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Module.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("ModuleTests"); + +TEST_CASE_FIXTURE(Fixture, "is_within_comment") +{ + check(R"( + --!strict + local foo = {} + function foo:bar() end + + --[[ + foo: + ]] foo:bar() + + --[[]]--[[]] -- Two distinct comments that have zero characters of space between them. + )"); + + SourceModule* sm = getMainSourceModule(); + + CHECK_EQ(5, sm->commentLocations.size()); + + CHECK(isWithinComment(*sm, Position{1, 15})); + CHECK(isWithinComment(*sm, Position{6, 16})); + CHECK(isWithinComment(*sm, Position{9, 13})); + CHECK(isWithinComment(*sm, Position{9, 14})); + + CHECK(!isWithinComment(*sm, Position{2, 15})); + CHECK(!isWithinComment(*sm, Position{7, 10})); + CHECK(!isWithinComment(*sm, Position{7, 11})); +} + +TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") +{ + TypeArena dest; + + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + // numberType is persistent. We leave it as-is. + TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks); + CHECK_EQ(newNumber, typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") +{ + TypeArena dest; + + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + // Create a new number type that isn't persistent + unfreeze(typeChecker.globalTypes); + TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number}); + freeze(typeChecker.globalTypes); + TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks); + + CHECK_NE(newNumber, oldNumber); + CHECK_EQ(*oldNumber, *newNumber); + CHECK_EQ("number", toString(newNumber)); + CHECK_EQ(1, dest.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") +{ + CheckResult result = check(R"( + local Cyclic = {} + function Cyclic.get() + return Cyclic + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + /* The inferred type of Cyclic is {get: () -> Cyclic} + * + * Assert that the return type of get() is the same as the outer table. + */ + + TypeId counterType = requireType("Cyclic"); + + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + TypeArena dest; + TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks); + + TableTypeVar* ttv = getMutable(counterCopy); + REQUIRE(ttv != nullptr); + + CHECK_EQ(std::optional{"Cyclic"}, ttv->syntheticName); + + TypeId methodType = ttv->props["get"].type; + REQUIRE(methodType != nullptr); + + const FunctionTypeVar* ftv = get(methodType); + REQUIRE(ftv != nullptr); + + std::optional methodReturnType = first(ftv->retType); + REQUIRE(methodReturnType); + + CHECK_EQ(methodReturnType, counterCopy); + CHECK_EQ(2, dest.typePacks.size()); // one for the function args, and another for its return type + CHECK_EQ(2, dest.typeVars.size()); // One table and one function +} + +TEST_CASE_FIXTURE(Fixture, "builtin_types_point_into_globalTypes_arena") +{ + CheckResult result = check(R"( + return {sign=math.sign} + )"); + dumpErrors(result); + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr module = frontend.moduleResolver.getModule("MainModule"); + std::optional exports = first(module->getModuleScope()->returnType); + REQUIRE(bool(exports)); + + REQUIRE(isInArena(*exports, module->interfaceTypes)); + + TableTypeVar* exportsTable = getMutable(*exports); + REQUIRE(exportsTable != nullptr); + + TypeId signType = exportsTable->props["sign"].type; + REQUIRE(signType != nullptr); + + CHECK(!isInArena(signType, module->interfaceTypes)); + CHECK(isInArena(signType, typeChecker.globalTypes)); +} + +TEST_CASE_FIXTURE(Fixture, "deepClone_union") +{ + TypeArena dest; + + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + unfreeze(typeChecker.globalTypes); + TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); + freeze(typeChecker.globalTypes); + TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks); + + CHECK_NE(newUnion, oldUnion); + CHECK_EQ("number | string", toString(newUnion)); + CHECK_EQ(1, dest.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") +{ + TypeArena dest; + + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + unfreeze(typeChecker.globalTypes); + TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); + freeze(typeChecker.globalTypes); + TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks); + + CHECK_NE(newIntersection, oldIntersection); + CHECK_EQ("number & string", toString(newIntersection)); + CHECK_EQ(1, dest.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "clone_class") +{ + TypeVar exampleMetaClass{ClassTypeVar{"ExampleClassMeta", + { + {"__add", {typeChecker.anyType}}, + }, + std::nullopt, std::nullopt, {}, {}}}; + TypeVar exampleClass{ClassTypeVar{"ExampleClass", + { + {"PropOne", {typeChecker.numberType}}, + {"PropTwo", {typeChecker.stringType}}, + }, + std::nullopt, &exampleMetaClass, {}, {}}}; + + TypeArena dest; + + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks); + const ClassTypeVar* ctv = get(cloned); + REQUIRE(ctv != nullptr); + + REQUIRE(ctv->metatable); + const ClassTypeVar* metatable = get(*ctv->metatable); + REQUIRE(metatable); + + CHECK_EQ("ExampleClass", ctv->name); + CHECK_EQ("ExampleClassMeta", metatable->name); +} + +TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") +{ + TypeVar freeTy(FreeTypeVar{TypeLevel{}}); + TypePackVar freeTp(FreeTypePack{TypeLevel{}}); + + TypeArena dest; + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + bool encounteredFreeType = false; + TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); + CHECK(Luau::get(clonedTy)); + CHECK(encounteredFreeType); + + encounteredFreeType = false; + TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); + CHECK(Luau::get(clonedTp)); + CHECK(encounteredFreeType); +} + +TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") +{ + TypeVar tableTy{TableTypeVar{}}; + TableTypeVar* ttv = getMutable(&tableTy); + ttv->state = TableState::Free; + + TypeArena dest; + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + bool encounteredFreeType = false; + TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); + const TableTypeVar* clonedTtv = get(cloned); + CHECK_EQ(clonedTtv->state, TableState::Sealed); + CHECK(encounteredFreeType); +} + +TEST_CASE_FIXTURE(Fixture, "clone_self_property") +{ + fileResolver.source["Module/A"] = R"( + --!nonstrict + local a = {} + function a:foo(x) + return -x; + end + return a; + )"; + + CheckResult result = frontend.check("Module/A"); + LUAU_REQUIRE_NO_ERRORS(result); + + fileResolver.source["Module/B"] = R"( + --!nonstrict + local a = require(script.Parent.A) + return a.foo(5) + )"; + + result = frontend.check("Module/B"); + LUAU_REQUIRE_ERRORS(result); + + CHECK_EQ(toString(result.errors[0]), "This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a " + "dot or pass 1 extra nil to suppress this warning"); +} + +TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp new file mode 100644 index 0000000..f3c76d5 --- /dev/null +++ b/tests/NonstrictMode.test.cpp @@ -0,0 +1,282 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +#include + +using namespace Luau; + +TEST_SUITE_BEGIN("NonstrictModeTests"); + +TEST_CASE_FIXTURE(Fixture, "infer_nullary_function") +{ + CheckResult result = check(R"( + --!nonstrict + function foo(x, y) end + )"); + + TypeId fooType = requireType("foo"); + REQUIRE(fooType); + + const FunctionTypeVar* ftv = get(fooType); + REQUIRE_MESSAGE(ftv != nullptr, "Expected a function, got " << toString(fooType)); + + auto args = flatten(ftv->argTypes).first; + REQUIRE_EQ(2, args.size()); + REQUIRE_EQ("any", toString(args[0])); + REQUIRE_EQ("any", toString(args[1])); + + auto rets = flatten(ftv->retType).first; + REQUIRE_EQ(0, rets.size()); +} + +TEST_CASE_FIXTURE(Fixture, "infer_the_maximum_number_of_values_the_function_could_return") +{ + CheckResult result = check(R"( + --!nonstrict + function getMinCardCountForWidth(width) + if width < 513 then + return 3 + else + return 8, 'jellybeans' + end + end + )"); + + TypeId t = requireType("getMinCardCountForWidth"); + REQUIRE(t); + + REQUIRE_EQ("(any) -> (...any)", toString(t)); +} + +#if 0 +// Maybe we want this? +TEST_CASE_FIXTURE(Fixture, "return_annotation_is_still_checked") +{ + CheckResult result = check(R"( + function foo(x): number return 'hello' end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE_NE(*typeChecker.anyType, *requireType("foo")); +} +#endif + +TEST_CASE_FIXTURE(Fixture, "function_parameters_are_any") +{ + CheckResult result = check(R"( + --!nonstrict + function f(arg) + arg = 9 + arg:concat(4) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "inconsistent_return_types_are_ok") +{ + CheckResult result = check(R"( + --!nonstrict + function f() + if 1 then + return 4 + else + return 'hello' + end + return 'one', 'two' + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "locals_are_any_by_default") +{ + CheckResult result = check(R"( + --!nonstrict + local m = 55 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.anyType, *requireType("m")); +} + +TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") +{ + CheckResult result = check(R"( + --!nonstrict + local function f(a, b) + return a + end + + f(5) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") +{ + CheckResult result = check(R"( + --!nonstrict + local T = {} + function T:method() end + function T.staticmethod() end + + T.method() + T:staticmethod() + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK(std::any_of(result.errors.begin(), result.errors.end(), [](const TypeError& e) { + return get(e); + })); + CHECK(std::any_of(result.errors.begin(), result.errors.end(), [](const TypeError& e) { + return get(e); + })); +} + +TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") +{ + CheckResult result = check(R"( + --!nonstrict + local T = {} + function T:method() end + T.method() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto e = get(result.errors[0]); + REQUIRE(e != nullptr); + + REQUIRE_EQ(1, e->requiredExtraNils); +} + +TEST_CASE_FIXTURE(Fixture, "table_props_are_any") +{ + CheckResult result = check(R"( + --!nonstrict + local T = {} + T.foo = 55 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* ttv = getMutable(requireType("T")); + + REQUIRE(ttv != nullptr); + + TypeId fooProp = ttv->props["foo"].type; + REQUIRE(fooProp != nullptr); + + CHECK_EQ(*fooProp, *typeChecker.anyType); +} + +TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") +{ + CheckResult result = check(R"( + --!nonstrict + local T = { + one = 1, + two = 'two', + three = function() return 3 end + } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* ttv = getMutable(requireType("T")); + REQUIRE_MESSAGE(ttv, "Should be a table: " << toString(requireType("T"))); + + CHECK_EQ(*typeChecker.anyType, *ttv->props["one"].type); + CHECK_EQ(*typeChecker.anyType, *ttv->props["two"].type); + CHECK_MESSAGE(get(ttv->props["three"].type), "Should be a function: " << *ttv->props["three"].type); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_iterator_variables_are_any") +{ + CheckResult result = check(R"( + --!nonstrict + function requires_a_table(arg: {}) end + function requires_a_number(arg: number) end + + local T = {} + for a, b in pairs(T) do + requires_a_table(a) + requires_a_table(b) + requires_a_number(a) + requires_a_number(b) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_dot_insert_and_recursive_calls") +{ + CheckResult result = check(R"( + --!nonstrict + function populateListFromIds(list, normalizedData) + local newList = {} + + for _, value in ipairs(list) do + if type(value) == "table" then + table.insert(newList, populateListFromIds(value, normalizedData)) + else + table.insert(newList, normalizedData[value]) + end + end + + return newList + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "delay_function_does_not_require_its_argument_to_return_anything") +{ + CheckResult result = check(R"( + --!nonstrict + + function delay(ms: number?, cb: () -> ()): () end + + delay(50, function() end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "inconsistent_module_return_types_are_ok") +{ + CheckResult result = check(R"( + --!nonstrict + + local FFlag: any + + if FFlag.get('SomeFlag') then + return {foo='bar'} + else + return function(prop) + return 'bar' + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("any", toString(getMainModule()->getModuleScope()->returnType)); +} + +TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp new file mode 100644 index 0000000..cb03a7b --- /dev/null +++ b/tests/Parser.test.cpp @@ -0,0 +1,2522 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +namespace +{ + +struct Counter +{ + static int instanceCount; + + int id; + + Counter() + { + ++instanceCount; + id = instanceCount; + } +}; + +int Counter::instanceCount = 0; + +// TODO: delete this and replace all other use of this function with matchParseError +std::string getParseError(const std::string& code) +{ + Fixture f; + + try + { + f.parse(code); + } + catch (const Luau::ParseErrors& e) + { + // in general, tests check only the first error + return e.getErrors().front().getMessage(); + } + + throw std::runtime_error("Expected a parse error in '" + code + "'"); +} + +} // namespace + +TEST_SUITE_BEGIN("AllocatorTests"); + +TEST_CASE("allocator_can_be_moved") +{ + Counter* c = nullptr; + auto inner = [&]() { + Luau::Allocator allocator; + c = allocator.alloc(); + Luau::Allocator moved{std::move(allocator)}; + return moved; + }; + + Counter::instanceCount = 0; + Luau::Allocator a{inner()}; + + CHECK_EQ(1, c->id); +} + +TEST_CASE("moved_out_Allocator_can_still_be_used") +{ + Luau::Allocator outer; + Luau::Allocator inner{std::move(outer)}; + + int* i = outer.alloc(); + REQUIRE(i != nullptr); + *i = 55; + REQUIRE_EQ(*i, 55); +} + +TEST_CASE("aligns_things") +{ + Luau::Allocator alloc; + + char* one = alloc.alloc(); + double* two = alloc.alloc(); + (void)one; + CHECK_EQ(0, reinterpret_cast(two) & (alignof(double) - 1)); +} + +TEST_CASE("initial_double_is_aligned") +{ + Luau::Allocator alloc; + + double* one = alloc.alloc(); + CHECK_EQ(0, reinterpret_cast(one) & (alignof(double) - 1)); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("LexerTests"); + +TEST_CASE("broken_string_works") +{ + const std::string testInput = "[["; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::Type::BrokenString); + CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 2))); +} + +TEST_CASE("broken_comment") +{ + const std::string testInput = "--[[ "; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::Type::BrokenComment); + CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 6))); +} + +TEST_CASE("broken_comment_kept") +{ + const std::string testInput = "--[[ "; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + lexer.setSkipComments(true); + CHECK_EQ(lexer.next().type, Lexeme::Type::BrokenComment); +} + +TEST_CASE("comment_skipped") +{ + const std::string testInput = "-- "; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + lexer.setSkipComments(true); + CHECK_EQ(lexer.next().type, Lexeme::Type::Eof); +} + +TEST_CASE("multilineCommentWithLexemeInAndAfter") +{ + const std::string testInput = "--[[ function \n" + "]] end"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme comment = lexer.next(); + Lexeme end = lexer.next(); + + CHECK_EQ(comment.type, Lexeme::Type::BlockComment); + CHECK_EQ(comment.location, Luau::Location(Luau::Position(0, 0), Luau::Position(1, 2))); + CHECK_EQ(end.type, Lexeme::Type::ReservedEnd); + CHECK_EQ(end.location, Luau::Location(Luau::Position(1, 3), Luau::Position(1, 6))); +} + +TEST_CASE("testBrokenEscapeTolerant") +{ + const std::string testInput = "'\\3729472897292378'"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme item = lexer.next(); + + CHECK_EQ(item.type, Lexeme::QuotedString); + CHECK_EQ(item.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, int(testInput.size())))); +} + +TEST_CASE("testBigDelimiters") +{ + const std::string testInput = "--[===[\n" + "\n" + "\n" + "\n" + "]===]"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme item = lexer.next(); + + CHECK_EQ(item.type, Lexeme::Type::BlockComment); + CHECK_EQ(item.location, Luau::Location(Luau::Position(0, 0), Luau::Position(4, 5))); +} + +TEST_CASE("lookahead") +{ + const std::string testInput = "foo --[[ comment ]] bar : nil end"; + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + lexer.setSkipComments(true); + lexer.next(); // must call next() before reading data from lexer at least once + + CHECK_EQ(lexer.current().type, Lexeme::Name); + CHECK_EQ(lexer.current().name, std::string("foo")); + CHECK_EQ(lexer.lookahead().type, Lexeme::Name); + CHECK_EQ(lexer.lookahead().name, std::string("bar")); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::Name); + CHECK_EQ(lexer.current().name, std::string("bar")); + CHECK_EQ(lexer.lookahead().type, ':'); + + lexer.next(); + + CHECK_EQ(lexer.current().type, ':'); + CHECK_EQ(lexer.lookahead().type, Lexeme::ReservedNil); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::ReservedNil); + CHECK_EQ(lexer.lookahead().type, Lexeme::ReservedEnd); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::ReservedEnd); + CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); + + lexer.next(); + + CHECK_EQ(lexer.current().type, Lexeme::Eof); + CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("ParserTests"); + +TEST_CASE_FIXTURE(Fixture, "basic_parse") +{ + AstStat* stat = parse("print(\"Hello World!\")"); + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "can_haz_annotations") +{ + AstStatBlock* block = parse("local foo: string = \"Hello Types!\""); + REQUIRE(block != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "local_cannot_have_annotation_with_extensions_disabled") +{ + Luau::ParseOptions options; + options.allowTypeAnnotations = false; + + CHECK_THROWS_AS(parse("local foo: string = \"Hello Types!\"", options), std::exception); +} + +TEST_CASE_FIXTURE(Fixture, "local_with_annotation") +{ + AstStatBlock* block = parse(R"( + local foo: string = "Hello Types!" + )"); + + REQUIRE(block != nullptr); + + REQUIRE(block->body.size > 0); + + AstStatLocal* local = block->body.data[0]->as(); + REQUIRE(local != nullptr); + + REQUIRE_EQ(1, local->vars.size); + + AstLocal* l = local->vars.data[0]; + REQUIRE(l->annotation != nullptr); + + REQUIRE_EQ(1, local->values.size); +} + +TEST_CASE_FIXTURE(Fixture, "type_names_can_contain_dots") +{ + AstStatBlock* block = parse(R"( + local foo: SomeModule.CoolType + )"); + + REQUIRE(block != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "functions_cannot_have_return_annotations_if_extensions_are_disabled") +{ + Luau::ParseOptions options; + options.allowTypeAnnotations = false; + + CHECK_THROWS_AS(parse("function foo(): number return 55 end", options), std::exception); +} + +TEST_CASE_FIXTURE(Fixture, "functions_can_have_return_annotations") +{ + AstStatBlock* block = parse(R"( + function foo(): number return 55 end + )"); + + REQUIRE(block != nullptr); + REQUIRE(block->body.size > 0); + + AstStatFunction* statFunction = block->body.data[0]->as(); + REQUIRE(statFunction != nullptr); + + CHECK_EQ(statFunction->func->returnAnnotation.types.size, 1); + CHECK(statFunction->func->returnAnnotation.tailType == nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "functions_can_have_a_function_type_annotation") +{ + AstStatBlock* block = parse(R"( + function f(): (number) -> nil return nil end + )"); + + REQUIRE(block != nullptr); + REQUIRE(block->body.size > 0); + + AstStatFunction* statFunc = block->body.data[0]->as(); + REQUIRE(statFunc != nullptr); + + AstArray& retTypes = statFunc->func->returnAnnotation.types; + REQUIRE(statFunc->func->hasReturnAnnotation); + CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(retTypes.size == 1); + + AstTypeFunction* funTy = retTypes.data[0]->as(); + REQUIRE(funTy != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "function_return_type_should_disambiguate_from_function_type_and_multiple_returns") +{ + AstStatBlock* block = parse(R"( + function f(): (number, string) return 1, "foo" end + )"); + + REQUIRE(block != nullptr); + REQUIRE(block->body.size > 0); + + AstStatFunction* statFunc = block->body.data[0]->as(); + REQUIRE(statFunc != nullptr); + + AstArray& retTypes = statFunc->func->returnAnnotation.types; + REQUIRE(statFunc->func->hasReturnAnnotation); + CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(retTypes.size == 2); + + AstTypeReference* ty0 = retTypes.data[0]->as(); + REQUIRE(ty0 != nullptr); + REQUIRE(ty0->name == "number"); + + AstTypeReference* ty1 = retTypes.data[1]->as(); + REQUIRE(ty1 != nullptr); + REQUIRE(ty1->name == "string"); +} + +TEST_CASE_FIXTURE(Fixture, "function_return_type_should_parse_as_function_type_annotation_with_no_args") +{ + AstStatBlock* block = parse(R"( + function f(): () -> nil return nil end + )"); + + REQUIRE(block != nullptr); + REQUIRE(block->body.size > 0); + + AstStatFunction* statFunc = block->body.data[0]->as(); + REQUIRE(statFunc != nullptr); + + AstArray& retTypes = statFunc->func->returnAnnotation.types; + REQUIRE(statFunc->func->hasReturnAnnotation); + CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(retTypes.size == 1); + + AstTypeFunction* funTy = retTypes.data[0]->as(); + REQUIRE(funTy != nullptr); + REQUIRE(funTy->argTypes.types.size == 0); + CHECK(funTy->argTypes.tailType == nullptr); + CHECK(funTy->returnTypes.tailType == nullptr); + + AstTypeReference* ty = funTy->returnTypes.types.data[0]->as(); + REQUIRE(ty != nullptr); + REQUIRE(ty->name == "nil"); +} + +TEST_CASE_FIXTURE(Fixture, "annotations_can_be_tables") +{ + AstStatBlock* stat = parse(R"( + local zero: number + local one: {x: number, y: string} + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "tables_should_have_an_indexer_and_keys") +{ + AstStatBlock* stat = parse(R"( + local t: { + [string]: number, + f: () -> nil + } + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "tables_can_have_trailing_separator") +{ + AstStatBlock* stat = parse(R"( + local zero: number + local one: {x: number, y: string, } + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "tables_can_use_semicolons") +{ + AstStatBlock* stat = parse(R"( + local zero: number + local one: {x: number; y: string; } + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "other_places_where_type_annotations_are_allowed") +{ + AstStatBlock* stat = parse(R"( + for i: number = 0, 50 do end + for i: number, s: string in expr() do end + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "nil_is_a_valid_type_name") +{ + AstStatBlock* stat = parse(R"( + local n: nil + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "function_type_annotation") +{ + AstStatBlock* stat = parse(R"( + local f: (number, string) -> nil + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "functions_can_return_multiple_values") +{ + AstStatBlock* stat = parse(R"( + local f: (number) -> (number, number) + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "functions_can_have_0_arguments") +{ + AstStatBlock* stat = parse(R"( + local f: () -> number + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "functions_can_return_0_values") +{ + AstStatBlock* block = parse(R"( + local f: (number) -> () + )"); + + REQUIRE(block != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_two_function_types_if_no_returns") +{ + AstStatBlock* block = parse(R"( + local f: (string) -> () & (number) -> () + )"); + + REQUIRE(block != nullptr); + + AstStatLocal* local = block->body.data[0]->as(); + AstTypeIntersection* annotation = local->vars.data[0]->annotation->as(); + REQUIRE(annotation != nullptr); + CHECK(annotation->types.data[0]->as()); + CHECK(annotation->types.data[1]->as()); +} + +TEST_CASE_FIXTURE(Fixture, "intersection_of_two_function_types_if_two_or_more_returns") +{ + AstStatBlock* block = parse(R"( + local f: (string) -> (string, number) & (number) -> (number, string) + )"); + + REQUIRE(block != nullptr); + + AstStatLocal* local = block->body.data[0]->as(); + AstTypeIntersection* annotation = local->vars.data[0]->annotation->as(); + REQUIRE(annotation != nullptr); + CHECK(annotation->types.data[0]->as()); + CHECK(annotation->types.data[1]->as()); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_is_an_intersection_type_if_led_with_one_parenthesized_type") +{ + AstStatBlock* block = parse(R"( + local f: (string) -> (string) & (number) -> (number) + )"); + + REQUIRE(block != nullptr); + + AstStatLocal* local = block->body.data[0]->as(); + AstTypeFunction* annotation = local->vars.data[0]->annotation->as(); + REQUIRE(annotation != nullptr); + + AstTypeIntersection* returnAnnotation = annotation->returnTypes.types.data[0]->as(); + REQUIRE(returnAnnotation != nullptr); + CHECK(returnAnnotation->types.data[0]->as()); + CHECK(returnAnnotation->types.data[1]->as()); +} + +TEST_CASE_FIXTURE(Fixture, "illegal_type_alias_if_extensions_are_disabled") +{ + Luau::ParseOptions options; + options.allowTypeAnnotations = false; + + CHECK_THROWS_AS(parse("type A = number", options), std::exception); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_to_a_typeof") +{ + AstStatBlock* block = parse(R"( + type A = typeof(1) + )"); + + REQUIRE(block != nullptr); + REQUIRE(block->body.size > 0); + + auto typeAliasStat = block->body.data[0]->as(); + REQUIRE(typeAliasStat != nullptr); + CHECK_EQ(typeAliasStat->location, (Location{{1, 8}, {1, 26}})); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_should_point_to_string") +{ + AstStatBlock* block = parse(R"( + type A = string + )"); + + REQUIRE(block != nullptr); + REQUIRE(block->body.size > 0); + REQUIRE(block->body.data[0]->is()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_should_not_interfere_with_type_function_call_or_assignment") +{ + AstStatBlock* block = parse(R"( + type("a") + type = nil + )"); + + REQUIRE(block != nullptr); + REQUIRE(block->body.size > 0); + + AstStatExpr* stat = block->body.data[0]->as(); + REQUIRE(stat != nullptr); + REQUIRE(stat->expr->as()); + + REQUIRE(block->body.data[1]->is()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_should_work_when_name_is_also_local") +{ + AstStatBlock* block = parse(R"( + local A = nil + type A = string + )"); + + REQUIRE(block != nullptr); + REQUIRE(block->body.size == 2); + REQUIRE(block->body.data[0]->is()); + REQUIRE(block->body.data[1]->is()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_messages") +{ + CHECK_EQ(getParseError(R"( + local a: (number, number) -> (string + )"), + "Expected ')' (to close '(' at line 2), got "); + + CHECK_EQ(getParseError(R"( + local a: (number, number) -> ( + string + )"), + "Expected ')' (to close '(' at line 2), got "); + + CHECK_EQ(getParseError(R"( + local a: (number, number) + )"), + "Expected '->' when parsing function type, got "); + + CHECK_EQ(getParseError(R"( + local a: (number, number + )"), + "Expected ')' (to close '(' at line 2), got "); + + CHECK_EQ(getParseError(R"( + local a: {foo: string, + )"), + "Expected identifier when parsing table field, got "); + + CHECK_EQ(getParseError(R"( + local a: {foo: string + )"), + "Expected '}' (to close '{' at line 2), got "); + + CHECK_EQ(getParseError(R"( + local a: { [string]: number, [number]: string } + )"), + "Cannot have more than one table indexer"); + + ScopedFastFlag sffs1{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauGenericFunctionsParserFix", true}; + ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; + + CHECK_EQ(getParseError(R"( + type T = foo + )"), + "Expected '(' when parsing function parameters, got 'foo'"); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_intersection_and_union_not_allowed") +{ + matchParseError("type A = number & string | boolean", "Mixing union and intersection types is not allowed; consider wrapping in parentheses."); +} + +TEST_CASE_FIXTURE(Fixture, "mixed_intersection_and_union_allowed_when_parenthesized") +{ + try + { + parse("type A = (number & string) | boolean"); + } + catch (const ParseErrors& e) + { + FAIL(e.what()); + } +} + +TEST_CASE_FIXTURE(Fixture, "cannot_write_multiple_values_in_type_groups") +{ + matchParseError("type F = ((string, number))", "Expected '->' when parsing function type, got ')'"); + matchParseError("type F = () -> ((string, number))", "Expected '->' when parsing function type, got ')'"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_error_messages") +{ + CHECK_EQ(getParseError("type 5 = number"), "Expected identifier when parsing type name, got '5'"); + CHECK_EQ(getParseError("type A"), "Expected '=' when parsing type alias, got "); + CHECK_EQ(getParseError("type A<"), "Expected identifier, got "); + CHECK_EQ(getParseError("type A' (to close '<' at column 7), got "); +} + +TEST_CASE_FIXTURE(Fixture, "type_assertion_expression") +{ + (void)parse(R"( + local a = something() :: any + )"); +} + +// The bug that motivated this test was an infinite loop. +// TODO: Set a timer and crash if the timeout is exceeded. +TEST_CASE_FIXTURE(Fixture, "last_line_does_not_have_to_be_blank") +{ + (void)parse("-- print('hello')"); +} + +TEST_CASE_FIXTURE(Fixture, "type_assertion_expression_binds_tightly") +{ + AstStatBlock* stat = parse(R"( + local a = one :: any + two :: any + )"); + + REQUIRE(stat != nullptr); + + AstStatBlock* block = stat->as(); + REQUIRE(block != nullptr); + REQUIRE_EQ(1, block->body.size); + + AstStatLocal* local = block->body.data[0]->as(); + REQUIRE(local != nullptr); + REQUIRE_EQ(1, local->values.size); + + AstExprBinary* bin = local->values.data[0]->as(); + REQUIRE(bin != nullptr); + + CHECK(nullptr != bin->left->as()); + CHECK(nullptr != bin->right->as()); +} + +TEST_CASE_FIXTURE(Fixture, "mode_is_unset_if_no_hot_comment") +{ + ParseResult result = parseEx("print('Hello World!')"); + CHECK(result.hotcomments.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "sense_hot_comment_on_first_line") +{ + ParseResult result = parseEx(" --!strict "); + std::optional mode = parseMode(result.hotcomments); + REQUIRE(bool(mode)); + CHECK_EQ(int(*mode), int(Mode::Strict)); +} + +TEST_CASE_FIXTURE(Fixture, "stop_if_line_ends_with_hyphen") +{ + CHECK_THROWS_AS(parse(" -"), std::exception); +} + +TEST_CASE_FIXTURE(Fixture, "nonstrict_mode") +{ + ParseResult result = parseEx("--!nonstrict"); + CHECK(result.errors.empty()); + std::optional mode = parseMode(result.hotcomments); + REQUIRE(bool(mode)); + CHECK_EQ(int(*mode), int(Mode::Nonstrict)); +} + +TEST_CASE_FIXTURE(Fixture, "nocheck_mode") +{ + ParseResult result = parseEx("--!nocheck"); + CHECK(result.errors.empty()); + std::optional mode = parseMode(result.hotcomments); + REQUIRE(bool(mode)); + CHECK_EQ(int(*mode), int(Mode::NoCheck)); +} + +TEST_CASE_FIXTURE(Fixture, "vertical_space") +{ + ParseResult result = parseEx("a()\vb()"); + CHECK(result.errors.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_type_name") +{ + CHECK_EQ(getParseError(R"( + local a: Foo.= + )"), + "Expected identifier when parsing field name, got '='"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_numbers_decimal") +{ + AstStat* stat = parse("return 1, .5, 1.5, 1e-5, 1.5e-5, 12_345.1_25"); + REQUIRE(stat != nullptr); + + AstStatReturn* str = stat->as()->body.data[0]->as(); + CHECK(str->list.size == 6); + CHECK_EQ(str->list.data[0]->as()->value, 1.0); + CHECK_EQ(str->list.data[1]->as()->value, 0.5); + CHECK_EQ(str->list.data[2]->as()->value, 1.5); + CHECK_EQ(str->list.data[3]->as()->value, 1.0e-5); + CHECK_EQ(str->list.data[4]->as()->value, 1.5e-5); + CHECK_EQ(str->list.data[5]->as()->value, 12345.125); +} + +TEST_CASE_FIXTURE(Fixture, "parse_numbers_hexadecimal") +{ + AstStat* stat = parse("return 0xab, 0XAB05, 0xff_ff"); + REQUIRE(stat != nullptr); + + AstStatReturn* str = stat->as()->body.data[0]->as(); + CHECK(str->list.size == 3); + CHECK_EQ(str->list.data[0]->as()->value, 0xab); + CHECK_EQ(str->list.data[1]->as()->value, 0xAB05); + CHECK_EQ(str->list.data[2]->as()->value, 0xFFFF); +} + +TEST_CASE_FIXTURE(Fixture, "parse_numbers_binary") +{ + AstStat* stat = parse("return 0b1, 0b0, 0b101010"); + REQUIRE(stat != nullptr); + + AstStatReturn* str = stat->as()->body.data[0]->as(); + CHECK(str->list.size == 3); + CHECK_EQ(str->list.data[0]->as()->value, 1); + CHECK_EQ(str->list.data[1]->as()->value, 0); + CHECK_EQ(str->list.data[2]->as()->value, 42); +} + +TEST_CASE_FIXTURE(Fixture, "parse_numbers_error") +{ + CHECK_EQ(getParseError("return 0b123"), "Malformed number"); + CHECK_EQ(getParseError("return 123x"), "Malformed number"); + CHECK_EQ(getParseError("return 0xg"), "Malformed number"); +} + +TEST_CASE_FIXTURE(Fixture, "break_return_not_last_error") +{ + CHECK_EQ(getParseError("return 0 print(5)"), "Expected , got 'print'"); + CHECK_EQ(getParseError("while true do break print(5) end"), "Expected 'end' (to close 'do' at column 12), got 'print'"); +} + +TEST_CASE_FIXTURE(Fixture, "error_on_unicode") +{ + CHECK_EQ(getParseError(R"( + local ☃ = 10 + )"), + "Expected identifier when parsing variable name, got Unicode character U+2603"); +} + +TEST_CASE_FIXTURE(Fixture, "allow_unicode_in_string") +{ + ParseResult result = parseEx("local snowman = \"☃\""); + CHECK(result.errors.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "error_on_confusable") +{ + CHECK_EQ(getParseError(R"( + local pi = 3․13 + )"), + "Expected identifier when parsing expression, got Unicode character U+2024 (did you mean '.'?)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_on_non_utf8_sequence") +{ + const char* expected = "Expected identifier when parsing expression, got invalid UTF-8 sequence"; + + CHECK_EQ(getParseError("local pi = \xFF!"), expected); + CHECK_EQ(getParseError("local pi = \xE2!"), expected); +} + +TEST_CASE_FIXTURE(Fixture, "lex_broken_unicode") +{ + const std::string testInput = std::string("\xFF\xFE☃․"); + + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + Lexeme lexeme = lexer.current(); + + lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::BrokenUnicode); + CHECK_EQ(lexeme.codepoint, 0); + CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 0), Luau::Position(0, 1))); + + lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::BrokenUnicode); + CHECK_EQ(lexeme.codepoint, 0); + CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 1), Luau::Position(0, 2))); + + lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::BrokenUnicode); + CHECK_EQ(lexeme.codepoint, 0x2603); + CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 2), Luau::Position(0, 5))); + + lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::BrokenUnicode); + CHECK_EQ(lexeme.codepoint, 0x2024); + CHECK_EQ(lexeme.location, Luau::Location(Luau::Position(0, 5), Luau::Position(0, 8))); + + lexeme = lexer.next(); + CHECK_EQ(lexeme.type, Lexeme::Eof); +} + +TEST_CASE_FIXTURE(Fixture, "parse_continue") +{ + AstStatBlock* stat = parse(R"( + while true do + continue() + continue = 5 + continue, continue = continue + continue + end + )"); + + REQUIRE(stat != nullptr); + + AstStatBlock* block = stat->as(); + REQUIRE(block != nullptr); + REQUIRE_EQ(1, block->body.size); + + AstStatWhile* wb = block->body.data[0]->as(); + REQUIRE(wb != nullptr); + + AstStatBlock* wblock = wb->body->as(); + REQUIRE(wblock != nullptr); + REQUIRE_EQ(4, wblock->body.size); + + REQUIRE(wblock->body.data[0]->is()); + REQUIRE(wblock->body.data[1]->is()); + REQUIRE(wblock->body.data[2]->is()); + REQUIRE(wblock->body.data[3]->is()); +} + +TEST_CASE_FIXTURE(Fixture, "continue_not_last_error") +{ + CHECK_EQ(getParseError("while true do continue print(5) end"), "Expected 'end' (to close 'do' at column 12), got 'print'"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_export_type") +{ + AstStatBlock* stat = parse(R"( + export() + export = 5 + export, export = export + export type A = number + type A = number + )"); + + REQUIRE(stat != nullptr); + + AstStatBlock* block = stat->as(); + REQUIRE(block != nullptr); + REQUIRE_EQ(5, block->body.size); + + REQUIRE(block->body.data[0]->is()); + REQUIRE(block->body.data[1]->is()); + REQUIRE(block->body.data[2]->is()); + REQUIRE(block->body.data[3]->is()); + REQUIRE(block->body.data[4]->is()); +} + +TEST_CASE_FIXTURE(Fixture, "export_is_an_identifier_only_when_followed_by_type") +{ + try + { + parse(R"( + export function a() end + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Incomplete statement: expected assignment or a function call", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "incomplete_statement_error") +{ + CHECK_EQ(getParseError("fiddlesticks"), "Incomplete statement: expected assignment or a function call"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_compound_assignment") +{ + AstStatBlock* block = parse(R"( + a += 5 + )"); + + REQUIRE(block != nullptr); + REQUIRE(block->body.size == 1); + REQUIRE(block->body.data[0]->is()); + REQUIRE(block->body.data[0]->as()->op == AstExprBinary::Add); +} + +TEST_CASE_FIXTURE(Fixture, "parse_compound_assignment_error_call") +{ + try + { + parse(R"( + a() += 5 + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Expected identifier when parsing expression, got '+='", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_compound_assignment_error_not_lvalue") +{ + try + { + parse(R"( + (a) += 5 + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Assigned expression must be a variable or a field", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_compound_assignment_error_multiple") +{ + try + { + parse(R"( + a, b += 5 + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Expected '=' when parsing assignment, got '+='", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection") +{ + try + { + parse(R"(-- i am line 1 +function BottomUpTree(item, depth) + if depth > 0 then + local i = item + item + depth = depth - 1 + local left, right = BottomUpTree(i-1, depth), BottomUpTree(i, depth) + return { item, left, right } + else + return { item } +end + +function ItemCheck(tree) + if tree[2] then + return tree[1] + ItemCheck(tree[2]) - ItemCheck(tree[3]) + else + return tree[1] + end +end + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 8?", + e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_single_line") +{ + try + { + parse(R"(-- i am line 1 +function ItemCheck(tree) + if tree[2] then return tree[1] + ItemCheck(tree[2]) - ItemCheck(tree[3]) else return tree[1] +end + +function BottomUpTree(item, depth) + if depth > 0 then + local i = item + item + depth = depth - 1 + local left, right = BottomUpTree(i-1, depth), BottomUpTree(i, depth) + return { item, left, right } + else + return { item } + end +end + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 3?", + e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_local_repeat") +{ + try + { + parse(R"(-- i am line 1 +repeat + print(1) + repeat + print(2) + print(3) +until false + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Expected 'until' (to close 'repeat' at line 2), got ; did you forget to close 'repeat' at line 4?", + e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_local_function") +{ + try + { + parse(R"(-- i am line 1 +local function BottomUpTree(item, depth) + if depth > 0 then + local i = item + item + depth = depth - 1 + local left, right = BottomUpTree(i-1, depth), BottomUpTree(i, depth) + return { item, left, right } + else + return { item } +end + +local function ItemCheck(tree) + if tree[2] then + return tree[1] + ItemCheck(tree[2]) - ItemCheck(tree[3]) + else + return tree[1] + end +end + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 8?", + e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_failsafe_earlier") +{ + try + { + parse(R"(-- i am line 1 +local function ItemCheck(tree) + if tree[2] then + return tree[1] + ItemCheck(tree[2]) - ItemCheck(tree[3]) + else + return tree[1] + end +end + +local function BottomUpTree(item, depth) + if depth > 0 then + local i = item + item + depth = depth - 1 + local left, right = BottomUpTree(i-1, depth), BottomUpTree(i, depth) + return { item, left, right } + else + return { item } + end + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Expected 'end' (to close 'function' at line 10), got ", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection_nested") +{ + try + { + parse(R"(-- i am line 1 +function stringifyTable(t) + local entries = {} + for k, v in pairs(t) do + -- if we find a nested table, convert that recursively + if type(v) == "table" then + v = stringifyTable(v) + else + v = tostring(v) + k = tostring(k) + + -- add another entry to our stringified table + entries[#entries + 1] = ("s = s"):format(k, v) + end + + -- the memory location of the table + local id = tostring(t):sub(8) + + return ("{s}@s"):format(table.concat(entries, ", "), id) +end + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Expected 'end' (to close 'function' at line 2), got ; did you forget to close 'else' at line 8?", + e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_table_literal") +{ + try + { + parse(R"( +function stringifyTable(t) + local foo = (name = t) + return foo +end + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ( + "Expected ')' (to close '(' at column 17), got '='; did you mean to use '{' when defining a table?", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_function_call") +{ + try + { + parse(R"( +function stringifyTable(t) + local foo = t:Parse 2 + return foo +end + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ(e.getErrors().front().getLocation().begin.line, 2); + CHECK_EQ("Expected '(', '{' or when parsing function call, got '2'", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_function_call_newline") +{ + try + { + parse(R"( +function stringifyTable(t) + local foo = t:Parse + return foo +end + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ(e.getErrors().front().getLocation().begin.line, 2); + CHECK_EQ("Expected function call arguments after '('", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_type_group") +{ + ScopedFastInt sfis{"LuauRecursionLimit", 20}; + + matchParseError( + "function f(): (((((((((Fail))))))))) end", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + + matchParseError("function f(): () -> () -> () -> () -> () -> () -> () -> () -> () -> () -> () end", + "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); + + matchParseError( + "local t: {a: {b: {c: {d: {e: {f: {}}}}}}}", "Exceeded allowed recursion depth; simplify your type annotation to make the code compile"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") +{ + ScopedFastInt sfis{"LuauRecursionLimit", 10}; + ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; + + matchParseErrorPrefix( + "function f() if true then if true then if true then if true then if true then if true then if true then if true then if true " + "then if true then if true then end end end end end end end end end end end end", + "Exceeded allowed recursion depth;"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements") +{ + ScopedFastInt sfis{"LuauRecursionLimit", 10}; + ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; + + matchParseErrorPrefix( + "function f() if false then elseif false then elseif false then elseif false then elseif false then elseif false then elseif " + "false then elseif false then elseif false then elseif false then elseif false then end end", + "Exceeded allowed recursion depth;"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1") +{ + ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastInt sfis{"LuauRecursionLimit", 10}; + + matchParseError("function f() return if true then 1 elseif true then 2 elseif true then 3 elseif true then 4 elseif true then 5 elseif true then " + "6 elseif true then 7 elseif true then 8 elseif true then 9 elseif true then 10 else 11 end", + "Exceeded allowed recursion depth; simplify your expression to make the code compile"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions2") +{ + ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastInt sfis{"LuauRecursionLimit", 10}; + + matchParseError( + "function f() return if if if if if if if if if if true then false else true then false else true then false else true then false else true " + "then false else true then false else true then false else true then false else true then false else true then 1 else 2 end", + "Exceeded allowed recursion depth; simplify your expression to make the code compile"); +} + +TEST_CASE_FIXTURE(Fixture, "unparenthesized_function_return_type_list") +{ + matchParseError( + "function foo(): string, number end", "Expected a statement, got ','; did you forget to wrap the list of return types in parentheses?"); + + matchParseError("function foo(): (number) -> string, string", + "Expected a statement, got ','; did you forget to wrap the list of return types in parentheses?"); + + // Will throw if the parse fails + parse(R"( + type Vector3MT = { + __add: (Vector3MT, Vector3MT) -> Vector3MT, + __mul: (Vector3MT, Vector3MT|number) -> Vector3MT + } + )"); +} + +TEST_CASE_FIXTURE(Fixture, "short_array_types") +{ + AstStatBlock* stat = parse(R"( + local n: {string} + )"); + + REQUIRE(stat != nullptr); + AstStatLocal* local = stat->body.data[0]->as(); + AstTypeTable* annotation = local->vars.data[0]->annotation->as(); + REQUIRE(annotation != nullptr); + CHECK(annotation->props.size == 0); + REQUIRE(annotation->indexer); + REQUIRE(annotation->indexer->indexType->is()); + CHECK(annotation->indexer->indexType->as()->name == "number"); + REQUIRE(annotation->indexer->resultType->is()); + CHECK(annotation->indexer->resultType->as()->name == "string"); +} + +TEST_CASE_FIXTURE(Fixture, "short_array_types_must_be_alone") +{ + matchParseError("local n: {string, number}", "Expected '}' (to close '{' at column 10), got ','"); + matchParseError("local n: {[number]: string, number}", "Expected ':' when parsing table field, got '}'"); + matchParseError("local n: {x: string, number}", "Expected ':' when parsing table field, got '}'"); + matchParseError("local n: {x: string, nil}", "Expected identifier when parsing table field, got 'nil'"); +} + +TEST_CASE_FIXTURE(Fixture, "short_array_types_do_not_break_field_names") +{ + AstStatBlock* stat = parse(R"( + local n: {string: number} + )"); + + REQUIRE(stat != nullptr); + AstStatLocal* local = stat->body.data[0]->as(); + AstTypeTable* annotation = local->vars.data[0]->annotation->as(); + REQUIRE(annotation != nullptr); + REQUIRE(annotation->props.size == 1); + CHECK(!annotation->indexer); + REQUIRE(annotation->props.data[0].name == "string"); + REQUIRE(annotation->props.data[0].type->is()); + REQUIRE(annotation->props.data[0].type->as()->name == "number"); +} + +TEST_CASE_FIXTURE(Fixture, "short_array_types_are_not_field_names_when_complex") +{ + matchParseError("local n: {string | number: number}", "Expected '}' (to close '{' at column 10), got ':'"); +} + +TEST_CASE_FIXTURE(Fixture, "nil_can_not_be_a_field_name") +{ + matchParseError("local n: {nil: number}", "Expected '}' (to close '{' at column 10), got ':'"); +} + +TEST_CASE_FIXTURE(Fixture, "string_literal_call") +{ + AstStatBlock* stat = parse("do foo 'bar' end"); + REQUIRE(stat != nullptr); + AstStatBlock* dob = stat->body.data[0]->as(); + AstStatExpr* stc = dob->body.data[0]->as(); + REQUIRE(stc != nullptr); + AstExprCall* ec = stc->expr->as(); + CHECK(ec->args.size == 1); + AstExprConstantString* arg = ec->args.data[0]->as(); + REQUIRE(arg != nullptr); + CHECK(std::string(arg->value.data, arg->value.size) == "bar"); +} + +TEST_CASE_FIXTURE(Fixture, "multiline_strings_newlines") +{ + AstStatBlock* stat = parse("return [=[\nfoo\r\nbar\n\nbaz\n]=]"); + REQUIRE(stat != nullptr); + + AstStatReturn* ret = stat->body.data[0]->as(); + REQUIRE(ret != nullptr); + + AstExprConstantString* str = ret->list.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(std::string(str->value.data, str->value.size) == "foo\nbar\n\nbaz\n"); +} + +TEST_CASE_FIXTURE(Fixture, "string_literals_escape") +{ + AstStatBlock* stat = parse(R"( +return +"foo\n\r", +"foo\0324", +"foo\x204", +"foo\u{20}", +"foo\u{0451}" +)"); + + REQUIRE(stat != nullptr); + + AstStatReturn* ret = stat->body.data[0]->as(); + REQUIRE(ret != nullptr); + CHECK(ret->list.size == 5); + + AstExprConstantString* str; + + str = ret->list.data[0]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "foo\n\r"); + + str = ret->list.data[1]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "foo 4"); + + str = ret->list.data[2]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "foo 4"); + + str = ret->list.data[3]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "foo "); + + str = ret->list.data[4]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "foo\xd1\x91"); +} + +TEST_CASE_FIXTURE(Fixture, "string_literals_escape_newline") +{ + AstStatBlock* stat = parse("return \"foo\\z\n bar\", \"foo\\\n bar\", \"foo\\\r\nbar\""); + + REQUIRE(stat != nullptr); + + AstStatReturn* ret = stat->body.data[0]->as(); + REQUIRE(ret != nullptr); + CHECK(ret->list.size == 3); + + AstExprConstantString* str; + + str = ret->list.data[0]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "foobar"); + + str = ret->list.data[1]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "foo\n bar"); + + str = ret->list.data[2]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "foo\nbar"); +} + +TEST_CASE_FIXTURE(Fixture, "string_literals_escapes") +{ + AstStatBlock* stat = parse(R"( +return +"\xAB", +"\u{2024}", +"\121", +"\1x", +"\t", +"\n" +)"); + + REQUIRE(stat != nullptr); + + AstStatReturn* ret = stat->body.data[0]->as(); + REQUIRE(ret != nullptr); + CHECK(ret->list.size == 6); + + AstExprConstantString* str; + + str = ret->list.data[0]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "\xAB"); + + str = ret->list.data[1]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "\xE2\x80\xA4"); + + str = ret->list.data[2]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "\x79"); + + str = ret->list.data[3]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "\x01x"); + + str = ret->list.data[4]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "\t"); + + str = ret->list.data[5]->as(); + REQUIRE(str != nullptr); + CHECK_EQ(std::string(str->value.data, str->value.size), "\n"); +} + +TEST_CASE_FIXTURE(Fixture, "string_literals_escapes_broken") +{ + const char* expected = "String literal contains malformed escape sequence"; + + matchParseError("return \"\\u{\"", expected); + matchParseError("return \"\\u{FO}\"", expected); + matchParseError("return \"\\u{123456789}\"", expected); + matchParseError("return \"\\359\"", expected); + matchParseError("return \"\\xFO\"", expected); + matchParseError("return \"\\xF\"", expected); + matchParseError("return \"\\x\"", expected); +} + +TEST_CASE_FIXTURE(Fixture, "string_literals_broken") +{ + matchParseError("return \"", "Malformed string"); + matchParseError("return \"\\", "Malformed string"); + matchParseError("return \"\r\r", "Malformed string"); +} + +TEST_CASE_FIXTURE(Fixture, "number_literals") +{ + AstStatBlock* stat = parse(R"( +return +1, +1.5, +.5, +12_34_56, +0x1234, + 0b010101 +)"); + + REQUIRE(stat != nullptr); + + AstStatReturn* ret = stat->body.data[0]->as(); + REQUIRE(ret != nullptr); + CHECK(ret->list.size == 6); + + AstExprConstantNumber* num; + + num = ret->list.data[0]->as(); + REQUIRE(num != nullptr); + CHECK_EQ(num->value, 1.0); + + num = ret->list.data[1]->as(); + REQUIRE(num != nullptr); + CHECK_EQ(num->value, 1.5); + + num = ret->list.data[2]->as(); + REQUIRE(num != nullptr); + CHECK_EQ(num->value, 0.5); + + num = ret->list.data[3]->as(); + REQUIRE(num != nullptr); + CHECK_EQ(num->value, 123456); + + num = ret->list.data[4]->as(); + REQUIRE(num != nullptr); + CHECK_EQ(num->value, 0x1234); + + num = ret->list.data[5]->as(); + REQUIRE(num != nullptr); + CHECK_EQ(num->value, 0x15); +} + +TEST_CASE_FIXTURE(Fixture, "end_extent_of_functions_unions_and_intersections") +{ + AstStatBlock* block = parse(R"( + type F = (string) -> string + type G = string | number | boolean + type H = string & number & boolean + print('hello') + )"); + + REQUIRE_EQ(4, block->body.size); + CHECK_EQ((Position{1, 35}), block->body.data[0]->location.end); + CHECK_EQ((Position{2, 42}), block->body.data[1]->location.end); + CHECK_EQ((Position{3, 42}), block->body.data[2]->location.end); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_loop_control") +{ + matchParseError("break", "break statement must be inside a loop"); + matchParseError("repeat local function a() break end until false", "break statement must be inside a loop"); + matchParseError("continue", "continue statement must be inside a loop"); + matchParseError("repeat local function a() continue end until false", "continue statement must be inside a loop"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_confusing_function_call") +{ + auto result1 = matchParseError(R"( + function add(x, y) return x + y end + add + (4, 7) + )", + "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " + "statements"); + + CHECK(result1.errors.size() == 1); + + auto result2 = matchParseError(R"( + function add(x, y) return x + y end + local f = add + (f :: any)['x'] = 2 + )", + "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " + "statements"); + + CHECK(result2.errors.size() == 1); + + auto result3 = matchParseError(R"( + local x = {} + function x:add(a, b) return a + b end + x:add + (1, 2) + )", + "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " + "statements"); + + CHECK(result3.errors.size() == 1); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_varargs") +{ + matchParseError("function add(x, y) return ... end", "Cannot use '...' outside of a vararg function"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_assignment_lvalue") +{ + matchParseError(R"( + local a, b + (2), b = b, a + )", + "Assigned expression must be a variable or a field"); + + matchParseError(R"( + local a, b + a, (3) = b, a + )", + "Assigned expression must be a variable or a field"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_error_type_annotation") +{ + matchParseError("local a : 2 = 2", "Expected type, got '2'"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_declarations") +{ + AstStatBlock* stat = parseEx(R"( + declare foo: number + declare function bar(x: number): string + declare function var(...: any) + )") + .root; + + REQUIRE(stat); + REQUIRE_EQ(stat->body.size, 3); + + AstStatDeclareGlobal* global = stat->body.data[0]->as(); + REQUIRE(global); + CHECK(global->name == "foo"); + CHECK(global->type); + + AstStatDeclareFunction* func = stat->body.data[1]->as(); + REQUIRE(func); + CHECK(func->name == "bar"); + REQUIRE_EQ(func->params.types.size, 1); + REQUIRE_EQ(func->retTypes.types.size, 1); + + AstStatDeclareFunction* varFunc = stat->body.data[2]->as(); + REQUIRE(varFunc); + CHECK(varFunc->name == "var"); + CHECK(varFunc->params.tailType); + + matchParseError("declare function foo(x)", "All declaration parameters must be annotated"); + matchParseError("declare foo", "Expected ':' when parsing global variable declaration, got "); +} + +TEST_CASE_FIXTURE(Fixture, "parse_class_declarations") +{ + AstStatBlock* stat = parseEx(R"( + declare class Foo + prop: number + function method(self, foo: number): string + end + + declare class Bar extends Foo + prop2: string + end + )") + .root; + + REQUIRE_EQ(stat->body.size, 2); + + AstStatDeclareClass* declaredClass = stat->body.data[0]->as(); + REQUIRE(declaredClass); + CHECK(declaredClass->name == "Foo"); + CHECK(!declaredClass->superName); + + REQUIRE_EQ(declaredClass->props.size, 2); + + AstDeclaredClassProp& prop = declaredClass->props.data[0]; + CHECK(prop.name == "prop"); + CHECK(prop.ty->is()); + + AstDeclaredClassProp& method = declaredClass->props.data[1]; + CHECK(method.name == "method"); + CHECK(method.ty->is()); + + AstStatDeclareClass* subclass = stat->body.data[1]->as(); + REQUIRE(subclass); + REQUIRE(subclass->superName); + CHECK(subclass->name == "Bar"); + CHECK(*subclass->superName == "Foo"); + + REQUIRE_EQ(subclass->props.size, 1); + AstDeclaredClassProp& prop2 = subclass->props.data[0]; + CHECK(prop2.name == "prop2"); + CHECK(prop2.ty->is()); +} + +TEST_CASE_FIXTURE(Fixture, "class_method_properties") +{ + const ParseResult p1 = matchParseError(R"( + declare class Foo + -- method's first parameter must be 'self' + function method(foo: number) + function method2(self) + end + )", + "'self' must be present as the unannotated first parameter"); + + REQUIRE_EQ(1, p1.root->body.size); + + AstStatDeclareClass* klass = p1.root->body.data[0]->as(); + REQUIRE(klass != nullptr); + + CHECK_EQ(2, klass->props.size); + + const ParseResult p2 = matchParseError(R"( + declare class Foo + function method(self, foo) + function method2() + end + )", + "All declaration parameters aside from 'self' must be annotated"); + + REQUIRE_EQ(1, p2.root->body.size); + + AstStatDeclareClass* klass2 = p2.root->body.data[0]->as(); + REQUIRE(klass2 != nullptr); + + CHECK_EQ(2, klass2->props.size); +} + +TEST_CASE_FIXTURE(Fixture, "parse_variadics") +{ + //clang-format off + AstStatBlock* stat = parseEx(R"( + function foo(bar, ...: number): ...string + end + + type Foo = (string, number, ...number) -> ...boolean + type Bar = () -> (number, ...boolean) + )") + .root; + //clang-format on + + REQUIRE(stat); + REQUIRE_EQ(stat->body.size, 3); + + AstStatFunction* fn = stat->body.data[0]->as(); + REQUIRE(fn); + CHECK(fn->func->vararg); + CHECK(fn->func->varargAnnotation); + + AstStatTypeAlias* foo = stat->body.data[1]->as(); + REQUIRE(foo); + AstTypeFunction* fnFoo = foo->type->as(); + REQUIRE(fnFoo); + CHECK_EQ(fnFoo->argTypes.types.size, 2); + CHECK(fnFoo->argTypes.tailType); + CHECK_EQ(fnFoo->returnTypes.types.size, 0); + CHECK(fnFoo->returnTypes.tailType); + + AstStatTypeAlias* bar = stat->body.data[2]->as(); + REQUIRE(bar); + AstTypeFunction* fnBar = bar->type->as(); + REQUIRE(fnBar); + CHECK_EQ(fnBar->argTypes.types.size, 0); + CHECK(!fnBar->argTypes.tailType); + CHECK_EQ(fnBar->returnTypes.types.size, 1); + CHECK(fnBar->returnTypes.tailType); +} + +TEST_CASE_FIXTURE(Fixture, "variadics_must_be_last") +{ + matchParseError("function foo(): (...number, string) end", "Expected ')' (to close '(' at column 17), got ','"); + matchParseError("type Foo = (...number, string) -> (...string, number)", "Expected ')' (to close '(' at column 12), got ','"); +} + +TEST_CASE_FIXTURE(Fixture, "variadic_definition_parsing") +{ + AstStatBlock* stat = parseEx(R"( + declare function foo(...: string): ...string + declare class Foo + function a(self, ...: string): ...string + end + )") + .root; + + REQUIRE(stat != nullptr); + + matchParseError("declare function foo(...)", "All declaration parameters must be annotated"); + matchParseError("declare class Foo function a(self, ...) end", "All declaration parameters aside from 'self' must be annotated"); +} + +TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") +{ + // Doesn't need LuauGenericFunctions + ScopedFastFlag sffs{"LuauParseGenericFunctions", true}; + + ParseResult result = parseEx(R"( + function f(...: a...) + end + + type A = (a...) -> b... + )"); + + AstStatBlock* stat = result.root; + REQUIRE(stat != nullptr); + + AstStatFunction* fn = stat->body.data[0]->as(); + REQUIRE(fn != nullptr); + REQUIRE(fn->func->varargAnnotation != nullptr); + + AstTypePackGeneric* annot = fn->func->varargAnnotation->as(); + REQUIRE(annot != nullptr); + CHECK(annot->genericName == "a"); + + AstStatTypeAlias* alias = stat->body.data[1]->as(); + REQUIRE(alias != nullptr); + AstTypeFunction* fnTy = alias->type->as(); + REQUIRE(fnTy != nullptr); + + AstTypePackGeneric* argAnnot = fnTy->argTypes.tailType->as(); + REQUIRE(argAnnot != nullptr); + CHECK(argAnnot->genericName == "a"); + + AstTypePackGeneric* retAnnot = fnTy->returnTypes.tailType->as(); + REQUIRE(retAnnot != nullptr); + CHECK(retAnnot->genericName == "b"); +} + +TEST_CASE_FIXTURE(Fixture, "generic_function_declaration_parsing") +{ + // Doesn't need LuauGenericFunctions + ScopedFastFlag sffs{"LuauParseGenericFunctions", true}; + + ParseResult result = parseEx(R"( + declare function f() + )"); + + AstStatBlock* stat = result.root; + REQUIRE(stat != nullptr); + + AstStatDeclareFunction* decl = stat->body.data[0]->as(); + REQUIRE(decl != nullptr); + REQUIRE_EQ(decl->generics.size, 2); + REQUIRE_EQ(decl->genericPacks.size, 1); +} + +TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments") +{ + { + ParseResult result = parseEx("type MyFunc = (a: number, b: string, c: number) -> string"); + + AstStatBlock* stat = result.root; + REQUIRE(stat != nullptr); + + AstStatTypeAlias* decl = stat->body.data[0]->as(); + REQUIRE(decl != nullptr); + AstTypeFunction* func = decl->type->as(); + REQUIRE(func != nullptr); + REQUIRE_EQ(func->argTypes.types.size, 3); + REQUIRE_EQ(func->argNames.size, 3); + REQUIRE(func->argNames.data[2]); + CHECK_EQ(func->argNames.data[2]->first, "c"); + } + + { + ParseResult result = parseEx("type MyFunc = (a: number, string, c: number) -> string"); + + AstStatBlock* stat = result.root; + REQUIRE(stat != nullptr); + + AstStatTypeAlias* decl = stat->body.data[0]->as(); + REQUIRE(decl != nullptr); + AstTypeFunction* func = decl->type->as(); + REQUIRE(func != nullptr); + REQUIRE_EQ(func->argTypes.types.size, 3); + REQUIRE_EQ(func->argNames.size, 3); + REQUIRE(!func->argNames.data[1]); + REQUIRE(func->argNames.data[2]); + CHECK_EQ(func->argNames.data[2]->first, "c"); + } + + { + ParseResult result = parseEx("type MyFunc = (a: number, string, number) -> string"); + + AstStatBlock* stat = result.root; + REQUIRE(stat != nullptr); + + AstStatTypeAlias* decl = stat->body.data[0]->as(); + REQUIRE(decl != nullptr); + AstTypeFunction* func = decl->type->as(); + REQUIRE(func != nullptr); + REQUIRE_EQ(func->argTypes.types.size, 3); + REQUIRE_EQ(func->argNames.size, 3); + REQUIRE(!func->argNames.data[1]); + REQUIRE(!func->argNames.data[2]); + } + + { + ParseResult result = parseEx("type MyFunc = (a: number, b: string, c: number) -> (d: number, e: string, f: number) -> string"); + + AstStatBlock* stat = result.root; + REQUIRE(stat != nullptr); + + AstStatTypeAlias* decl = stat->body.data[0]->as(); + REQUIRE(decl != nullptr); + AstTypeFunction* func = decl->type->as(); + REQUIRE(func != nullptr); + REQUIRE_EQ(func->argTypes.types.size, 3); + REQUIRE_EQ(func->argNames.size, 3); + REQUIRE(func->argNames.data[2]); + CHECK_EQ(func->argNames.data[2]->first, "c"); + AstTypeFunction* funcRet = func->returnTypes.types.data[0]->as(); + REQUIRE(funcRet != nullptr); + REQUIRE_EQ(funcRet->argTypes.types.size, 3); + REQUIRE_EQ(funcRet->argNames.size, 3); + REQUIRE(func->argNames.data[2]); + CHECK_EQ(funcRet->argNames.data[2]->first, "f"); + } + + matchParseError("type MyFunc = (a: number, b: string, c: number) -> (d: number, e: string, f: number)", + "Expected '->' when parsing function type, got "); + + { + ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; + ScopedFastFlag luauGenericFunctionsParserFix{"LuauGenericFunctionsParserFix", true}; + + matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); + } +} + +TEST_SUITE_END(); + +TEST_SUITE_BEGIN("ParseErrorRecovery"); + +TEST_CASE_FIXTURE(Fixture, "multiple_parse_errors") +{ + try + { + parse(R"( +local a = 3 * ( +return a + +)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(2, e.getErrors().size()); + } +} + +// check that we are not skipping tokens that weren't processed at all +TEST_CASE_FIXTURE(Fixture, "statement_error_recovery_expected") +{ + try + { + parse(R"( +function a(a, b) return a + b end +some +a(2, 5) +)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(1, e.getErrors().size()); + } +} + +TEST_CASE_FIXTURE(Fixture, "statement_error_recovery_unexpected") +{ + try + { + parse(R"(+)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(1, e.getErrors().size()); + } +} + +TEST_CASE_FIXTURE(Fixture, "extra_token_in_consume") +{ + try + { + parse(R"( +function test + (a, f) return a + f end +return test(2, 3) +)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(1, e.getErrors().size()); + CHECK_EQ("Expected '(' when parsing function, got '+'", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "extra_token_in_consume_match") +{ + try + { + parse(R"( +function test(a, f+) return a + f end +return test(2, 3) +)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(1, e.getErrors().size()); + CHECK_EQ("Expected ')' (to close '(' at column 14), got '+'", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "extra_token_in_consume_match_end") +{ + try + { + parse(R"( +if true then + return 12 +then +end +)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(1, e.getErrors().size()); + CHECK_EQ("Expected 'end' (to close 'then' at line 2), got 'then'", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "extra_table_indexer_recovery") +{ + try + { + parse(R"( +local a : { [string] : number, [number] : string, count: number } +)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(1, e.getErrors().size()); + } +} + +TEST_CASE_FIXTURE(Fixture, "recovery_error_limit_1") +{ + ScopedFastInt luauParseErrorLimit("LuauParseErrorLimit", 1); + + try + { + parse("local a = "); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(1, e.getErrors().size()); + CHECK_EQ(e.getErrors().front().getMessage(), e.what()); + } +} + +TEST_CASE_FIXTURE(Fixture, "recovery_error_limit_2") +{ + ScopedFastInt luauParseErrorLimit("LuauParseErrorLimit", 2); + + try + { + parse("escape escape escape"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(3, e.getErrors().size()); + CHECK_EQ("3 parse errors", std::string(e.what())); + CHECK_EQ("Reached error limit (2)", e.getErrors().back().getMessage()); + } +} + +class CountAstNodes : public AstVisitor +{ +public: + bool visit(AstNode* node) override + { + count++; + + return true; + } + + unsigned count = 0; +}; + +TEST_CASE_FIXTURE(Fixture, "recovery_of_parenthesized_expressions") +{ + auto checkAstEquivalence = [this](const char* codeWithErrors, const char* code) { + try + { + parse(codeWithErrors); + } + catch (const Luau::ParseErrors&) + { + } + + CountAstNodes counterWithErrors; + sourceModule->root->visit(&counterWithErrors); + + parse(code); + + CountAstNodes counter; + sourceModule->root->visit(&counter); + + CHECK_EQ(counterWithErrors.count, counter.count); + }; + + auto checkRecovery = [this, checkAstEquivalence](const char* codeWithErrors, const char* code, unsigned expectedErrorCount) { + try + { + parse(codeWithErrors); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(expectedErrorCount, e.getErrors().size()); + checkAstEquivalence(codeWithErrors, code); + } + }; + + checkRecovery("function foo(a, b. c) return a + b end", "function foo(a, b) return a + b end", 1); + checkRecovery("function foo(a, b: { a: number, b: number. c:number }) return a + b end", + "function foo(a, b: { a: number, b: number }) return a + b end", 1); + + checkRecovery("function foo(a, b): (number -> number return a + b end", "function foo(a, b): (number) -> number return a + b end", 1); + checkRecovery("function foo(a, b): (number, number -> number return a + b end", "function foo(a, b): (number) -> number return a + b end", 1); + checkRecovery("function foo(a, b): (number; number) -> number return a + b end", "function foo(a, b): (number) -> number return a + b end", 1); + + checkRecovery("function foo(a, b): (number, number return a + b end", "function foo(a, b): (number, number) end", 1); + checkRecovery("local function foo(a, b): (number, number return a + b end", "local function foo(a, b): (number, number) end", 1); + + // These tests correctly recovered before the changes and we test that new recovery didn't make them worse + // (by skipping more tokens necessary) + checkRecovery("type F = (number, number -> number", "type F = (number, number) -> number", 1); + checkRecovery("function foo(a, b: { a: number, b: number) return a + b end", "function foo(a, b: { a: number, b: number }) return a + b end", 1); + checkRecovery("function foo(a, b: { [number: number}) return a + b end", "function foo(a, b: { [number]: number}) return a + b end", 1); + checkRecovery("local n: (string | number = 2", "local n: (string | number) = 2", 1); + + // Check that we correctly stop at the end of a line + checkRecovery(R"( +function foo(a, b + return a + b +end +)", + "function foo(a, b) return a + b end", 1); +} + +TEST_CASE_FIXTURE(Fixture, "incomplete_method_call") +{ + const std::string_view source = R"( + function howdy() + return game: + end + )"; + + SourceModule sourceModule; + ParseResult result = Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, {}); + + REQUIRE_EQ(1, result.root->body.size); + + AstStatFunction* howdyFunction = result.root->body.data[0]->as(); + REQUIRE(howdyFunction != nullptr); + + AstStatBlock* body = howdyFunction->func->body; + REQUIRE_EQ(1, body->body.size); + + AstStatReturn* ret = body->body.data[0]->as(); + REQUIRE(ret != nullptr); + + REQUIRE_GT(howdyFunction->location.end, body->location.end); +} + +TEST_CASE_FIXTURE(Fixture, "incomplete_method_call_2") +{ + const std::string_view source = R"( + local game = { GetService=function(s) return 'hello' end } + + function a() + game:a + end + )"; + + SourceModule sourceModule; + ParseResult result = Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, {}); + + REQUIRE_EQ(2, result.root->body.size); + + AstStatFunction* howdyFunction = result.root->body.data[1]->as(); + REQUIRE(howdyFunction != nullptr); + + AstStatBlock* body = howdyFunction->func->body; + REQUIRE_EQ(1, body->body.size); + + AstStatError* ret = body->body.data[0]->as(); + REQUIRE(ret != nullptr); + + REQUIRE_GT(howdyFunction->location.end, body->location.end); +} + +TEST_CASE_FIXTURE(Fixture, "incomplete_method_call_still_yields_an_AstExprIndexName") +{ + ParseResult result = tryParse(R"( + game: + )"); + + REQUIRE_EQ(1, result.root->body.size); + + AstStatError* stat = result.root->body.data[0]->as(); + REQUIRE(stat); + + AstExprError* expr = stat->expressions.data[0]->as(); + REQUIRE(expr); + + AstExprIndexName* indexName = expr->expressions.data[0]->as(); + REQUIRE(indexName); +} + +TEST_CASE_FIXTURE(Fixture, "recover_confusables") +{ + // Binary + matchParseError("local a = 4 != 10", "Unexpected '!=', did you mean '~='?"); + matchParseError("local a = true && false", "Unexpected '&&', did you mean 'and'?"); + matchParseError("local a = false || true", "Unexpected '||', did you mean 'or'?"); + + // Unary + matchParseError("local a = !false", "Unexpected '!', did you mean 'not'?"); + + // Check that separate tokens are not considered as a single one + matchParseError("local a = 4 ! = 10", "Expected identifier when parsing expression, got '!'"); + matchParseError("local a = true & & false", "Expected identifier when parsing expression, got '&'"); + matchParseError("local a = false | | true", "Expected identifier when parsing expression, got '|'"); +} + +TEST_CASE_FIXTURE(Fixture, "capture_comments") +{ + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx(R"( + --!strict + + local a = 5 -- comment one + local b = 8 -- comment two + --[[ + Multi line comment + ]] + local c = 'see' + )", + options); + + CHECK(result.errors.empty()); + + CHECK_EQ(4, result.commentLocations.size()); + CHECK_EQ((Location{{1, 8}, {1, 17}}), result.commentLocations[0].location); + CHECK_EQ((Location{{3, 20}, {3, 34}}), result.commentLocations[1].location); + CHECK_EQ((Location{{4, 20}, {4, 34}}), result.commentLocations[2].location); + CHECK_EQ((Location{{5, 8}, {7, 10}}), result.commentLocations[3].location); +} + +TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") +{ + ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; + + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx(R"( + --[[ + )", + options); + + CHECK_EQ(1, result.commentLocations.size()); + CHECK_EQ((Location{{1, 8}, {2, 4}}), result.commentLocations[0].location); +} + +TEST_CASE_FIXTURE(Fixture, "capture_broken_comment") +{ + ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; + + ParseOptions options; + options.captureComments = true; + + ParseResult result = tryParse(R"( + local a = "test" + + --[[broken! + )", + options); + + CHECK_EQ(1, result.commentLocations.size()); + CHECK_EQ((Location{{3, 8}, {4, 4}}), result.commentLocations[0].location); +} + +TEST_CASE_FIXTURE(Fixture, "empty_function_type_error_recovery") +{ + try + { + parse(R"( +type Fn = ( + any, + string | number | () +) -> any +)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ("Expected '->' after '()' when parsing function type; did you mean 'nil'?", e.getErrors().front().getMessage()); + } + + // If we have arguments or generics, don't use special case + try + { + parse(R"(type Fn = (any, string | number | (number, number)) -> any)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ("Expected '->' when parsing function type, got ')'", e.getErrors().front().getMessage()); + } + + ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; + + try + { + parse(R"(type Fn = (any, string | number | ()) -> any)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ("Expected '->' when parsing function type, got ')'", e.getErrors().front().getMessage()); + } + + try + { + parse(R"(type Fn = (any, string | number | ()) -> any)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ("Expected '->' when parsing function type, got ')'", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "AstName_comparison") +{ + CHECK(!(AstName() < AstName())); + + AstName one{"one"}; + AstName two{"two"}; + + CHECK_NE((one < two), (two < one)); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_list_recovery") +{ + ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; + + try + { + parse(R"( +local function foo(a: U, ...: T...): (U, ...T) return a, ... end +return foo(1, 2 -- to check for a second error after recovery +)"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const Luau::ParseErrors& e) + { + CHECK_EQ(2, e.getErrors().size()); + CHECK_EQ("Generic types come before generic type packs", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "recover_index_name_keyword") +{ + ParseResult result = tryParse(R"( +local b +local a = b.do + )"); + CHECK_EQ(1, result.errors.size()); + + result = tryParse(R"( +local b +local a = b. +do end + )"); + CHECK_EQ(1, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "recover_self_call_keyword") +{ + ParseResult result = tryParse(R"( +local b +local a = b:do + )"); + CHECK_EQ(2, result.errors.size()); + + result = tryParse(R"( +local b +local a = b: +do end + )"); + CHECK_EQ(2, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "recover_type_index_name_keyword") +{ + ParseResult result = tryParse(R"( +local A +local b : A.do + )"); + CHECK_EQ(1, result.errors.size()); + + result = tryParse(R"( +local A +local b : A.do +do end + )"); + CHECK_EQ(1, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +{ + ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + + { + AstStat* stat = parse("return if true then 1 else 2"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr = str->list.data[0]->as(); + REQUIRE(ifElseExpr != nullptr); + } + + { + AstStat* stat = parse("return if true then 1 elseif true then 2 else 3"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr1 = str->list.data[0]->as(); + REQUIRE(ifElseExpr1 != nullptr); + auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); + REQUIRE(ifElseExpr2 != nullptr); + } + + // Use "else if" as opposed to elseif + { + AstStat* stat = parse("return if true then 1 else if true then 2 else 3"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr1 = str->list.data[0]->as(); + REQUIRE(ifElseExpr1 != nullptr); + auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); + REQUIRE(ifElseExpr2 != nullptr); + } + + // Use an if-else expression as the conditional expression of an if-else expression + { + AstStat* stat = parse("return if if true then false else true then 1 else 2"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr = str->list.data[0]->as(); + REQUIRE(ifElseExpr != nullptr); + auto* nestedIfElseExpr = ifElseExpr->condition->as(); + REQUIRE(nestedIfElseExpr != nullptr); + } +} + +TEST_SUITE_END(); diff --git a/tests/Predicate.test.cpp b/tests/Predicate.test.cpp new file mode 100644 index 0000000..bb5a93c --- /dev/null +++ b/tests/Predicate.test.cpp @@ -0,0 +1,123 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeInfer.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) +{ + Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId { + // TODO: normalize here also. + std::unordered_set s; + + if (auto utv = get(follow(a))) + s.insert(begin(utv), end(utv)); + else + s.insert(a); + + if (auto utv = get(follow(b))) + s.insert(begin(utv), end(utv)); + else + s.insert(b); + + std::vector options(s.begin(), s.end()); + return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)}); + }); +} + +TEST_SUITE_BEGIN("Predicate"); + +TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") +{ + ScopedFastFlag sff{"LuauOrPredicate", true}; + + RefinementMap m{ + {"b", typeChecker.stringType}, + {"c", typeChecker.numberType}, + }; + + RefinementMap other{ + {"a", typeChecker.stringType}, + {"b", typeChecker.stringType}, + {"c", typeChecker.booleanType}, + }; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(3, m.size()); + REQUIRE(m.count("a")); + REQUIRE(m.count("b")); + REQUIRE(m.count("c")); + + CHECK_EQ("string", toString(m["a"])); + CHECK_EQ("string", toString(m["b"])); + CHECK_EQ("boolean | number", toString(m["c"])); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") +{ + ScopedFastFlag sff{"LuauOrPredicate", true}; + + RefinementMap m{ + {"a", typeChecker.stringType}, + {"b", typeChecker.stringType}, + {"c", typeChecker.numberType}, + }; + + RefinementMap other{ + {"b", typeChecker.stringType}, + {"c", typeChecker.booleanType}, + }; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(3, m.size()); + REQUIRE(m.count("a")); + REQUIRE(m.count("b")); + REQUIRE(m.count("c")); + + CHECK_EQ("string", toString(m["a"])); + CHECK_EQ("string", toString(m["b"])); + CHECK_EQ("boolean | number", toString(m["c"])); +} + +TEST_CASE_FIXTURE(Fixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start") +{ + ScopedFastFlag sff{"LuauOrPredicate", true}; + + RefinementMap m{ + {"a", typeChecker.stringType}, + {"b", typeChecker.numberType}, + {"c", typeChecker.booleanType}, + }; + + RefinementMap other{ + {"c", typeChecker.stringType}, + {"d", typeChecker.numberType}, + {"e", typeChecker.booleanType}, + }; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(5, m.size()); + REQUIRE(m.count("a")); + REQUIRE(m.count("b")); + REQUIRE(m.count("c")); + REQUIRE(m.count("d")); + REQUIRE(m.count("e")); + + CHECK_EQ("string", toString(m["a"])); + CHECK_EQ("number", toString(m["b"])); + CHECK_EQ("boolean | string", toString(m["c"])); + CHECK_EQ("number", toString(m["d"])); + CHECK_EQ("boolean", toString(m["e"])); +} + +TEST_SUITE_END(); diff --git a/tests/RequireTracer.test.cpp b/tests/RequireTracer.test.cpp new file mode 100644 index 0000000..cbd4af2 --- /dev/null +++ b/tests/RequireTracer.test.cpp @@ -0,0 +1,221 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/RequireTracer.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +namespace +{ + +struct RequireTracerFixture +{ + RequireTracerFixture() + : allocator() + , names(allocator) + { + } + + AstStatBlock* parse(std::string_view src) + { + ParseResult result = Parser::parse(src.data(), src.size(), names, allocator, ParseOptions{}); + if (!result.errors.empty()) + { + std::string message; + + for (const auto& error : result.errors) + { + if (!message.empty()) + message += "\n"; + + message += error.what(); + } + + printf("Parse error: %s\n", message.c_str()); + return nullptr; + } + else + return result.root; + } + + Allocator allocator; + AstNameTable names; + + Luau::TestFileResolver fileResolver; +}; + +const std::vector roots = {"game", "Game", "workspace", "Workspace", "script"}; + +} // namespace + +TEST_SUITE_BEGIN("RequireTracerTest"); + +TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") +{ + AstStatBlock* block = parse(R"( + local m = workspace.Foo.Bar.Baz + )"); + + RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); + REQUIRE(!result.exprs.empty()); + + AstStatLocal* loc = block->body.data[0]->as(); + REQUIRE(loc); + REQUIRE_EQ(1, loc->vars.size); + REQUIRE_EQ(1, loc->values.size); + + AstExprIndexName* value = loc->values.data[0]->as(); + REQUIRE(value); + REQUIRE(result.exprs.contains(value)); + CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value]); + + value = value->expr->as(); + REQUIRE(value); + REQUIRE(result.exprs.contains(value)); + CHECK_EQ("workspace/Foo/Bar", result.exprs[value]); + + value = value->expr->as(); + REQUIRE(value); + REQUIRE(result.exprs.contains(value)); + CHECK_EQ("workspace/Foo", result.exprs[value]); + + AstExprGlobal* workspace = value->expr->as(); + REQUIRE(workspace); + REQUIRE(result.exprs.contains(workspace)); + CHECK_EQ("workspace", result.exprs[workspace]); +} + +TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") +{ + AstStatBlock* block = parse(R"( + local m = workspace.Foo.Bar.Baz + local n = m.Quux + )"); + + REQUIRE_EQ(2, block->body.size); + + RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); + + AstStatLocal* local = block->body.data[1]->as(); + REQUIRE(local); + REQUIRE_EQ(1, local->vars.size); + + REQUIRE(result.exprs.contains(local->values.data[0])); + CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]]); +} + +TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") +{ + AstStatBlock* block = parse(R"( + local M = require(workspace.Game.Thing, workspace.Something.Else) + )"); + REQUIRE_EQ(1, block->body.size); + + RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); + + AstStatLocal* local = block->body.data[0]->as(); + REQUIRE(local != nullptr); + REQUIRE_EQ(1, local->vars.size); + REQUIRE_EQ(1, local->values.size); + + AstExprCall* call = local->values.data[0]->as(); + REQUIRE(call != nullptr); + + REQUIRE_EQ(2, call->args.size); + + CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]]); + CHECK_EQ("workspace/Something/Else", result.exprs[call->args.data[1]]); +} + +TEST_CASE_FIXTURE(RequireTracerFixture, "follow_GetService_calls") +{ + AstStatBlock* block = parse(R"( + local R = game:GetService('ReplicatedStorage').Roact + local Roact = require(R) + )"); + REQUIRE_EQ(2, block->body.size); + + RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); + + AstStatLocal* local = block->body.data[0]->as(); + REQUIRE(local != nullptr); + + CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[local->values.data[0]]); + + AstStatLocal* local2 = block->body.data[1]->as(); + REQUIRE(local2 != nullptr); + REQUIRE_EQ(1, local2->values.size); + + AstExprCall* call = local2->values.data[0]->as(); + REQUIRE(call != nullptr); + REQUIRE_EQ(1, call->args.size); + + CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[call->args.data[0]]); +} + +TEST_CASE_FIXTURE(RequireTracerFixture, "follow_WaitForChild_calls") +{ + ScopedFastFlag luauTraceRequireLookupChild("LuauTraceRequireLookupChild", true); + + AstStatBlock* block = parse(R"( +local A = require(workspace:WaitForChild('ReplicatedStorage').Content) +local B = require(workspace:FindFirstChild('ReplicatedFirst').Data) + )"); + + RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); + + REQUIRE_EQ(2, result.requires.size()); + CHECK_EQ("workspace/ReplicatedStorage/Content", result.requires[0].first); + CHECK_EQ("workspace/ReplicatedFirst/Data", result.requires[1].first); +} + +TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") +{ + AstStatBlock* block = parse(R"( + local R: typeof(require(workspace.CoolThing).UsefulObject) + )"); + REQUIRE_EQ(1, block->body.size); + + RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); + + AstStatLocal* local = block->body.data[0]->as(); + REQUIRE(local != nullptr); + + REQUIRE_EQ(local->vars.size, 1); + + AstType* ann = local->vars.data[0]->annotation; + REQUIRE(ann != nullptr); + + AstTypeTypeof* typeofAnnotation = ann->as(); + REQUIRE(typeofAnnotation != nullptr); + + AstExprIndexName* indexName = typeofAnnotation->expr->as(); + REQUIRE(indexName != nullptr); + REQUIRE_EQ(indexName->index, "UsefulObject"); + + AstExprCall* call = indexName->expr->as(); + REQUIRE(call != nullptr); + REQUIRE_EQ(1, call->args.size); + + CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]]); +} + +TEST_CASE_FIXTURE(RequireTracerFixture, "follow_string_indexexpr") +{ + AstStatBlock* block = parse(R"( + local R = game["Test"] + )"); + REQUIRE_EQ(1, block->body.size); + + RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); + + AstStatLocal* local = block->body.data[0]->as(); + REQUIRE(local != nullptr); + + CHECK_EQ("game/Test", result.exprs[local->values.data[0]]); +} + +TEST_SUITE_END(); diff --git a/tests/ScopedFlags.h b/tests/ScopedFlags.h new file mode 100644 index 0000000..9454307 --- /dev/null +++ b/tests/ScopedFlags.h @@ -0,0 +1,59 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +template +struct ScopedFValue +{ +private: + Luau::FValue* value = nullptr; + T oldValue = T(); + +public: + ScopedFValue(const char* name, T newValue) + { + for (Luau::FValue* v = Luau::FValue::list; v; v = v->next) + if (strcmp(v->name, name) == 0) + { + value = v; + oldValue = v->value; + v->value = newValue; + break; + } + + LUAU_ASSERT(value); + } + + ScopedFValue(const ScopedFValue&) = delete; + ScopedFValue& operator=(const ScopedFValue&) = delete; + + ScopedFValue(ScopedFValue&& rhs) + { + value = rhs.value; + oldValue = rhs.oldValue; + + rhs.value = nullptr; + } + + ScopedFValue& operator=(ScopedFValue&& rhs) + { + value = rhs.value; + oldValue = rhs.oldValue; + + rhs.value = nullptr; + + return *this; + } + + ~ScopedFValue() + { + if (value) + value->value = oldValue; + } +}; + +using ScopedFastFlag = ScopedFValue; +using ScopedFastInt = ScopedFValue; diff --git a/tests/StringUtils.test.cpp b/tests/StringUtils.test.cpp new file mode 100644 index 0000000..afef3b0 --- /dev/null +++ b/tests/StringUtils.test.cpp @@ -0,0 +1,109 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/StringUtils.h" + +#include "doctest.h" + +#include + +namespace +{ +using LevenshteinMatrix = std::vector>; + +std::string format(std::string_view a, std::string_view b, size_t expected, size_t actual) +{ + return "Distance of '" + std::string(a) + "' and '" + std::string(b) + "': expected " + std::to_string(expected) + ", got " + + std::to_string(actual); +} + +// Each call to this function is not one test, but instead actually runs tests (A.size() * B.size()) + 2 times. +void compareLevenshtein(LevenshteinMatrix distances, std::string_view a, std::string_view b) +{ + for (size_t x = 0; x <= a.size(); ++x) + { + for (size_t y = 0; y <= b.size(); ++y) + { + std::string_view currentA = a.substr(0, x); + std::string_view currentB = b.substr(0, y); + + size_t actual = Luau::editDistance(currentA, currentB); + size_t expected = distances[x][y]; + CHECK_MESSAGE(actual == expected, format(currentA, currentB, expected, actual)); + } + } +} +} // namespace + +TEST_SUITE_BEGIN("StringUtilsTest"); + +#if 0 +// This unit test is only used to measure how performant the current levenshtein distance algorithm is. +// It is entirely ok to submit this, but keep #if 0. +TEST_CASE("BenchmarkLevenshteinDistance") +{ + // For reference: running this benchmark on a Macbook Pro 16 takes ~250ms. + + int count = 1'000'000; + + // specifically chosen because they: + // - are real words, + // - have common prefix and suffix, and + // - are sufficiently long enough to stress test with + std::string_view a("Intercalate"); + std::string_view b("Interchangeable"); + + auto start = std::chrono::steady_clock::now(); + + for (int i = 0; i < count; ++i) + Luau::editDistance(a, b); + + auto end = std::chrono::steady_clock::now(); + auto time = std::chrono::duration_cast(end - start); + + std::cout << "Running levenshtein distance " << count << " times took " << time.count() << "ms" << std::endl; +} +#endif + +TEST_CASE("LevenshteinDistanceKittenSitting") +{ + LevenshteinMatrix distances{ + {0, 1, 2, 3, 4, 5, 6, 7}, // S I T T I N G + {1, 1, 2, 3, 4, 5, 6, 7}, // K + {2, 2, 1, 2, 3, 4, 5, 6}, // I + {3, 3, 2, 1, 2, 3, 4, 5}, // T + {4, 4, 3, 2, 1, 2, 3, 4}, // T + {5, 5, 4, 3, 2, 2, 3, 4}, // E + {6, 6, 5, 4, 3, 3, 2, 3}, // N + }; + + compareLevenshtein(distances, "kitten", "sitting"); +} + +TEST_CASE("LevenshteinDistanceSaturdaySunday") +{ + LevenshteinMatrix distances{ + {0, 1, 2, 3, 4, 5, 6}, // S U N D A Y + {1, 0, 1, 2, 3, 4, 5}, // S + {2, 1, 1, 2, 3, 3, 4}, // A + {3, 2, 2, 2, 3, 4, 4}, // T + {4, 3, 2, 3, 3, 4, 5}, // U + {5, 4, 3, 3, 4, 4, 5}, // R + {6, 5, 4, 4, 3, 4, 5}, // D + {7, 6, 5, 5, 4, 3, 4}, // A + {8, 7, 6, 6, 5, 4, 3}, // Y + }; + + compareLevenshtein(distances, "saturday", "sunday"); +} + +TEST_CASE("EditDistanceIsAgnosticOfArgumentOrdering") +{ + CHECK_EQ(Luau::editDistance("blox", "block"), Luau::editDistance("block", "blox")); +} + +TEST_CASE("AreWeUsingDistanceWithAdjacentTranspositionsAndNotOptimalStringAlignment") +{ + size_t distance = Luau::editDistance("CA", "ABC"); + CHECK_EQ(distance, 2); +} + +TEST_SUITE_END(); diff --git a/tests/Symbol.test.cpp b/tests/Symbol.test.cpp new file mode 100644 index 0000000..44fe3a3 --- /dev/null +++ b/tests/Symbol.test.cpp @@ -0,0 +1,40 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Symbol.h" +#include "Luau/Ast.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("SymbolTests"); + +TEST_CASE("hashing") +{ + std::string s1 = "name"; + std::string s2 = "name"; + + // These two names point to distinct memory areas. + AstName one{s1.data()}; + AstName two{s2.data()}; + + Symbol n1{one}; + Symbol n2{two}; + + CHECK(n1 == n1); + CHECK(n1 == n2); + CHECK(n2 == n2); + + CHECK_EQ(std::hash()(one), std::hash()(one)); + CHECK_EQ(std::hash()(one), std::hash()(two)); + CHECK_EQ(std::hash()(two), std::hash()(two)); + + std::unordered_map theMap; + theMap[AstName{s1.data()}] = 5; + theMap[AstName{s2.data()}] = 1; + + REQUIRE_EQ(1, theMap.size()); +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp new file mode 100644 index 0000000..d7d68c4 --- /dev/null +++ b/tests/ToString.test.cpp @@ -0,0 +1,482 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/ToString.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("ToString"); + +TEST_CASE_FIXTURE(Fixture, "primitive") +{ + CheckResult result = check("local a = nil local b = 44 local c = 'lalala' local d = true"); + LUAU_REQUIRE_NO_ERRORS(result); + + // A variable without an annotation and with a nil literal should infer as 'free', not 'nil' + CHECK_NE("nil", toString(requireType("a"))); + + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("string", toString(requireType("c"))); + CHECK_EQ("boolean", toString(requireType("d"))); +} + +TEST_CASE_FIXTURE(Fixture, "bound_types") +{ + CheckResult result = check("local a = 444 local b = a"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "free_types") +{ + CheckResult result = check("local a"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("a", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_table") +{ + TypeVar cyclicTable{TypeVariant(TableTypeVar())}; + TableTypeVar* tableOne = getMutable(&cyclicTable); + tableOne->props["self"] = {&cyclicTable}; + + CHECK_EQ("t1 where t1 = { self: t1 }", toString(&cyclicTable)); +} + +TEST_CASE_FIXTURE(Fixture, "named_table") +{ + TypeVar table{TypeVariant(TableTypeVar())}; + TableTypeVar* t = getMutable(&table); + t->name = "TheTable"; + + CHECK_EQ("TheTable", toString(&table)); +} + +TEST_CASE_FIXTURE(Fixture, "exhaustive_toString_of_cyclic_table") +{ + CheckResult result = check(R"( + --!strict + local Vec3 = {} + Vec3.__index = Vec3 + function Vec3.new() + return setmetatable({x=0, y=0, z=0}, Vec3) + end + + export type Vec3 = typeof(Vec3.new()) + + local thefun: any = function(self, o) return self end + + local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun + + Vec3.__mul = multiply + + local a = Vec3.new() + )"); + + std::string a = toString(requireType("a"), {true}); + + CHECK_EQ(std::string::npos, a.find("CYCLE")); + CHECK_EQ(std::string::npos, a.find("TRUNCATED")); + + //clang-format off + CHECK_EQ("t2 where " + "t1 = { __index: t1, __mul: ((t2, number) -> t2) & ((t2, t2) -> t2), new: () -> t2 } ; " + "t2 = { @metatable t1, {| x: number, y: number, z: number |} }", + a); + //clang-format on +} + + +TEST_CASE_FIXTURE(Fixture, "intersection_parenthesized_only_if_needed") +{ + auto utv = TypeVar{UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}}; + auto itv = TypeVar{IntersectionTypeVar{{&utv, typeChecker.booleanType}}}; + + CHECK_EQ(toString(&itv), "(number | string) & boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "union_parenthesized_only_if_needed") +{ + auto itv = TypeVar{IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}}; + auto utv = TypeVar{UnionTypeVar{{&itv, typeChecker.booleanType}}}; + + CHECK_EQ(toString(&utv), "(number & string) | boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "functions_are_always_parenthesized_in_unions_or_intersections") +{ + auto stringAndNumberPack = TypePackVar{TypePack{{typeChecker.stringType, typeChecker.numberType}}}; + auto numberAndStringPack = TypePackVar{TypePack{{typeChecker.numberType, typeChecker.stringType}}}; + + auto sn2ns = TypeVar{FunctionTypeVar{&stringAndNumberPack, &numberAndStringPack}}; + auto ns2sn = TypeVar{FunctionTypeVar(typeChecker.globalScope->level, &numberAndStringPack, &stringAndNumberPack)}; + + auto utv = TypeVar{UnionTypeVar{{&ns2sn, &sn2ns}}}; + auto itv = TypeVar{IntersectionTypeVar{{&ns2sn, &sn2ns}}}; + + CHECK_EQ(toString(&utv), "((number, string) -> (string, number)) | ((string, number) -> (number, string))"); + CHECK_EQ(toString(&itv), "((number, string) -> (string, number)) & ((string, number) -> (number, string))"); +} + +TEST_CASE_FIXTURE(Fixture, "quit_stringifying_table_type_when_length_is_exceeded") +{ + TableTypeVar ttv{}; + for (char c : std::string("abcdefghijklmno")) + ttv.props[std::string(1, c)] = {typeChecker.numberType}; + + TypeVar tv{ttv}; + + ToStringOptions o; + o.exhaustive = false; + o.maxTableLength = 40; + CHECK_EQ(toString(&tv, o), "{ a: number, b: number, c: number, d: number, e: number, ... 10 more ... }"); +} + +TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_is_still_capped_when_exhaustive") +{ + TableTypeVar ttv{}; + for (char c : std::string("abcdefg")) + ttv.props[std::string(1, c)] = {typeChecker.numberType}; + + TypeVar tv{ttv}; + + ToStringOptions o; + o.exhaustive = true; + o.maxTableLength = 40; + CHECK_EQ(toString(&tv, o), "{ a: number, b: number, c: number, d: number, e: number, ... 2 more ... }"); +} + +TEST_CASE_FIXTURE(Fixture, "quit_stringifying_type_when_length_is_exceeded") +{ + CheckResult result = check(R"( + function f0() end + function f1(f) return f or f0 end + function f2(f) return f or f1 end + function f3(f) return f or f2 end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = false; + o.maxTypeLength = 40; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); +} + +TEST_CASE_FIXTURE(Fixture, "stringifying_type_is_still_capped_when_exhaustive") +{ + CheckResult result = check(R"( + function f0() end + function f1(f) return f or f0 end + function f2(f) return f or f1 end + function f3(f) return f or f2 end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = true; + o.maxTypeLength = 40; + CHECK_EQ(toString(requireType("f0"), o), "() -> ()"); + CHECK_EQ(toString(requireType("f1"), o), "(() -> ()) -> () -> ()"); + CHECK_EQ(toString(requireType("f2"), o), "((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); + CHECK_EQ(toString(requireType("f3"), o), "(((() -> ()) -> () -> ()) -> (() -> ()) -> ... "); +} + +TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table_state_braces") +{ + TableTypeVar ttv{TableState::Sealed, TypeLevel{}}; + for (char c : std::string("abcdefghij")) + ttv.props[std::string(1, c)] = {typeChecker.numberType}; + + TypeVar tv{ttv}; + + ToStringOptions o{/* exhaustive= */ false, /* useLineBreaks= */ false, /* functionTypeArguments= */ false, /* hideTableKind= */ false, 40}; + CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}"); +} + +TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_union_type_bails_early") +{ + TypeVar tv{UnionTypeVar{{typeChecker.stringType, typeChecker.numberType}}}; + UnionTypeVar* utv = getMutable(&tv); + utv->options.push_back(&tv); + utv->options.push_back(&tv); + + CHECK_EQ("t1 where t1 = number | string", toString(&tv)); +} + +TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_intersection_type_bails_early") +{ + TypeVar tv{IntersectionTypeVar{}}; + IntersectionTypeVar* itv = getMutable(&tv); + itv->parts.push_back(&tv); + itv->parts.push_back(&tv); + + CHECK_EQ("t1 where t1 = t1 & t1", toString(&tv)); +} + +TEST_CASE_FIXTURE(Fixture, "stringifying_array_uses_array_syntax") +{ + TableTypeVar ttv{TableState::Sealed, TypeLevel{}}; + ttv.indexer = TableIndexer{typeChecker.numberType, typeChecker.stringType}; + + CHECK_EQ("{string}", toString(TypeVar{ttv})); + + ttv.props["A"] = {typeChecker.numberType}; + CHECK_EQ("{| [number]: string, A: number |}", toString(TypeVar{ttv})); + + ttv.props.clear(); + ttv.state = TableState::Unsealed; + CHECK_EQ("{string}", toString(TypeVar{ttv})); +} + + +TEST_CASE_FIXTURE(Fixture, "generic_packs_are_stringified_differently_from_generic_types") +{ + TypePackVar tpv{GenericTypePack{"a"}}; + CHECK_EQ(toString(&tpv), "a..."); + + TypeVar tv{GenericTypeVar{"a"}}; + CHECK_EQ(toString(&tv), "a"); +} + +TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names") +{ + CheckResult result = check("type MyFunc = (a: number, string, c: number) -> string; local a : MyFunc"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.functionTypeArguments = true; + CHECK_EQ("(a: number, string, c: number) -> string", toString(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names_generic") +{ + ScopedFastFlag luauGenericFunctions{"LuauGenericFunctions", true}; + ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; + + CheckResult result = check("local function f(n: number, ...: a...): (a...) return ... end"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.functionTypeArguments = true; + CHECK_EQ("(n: number, a...) -> (a...)", toString(requireType("f"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names_and_self") +{ + CheckResult result = check(R"( +local tbl = {} +tbl.a = 2 +function tbl:foo(b: number, c: number) return (self.a :: number) + b + c end +type Table = typeof(tbl) +type Foo = typeof(tbl.foo) +local u: Foo +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.functionTypeArguments = true; + // Can't guess the name of 'self' to compare name, but at least there should be no assertion + toString(requireType("u"), opts); +} + +TEST_CASE_FIXTURE(Fixture, "generate_friendly_names_for_inferred_generics") +{ + CheckResult result = check(R"( + function id(x) return x end + + function id2(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, a29, a30) + return a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19, a20, a21, a22, a23, a24, a25, a26, a27, a28, a29, a30 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(a) -> a", toString(requireType("id"))); + + CHECK_EQ("(a, b, c, d, e, f, g, h, i, j, k, l, " + "m, n, o, p, q, r, s, t, u, v, w, x, y, z, a1, b1, c1, d1) -> (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, " + "x, y, z, a1, b1, c1, d1)", + toString(requireType("id2"))); +} + +TEST_CASE_FIXTURE(Fixture, "toStringDetailed") +{ + CheckResult result = check(R"( + function id3(a, b, c) + return a, b, c + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId id3Type = requireType("id3"); + ToStringResult nameData = toStringDetailed(id3Type); + + REQUIRE_EQ(3, nameData.nameMap.typeVars.size()); + REQUIRE_EQ("(a, b, c) -> (a, b, c)", nameData.name); + + ToStringOptions opts; + opts.nameMap = std::move(nameData.nameMap); + + const FunctionTypeVar* ftv = get(follow(id3Type)); + REQUIRE(ftv != nullptr); + + auto params = flatten(ftv->argTypes).first; + REQUIRE_EQ(3, params.size()); + + REQUIRE_EQ("a", toString(params[0], opts)); + REQUIRE_EQ("b", toString(params[1], opts)); + REQUIRE_EQ("c", toString(params[2], opts)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") +{ + ScopedFastFlag sff[] = { + {"LuauGenericFunctions", true}, + }; + + CheckResult result = check(R"( + local base = {} + function base:one() return 1 end + + local child = {} + setmetatable(child, {__index=base}) + function child:two() return 2 end + + local inst = {} + setmetatable(inst, {__index=child}) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId tType = requireType("inst"); + ToStringResult r = toStringDetailed(tType); + CHECK_EQ("{ @metatable {| __index: { @metatable {| __index: base |}, child } |}, inst }", r.name); + CHECK_EQ(0, r.nameMap.typeVars.size()); + + ToStringOptions opts; + opts.nameMap = r.nameMap; + + const MetatableTypeVar* tMeta = get(tType); + REQUIRE(tMeta); + + TableTypeVar* tMeta2 = getMutable(tMeta->metatable); + REQUIRE(tMeta2); + REQUIRE(tMeta2->props.count("__index")); + + const MetatableTypeVar* tMeta3 = get(tMeta2->props["__index"].type); + REQUIRE(tMeta3); + + TableTypeVar* tMeta4 = getMutable(tMeta3->metatable); + REQUIRE(tMeta4); + REQUIRE(tMeta4->props.count("__index")); + + TableTypeVar* tMeta5 = getMutable(tMeta4->props["__index"].type); + REQUIRE(tMeta5); + + TableTypeVar* tMeta6 = getMutable(tMeta3->table); + REQUIRE(tMeta6); + + ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts); + opts.nameMap = oneResult.nameMap; + + std::string twoResult = toString(tMeta6->props["two"].type, opts); + + REQUIRE_EQ("(a) -> number", oneResult.name); + REQUIRE_EQ("(b) -> number", twoResult); +} + + +TEST_CASE_FIXTURE(Fixture, "toStringErrorPack") +{ + CheckResult result = check(R"( +local function target(callback: nil) return callback(4, "hello") end + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(requireType("target")), "(nil) -> (*unknown*)"); +} + +TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") +{ + CheckResult result = check(R"( +function foo(a, b) return a(b) end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("foo")), "((a) -> (b...), a) -> (b...)"); +} + +TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_TypePack") +{ + ScopedFastFlag sff{"LuauToStringFollowsBoundTo", true}; + + TypeVar tv1{TableTypeVar{}}; + TableTypeVar* ttv = getMutable(&tv1); + ttv->state = TableState::Sealed; + ttv->props["hello"] = {typeChecker.numberType}; + ttv->props["world"] = {typeChecker.numberType}; + + TypePackVar tpv1{TypePack{{&tv1}}}; + + TypeVar tv2{TableTypeVar{}}; + TableTypeVar* bttv = getMutable(&tv2); + bttv->state = TableState::Free; + bttv->props["hello"] = {typeChecker.numberType}; + bttv->boundTo = &tv1; + + TypePackVar tpv2{TypePack{{&tv2}}}; + + CHECK_EQ("{| hello: number, world: number |}", toString(&tpv1)); + CHECK_EQ("{| hello: number, world: number |}", toString(&tpv2)); +} + +TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type F = ((() -> number)?) -> F? + local function f(p) return f end + local g: F = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("t1 where t1 = ((() -> number)?) -> t1?", toString(requireType("g"))); +} + +TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_intersection") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + function f() return f end + local a: ((number) -> ()) & typeof(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("((number) -> ()) & t1 where t1 = () -> t1", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") +{ + ScopedFastFlag luauInstantiatedTypeParamRecursion{"LuauInstantiatedTypeParamRecursion", true}; + + TypeVar tableTy{TableTypeVar{}}; + TableTypeVar* ttv = getMutable(&tableTy); + ttv->name = "Table"; + ttv->instantiatedTypeParams.push_back(&tableTy); + + CHECK_EQ(toString(tableTy), "Table"); +} + +TEST_SUITE_END(); diff --git a/tests/TopoSort.test.cpp b/tests/TopoSort.test.cpp new file mode 100644 index 0000000..9b99086 --- /dev/null +++ b/tests/TopoSort.test.cpp @@ -0,0 +1,413 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TopoSortStatements.h" +#include "Luau/Transpiler.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +static std::vector toposort(AstStatBlock& block) +{ + std::vector result{block.body.begin(), block.body.end()}; + + Luau::toposort(result); + + return result; +} + +TEST_SUITE_BEGIN("TopoSortTests"); + +TEST_CASE_FIXTURE(Fixture, "sorts") +{ + AstStatBlock* program = parse(R"( + function A() + return B("high five!") + end + + function B(x) + return x + end + )"); + + auto sorted = toposort(*program); + REQUIRE_EQ(2, sorted.size()); + + AstStatBlock* block = program->as(); + REQUIRE(block != nullptr); + REQUIRE_EQ(2, block->body.size); + + // B is sorted ahead of A + CHECK_EQ(block->body.data[1], sorted[0]); + CHECK_EQ(block->body.data[0], sorted[1]); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_dependency_terminates") +{ + AstStatBlock* program = parse(R"( + function A() + return B() + end + + function B() + return A() + end + )"); + + auto sorted = toposort(*program); + REQUIRE_EQ(2, sorted.size()); +} + +TEST_CASE_FIXTURE(Fixture, "doesnt_omit_statements_that_dont_need_sorting") +{ + AstStatBlock* program = parse(R"( + local X = {} + + function A() + return B(5), B("Hi") + end + + local Y = {} + + function B(x) + return x + end + + local Z = B() + )"); + + auto sorted = toposort(*program); + REQUIRE_EQ(5, sorted.size()); + + AstStatBlock* block = program->as(); + REQUIRE(block != nullptr); + REQUIRE_EQ(5, block->body.size); + + AstStat* X = block->body.data[0]; + AstStat* A = block->body.data[1]; + AstStat* Y = block->body.data[2]; + AstStat* B = block->body.data[3]; + AstStat* Z = block->body.data[4]; + + CHECK_EQ(sorted[0], X); + CHECK_EQ(sorted[1], Y); + CHECK_EQ(sorted[2], B); + CHECK_EQ(sorted[3], Z); + CHECK_EQ(sorted[4], A); +} + +TEST_CASE_FIXTURE(Fixture, "slightly_more_complex") +{ + AstStatBlock* program = parse(R"( + local T = {} + + function T:foo() + return T:bar(999), T:bar("hi") + end + + function T:bar(i) + return i + end + )"); + + auto sorted = toposort(*program); + + REQUIRE_EQ(3, sorted.size()); + + CHECK_EQ(sorted[0], program->body.data[0]); + CHECK_EQ(sorted[1], program->body.data[2]); + CHECK_EQ(sorted[2], program->body.data[1]); +} + +TEST_CASE_FIXTURE(Fixture, "reorder_functions_after_dependent_assigns") +{ + AstStatBlock* program = parse(R"( + local T = {} -- 0 + + function T.a() -- 1 depends on (2) + T.b() + end + + function T.b() -- 2 depends on (4) + T.c() + end + + function make_function() -- 3 + return function() end + end + + T.c = make_function() -- 4 depends on (3) + + T.a() -- 5 depends on (1) + )"); + + auto sorted = toposort(*program); + + REQUIRE_EQ(6, sorted.size()); + + CHECK_EQ(sorted[0], program->body.data[0]); + CHECK_EQ(sorted[1], program->body.data[3]); + CHECK_EQ(sorted[2], program->body.data[4]); + CHECK_EQ(sorted[3], program->body.data[2]); + CHECK_EQ(sorted[4], program->body.data[1]); + CHECK_EQ(sorted[5], program->body.data[5]); +} + +TEST_CASE_FIXTURE(Fixture, "dont_reorder_assigns") +{ + AstStatBlock* program = parse(R"( + local T = {} -- 0 + + function T.a() -- 1 depends on (2) + T.b() + end + + function T.b() -- 2 depends on (5) + T.c() + end + + function make_function() -- 3 + return function() end + end + + T.a() -- 4 depends on (1 -> 2 -> 5), but we cannot reorder it after 5! + + T.c = make_function() -- 5 depends on (3) + )"); + + auto sorted = toposort(*program); + + REQUIRE_EQ(6, sorted.size()); + + CHECK_EQ(sorted[0], program->body.data[0]); + CHECK_EQ(sorted[1], program->body.data[3]); + CHECK_EQ(sorted[2], program->body.data[2]); + CHECK_EQ(sorted[3], program->body.data[1]); + CHECK_EQ(sorted[4], program->body.data[4]); + CHECK_EQ(sorted[5], program->body.data[5]); +} + +TEST_CASE_FIXTURE(Fixture, "dont_reorder_function_after_assignment_to_global") +{ + AstStatBlock* program = parse(R"( + local f + + function g() + f() + end + + f = function() end + )"); + + auto sorted = toposort(*program); + + REQUIRE_EQ(3, sorted.size()); + + CHECK_EQ(sorted[0], program->body.data[0]); + CHECK_EQ(sorted[1], program->body.data[1]); + CHECK_EQ(sorted[2], program->body.data[2]); +} + +TEST_CASE_FIXTURE(Fixture, "local_functions_need_sorting_too") +{ + AstStatBlock* program = parse(R"( + local a = nil -- 0 + + local function f() -- 1 depends on 4 + a.c = 4 + end + + local function g() -- 2 depends on 1 + f() + end + + a = {} -- 3 + a.c = nil -- 4 + )"); + + auto sorted = toposort(*program); + + REQUIRE_EQ(5, sorted.size()); + + CHECK_EQ(sorted[0], program->body.data[0]); + CHECK_EQ(sorted[1], program->body.data[3]); + CHECK_EQ(sorted[2], program->body.data[4]); + CHECK_EQ(sorted[3], program->body.data[1]); + CHECK_EQ(sorted[4], program->body.data[2]); +} + +TEST_CASE_FIXTURE(Fixture, "dont_force_checking_until_an_AstExprCall_needs_the_symbol") +{ + AstStatBlock* program = parse(R"( + function A(obj) + C(obj) + end + + local B = A -- It would be an error to force checking of A at this point just because the definition of B is an imperative + + function C(player) + end + + local D = A(nil) -- The real dependency on A is here, where A is invoked. + )"); + + auto sorted = toposort(*program); + + REQUIRE_EQ(4, sorted.size()); + + auto A = program->body.data[0]; + auto B = program->body.data[1]; + auto C = program->body.data[2]; + auto D = program->body.data[3]; + + CHECK_EQ(sorted[0], C); + CHECK_EQ(sorted[1], A); + CHECK_EQ(sorted[2], B); + CHECK_EQ(sorted[3], D); +} + +TEST_CASE_FIXTURE(Fixture, "dont_reorder_imperatives") +{ + AstStatBlock* program = parse(R"( + local temp = work + work = arr + arr = temp + width = width * 2 + )"); + + auto sorted = toposort(*program); + + REQUIRE_EQ(4, sorted.size()); +} + +TEST_CASE_FIXTURE(Fixture, "sort_typealias_first") +{ + AstStatBlock* program = parse(R"( + local foo: A = 1 + type A = number + )"); + + auto sorted = toposort(*program); + + REQUIRE_EQ(2, sorted.size()); + + auto A = program->body.data[0]; + auto B = program->body.data[1]; + + CHECK_EQ(sorted[0], B); + CHECK_EQ(sorted[1], A); +} + +TEST_CASE_FIXTURE(Fixture, "typealias_of_typeof_is_not_sorted") +{ + AstStatBlock* program = parse(R"( + type Foo = typeof(foo) + local function foo(x: number) end + )"); + + auto sorted = toposort(*program); + + REQUIRE_EQ(2, sorted.size()); + + auto A = program->body.data[0]; + auto B = program->body.data[1]; + + CHECK_EQ(sorted[0], A); + CHECK_EQ(sorted[1], B); +} + +TEST_CASE_FIXTURE(Fixture, "nested_type_annotations_depends_on_later_typealiases") +{ + AstStatBlock* program = parse(R"( + type Foo = A | B + type B = number + type A = string + )"); + + auto sorted = toposort(*program); + + REQUIRE_EQ(3, sorted.size()); + + auto Foo = program->body.data[0]; + auto B = program->body.data[1]; + auto A = program->body.data[2]; + + CHECK_EQ(sorted[0], B); + CHECK_EQ(sorted[1], A); + CHECK_EQ(sorted[2], Foo); +} + +TEST_CASE_FIXTURE(Fixture, "return_comes_last") +{ + CheckResult result = check(R"( +export type Module = { bar: (number) -> boolean, foo: () -> string } + +return function() : Module + local module = {} + + local function confuseCompiler() return module.foo() end + + module.foo = function() return "" end + + function module.bar(x:number) + confuseCompiler() + return true + end + + return module +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "break_comes_last") +{ + AstStatBlock* program = parse(R"( +repeat +local module = {} +local function confuseCompiler() return module.foo() end +module.foo = function() return "" end +break +until true + )"); + + REQUIRE(program->body.size == 1); + + auto repeat = program->body.data[0]->as(); + REQUIRE(repeat); + + REQUIRE(repeat->body->body.size == 4); + + auto sorted = toposort(*repeat->body); + + REQUIRE(sorted.size() == 4); + CHECK_EQ(sorted[3], repeat->body->body.data[3]); +} + +TEST_CASE_FIXTURE(Fixture, "continue_comes_last") +{ + AstStatBlock* program = parse(R"( +repeat +local module = {} +local function confuseCompiler() return module.foo() end +module.foo = function() return "" end +continue +until true + )"); + + REQUIRE(program->body.size == 1); + + auto repeat = program->body.data[0]->as(); + REQUIRE(repeat); + + REQUIRE(repeat->body->body.size == 4); + + auto sorted = toposort(*repeat->body); + + REQUIRE(sorted.size() == 4); + CHECK_EQ(sorted[3], repeat->body->body.data[3]); +} + +TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp new file mode 100644 index 0000000..bfff60f --- /dev/null +++ b/tests/Transpiler.test.cpp @@ -0,0 +1,403 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeAttach.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/Transpiler.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TranspilerTests"); + +TEST_CASE("test_1") +{ + const std::string example = R"( +local function isPortal(element) + if type(element)~='table'then + return false + end + + return element.component==Core.Portal +end +)"; + + CHECK_EQ(example, transpile(example).code); +} + +TEST_CASE("string_literals") +{ + const std::string code = R"( local S='abcdef\n\f\a\020' )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("string_literals_containing_utf8") +{ + const std::string code = R"( local S='lalala こんにちは' )"; // Konichiwa! + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("elseif_chains_indent_sensibly") +{ + const std::string code = R"( + if This then + Once() + elseif That then + Another() + elseif SecondLast then + Third() + else + IfAllElseFails() + end + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("strips_type_annotations") +{ + const std::string code = R"( local s: string= 'hello there' )"; + const std::string expected = R"( local s = 'hello there' )"; + + CHECK_EQ(expected, transpile(code).code); +} + +TEST_CASE("strips_type_assertion_expressions") +{ + const std::string code = R"( local s= some_function() :: any+ something_else() :: number )"; + const std::string expected = R"( local s= some_function() + something_else() )"; + + CHECK_EQ(expected, transpile(code).code); +} + +TEST_CASE("function_taking_ellipsis") +{ + const std::string code = R"( function F(...) end )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("omit_decimal_place_for_integers") +{ + const std::string code = R"( local a=5, 6, 7, 3.141, 1.1290000000000002e+45 )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("for_loop") +{ + const std::string one = R"( for i=1,10 do end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string code = R"( for i=5,6,7 do end )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("for_in_loop") +{ + const std::string code = R"( for k, v in ipairs(x)do end )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("while_loop") +{ + const std::string code = R"( while f(x)do print() end )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("repeat_until_loop") +{ + const std::string code = R"( repeat print() until f(x) )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("lambda") +{ + const std::string one = R"( local p=function(o, m, g) return 77 end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local p=function(o, m, g,...) return 77 end )"; + CHECK_EQ(two, transpile(two).code); +} + +TEST_CASE("local_function") +{ + const std::string one = R"( local function p(o, m, g) return 77 end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( local function p(o, m, g,...) return 77 end )"; + CHECK_EQ(two, transpile(two).code); +} + +TEST_CASE("function") +{ + const std::string one = R"( function p(o, m, g) return 77 end )"; + CHECK_EQ(one, transpile(one).code); + + const std::string two = R"( function p(o, m, g,...) return 77 end )"; + CHECK_EQ(two, transpile(two).code); +} + +TEST_CASE("table_literals") +{ + const std::string code = R"( local t={1, 2, 3, foo='bar', baz=99,[5.5]='five point five', 'end'} )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("more_table_literals") +{ + const std::string code = R"( local t={['Content-Type']='text/plain'} )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_preserves_record_vs_general") +{ + const std::string code = R"( local t={['foo']='bar',quux=42} )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_numeric_key") +{ + const std::string code = R"( local t={[5]='five',[6]='six'} )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_with_keyword_key") +{ + const std::string code = R"( local t={['nil']=nil,['true']=true} )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("table_literal_closing_brace_at_correct_position") +{ + const std::string code = R"( + local t={ + eggs='Tasty', + avocado='more like awesomecavo amirite' + } + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("method_calls") +{ + const std::string code = R"( foo.bar.baz:quux() )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("method_definitions") +{ + const std::string code = R"( function foo.bar.baz:quux() end )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("spaces_between_keywords_even_if_it_pushes_the_line_estimation_off") +{ + // Luau::Parser doesn't exactly preserve the string representation of numbers in Lua, so we can find ourselves + // falling out of sync with the original code. We need to push keywords out so that there's at least one space between them. + const std::string code = R"( if math.abs(raySlope) < .01 then return 0 end )"; + const std::string expected = R"( if math.abs(raySlope) < 0.01 then return 0 end)"; + CHECK_EQ(expected, transpile(code).code); +} + +TEST_CASE("numbers") +{ + const std::string code = R"( local a=2510238627 )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("infinity") +{ + const std::string code = R"( local a = 1e500 local b = 1e400 )"; + const std::string expected = R"( local a = 1e500 local b = 1e500 )"; + CHECK_EQ(expected, transpile(code).code); +} + +TEST_CASE("escaped_strings") +{ + const std::string code = R"( local s='\\b\\t\\n\\\\' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("need_a_space_between_number_literals_and_dots") +{ + const std::string code = R"( return point and math.ceil(point* 100000* 100)/ 100000 .. '%'or '' )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("do_blocks") +{ + const std::string code = R"( + foo() + + do + local bar=baz() + quux() + end + + foo2() + )"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("emit_a_do_block_in_cases_of_potentially_ambiguous_syntax") +{ + const std::string code = R"( + f(); + (g or f)() + )"; + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE("roundtrip_types") +{ + const std::string code = R"( + local s:string='str' + local t:{a:string,b:number,[string]:number} + local fn:(string,string)->(number,number) + local s2:typeof(s)='foo' + local os:string? + local sn:string|number + local it:{x:number}&{y:number} + )"; + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + + ParseOptions options; + + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, options); + REQUIRE(parseResult.errors.empty()); + + CHECK_EQ(code, transpileWithTypes(*parseResult.root)); +} + +TEST_CASE("roundtrip_generic_types") +{ + const std::string code = R"( + export type A = {v:T, next:A} + )"; + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + + ParseOptions options; + + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, options); + REQUIRE(parseResult.errors.empty()); + + CHECK_EQ(code, transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "attach_types") +{ + const std::string code = R"( + local s='str' + local t={a=1,b=false} + local function fn() + return 10 + end + )"; + const std::string expected = R"( + local s:string='str' + local t:{a:number,b:boolean}={a=1,b=false} + local function fn(): number + return 10 + end + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); +} + +TEST_CASE("a_table_key_can_be_the_empty_string") +{ + std::string code = "local T = {[''] = true}"; + + CHECK_EQ(code, transpile(code).code); +} + +// There's a bit of login in the transpiler that always adds a space before a dot if the previous symbol ends in a digit. +// This was surfacing an issue where we might not insert a space after the 'local' keyword. +TEST_CASE("always_emit_a_space_after_local_keyword") +{ + std::string code = "do local aZZZZ = Workspace.P1.Shape local bZZZZ = Enum.PartType.Cylinder end"; + std::string expected = "do local aZZZZ = Workspace.P1 .Shape local bZZZZ= Enum.PartType.Cylinder end"; + + CHECK_EQ(expected, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "types_should_not_be_considered_cyclic_if_they_are_not_recursive") +{ + std::string code = R"( + local common: {foo:string} + + local t = {} + t.x = common + t.y = common + )"; + + std::string expected = R"( + local common: {foo:string} + + local t:{x:{foo:string},y:{foo:string}}={} + t.x = common + t.y = common + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); +} + +TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") +{ + std::string code = R"( + local a = function(a: string, b: number, ...: string): (string, ...number) + end + + local b = function(...: string): ...number + end + + local c = function() + end + )"; + + std::string expected = R"( + local a:(string,number,...string)->(string,...number)=function(a:string,b:number,...:...string): (string,...number) + end + + local b:(...string)->(...number)=function(...:...string): ...number + end + + local c:()->()=function(): () + end + )"; + + std::string actual = decorateWithTypes(code); + + CHECK_EQ(expected, decorateWithTypes(code)); +} + +TEST_CASE_FIXTURE(Fixture, "function_type_location") +{ + std::string code = R"( + local function foo(x: number): number + return x + end + local g: (number)->number = foo + )"; + + std::string expected = R"( + local function foo(x: number): number + return x + end + local g: (number)->(number)=foo + )"; + + std::string actual = decorateWithTypes(code); + + CHECK_EQ(expected, actual); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp new file mode 100644 index 0000000..2e40016 --- /dev/null +++ b/tests/TypeInfer.annotations.test.cpp @@ -0,0 +1,728 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("AnnotationTests"); + +TEST_CASE_FIXTURE(Fixture, "check_against_annotations") +{ + CheckResult result = check("local a: number = \"Hello Types!\""); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "check_multi_assign") +{ + CheckResult result = check("local a: number, b: string = \"994\", 888"); + CHECK_EQ(2, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "successful_check") +{ + CheckResult result = check("local a: number, b: string = 994, \"eight eighty eight\""); + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_parameters_can_have_annotations") +{ + CheckResult result = check(R"( + function double(x: number) + return x * 2 + end + + local four = double(2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_parameter_annotations_are_checked") +{ + CheckResult result = check(R"( + function double(x: number) + return x * 2 + end + + local four = double("two") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") +{ + CheckResult result = check(R"( + function fifty(): any + return 55 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId fiftyType = requireType("fifty"); + const FunctionTypeVar* ftv = get(fiftyType); + REQUIRE(ftv != nullptr); + + TypePackId retPack = ftv->retType; + const TypePack* tp = get(retPack); + REQUIRE(tp != nullptr); + + REQUIRE_EQ(1, tp->head.size()); + + REQUIRE_EQ(typeChecker.anyType, tp->head[0]); +} + +TEST_CASE_FIXTURE(Fixture, "function_return_multret_annotations_are_checked") +{ + CheckResult result = check(R"( + function foo(): (number, string) + return 1, 2 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "function_return_annotation_should_disambiguate_into_function_type_return_and_checked") +{ + CheckResult result = check(R"( + function foo(): (number, string) -> nil + return function(a: number, b: string): number return 1 end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "function_return_annotation_should_continuously_parse_return_annotation_and_checked") +{ + CheckResult result = check(R"( + function foo(): (number, string) -> (number) -> nil + return function(a: number, b: string): (number) -> nil + return function(a: number): nil + return 1 + end + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "typeof_variable_type_annotation_should_return_its_type") +{ + CheckResult result = check(R"( + local foo = { bar = "baz" } + + type Foo = typeof(foo) + + local foo2: Foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(requireType("foo"), requireType("foo2")); +} + +TEST_CASE_FIXTURE(Fixture, "infer_type_of_value_a_via_typeof_with_assignment") +{ + CheckResult result = check(R"( + local a + local b: typeof(a) = 1 + + a = "foo" + )"); + + CHECK_EQ(*typeChecker.numberType, *requireType("a")); + CHECK_EQ(*typeChecker.numberType, *requireType("b")); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{4, 12}, Position{4, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); +} + +TEST_CASE_FIXTURE(Fixture, "table_annotation") +{ + CheckResult result = check(R"( + local x: {a: number, b: string} + local y = x.a + local z = x.b + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(follow(requireType("y")))); + CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(follow(requireType("z")))); +} + +TEST_CASE_FIXTURE(Fixture, "function_annotation") +{ + CheckResult result = check(R"( + local f: (number, string) -> number + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + dumpErrors(result); + + TypeId fType = requireType("f"); + const FunctionTypeVar* ftv = get(follow(fType)); + + REQUIRE(ftv != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "function_annotation_with_a_defined_function") +{ + CheckResult result = check(R"( + local f: (number, number) -> string = function(a: number, b: number) return "" end + )"); + + TypeId fType = requireType("f"); + const FunctionTypeVar* ftv = get(follow(fType)); + + REQUIRE(ftv != nullptr); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_assertion_expr") +{ + CheckResult result = check("local a = 55 :: any"); + REQUIRE_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "as_expr_does_not_propagate_type_info") +{ + CheckResult result = check(R"( + local a = 55 :: any + local b = a :: number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "type_annotations_inside_function_bodies") +{ + CheckResult result = check(R"( + function get_message() + local message = 'That smarts!' :: string + return message + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "for_loop_counter_annotation") +{ + CheckResult result1 = check(R"( for i: number = 0, 50 do end )"); + LUAU_REQUIRE_NO_ERRORS(result1); +} + +TEST_CASE_FIXTURE(Fixture, "for_loop_counter_annotation_is_checked") +{ + CheckResult result2 = check(R"( for i: string = 0, 10 do end )"); + CHECK_EQ(1, result2.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_should_alias_to_number") +{ + CheckResult result = check(R"( + type A = number + local a: A = 10 + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_B_should_check_with_another_aliases_until_a_non_aliased_type") +{ + CheckResult result = check(R"( + type A = number + type B = A + local b: B = 10 + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_aliasing_to_number_should_not_check_given_a_string") +{ + CheckResult result = check(R"( + type A = number + local a: A = "fail" + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "self_referential_type_alias") +{ + CheckResult result = check(R"( + type O = { x: number, incr: (O) -> number } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional res = getMainModule()->getModuleScope()->lookupType("O"); + REQUIRE(res); + + TypeId oType = follow(res->type); + const TableTypeVar* oTable = get(oType); + REQUIRE(oTable); + + std::optional incr = get(oTable->props, "incr"); + REQUIRE(incr); + + const FunctionTypeVar* incrFunc = get(incr->type); + REQUIRE(incrFunc); + + std::optional firstArg = first(incrFunc->argTypes); + REQUIRE(firstArg); + + REQUIRE_EQ(follow(*firstArg), oType); +} + +TEST_CASE_FIXTURE(Fixture, "define_generic_type_alias") +{ + CheckResult result = check(R"( + type Array = {[number]: T} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr mainModule = getMainModule(); + + auto it = mainModule->getModuleScope()->privateTypeBindings.find("Array"); + REQUIRE(it != mainModule->getModuleScope()->privateTypeBindings.end()); + + TypeFun& tf = it->second; + CHECK_EQ(1, tf.typeParams.size()); +} + +TEST_CASE_FIXTURE(Fixture, "use_generic_type_alias") +{ + CheckResult result = check(R"( + type Array = {[number]: T} -- 1 + local p: Array = {} -- 2 + p[1] = 5 -- 3 OK + p[2] = 'hello' -- 4 Error. + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(4, result.errors[0].location.begin.line); + CHECK(nullptr != get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "two_type_params") +{ + CheckResult result = check(R"( + type Map = {[K]: V} + local m: Map = {}; + local a = m['foo'] + local b = m[9] -- error here + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(4, result.errors[0].location.begin.line); + + CHECK_EQ(toString(requireType("a")), "number"); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_type_params") +{ + CheckResult result = check(R"( + type Callback = (A) -> (boolean, R) + local a: Callback = function(i) return true, 4 end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(2, result.errors[0].location.begin.line); + + IncorrectGenericParameterCount* igpc = get(result.errors[0]); + CHECK(nullptr != igpc); + + CHECK_EQ(3, igpc->actualParameters); + CHECK_EQ(2, igpc->typeFun.typeParams.size()); + CHECK_EQ("Callback", igpc->name); + + CHECK_EQ("Generic type 'Callback' expects 2 type arguments, but 3 are specified", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_type_param_name") +{ + CheckResult result = check(R"( + type Oopsies = {a: T, b: T} + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + auto dgp = get(result.errors[0]); + REQUIRE(dgp); + CHECK_EQ(dgp->parameterName, "T"); +} + +TEST_CASE_FIXTURE(Fixture, "typeof_expr") +{ + CheckResult result = check(R"( + function id(i) return i end + + local m: typeof(id(77)) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("m"))); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") +{ + CheckResult result = check(R"( + type A = B + type B = A + + local aa:A + local bb:B + )"); + + TypeId fType = requireType("aa"); + const ErrorTypeVar* ftv = get(follow(fType)); + REQUIRE(ftv != nullptr); + REQUIRE(!result.errors.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_always_resolve_to_a_real_type") +{ + CheckResult result = check(R"( + type A = B + type B = C + type C = number + + local aa:A + )"); + + TypeId fType = requireType("aa"); + REQUIRE(follow(fType) == typeChecker.numberType); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") +{ + CheckResult result = check(R"( + export type A = {field: number} + + local n: A = {field = 551} + + return {n=n} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + Module& mod = *getMainModule(); + + const TypeFun& a = mod.getModuleScope()->exportedTypeBindings["A"]; + + CHECK(isInArena(a.type, mod.interfaceTypes)); + CHECK(!isInArena(a.type, typeChecker.globalTypes)); + + std::optional exportsType = first(mod.getModuleScope()->returnType); + REQUIRE(exportsType); + + TableTypeVar* exportsTable = getMutable(*exportsType); + REQUIRE(exportsTable != nullptr); + + TypeId n = exportsTable->props["n"].type; + REQUIRE(n != nullptr); + + CHECK(isInArena(n, mod.interfaceTypes)); +} + +TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") +{ + CheckResult result = check(R"( + export type Array = { [number]: T } + )"); + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + Module& mod = *getMainModule(); + const auto& typeBindings = mod.getModuleScope()->exportedTypeBindings; + + auto it = typeBindings.find("Array"); + REQUIRE(typeBindings.end() != it); + const TypeFun& array = it->second; + + REQUIRE_EQ(1, array.typeParams.size()); + + const TableTypeVar* arrayTable = get(array.type); + REQUIRE(arrayTable != nullptr); + + CHECK_EQ(0, arrayTable->props.size()); + CHECK(arrayTable->indexer); + + CHECK(isInArena(array.type, mod.interfaceTypes)); + CHECK_EQ(array.typeParams[0], arrayTable->indexer->indexResultType); +} + +TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions") +{ + CheckResult result = check(R"( + export type Record = { name: string, location: string } + local a: Record = { name="Waldo", location="?????" } + local b: Record = { name="Santa Claus", location="Maui" } -- FIXME + + return {a=a, b=b} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + Module& mod = *getMainModule(); + + TypeId recordType = mod.getModuleScope()->exportedTypeBindings["Record"].type; + + std::optional exportsType = first(mod.getModuleScope()->returnType); + REQUIRE(exportsType); + + TableTypeVar* exportsTable = getMutable(*exportsType); + REQUIRE(exportsTable != nullptr); + + TypeId aType = exportsTable->props["a"].type; + REQUIRE(aType); + + TypeId bType = exportsTable->props["b"].type; + REQUIRE(bType); + + CHECK(isInArena(recordType, mod.interfaceTypes)); + CHECK(isInArena(aType, mod.interfaceTypes)); + CHECK(isInArena(bType, mod.interfaceTypes)); + + CHECK_EQ(recordType, aType); + CHECK_EQ(recordType, bType); +} + +TEST_CASE_FIXTURE(Fixture, "use_type_required_from_another_file") +{ + addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); + + fileResolver.source["Modules/Main"] = R"( + --!strict + local Test = require(script.Parent.Thing) + + export type Foo = { [any]: Test.TestType } + + return Test + )"; + + fileResolver.source["Modules/Thing"] = R"( + --!strict + + export type TestType = {bar: boolean} + + return {} + )"; + + CheckResult result = frontend.check("Modules/Main"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_use_nonexported_type") +{ + addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); + + fileResolver.source["Modules/Main"] = R"( + --!strict + local Test = require(script.Parent.Thing) + + export type Foo = { [any]: Test.TestType } + + return Test + )"; + + fileResolver.source["Modules/Thing"] = R"( + --!strict + + type TestType = {bar: boolean} + + return {} + )"; + + CheckResult result = frontend.check("Modules/Main"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "builtin_types_are_not_exported") +{ + addGlobalBinding(frontend.typeChecker, "script", frontend.typeChecker.anyType, "@test"); + + fileResolver.source["Modules/Main"] = R"( + --!strict + local Test = require(script.Parent.Thing) + + export type Foo = { [any]: Test.number } + + return Test + )"; + + fileResolver.source["Modules/Thing"] = R"( + --!strict + + return {} + )"; + + CheckResult result = frontend.check("Modules/Main"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +namespace +{ +struct AssertionCatcher +{ + AssertionCatcher() + { + tripped = 0; + oldhook = Luau::assertHandler(); + Luau::assertHandler() = [](const char* expr, const char* file, int line) -> int { + ++tripped; + return 0; + }; + } + + ~AssertionCatcher() + { + Luau::assertHandler() = oldhook; + } + + static int tripped; + Luau::AssertHandler oldhook; +}; + +int AssertionCatcher::tripped; +} // namespace + +TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice") +{ + ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + + AssertionCatcher ac; + + CHECK_THROWS_AS(check(R"( + local a: _luau_ice = 55 + )"), + std::runtime_error); + + LUAU_ASSERT(1 == AssertionCatcher::tripped); +} + +TEST_CASE_FIXTURE(Fixture, "luau_ice_triggers_an_ice_handler") +{ + ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + + AssertionCatcher ac; + + bool caught = false; + + frontend.iceHandler.onInternalError = [&](const char*) { + caught = true; + }; + + CHECK_THROWS_AS(check(R"( + local a: _luau_ice = 55 + )"), + std::runtime_error); + + CHECK_EQ(true, caught); + + frontend.iceHandler.onInternalError = {}; +} + +TEST_CASE_FIXTURE(Fixture, "luau_ice_is_not_special_without_the_flag") +{ + ScopedFastFlag sffs{"DebugLuauMagicTypes", false}; + + // We only care that this does not throw + check(R"( + local a: _luau_ice = 55 + )"); +} + +TEST_CASE_FIXTURE(Fixture, "luau_print_is_magic_if_the_flag_is_set") +{ + // Luau::resetPrintLine(); + ScopedFastFlag sffs{"DebugLuauMagicTypes", true}; + + CheckResult result = check(R"( + local a: _luau_print + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "luau_print_is_not_special_without_the_flag") +{ + ScopedFastFlag sffs{"DebugLuauMagicTypes", false}; + + CheckResult result = check(R"( + local a: _luau_print + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_type_fun_should_not_trip_rbxassert") +{ + CheckResult result = check(R"( + type Foo = typeof(function(x) return x end) + local foo: Foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +#if 0 +// This is because, after visiting all nodes in a block, we check if each type alias still points to a FreeTypeVar. +// Doing it that way is wrong, but I also tried to make typeof(x) return a BoundTypeVar, with no luck. +// Not important enough to fix today. +TEST_CASE_FIXTURE(Fixture, "pulling_a_type_from_value_dont_falsely_create_occurs_check_failed") +{ + CheckResult result = check(R"( + function f(x) + type T = typeof(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} +#endif + +TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_union_typevar") +{ + CheckResult result = check(R"( + type T = T | T + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + OccursCheckFailed* ocf = get(result.errors[0]); + REQUIRE(ocf); +} + +TEST_CASE_FIXTURE(Fixture, "occurs_check_on_cyclic_intersection_typevar") +{ + CheckResult result = check(R"( + type T = T & T + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + OccursCheckFailed* ocf = get(result.errors[0]); + REQUIRE(ocf); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp new file mode 100644 index 0000000..e897477 --- /dev/null +++ b/tests/TypeInfer.builtins.test.cpp @@ -0,0 +1,847 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/BuiltinDefinitions.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("BuiltinTests"); + +TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") +{ + CheckResult result = check(R"( + local a00 = math.frexp + local a01 = math.ldexp + local a02 = math.fmod + local a03 = math.modf + local a04 = math.pow + local a05 = math.exp + local a06 = math.floor + local a07 = math.abs + local a08 = math.sqrt + local a09 = math.log + local a10 = math.log10 + local a11 = math.rad + local a12 = math.deg + local a13 = math.sin + local a14 = math.cos + local a15 = math.tan + local a16 = math.sinh + local a17 = math.cosh + local a18 = math.tanh + local a19 = math.atan + local a20 = math.acos + local a21 = math.asin + local a22 = math.atan2 + local a23 = math.ceil + local a24 = math.min + local a25 = math.max + local a26 = math.pi + local a29 = math.huge + local a30 = math.randomseed + local a31 = math.random + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "next_iterator_should_infer_types_and_type_check") +{ + CheckResult result = check(R"( + local a: string, b: number = next({ 1 }) + + local s = "foo" + local t = { [s] = 1 } + local c: string, d: number = next(t) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "pairs_iterator_should_infer_types_and_type_check") +{ + CheckResult result = check(R"( + type Map = { [K]: V } + local map: Map = { ["foo"] = 1, ["bar"] = 2, ["baz"] = 3 } + + local it: (Map, string | nil) -> (string, number), t: Map, i: nil = pairs(map) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "ipairs_iterator_should_infer_types_and_type_check") +{ + CheckResult result = check(R"( + type Map = { [K]: V } + local array: Map = { "foo", "bar", "baz" } + + local it: (Map, number) -> (number, string), t: Map, i: number = ipairs(array) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_dot_remove_optionally_returns_generic") +{ + CheckResult result = check(R"( + local t = { 1 } + local n = table.remove(t, 7) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("n")), "number?"); +} + +TEST_CASE_FIXTURE(Fixture, "table_concat_returns_string") +{ + CheckResult result = check(R"( + local r = table.concat({1,2,3,4}, ",", 2); + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.stringType, *requireType("r")); +} + +TEST_CASE_FIXTURE(Fixture, "sort") +{ + CheckResult result = check(R"( + local t = {1, 2, 3}; + table.sort(t) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "sort_with_predicate") +{ + CheckResult result = check(R"( + --!strict + local t = {1, 2, 3} + local function p(a: number, b: number) return a < b end + table.sort(t, p) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "sort_with_bad_predicate") +{ + CheckResult result = check(R"( + --!strict + local t = {'one', 'two', 'three'} + local function p(a: number, b: number) return a < b end + table.sort(t, p) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "strings_have_methods") +{ + CheckResult result = check(R"LUA( + local s = ("RoactHostChangeEvent(%s)"):format("hello") + )LUA"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.stringType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "math_max_variatic") +{ + CheckResult result = check(R"( + local n = math.max(1,2,3,4,5,6,7,8,9,0) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.numberType, *requireType("n")); +} + +TEST_CASE_FIXTURE(Fixture, "math_max_checks_for_numbers") +{ + CheckResult result = check(R"( + local n = math.max(1,2,"3") + )"); + + CHECK(!result.errors.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "builtin_tables_sealed") +{ + CheckResult result = check(R"LUA( + local b = bit32 + )LUA"); + TypeId bit32 = requireType("b"); + REQUIRE(bit32 != nullptr); + const TableTypeVar* bit32t = get(bit32); + REQUIRE(bit32t != nullptr); + CHECK_EQ(bit32t->state, TableState::Sealed); +} + +TEST_CASE_FIXTURE(Fixture, "lua_51_exported_globals_all_exist") +{ + // Extracted from lua5.1 + CheckResult result = check(R"( + local v__G = _G + local v_string_sub = string.sub + local v_string_upper = string.upper + local v_string_len = string.len + local v_string_rep = string.rep + local v_string_find = string.find + local v_string_match = string.match + local v_string_char = string.char + local v_string_gmatch = string.gmatch + local v_string_reverse = string.reverse + local v_string_byte = string.byte + local v_string_format = string.format + local v_string_gsub = string.gsub + local v_string_lower = string.lower + + local v_xpcall = xpcall + + --local v_package_loadlib = package.loadlib + --local v_package_loaders_1_ = package.loaders[1] + --local v_package_loaders_2_ = package.loaders[2] + --local v_package_loaders_3_ = package.loaders[3] + --local v_package_loaders_4_ = package.loaders[4] + + local v_tostring = tostring + local v_print = print + + --local v_os_exit = os.exit + --local v_os_setlocale = os.setlocale + local v_os_date = os.date + --local v_os_getenv = os.getenv + local v_os_difftime = os.difftime + --local v_os_remove = os.remove + local v_os_time = os.time + --local v_os_clock = os.clock + --local v_os_tmpname = os.tmpname + --local v_os_rename = os.rename + --local v_os_execute = os.execute + + local v_unpack = unpack + local v_require = require + local v_getfenv = getfenv + local v_setmetatable = setmetatable + local v_next = next + local v_assert = assert + local v_tonumber = tonumber + + --local v_io_lines = io.lines + --local v_io_write = io.write + --local v_io_close = io.close + --local v_io_flush = io.flush + --local v_io_open = io.open + --local v_io_output = io.output + --local v_io_type = io.type + --local v_io_read = io.read + --local v_io_stderr = io.stderr + --local v_io_stdin = io.stdin + --local v_io_input = io.input + --local v_io_stdout = io.stdout + --local v_io_popen = io.popen + --local v_io_tmpfile = io.tmpfile + + local v_rawequal = rawequal + --local v_collectgarbage = collectgarbage + local v_getmetatable = getmetatable + local v_rawset = rawset + + local v_math_log = math.log + local v_math_max = math.max + local v_math_acos = math.acos + local v_math_huge = math.huge + local v_math_ldexp = math.ldexp + local v_math_pi = math.pi + local v_math_cos = math.cos + local v_math_tanh = math.tanh + local v_math_pow = math.pow + local v_math_deg = math.deg + local v_math_tan = math.tan + local v_math_cosh = math.cosh + local v_math_sinh = math.sinh + local v_math_random = math.random + local v_math_randomseed = math.randomseed + local v_math_frexp = math.frexp + local v_math_ceil = math.ceil + local v_math_floor = math.floor + local v_math_rad = math.rad + local v_math_abs = math.abs + local v_math_sqrt = math.sqrt + local v_math_modf = math.modf + local v_math_asin = math.asin + local v_math_min = math.min + --local v_math_mod = math.mod + local v_math_fmod = math.fmod + local v_math_log10 = math.log10 + local v_math_atan2 = math.atan2 + local v_math_exp = math.exp + local v_math_sin = math.sin + local v_math_atan = math.atan + + --local v_debug_getupvalue = debug.getupvalue + --local v_debug_debug = debug.debug + --local v_debug_sethook = debug.sethook + --local v_debug_getmetatable = debug.getmetatable + --local v_debug_gethook = debug.gethook + --local v_debug_setmetatable = debug.setmetatable + --local v_debug_setlocal = debug.setlocal + --local v_debug_traceback = debug.traceback + --local v_debug_setfenv = debug.setfenv + --local v_debug_getinfo = debug.getinfo + --local v_debug_setupvalue = debug.setupvalue + --local v_debug_getlocal = debug.getlocal + --local v_debug_getregistry = debug.getregistry + --local v_debug_getfenv = debug.getfenv + + local v_pcall = pcall + + --local v_table_setn = table.setn + local v_table_insert = table.insert + --local v_table_getn = table.getn + --local v_table_foreachi = table.foreachi + local v_table_maxn = table.maxn + --local v_table_foreach = table.foreach + local v_table_concat = table.concat + local v_table_sort = table.sort + local v_table_remove = table.remove + + local v_newproxy = newproxy + local v_type = type + + local v_coroutine_resume = coroutine.resume + local v_coroutine_yield = coroutine.yield + local v_coroutine_status = coroutine.status + local v_coroutine_wrap = coroutine.wrap + local v_coroutine_create = coroutine.create + local v_coroutine_running = coroutine.running + + local v_select = select + local v_gcinfo = gcinfo + local v_pairs = pairs + local v_rawget = rawget + local v_loadstring = loadstring + local v_ipairs = ipairs + local v__VERSION = _VERSION + --local v_dofile = dofile + local v_setfenv = setfenv + --local v_load = load + local v_error = error + --local v_loadfile = loadfile + )"); + + dumpErrors(result); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "setmetatable_unpacks_arg_types_correctly") +{ + CheckResult result = check(R"( + setmetatable({}, setmetatable({}, {})) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_overload") +{ + CheckResult result = check(R"( + local t = {} + table.insert(t, "foo") + local s = t[1] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(typeChecker.stringType, requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "table_insert_corrrectly_infers_type_of_array_3_args_overload") +{ + CheckResult result = check(R"( + local t = {} + table.insert(t, 1, "foo") + local s = t[1] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireType("s"))); +} + +TEST_CASE_FIXTURE(Fixture, "table_pack") +{ + CheckResult result = check(R"( + local t = table.pack(1, "foo", true) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("{| [number]: boolean | number | string, n: number |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "table_pack_variadic") +{ + CheckResult result = check(R"( +--!strict +function f(): (string, ...number) + return "str", 2, 3, 4 +end + +local t = table.pack(f()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("{| [number]: number | string, n: number |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "table_pack_reduce") +{ + CheckResult result = check(R"( + local t = table.pack(1, 2, true) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("{| [number]: boolean | number, n: number |}", toString(requireType("t"))); + + result = check(R"( + local t = table.pack("a", "b", "c") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("{| [number]: string, n: number |}", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "gcinfo") +{ + CheckResult result = check(R"( + local n = gcinfo() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.numberType, *requireType("n")); +} + +TEST_CASE_FIXTURE(Fixture, "getfenv") +{ + LUAU_REQUIRE_NO_ERRORS(check("getfenv(1)")); +} + +TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") +{ + CheckResult result = check(R"( + local n1 = os.time() + local n2 = os.time({ year = 2020, month = 4, day = 20 }) + local n3 = os.time({ year = 2020, month = 4, day = 20, hour = 0, min = 0, sec = 0, isdst = true }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.numberType, *requireType("n1")); + CHECK_EQ(*typeChecker.numberType, *requireType("n2")); + CHECK_EQ(*typeChecker.numberType, *requireType("n3")); +} + +TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") +{ + ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; + + CheckResult result = check(R"( + local co = coroutine.create(function() end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.threadType, *requireType("co")); +} + +TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") +{ + ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; + + CheckResult result = check(R"( + local function nifty(x, y) + print(x, y) + local z = coroutine.yield(1, 2) + print(z) + return 42 + end + + local co = coroutine.create(nifty) + local x, y = coroutine.resume(co, 1, 2) + local answer = coroutine.resume(co, 3) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") +{ + ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; + + CheckResult result = check(R"( + --!nonstrict + local function nifty(x, y) + print(x, y) + local z = coroutine.yield(1, 2) + print(z) + return 42 + end + + local f = coroutine.wrap(nifty) + local x, y = f(1, 2) + local answer = f(3) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") +{ + CheckResult result = check(R"( + local string = string + + setmetatable(string, {}) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto stringType = requireType("string"); + auto ttv = get(stringType); + REQUIRE(ttv); +} + +TEST_CASE_FIXTURE(Fixture, "string_format_arg_types_inference") +{ + CheckResult result = check(R"( + --!strict + function f(a, b, c) + return string.format("%f %d %s", a, b, c) + end + )"); + + CHECK_EQ(0, result.errors.size()); + CHECK_EQ("(number, number, string) -> string", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "string_format_arg_count_mismatch") +{ + CheckResult result = check(R"( + --!strict + string.format("%f %d %s") + string.format("%s", "hi", 42) + string.format("%s", "hi", 42, ...) + string.format("%s", "hi", ...) + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + CHECK_EQ(result.errors[0].location.begin.line, 2); + CHECK_EQ(result.errors[1].location.begin.line, 3); + CHECK_EQ(result.errors[2].location.begin.line, 4); +} + +TEST_CASE_FIXTURE(Fixture, "string_format_correctly_ordered_types") +{ + CheckResult result = check(R"( + --!strict + string.format("%s", 123) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(tm->wantedType, typeChecker.stringType); + CHECK_EQ(tm->givenType, typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "xpcall") +{ + CheckResult result = check(R"( + --!strict + local a, b, c = xpcall( + function() return 5, true end, + function(e) return 0, false end + ) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("boolean", toString(requireType("a"))); + REQUIRE_EQ("number", toString(requireType("b"))); + REQUIRE_EQ("boolean", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "see_thru_select") +{ + CheckResult result = check(R"( + local a:number, b:boolean = select(2,"hi", 10, true) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "see_thru_select_count") +{ + CheckResult result = check(R"( + local a = select("#","hi", 10, true) + )"); + + dumpErrors(result); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "select_with_decimal_argument_is_rounded_down") +{ + CheckResult result = check(R"( + local a: number, b: boolean = select(2.9, "foo", 1, true) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// Could be flaky if the fix has regressed. +TEST_CASE_FIXTURE(Fixture, "bad_select_should_not_crash") +{ + CheckResult result = check(R"( + do end + local _ = function(l0,...) + end + local _ = function() + _(_); + _ += select(_()) + end + )"); + + CHECK_LE(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "select_way_out_of_range") +{ + CheckResult result = check(R"( + select(5432598430953240958) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); + CHECK_EQ("bad argument #1 to select (index out of range)", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "select_slightly_out_of_range") +{ + CheckResult result = check(R"( + select(3, "a", 1) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); + CHECK_EQ("bad argument #1 to select (index out of range)", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail") +{ + CheckResult result = check(R"( + --!nonstrict + local function f(...) + return ... + end + + local foo, bar, baz, quux = select(1, f("foo", true)) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); + CHECK_EQ("any", toString(requireType("baz"))); + CHECK_EQ("any", toString(requireType("quux"))); +} + +TEST_CASE_FIXTURE(Fixture, "select_with_variadic_typepack_tail_and_string_head") +{ + CheckResult result = check(R"( + --!nonstrict + local function f(...) + return ... + end + + local foo, bar, baz, quux = select(1, "foo", f("bar", true)) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); + CHECK_EQ("any", toString(requireType("baz"))); + CHECK_EQ("any", toString(requireType("quux"))); +} + +TEST_CASE_FIXTURE(Fixture, "string_format_as_method") +{ + CheckResult result = check("local _ = ('%s'):format(5)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(tm->wantedType, typeChecker.stringType); + CHECK_EQ(tm->givenType, typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument") +{ + CheckResult result = check(R"( + local _ = ("%s"):format("%d", "hello") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument2") +{ + CheckResult result = check(R"( + local _ = ("%s %d").format("%d %s", "A type error", 2) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "debug_traceback_is_crazy") +{ + CheckResult result = check(R"( +local co: thread = ... +-- debug.traceback takes thread?, message?, level? - yes, all optional! +debug.traceback() +debug.traceback(nil, 1) +debug.traceback("msg") +debug.traceback("msg", 1) +debug.traceback(co) +debug.traceback(co, "msg") +debug.traceback(co, "msg", 1) +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "debug_info_is_crazy") +{ + CheckResult result = check(R"( +local co: thread, f: ()->() = ... + +-- debug.info takes thread?, level, options or function, options +debug.info(1, "n") +debug.info(co, 1, "n") +debug.info(f, "n") +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "aliased_string_format") +{ + CheckResult result = check(R"( + local fmt = string.format + local s = fmt("%d", "oops") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_lib_self_noself") +{ + CheckResult result = check(R"( + --!nonstrict + local a1 = string.byte("abcdef", 2) + local a2 = string.find("abcdef", "def") + local a3 = string.gmatch("ab ab", "%a+") + local a4 = string.gsub("abab", "ab", "cd") + local a5 = string.len("abc") + local a6 = string.match("12 ab", "%d+ %a+") + local a7 = string.rep("a", 10) + local a8 = string.sub("abcd", 1, 2) + local a9 = string.split("a,b,c", ",") + local a0 = string.packsize("ff") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "gmatch_definition") +{ + CheckResult result = check(R"_( +local a, b, c = ("hey"):gmatch("(.)(.)(.)")() + +for c in ("hey"):gmatch("(.)") do + print(c:upper()) +end +)_"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "select_on_variadic") +{ + CheckResult result = check(R"( + local function f(): (number, ...(boolean | number)) + return 100, true, 1 + end + + local a, b, c = select(f()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("a"))); + CHECK_EQ("any", toString(requireType("b"))); + CHECK_EQ("any", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_positions") +{ + CheckResult result = check(R"( + ("%s%d%s"):format(1, "hello", true) + )"); + + TypeId stringType = typeChecker.stringType; + TypeId numberType = typeChecker.numberType; + TypeId booleanType = typeChecker.booleanType; + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + CHECK_EQ(Location(Position{1, 26}, Position{1, 27}), result.errors[0].location); + CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[0].data); + + CHECK_EQ(Location(Position{1, 29}, Position{1, 36}), result.errors[1].location); + CHECK_EQ(TypeErrorData(TypeMismatch{numberType, stringType}), result.errors[1].data); + + CHECK_EQ(Location(Position{1, 38}, Position{1, 42}), result.errors[2].location); + CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[2].data); +} + +TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") +{ + ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; + + CheckResult result = check(R"( + local f = math.sin + local function g(x) return math.sin(x) end + f = g + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId fType = requireType("f"); + const FunctionTypeVar* ftv = get(fType); + REQUIRE(fType); + REQUIRE(fType->persistent); + REQUIRE(!ftv->definition); + + TypeId gType = requireType("g"); + const FunctionTypeVar* gtv = get(gType); + REQUIRE(gType); + REQUIRE(!gType->persistent); + REQUIRE(gtv->definition); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp new file mode 100644 index 0000000..6da33a0 --- /dev/null +++ b/tests/TypeInfer.classes.test.cpp @@ -0,0 +1,456 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; +using std::nullopt; + +struct ClassFixture : Fixture +{ + ClassFixture() + { + TypeArena& arena = typeChecker.globalTypes; + TypeId numberType = typeChecker.numberType; + + unfreeze(arena); + + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}}); + getMutable(baseClassInstanceType)->props = { + {"BaseMethod", {makeFunction(arena, baseClassInstanceType, {numberType}, {})}}, + {"BaseField", {numberType}}, + }; + + TypeId baseClassType = arena.addType(ClassTypeVar{"BaseClass", {}, nullopt, nullopt, {}, {}}); + getMutable(baseClassType)->props = { + {"StaticMethod", {makeFunction(arena, nullopt, {}, {numberType})}}, + {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, + {"New", {makeFunction(arena, nullopt, {}, {baseClassInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + addGlobalBinding(typeChecker, "BaseClass", baseClassType, "@test"); + + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}}); + + getMutable(childClassInstanceType)->props = { + {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, + }; + + TypeId childClassType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassType, nullopt, {}, {}}); + getMutable(childClassType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + addGlobalBinding(typeChecker, "ChildClass", childClassType, "@test"); + + TypeId grandChildInstanceType = arena.addType(ClassTypeVar{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}}); + + getMutable(grandChildInstanceType)->props = { + {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, + }; + + TypeId grandChildType = arena.addType(ClassTypeVar{"GrandChild", {}, baseClassType, nullopt, {}, {}}); + getMutable(grandChildType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; + addGlobalBinding(typeChecker, "GrandChild", childClassType, "@test"); + + TypeId anotherChildInstanceType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}}); + + getMutable(anotherChildInstanceType)->props = { + {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, + }; + + TypeId anotherChildType = arena.addType(ClassTypeVar{"AnotherChild", {}, baseClassType, nullopt, {}, {}}); + getMutable(anotherChildType)->props = { + {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; + addGlobalBinding(typeChecker, "AnotherChild", childClassType, "@test"); + + TypeId vector2MetaType = arena.addType(TableTypeVar{}); + + TypeId vector2InstanceType = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, vector2MetaType, {}, {}}); + getMutable(vector2InstanceType)->props = { + {"X", {numberType}}, + {"Y", {numberType}}, + }; + + TypeId vector2Type = arena.addType(ClassTypeVar{"Vector2", {}, nullopt, nullopt, {}, {}}); + getMutable(vector2Type)->props = { + {"New", {makeFunction(arena, nullopt, {numberType, numberType}, {vector2InstanceType})}}, + }; + getMutable(vector2MetaType)->props = { + {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, + }; + typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; + addGlobalBinding(typeChecker, "Vector2", vector2Type, "@test"); + + freeze(arena); + } +}; + +TEST_SUITE_BEGIN("TypeInferClasses"); + +TEST_CASE_FIXTURE(ClassFixture, "call_method_of_a_class") +{ + CheckResult result = check(R"( + local m = BaseClass.StaticMethod() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("number", toString(requireType("m"))); +} + +TEST_CASE_FIXTURE(ClassFixture, "call_method_of_a_child_class") +{ + CheckResult result = check(R"( + local m = ChildClass.StaticMethod() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("number", toString(requireType("m"))); +} + +TEST_CASE_FIXTURE(ClassFixture, "call_instance_method") +{ + CheckResult result = check(R"( + local i = ChildClass.New() + local result = i:Method() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireType("result"))); +} + +TEST_CASE_FIXTURE(ClassFixture, "call_base_method") +{ + CheckResult result = check(R"( + local i = ChildClass.New() + i:BaseMethod(41) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "cannot_call_unknown_method_of_a_class") +{ + CheckResult result = check(R"( + local m = BaseClass.Nope() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(ClassFixture, "cannot_call_method_of_child_on_base_instance") +{ + CheckResult result = check(R"( + local i = BaseClass.New() + i:Method() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(ClassFixture, "we_can_infer_that_a_parameter_must_be_a_particular_class") +{ + CheckResult result = check(R"( + function makeClone(o) + return BaseClass.Clone(o) + end + + local a = makeClone(ChildClass.New()) + )"); + + CHECK_EQ("BaseClass", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(ClassFixture, "we_can_report_when_someone_is_trying_to_use_a_table_rather_than_a_class") +{ + CheckResult result = check(R"( + function makeClone(o) + return BaseClass.Clone(o) + end + + type Oopsies = { BaseMethod: (Oopsies, number) -> ()} + + local oopsies: Oopsies = { + BaseMethod = function (self: Oopsies, i: number) + print('gadzooks!') + end + } + + makeClone(oopsies) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm != nullptr); + + CHECK_EQ("Oopsies", toString(tm->givenType)); + CHECK_EQ("BaseClass", toString(tm->wantedType)); +} + +TEST_CASE_FIXTURE(ClassFixture, "assign_to_prop_of_class") +{ + CheckResult result = check(R"( + local v = Vector2.New(0, 5) + v.X = 55 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "can_read_prop_of_base_class") +{ + CheckResult result = check(R"( + local c = ChildClass.New() + local x = 1 + c.BaseField + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class") +{ + CheckResult result = check(R"( + local c = ChildClass.New() + c.BaseField = 444 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "can_read_prop_of_base_class_using_string") +{ + ScopedFastFlag luauClassPropertyAccessAsString("LuauClassPropertyAccessAsString", true); + + CheckResult result = check(R"( + local c = ChildClass.New() + local x = 1 + c["BaseField"] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class_using_string") +{ + ScopedFastFlag luauClassPropertyAccessAsString("LuauClassPropertyAccessAsString", true); + + CheckResult result = check(R"( + local c = ChildClass.New() + c["BaseField"] = 444 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "cannot_unify_class_instance_with_primitive") +{ + CheckResult result = check(R"( + local v = Vector2.New(0, 5) + v = 444 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(ClassFixture, "warn_when_prop_almost_matches") +{ + CheckResult result = check(R"( + Vector2.new(0, 0) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = get(result.errors[0]); + REQUIRE(err != nullptr); + + REQUIRE_EQ(1, err->candidates.size()); + CHECK_EQ("New", *err->candidates.begin()); +} + +TEST_CASE_FIXTURE(ClassFixture, "classes_can_have_overloaded_operators") +{ + CheckResult result = check(R"( + local a = Vector2.New(1, 2) + local b = Vector2.New(3, 4) + local c = a + b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Vector2", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(ClassFixture, "classes_without_overloaded_operators_cannot_be_added") +{ + CheckResult result = check(R"( + local a = BaseClass.New() + local b = BaseClass.New() + local c = a + b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(ClassFixture, "function_arguments_are_covariant") +{ + CheckResult result = check(R"( + function f(b: BaseClass) end + + f(ChildClass.New()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "higher_order_function_arguments_are_contravariant") +{ + CheckResult result = check(R"( + function apply(f: (BaseClass) -> ()) + f(ChildClass.New()) -- 2 + end + + apply(function (c: ChildClass) end) -- 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(ClassFixture, "higher_order_function_return_values_are_covariant") +{ + CheckResult result = check(R"( + function apply(f: () -> BaseClass) + return f() + end + + apply(function () + return ChildClass.New() + end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "higher_order_function_return_type_is_not_contravariant") +{ + CheckResult result = check(R"( + function apply(f: () -> BaseClass) + return f() + end + + apply(function () + return ChildClass.New() + end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(ClassFixture, "table_properties_are_invariant") +{ + CheckResult result = check(R"( + function f(a: {foo: BaseClass}) + a.foo = AnotherChild.New() + end + + local t: {foo: ChildClass} + f(t) -- line 6. Breaks soundness. + + function g(t: {foo: ChildClass}) + end + + local t2: {foo: BaseClass} = {foo=BaseClass.New()} + t2.foo = AnotherChild.New() + g(t2) -- line 13. Breaks soundness + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(6, result.errors[0].location.begin.line); + CHECK_EQ(13, result.errors[1].location.begin.line); +} + +TEST_CASE_FIXTURE(ClassFixture, "table_indexers_are_invariant") +{ + CheckResult result = check(R"( + function f(a: {[number]: BaseClass}) + a[1] = AnotherChild.New() + end + + local t: {[number]: ChildClass} + f(t) -- line 6. Breaks soundness. + + function g(t: {[number]: ChildClass}) + end + + local t2: {[number]: BaseClass} = {BaseClass.New()} + t2[1] = AnotherChild.New() + g(t2) -- line 13. Breaks soundness + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(6, result.errors[0].location.begin.line); + CHECK_EQ(13, result.errors[1].location.begin.line); +} + +TEST_CASE_FIXTURE(ClassFixture, "table_class_unification_reports_sane_errors_for_missing_properties") +{ + CheckResult result = check(R"( + function foo(bar) + bar.Y = 1 -- valid + bar.x = 2 -- invalid, wanted 'X' + bar.w = 2 -- invalid + end + + local a: Vector2 + foo(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + REQUIRE_EQ("Key 'w' not found in class 'Vector2'", toString(result.errors[0])); + REQUIRE_EQ("Key 'x' not found in class 'Vector2'. Did you mean 'X'?", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_order") +{ + CheckResult result = check(R"( + local p: BaseClass + local foo: number = p + local foo2: BaseClass = 1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + REQUIRE_EQ("Type 'BaseClass' could not be converted into 'number'", toString(result.errors[0])); + REQUIRE_EQ("Type 'number' could not be converted into 'BaseClass'", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(ClassFixture, "optional_class_field_access_error") +{ + ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); + + CheckResult result = check(R"( +local b: Vector2? = nil +local a = b.X + b.Z + +b.X = 2 -- real Vector2.X is also read-only + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors[0])); + CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors[1])); + CHECK_EQ("Key 'Z' not found in class 'Vector2'", toString(result.errors[2])); + CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors[3])); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp new file mode 100644 index 0000000..41e3e45 --- /dev/null +++ b/tests/TypeInfer.definitions.test.cpp @@ -0,0 +1,300 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("DefinitionTests"); + +TEST_CASE_FIXTURE(Fixture, "definition_file_loading") +{ + loadDefinition(R"( + declare foo: number + export type Asdf = number | string + declare function bar(x: number): string + declare foo2: typeof(foo) + declare function var(...: any): string + )"); + + TypeId globalFooTy = getGlobalBinding(frontend.typeChecker, "foo"); + CHECK_EQ(toString(globalFooTy), "number"); + + std::optional globalAsdfTy = frontend.typeChecker.globalScope->lookupType("Asdf"); + REQUIRE(bool(globalAsdfTy)); + CHECK_EQ(toString(globalAsdfTy->type), "number | string"); + + TypeId globalBarTy = getGlobalBinding(frontend.typeChecker, "bar"); + CHECK_EQ(toString(globalBarTy), "(number) -> string"); + + TypeId globalFoo2Ty = getGlobalBinding(frontend.typeChecker, "foo2"); + CHECK_EQ(toString(globalFoo2Ty), "number"); + + TypeId globalVarTy = getGlobalBinding(frontend.typeChecker, "var"); + + CHECK_EQ(toString(globalVarTy), "(...any) -> string"); + + CheckResult result = check(R"( + local x: number = foo + 1 + local y: string = bar(x) + local z: Asdf = x + z = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_scope") +{ + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult parseFailResult = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + declare foo + )", + "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE(!parseFailResult.success); + std::optional fooTy = tryGetGlobalBinding(typeChecker, "foo"); + CHECK(!fooTy.has_value()); + + LoadDefinitionFileResult checkFailResult = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + local foo: string = 123 + declare bar: typeof(foo) + )", + "@test"); + + REQUIRE(!checkFailResult.success); + std::optional barTy = tryGetGlobalBinding(typeChecker, "bar"); + CHECK(!barTy.has_value()); +} + +TEST_CASE_FIXTURE(Fixture, "definition_file_classes") +{ + loadDefinition(R"( + declare class Foo + X: number + + function inheritance(self): number + end + + declare class Bar extends Foo + Y: number + + function foo(self, x: number): number + function foo(self, x: string): string + + function __add(self, other: Bar): Bar + end + )"); + + CheckResult result = check(R"( + local x: Bar + local prop: number = x.Y + local inheritedProp: number = x.X + local method: number = x:foo(1) + local method2: string = x:foo("string") + local metamethod: Bar = x + x + local inheritedMethod: number = x:inheritance() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("prop")), "number"); + CHECK_EQ(toString(requireType("inheritedProp")), "number"); + CHECK_EQ(toString(requireType("method")), "number"); + CHECK_EQ(toString(requireType("method2")), "string"); + CHECK_EQ(toString(requireType("metamethod")), "Bar"); + CHECK_EQ(toString(requireType("inheritedMethod")), "number"); +} + +TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") +{ + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + declare class A + X: number + X: string + end + )", + "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE(!result.success); + CHECK_EQ(result.parseResult.errors.size(), 0); + REQUIRE(bool(result.module)); + REQUIRE_EQ(result.module->errors.size(), 1); + GenericError* ge = get(result.module->errors[0]); + REQUIRE(ge); + CHECK_EQ("Cannot overload non-function class member 'X'", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") +{ + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + type NotAClass = {} + + declare class Foo extends NotAClass + end + )", + "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE(!result.success); + CHECK_EQ(result.parseResult.errors.size(), 0); + REQUIRE(bool(result.module)); + REQUIRE_EQ(result.module->errors.size(), 1); + GenericError* ge = get(result.module->errors[0]); + REQUIRE(ge); + CHECK_EQ("Cannot use non-class type 'NotAClass' as a superclass of class 'Foo'", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") +{ + unfreeze(typeChecker.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + declare class Foo extends Bar + end + + declare class Bar extends Foo + end + )", + "@test"); + freeze(typeChecker.globalTypes); + + REQUIRE(!result.success); +} + +TEST_CASE_FIXTURE(Fixture, "declaring_generic_functions") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; + + loadDefinition(R"( + declare function f(a: a, b: b): string + declare function g(...: a...): b... + declare function h(a: a, b: b): (b, a) + )"); + + CheckResult result = check(R"( + local x = f(1, true) + local y: number, z: string = g("foo", 123) + local w, u = h(1, true) + + local f = f + local g = g + local h = h + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("x")), "string"); + CHECK_EQ(toString(requireType("w")), "boolean"); + CHECK_EQ(toString(requireType("u")), "number"); + CHECK_EQ(toString(requireType("f")), "(a, b) -> string"); + CHECK_EQ(toString(requireType("g")), "(a...) -> (b...)"); + CHECK_EQ(toString(requireType("h")), "(a, b) -> (b, a)"); +} + +TEST_CASE_FIXTURE(Fixture, "class_definition_function_prop") +{ + loadDefinition(R"( + declare class Foo + X: (number) -> string + end + )"); + + CheckResult result = check(R"( + local x: Foo + local prop = x.X + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("prop")), "(number) -> string"); +} + +TEST_CASE_FIXTURE(Fixture, "definition_file_class_function_args") +{ + loadDefinition(R"( + declare class Foo + function foo1(self, x: number): number + function foo2(self, x: number, y: string): number + + y: (a: number, b: string) -> string + end + )"); + + CheckResult result = check(R"( + local x: Foo + local methodRef1 = x.foo1 + local methodRef2 = x.foo2 + local prop = x.y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + ToStringOptions opts; + opts.functionTypeArguments = true; + CHECK_EQ(toString(requireType("methodRef1"), opts), "(self: Foo, x: number) -> number"); + CHECK_EQ(toString(requireType("methodRef2"), opts), "(self: Foo, x: number, y: string) -> number"); + CHECK_EQ(toString(requireType("prop"), opts), "(a: number, b: string) -> string"); +} + +TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") +{ + loadDefinition(R"( + declare x: string + + export type Foo = string | number + + declare class Bar + prop: string + end + + declare y: { + x: number, + } + )"); + + std::optional xBinding = typeChecker.globalScope->linearSearchForBinding("x"); + REQUIRE(bool(xBinding)); + // note: loadDefinition uses the @test package name. + CHECK_EQ(xBinding->documentationSymbol, "@test/global/x"); + + std::optional fooTy = typeChecker.globalScope->lookupType("Foo"); + REQUIRE(bool(fooTy)); + CHECK_EQ(fooTy->type->documentationSymbol, "@test/globaltype/Foo"); + + std::optional barTy = typeChecker.globalScope->lookupType("Bar"); + REQUIRE(bool(barTy)); + CHECK_EQ(barTy->type->documentationSymbol, "@test/globaltype/Bar"); + + ClassTypeVar* barClass = getMutable(barTy->type); + REQUIRE(bool(barClass)); + REQUIRE_EQ(barClass->props.count("prop"), 1); + CHECK_EQ(barClass->props["prop"].documentationSymbol, "@test/globaltype/Bar.prop"); + + std::optional yBinding = typeChecker.globalScope->linearSearchForBinding("y"); + REQUIRE(bool(yBinding)); + CHECK_EQ(yBinding->documentationSymbol, "@test/global/y"); + + TableTypeVar* yTtv = getMutable(yBinding->typeId); + REQUIRE(bool(yTtv)); + REQUIRE_EQ(yTtv->props.count("x"), 1); + CHECK_EQ(yTtv->props["x"].documentationSymbol, "@test/global/y.x"); +} + +TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_types") +{ + loadDefinition(R"( + export type Evil = string + )"); + + std::optional ty = typeChecker.globalScope->lookupType("Evil"); + REQUIRE(bool(ty)); + CHECK_EQ(ty->type->documentationSymbol, std::nullopt); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp new file mode 100644 index 0000000..3a04a18 --- /dev/null +++ b/tests/TypeInfer.generics.test.cpp @@ -0,0 +1,698 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("GenericsTests"); + +TEST_CASE_FIXTURE(Fixture, "check_generic_function") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + function id(x:a): a + return x + end + local x: string = id("hi") + local y: number = id(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + local function id(x:a): a + return x + end + local x: string = id("hi") + local y: number = id(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "check_generic_typepack_function") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs4{"LuauGenericVariadicsUnification", true}; + ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + function id(...: a...): (a...) return ... end + local x: string, y: boolean = id("hi", true) + local z: number = id(37) + id() + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "types_before_typepacks") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + function f() end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + local function id(x:a):a return x end + local f: (a)->a = id + local x: string = f("hi") + local y: number = f(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + local function id(x) return x end + print("This is bogus") -- TODO: CLI-39916 + local f = id + local x: string = f("hi") + local y: number = f(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + local function id(x) return x end + print("This is bogus") -- TODO: CLI-39916 + local f: (number)->number = id + local g: (string)->string = id + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + local t = {} + t.m = function(x: a):a return x end + local x: string = t.m("hi") + local y: number = t.m(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + local t: { m: (number)->number } = { m = function(x:number) return x+1 end } + local function id(x:a):a return x end + t.m = id + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "check_nested_generic_function") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + local function f() + local function id(x:a): a + return x + end + local x: string = id("hi") + local y: number = id(37) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "check_recursive_generic_function") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + local function id(x:a):a + local y: string = id("hi") + local z: number = id(37) + return x + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + local id2 + local function id1(x:a):a + local y: string = id2("hi") + local z: number = id2(37) + return x + end + function id2(x:a):a + local y: string = id1("hi") + local z: number = id1(37) + return x + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + type T = { id: (a) -> a } + local x: T = { id = function(x:a):a return x end } + local y: string = x.id("hi") + local z: number = x.id(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_factories") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + type T = { id: (a) -> a } + type Factory = { build: () -> T } + + local f: Factory = { + build = function(): T + return { + id = function(x:a):a + return x + end + } + end + } + local y: string = f.build().id("hi") + local z: number = f.build().id(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "factories_of_generics") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + ScopedFastFlag sffs3{"LuauRankNTypes", true}; + + CheckResult result = check(R"( + type T = { id: (a) -> a } + type Factory = { build: () -> T } + + local f: Factory = { + build = function(): T + return { + id = function(x:a):a + return x + end + } + end + } + local x: T = f.build() + local y: string = x.id("hi") + local z: number = x.id(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_function") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + function id(x) + return x + end + local x: string = id("hi") + local y: number = id(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId idType = requireType("id"); + const FunctionTypeVar* idFun = get(idType); + REQUIRE(idFun); + auto [args, varargs] = flatten(idFun->argTypes); + auto [rets, varrets] = flatten(idFun->retType); + + CHECK_EQ(idFun->generics.size(), 1); + CHECK_EQ(idFun->genericPacks.size(), 0); + CHECK_EQ(args[0], idFun->generics[0]); + CHECK_EQ(rets[0], idFun->generics[0]); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + local function id(x) + return x + end + local x: string = id("hi") + local y: number = id(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId idType = requireType("id"); + const FunctionTypeVar* idFun = get(idType); + REQUIRE(idFun); + auto [args, varargs] = flatten(idFun->argTypes); + auto [rets, varrets] = flatten(idFun->retType); + + CHECK_EQ(idFun->generics.size(), 1); + CHECK_EQ(idFun->genericPacks.size(), 0); + CHECK_EQ(args[0], idFun->generics[0]); + CHECK_EQ(rets[0], idFun->generics[0]); +} + +TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + local function f() + local function id(x) + return x + end + local x: string = id("hi") + local y: number = id(37) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + local x = {} + function x:id(x) return x end + function x:f(): string return self:id("hello") end + function x:g(): number return self:id(37) end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + local x = {} + function x:id(x) return x end + function x:f() + local x: string = self:id("hi") + local y: number = self:id(37) + end + )"); + // TODO: Should typecheck but currently errors CLI-39916 + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_property") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauRankNTypes", true}; + CheckResult result = check(R"( + local t = {} + t.m = function(x) return x end + local x: string = t.m("hi") + local y: number = t.m(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_arguments_can_be_polytypes") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + ScopedFastFlag sffs3{"LuauRankNTypes", true}; + CheckResult result = check(R"( + local function f(g: (a)->a) + local x: number = g(37) + local y: string = g("hi") + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_results_can_be_polytypes") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + ScopedFastFlag sffs3{"LuauRankNTypes", true}; + CheckResult result = check(R"( + local function f() : (a)->a + local function id(x:a):a return x end + return id + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_parameters_can_be_polytypes") +{ + ScopedFastFlag sffs1{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + ScopedFastFlag sffs3{"LuauRankNTypes", true}; + CheckResult result = check(R"( + local function id(x:a):a return x end + local f: (a)->a = id(id) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + local function f(y) + -- this will only typecheck if we infer z: any + -- so f: (any)->(any) + local z = y + local function id(x) + z = x -- this assignment is what forces z: any + return x + end + local x: string = id("hi") + local y: number = id(37) + return z + end + -- so this assignment should fail + local b: boolean = f(true) + )"); + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + local function f(y) + local z = y + local function id(x) + z = x + return x + end + local x: string = id("hi") + local y: number = id(37) + end + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") +{ + ScopedFastFlag sffs[] = { + {"LuauGenericFunctions", true}, + {"LuauParseGenericFunctions", true}, + {"LuauRankNTypes", true}, + }; + + CheckResult result = check(R"( + type T = { m: (a) -> T } + function f(t : T) + local x: T = t.m(37) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_unify_bound_types") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + ScopedFastFlag sffs3{"LuauRankNTypes", true}; + + CheckResult result = check(R"( + type F = () -> (a, b) -> a + type G = (b, b) -> b + local f: F = function() + local x + return function(y: a, z: b): a + if not(x) then x = y end + return x + end + end + -- This assignment shouldn't typecheck + -- If it does, it means we instantiated + -- f as () -> (X, b) -> X, then unified X to be b + local g: G = f() + -- Oh dear, if that works then the type system is unsound + local a : string = g("not a number", "hi") + local b : number = g(5, 37) + )"); + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutable_state_polymorphism") +{ + // Replaying the classic problem with polymorphism and mutable state in Luau + // See, e.g. Tofte (1990) + // https://www.sciencedirect.com/science/article/pii/089054019090018D. + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + --!strict + -- Our old friend the polymorphic identity function + local function id(x) return x end + local a: string = id("hi") + local b: number = id(37) + + -- This allows (a)->a to be expressed without generic function syntax + type Id = typeof(id) + + -- This function should have type + -- () -> (a) -> a + -- not type + -- () -> (a) -> a + local function ohDear(): Id + local y + function oh(x) + -- Returns the same x every time it's called + if not(y) then y = x end + return y + end + return oh + end + + -- oh dear, f claims to polymorphic which it shouldn't be + local f: Id = ohDear() + + -- the first call sets y + local a: string = f("not a number") + -- so b has value "not a number" at run time + local b: number = f(37) + )"); + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "rank_N_types_via_typeof") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", false}; + CheckResult result = check(R"( + --!strict + local function id(x) return x end + local x: string = id("hi") + local y: number = id(37) + -- This allows (a)->a to be expressed without generic function syntax + type Id = typeof(id) + -- The rank 1 restriction causes this not to typecheck, since it's + -- declared as returning a polytype. + local function returnsId(): Id + return id + end + -- So this won't typecheck + local f: Id = returnsId() + local a: string = f("hi") + local b: number = f(37) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_generic_types") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; + CheckResult result = check(R"( + function f(x:a):a return x end + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_generic_type_packs") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + function f() end + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "typepacks_before_types") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + CheckResult result = check(R"( + function f() end + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "variadic_generics") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + function f(...: a) end + + type F = (...a) -> ...a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_syntax") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + function f(...: a...): (a...) return ... end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("f")), "(a...) -> (a...)"); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_parentheses") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs4{"LuauGenericVariadicsUnification", true}; + ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + function f(...: a...): any return (...) end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + function f(...: T...) + return ... + end + + function g(a: T) + return a + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + SwappedGenericTypeParameter* fErr = get(result.errors[0]); + REQUIRE(fErr); + CHECK_EQ(fErr->name, "T"); + CHECK_EQ(fErr->kind, SwappedGenericTypeParameter::Pack); + + SwappedGenericTypeParameter* gErr = get(result.errors[1]); + REQUIRE(gErr); + CHECK_EQ(gErr->name, "T"); + CHECK_EQ(gErr->kind, SwappedGenericTypeParameter::Type); +} + +TEST_CASE_FIXTURE(Fixture, "reject_clashing_generic_and_pack_names") +{ + ScopedFastFlag sffs{"LuauGenericFunctions", true}; + ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + function f() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + DuplicateGenericParameter* err = get(result.errors[0]); + REQUIRE(err != nullptr); + CHECK_EQ(err->parameterName, "a"); +} + +TEST_CASE_FIXTURE(Fixture, "instantiation_sharing_types") +{ + ScopedFastFlag sffs1{"LuauGenericFunctions", true}; + + CheckResult result = check(R"( + function f(z) + local o = {} + o.x = o + o.y = {5} + o.z = z + return o + end + local o1 = f(true) + local x1, y1, z1 = o1.x, o1.y, o1.z + local o2 = f("hi") + local x2, y2, z2 = o2.x, o2.y, o2.z + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(requireType("x1") != requireType("x2")); + CHECK(requireType("y1") == requireType("y2")); + CHECK(requireType("z1") != requireType("z2")); +} + +TEST_CASE_FIXTURE(Fixture, "quantification_sharing_types") +{ + ScopedFastFlag sffs1{"LuauGenericFunctions", true}; + + CheckResult result = check(R"( + function f(x) return {5} end + function g(x, y) return f(x) end + local z1 = f(5) + local z2 = g(true, "hi") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(requireType("z1") == requireType("z2")); +} + +TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") +{ + ScopedFastFlag sffs1{"LuauGenericFunctions", true}; + + CheckResult result = check(R"( + type T = { x: {a}, y: {number} } + local o1: T = { x = {true}, y = {5} } + local x1, y1 = o1.x, o1.y + local o2: T = { x = {"hi"}, y = {37} } + local x2, y2 = o2.x, o2.y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(requireType("x1") != requireType("x2")); + CHECK(requireType("y1") == requireType("y2")); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp new file mode 100644 index 0000000..9685f4f --- /dev/null +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -0,0 +1,344 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("IntersectionTypes"); + +TEST_CASE_FIXTURE(Fixture, "select_correct_union_fn") +{ + CheckResult result = check(R"( + type A = (number) -> (string) + type B = (string) -> (number) + local f:A & B + local b = f(10) -- b is a string + local c = f("a") -- c is a number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(requireType("b"), typeChecker.stringType); + CHECK_EQ(requireType("c"), typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "table_combines") +{ + CheckResult result = check(R"( + type A={a:number} + type B={b:string} + local c:A & B = {a=10, b="s"} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_combines_missing") +{ + CheckResult result = check(R"( + type A={a:number} + type B={b:string} + local c:A & B = {a=10} + )"); + + REQUIRE(result.errors.size() == 1); +} + +TEST_CASE_FIXTURE(Fixture, "impossible_type") +{ + CheckResult result = check(R"( + local c:number&string = 10 + )"); + + REQUIRE(result.errors.size() == 1); +} + +TEST_CASE_FIXTURE(Fixture, "table_extra_ok") +{ + CheckResult result = check(R"( + type A={a:number} + type B={b:string} + local c:A & B + local d:A = c + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fx_intersection_as_argument") +{ + CheckResult result = check(R"( + type A = (number) -> (string) + type B = (string) -> (number) + type C = (A) -> (number) + local f:A & B + local g:C + local b = g(f) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "fx_union_as_argument_fails") +{ + CheckResult result = check(R"( + type A = (number) -> (string) + type B = (string) -> (number) + type C = (A) -> (number) + local f:A | B + local g:C + local b = g(f) + )"); + + REQUIRE(!result.errors.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "argument_is_intersection") +{ + CheckResult result = check(R"( + type A = (number | boolean) -> number + local f: A + + f(5) + f(true) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "should_still_pick_an_overload_whose_arguments_are_unions") +{ + CheckResult result = check(R"( + type A = (number | boolean) -> number + type B = (string | nil) -> string + local f: A & B + + local a1, a2 = f(1), f(true) + local b1, b2 = f("foo"), f(nil) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*requireType("a1"), *typeChecker.numberType); + CHECK_EQ(*requireType("a2"), *typeChecker.numberType); + + CHECK_EQ(*requireType("b1"), *typeChecker.stringType); + CHECK_EQ(*requireType("b2"), *typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "propagates_name") +{ + const std::string code = R"( + type A={a:number} + type B={b:string} + + local c:A&B + local b = c + )"; + const std::string expected = R"( + type A={a:number} + type B={b:string} + + local c:A&B + local b:A&B=c + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guaranteed_to_exist") +{ + CheckResult result = check(R"( + type A = {x: {y: number}} + type B = {x: {y: number}} + local t: A & B + + local r = t.x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const IntersectionTypeVar* r = get(requireType("r")); + REQUIRE(r); + + TableTypeVar* a = getMutable(r->parts[0]); + REQUIRE(a); + CHECK_EQ(typeChecker.numberType, a->props["y"].type); + + TableTypeVar* b = getMutable(r->parts[1]); + REQUIRE(b); + CHECK_EQ(typeChecker.numberType, b->props["y"].type); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_depth") +{ + CheckResult result = check(R"( + type A = {x: {y: {z: {thing: string}}}} + type B = {x: {y: {z: {thing: string}}}} + local t: A & B + + local r = t.x.y.z.thing + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.stringType, *requireType("r")); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_mixed_types") +{ + CheckResult result = check(R"( + type A = {x: number} + type B = {x: string} + local t: A & B + + local r = t.x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number & string", toString(requireType("r"))); // TODO(amccord): This should be an error. +} + +TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_part_missing_the_property") +{ + CheckResult result = check(R"( + type A = {x: number} + type B = {} + local t: A & B + + local r = t.x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("r"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_property_of_type_any") +{ + CheckResult result = check(R"( + type A = {x: number} + type B = {x: any} + local t: A & B + + local r = t.x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.anyType, *requireType("r")); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_all_parts_missing_the_property") +{ + CheckResult result = check(R"( + type A = {} + type B = {} + + local function f(t: A & B) + local x = t.x + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownProperty* up = get(result.errors[0]); + REQUIRE_MESSAGE(up, result.errors[0].data); + CHECK_EQ(up->key, "x"); +} + +TEST_CASE_FIXTURE(Fixture, "table_intersection_write") +{ + CheckResult result = check(R"( + type X = { x: number } + type XY = X & { y: number } + + local a : XY = { x = 1, y = 2 } + a.x = 10 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( + type X = {} + type XY = X & { x: number, y: number } + + local a : XY = { x = 1, y = 2 } + a.x = 10 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( + type X = { x: number } + type Y = { y: number } + type XY = X & Y + + local a : XY = { x = 1, y = 2 } + a.x = 10 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( + type A = { x: {y: number} } + type B = { x: {y: number} } + local t : A & B = { x = { y = 1 } } + + t.x = { y = 4 } + t.x.y = 40 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed") +{ + CheckResult result = check(R"( + type X = { x: number } + type Y = { y: number } + type XY = X & Y + + local a : XY = { x = 1, y = 2 } + a.z = 10 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); +} + +TEST_CASE_FIXTURE(Fixture, "table_intersection_write_sealed_indirect") +{ + CheckResult result = check(R"( + type X = { x: (number) -> number } + type Y = { y: (string) -> string } + + type XY = X & Y + + local xy : XY = { + x = function(a: number) return -a end, + y = function(a: string) return a .. "b" end + } + function xy.z(a:number) return a * 10 end + function xy:y(a:number) return a * 10 end + function xy:w(a:number) return a * 10 end + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + CHECK_EQ(toString(result.errors[0]), "Cannot add property 'z' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[1]), "Cannot add property 'y' to table 'X & Y'"); + CHECK_EQ(toString(result.errors[2]), "Cannot add property 'w' to table 'X & Y'"); +} + +TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") +{ + CheckResult result = check(R"( + local t: {} & {} + setmetatable(t, {}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp new file mode 100644 index 0000000..46496fd --- /dev/null +++ b/tests/TypeInfer.provisional.test.cpp @@ -0,0 +1,588 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" + +#include "Fixture.h" + +#include "doctest.h" + +#include + +LUAU_FASTFLAG(LuauEqConstraint) + +using namespace Luau; + +TEST_SUITE_BEGIN("ProvisionalTests"); + +// These tests check for behavior that differes from the final behavior we'd +// like to have. They serve to document the current state of the typechecker. +// When making future improvements, its very likely these tests will break and +// will need to be replaced. + +/* + * This test falls into a sort of "do as I say" pit of consequences: + * Technically, the type of the type() function is (T) -> string + * + * We thus infer that the argument to f is a free type. + * While we can still learn something about this argument, we can't seem to infer a union for it. + * + * Is this good? Maybe not, but I'm not sure what else we should do. + */ +TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") +{ + const std::string code = R"( + function f(a) + if type(a) == "boolean" then + local a1 = a + elseif a.fn() then + local a2 = a + end + end + )"; + + const std::string expected = R"( + function f(a:{fn:()->(free)}): () + if type(a) == 'boolean'then + local a1:boolean=a + elseif a.fn()then + local a2:{fn:()->(free)}=a + end + end + )"; + CHECK_EQ(expected, decorateWithTypes(code)); +} + +TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") +{ + const std::string code = R"( + local a, b, c = xpcall(function() return 1, "foo" end, function() return "foo", 1 end) + )"; + + const std::string expected = R"( + local a:boolean,b:number,c:string=xpcall(function(): (number,string)return 1,'foo'end,function(): (string,number)return'foo',1 end) + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); +} + +// We had a bug where if you have two type packs that looks like: +// { x, y }, ... +// { x }, ... +// It would infinitely grow the type pack because one WeirdIter is trying to catch up, but can't. +// However, the following snippet is supposed to generate an OccursCheckFailed, but it doesn't. +TEST_CASE_FIXTURE(Fixture, "weirditer_should_not_loop_forever") +{ + // this flag is intentionally here doing nothing to demonstrate that we exit early via case detection + ScopedFastInt sfis{"LuauTypeInferTypePackLoopLimit", 50}; + + CheckResult result = check(R"( + local function toVertexList(vertices, x, y, ...) + if not (x and y) then return vertices end -- no more arguments + vertices[#vertices + 1] = {x = x, y = y} -- set vertex + return toVertexList(vertices, ...) -- recurse + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// This should also generate an OccursCheckFailed error too, like the above toVertexList snippet. +// at least up until we can get Luau to recognize this code as a valid function that iterates over a list of values in the pack. +TEST_CASE_FIXTURE(Fixture, "it_should_be_agnostic_of_actual_size") +{ + CheckResult result = check(R"( + local function f(x, y, ...) + if not y then return x end + return f(x, ...) + end + + f(3, 2, 1, 0) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// Ideally setmetatable's second argument would be an optional free table. +// For now, infer it as just a free table. +TEST_CASE_FIXTURE(Fixture, "setmetatable_constrains_free_type_into_free_table") +{ + CheckResult result = check(R"( + local a = {} + local b + setmetatable(a, b) + b = 1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("{- -}", toString(tm->wantedType)); + CHECK_EQ("number", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") +{ + CheckResult result = check(R"( + local a: {x: number, y: number, [any]: any} | {y: number} + + function f(t) + t.y = 1 + return t + end + + local b = f(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // :( + // Should be the same as the type of a + REQUIRE_EQ("{| y: number |}", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") +{ + CheckResult result = check(R"( + local a: {y: number} | {x: number, y: number, [any]: any} + + function f(t) + t.y = 1 + return t + end + + local b = f(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // :( + // Should be the same as the type of a + REQUIRE_EQ("{| [any]: any, x: number, y: number |}", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "normal_conditional_expression_has_refinements") +{ + CheckResult result = check(R"( + local foo: {x: number}? = nil + local bar = foo and foo.x -- TODO: Geez. We are inferring the wrong types here. Should be 'number?'. + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Binary and/or return types are straight up wrong. JIRA: CLI-40300 + CHECK_EQ("boolean | number", toString(requireType("bar"))); +} + +// Luau currently doesn't yet know how to allow assignments when the binding was refined. +TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") +{ + ScopedFastFlag sffs2{"LuauGenericFunctions", true}; + ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + type Node = { value: T, child: Node? } + + local function visitor(node: Node, f: (T) -> ()) + local current = node + + while current do + f(current.value) + current = current.child -- TODO: Can't work just yet. It thinks 'current' can never be nil. :( + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'Node?' could not be converted into 'Node'", toString(result.errors[0])); +} + +// Originally from TypeInfer.test.cpp. +// I dont think type checking the metamethod at every site of == is the correct thing to do. +// We should be type checking the metamethod at the call site of setmetatable. +TEST_CASE_FIXTURE(Fixture, "error_on_eq_metamethod_returning_a_type_other_than_boolean") +{ + CheckResult result = check(R"( + local tab = {a = 1} + setmetatable(tab, {__eq = function(a, b): number + return 1 + end}) + local tab2 = tab + + local a = tab2 == tab + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Metamethod '__eq' must return type 'boolean'", ge->message); +} + +// Requires success typing to confidently determine that this expression has no overlap. +TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") +{ + CheckResult result = check(R"( + local a: string | number = "hi" + local b: {x: string}? = {x = "bye"} + + local r1 = a == b + local r2 = b == a + )"); + + if (FFlag::LuauEqConstraint) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Type '{| x: string |}?' could not be converted into 'number | string'"); + CHECK_EQ(toString(result.errors[1]), "Type 'number | string' could not be converted into '{| x: string |}?'"); + } +} + +// Belongs in TypeInfer.refinements.test.cpp. +// We'll need to not only report an error on `a == b`, but also to refine both operands as `never` in the `==` branch. +TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local function f(a: string, b: boolean?) + if a == b then + local foo, bar = a, b + else + local foo, bar = a, b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b + + CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "string"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b +} + +TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) +{ + ScopedFastInt sffi{"LuauTarjanChildLimit", 50}; + ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 50}; + + CheckResult result = check(R"LUA( + local Result + Result = setmetatable({}, {}) + Result.__index = Result + function Result.new(okValue) + local self = setmetatable({}, Result) + self:constructor(okValue) + return self + end + function Result:constructor(okValue) + self.okValue = okValue + end + function Result:ok(val) return Result.new(val) end + function Result:a(p0, p1, p2, p3, p4) return Result.new((self.okValue)) or p0 or p1 or p2 or p3 or p4 end + function Result:b(p0, p1, p2, p3, p4) return Result:ok((self.okValue)) or p0 or p1 or p2 or p3 or p4 end + function Result:c(p0, p1, p2, p3, p4) return Result:ok((self.okValue)) or p0 or p1 or p2 or p3 or p4 end + function Result:transpose(a) + return a and self.okValue:z(function(some) + return Result:ok(some) + end) or Result:ok(self.okValue) + end + )LUA"); + + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& a) { + return nullptr != get(a); + }); + if (it == result.errors.end()) + { + dumpErrors(result); + FAIL("Expected a UnificationTooComplex error"); + } +} + +TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doctest::timeout(1.0)) +{ + ScopedFastInt sffi{"LuauTarjanChildLimit", 400}; + + CheckResult result = check(R"LUA( + --!strict + local TS = _G[script] + local lazyGet = TS.import(script, script.Parent.Parent, "util", "lazyLoad").lazyGet + local unit = TS.import(script, script.Parent.Parent, "util", "Unit").unit + local Iterator + lazyGet("Iterator", function(c) + Iterator = c + end) + local Option + lazyGet("Option", function(c) + Option = c + end) + local Vec + lazyGet("Vec", function(c) + Vec = c + end) + local Result + do + Result = setmetatable({}, { + __tostring = function() + return "Result" + end, + }) + Result.__index = Result + function Result.new(...) + local self = setmetatable({}, Result) + self:constructor(...) + return self + end + function Result:constructor(okValue, errValue) + self.okValue = okValue + self.errValue = errValue + end + function Result:ok(val) + return Result.new(val, nil) + end + function Result:err(val) + return Result.new(nil, val) + end + function Result:fromCallback(c) + local _0 = c + local _1, _2 = pcall(_0) + local result = _1 and { + success = true, + value = _2, + } or { + success = false, + error = _2, + } + return result.success and Result:ok(result.value) or Result:err(Option:wrap(result.error)) + end + function Result:fromVoidCallback(c) + local _0 = c + local _1, _2 = pcall(_0) + local result = _1 and { + success = true, + value = _2, + } or { + success = false, + error = _2, + } + return result.success and Result:ok(unit()) or Result:err(Option:wrap(result.error)) + end + Result.fromPromise = TS.async(function(self, p) + local _0, _1 = TS.try(function() + return TS.TRY_RETURN, { Result:ok(TS.await(p)) } + end, function(e) + return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } + end) + if _0 then + return unpack(_1) + end + end) + Result.fromVoidPromise = TS.async(function(self, p) + local _0, _1 = TS.try(function() + TS.await(p) + return TS.TRY_RETURN, { Result:ok(unit()) } + end, function(e) + return TS.TRY_RETURN, { Result:err(Option:wrap(e)) } + end) + if _0 then + return unpack(_1) + end + end) + function Result:isOk() + return self.okValue ~= nil + end + function Result:isErr() + return self.errValue ~= nil + end + function Result:contains(x) + return self.okValue == x + end + function Result:containsErr(x) + return self.errValue == x + end + function Result:okOption() + return Option:wrap(self.okValue) + end + function Result:errOption() + return Option:wrap(self.errValue) + end + function Result:map(func) + return self:isOk() and Result:ok(func(self.okValue)) or Result:err(self.errValue) + end + function Result:mapOr(def, func) + local _0 + if self:isOk() then + _0 = func(self.okValue) + else + _0 = def + end + return _0 + end + function Result:mapOrElse(def, func) + local _0 + if self:isOk() then + _0 = func(self.okValue) + else + _0 = def(self.errValue) + end + return _0 + end + function Result:mapErr(func) + return self:isErr() and Result:err(func(self.errValue)) or Result:ok(self.okValue) + end + Result["and"] = function(self, other) + return self:isErr() and Result:err(self.errValue) or other + end + function Result:andThen(func) + return self:isErr() and Result:err(self.errValue) or func(self.okValue) + end + Result["or"] = function(self, other) + return self:isOk() and Result:ok(self.okValue) or other + end + function Result:orElse(other) + return self:isOk() and Result:ok(self.okValue) or other(self.errValue) + end + function Result:expect(msg) + if self:isOk() then + return self.okValue + else + error(msg) + end + end + function Result:unwrap() + return self:expect("called `Result.unwrap()` on an `Err` value: " .. tostring(self.errValue)) + end + function Result:unwrapOr(def) + local _0 + if self:isOk() then + _0 = self.okValue + else + _0 = def + end + return _0 + end + function Result:unwrapOrElse(gen) + local _0 + if self:isOk() then + _0 = self.okValue + else + _0 = gen(self.errValue) + end + return _0 + end + function Result:expectErr(msg) + if self:isErr() then + return self.errValue + else + error(msg) + end + end + function Result:unwrapErr() + return self:expectErr("called `Result.unwrapErr()` on an `Ok` value: " .. tostring(self.okValue)) + end + function Result:transpose() + return self:isOk() and self.okValue:map(function(some) + return Result:ok(some) + end) or Option:some(Result:err(self.errValue)) + end + function Result:flatten() + return self:isOk() and Result.new(self.okValue.okValue, self.okValue.errValue) or Result:err(self.errValue) + end + function Result:match(ifOk, ifErr) + local _0 + if self:isOk() then + _0 = ifOk(self.okValue) + else + _0 = ifErr(self.errValue) + end + return _0 + end + function Result:asPtr() + local _0 = (self.okValue) + if _0 == nil then + _0 = (self.errValue) + end + return _0 + end + end + local resultMeta = Result + resultMeta.__eq = function(a, b) + return b:match(function(ok) + return a:contains(ok) + end, function(err) + return a:containsErr(err) + end) + end + resultMeta.__tostring = function(result) + return result:match(function(ok) + return "Result.ok(" .. tostring(ok) .. ")" + end, function(err) + return "Result.err(" .. tostring(err) .. ")" + end) + end + return { + Result = Result, + } + )LUA"); + + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& a) { + return nullptr != get(a); + }); + if (it == result.errors.end()) + { + dumpErrors(result); + FAIL("Expected a UnificationTooComplex error"); + } +} + +// Should be in TypeInfer.tables.test.cpp +// It's unsound to instantiate tables containing generic methods, +// since mutating properties means table properties should be invariant. +// We currently allow this but we shouldn't! +TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_call_is_unsound") +{ + CheckResult result = check(R"( + --!strict + local t = {} + function t.m(x) return x end + local a : string = t.m("hi") + local b : number = t.m(5) + function f(x : { m : (number)->number }) + x.m = function(x) return 1+x end + end + f(t) -- This shouldn't typecheck + local c : string = t.m("hi") + )"); + + // TODO: this should error! + // This should be fixed by replacing generic tables by generics with type bounds. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") +{ + ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; + ScopedFastFlag luauFollowInTypeFunApply{"LuauFollowInTypeFunApply", true}; + ScopedFastFlag luauInstantiatedTypeParamRecursion{"LuauInstantiatedTypeParamRecursion", true}; + + // Mutability in type function application right now can create strange recursive types + // TODO: instantiation right now is problematic, it this example should either leave the Table type alone + // or it should rename the type to 'Self' so that the result will be 'Self
' + CheckResult result = check(R"( +type Table = { a: number } +type Self = T +local a: Self
+ )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a")), "Table
"); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp new file mode 100644 index 0000000..f2ba0dd --- /dev/null +++ b/tests/TypeInfer.refinements.test.cpp @@ -0,0 +1,1160 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeInfer.h" + +#include "Fixture.h" + +#include "doctest.h" + +LUAU_FASTFLAG(LuauWeakEqConstraint) +LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) +LUAU_FASTFLAG(LuauOrPredicate) + +using namespace Luau; + +TEST_SUITE_BEGIN("RefinementTest"); + +TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint") +{ + CheckResult result = check(R"( + function f(v: string?) + if v then + local s = v + else + local s = v + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({5, 26}))); +} + +TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint") +{ + CheckResult result = check(R"( + function f(v: string?) + if not v then + local s = v + else + local s = v + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); +} + +TEST_CASE_FIXTURE(Fixture, "parenthesized_expressions_are_followed_through") +{ + CheckResult result = check(R"( + function f(v: string?) + if (not v) then + local s = v + else + local s = v + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 26}))); +} + +TEST_CASE_FIXTURE(Fixture, "and_constraint") +{ + CheckResult result = check(R"( + function f(a: string?, b: number?) + if a and b then + local x = a + local y = b + else + local x = a + local y = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("number", toString(requireTypeAtPosition({4, 26}))); + + CHECK_EQ("string?", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({7, 26}))); +} + +TEST_CASE_FIXTURE(Fixture, "not_and_constraint") +{ + CheckResult result = check(R"( + function f(a: string?, b: number?) + if not (a and b) then + local x = a + local y = b + else + local x = a + local y = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); + + CHECK_EQ("string", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("number", toString(requireTypeAtPosition({7, 26}))); +} + +TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") +{ + CheckResult result = check(R"( + function f(a: string?, b: number?) + if a or b then + local x = a + local y = b + else + local x = a + local y = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); + + if (FFlag::LuauOrPredicate) + { + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); + } +} + +TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") +{ + CheckResult result = check(R"( + function g(a: number?, b: string?) + if (a :: any) and (b :: any) then + local x = a + local y = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); + CHECK_EQ("string", toString(requireTypeAtPosition({4, 26}))); +} + +TEST_CASE_FIXTURE(Fixture, "typeguard_in_if_condition_position") +{ + CheckResult result = check(R"( + function f(s: any) + if type(s) == "number" then + local n = s + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireTypeAtPosition({3, 26}))); +} + +TEST_CASE_FIXTURE(Fixture, "typeguard_in_assert_position") +{ + CheckResult result = check(R"( + local a + assert(type(a) == "number") + local b = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") +{ + CheckResult result = check(R"( + type ActuallyString = string + + do -- Necessary. Otherwise toposort has ActuallyString come after string type alias. + type string = number + local foo: string = 1 + + if type(foo) == "string" then + local bar: ActuallyString = foo + local baz: boolean = foo + end + end + )"); + + if (FFlag::LuauImprovedTypeGuardPredicate2) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'boolean'", toString(result.errors[0])); + } +} + +TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") +{ + CheckResult result = check(R"( + local function f(x: number) + return x + end + + local function g(x: any) + if type(x) == "string" then + f(x) + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'string' could not be converted into 'number'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "impossible_type_narrow_is_not_an_error") +{ + // This unit test serves as a reminder to not implement this warning until Luau is intelligent enough. + // For instance, getting a value out of the indexer and checking whether the value exists is not an error. + CheckResult result = check(R"( + local t: {string} = {"a", "b", "c"} + local v = t[4] + if not v then + t[4] = "d" + else + print(v) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") +{ + + CheckResult result = check(R"( + local t: {x: number?} = {x = 1} + + if t.x then + local foo: number = t.x + end + + local bar = t.x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number?", toString(requireType("bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") +{ + + CheckResult result = check(R"( + local t: {x: {y: string}?} = {x = {y = "hello!"}} + + if t.x then + print(t.x.y) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_constraints") +{ + CheckResult result = check(R"( + local foo: string? = "hello" + assert(foo) + local bar: string = foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") +{ + + CheckResult result = check(R"( + local t: {x: number?} = {x = nil} + + if t.x then + local u: {x: number} = t + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '{| x: number? |}' could not be converted into '{| x: number |}'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local function f(a: (string | number)?, b: boolean?) + if a == b then + local foo, bar = a, b + else + local foo, bar = a, b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauWeakEqConstraint) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "(number | string)?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "boolean?"); // a == b + + CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "nil"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "nil"); // a == b + + CHECK_EQ(toString(requireTypeAtPosition({5, 33})), "(number | string)?"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b + } +} + +TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local function f(a: (string | number)?) + if a == 1 then + local foo = a + else + local foo = a + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauWeakEqConstraint) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == 1 + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number"); // a == 1 + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= 1 + } +} + +TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local function f(a: (string | number)?) + if "hello" == a then + local foo = a + else + local foo = a + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauWeakEqConstraint) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == "hello" + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "string"); // a == "hello" + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" + } +} + +TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local function f(a: (string | number)?) + if a ~= nil then + local foo = a + else + local foo = a + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauWeakEqConstraint) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a == nil + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "number | string"); // a ~= nil + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "nil"); // a == nil + } +} + +TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local function f(a, b: string?) + if a == b then + local foo, bar = a, b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauWeakEqConstraint) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "string?"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b + } +} + +TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local function f(a: any, b: {x: number}?) + if a ~= b then + local foo, bar = a, b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + if (FFlag::LuauWeakEqConstraint) + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}?"); // a ~= b + } + else + { + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "any"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "{| x: number |}"); // a ~= b + } +} + +TEST_CASE_FIXTURE(Fixture, "string_not_equal_to_string_or_nil") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local t: {string} = {"hello"} + + local a: string = t[1] + local b: string? = nil + if a ~= b then + local foo, bar = a, b + else + local foo, bar = a, b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireTypeAtPosition({6, 29})), "string"); // a ~= b + CHECK_EQ(toString(requireTypeAtPosition({6, 32})), "string?"); // a ~= b + + if (FFlag::LuauWeakEqConstraint) + { + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string?"); // a == b + } + else + { + // This is technically not wrong, but it's also wrong at the same time. + // The refinement code is none the wiser about the fact we pulled a string out of an array, so it has no choice but to narrow as just string. + CHECK_EQ(toString(requireTypeAtPosition({8, 29})), "string"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({8, 32})), "string"); // a == b + } +} + +TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") +{ + + CheckResult result = check(R"( + local t + local u: {x: number?} = {x = nil} + t = u + + if t.x then + local foo: number = t.x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + local function f(x) + if type(x) == "vector" then + local foo = x + end + end + )"); + + // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + local t = {"hello"} + local v = t[2] + if type(v) == "nil" then + local foo = v + else + local foo = v + end + + if not (type(v) ~= "nil") then + local foo = v + else + local foo = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 24}))); // type(v) == "nil" + CHECK_EQ("string", toString(requireTypeAtPosition({6, 24}))); // type(v) ~= "nil" + + CHECK_EQ("nil", toString(requireTypeAtPosition({10, 24}))); // equivalent to type(v) == "nil" + CHECK_EQ("string", toString(requireTypeAtPosition({12, 24}))); // equivalent to type(v) ~= "nil" +} + +TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + local function f(x: string | number | boolean) + if type(x) ~= "string" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("boolean | number", toString(requireTypeAtPosition({3, 28}))); // type(x) ~= "string" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) == "string" +} + +TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + local function f(x: string | {x: number} | {y: boolean}) + if type(x) == "table" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{| x: number |} | {| y: boolean |}", toString(requireTypeAtPosition({3, 28}))); // type(x) == "table" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "table" +} + +TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + local function weird(x: string | ((number) -> string)) + if type(x) == "function" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(number) -> string", toString(requireTypeAtPosition({3, 28}))); // type(x) == "function" + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" +} + +namespace +{ +std::optional> magicFunctionInstanceIsA( + TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +{ + if (expr.args.size != 1) + return std::nullopt; + + auto index = expr.func->as(); + auto str = expr.args.data[0]->as(); + if (!index || !str) + return std::nullopt; + + std::optional lvalue = tryGetLValue(*index->expr); + std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); + if (!lvalue || !tfun) + return std::nullopt; + + unfreeze(typeChecker.globalTypes); + TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); + freeze(typeChecker.globalTypes); + return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; +} + +struct RefinementClassFixture : Fixture +{ + RefinementClassFixture() + { + TypeArena& arena = typeChecker.globalTypes; + + unfreeze(arena); + TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); + getMutable(vec3)->props = { + {"X", Property{typeChecker.numberType}}, + {"Y", Property{typeChecker.numberType}}, + {"Z", Property{typeChecker.numberType}}, + }; + + TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); + + TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); + TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); + TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); + getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + + getMutable(inst)->props = { + {"Name", Property{typeChecker.stringType}}, + {"IsA", Property{isA}}, + }; + + TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); + TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); + getMutable(part)->props = { + {"Position", Property{vec3}}, + }; + + typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; + typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; + typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; + typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; + freeze(typeChecker.globalTypes); + } +}; +} // namespace + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + local function f(vec) + local X, Y, Z = vec.X, vec.Y, vec.Z + + if type(vec) == "vector" then + local foo = vec + elseif typeof(vec) == "Instance" then + local foo = vec + else + local foo = vec + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" + + CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" + + CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + local function f(x: Instance | Vector3) + if typeof(x) == "Vector3" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + local function f(x: string | number | Instance | Vector3) + if type(x) == "userdata" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Instance | Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") +{ + ScopedFastFlag sffs[] = { + {"LuauImprovedTypeGuardPredicate2", true}, + {"LuauTypeGuardPeelsAwaySubclasses", true}, + }; + + CheckResult result = check(R"( + local function f(x: Part | Folder | string) + if typeof(x) == "Instance" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") +{ + ScopedFastFlag sffs[] = { + {"LuauImprovedTypeGuardPredicate2", true}, + {"LuauTypeGuardPeelsAwaySubclasses", true}, + }; + + CheckResult result = check(R"( + local function f(x: Part | Folder | Instance | string | Vector3 | any) + if typeof(x) == "Instance" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") +{ + ScopedFastFlag sffs[] = { + {"LuauOrPredicate", true}, + {"LuauImprovedTypeGuardPredicate2", true}, + }; + + CheckResult result = check(R"( + --!nonstrict + + local function f(x) + if typeof(x) == "Instance" and x:IsA("Folder") then + local foo = x + elseif typeof(x) == "table" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") +{ + ScopedFastFlag sffs[] = { + {"LuauOrPredicate", true}, + {"LuauImprovedTypeGuardPredicate2", true}, + {"LuauTypeGuardPeelsAwaySubclasses", true}, + }; + + CheckResult result = check(R"( + local function f(x: Part | Folder | string) + if typeof(x) ~= "Instance" or not x:IsA("Part") then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + type XYCoord = {x: number} & {y: number} + local function f(t: XYCoord?) + if type(t) == "table" then + local foo = t + else + local foo = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{| x: number |} & {| y: number |}", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + type SomeOverloadedFunction = ((number) -> string) & ((string) -> number) + local function f(g: SomeOverloadedFunction?) + if type(g) == "function" then + local foo = g + else + local foo = g + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("((number) -> string) & ((string) -> number)", toString(requireTypeAtPosition({4, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") +{ + ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; + + CheckResult result = check(R"( + local function f(t: {x: number}) + if type(t) ~= "table" then + local foo = t + error(("Expected a table, got %s"):format(type(t))) + end + + return t.x + 1 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") +{ + ScopedFastFlag sff{"LuauOrPredicate", true}; + + CheckResult result = check(R"( + local function f(a: number?, b: number?) + if (not a) or (not b) then + local foo = a + local bar = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number?", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b2") +{ + ScopedFastFlag sff{"LuauOrPredicate", true}; + + CheckResult result = check(R"( + local function f(a: number?, b: number?) + if not (a and b) then + local foo = a + local bar = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number?", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number?", toString(requireTypeAtPosition({4, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") +{ + ScopedFastFlag sff{"LuauOrPredicate", true}; + + CheckResult result = check(R"( + local function f(a: number?, b: number?) + if (not a) and (not b) then + local foo = a + local bar = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") +{ + ScopedFastFlag sff{"LuauOrPredicate", true}; + + CheckResult result = check(R"( + local function f(a: number?, b: number?) + if not (a or b) then + local foo = a + local bar = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({4, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "either_number_or_string") +{ + ScopedFastFlag sffs[] = { + {"LuauOrPredicate", true}, + {"LuauImprovedTypeGuardPredicate2", true}, + }; + + CheckResult result = check(R"( + local function f(x: any) + if type(x) == "number" or type(x) == "string" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") +{ + ScopedFastFlag sff{"LuauOrPredicate", true}; + + CheckResult result = check(R"( + local function f(t: {x: boolean}?) + if not t or t.x then + local foo = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{| x: boolean |}?", toString(requireTypeAtPosition({3, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") +{ + ScopedFastFlag sffs[] = { + {"LuauOrPredicate", true}, + {"LuauImprovedTypeGuardPredicate2", true}, + }; + + CheckResult result = check(R"( + local a: (number | string)? + assert(a) + local b = a + assert(type(a) == "number") + local c = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number | string", toString(requireTypeAtPosition({3, 18}))); + CHECK_EQ("number", toString(requireTypeAtPosition({5, 18}))); +} + +TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") +{ + ScopedFastFlag sffs[] = { + {"LuauOrPredicate", true}, + {"LuauImprovedTypeGuardPredicate2", true}, + }; + + // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. + CheckResult result = check(R"( + local function f(b: string | { x: string }, a) + assert(type(a) == "string") + assert(type(b) == "string" or type(b) == "table") + + if type(b) == "string" then + local foo = b + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({6, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") +{ + ScopedFastFlag sffs[] = { + {"LuauOrPredicate", true}, + {"LuauImprovedTypeGuardPredicate2", true}, + }; + + CheckResult result = check(R"( + local function f(a: string | number | boolean) + if type(a) ~= "number" and type(a) ~= "string" then + local foo = a + else + local foo = a + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("boolean", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") +{ + ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + + CheckResult result = check(R"( + function f(v:string?) + return if v then v else tostring(v) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({2, 29}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({2, 45}))); +} + +TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") +{ + ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + + CheckResult result = check(R"( + function f(v:string?) + return if not v then tostring(v) else v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("nil", toString(requireTypeAtPosition({2, 42}))); + CHECK_EQ("string", toString(requireTypeAtPosition({2, 50}))); +} + +TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") +{ + ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + + CheckResult result = check(R"( + function returnOne(x) + return 1 + end + + function f(v:any) + return if typeof(v) == "number" then v else returnOne(v) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireTypeAtPosition({6, 49}))); + CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp new file mode 100644 index 0000000..1d1b2fa --- /dev/null +++ b/tests/TypeInfer.tables.test.cpp @@ -0,0 +1,1827 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +#include + +using namespace Luau; + +TEST_SUITE_BEGIN("TableTests"); + +TEST_CASE_FIXTURE(Fixture, "basic") +{ + CheckResult result = check("local t = {foo = \"bar\", baz = 9, quux = nil}"); + LUAU_REQUIRE_NO_ERRORS(result); + + const TableTypeVar* tType = get(requireType("t")); + REQUIRE(tType != nullptr); + + std::optional fooProp = get(tType->props, "foo"); + REQUIRE(bool(fooProp)); + CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(fooProp->type)); + + std::optional bazProp = get(tType->props, "baz"); + REQUIRE(bool(bazProp)); + CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(bazProp->type)); + + std::optional quuxProp = get(tType->props, "quux"); + REQUIRE(bool(quuxProp)); + CHECK_EQ(PrimitiveTypeVar::NilType, getPrimitiveType(quuxProp->type)); +} + +TEST_CASE_FIXTURE(Fixture, "augment_table") +{ + CheckResult result = check("local t = {} t.foo = 'bar'"); + LUAU_REQUIRE_NO_ERRORS(result); + + const TableTypeVar* tType = get(requireType("t")); + REQUIRE(tType != nullptr); + + CHECK(tType->props.find("foo") != tType->props.end()); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") +{ + CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeError& err = result.errors[0]; + CannotExtendTable* error = get(err); + REQUIRE(error != nullptr); + + // TODO: better, more robust comparison of type vars + auto s = toString(error->tableType, ToStringOptions{/*exhaustive*/ true}); + CHECK_EQ(s, "{| prop: number |}"); + CHECK_EQ(error->prop, "foo"); + CHECK_EQ(error->context, CannotExtendTable::Property); + CHECK_EQ(err.location, (Location{Position{0, 24}, Position{0, 29}})); +} + +TEST_CASE_FIXTURE(Fixture, "dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table") +{ + CheckResult result = check(R"( + type T = {[number]: number} + function f(arg: T) end + + local B = {} + f(B) + function B:method() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "updating_sealed_table_prop_is_ok") +{ + CheckResult result = check("local t = {prop=999} t.prop = 0"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_change_type_of_unsealed_table_prop") +{ + CheckResult result = check("local t = {} t.prop = 999 t.prop = 'hello'"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_change_type_of_table_prop") +{ + CheckResult result = check("local t = {prop=999} t.prop = 'hello'"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "function_calls_can_produce_tables") +{ + CheckResult result = check("function get_table() return {prop=999} end get_table().prop = 0"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_calls_produces_sealed_table_given_unsealed_table") +{ + CheckResult result = check(R"( + function f() return {} end + f().foo = 'fail' + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_member_function") +{ + CheckResult result = check("local T = {} function T:foo() return 5 end"); + LUAU_REQUIRE_NO_ERRORS(result); + + const TableTypeVar* tableType = get(requireType("T")); + REQUIRE(tableType != nullptr); + + std::optional fooProp = get(tableType->props, "foo"); + REQUIRE(bool(fooProp)); + + const FunctionTypeVar* methodType = get(follow(fooProp->type)); + REQUIRE(methodType != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "tc_member_function_2") +{ + CheckResult result = check("local T = {U={}} function T.U:foo() return 5 end"); + LUAU_REQUIRE_NO_ERRORS(result); + + const TableTypeVar* tableType = get(requireType("T")); + REQUIRE(tableType != nullptr); + + std::optional uProp = get(tableType->props, "U"); + REQUIRE(bool(uProp)); + TypeId uType = uProp->type; + + const TableTypeVar* uTable = get(uType); + REQUIRE(uTable != nullptr); + + std::optional fooProp = get(uTable->props, "foo"); + REQUIRE(bool(fooProp)); + + const FunctionTypeVar* methodType = get(follow(fooProp->type)); + REQUIRE(methodType != nullptr); + + std::vector methodArgs = flatten(methodType->argTypes).first; + + REQUIRE_EQ(methodArgs.size(), 1); + + // TODO(rblanckaert): Revist when we can bind self at function creation time + // REQUIRE_EQ(*methodArgs[0], *uType); +} + +TEST_CASE_FIXTURE(Fixture, "call_method") +{ + CheckResult result = check("local T = {} T.x = 0 function T:method() return self.x end local a = T:method()"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "call_method_with_explicit_self_argument") +{ + CheckResult result = check("local T = {} T.x = 0 function T:method() return self.x end local a = T.method(T)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon") +{ + CheckResult result = check(R"( + local T = {} + T.x = 0 + function T:method() + return self.x + end + local a = T.method() + )"); + + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](const TypeError& e) { + return nullptr != get(e); + }); + REQUIRE(it != result.errors.end()); +} + +TEST_CASE_FIXTURE(Fixture, "used_colon_correctly") +{ + CheckResult result = check(R"( + --!nonstrict + local upVector = {} + function upVector:Dot(lookVector) + return 8 + end + local v = math.abs(upVector:Dot(5)) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "used_dot_instead_of_colon_but_correctly") +{ + CheckResult result = check(R"( + local T = {} + T.x = 0 + function T:method(arg1, arg2) + return self.x + end + local a = T.method(T, 6, 7) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "used_colon_instead_of_dot") +{ + CheckResult result = check(R"( + local T = {} + T.x = 0 + function T.method() + return 5 + end + local a = T:method() + )"); + + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](const TypeError& e) { + return nullptr != get(e); + }); + REQUIRE(it != result.errors.end()); +} + +#if 0 +TEST_CASE_FIXTURE(Fixture, "open_table_unification") +{ + CheckResult result = check(R"( + function foo(o) + print(o.foo) + print(o.bar) + end + + local a = {} + a.foo = 9 + + local b = {} + b.foo = 0 + + if random() then + b = a + end + + b.bar = '99' + + foo(a) + foo(b) + )"); + LUAU_REQUIRE_NO_ERRORS(result); +} +#endif + +TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") +{ + CheckResult result = check(R"( + local a = {} + a.x = 99 + + function a:method() + return self.y + end + a:method() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypeError& err = result.errors[0]; + UnknownProperty* error = get(err); + REQUIRE(error != nullptr); + + CHECK_EQ(error->key, "y"); + // TODO(rblanckaert): Revist when we can bind self at function creation time + // CHECK_EQ(err.location, Location(Position{5, 19}, Position{5, 25})); + + CHECK_EQ(err.location, Location(Position{7, 8}, Position{7, 9})); +} + +TEST_CASE_FIXTURE(Fixture, "open_table_unification_3") +{ + CheckResult result = check(R"( + function id(x) + return x + end + + function foo(o) + id(o.bar) + id(o.baz) + end + )"); + + TypeId fooType = requireType("foo"); + const FunctionTypeVar* fooFn = get(fooType); + REQUIRE(fooFn != nullptr); + + std::vector fooArgs = flatten(fooFn->argTypes).first; + + REQUIRE_EQ(1, fooArgs.size()); + + TypeId arg0 = fooArgs[0]; + const TableTypeVar* arg0Table = get(follow(arg0)); + REQUIRE(arg0Table != nullptr); + + REQUIRE(arg0Table->props.find("bar") != arg0Table->props.end()); + REQUIRE(arg0Table->props.find("baz") != arg0Table->props.end()); +} + +TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") +{ + CheckResult result = check(R"( + function foo(o) + local a = o.x + local b = o.y + return o + end + + foo({x=55, y=nil, w=3.14159}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") +{ + CheckResult result = check(R"( + --!strict + function foo(o) + local a = o.bar + local b = o.baz + end + + foo({bar='bar'}) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownProperty* error = get(result.errors[0]); + REQUIRE(error != nullptr); + + CHECK_EQ("baz", error->key); +} + +TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") +{ + CheckResult result = check(R"( + local T = {} + T.bar = 'hello' + function T:method() + local a = self.baz + end + T:method() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypeError& err = result.errors[0]; + UnknownProperty* error = get(err); + REQUIRE(error != nullptr); + + // TODO(rblanckaert): Revist when we can bind self at function creation time + /* + CHECK_EQ(err->location, + (Location{ Position{4, 22}, Position{4, 30} }) + ); + */ + + CHECK_EQ(err.location, (Location{Position{6, 8}, Position{6, 9}})); +} + +#if 0 +TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") +{ + CheckResult result = check(R"( + function id(x) + return x + end + + function foo(o) + id(o.x) + id(o.y) + return o + end + + local a = {x=55, y=nil, w=3.14159} + local b = {} + b.x = 1 + b.y = 'hello' + b.z = 'something extra!' + + local q = foo(a) -- line 17 + local w = foo(b) -- line 18 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + for (const auto& e : result.errors) + std::cout << "Error: " << e << std::endl; + + TypeId qType = requireType("q"); + const TableTypeVar* qTable = get(qType); + REQUIRE(qType != nullptr); + + CHECK(qTable->props.find("x") != qTable->props.end()); + CHECK(qTable->props.find("y") != qTable->props.end()); + CHECK(qTable->props.find("z") == qTable->props.end()); + CHECK(qTable->props.find("w") != qTable->props.end()); + + TypeId wType = requireType("w"); + const TableTypeVar* wTable = get(wType); + REQUIRE(wTable != nullptr); + + CHECK(wTable->props.find("x") != wTable->props.end()); + CHECK(wTable->props.find("y") != wTable->props.end()); + CHECK(wTable->props.find("z") != wTable->props.end()); + CHECK(wTable->props.find("w") == wTable->props.end()); +} +#endif + +TEST_CASE_FIXTURE(Fixture, "table_unification_4") +{ + CheckResult result = check(R"( + function foo(o) + if o.prop then + return o + else + return {prop=false} + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") +{ + CheckResult result = check(R"( + function fn(d) + d:Method() + d.prop = true + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_array") +{ + CheckResult result = check(R"( + local t = {} + t[1] = 'one' + t[2] = 'two' + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const TableTypeVar* ttv = get(requireType("t")); + REQUIRE(ttv != nullptr); + + REQUIRE(bool(ttv->indexer)); + + CHECK_EQ(*ttv->indexer->indexType, *typeChecker.numberType); + CHECK_EQ(*ttv->indexer->indexResultType, *typeChecker.stringType); +} + +/* This is a bit weird. + * The type of buttonVector[i] is initially free, compared to a string with == + * We can't actually use this to infer that buttonVector is {string}, and we + * also have a rule that forbids comparing unknown types with those that may have + * metatables. + * + * Due to a historical quirk, strings are exempt from this rule. Without this exemption, + * the test code here would fail to typecheck at the use of ==. + */ +TEST_CASE_FIXTURE(Fixture, "infer_array_2") +{ + CheckResult result = check(R"( + local buttonVector = {} + + function createButton( actionName, functionInfoTable ) + local position = nil + for i = 1,#buttonVector do + if buttonVector[i] == "empty" then + position = i + break + end + end + return position + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") +{ + CheckResult result = check(R"( + function swap(p) + local temp = p[0] + p[0] = p[1] + p[1] = temp + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("swap")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(1, argVec.size()); + + const TableTypeVar* ttv = get(follow(argVec[0])); + REQUIRE(ttv != nullptr); + + REQUIRE(bool(ttv->indexer)); + + const TableIndexer& indexer = *ttv->indexer; + + REQUIRE_EQ(indexer.indexType, typeChecker.numberType); + + REQUIRE(nullptr != get(indexer.indexResultType)); +} + +TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") +{ + CheckResult result = check(R"( + function mergesort(arr) + local p = arr[0] + return arr + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("mergesort")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(1, argVec.size()); + + const TableTypeVar* argType = get(follow(argVec[0])); + REQUIRE(argType != nullptr); + + std::vector retVec = flatten(ftv->retType).first; + + const TableTypeVar* retType = get(follow(retVec[0])); + REQUIRE(retType != nullptr); + + CHECK_EQ(argType->state, retType->state); + + REQUIRE_EQ(*argVec[0], *retVec[0]); +} + +TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_array_like_table") +{ + CheckResult result = check(R"( + local t = {"one", "two", "three"} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const TableTypeVar* ttv = get(requireType("t")); + REQUIRE(ttv != nullptr); + + REQUIRE(bool(ttv->indexer)); + const TableIndexer& indexer = *ttv->indexer; + + CHECK_EQ(*typeChecker.numberType, *indexer.indexType); + CHECK_EQ(*typeChecker.stringType, *indexer.indexResultType); +} + +TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") +{ + CheckResult result = check(R"( + function Symbol(n) + return { __name=n } + end + + function f() + return { + [Symbol("hello")] = true, + x = 0, + y = 0 + } + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* fType = get(requireType("f")); + REQUIRE(fType != nullptr); + + auto retType_ = first(fType->retType); + REQUIRE(bool(retType_)); + + auto retType = get(follow(*retType_)); + REQUIRE(retType != nullptr); + + CHECK(bool(retType->indexer)); + + const TableIndexer& indexer = *retType->indexer; + CHECK_EQ("{| __name: string |}", toString(indexer.indexType)); +} + +TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_its_variable_type_and_unifiable") +{ + CheckResult result = check(R"( + local t1: { [string]: string } = {} + local t2 = { "bar" } + + t2 = t1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm != nullptr); + + const TableTypeVar* tTy = get(requireType("t2")); + REQUIRE(tTy != nullptr); + + REQUIRE(tTy->indexer); + CHECK_EQ(*typeChecker.numberType, *tTy->indexer->indexType); + CHECK_EQ(*typeChecker.stringType, *tTy->indexer->indexResultType); +} + +TEST_CASE_FIXTURE(Fixture, "indexer_mismatch") +{ + CheckResult result = check(R"( + local t1: { [string]: string } = {} + local t2: { [number]: number } = {} + + t2 = t1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeId t1 = requireType("t1"); + TypeId t2 = requireType("t2"); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm != nullptr); + CHECK_EQ(tm->wantedType, t2); + CHECK_EQ(tm->givenType, t1); + + CHECK_NE(*t1, *t2); +} + +TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_its_function_return_type") +{ + CheckResult result = check(R"( + local function f(): { [number]: string } + return {} + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_hand_table_with_indexer") +{ + CheckResult result = check(R"( + local function f(): { [number]: string } return {} end + + local t = {} + t = f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "sealed_table_value_must_not_infer_an_indexer") +{ + CheckResult result = check(R"( + local t: { a: string, [number]: string } = { a = "foo" } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "sealed_table_indexers_must_unify") +{ + CheckResult result = check(R"( + local A = { 5, 7, 8 } + local B = { "one", "two", "three" } + + B = A + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); +} + +TEST_CASE_FIXTURE(Fixture, "indexer_on_sealed_table_must_unify_with_free_table") +{ + CheckResult result = check(R"( + local A = { 1, 2, 3 } + function F(t) + t[4] = "hi" + A = t + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_type_when_indexing_from_a_table_indexer") +{ + CheckResult result = check(R"( + local t: { [number]: string } + local s = t[1] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.stringType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_possible") +{ + CheckResult result = check(R"( + local t: { a: string, [string]: number } + local a1 = t.a + local a2 = t["a"] + + local b1 = t.b + local b2 = t["b"] + + local some_indirection_variable = "foo" + local c = t[some_indirection_variable] + + local d = t[1] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(*typeChecker.stringType, *requireType("a1")); + CHECK_EQ(*typeChecker.stringType, *requireType("a2")); + + CHECK_EQ(*typeChecker.numberType, *requireType("b1")); + CHECK_EQ(*typeChecker.numberType, *requireType("b2")); + + CHECK_EQ(*typeChecker.numberType, *requireType("c")); + + CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_string") +{ + ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); + + CheckResult result = check(R"( + local t: { a: string } + function f(x: string) return t[x] end + local a = f("a") + local b = f("b") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.anyType, *requireType("a")); + CHECK_EQ(*typeChecker.anyType, *requireType("b")); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_number") +{ + ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); + + CheckResult result = check(R"( + local t = { a = true } + function f(x: number) return t[x] end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); +} + +TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer") +{ + CheckResult result = check(R"( + local t = {} + t["a"] = "foo" + + local a = t.a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.stringType, *requireType("a")); + + TableTypeVar* tableType = getMutable(requireType("t")); + REQUIRE(tableType != nullptr); + REQUIRE(tableType->indexer == std::nullopt); + + TypeId propertyA = tableType->props["a"].type; + REQUIRE(propertyA != nullptr); + CHECK_EQ(*typeChecker.stringType, *propertyA); +} + +TEST_CASE_FIXTURE(Fixture, "oop_indexer_works") +{ + CheckResult result = check(R"( + local clazz = {} + clazz.__index = clazz + + function clazz:speak() + return "hi" + end + + function clazz.new() + return setmetatable({}, clazz) + end + + local me = clazz.new() + local words = me:speak() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.stringType, *requireType("words")); +} + +TEST_CASE_FIXTURE(Fixture, "indexer_table") +{ + CheckResult result = check(R"( + local clazz = {a="hello"} + local instanace = setmetatable({}, {__index=clazz}) + local b = instanace.a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.stringType, *requireType("b")); +} + +TEST_CASE_FIXTURE(Fixture, "indexer_fn") +{ + CheckResult result = check(R"( + local instanace = setmetatable({}, {__index=function() return 10 end}) + local b = instanace.somemethodwedonthave + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("b")); +} + +TEST_CASE_FIXTURE(Fixture, "meta_add") +{ + // Note: meta_add_inferred and this unit test are currently the same exact thing. + // We'll want to change this one in particular when we add real syntax for metatables. + + CheckResult result = check(R"( + local a = setmetatable({}, {__add = function(l, r) return l end}) + type Vector = typeof(a) + local b:Vector + local c = a + b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(follow(requireType("a")), follow(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "meta_add_inferred") +{ + CheckResult result = check(R"( + local a = {} + setmetatable(a, {__add=function(a,b) return b end} ) + local c = a + a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*requireType("a"), *requireType("c")); +} + +TEST_CASE_FIXTURE(Fixture, "meta_add_both_ways") +{ + CheckResult result = check(R"( + type VectorMt = { __add: (Vector, number) -> Vector } + local vectorMt: VectorMt + type Vector = typeof(setmetatable({}, vectorMt)) + local a: Vector + + local b = a + 2 + local c = 2 + a + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Vector", toString(requireType("a"))); + CHECK_EQ(*requireType("a"), *requireType("b")); + CHECK_EQ(*requireType("a"), *requireType("c")); +} + +// This test exposed a bug where we let go of the "seen" stack while unifying table types +// As a result, type inference crashed with a stack overflow. +TEST_CASE_FIXTURE(Fixture, "unification_of_unions_in_a_self_referential_type") +{ + CheckResult result = check(R"( + type A = {} + type AMT = { __mul: (A, A | number) -> A } + local a: A + local amt: AMT + setmetatable(a, amt) + + type B = {} + type BMT = { __mul: (B, A | B | number) -> A } + local b: B + local bmt: BMT + setmetatable(b, bmt) + + a = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const MetatableTypeVar* amtv = get(requireType("a")); + REQUIRE(amtv); + CHECK_EQ(amtv->metatable, requireType("amt")); + + const MetatableTypeVar* bmtv = get(requireType("b")); + REQUIRE(bmtv); + CHECK_EQ(bmtv->metatable, requireType("bmt")); +} + +TEST_CASE_FIXTURE(Fixture, "oop_polymorphic") +{ + CheckResult result = check(R"( + local animal = {} + animal.__index = animal + function animal:isAlive() return true end + function animal:speed() return 10 end + + local pelican = {} + setmetatable(pelican, animal) + pelican.__index = pelican + function pelican:movement() return "fly" end + function pelican:speed() return 30 end + + function pelican.new(name) + local s = {} + setmetatable(s, pelican) + s.name = name + return s + end + + local scoops = pelican.new("scoops") + + local alive = scoops:isAlive() + local at = scoops.isAlive + local movement = scoops:movement() + local name = scoops.name + local speed = scoops:speed() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.booleanType, *requireType("alive")); + CHECK_EQ(*typeChecker.stringType, *requireType("movement")); + CHECK_EQ(*typeChecker.stringType, *requireType("name")); + CHECK_EQ(*typeChecker.numberType, *requireType("speed")); +} + +TEST_CASE_FIXTURE(Fixture, "user_defined_table_types_are_named") +{ + CheckResult result = check(R"( + type Vector3 = {x: number, y: number} + + local v: Vector3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Vector3", toString(requireType("v"))); +} + +TEST_CASE_FIXTURE(Fixture, "result_is_always_any_if_lhs_is_any") +{ + CheckResult result = check(R"( + type Vector3MT = { + __add: (Vector3MT, Vector3MT) -> Vector3MT, + __mul: (Vector3MT, Vector3MT|number) -> Vector3MT + } + + local Vector3: {new: (number?, number?, number?) -> Vector3MT} + local Vector3MT: Vector3MT + setmetatable(Vector3, Vector3MT) + + type CFrameMT = { + __mul: (CFrameMT, Vector3MT|CFrameMT) -> Vector3MT|CFrameMT + } + + local CFrame: { + Angles:(number, number, number) -> CFrameMT + } + local CFrameMT: CFrameMT + setmetatable(CFrame, CFrameMT) + + local n: any + local a = (n + Vector3.new(0, 1.5, 0)) * CFrame.Angles(0, math.pi/2, 0) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "result_is_bool_for_equality_operators_if_lhs_is_any") +{ + CheckResult result = check(R"( + local a: any + local b: number + + local c = a < b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("boolean", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "inequality_operators_imply_exactly_matching_types") +{ + CheckResult result = check(R"( + function abs(n) + if n < 0 then + return -n + else + return n + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(number) -> number", toString(requireType("abs"))); +} + +TEST_CASE_FIXTURE(Fixture, "nice_error_when_trying_to_fetch_property_of_boolean") +{ + CheckResult result = check(R"( + local a = true + local b = a.some_prop + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'boolean' does not have key 'some_prop'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_builtin_sealed_table_must_fail") +{ + CheckResult result = check(R"( + function string.m() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_must_fail") +{ + CheckResult result = check(R"( + function string:m() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail") +{ + CheckResult result = check(R"( + local t = {x = 1} + function t.m() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must_fail") +{ + CheckResult result = check(R"( + local t = {x = 1} + function t:m() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +// This unit test could be flaky if the fix has regressed. +TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_crashing") +{ + CheckResult result = check(R"( + -- must be in this specific order, and with (roughly) those exact properties! + type A = {x: number, [any]: any} | {} + local a: A + + function f(t) + t.y = 1 + end + + f(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); +} + +// This unit test could be flaky if the fix has regressed. +TEST_CASE_FIXTURE(Fixture, "passing_compatible_unions_to_a_generic_table_without_crashing") +{ + CheckResult result = check(R"( + type A = {x: number, y: number, [any]: any} | {y: number} + local a: A + + function f(t) + t.y = 1 + end + + f(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_function_call") +{ + CheckResult result = check(R"( + local t = {} + function t.Foo() end + + t.fOo() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeError te = result.errors[0]; + UnknownPropButFoundLikeProp* error = get(te); + REQUIRE(error); + + TypeId t = requireType("t"); + CHECK_EQ(*t, *error->table); + CHECK_EQ("fOo", error->key); + + auto candidates = error->candidates; + CHECK_EQ(1, candidates.size()); + CHECK(candidates.find("Foo") != candidates.end()); + + CHECK_EQ(toString(te), "Key 'fOo' not found in table 't'. Did you mean 'Foo'?"); +} + +TEST_CASE_FIXTURE(Fixture, "found_like_key_in_table_property_access") +{ + CheckResult result = check(R"( + local t = {X = 1} + + print(t.x) + )"); + + REQUIRE_EQ(result.errors.size(), 1); + + TypeError te = result.errors[0]; + UnknownPropButFoundLikeProp* error = get(te); + REQUIRE(error); + + TypeId t = requireType("t"); + CHECK_EQ(*t, *error->table); + CHECK_EQ("x", error->key); + + auto candidates = error->candidates; + CHECK_EQ(1, candidates.size()); + CHECK(candidates.find("X") != candidates.end()); + + CHECK_EQ(toString(te), "Key 'x' not found in table 't'. Did you mean 'X'?"); +} + +TEST_CASE_FIXTURE(Fixture, "found_multiple_like_keys") +{ + CheckResult result = check(R"( + local t = {Foo = 1, foO = 2} + + print(t.foo) + )"); + + REQUIRE_EQ(result.errors.size(), 1); + + TypeError te = result.errors[0]; + UnknownPropButFoundLikeProp* error = get(te); + REQUIRE(error); + + TypeId t = requireType("t"); + CHECK_EQ(*t, *error->table); + CHECK_EQ("foo", error->key); + + auto candidates = error->candidates; + CHECK_EQ(2, candidates.size()); + CHECK(candidates.find("Foo") != candidates.end()); + CHECK(candidates.find("foO") != candidates.end()); + + CHECK_EQ(toString(te), "Key 'foo' not found in table 't'. Did you mean one of 'Foo', 'foO'?"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_exact_match_keys") +{ + CheckResult result = check(R"( + local t = {} + t.foO = 1 + print(t.Foo) + t.Foo = 2 + )"); + + REQUIRE_EQ(result.errors.size(), 1); + + TypeError te = result.errors[0]; + UnknownPropButFoundLikeProp* error = get(te); + REQUIRE(error); + + TypeId t = requireType("t"); + CHECK_EQ(*t, *error->table); + CHECK_EQ("Foo", error->key); + + auto candidates = error->candidates; + CHECK_EQ(1, candidates.size()); + CHECK(candidates.find("foO") != candidates.end()); + CHECK(candidates.find("Foo") == candidates.end()); + + CHECK_EQ(toString(te), "Key 'Foo' not found in table 't'. Did you mean 'foO'?"); +} + +TEST_CASE_FIXTURE(Fixture, "getmetatable_returns_pointer_to_metatable") +{ + CheckResult result = check(R"( + local t = {x = 1} + local mt = {__index = {y = 2}} + setmetatable(t, mt) + + local returnedMT = getmetatable(t) + )"); + + CHECK_EQ(*requireType("mt"), *requireType("returnedMT")); +} + +TEST_CASE_FIXTURE(Fixture, "metatable_mismatch_should_fail") +{ + CheckResult result = check(R"( + local t1 = {x = 1} + local mt1 = {__index = {y = 2}} + setmetatable(t1, mt1) + + local t2 = {x = 1} + local mt2 = {__index = function() return nil end} + setmetatable(t2, mt2) + + t1 = t2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(*tm->wantedType, *requireType("t1")); + CHECK_EQ(*tm->givenType, *requireType("t2")); +} + +TEST_CASE_FIXTURE(Fixture, "property_lookup_through_tabletypevar_metatable") +{ + CheckResult result = check(R"( + local t = {x = 1} + local mt = {__index = {y = 2}} + setmetatable(t, mt) + + print(t.x) + print(t.y) + print(t.z) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownProperty* up = get(result.errors[0]); + REQUIRE_MESSAGE(up, result.errors[0].data); + CHECK_EQ(up->key, "z"); +} + +TEST_CASE_FIXTURE(Fixture, "missing_metatable_for_sealed_tables_do_not_get_inferred") +{ + CheckResult result = check(R"( + local t = {x = 1} + + local a = {x = 1} + local b = {__index = {y = 2}} + setmetatable(a, b) + + t = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeId a = requireType("a"); + TypeId t = requireType("t"); + CHECK_NE(*a, *t); + + TypeError te = result.errors[0]; + TypeMismatch* tm = get(te); + REQUIRE(tm); + CHECK_EQ(tm->wantedType, t); + CHECK_EQ(tm->givenType, a); + + const MetatableTypeVar* aTy = get(a); + REQUIRE(aTy); + + const TableTypeVar* tTy = get(t); + REQUIRE(tTy); +} + +// Could be flaky if the fix has regressed. +TEST_CASE_FIXTURE(Fixture, "right_table_missing_key") +{ + CheckResult result = check(R"( + function _(...) + end + local l7 = not _,function(l0) + _ += _((_) or {function(...) + end,["z"]=_,} or {},(function(l43,...) + end)) + _ += 0 < {} + end + repeat + until _ + local l0 = n4,_((_) or {} or {[30976]=_,},({})) + )"); + + CHECK_GE(result.errors.size(), 0); +} + +// Could be flaky if the fix has regressed. +TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") +{ + CheckResult result = check(R"( + local lt: { [string]: string, a: string } + local rt: {} + + lt = rt + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + MissingProperties* mp = get(result.errors[0]); + REQUIRE(mp); + CHECK_EQ(mp->context, MissingProperties::Missing); + REQUIRE_EQ(1, mp->properties.size()); + CHECK_EQ(mp->properties[0], "a"); + + CHECK_EQ("{| [string]: string, a: string |}", toString(mp->superType)); + CHECK_EQ("{| |}", toString(mp->subType)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") +{ + CheckResult result = check(R"( + type StringToStringMap = { [string]: string } + local rt: StringToStringMap = { ["foo"] = 1 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + ToStringOptions o{/* exhaustive= */ true}; + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); + CHECK_EQ("{| foo: number |}", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") +{ + CheckResult result = check(R"( + local function foo(a: {[string]: number, a: string}) end + foo({ a = "" }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") +{ + CheckResult result = check(R"( + local function foo(a: {[string]: number, a: string}) end + foo({ a = 1 }) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + ToStringOptions o{/* exhaustive= */ true}; + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("string", toString(tm->wantedType, o)); + CHECK_EQ("number", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors") +{ + CheckResult result = check(R"( + local vec3 = {x = 1, y = 2, z = 3} + local vec1 = {x = 1} + + vec3 = vec1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + MissingProperties* mp = get(result.errors[0]); + REQUIRE(mp); + CHECK_EQ(mp->context, MissingProperties::Missing); + REQUIRE_EQ(2, mp->properties.size()); + CHECK_EQ(mp->properties[0], "y"); + CHECK_EQ(mp->properties[1], "z"); + CHECK_EQ("vec3", toString(mp->superType)); + CHECK_EQ("vec1", toString(mp->subType)); +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors2") +{ + CheckResult result = check(R"( + type DumbMixedTable = {[number]: number, x: number} + local t: DumbMixedTable = {"fail"} + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + MissingProperties* mp = get(result.errors[1]); + REQUIRE(mp); + CHECK_EQ(mp->context, MissingProperties::Missing); + REQUIRE_EQ(1, mp->properties.size()); + CHECK_EQ(mp->properties[0], "x"); +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") +{ + CheckResult result = check(R"( + local vec3 = {x = 1, y = 2, z = 3} + local vec1 = {x = 1} + + vec1 = vec3 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + MissingProperties* mp = get(result.errors[0]); + REQUIRE(mp); + CHECK_EQ(mp->context, MissingProperties::Extra); + REQUIRE_EQ(2, mp->properties.size()); + CHECK_EQ(mp->properties[0], "y"); + CHECK_EQ(mp->properties[1], "z"); + CHECK_EQ("vec1", toString(mp->superType)); + CHECK_EQ("vec3", toString(mp->subType)); +} + +TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short") +{ + ScopedFastInt sfis{"LuauTableTypeMaximumStringifierLength", 40}; + + CheckResult result = check(R"( + local t + t = {} + t.a = 1 + t.b = 1 + t.c = 1 + t.d = 1 + t.e = 1 + t.f = 1 + + t = 1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(requireType("t"), tm->wantedType); + CHECK_EQ("number", toString(tm->givenType)); + + CHECK_EQ("Type 'number' could not be converted into '{ a: number, b: number, c: number, d: number, e: number, ... 1 more ... }'", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "ok_to_set_nil_even_on_non_lvalue_base_expr") +{ + CheckResult result = check(R"( + local function f(): { [string]: number } + return { ["foo"] = 1 } + end + + f()["foo"] = nil + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "ok_to_provide_a_subtype_during_construction") +{ + CheckResult result = check(R"( + local a: string | number = 1 + local t = {a, 1} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("{number | string}", toString(requireType("t"), {/*exhaustive*/ true})); +} + +TEST_CASE_FIXTURE(Fixture, "reasonable_error_when_adding_a_nonexistent_property_to_an_array_like_table") +{ + CheckResult result = check(R"( + --!strict + local A = {"value"} + A.B = "Hello" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownProperty* up = get(result.errors[0]); + REQUIRE(up != nullptr); + + CHECK_EQ("B", up->key); +} + +TEST_CASE_FIXTURE(Fixture, "shorter_array_types_actually_work") +{ + CheckResult result = check(R"( + --!strict + local A: {string | number} + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + CHECK_EQ("{number | string}", toString(requireType("A"))); +} + +TEST_CASE_FIXTURE(Fixture, "only_ascribe_synthetic_names_at_module_scope") +{ + CheckResult result = check(R"( + --!strict + local TopLevel = {} + local foo + + for i = 1, 10 do + local SubScope = { 1, 2, 3 } + foo = SubScope + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + + CHECK_EQ("TopLevel", toString(requireType("TopLevel"))); + CHECK_EQ("{number}", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") +{ + CheckResult result = check(R"( + --!strict + + local function f() + local t = { x = 1 } + + function t.a() end + function t.b() end + + return t + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ("Cannot add property 'a' to table '{| x: number |}'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'b' to table '{| x: number |}'", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "builtin_table_names") +{ + CheckResult result = check(R"( + os.h = 2 + string.k = 3 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ("Cannot add property 'h' to table 'os'", toString(result.errors[0])); + CHECK_EQ("Cannot add property 'k' to table 'string'", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "persistent_sealed_table_is_immutable") +{ + CheckResult result = check(R"( + --!nonstrict + function os:bad() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Cannot add property 'bad' to table 'os'", toString(result.errors[0])); + + const TableTypeVar* osType = get(requireType("os")); + REQUIRE(osType != nullptr); + CHECK(osType->props.find("bad") == osType->props.end()); +} + +TEST_CASE_FIXTURE(Fixture, "common_table_element_list") +{ + CheckResult result = check(R"( +type Table = { + a: number, + b: number? +} + +local Test: {Table} = { + { a = 1 }, + { a = 2, b = 3 } +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "common_table_element_general") +{ + CheckResult result = check(R"( +type Table = { + a: number, + b: number? +} + +local Test: {Table} = { + [2] = { a = 1 }, + [5] = { a = 2, b = 3 } +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "common_table_element_inner_index") +{ + CheckResult result = check(R"( +type Table = { + a: number, + b: number? +} + +local Test: {{Table}} = {{ + { a = 1 }, + { a = 2, b = 3 } +},{ + { a = 3 }, + { a = 4, b = 3 } +}} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "common_table_element_inner_prop") +{ + CheckResult result = check(R"( +type Table = { + a: number, + b: number? +} + +local Test: {{x: Table, y: Table}} = {{ + x = { a = 1 }, + y = { a = 2, b = 3 } +},{ + x = { a = 3 }, + y = { a = 4 } +}} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "common_table_element_union_assignment") +{ + CheckResult result = check(R"( +type Foo = {x: number | string} + +local foos: {Foo} = { + {x = 1234567}, + {x = "hello"}, +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "quantifying_a_bound_var_works") +{ + CheckResult result = check(R"( + local clazz = {} + clazz.__index = clazz + + function clazz:speak() + return "hi" + end + + function clazz.new() + return setmetatable({}, clazz) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + TypeId ty = requireType("clazz"); + TableTypeVar* ttv = getMutable(ty); + REQUIRE(ttv); + Property& prop = ttv->props["new"]; + REQUIRE(prop.type); + const FunctionTypeVar* ftv = get(follow(prop.type)); + REQUIRE(ftv); + const TypePack* res = get(follow(ftv->retType)); + REQUIRE(res); + REQUIRE(res->head.size() == 1); + const MetatableTypeVar* mtv = get(follow(res->head[0])); + REQUIRE(mtv); + ttv = getMutable(follow(mtv->table)); + REQUIRE(ttv); + REQUIRE_EQ(ttv->state, TableState::Sealed); +} + +TEST_CASE_FIXTURE(Fixture, "less_exponential_blowup_please") +{ + CheckResult result = check(R"( + --!strict + + local Foo = setmetatable({}, {}) + Foo.__index = Foo + + function Foo.new() + local self = setmetatable({}, Foo) + return self:constructor() or self + end + function Foo:constructor() end + + function Foo:create() + local foo = Foo.new() + foo:First() + foo:Second() + foo:Third() + return foo + end + function Foo:First() end + function Foo:Second() end + function Foo:Third() end + + local newData = Foo:create() + newData:First() + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call") +{ + CheckResult result = check(R"( +local function foo(l: {{x: number | string}}) end + +foo({ + {x = 1234567}, + {x = "hello"}, +}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_call_tail") +{ + CheckResult result = check(R"( +type Foo = {x: number | string} +local function foo(l: {Foo}, ...: {Foo}) end + +foo({{x = 1234567}, {x = "hello"}}, {{x = 1234567}, {x = "hello"}}, {{x = 1234567}, {x = "hello"}}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "common_table_element_union_in_prop") +{ + CheckResult result = check(R"( +type Foo = {x: number | string} +local t: { a: {Foo}, b: number } = { + a = { + {x = 1234567}, + {x = "hello"}, + }, + b = 5 +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// It's unsound to instantiate tables containing generic methods, +// since mutating properties means table properties should be invariant. +TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound") +{ + CheckResult result = check(R"( + --!strict + local t = {} + function t.m(x) return x end + local a : string = t.m("hi") + local b : number = t.m(5) + local u : { m : (number)->number } = t -- This shouldn't typecheck + u.m = function(x) return 1+x end + local c : string = t.m("hi") + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp new file mode 100644 index 0000000..37333b1 --- /dev/null +++ b/tests/TypeInfer.test.cpp @@ -0,0 +1,5306 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +#include + +LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) +LUAU_FASTFLAG(LuauEqConstraint) + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInfer"); + +TEST_CASE_FIXTURE(Fixture, "tc_hello_world") +{ + CheckResult result = check("local a = 7"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::Number); +} + +TEST_CASE_FIXTURE(Fixture, "tc_propagation") +{ + CheckResult result = check("local a = 7 local b = a"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId bType = requireType("b"); + CHECK_EQ(getPrimitiveType(bType), PrimitiveTypeVar::Number); +} + +TEST_CASE_FIXTURE(Fixture, "tc_error") +{ + CheckResult result = check("local a = 7 local b = 'hi' a = b"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{ + requireType("a"), + requireType("b"), + }})); +} + +TEST_CASE_FIXTURE(Fixture, "tc_error_2") +{ + CheckResult result = check("local a = 7 a = 'hi'"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 18}, Position{0, 22}}, TypeMismatch{ + requireType("a"), + typeChecker.stringType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "tc_function") +{ + CheckResult result = check("function five() return 5 end"); + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* fiveType = get(requireType("five")); + REQUIRE(fiveType != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "infer_locals_with_nil_value") +{ + CheckResult result = check("local f = nil; f = 'hello world'"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = requireType("f"); + CHECK_EQ(getPrimitiveType(ty), PrimitiveTypeVar::String); +} + +TEST_CASE_FIXTURE(Fixture, "infer_locals_via_assignment_from_its_call_site") +{ + CheckResult result = check(R"( + local a + function f(x) a = x end + f(1) + f("foo") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("number", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") +{ + CheckResult result = check(R"( + --!nocheck + function f(x) + return x + end + -- we get type information even if there's type errors + f(1, 2) + )"); + + CHECK_EQ("(any) -> (...any)", toString(requireType("f"))); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "check_function_bodies") +{ + CheckResult result = check("function myFunction() local a = 0 a = true end"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 44}, Position{0, 48}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.booleanType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_type") +{ + CheckResult result = check("function take_five() return 5 end"); + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* takeFiveType = get(requireType("take_five")); + REQUIRE(takeFiveType != nullptr); + + std::vector retVec = flatten(takeFiveType->retType).first; + REQUIRE(!retVec.empty()); + + REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") +{ + CheckResult result = check("function take_five() return 5 end local five = take_five()"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *follow(requireType("five"))); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_call_primitives") +{ + CheckResult result = check("local foo = 5 foo()"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0]) != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_call_tables") +{ + CheckResult result = check("local foo = {} foo()"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0]) != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") +{ + CheckResult result = check(R"( + function take_five() + return 5 + end + + take_five().prop = 888 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); +} + +TEST_CASE_FIXTURE(Fixture, "expr_statement") +{ + CheckResult result = check("local foo = 5 foo()"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_function") +{ + CheckResult result = check("function id(x) return x end local a = id(55) local b = id(nil)"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("a")); + CHECK_EQ(*typeChecker.nilType, *requireType("b")); +} + +TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and_size") +{ + CheckResult result = check(R"( + function f(...) end + + f(1) + f("foo", 2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") +{ + CheckResult result = check(R"( + local T = {} + function T.f(...) + local result = {} + + for i = 1, select("#", ...) do + local dictionary = select(i, ...) + for key, value in pairs(dictionary) do + result[key] = value + end + end + + return result + end + + return T + )"); + + auto r = first(getMainModule()->getModuleScope()->returnType); + REQUIRE(r); + + TableTypeVar* ttv = getMutable(*r); + REQUIRE(ttv); + + TypeId k = ttv->props["f"].type; + REQUIRE(k); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "for_loop") +{ + CheckResult result = check(R"( + local q + for i=0, 50, 2 do + q = i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("q")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop") +{ + CheckResult result = check(R"( + local n + local s + for i, v in pairs({ "foo" }) do + n = i + s = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*typeChecker.stringType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_next") +{ + CheckResult result = check(R"( + local n + local s + for i, v in next, { "foo", "bar" } do + n = i + s = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*typeChecker.stringType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") +{ + CheckResult result = check(R"( + local it: any + local a, b + for i, v in it do + a, b = i, v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") +{ + CheckResult result = check(R"( + local foo = "bar" + for i, v in foo do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") +{ + CheckResult result = check(R"( + local function keys(dictionary) + local new = {} + local index = 1 + + for key in pairs(dictionary) do + new[index] = key + index = index + 1 + end + + return new + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_with_a_custom_iterator_should_type_check") +{ + CheckResult result = check(R"( + local function range(l, h): () -> number + return function() + return l + end + end + + for n: string in range(1, 10) do + print(n) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") +{ + CheckResult result = check(R"( + function f(x) + gobble.prop = x.otherprop + end + + local p + for _, part in i_am_not_defined do + p = part + f(part) + part.thirdprop = false + end + )"); + + CHECK_EQ(2, result.errors.size()); + + TypeId p = requireType("p"); + CHECK_EQ(*p, *typeChecker.errorType); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") +{ + CheckResult result = check(R"( + local bad_iter = 5 + + for a in bad_iter() do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") +{ + CheckResult result = check(R"( + local function hasDivisors(value: number, table) + return false + end + + function prime_iter(state, index) + while hasDivisors(index, state) do + index += 1 + end + + state[index] = true + return index + end + + function primes1() + return prime_iter, {} + end + + function primes2() + return prime_iter, {}, "" + end + + function primes3() + return prime_iter, {}, 2 + end + + for p in primes1() do print(p) end -- mismatch in argument count + + for p in primes2() do print(p) end -- mismatch in argument types, prime_iter takes {}, number, we are given {}, string + + for p in primes3() do print(p) end -- no errror + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Arg); + CHECK_EQ(2, acm->expected); + CHECK_EQ(1, acm->actual); + + TypeMismatch* tm = get(result.errors[1]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") +{ + CheckResult result = check(R"( + function prime_iter(state, index) + return 1 + end + + for p in prime_iter do print(p) end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Arg); + CHECK_EQ(2, acm->expected); + CHECK_EQ(0, acm->actual); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") +{ + CheckResult result = check(R"( + function bar(): any + return true + end + + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(typeChecker.anyType, requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") +{ + CheckResult result = check(R"( + function bar(): any + return true + end + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(typeChecker.anyType, requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") +{ + CheckResult result = check(R"( + local bar: any + + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(typeChecker.anyType, requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") +{ + CheckResult result = check(R"( + local bar: any + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(typeChecker.anyType, requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") +{ + CheckResult result = check(R"( + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(typeChecker.errorType, requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") +{ + CheckResult result = check(R"( + function bar(c) return c end + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(typeChecker.errorType, requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") +{ + CheckResult result = check(R"( + function primes() + return function (state: number) end, 2 + end + + for p, q in primes do + q = "" + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "if_statement") +{ + CheckResult result = check(R"( + local a + local b + + if true then + a = 'hello' + else + b = 999 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.stringType, *requireType("a")); + CHECK_EQ(*typeChecker.numberType, *requireType("b")); +} + +TEST_CASE_FIXTURE(Fixture, "while_loop") +{ + CheckResult result = check(R"( + local i + while true do + i = 8 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("i")); +} + +TEST_CASE_FIXTURE(Fixture, "repeat_loop") +{ + CheckResult result = check(R"( + local i + repeat + i = 'hi' + until true + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.stringType, *requireType("i")); +} + +TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") +{ + CheckResult result = check(R"( + repeat + local x = true + until x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") +{ + CheckResult result = check(R"( + repeat + local x = true + until x + + print(x) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "table_length") +{ + CheckResult result = check(R"( + local t = {} + local s = #t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(nullptr != get(requireType("t"))); + CHECK_EQ(*typeChecker.numberType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "string_length") +{ + CheckResult result = check(R"( + local s = "Hello, World!" + local t = #s + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "string_index") +{ + CheckResult result = check(R"( + local s = "Hello, World!" + local t = s[4] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + NotATable* nat = get(result.errors[0]); + REQUIRE(nat); + CHECK_EQ("string", toString(nat->ty)); + + CHECK(get(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local l = #this_is_not_defined + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local originalReward = unknown.Parent.Reward:GetChildren()[1] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") +{ + CheckResult result = check("local a = {} a[0] = 7 a[0] = nil"); + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") +{ + CheckResult result = check("local a = {} a[0] = 7 a[0] = 't'"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.stringType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") +{ + CheckResult result = check("local a = {a=1, b=2} a['a'] = nil"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.nilType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "dot_on_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local foo = (true).x + foo.x = foo.y + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon") +{ + CheckResult result = check(R"( + local someTable = {} + + someTable.Function1 = function(Arg1) + end + + someTable.Function1() -- Argument count mismatch + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") +{ + CheckResult result = check(R"( + local someTable = {} + + someTable.Function2 = function(Arg1, Arg2) + end + + someTable.Function2() -- Argument count mismatch + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") +{ + CheckResult result = check(R"( + type T = {method: ((T, number) -> number) & ((number) -> number)} + local T: T + + T.method(4) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_count") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply("") + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); + + ExtraInformation* ei = get(result.errors[1]); + REQUIRE(ei); + CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); +} + +TEST_CASE_FIXTURE(Fixture, "list_all_overloads_if_no_overload_takes_given_argument_count") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply() + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("No overload for function accepts 0 arguments.", ge->message); + + ExtraInformation* ei = get(result.errors[1]); + REQUIRE(ei); + CHECK_EQ("Available overloads: (number) -> number; (number) -> string; and (number, number) -> number", ei->message); +} + +TEST_CASE_FIXTURE(Fixture, "dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply(1, "") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") +{ + CheckResult result = check(R"( + type T = {method: ((T, number) -> number) & ((number) -> string)} + local T: T + + local a = T.method(T, 4) + local b = T.method(5) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("string", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_arguments") +{ + CheckResult result = check(R"( + --!nonstrict + + function g(a: number) end + + g() + + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(1, acm->expected); + CHECK_EQ(0, acm->actual); +} + +TEST_CASE_FIXTURE(Fixture, "any_type_propagates") +{ + CheckResult result = check(R"( + local foo: any + local bar = foo:method("argument") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "can_subscript_any") +{ + CheckResult result = check(R"( + local foo: any + local bar = foo[5] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("bar"))); +} + +// Not strictly correct: metatables permit overriding this +TEST_CASE_FIXTURE(Fixture, "can_get_length_of_any") +{ + CheckResult result = check(R"( + local foo: any = {} + local bar = #foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_function") +{ + CheckResult result = check(R"( + function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "lambda_form_of_local_function_cannot_be_recursive") +{ + CheckResult result = check(R"( + local f = function() return f() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_local_function") +{ + CheckResult result = check(R"( + local function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// FIXME: This and the above case get handled very differently. It's pretty dumb. +// We really should unify the two code paths, probably by deleting AstStatFunction. +TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") +{ + CheckResult result = check(R"( + local count + function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + function f() + return f + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = () -> t1", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + function f(g) + return f(f) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type F = () -> F? + local function f() + return f + end + + local g: F = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); +} + +// TODO: File a Jira about this +/* +TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") +{ + CheckResult result = check(R"( + function a(x) return 1 end + a(...) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE(bool(getMainModule()->getModuleScope()->varargPack)); + + TypePackId varargPack = *getMainModule()->getModuleScope()->varargPack; + + auto iter = begin(varargPack); + auto endIter = end(varargPack); + + CHECK(iter != endIter); + ++iter; + CHECK(iter == endIter); + + CHECK(!iter.tail()); +} +*/ + +TEST_CASE_FIXTURE(Fixture, "method_depends_on_table") +{ + CheckResult result = check(R"( + -- This catches a bug where x:m didn't count as a use of x + -- so toposort would happily reorder a definition of + -- function x:m before the definition of x. + function g() f() end + local x = {} + function x:m() end + function f() x:m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") +{ + CheckResult result = check(R"( + local Get_des + function Get_des(func) + Get_des(func) + end + + local function f(d) + d:IsA("BasePart") + d.Parent:FindFirstChild("Humanoid") + d:IsA("Decal") + end + Get_des(f) + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_other_higher_order_function") +{ + CheckResult result = check(R"( + local d + d:foo() + d:foo() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") +{ + CheckResult result = check(R"( + function foo() + return bar(999), bar("hi") + end + + function bar(i) + return i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_table_method") +{ + CheckResult result = check(R"( + local T = {} + + function T:bar(i) + return i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId tType = requireType("T"); + TableTypeVar* tTable = getMutable(tType); + REQUIRE(tTable != nullptr); + + TypeId barType = tTable->props["bar"].type; + REQUIRE(barType != nullptr); + + const FunctionTypeVar* ftv = get(follow(barType)); + REQUIRE_MESSAGE(ftv != nullptr, "Should be a function: " << *barType); + + std::vector args = flatten(ftv->argTypes).first; + TypeId argType = args.at(1); + + CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") +{ + CheckResult result = check(R"( + local T = {} + + function T:foo() + return T:bar(5) + end + + function T:bar(i) + return i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + const TableTypeVar* t = get(requireType("T")); + REQUIRE(t != nullptr); + + std::optional fooProp = get(t->props, "foo"); + REQUIRE(bool(fooProp)); + + const FunctionTypeVar* foo = get(follow(fooProp->type)); + REQUIRE(bool(foo)); + + std::optional ret_ = first(foo->retType); + REQUIRE(bool(ret_)); + TypeId ret = follow(*ret_); + + REQUIRE_EQ(getPrimitiveType(ret), PrimitiveTypeVar::Number); +} + +TEST_CASE_FIXTURE(Fixture, "methods_are_topologically_sorted") +{ + CheckResult result = check(R"( + local T = {} + + function T:foo() + return T:bar(999), T:bar("hi") + end + + function T:bar(i) + return i + end + + local a, b = T:foo() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("a"))); + CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "local_function") +{ + CheckResult result = check(R"( + function f() + return 8 + end + + function g() + local function f() + return 'hello' + end + return f + end + + local h = g() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId h = follow(requireType("h")); + + const FunctionTypeVar* ftv = get(h); + REQUIRE(ftv != nullptr); + + std::optional rt = first(ftv->retType); + REQUIRE(bool(rt)); + + TypeId retType = follow(*rt); + CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(retType)); +} + +TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") +{ + CheckResult result = check(R"( + local o + o:method() + + local p + p:method() + + o = p + )"); +} + +/* + * We had a bug in instantiation where the argument types of 'f' and 'g' would be inferred as + * f {+ method: function(): (t2, T3...) +} + * g {+ method: function({+ method: function(): (t2, T3...) +}): (t5, T6...) +} + * + * The type of 'g' is totally wrong as t2 and t5 should be unified, as should T3 with T6. + * + * The correct unification of the argument to 'g' is + * + * {+ method: function(): (t5, T6...) +} + */ +TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") +{ + auto result = check(R"( + function f(o) + o:method() + end + + function g(o) + f(o) + end + )"); + + TypeId g = requireType("g"); + const FunctionTypeVar* gFun = get(g); + REQUIRE(gFun != nullptr); + + auto optionArg = first(gFun->argTypes); + REQUIRE(bool(optionArg)); + + TypeId arg = follow(*optionArg); + const TableTypeVar* argTable = get(arg); + REQUIRE(argTable != nullptr); + + std::optional methodProp = get(argTable->props, "method"); + REQUIRE(bool(methodProp)); + + const FunctionTypeVar* methodFunction = get(methodProp->type); + REQUIRE(methodFunction != nullptr); + + std::optional methodArg = first(methodFunction->argTypes); + REQUIRE(bool(methodArg)); + + REQUIRE_EQ(follow(*methodArg), follow(arg)); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") +{ + CheckResult result = check(R"( + --!strict + type Node = { Parent: Node?; } + local node: Node; + node.Parent = 1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Node?", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") +{ + CheckResult result = check(R"( + local T = {} + + function T.f(p) + for i, v in pairs(p) do + T.f(v) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") +{ + // In this case, we cannot know the element type of the table {}. It could be anything. + // We therefore must initially ascribe a free typevar to iter. + CheckResult result = check(R"( + for iter in pairs({}) do + iter:g().p = true + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_methods_defined_using_dot_syntax_and_explicit_self_parameter") +{ + check(R"( + local T = {} + + function T.method(self) + self:method() + end + + function T.method2(self) + self:method() + end + + T:method2() + )"); +} + +TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") +{ + check(R"( + local o + local v = o:i() + + function g(u) + v = u + end + + o:f(g) + o:h() + o:h() + )"); +} + +TEST_CASE_FIXTURE(Fixture, "require") +{ + fileResolver.source["game/A"] = R"( + local function hooty(x: number): string + return "Hi there!" + end + + return {hooty=hooty} + )"; + + fileResolver.source["game/B"] = R"( + local Hooty = require(game.A) + + local h -- free! + local i = Hooty.hooty(h) + )"; + + CheckResult aResult = frontend.check("game/A"); + dumpErrors(aResult); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = frontend.check("game/B"); + dumpErrors(bResult); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ModulePtr b = frontend.moduleResolver.modules["game/B"]; + + REQUIRE(b != nullptr); + + dumpErrors(bResult); + + std::optional iType = requireType(b, "i"); + REQUIRE_EQ("string", toString(*iType)); + + std::optional hType = requireType(b, "h"); + REQUIRE_EQ("number", toString(*hType)); +} + +TEST_CASE_FIXTURE(Fixture, "require_types") +{ + fileResolver.source["workspace/A"] = R"( + export type Point = {x: number, y: number} + + return {} + )"; + + fileResolver.source["workspace/B"] = R"( + local Hooty = require(workspace.A) + + local h: Hooty.Point + )"; + + CheckResult bResult = frontend.check("workspace/B"); + dumpErrors(bResult); + + ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; + REQUIRE(b != nullptr); + + TypeId hType = requireType(b, "h"); + REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); +} + +TEST_CASE_FIXTURE(Fixture, "require_a_variadic_function") +{ + fileResolver.source["game/A"] = R"( + local T = {} + function T.f(...) end + return T + )"; + + fileResolver.source["game/B"] = R"( + local A = require(game.A) + local f = A.f + )"; + + CheckResult result = frontend.check("game/B"); + + ModulePtr bModule = frontend.moduleResolver.getModule("game/B"); + REQUIRE(bModule != nullptr); + + TypeId f = follow(requireType(bModule, "f")); + + const FunctionTypeVar* ftv = get(f); + REQUIRE(ftv); + + auto iter = begin(ftv->argTypes); + auto endIter = end(ftv->argTypes); + + REQUIRE(iter == endIter); + REQUIRE(iter.tail()); + + CHECK(get(*iter.tail())); +} + +TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") +{ + CheckResult result = check(R"( + local f: any + local T = {} + + T.prop = f() + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* ttv = getMutable(requireType("T")); + REQUIRE(ttv); + REQUIRE(ttv->props.count("prop")); + + REQUIRE_EQ("any", toString(ttv->props["prop"].type)); +} + +TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") +{ + CheckResult result = check(R"( + local p: SomeModule.DoesNotExist + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE_EQ(result.errors[0], (TypeError{Location{{1, 17}, {1, 40}}, UnknownSymbol{"SomeModule.DoesNotExist"}})); +} + +TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") +{ + const std::string sourceA = R"( + )"; + + const std::string sourceB = R"( + local Hooty = require(script.Parent.A) + )"; + + fileResolver.source["game/Workspace/A"] = sourceA; + fileResolver.source["game/Workspace/B"] = sourceB; + + frontend.check("game/Workspace/A"); + frontend.check("game/Workspace/B"); + + ModulePtr aModule = frontend.moduleResolver.modules["game/Workspace/A"]; + ModulePtr bModule = frontend.moduleResolver.modules["game/Workspace/B"]; + + CHECK(aModule->errors.empty()); + REQUIRE_EQ(1, bModule->errors.size()); + CHECK_MESSAGE(get(bModule->errors[0]), "Should be IllegalRequire: " << toString(bModule->errors[0])); + + auto hootyType = requireType(bModule, "Hooty"); + + CHECK_MESSAGE(get(follow(hootyType)) != nullptr, "Should be an error: " << toString(hootyType)); +} + +TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") +{ + CheckResult result = check(R"( + local M = require(script.parent.DoesNotMatter) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto ed = get(result.errors[0]); + REQUIRE(ed); + + REQUIRE_EQ("parent", ed->symbol); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") +{ + CheckResult result = check(R"( + local A : any + function A.B() end + A:C() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId aType = requireType("A"); + CHECK_EQ(aType, typeChecker.anyType); +} + +TEST_CASE_FIXTURE(Fixture, "table_unifies_into_map") +{ + CheckResult result = check(R"( + local Instance: any + local UDim2: any + + function Create(instanceType) + return function(data) + local obj = Instance.new(instanceType) + for k, v in pairs(data) do + if type(k) == 'number' then + --v.Parent = obj + else + obj[k] = v + end + end + return obj + end + end + + local topbarShadow = Create'ImageLabel'{ + Name = "TopBarShadow"; + Size = UDim2.new(1, 0, 0, 3); + Position = UDim2.new(0, 0, 1, 0); + Image = "rbxasset://textures/ui/TopBar/dropshadow.png"; + BackgroundTransparency = 1; + Active = false; + Visible = false; + }; + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") +{ + CheckResult result = check(R"( + local p = function(x) return x end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + const Luau::FunctionTypeVar* fn = get(requireType("p")); + REQUIRE(fn); + auto ret = first(fn->retType); + REQUIRE(ret); + REQUIRE(get(follow(*ret))); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") +{ + CheckResult result = check(R"( + function foo(a, b) + return a(b) + end + + function bar() + local c: ((number)->number, number)->number = foo -- no error + c = foo -- no error + local d: ((number)->number, string)->number = foo -- error from arg 2 (string) not being convertable to number from the call a(b) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); + CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") +{ + CheckResult result = check(R"( + function foo(a, b) + return a(b) + end + + function bar() + local _: (string, string)->number = foo -- string cannot be converted to (string)->number + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); + CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "string_method") +{ + CheckResult result = check(R"( + local p = ("tacos"):len() + )"); + CHECK_EQ(0, result.errors.size()); + + CHECK_EQ(*requireType("p"), *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "string_function_indirect") +{ + CheckResult result = check(R"( + local s:string + local l = s.lower + local p = l(s) + )"); + CHECK_EQ(0, result.errors.size()); + + CHECK_EQ(*requireType("p"), *typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "string_function_other") +{ + CheckResult result = check(R"( + local s:string + local p = s:match("foo") + )"); + CHECK_EQ(0, result.errors.size()); + + CHECK_EQ(toString(requireType("p")), "string?"); +} + +TEST_CASE_FIXTURE(Fixture, "weird_case") +{ + CheckResult result = check(R"( + local function f() return 4 end + local d = math.deg(f()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "or_joins_types") +{ + CheckResult result = check(R"( + local s = "a" or 10 + local x:string|number = s + )"); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("x")), "number | string"); +} + +TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") +{ + CheckResult result = check(R"( + local s = "a" or 10 + local x:number|string = s + local y = x or "s" + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("y")), "number | string"); +} + +TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") +{ + CheckResult result = check(R"( + local s = "a" or "b" + local x:string = s + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(*requireType("s"), *typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "and_adds_boolean") +{ + CheckResult result = check(R"( + local s = "a" and 10 + local x:boolean|number = s + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "boolean | number"); +} + +TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") +{ + CheckResult result = check(R"( + local s = "a" and true + local x:boolean = s + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(*requireType("x"), *typeChecker.booleanType); +} + +TEST_CASE_FIXTURE(Fixture, "and_or_ternary") +{ + CheckResult result = check(R"( + local s = (1/2) > 0.5 and "a" or 10 + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "number | string"); +} + +TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") +{ + CheckResult result = check(R"( + local T = {} + function T.new(a: number?, b: number?, c: number?) return 5 end + local m = T.new() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + --!strict + local s + s(s, 'a') + )"); + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "occurs_check_does_not_recurse_forever_if_asked_to_traverse_a_cyclic_type") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + --!strict + function u(t, w) + u(u, t) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +#if 0 +// CLI-29798 +TEST_CASE_FIXTURE(Fixture, "crazy_complexity") +{ + CheckResult result = check(R"( + --!nonstrict + A:A():A():A():A():A():A():A():A():A():A():A() + )"); + + std::cout << "OK! Allocated " << typeChecker.typeVars.size() << " typevars" << std::endl; +} +#endif + +// We had a bug where a cyclic union caused a stack overflow. +// ex type U = number | U +TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") +{ + CheckResult result = check(R"( + --!strict + + function f(a, b) + a:g(b or {}) + a:g(b) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "it_is_ok_not_to_supply_enough_retvals") +{ + CheckResult result = check(R"( + function get_two() return 5, 6 end + + local a = get_two() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions2") +{ + CheckResult result = check(R"( + function foo() end + + function bar() + local function foo() end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions_allowed_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + function foo() end + + function foo() end + + function bar() + local function foo() end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions_with_different_signatures_not_allowed_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + function foo(): number + return 1 + end + foo() + + function foo(n: number): number + return 2 + end + foo() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> number", toString(tm->wantedType)); + CHECK_EQ("(number) -> number", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals") +{ + CheckResult result = check(R"( + local T = {} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("T", toString(requireType("T"))); +} + +TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") +{ + CheckResult result = check(R"( + function foo(arr) + local work = {} + for i = 1, #arr do + work[i] = arr[i] + end + + return arr + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + const FunctionTypeVar* fooType = get(requireType("foo")); + REQUIRE(fooType); + + std::optional fooArg1 = first(fooType->argTypes); + REQUIRE(fooArg1); + + const TableTypeVar* fooArg1Table = get(*fooArg1); + REQUIRE(fooArg1Table); + + CHECK_EQ(fooArg1Table->state, TableState::Generic); +} + +TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotation") +{ + CheckResult result = check(R"( + local i = 0 + function most_of_the_natural_numbers(): number? + if i < 10 then + i = i + 1 + return i + else + return nil + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); + + std::optional retType = first(functionType->retType); + REQUIRE(retType); + CHECK(get(*retType)); +} + +TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") +{ + CheckResult result = check(R"( + function apply(f, x) + return f(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("apply")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(2, argVec.size()); + + const FunctionTypeVar* fType = get(argVec[0]); + REQUIRE(fType != nullptr); + + std::vector fArgs = flatten(fType->argTypes).first; + + TypeId xType = argVec[1]; + + CHECK_EQ(1, fArgs.size()); + CHECK_EQ(xType, fArgs[0]); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") +{ + CheckResult result = check(R"( + function bottomupmerge(comp, a, b, left, mid, right) + local i, j = left, mid + for k = left, right do + if i < mid and (j > right or not comp(a[j], a[i])) then + b[k] = a[i] + i = i + 1 + else + b[k] = a[j] + j = j + 1 + end + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("bottomupmerge")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(6, argVec.size()); + + const FunctionTypeVar* fType = get(argVec[0]); + REQUIRE(fType != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") +{ + CheckResult result = check(R"( + function swap(p) + local t = p[0] + p[0] = p[1] + p[1] = t + return nil + end + + function swapTwice(p) + swap(p) + swap(p) + return p + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("swapTwice")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(1, argVec.size()); + + const TableTypeVar* argType = get(follow(argVec[0])); + REQUIRE(argType != nullptr); + + CHECK(bool(argType->indexer)); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") +{ + CheckResult result = check(R"( + function bottomupmerge(comp, a, b, left, mid, right) + local i, j = left, mid + for k = left, right do + if i < mid and (j > right or not comp(a[j], a[i])) then + b[k] = a[i] + i = i + 1 + else + b[k] = a[j] + j = j + 1 + end + end + end + + function mergesort(arr, comp) + local work = {} + for i = 1, #arr do + work[i] = arr[i] + end + local width = 1 + while width < #arr do + for i = 1, #arr, 2*width do + bottomupmerge(comp, arr, work, i, math.min(i+width, #arr), math.min(i+2*width-1, #arr)) + end + local temp = work + work = arr + arr = temp + width = width * 2 + end + return arr + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + /* + * mergesort takes two arguments: an array of some type T and a function that takes two Ts. + * We must assert that these two types are in fact the same type. + * In other words, comp(arr[x], arr[y]) is well-typed. + */ + + const FunctionTypeVar* ftv = get(requireType("mergesort")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(2, argVec.size()); + + const TableTypeVar* arg0 = get(follow(argVec[0])); + REQUIRE(arg0 != nullptr); + REQUIRE(bool(arg0->indexer)); + + const FunctionTypeVar* arg1 = get(follow(argVec[1])); + REQUIRE(arg1 != nullptr); + REQUIRE_EQ(2, size(arg1->argTypes)); + + std::vector arg1Args = flatten(arg1->argTypes).first; + + CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[0]); + CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); +} + +TEST_CASE_FIXTURE(Fixture, "error_types_propagate") +{ + CheckResult result = check(R"( + local err = (true).x + local c = err.Parent.Reward.GetChildren + local d = err.Parent.Reward + local e = err.Parent + local f = err + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownProperty* err = get(result.errors[0]); + REQUIRE(err != nullptr); + CHECK_EQ("boolean", toString(err->table)); + CHECK_EQ("x", err->key); + + CHECK(nullptr != get(requireType("c"))); + CHECK(nullptr != get(requireType("d"))); + CHECK(nullptr != get(requireType("e"))); + CHECK(nullptr != get(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") +{ + CheckResult result = check(R"( + local a = unknown.Parent.Reward.GetChildren() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownSymbol* err = get(result.errors[0]); + REQUIRE(err != nullptr); + + CHECK_EQ("unknown", err->name); + + CHECK(nullptr != get(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") +{ + CheckResult result = check(R"( + local a = Utility.Create "Foo" {} + )"); + + TypeId aType = requireType("a"); + + REQUIRE_MESSAGE(nullptr != get(aType), "Not an error: " << toString(aType)); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") +{ + CheckResult result = check(R"( + function add(a: number, b: string) + return a + tonumber(b), a .. b + end + local n, s = add(2,"3") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* functionType = get(requireType("add")); + + std::optional retType = first(functionType->retType); + CHECK_EQ(std::optional(typeChecker.numberType), retType); + CHECK_EQ(requireType("n"), typeChecker.numberType); + CHECK_EQ(requireType("s"), typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") +{ + CheckResult result = check(R"( + local PI=3.1415926535897931 + local SOLAR_MASS=4*PI * PI + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable") +{ + CheckResult result = check(R"( + function add(a: number, b: any) + return a + b + end + local t = add(1,2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") +{ + CheckResult result = check(R"( + local a = 4 + 8 + local b = a + 9 + local s = 'hotdogs' + local t = s .. s + local c = b - a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("string", toString(requireType("s"))); + CHECK_EQ("string", toString(requireType("t"))); + CHECK_EQ("number", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") +{ + CheckResult result = check(R"( + --!strict + local Vec3 = {} + Vec3.__index = Vec3 + function Vec3.new() + return setmetatable({x=0, y=0, z=0}, Vec3) + end + + export type Vec3 = typeof(Vec3.new()) + + local thefun: any = function(self, o) return self end + + local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun + + Vec3.__mul = multiply + + local a = Vec3.new() + local b = Vec3.new() + local c = a * b + local d = a * 2 + local e = a * 'cabbage' + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Vec3", toString(requireType("a"))); + CHECK_EQ("Vec3", toString(requireType("b"))); + CHECK_EQ("Vec3", toString(requireType("c"))); + CHECK_EQ("Vec3", toString(requireType("d"))); + CHECK(get(requireType("e"))); +} + +TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") +{ + CheckResult result = check(R"( + --!strict + local Vec3 = {} + Vec3.__index = Vec3 + function Vec3.new() + return setmetatable({x=0, y=0, z=0}, Vec3) + end + + export type Vec3 = typeof(Vec3.new()) + + local thefun: any = function(self, o) return self end + + local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun + + Vec3.__mul = multiply + + local a = Vec3.new() + local b = Vec3.new() + local c = b * a + local d = 2 * a + local e = 'cabbage' * a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Vec3", toString(requireType("a"))); + CHECK_EQ("Vec3", toString(requireType("b"))); + CHECK_EQ("Vec3", toString(requireType("c"))); + CHECK_EQ("Vec3", toString(requireType("d"))); + CHECK(get(requireType("e"))); +} + +TEST_CASE_FIXTURE(Fixture, "compare_numbers") +{ + CheckResult result = check(R"( + local a = 441 + local b = 0 + local c = a < b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "compare_strings") +{ + CheckResult result = check(R"( + local a = '441' + local b = '0' + local c = a < b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_metatable") +{ + CheckResult result = check(R"( + local a = {} + local b = {} + local c = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* gen = get(result.errors[0]); + + REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") +{ + CheckResult result = check(R"( + local M = {} + function M.new() + return setmetatable({}, M) + end + type M = typeof(M.new()) + + local a = M.new() + local b = M.new() + local c = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* gen = get(result.errors[0]); + REQUIRE(gen != nullptr); + REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") +{ + CheckResult result = check(R"( + --!strict + local M = {} + function M.new() + return setmetatable({}, M) + end + function M.__lt(left, right) return true end + + local a = M.new() + local b = {} + local c = a < b -- line 10 + local d = b < a -- line 11 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + REQUIRE_EQ((Location{{10, 18}, {10, 23}}), result.errors[0].location); + + REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location); +} + +TEST_CASE_FIXTURE(Fixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") +{ + CheckResult result = check(R"( + --!strict + local M = {} + function M.new() + return setmetatable({}, M) + end + function M.__lt(left, right) return true end + type M = typeof(M.new()) + + local a = M.new() + local b = {} + local c = a < b -- line 10 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = get(result.errors[0]); + REQUIRE(err != nullptr); + + // Frail. :| + REQUIRE_EQ("Types M and b cannot be compared with < because they do not have the same metatable", err->message); +} + +TEST_CASE_FIXTURE(Fixture, "in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators") +{ + CheckResult result = check(R"( + --!nonstrict + + function maybe_a_number(): number? + return 50 + end + + local a = maybe_a_number() < maybe_a_number() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * This test case exposed an oversight in the treatment of free tables. + * Free tables, like free TypeVars, need to record the scope depth where they were created so that + * we do not erroneously let-generalize them when they are used in a nested lambda. + * + * For more information about let-generalization, see + * + * The important idea here is that the return type of Counter.new is a table with some metatable. + * That metatable *must* be the same TypeVar as the type of Counter. If it is a copy (produced by + * the generalization process), then it loses the knowledge that its metatable will have an :incr() + * method. + */ +TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") +{ + CheckResult result = check(R"( + local Counter = {} + Counter.__index = Counter + + function Counter.new() + local self = setmetatable({count=0}, Counter) + return self + end + + function Counter:incr() + self.count = 1 + return self.count + end + + local self = Counter.new() + print(self:incr()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* counterType = getMutable(requireType("Counter")); + REQUIRE(counterType); + + const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); + REQUIRE(newType); + + std::optional newRetType = *first(newType->retType); + REQUIRE(newRetType); + + const MetatableTypeVar* newRet = get(follow(*newRetType)); + REQUIRE(newRet); + + const TableTypeVar* newRetMeta = get(newRet->metatable); + REQUIRE(newRetMeta); + + CHECK(newRetMeta->props.count("incr")); + CHECK_EQ(follow(newRet->metatable), follow(requireType("Counter"))); +} + +// TODO: CLI-39624 +TEST_CASE_FIXTURE(Fixture, "instantiate_tables_at_scope_level") +{ + CheckResult result = check(R"( + --!strict + local Option = {} + Option.__index = Option + function Option.Is(obj) + return (type(obj) == "table" and getmetatable(obj) == Option) + end + return Option + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") +{ + const std::string code = R"( + function f(a) + if type(a) == "boolean" then + local a1 = a + elseif a.fn() then + local a2 = a + else + local a3 = a + end + end + )"; + CheckResult result = check(code); + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowing") +{ + CheckResult result = check(R"( + local function f(x, y) + return x or y + end + + local function dont_crash(x, y) + local z: typeof(f(x, y)) = f(x, y) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "x_or_y_forces_both_x_and_y_to_be_of_same_type_if_either_is_free") +{ + CheckResult result = check(R"( + local function f(x, y) return x or y end + + local x = f(1, 2) + local y = f(3, "foo") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(*requireType("x"), *typeChecker.numberType); + + CHECK_EQ(result.errors[0], (TypeError{Location{{4, 23}, {4, 28}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); +} + +TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocate_memory") +{ + CheckResult result = check(R"( + ("foo") + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + )"); + + ModulePtr module = getMainModule(); + CHECK_GE(50, module->internalTypes.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") +{ + CheckResult result = check(R"( + --!strict + function f(U) + U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() + end + )"); + + ModulePtr module = getMainModule(); + CHECK_GE(100, module->internalTypes.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") +{ + CheckResult result = check(R"( + --!strict + -- An example of exponential blowup in number of types + -- The problem is that if we define function f(a) return x end + -- then this has type (t)->T where x:T + -- *but* it copies T each time f is applied + -- so { left = f("hi"), right = f(5) } + -- has type { left : T_L, right : T_R } + -- where T_L and T_R are copies of T. + -- x0 : T0 where T0 = {} + local x0 = {} + -- f0 : (t)->T0 + local function f0(a) return x0 end + -- x1 : T1 where T1 = { left : T0_L, right : T0_R } + local x1 = { left = f0("hi"), right = f0(5) } + -- f1 : (t)->T1 + local function f1(a) return x1 end + -- x2 : T2 where T2 = { left : T1_L, right : T1_R } + local x2 = { left = f1("hi"), right = f1(5) } + -- f2 : (t)->T2 + local function f2(a) return x2 end + -- etc etc + local x3 = { left = f2("hi"), right = f2(5) } + local function f3(a) return x3 end + local x4 = { left = f3("hi"), right = f3(5) } + return x4 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + ModulePtr module = getMainModule(); + + // If we're not careful about copying, this ends up with O(2^N) types rather than O(N) + // (in this case 5 vs 31). + CHECK_GE(5, module->interfaceTypes.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "mutual_recursion") +{ + CheckResult result = check(R"( + --!strict + + function newPlayerCharacter() + startGui() -- Unknown symbol 'startGui' + end + + local characterAddedConnection: any + function startGui() + characterAddedConnection = game:GetService("Players").LocalPlayer.CharacterAdded:connect(newPlayerCharacter) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") +{ + CheckResult result = check(R"( + --!strict + local x = nil + function f() g() end + -- make sure print(x) doen't get toposorted here, breaking the mutual block + function g() x = f end + print(x) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: a, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = "lo", i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: b, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = 5, i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_ERRORS(result); + + // We had a UAF in this example caused by not cloning type function arguments + ModulePtr module = frontend.moduleResolver.getModule("MainModule"); + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes); + freeze(module->interfaceTypes); + module->internalTypes.clear(); + module->astTypes.clear(); + + // Make sure the error strings don't include "VALUELESS" + for (auto error : module->errors) + CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); +} + +TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") +{ + // CLI-30902 + CheckResult result = check(R"( + --!strict + + type Foo = { + fooConn: () -> () | nil + } + + local Foo = {} + Foo.__index = Foo + + function Foo.new() + local self: Foo = { + fooConn = nil, + } + setmetatable(self, Foo) + + self.fooConn = function() + self:method() -- Key 'method' not found in table self + end + + return self + end + + function Foo:method() + print("foo") + end + + local foo = Foo.new() + + -- TODO This is the best our current refinement support can offer :( + local bar = foo.fooConn + if bar then bar() end + + -- foo.fooConn() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") +{ + CheckResult result = check(R"( + local a: any + local b + for _, i in pairs(a) do + b = i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("b"))); +} + +// In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type +// checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. +TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") +{ +#if defined(LUAU_ENABLE_ASAN) + int limit = 250; +#elif defined(_DEBUG) || defined(_NOOPT) + int limit = 350; +#else + int limit = 600; +#endif + ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; + ScopedFastInt luauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", limit - 100}; + ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", 0}; + + CHECK_NOTHROW(check("print('Hello!')")); + CHECK_THROWS_AS(check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"), std::runtime_error); +} + +TEST_CASE_FIXTURE(Fixture, "check_block_recursion_limit") +{ +#if defined(LUAU_ENABLE_ASAN) + int limit = 250; +#elif defined(_DEBUG) || defined(_NOOPT) + int limit = 350; +#else + int limit = 600; +#endif + + ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; + ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", limit - 100}; + + CheckResult result = check(rep("do ", limit) + "local a = 1" + rep(" end", limit)); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(nullptr != get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") +{ +#if defined(LUAU_ENABLE_ASAN) + int limit = 250; +#elif defined(_DEBUG) || defined(_NOOPT) + int limit = 350; +#else + int limit = 600; +#endif + ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; + ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", limit - 100}; + + CheckResult result = check(R"(("foo"))" + rep(":lower()", limit)); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(nullptr != get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_basic") +{ + CheckResult result = check(R"( + local s = 10 + s += 20 + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "number"); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op") +{ + CheckResult result = check(R"( + local s = 10 + s += true + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}})); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") +{ + CheckResult result = check(R"( + local s = 'hello' + s += 10 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") +{ + CheckResult result = check(R"( + --!strict + type V2B = { x: number, y: number } + local v2b: V2B = { x = 0, y = 0 } + local VMT = {} + type V2 = typeof(setmetatable(v2b, VMT)) + + function VMT.__add(a: V2, b: V2): V2 + return setmetatable({ x = a.x + b.x, y = a.y + b.y }, VMT) + end + + local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) + local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) + v1 += v2 + )"); + CHECK_EQ(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_metatable") +{ + CheckResult result = check(R"( + --!strict + type V2B = { x: number, y: number } + local v2b: V2B = { x = 0, y = 0 } + local VMT = {} + type V2 = typeof(setmetatable(v2b, VMT)) + + function VMT.__mod(a: V2, b: V2): number + return a.x * b.x + a.y * b.y + end + + local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) + local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) + v1 %= v2 + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + CHECK_EQ(*tm->wantedType, *requireType("v2")); + CHECK_EQ(*tm->givenType, *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "dont_ice_if_a_TypePack_is_an_error") +{ + CheckResult result = check(R"( + --!strict + function f(s) + print(s) + return f + end + + f("foo")("bar") + )"); +} + +TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") +{ + CheckResult result = check(R"( + --!nonstrict + + function f() + return 114 + end + + return function() + return f():andThen() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") +{ + CheckResult result = check(R"( + function onerror() end + function foo() end + xpcall(foo, onerror) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments") +{ + CheckResult result = check(R"( + local mycb: (number, number) -> () + + function f() end + + mycb = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "call_to_any_yields_any") +{ + CheckResult result = check(R"( + local a: any + local b = a() + )"); + + REQUIRE_EQ("any", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals") +{ + CheckResult result = check(R"( + --!nonstrict + foo = true + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals2") +{ + CheckResult result = check(R"( + --!nonstrict + foo = function() return 1 end + foo = "now i'm a string!" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> (...any)", toString(tm->wantedType)); + CHECK_EQ("string", toString(tm->givenType)); + CHECK_EQ("() -> (...any)", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") +{ + CheckResult result = check(R"( + --!strict + foo = true + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownSymbol* us = get(result.errors[0]); + REQUIRE(us); + CHECK_EQ("foo", us->name); +} + +TEST_CASE_FIXTURE(Fixture, "globals_everywhere") +{ + CheckResult result = check(R"( + --!nonstrict + foo = 1 + + if true then + bar = 2 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfAny") +{ + CheckResult result = check(R"( +local x: any = {} +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfSealed") +{ + CheckResult result = check(R"( +local x: {prop: number} = {prop=9999} +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") +{ + CheckResult result = check(R"( +local x: number = 9999 +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") +{ + CheckResult result = check(R"( +local x = (true).foo +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") +{ + CheckResult result = check(R"( +function f() return 1; end +function g() return 2; end +(f or g)() +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "CallAndOrOfFunctions") +{ + CheckResult result = check(R"( +function f() return 1; end +function g() return 2; end +local x = false +(x and f or g)() +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers") +{ + CheckResult result = check(R"( +local x = {} +x.a = "a" +x[0] = true +x.b = 37 +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") +{ + CheckResult result = check(R"( + do + local a = 1 + end + + print(a) -- oops! + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownSymbol* us = get(result.errors[0]); + REQUIRE(us); + CHECK_EQ(us->name, "a"); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") +{ + CheckResult result = check(R"( + while true do + local a = 1 + end + + print(a) -- oops! + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownSymbol* us = get(result.errors[0]); + REQUIRE(us); + CHECK_EQ(us->name, "a"); +} + +TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indeces") +{ + CheckResult result = check(R"( + local key + for i, e in ipairs({}) do key = i end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("number", toString(requireType("key"))); +} + +TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") +{ + CHECK_NOTHROW(check(R"( + --!nonstrict + f,g = ... + f(g(...))[...] = nil + f,xpcall = ... + local value = g(...)(g(...)) + )")); + + CHECK_EQ("any", toString(requireType("value"))); +} + +TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + + local function f1(v): number? + if v then + return 1 + end + end + + local function f2(v) + if v then + return 1 + end + end + + local function f3(v): () + if v then + return + end + end + + local function f4(v) + if v then + return + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + FunctionExitsWithoutReturning* err = get(result.errors[0]); + CHECK(err); +} + +TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_strict") +{ + CheckResult result = check(R"( + --!strict + + local function f1(v): number? + if v then + return 1 + end + end + + local function f2(v) + if v then + return 1 + end + end + + local function f3(v): () + if v then + return + end + end + + local function f4(v) + if v then + return + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + FunctionExitsWithoutReturning* annotatedErr = get(result.errors[0]); + CHECK(annotatedErr); + + FunctionExitsWithoutReturning* inferredErr = get(result.errors[1]); + CHECK(inferredErr); +} + +// TEST_CASE_FIXTURE(Fixture, "infer_method_signature_of_argument") +// { +// CheckResult result = check(R"( +// function f(a) +// if a.cond then +// return a.method() +// end +// end +// )"); + +// LUAU_REQUIRE_NO_ERRORS(result); + +// CHECK_EQ("A", toString(requireType("f"))); +// } + +TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") +{ + fileResolver.source["Modules/A"] = ""; + fileResolver.sourceTypes["Modules/A"] = SourceCode::Local; + + fileResolver.source["Modules/B"] = R"( + local M = require(script.Parent.A) + )"; + + CheckResult result = frontend.check("Modules/B"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields_errors_spanning_argument") +{ + CheckResult result = check(R"( + function foo(a: number, b: string) end + + foo("Test", 123) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(result.errors[0], (TypeError{Location{Position{3, 12}, Position{3, 18}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.stringType, + }})); + + CHECK_EQ(result.errors[1], (TypeError{Location{Position{3, 20}, Position{3, 23}}, TypeMismatch{ + typeChecker.stringType, + typeChecker.numberType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") +{ + CheckResult result = check(R"( + --!nonstrict + + function Test(a) + return 1, "" + end + + + local tab = {} + table.insert(tab, Test(1)); + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + opts.maxTableLength = 0; + + CHECK_EQ("{any}", toString(requireType("tab"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_return_values") +{ + CheckResult result = check(R"( + --!strict + + function f() + return 55 + end + + local a, b = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK(acm->context == CountMismatch::Result); +} + +TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") +{ + CheckResult result = check(R"( + --!strict + + function f(): (number, string) + return 55 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Return); +} + +TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") +{ + CheckResult result = check(R"( + --!strict + local foo = { + value = 10 + } + local mt = {} + setmetatable(foo, mt) + + mt.__unm = function(val: typeof(foo)): string + return val.value .. "test" + end + + local a = -foo + + local b = 1+-1 + + local bar = { + value = 10 + } + local c = -bar -- disallowed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("string", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + + GenericError* gen = get(result.errors[0]); + REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); +} + +TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") +{ + CheckResult result = check(R"( + local b = not "string" + local c = not (math.random() > 0.5 and "string" or 7) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("boolean", toString(requireType("b"))); + REQUIRE_EQ("boolean", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") +{ + CheckResult result = check(R"( + --!strict + local a = "1.24" + 123 -- not allowed + + local foo = { + value = 10 + } + + local b = foo + 1 -- not allowed + + local bar = { + value = 1 + } + + local mt = {} + + setmetatable(bar, mt) + + mt.__add = function(a: typeof(bar), b: number): number + return a.value + b + end + + local c = bar + 1 -- allowed + + local d = bar + foo -- not allowed + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); + REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); + + TypeMismatch* tm2 = get(result.errors[2]); + CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); + CHECK_EQ(*tm2->givenType, *requireType("foo")); + + GenericError* gen2 = get(result.errors[1]); + REQUIRE_EQ(gen2->message, "Binary operator '+' not supported by types 'foo' and 'number'"); +} + +// CLI-29033 +TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") +{ + CheckResult result = check(R"( + function merge(lower, greater) + if lower.y == greater.y then + end + end + )"); + + if (FFlag::LuauEqConstraint) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); + } +} + +TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable") +{ + CheckResult result = check(R"( + local x + print((x == true and (x .. "y")) .. 1) + )"); + + if (FFlag::LuauEqConstraint) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ("Type 'boolean' could not be converted into 'number | string'", toString(result.errors[0])); + CHECK_EQ("Type 'boolean | string' could not be converted into 'number | string'", toString(result.errors[1])); + } +} + +TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") +{ + CheckResult result = check(R"( + local x + print("foo" .. x) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") +{ + std::vector ops = {"+", "-", "*", "/", "%", "^", ".."}; + + std::string src = R"( + function foo(a, b) + )"; + + for (const auto& op : ops) + src += "local _ = a " + op + "b\n"; + + src += "end"; + + CheckResult result = check(src); + LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); + + CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") +{ + CheckResult result = check(R"( + function foo(a, b): number + return 0 + end + + local a: (string)->number = foo + local b: (number, number)->(number, number) = foo + + local c: (string, number)->number = foo -- no error + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + auto tm1 = get(result.errors[0]); + REQUIRE(tm1); + + CHECK_EQ("(string) -> number", toString(tm1->wantedType)); + CHECK_EQ("(string, *unknown*) -> number", toString(tm1->givenType)); + + auto tm2 = get(result.errors[1]); + REQUIRE(tm2); + + CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); + CHECK_EQ("(string, *unknown*) -> number", toString(tm2->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "setmetatable_cant_be_used_to_mutate_global_types") +{ + { + Fixture fix; + + // inherit env from parent fixture checker + fix.typeChecker.globalScope = typeChecker.globalScope; + + fix.check(R"( +--!nonstrict +type MT = typeof(setmetatable) +function wtf(arg: {MT}): typeof(table) + arg = wtf(arg) +end +)"); + } + + // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down + // note: it's important for typeck to be destroyed at this point! + { + for (auto& p : typeChecker.globalScope->bindings) + { + toString(p.second.typeId); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas + } + } +} + +TEST_CASE_FIXTURE(Fixture, "evil_table_unification") +{ + // this code re-infers the type of _ while processing fields of _, which can cause use-after-free + check(R"( +--!nonstrict +_ = ... +_:table(_,string)[_:gsub(_,...,n0)],_,_:gsub(_,string)[""],_:split(_,...,table)._,n0 = nil +do end +)"); +} + +TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") +{ + check(R"( +--!nonstrict +function _(...):((typeof(not _))&(typeof(not _)))&((typeof(not _))&(typeof(not _))) +_(...)(setfenv,_,not _,"")[_] = nil +end +do end +_(...)(...,setfenv,_):_G() +)"); +} + +TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") +{ + CheckResult result = check(R"( + type Pair = {first: T, second: U} + local a: Pair + local b: Pair + + a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK_EQ("Pair", toString(tm->wantedType)); + CHECK_EQ("Pair", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") +{ + // this has a risk of creating cyclic type packs, causing infinite loops / OOMs + check(R"( +--!nonstrict +_ += _(_,...) +repeat +_ += _(...) +until ... + _ +)"); + + check(R"( +--!nonstrict +_ += _(_(...,...),_(...)) +repeat +until _ +)"); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_follow") +{ + check(R"( +--!nonstrict +l0,table,_,_,_ = ... +_,_,_,_.time(...)._.n0,l0,_ = function(l0) +end,_.__index,(_),_.time(_.n0 or _,...) +for l0=...,_,"" do +end +_ += not _ +do end +)"); + + check(R"( +--!nonstrict +n13,_,table,_,l0,_,_ = ... +_,n0[(_)],_,_._(...)._.n39,l0,_._ = function(l84,...) +end,_.__index,"",_,l0._(nil) +for l0=...,table.n5,_ do +end +_:_(...).n1 /= _ +do +_(_ + _) +do end +end +)"); +} + +TEST_CASE_FIXTURE(Fixture, "and_binexps_dont_unify") +{ + CheckResult result = check(R"( + --!strict + local t = {} + while true and t[1] do + print(t[1].test) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +struct FindFreeTypeVars +{ + bool foundOne = false; + + template + void cycle(ID) + { + } + + template + bool operator()(ID, T) + { + return !foundOne; + } + + template + bool operator()(ID, Unifiable::Free) + { + foundOne = true; + return false; + } +}; + +TEST_CASE_FIXTURE(Fixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") +{ + CheckResult result = check("local x = setmetatable({})"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") +{ + // This code doesn't pass typechecking. We just care that it doesn't crash. + (void)check(R"( + --!nonstrict + function _:_(...) + end + + repeat + if _ then + else + _ = ... + end + until _ + + for _ in _() do + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") +{ + CheckResult result = check(R"( + type A = number + type A = string -- Redefinition of type 'A', previously defined at line 1 + local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery") +{ + CheckResult result = check(R"( + local x = + local a = 7 + )"); + LUAU_REQUIRE_ERRORS(result); + + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::Number); +} + +// Check that type checker knows about error expressions +TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_assert") +{ + CheckResult result = check("function +() local _ = true end"); + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error") +{ + { + CheckResult result = check(R"( + --!strict + local t = { x = 10, y = 20 } + return t. + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + } + + { + CheckResult result = check(R"( + --!strict + export type = number + export type = string + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + } + + { + CheckResult result = check(R"( + --!strict + function string.() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + } + + { + CheckResult result = check(R"( + --!strict + local function () end + local function () end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + } + + { + CheckResult result = check(R"( + --!strict + local dm = {} + function dm.() end + function dm.() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + } +} + +TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") +{ + CheckResult result = check(R"( + local a: boolean = true + local b: boolean = false + local foo = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") +{ + CheckResult result = check(R"( + local a: number | string = "" + local b: number | string = 1 + local foo = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = Table + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Wrapped", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = (Table) -> string + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") +{ + CheckResult result = check(R"( + local foo: any + + print(foo[(true).x]) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownProperty* up = get(result.errors[0]); // Should probably be NotATable + REQUIRE(up); + CHECK_EQ("boolean", toString(up->table)); + CHECK_EQ("x", up->key); +} + +TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") +{ + { + CheckResult result = check(R"( + function unreachablecodepath(a): number + while true do + if a then return 10 end + end + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } + + { + CheckResult result = check(R"( + function reachablecodepath(a): number + while true do + if a then break end + return 10 + end + + print("x") -- correct error + end + reachablecodepath(4) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK(get(result.errors[0])); + } + + { + CheckResult result = check(R"( + function unreachablecodepath(a): number + repeat + if a then return 10 end + until false + + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } + + { + CheckResult result = check(R"( + function reachablecodepath(a, b): number + repeat + if a then break end + + if b then return 10 end + until false + + print("x") -- correct error + end + reachablecodepath(4) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK(get(result.errors[0])); + } + + { + CheckResult result = check(R"( + function unreachablecodepath(a: number?): number + repeat + return 10 + until a ~= nil + + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } +} + +TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") +{ + CheckResult result = check(R"( + --!strict + local _ + _ += _ and _ or _ and _ or _ and _ + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") +{ + CheckResult result = check(R"( + --!strict + local a: number | (string | boolean) | nil + local b: number = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); +} + +// Check that recursive intersection type doesn't generate an OOM +TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") +{ + CheckResult result = check(R"( + function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any + end + type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) + _(_) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") +{ + // In non-strict mode, global definition is still allowed + { + CheckResult result = check(R"( + --!nonstrict + a = a + 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } + + // In strict mode we no longer generate two errors from lhs + { + CheckResult result = check(R"( + --!strict + a += 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } + + // In non-strict mode, compound assignment is not a definition, it's a modification + { + CheckResult result = check(R"( + --!nonstrict + a += 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") +{ + CheckResult result = check(R"( + local t = {} + for _ in t do + for _ in assert(missing()) do + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") +{ + CheckResult result = check(R"( + local foo: Id = 1 + type Id = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") +{ + CheckResult result = check(R"( + local x: {number|number} = {1, 2, 3} + local y = x[1] - x[2] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "metatable_of_any_can_be_a_table") +{ + CheckResult result = check(R"( +--!strict +local T: any +T = {} +T.__index = T +function T.new(...) + local self = {} + setmetatable(self, T) + self:construct(...) + return self +end +function T:construct(index) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") +{ + const std::string code = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb = aa + )"; + + const std::string expected = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type A = () -> (number, B) + type B = () -> (string, A) + local a: A + local b: B + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); + CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "generic_param_remap") +{ + const std::string code = R"( + -- An example of a forwarded use of a type that has different type arguments than parameters + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb = aa + )"; + + const std::string expected = R"( + + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") +{ + CheckResult result = check(R"( + export type Foo = number + type Foo = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto dtd = get(result.errors[0]); + REQUIRE(dtd); + CHECK_EQ(dtd->name, "Foo"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") +{ + CheckResult result = check(R"( + foo + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstExprError") +{ + CheckResult result = check(R"( + local a = foo: + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_ice_on_astexprerror") +{ + CheckResult result = check(R"( + local foo = -; + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator") +{ + CheckResult result = check(R"( +--!strict +local a: number? = nil +local b: number = a or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator2") +{ + CheckResult result = check(R"( +--!nonstrict +local a: number? = nil +local b: number = a or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_strip_nil_from_rhs_or_operator") +{ + CheckResult result = check(R"( +--!strict +local a: number? = nil +local b: number = 1 or a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ("number?", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") +{ + ScopedFastFlag sffs2{"LuauGenericFunctions", true}; + + CheckResult result = check(R"( + --!strict + local tbl = {} + function tbl:abc(a: number, b: number) + return a + end + tbl:abc(1, 2) -- Line 6 + -- | Column 14 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + TypeId type = requireTypeAtPosition(Position(6, 14)); + CHECK_EQ("(tbl, number, number) -> number", toString(type)); + auto ftv = get(type); + REQUIRE(ftv); + CHECK(ftv->hasSelf); +} + +TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") +{ + CheckResult result = check(R"( + --!strict + function Funky() + local a: number = foo + end + + local foo: string = 'hello' + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto e = result.errors.front(); + REQUIRE_MESSAGE(get(e) != nullptr, "Expected UnknownSymbol, but got " << e); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") +{ + ScopedFastFlag sffs3{"LuauGenericFunctions", true}; + ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + type Node = { value: T, child: Node? } + + local function visitor(node: Node?) + local a: Node + + if node then + a = node.child -- Observe the output of the error message. + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto e = get(result.errors[0]); + CHECK_EQ("Node?", toString(e->givenType)); + CHECK_EQ("Node", toString(e->wantedType)); +} + +TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") +{ + CheckResult result = check(R"( + type Array = { [number]: T } + type Fiber = { id: number } + type null = {} + + local fiberStack: Array = {} + local index = 0 + + local function f(fiber: Fiber) + local a = fiber ~= fiberStack[index] + local b = fiberStack[index] ~= fiber + end + + return f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_call_expression") +{ + fileResolver.source["game/A"] = R"( +--!strict +return { def = 4 } + )"; + + fileResolver.source["game/B"] = R"( +--!strict +local tbl = { abc = require(game.A) } +local a : string = "" +a = tbl.abc.def + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_type_mismatch") +{ + fileResolver.source["game/A"] = R"( +return { def = 4 } + )"; + + fileResolver.source["game/B"] = R"( +local tbl: string = require(game.A) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") +{ + fileResolver.source["workspace/A"] = R"( + export type myvec2 = {x: number, y: number} + return {} + )"; + + fileResolver.source["workspace/B"] = R"( + export type myvec3 = {x: number, y: number, z: number} + return {} + )"; + + fileResolver.source["workspace/C"] = R"( + local Foo, Bar = require(workspace.A), require(workspace.B) + + local a: Foo.myvec2 + local b: Bar.myvec3 + )"; + + CheckResult result = frontend.check("workspace/C"); + LUAU_REQUIRE_NO_ERRORS(result); + ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; + + REQUIRE(m != nullptr); + + std::optional aTypeId = lookupName(m->getModuleScope(), "a"); + REQUIRE(aTypeId); + const Luau::TableTypeVar* aType = get(follow(*aTypeId)); + REQUIRE(aType); + REQUIRE(aType->props.size() == 2); + + std::optional bTypeId = lookupName(m->getModuleScope(), "b"); + REQUIRE(bTypeId); + const Luau::TableTypeVar* bType = get(follow(*bTypeId)); + REQUIRE(bType); + REQUIRE(bType->props.size() == 3); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") +{ + CheckResult result = check("type t10 = typeof(table)"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); + CHECK_EQ(toString(ty), "table"); + + const TableTypeVar* ttv = get(ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +local c: Cool = { a = 1, b = "s" } +type NotCool = Cool +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +type NotCool = Cool +local c: Cool = { a = 1, b = "s" } +local d: NotCool = { a = 1, b = "s" } +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + ty = requireType("d"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "NotCool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") +{ + CheckResult result = check(R"( +local c = { a = 1, b = "s" } +type Cool = typeof(c) +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK_EQ(ttv->name, "Cool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") +{ + ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; + + fileResolver.source["game/A"] = R"( +export type X = { a: number, b: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(follow(*ty1), follow(*ty2)); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") +{ + ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; + + fileResolver.source["game/A"] = R"( +export type X = { a: T, b: U, C: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); + + bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); + CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); +} + +TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") +{ + CheckResult result = check(R"( +--!nonstrict +local f = {} +function f:foo(a: number, b: number) end + +function bar(...) + f.foo(f, 1, ...) +end + +bar(2) +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") +{ + CheckResult result = check(R"( +local function f(a: typeof(f)) end +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning") +{ + CheckResult result = check(R"( +--!nonstrict +local l0:any,l61:t0 = _,math +while _ do +_() +end +function _():t0 +end +type t0 = any +)"); + + std::optional ty = requireType("math"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "bound_free_table_export_is_ok") +{ + CheckResult result = check(R"( +local n = {} +function n:Clone() end + +local m = {} + +function m.a(x) + x:Clone() +end + +function m.b() + m.a(n) +end + +return m +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") +{ + TypeId mathTy = requireType(typeChecker.globalScope, "math"); + REQUIRE(mathTy); + TableTypeVar* ttv = getMutable(mathTy); + REQUIRE(ttv); + const FunctionTypeVar* ftv = get(ttv->props["frexp"].type); + REQUIRE(ftv); + auto original = ftv->level; + + CheckResult result = check("local a = math.frexp"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(ftv->level.level == original.level); + CHECK(ftv->level.subLevel == original.subLevel); +} + +TEST_CASE_FIXTURE(Fixture, "table_indexing_error_location") +{ + CheckResult result = check(R"( +local foo = {42} +local bar: number? +local baz = foo[bar] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); +} + +TEST_CASE_FIXTURE(Fixture, "table_simple_call") +{ + CheckResult result = check(R"( +local a = setmetatable({ x = 2 }, { + __call = function(self) + return (self.x :: number) * 2 -- should work without annotation in the future + end +}) +local b = a() +local c = a(2) -- too many arguments + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return function(obj) return true end +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return {a = 1, b = function(obj) return true end} +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "custom_require_global") +{ + CheckResult result = check(R"( +--!nonstrict +require = function(a) end + +local crash = require(game.A) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local function f(a: string | number, b: boolean | number) + return a == b + end + )"); + + // This doesn't produce any errors but for the wrong reasons. + // This unit test serves as a reminder to not try and unify the operands on `==`/`~=`. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "access_index_metamethod_that_returns_variadic") +{ + CheckResult result = check(R"( + type Foo = {x: string} + local t = {} + setmetatable(t, { + __index = function(x: string): ...Foo + return {x = x} + end + }) + + local foo = t.bar + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = true; + CHECK_EQ("{| x: string |}", toString(requireType("foo"), o)); +} + +TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks") +{ + CheckResult result = check(R"( + type ( ... ) ( ) ; + ( ... ) ( - - ... ) ( - ... ) + type = ( ... ) ; + ( ... ) ( ) ( ... ) ; + ( ... ) "" + )"); + + CHECK_LE(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks2") +{ + CheckResult result = check(R"( + function _(l0:((typeof((pcall)))|((((t0)->())|(typeof(-67108864)))|(any)))|(any),...):(((typeof(0))|(any))|(any),typeof(-67108864),any) + xpcall(_,_,_) + _(_,_,_) + end + )"); + + CHECK_LE(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") +{ + CheckResult result = check(R"( + function _(l0:t0): (any, ()->()) + end + + type t0 = t0 | {} + )"); + + CHECK_LE(0, result.errors.size()); + + std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); + REQUIRE(t0); + CHECK(get(t0->type)); + + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { + return get(err); + }); + CHECK(it != result.errors.end()); +} + +TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") +{ + CheckResult result = check(R"( + function _(l0:t0): (any, ()->()) + return 0,_ + end + + type t0 = t0 | {} + _(nil) + )"); + + CHECK_LE(0, result.errors.size()); + + std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); + REQUIRE(t0); + CHECK(get(t0->type)); + + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { + return get(err); + }); + CHECK(it != result.errors.end()); +} + +TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional2") +{ + CheckResult result = check(R"( + function _(l0:({})|(t0)):((((typeof((xpcall)))|(t96))|(t13))&(t96),()->typeof(...)) + return 0,_ + end + + type t0 = ((typeof((_G)))|(({})|(t0)))|(t0) + _(nil) + + local t: ({})|(t0) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "no_infinite_loop_when_trying_to_unify_uh_this") +{ + CheckResult result = check(R"( + function _(l22,l0):((((boolean)|(t0))|(t0))&(()->(()->(()->()->{},(t0)|(t0)),any))) + return function():t0 + end + end + type t0 = ((typeof(_))|(any))|(typeof(_)) + _() + )"); + + CHECK_LE(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") +{ + CheckResult result = check(R"( + local l0,l0 + repeat + type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) + function _(l0):(t0)&(t0) + while nil do + end + end + until _(_)(_)._ + )"); + + CHECK_LE(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") +{ + CheckResult result = check(R"( + --!nonstrict + _ += _:n0(xpcall,_) + local l0 + do end + while _ do + function _:_() + _ += _(_._(_:n0(xpcall,_))) + end + end + )"); + + CHECK_LE(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") +{ + ScopedFastFlag sff{"LuauLogTableTypeVarBoundTo", true}; + + fileResolver.source["Module/Backend/Types"] = R"( + export type Fiber = { + return_: Fiber? + } + return {} + )"; + + fileResolver.source["Module/Backend"] = R"( + local Types = require(script.Types) + type Fiber = Types.Fiber + type ReactRenderer = { findFiberByHostInstance: () -> Fiber? } + + local function attach(renderer): () + local function getPrimaryFiber(fiber) + local alternate = fiber.alternate + return fiber + end + + local function getFiberIDForNative() + local fiber = renderer.findFiberByHostInstance() + fiber = fiber.return_ + return getPrimaryFiber(fiber) + end + end + + function culprit(renderer: ReactRenderer): () + attach(renderer) + end + + return culprit + )"; + + CheckResult result = frontend.check("Module/Backend"); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") +{ + CheckResult result = check(R"( + type Tree = { data: T, children: {Tree} } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- this would be an infinite type if we allowed it + type Tree = { data: T, children: {Tree<{T}>} } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "record_matching_overload") +{ + ScopedFastFlag sffs("LuauStoreMatchingOverloadFnType", true); + + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number) -> number) + local abc: Overload + abc(1) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // AstExprCall is the node that has the overload stored on it. + // findTypeAtPosition will look at the AstExprLocal, but this is not what + // we want to look at. + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(3, 10)); + REQUIRE_GE(ancestry.size(), 2); + AstExpr* parentExpr = ancestry[ancestry.size() - 2]->asExpr(); + REQUIRE(bool(parentExpr)); + REQUIRE(parentExpr->is()); + + ModulePtr module = getMainModule(); + auto it = module->astOverloadResolvedTypes.find(parentExpr); + REQUIRE(it != module->astOverloadResolvedTypes.end()); + CHECK_EQ(toString(it->second), "(number) -> number"); +} + +TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +{ + ScopedFastFlag luauInferFunctionArgsFix("LuauInferFunctionArgsFix", true); + + // Simple direct arg to arg propagation + CheckResult result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // An optional funciton is accepted, but since we already provide a function, nil can be ignored + result = check(R"( +type Table = { x: number, y: number } +local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Make sure self calls match correct index + result = check(R"( +type Table = { x: number, y: number } +local x = {} +x.b = {x = 1, y = 2} +function x:f(a: (Table) -> number) return a(self.b) end +x:f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Mix inferred and explicit argument types + result = check(R"( +function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end +f(function(a: number, b, c) return c and a + b or b - a end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Anonymous function has a varyadic pack + result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(...) return select(1, ...).z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Can't accept more arguments than provided + result = check(R"( +function f(a: (a: number, b: number) -> number) return a(1, 2) end +f(function(a, b, c, ...) return a + b end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'", toString(result.errors[0])); + + // Infer from varyadic packs into elements + result = check(R"( +function f(a: (...number) -> number) return a(1, 2) end +f(function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Infer from varyadic packs into varyadic packs + result = check(R"( +type Table = { x: number, y: number } +function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end +f(function(a, ...) local b = ... return b.z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Return type inference + result = check(R"( +type Table = { x: number, y: number } +function f(a: (number) -> Table) return a(4) end +f(function(x) return x * 2 end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); + + // Return type doesn't inference 'nil' + result = check(R"( +function f(a: (number) -> nil) return a(4) end +f(function(x) print(x) end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") +{ + ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); + + CheckResult result = check(R"( +local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end +return sum(2, 3, function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( +local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end +local a = {1, 2, 3} +local r = map(a, function(a) return a + a > 100 end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{boolean}", toString(requireType("r"))); + + check(R"( +local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end +local a = {1, 2, 3} +local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{| c: number, s: number |}", toString(requireType("r"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") +{ + ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); + + CheckResult result = check(R"( +local function g1(a: T, f: (T) -> T) return f(a) end +local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + +local g12: typeof(g1) & typeof(g2) + +g12(1, function(x) return x + x end) +g12(1, 2, function(x, y) return x + y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( +local function g1(a: T, f: (T) -> T) return f(a) end +local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + +local g12: typeof(g1) & typeof(g2) + +g12({x=1}, function(x) return {x=-x.x} end) +g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") +{ + CheckResult result = check(R"( + type Tree = { data: T, children: Forest } + type Forest = {Tree} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- OK because forwarded types are used with their parameters. + type Tree = { data: T, children: Forest } + type Forest = {Tree<{T}>} + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- Not OK because forwarded types are used with different types than their parameters. + type Forest = {Tree<{T}>} + type Tree = { data: T, children: Forest } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") +{ + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") +{ + CheckResult result = check(R"( + function f(x) return x[1] end + -- x has type X? for a free type variable X + local x = f ({}) + type ContainsFree = { this: a, that: typeof(x) } + type ContainsContainsFree = { that: ContainsFree } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") +{ + CheckResult result = check(R"( +local a = {{x=4}, {x=7}, {x=1}} +table.sort(a, function(x, y) return x.x < y.x end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") +{ + CheckResult result = check(R"( +type Table = { x: number, y: number } +local f: (Table) -> number = function(t) return t.x + t.y end + +type TableWithFunc = { x: number, y: number, f: (number, number) -> number } +local a: TableWithFunc = { x = 3, y = 4, f = function(a, b) return a + b end } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") +{ + ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); + + CheckResult result = check(R"( +local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end + +local function sumrec(f: typeof(sum)) + return sum(2, 3, function(a, b) return a + b end) +end + +local b = sumrec(sum) -- ok +local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") +{ + ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); + + CheckResult result = check(R"( +local function f(): {string|number} + return {1, "b", 3} +end + +local function g(): (number, {string|number}) + return 4, {1, "b", 3} +end + +local function h(): ...{string|number} + return {4}, {1, "b", 3}, {"s"} +end + +local function i(): ...{string|number} + return {1, "b", 3}, h() +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") +{ + ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); + + CheckResult result = check(R"( +local function f() + return {4, "b", 3} :: {string|number} +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types") +{ + ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); + + CheckResult result = check(R"( +local a: (number, number) -> number = function(a, b) return a - b end + +a = function(a, b) return a + b end + +local b: {number|string} +local c: {number|string} +b, c = {2, "s"}, {"b", 4} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types_mutable_lval") +{ + ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); + + CheckResult result = check(R"( +local a = {} +a.x = 2 +a = setmetatable(a, { __call = function(x) end }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "refine_and_or") +{ + ScopedFastFlag sff{"LuauSlightlyMoreFlexibleBinaryPredicates", true}; + + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t and t.x or 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("u"))); +} + +TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") +{ + ScopedFastFlag sffs[] = { + {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, + {"LuauExtraNilRecovery", true}, + }; + + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t.x and t or 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); + CHECK_EQ("number | {| x: number? |}", toString(requireType("u"))); +} + +TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") +{ + ScopedFastFlag sffs[] = { + {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, + {"LuauExtraNilRecovery", true}, + }; + + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t and t.x == 5 or t.x == 31337 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); + CHECK_EQ("boolean", toString(requireType("u"))); +} + +TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") +{ + ScopedFastFlag luauFollowInTypeFunApply("LuauFollowInTypeFunApply", true); + + CheckResult result = check(R"( +type A = { x: number } +local a: A = { x = 1 } +local b = a +type B = typeof(b) +type X = T +local c: X + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") +{ + ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + + { + CheckResult result = check(R"(local a = if true then "true" else "false")"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); + } +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") +{ + ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + + { + // Test expression containing elseif + CheckResult result = check(R"( +local a = if false then "a" elseif false then "b" else "c" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); + } +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") +{ + ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + + { + CheckResult result = check(R"(local a = if true then "true" else 42)"); + // We currently require both true/false expressions to unify to the same type. However, we do intend to lift + // this restriction in the future. + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); + } +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp new file mode 100644 index 0000000..91ac9f0 --- /dev/null +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -0,0 +1,207 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +struct TryUnifyFixture : Fixture +{ + TypeArena arena; + ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; + InternalErrorReporter iceHandler; + Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, &iceHandler}; +}; + +TEST_SUITE_BEGIN("TryUnifyTests"); + +TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") +{ + TypeVar numberOne{TypeVariant{PrimitiveTypeVar{PrimitiveTypeVar::Number}}}; + TypeVar numberTwo = numberOne; + + state.tryUnify(&numberOne, &numberTwo); + + CHECK(state.errors.empty()); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") +{ + TypeVar functionOne{ + TypeVariant{FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; + + TypeVar functionTwo{TypeVariant{ + FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; + + state.tryUnify(&functionOne, &functionTwo); + CHECK(state.errors.empty()); + + CHECK_EQ(functionOne, functionTwo); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") +{ + TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; + TypeVar functionOne{ + TypeVariant{FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; + + TypeVar functionOneSaved = functionOne; + + TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; + TypeVar functionTwo{ + TypeVariant{FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.stringType}))}}; + + TypeVar functionTwoSaved = functionTwo; + + state.tryUnify(&functionOne, &functionTwo); + CHECK(!state.errors.empty()); + + CHECK_EQ(functionOne, functionOneSaved); + CHECK_EQ(functionTwo, functionTwoSaved); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") +{ + TypeVar tableOne{TypeVariant{ + TableTypeVar{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, + }}; + + TypeVar tableTwo{TypeVariant{ + TableTypeVar{{{"foo", {arena.freshType(globalScope->level)}}}, std::nullopt, globalScope->level, TableState::Unsealed}, + }}; + + CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + + state.tryUnify(&tableOne, &tableTwo); + + CHECK(state.errors.empty()); + + CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") +{ + TypeVar tableOne{TypeVariant{ + TableTypeVar{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.numberType}}}, std::nullopt, globalScope->level, + TableState::Unsealed}, + }}; + + TypeVar tableTwo{TypeVariant{ + TableTypeVar{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.stringType}}}, std::nullopt, globalScope->level, + TableState::Unsealed}, + }}; + + CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); + + state.tryUnify(&tableOne, &tableTwo); + + CHECK_EQ(1, state.errors.size()); + + state.log.rollback(); + + CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_unified_with_errorType") +{ + CheckResult result = check(R"( + function f(arg: number) end + local a + local b + f(a, b) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeId bType = requireType("b"); + + CHECK_MESSAGE(get(bType), "Should be an error: " << toString(bType)); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails") +{ + CheckResult result = check(R"( + --!strict + local function f(v: number) + if v % 2 == 0 then + return true + end + end + + return function() + return (f(1)) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") +{ + TypePackVar testPack{TypePack{{typeChecker.numberType, typeChecker.stringType}, std::nullopt}}; + TypePackVar variadicPack{VariadicTypePack{typeChecker.numberType}}; + + state.tryUnify(&variadicPack, &testPack); + CHECK(!state.errors.empty()); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") +{ + TypePackVar variadicPack{VariadicTypePack{typeChecker.booleanType}}; + TypePackVar a{TypePack{{typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType, typeChecker.booleanType}}}; + TypePackVar b{TypePack{{typeChecker.numberType, typeChecker.stringType}, &variadicPack}}; + + state.tryUnify(&a, &b); + CHECK(state.errors.empty()); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_work") +{ + TypePackId variadicPack = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); + TypePackId errorPack = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{Unifiable::Error{}})}); + + state.tryUnify(variadicPack, errorPack); + REQUIRE_EQ(0, state.errors.size()); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") +{ + ScopedFastFlag sffs2{"LuauGenericFunctions", true}; + ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + --!strict + local function f(...: T): ...T + return ... + end + + local x: string = f(1) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(toString(tm->givenType), "number"); + CHECK_EQ(toString(tm->wantedType), "string"); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unification") +{ + ScopedFastFlag sffs2("LuauGenericFunctions", true); + + CheckResult result = check(R"( + --!strict + table.insert() + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "No overload for function accepts 0 arguments."); + CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp new file mode 100644 index 0000000..5f7f284 --- /dev/null +++ b/tests/TypeInfer.typePacks.cpp @@ -0,0 +1,297 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypePackTests"); + +TEST_CASE_FIXTURE(Fixture, "infer_multi_return") +{ + CheckResult result = check(R"( + function take_two() + return 2, 2 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* takeTwoType = get(requireType("take_two")); + REQUIRE(takeTwoType != nullptr); + + const auto& [returns, tail] = flatten(takeTwoType->retType); + + CHECK_EQ(2, returns.size()); + CHECK_EQ(typeChecker.numberType, returns[0]); + CHECK_EQ(typeChecker.numberType, returns[1]); + + CHECK(!tail); +} + +TEST_CASE_FIXTURE(Fixture, "empty_varargs_should_return_nil_when_not_in_tail_position") +{ + CheckResult result = check(R"( + local a, b = ..., 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "self_and_varargs_should_work") +{ + CheckResult result = check(R"( + local t = {} + function t:f(...) end + t:f(1) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pack") +{ + CheckResult result = check(R"( + function take_two() + return 2, 2 + end + + function take_three() + return 1, take_two() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + const FunctionTypeVar* takeOneMoreType = get(requireType("take_three")); + REQUIRE(takeOneMoreType != nullptr); + + const auto& [rets, tail] = flatten(takeOneMoreType->retType); + + REQUIRE_EQ(3, rets.size()); + CHECK_EQ(typeChecker.numberType, rets[0]); + CHECK_EQ(typeChecker.numberType, rets[1]); + CHECK_EQ(typeChecker.numberType, rets[2]); + + CHECK(!tail); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function") +{ + CheckResult result = check(R"( + function apply(f, g, x) + return f(g(x)) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* applyType = get(requireType("apply")); + REQUIRE(applyType != nullptr); + + std::vector applyArgs = flatten(applyType->argTypes).first; + REQUIRE_EQ(3, applyArgs.size()); + + const FunctionTypeVar* fType = get(applyArgs[0]); + REQUIRE(fType != nullptr); + + const FunctionTypeVar* gType = get(applyArgs[1]); + REQUIRE(gType != nullptr); + + std::vector gArgs = flatten(gType->argTypes).first; + REQUIRE_EQ(1, gArgs.size()); + + // function(function(t1, T2...): (t3, T4...), function(t5): (t1, T2...), t5): (t3, T4...) + + REQUIRE_EQ(*gArgs[0], *applyArgs[2]); + REQUIRE_EQ(toString(fType->argTypes), toString(gType->retType)); + REQUIRE_EQ(toString(fType->retType), toString(applyType->retType)); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_should_be_empty_if_nothing_is_returned") +{ + CheckResult result = check(R"( + function f() end + function g() return end + )"); + LUAU_REQUIRE_NO_ERRORS(result); + const FunctionTypeVar* fTy = get(requireType("f")); + REQUIRE(fTy != nullptr); + CHECK_EQ(0, size(fTy->retType)); + const FunctionTypeVar* gTy = get(requireType("g")); + REQUIRE(gTy != nullptr); + CHECK_EQ(0, size(gTy->retType)); +} + +TEST_CASE_FIXTURE(Fixture, "no_return_size_should_be_zero") +{ + CheckResult result = check(R"( + function f(a:any) return a end + function g() return end + function h() end + + g(h()) + f(g(),h()) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* fTy = get(requireType("f")); + REQUIRE(fTy != nullptr); + CHECK_EQ(1, size(follow(fTy->retType))); + + const FunctionTypeVar* gTy = get(requireType("g")); + REQUIRE(gTy != nullptr); + CHECK_EQ(0, size(gTy->retType)); + + const FunctionTypeVar* hTy = get(requireType("h")); + REQUIRE(hTy != nullptr); + CHECK_EQ(0, size(hTy->retType)); +} + +TEST_CASE_FIXTURE(Fixture, "varargs_inference_through_multiple_scopes") +{ + CheckResult result = check(R"( + local function f(...) + do + local a: string = ... + local b: number = ... + end + end + + f("foo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "multiple_varargs_inference_are_not_confused") +{ + CheckResult result = check(R"( + local function f(...) + local a: string = ... + + return function(...) + local b: number = ... + end + end + + f("foo", "bar")(1, 2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "parenthesized_varargs_returns_any") +{ + CheckResult result = check(R"( + --!strict + local value + + local function f(...) + value = ... + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("value"))); +} + +TEST_CASE_FIXTURE(Fixture, "variadic_packs") +{ + TypeArena& arena = typeChecker.globalTypes; + + unfreeze(arena); + + TypePackId listOfNumbers = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); + TypePackId listOfStrings = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.stringType}}); + + // clang-format off + addGlobalBinding(typeChecker, "foo", + arena.addType( + FunctionTypeVar{ + listOfNumbers, + arena.addTypePack({typeChecker.numberType}) + } + ), + "@test" + ); + addGlobalBinding(typeChecker, "bar", + arena.addType( + FunctionTypeVar{ + arena.addTypePack({{typeChecker.numberType}, listOfStrings}), + arena.addTypePack({typeChecker.numberType}) + } + ), + "@test" + ); + // clang-format on + + freeze(arena); + + CheckResult result = check(R"( + --!strict + + foo(1, 2, 3, "foo") + bar(1, "foo", "bar", 3) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(result.errors[0], (TypeError{Location(Position{3, 21}, Position{3, 26}), TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + + CHECK_EQ(result.errors[1], (TypeError{Location(Position{4, 29}, Position{4, 30}), TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); +} + +TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") +{ + CheckResult result = check(R"( + --!strict + + local function foo(...: number) + end + + foo(1, 2, 3, 4, 5, 6) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("foo")), "(...number) -> ()"); +} + +// CLI-45791 +TEST_CASE_FIXTURE(UnfrozenFixture, "type_pack_hidden_free_tail_infinite_growth") +{ + CheckResult result = check(R"( +--!nonstrict +if _ then + _[function(l0)end],l0 = _ +elseif _ then + return l0(nil) +elseif 1 / l0(nil) then +elseif _ then + return #_,l0() +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") +{ + CheckResult result = check(R"( +local _ = function():((...any)->(...any),()->()) + return function() end, function() end +end +for y in _() do +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp new file mode 100644 index 0000000..ae4d836 --- /dev/null +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -0,0 +1,464 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +LUAU_FASTFLAG(LuauEqConstraint) + +using namespace Luau; + +TEST_SUITE_BEGIN("UnionTypes"); + +TEST_CASE_FIXTURE(Fixture, "return_types_can_be_disjoint") +{ + CheckResult result = check(R"( + local count = 0 + function most_of_the_natural_numbers(): number? + if count < 10 then + count = count + 1 + return count + else + return nil + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* utv = get(requireType("most_of_the_natural_numbers")); + REQUIRE(utv != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "allow_specific_assign") +{ + CheckResult result = check(R"( + local a:number|string = 22 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "allow_more_specific_assign") +{ + CheckResult result = check(R"( + local a:number|string = 22 + local b:number|string|nil = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign") +{ + CheckResult result = check(R"( + local a:number = 10 + local b:number|string = 20 + a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign2") +{ + CheckResult result = check(R"( + local a:number? = 10 + local b:number|string? = 20 + a = b + )"); + + REQUIRE_EQ(1, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "optional_arguments") +{ + CheckResult result = check(R"( + function f(a:string, b:string?) + end + f("s") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "optional_arguments_table") +{ + CheckResult result = check(R"( + local a:{a:string, b:string?} + a = {a="ok"} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "optional_arguments_table2") +{ + CheckResult result = check(R"( + local a:{a:string, b:string} + a = {a=""} + )"); + REQUIRE(!result.errors.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "error_takes_optional_arguments") +{ + CheckResult result = check(R"( + error("message") + error("message", 2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "error_optional_argument_enforces_type") +{ + CheckResult result = check(R"( + error("message", "2") + )"); + + REQUIRE(result.errors.size() == 1); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_property_guaranteed_to_exist") +{ + CheckResult result = check(R"( + type A = {x: number} + type B = {x: number} + local t: A | B + + local r = t.x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.numberType, *requireType("r")); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_mixed_types") +{ + CheckResult result = check(R"( + type A = {x: number} + type B = {x: string} + local t: A | B + + local r = t.x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number | string", toString(requireType("r"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_works_at_arbitrary_depth") +{ + CheckResult result = check(R"( + type A = {x: {y: {z: {thing: number}}}} + type B = {x: {y: {z: {thing: string}}}} + local t: A | B + + local r = t.x.y.z.thing + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number | string", toString(requireType("r"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property") +{ + CheckResult result = check(R"( + type A = {x: number} + type B = {x: number?} + local t: A | B + + local r = t.x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number?", toString(requireType("r"))); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") +{ + ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); + + CheckResult result = check(R"( + type A = {x: number} + type B = {} + local t: A | B + + local r = t.x + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + MissingUnionProperty* mup = get(result.errors[0]); + REQUIRE(mup); + CHECK_EQ(mup->type, requireType("t")); + REQUIRE(mup->missing.size() == 1); + std::optional bTy = lookupType("B"); + REQUIRE(bTy); + CHECK_EQ(mup->missing[0], *bTy); + CHECK_EQ(mup->key, "x"); + + TypeId r = requireType("r"); + CHECK_MESSAGE(get(r), "Expected error, got " << toString(r)); +} + +TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") +{ + CheckResult result = check(R"( + type A = {x: number} + type B = {x: any} + local t: A | B + + local r = t.x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.anyType, *requireType("r")); +} + +TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") +{ + CheckResult result = check(R"( + type A = number | string | nil + type B = number | nil + type C = number | boolean + + local a: A = 1 + local b: B = nil + local c: C = true + local n = 1 + + local x = a == b + local y = a == n + local z = a == c + )"); + + if (FFlag::LuauEqConstraint) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(*typeChecker.booleanType, *requireType("x")); + CHECK_EQ(*typeChecker.booleanType, *requireType("y")); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("(number | string)?", toString(*tm->wantedType)); + CHECK_EQ("boolean | number", toString(*tm->givenType)); + } +} + +TEST_CASE_FIXTURE(Fixture, "optional_union_members") +{ + ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); + + CheckResult result = check(R"( +local a = { a = { x = 1, y = 2 }, b = 3 } +type A = typeof(a) +local b: A? = a +local bf = b +local c = bf.a.y + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "optional_union_functions") +{ + ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); + + CheckResult result = check(R"( +local a = {} +function a.foo(x:number, y:number) return x + y end +type A = typeof(a) +local b: A? = a +local c = b.foo(1, 2) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "optional_union_methods") +{ + ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); + + CheckResult result = check(R"( +local a = {} +function a:foo(x:number, y:number) return x + y end +type A = typeof(a) +local b: A? = a +local c = b:foo(1, 2) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "optional_union_follow") +{ + CheckResult result = check(R"( +local y: number? = 2 +local x = y +local function f(a: number, b: typeof(x), c: typeof(x)) return -a end +return f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(1, acm->expected); + CHECK_EQ(0, acm->actual); +} + +TEST_CASE_FIXTURE(Fixture, "optional_field_access_error") +{ + ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); + + CheckResult result = check(R"( +type A = { x: number } +local b: A? = { x = 2 } +local c = b.x +local d = b.y + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[1])); + CHECK_EQ("Key 'y' not found in table 'A'", toString(result.errors[2])); +} + +TEST_CASE_FIXTURE(Fixture, "optional_index_error") +{ + ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); + + CheckResult result = check(R"( +type A = {number} +local a: A? = {1, 2, 3} +local b = a[1] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "optional_call_error") +{ + ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); + + CheckResult result = check(R"( +type A = (number) -> number +local a: A? = function(a) return -a end +local b = a(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '((number) -> number)?' could be nil", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors") +{ + ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); + + CheckResult result = check(R"( +type A = { x: number } +local a: A? = { x = 2 } +a.x = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); + + result = check(R"( +type A = { x: number } & { y: number } +local a: A? = { x = 2, y = 3 } +a.x = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '({| x: number |} & {| y: number |})?' could be nil", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "optional_length_error") +{ + ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); + + CheckResult result = check(R"( +type A = {number} +local a: A? = {1, 2, 3} +local b = #a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details") +{ + ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); + ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); + + CheckResult result = check(R"( +type A = { x: number, y: number } +type B = { x: number, y: number } +type C = { x: number } +type D = { x: number } + +local a: A|B|C|D +local b = a.y + +local c: A|(B|C)?|D +local d = c.y + +local e = a.z + )"); + + LUAU_REQUIRE_ERROR_COUNT(4, result); + CHECK_EQ("Key 'y' is missing from 'C', 'D' in the type 'A | B | C | D'", toString(result.errors[0])); + + CHECK_EQ("Value of type '(A | B | C | D)?' could be nil", toString(result.errors[1])); + CHECK_EQ("Key 'y' is missing from 'C', 'D' in the type 'A | B | C | D'", toString(result.errors[2])); + + CHECK_EQ("Type 'A | B | C | D' does not have key 'z'", toString(result.errors[3])); +} + +TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") +{ + ScopedFastFlag luauSealedTableUnifyOptionalFix("LuauSealedTableUnifyOptionalFix", true); + + CheckResult result = check(R"( +local x: { x: number } = { x = 3 } +type A = number? +type B = string? +local y: { x: number, y: A | B } +y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( +local x: { x: number } = { x = 3 } + +local a: number? = 2 +local y = {} +y.x = 2 +y.y = a + +y = x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp new file mode 100644 index 0000000..8b05654 --- /dev/null +++ b/tests/TypePack.test.cpp @@ -0,0 +1,201 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +namespace +{ +struct TypePackFixture +{ + TypePackFixture() + { + typeVars.emplace_back(new TypeVar(PrimitiveTypeVar(PrimitiveTypeVar::NilType))); + typeVars.emplace_back(new TypeVar(PrimitiveTypeVar(PrimitiveTypeVar::Boolean))); + typeVars.emplace_back(new TypeVar(PrimitiveTypeVar(PrimitiveTypeVar::Number))); + typeVars.emplace_back(new TypeVar(PrimitiveTypeVar(PrimitiveTypeVar::String))); + + for (const auto& ptr : typeVars) + types.push_back(ptr.get()); + } + + TypePackId freshTypePack() + { + typePacks.emplace_back(new TypePackVar{Unifiable::Free{{}}}); + return typePacks.back().get(); + } + + TypePackId newTypePack(std::initializer_list types, std::optional tail) + { + typePacks.emplace_back(new TypePackVar{TypePack{types, tail}}); + return typePacks.back().get(); + } + + std::vector> typePacks; + + std::vector> typeVars; + std::vector types; +}; + +} // namespace + +TEST_SUITE_BEGIN("TypePackTests"); + +TEST_CASE_FIXTURE(TypePackFixture, "type_pack_hello") +{ + auto tp = TypePackVar{TypePack{{types[0], types[1]}, std::nullopt}}; + + CHECK(tp == tp); +} + +TEST_CASE_FIXTURE(TypePackFixture, "first_chases_Bound_TypePackVars") +{ + TypeVar nilType{PrimitiveTypeVar{PrimitiveTypeVar::NilType}}; + + auto tp1 = TypePackVar{TypePack{{&nilType}, std::nullopt}}; + + auto tp2 = TypePackVar{BoundTypePack{&tp1}}; + + auto tp3 = TypePackVar{TypePack{{}, &tp2}}; + + CHECK_EQ(first(&tp3), &nilType); +} + +TEST_CASE_FIXTURE(TypePackFixture, "iterate_over_TypePack") +{ + TypePackId typePack = newTypePack({types[0], types[1]}, std::nullopt); + + std::vector res; + for (TypeId t : typePack) + res.push_back(t); + + REQUIRE_EQ(2, res.size()); +} + +TEST_CASE_FIXTURE(TypePackFixture, "iterate_over_TypePack_with_2_links") +{ + auto typePack1 = newTypePack({types[0], types[1]}, std::nullopt); + auto typePack2 = newTypePack({types[0], types[3]}, typePack1); + + std::vector result; + for (TypeId ty : typePack2) + result.push_back(ty); + + REQUIRE_EQ(4, result.size()); + + CHECK_EQ(types[0], result[0]); + CHECK_EQ(types[3], result[1]); + CHECK_EQ(types[0], result[2]); + CHECK_EQ(types[1], result[3]); +} + +TEST_CASE_FIXTURE(TypePackFixture, "get_the_tail") +{ + TypePackId freeTail = freshTypePack(); + TypePackId typePack = newTypePack({types[0]}, freeTail); + + auto it = begin(typePack); + auto endIt = end(typePack); + int count = 0; + while (it != endIt) + ++count, ++it; + REQUIRE_EQ(1, count); + + CHECK(it == end(typePack)); + + REQUIRE_EQ(it.tail(), freeTail); +} + +TEST_CASE_FIXTURE(TypePackFixture, "tail_can_be_nullopt") +{ + TypePackId typePack = newTypePack({types[0], types[0]}, std::nullopt); + + auto it = end(typePack); + REQUIRE_EQ(std::nullopt, it.tail()); +} + +TEST_CASE_FIXTURE(TypePackFixture, "tail_is_end_for_free_TypePack") +{ + TypePackId typePack = freshTypePack(); + + auto it = begin(typePack); + auto endIt = end(typePack); + while (it != endIt) + ++it; + REQUIRE_EQ(typePack, it.tail()); +} + +TEST_CASE_FIXTURE(TypePackFixture, "skip_over_empty_head_typepack_with_tail") +{ + TypePackId tailTP = newTypePack({types[2], types[3]}, std::nullopt); + TypePackId headTP = newTypePack({}, tailTP); + + int count = 0; + for (TypeId ty : headTP) + { + (void)ty; + ++count; + } + + CHECK_EQ(2, count); +} + +TEST_CASE_FIXTURE(TypePackFixture, "skip_over_empty_middle_link") +{ + TypePackId tailTP = newTypePack({types[2], types[3]}, std::nullopt); + TypePackId middleTP = newTypePack({}, tailTP); + TypePackId headTP = newTypePack({types[0], types[1]}, middleTP); + + int count = 0; + for (TypeId ty : headTP) + { + (void)ty; + ++count; + } + + CHECK_EQ(4, count); +} + +TEST_CASE_FIXTURE(TypePackFixture, "follows_Bound_TypePacks") +{ + TypePackId tailTP = newTypePack({types[2], types[3]}, std::nullopt); + TypePackId middleTP = freshTypePack(); + *asMutable(middleTP) = Unifiable::Bound(tailTP); + TypePackId headTP = newTypePack({}, middleTP); + + int count = 0; + for (TypeId ty : headTP) + { + (void)ty; + ++count; + } + + CHECK_EQ(2, count); +} + +TEST_CASE_FIXTURE(TypePackFixture, "post_and_pre_increment") +{ + TypePackId typePack = newTypePack({types[0], types[1], types[2], types[3]}, std::nullopt); + + auto it1 = begin(typePack); + auto it2 = it1++; + auto it3 = ++it2; + + CHECK_EQ(*it2, *it3); +} + +TEST_CASE_FIXTURE(TypePackFixture, "std_distance") +{ + TypePackId typePack = newTypePack({types[0], types[1], types[2], types[3]}, std::nullopt); + + auto b = begin(typePack); + auto e = end(typePack); + CHECK_EQ(4, std::distance(b, e)); +} + +TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp new file mode 100644 index 0000000..98ce9f9 --- /dev/null +++ b/tests/TypeVar.test.cpp @@ -0,0 +1,267 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +LUAU_FASTFLAG(LuauGenericFunctions); + +TEST_SUITE_BEGIN("TypeVarTests"); + +TEST_CASE_FIXTURE(Fixture, "primitives_are_equal") +{ + REQUIRE_EQ(typeChecker.booleanType, typeChecker.booleanType); +} + +TEST_CASE_FIXTURE(Fixture, "bound_type_is_equal_to_that_which_it_is_bound") +{ + TypeVar bound(BoundTypeVar(typeChecker.booleanType)); + REQUIRE_EQ(bound, *typeChecker.booleanType); +} + +TEST_CASE_FIXTURE(Fixture, "equivalent_cyclic_tables_are_equal") +{ + TypeVar cycleOne{TypeVariant(TableTypeVar())}; + TableTypeVar* tableOne = getMutable(&cycleOne); + tableOne->props["self"] = {&cycleOne}; + + TypeVar cycleTwo{TypeVariant(TableTypeVar())}; + TableTypeVar* tableTwo = getMutable(&cycleTwo); + tableTwo->props["self"] = {&cycleTwo}; + + CHECK_EQ(cycleOne, cycleTwo); +} + +TEST_CASE_FIXTURE(Fixture, "different_cyclic_tables_are_not_equal") +{ + TypeVar cycleOne{TypeVariant(TableTypeVar())}; + TableTypeVar* tableOne = getMutable(&cycleOne); + tableOne->props["self"] = {&cycleOne}; + + TypeVar cycleTwo{TypeVariant(TableTypeVar())}; + TableTypeVar* tableTwo = getMutable(&cycleTwo); + tableTwo->props["this"] = {&cycleTwo}; + + CHECK_NE(cycleOne, cycleTwo); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_not_parenthesized_if_just_one_value") +{ + auto emptyArgumentPack = TypePackVar{TypePack{}}; + auto returnPack = TypePackVar{TypePack{{typeChecker.numberType}}}; + auto returnsTwo = TypeVar(FunctionTypeVar(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + + std::string res = toString(&returnsTwo); + CHECK_EQ("() -> number", res); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_not_just_one_value") +{ + auto emptyArgumentPack = TypePackVar{TypePack{}}; + auto returnPack = TypePackVar{TypePack{{typeChecker.numberType, typeChecker.numberType}}}; + auto returnsTwo = TypeVar(FunctionTypeVar(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + + std::string res = toString(&returnsTwo); + CHECK_EQ("() -> (number, number)", res); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_free") +{ + auto emptyArgumentPack = TypePackVar{TypePack{}}; + auto free = Unifiable::Free(TypeLevel()); + auto freePack = TypePackVar{TypePackVariant{free}}; + auto returnPack = TypePackVar{TypePack{{typeChecker.numberType}, &freePack}}; + auto returnsTwo = TypeVar(FunctionTypeVar(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + + std::string res = toString(&returnsTwo); + CHECK_EQ(res, "() -> (number, a...)"); +} + +TEST_CASE_FIXTURE(Fixture, "subset_check") +{ + UnionTypeVar super, sub, notSub; + super.options = {typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType}; + sub.options = {typeChecker.numberType, typeChecker.stringType}; + notSub.options = {typeChecker.numberType, typeChecker.nilType}; + + CHECK(isSubset(super, sub)); + CHECK(!isSubset(super, notSub)); +} + +TEST_CASE_FIXTURE(Fixture, "iterate_over_UnionTypeVar") +{ + UnionTypeVar utv; + utv.options = {typeChecker.numberType, typeChecker.stringType, typeChecker.anyType}; + + std::vector result; + for (TypeId ty : &utv) + result.push_back(ty); + + CHECK(result == utv.options); +} + +TEST_CASE_FIXTURE(Fixture, "iterating_over_nested_UnionTypeVars") +{ + TypeVar subunion{UnionTypeVar{}}; + UnionTypeVar* innerUtv = getMutable(&subunion); + innerUtv->options = {typeChecker.numberType, typeChecker.stringType}; + + UnionTypeVar utv; + utv.options = {typeChecker.anyType, &subunion}; + + std::vector result; + for (TypeId ty : &utv) + result.push_back(ty); + + REQUIRE_EQ(result.size(), 3); + CHECK_EQ(result[0], typeChecker.anyType); + CHECK_EQ(result[2], typeChecker.stringType); + CHECK_EQ(result[1], typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypeVars_and_skips_over_them") +{ + TypeVar atv{UnionTypeVar{}}; + UnionTypeVar* utv1 = getMutable(&atv); + + TypeVar btv{UnionTypeVar{}}; + UnionTypeVar* utv2 = getMutable(&btv); + utv2->options.push_back(typeChecker.numberType); + utv2->options.push_back(typeChecker.stringType); + utv2->options.push_back(&atv); + + utv1->options.push_back(&btv); + + std::vector result; + for (TypeId ty : utv2) + result.push_back(ty); + + REQUIRE_EQ(result.size(), 2); + CHECK_EQ(result[0], typeChecker.numberType); + CHECK_EQ(result[1], typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "iterator_descends_on_nested_in_first_operator*") +{ + TypeVar tv1{UnionTypeVar{{typeChecker.stringType, typeChecker.numberType}}}; + TypeVar tv2{UnionTypeVar{{&tv1, typeChecker.booleanType}}}; + auto utv = get(&tv2); + + std::vector result; + for (TypeId ty : utv) + result.push_back(ty); + + REQUIRE_EQ(result.size(), 3); + CHECK_EQ(result[0], typeChecker.stringType); + CHECK_EQ(result[1], typeChecker.numberType); + CHECK_EQ(result[2], typeChecker.booleanType); +} + +TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_vector_iter_ctor") +{ + TypeVar tv1{UnionTypeVar{{typeChecker.stringType, typeChecker.numberType}}}; + TypeVar tv2{UnionTypeVar{{&tv1, typeChecker.booleanType}}}; + auto utv = get(&tv2); + + std::vector actual(begin(utv), end(utv)); + std::vector expected{typeChecker.stringType, typeChecker.numberType, typeChecker.booleanType}; + CHECK_EQ(actual, expected); +} + +TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_empty_union") +{ + TypeVar tv{UnionTypeVar{}}; + auto utv = get(&tv); + + std::vector actual(begin(utv), end(utv)); + CHECK(actual.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") +{ + TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; + + TypePackVar tp24{TypePack{{&ftv11}}}; + TypePackVar tp17{TypePack{}}; + + TypeVar ftv23{FunctionTypeVar{&tp24, &tp17}}; + + TypeVar ttvConnection2{TableTypeVar{}}; + TableTypeVar* ttvConnection2_ = getMutable(&ttvConnection2); + ttvConnection2_->instantiatedTypeParams.push_back(&ftv11); + ttvConnection2_->props["f"] = {&ftv23}; + + TypePackVar tp21{TypePack{{&ftv11}}}; + TypePackVar tp20{TypePack{}}; + + TypeVar ftv19{FunctionTypeVar{&tp21, &tp20}}; + + TypeVar ttvSignal{TableTypeVar{}}; + TableTypeVar* ttvSignal_ = getMutable(&ttvSignal); + ttvSignal_->instantiatedTypeParams.push_back(&ftv11); + ttvSignal_->props["f"] = {&ftv19}; + + // Back edge + ttvConnection2_->props["signal"] = {&ttvSignal}; + + TypeVar gtvK2{GenericTypeVar{}}; + TypeVar gtvV2{GenericTypeVar{}}; + + TypeVar ttvTweenResult2{TableTypeVar{}}; + TableTypeVar* ttvTweenResult2_ = getMutable(&ttvTweenResult2); + ttvTweenResult2_->instantiatedTypeParams.push_back(>vK2); + ttvTweenResult2_->instantiatedTypeParams.push_back(>vV2); + + TypePackVar tp13{TypePack{{&ttvTweenResult2}}}; + TypeVar ftv12{FunctionTypeVar{&tp13, &tp17}}; + + TypeVar ttvConnection{TableTypeVar{}}; + TableTypeVar* ttvConnection_ = getMutable(&ttvConnection); + ttvConnection_->instantiatedTypeParams.push_back(&ttvTweenResult2); + ttvConnection_->props["f"] = {&ftv12}; + ttvConnection_->props["signal"] = {&ttvSignal}; + + TypePackVar tp9{TypePack{}}; + TypePackVar tp10{TypePack{{&ttvConnection}}}; + + TypeVar ftv8{FunctionTypeVar{&tp9, &tp10}}; + + TypeVar ttvTween{TableTypeVar{}}; + TableTypeVar* ttvTween_ = getMutable(&ttvTween); + ttvTween_->instantiatedTypeParams.push_back(>vK2); + ttvTween_->instantiatedTypeParams.push_back(>vV2); + ttvTween_->props["f"] = {&ftv8}; + + TypePackVar tp4{TypePack{}}; + TypePackVar tp5{TypePack{{&ttvTween}}}; + + TypeVar ftv3{FunctionTypeVar{&tp4, &tp5}}; + + // Back edge + ttvTweenResult2_->props["f"] = {&ftv3}; + + TypeVar gtvK{GenericTypeVar{}}; + TypeVar gtvV{GenericTypeVar{}}; + + TypeVar ttvTweenResult{TableTypeVar{}}; + TableTypeVar* ttvTweenResult_ = getMutable(&ttvTweenResult); + ttvTweenResult_->instantiatedTypeParams.push_back(>vK); + ttvTweenResult_->instantiatedTypeParams.push_back(>vV); + ttvTweenResult_->props["f"] = {&ftv3}; + + TypeId root = &ttvTweenResult; + + typeChecker.currentModule = std::make_shared(); + + TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); + + CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result)); +} + +TEST_SUITE_END(); diff --git a/tests/Variant.test.cpp b/tests/Variant.test.cpp new file mode 100644 index 0000000..fcf3787 --- /dev/null +++ b/tests/Variant.test.cpp @@ -0,0 +1,178 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Variant.h" + +#include +#include + +#include "doctest.h" + +using namespace Luau; + +struct Foo +{ + int x = 42; +}; + +TEST_SUITE_BEGIN("Variant"); + +TEST_CASE("DefaultCtor") +{ + Variant v1; + Variant v2; + + REQUIRE(get_if(&v1)); + CHECK(*get_if(&v1) == 0); + CHECK(!get_if(&v1)); + + REQUIRE(get_if(&v2)); + CHECK(get_if(&v2)->x == 42); +} + +TEST_CASE("Create") +{ + Variant v1 = 1; + Variant v2 = Foo{2}; + + Foo f = {3}; + Variant v3 = f; + + REQUIRE(get_if(&v1)); + CHECK(*get_if(&v1) == 1); + + REQUIRE(get_if(&v2)); + CHECK(get_if(&v2)->x == 2); + + REQUIRE(get_if(&v3)); + CHECK(get_if(&v3)->x == 3); +} + +TEST_CASE("NonPOD") +{ + // initialize (copy) + std::string s1 = "hello"; + Variant v1 = s1; + + CHECK(*get_if(&v1) == "hello"); + + // initialize (move) + Variant v2 = std::string("hello"); + + CHECK(*get_if(&v2) == "hello"); + + // move-assign + v2 = std::string("this is a long string that doesn't fit into the small buffer"); + + CHECK(*get_if(&v2) == "this is a long string that doesn't fit into the small buffer"); + + // copy-assign + std::string s2("this is another long string, and this time we're copying it"); + v2 = s2; + + CHECK(*get_if(&v2) == "this is another long string, and this time we're copying it"); + + // copy ctor + Variant v3 = v2; + + CHECK(*get_if(&v2) == "this is another long string, and this time we're copying it"); + CHECK(*get_if(&v3) == "this is another long string, and this time we're copying it"); + + // move ctor + Variant v4 = std::move(v3); + + CHECK(*get_if(&v2) == "this is another long string, and this time we're copying it"); + CHECK(*get_if(&v3) == ""); // moved-from variant has an empty string now + CHECK(*get_if(&v4) == "this is another long string, and this time we're copying it"); +} + +TEST_CASE("Equality") +{ + Variant v1 = std::string("hi"); + Variant v2 = std::string("me"); + Variant v3 = 1; + Variant v4 = 0; + Variant v5; + + CHECK(v1 == v1); + CHECK(v1 != v2); + CHECK(v1 != v3); + CHECK(v3 != v4); + CHECK(v4 == v5); +} + +struct ToStringVisitor +{ + std::string operator()(const std::string& v) + { + return v; + } + + std::string operator()(int v) + { + return std::to_string(v); + } +}; + +struct IncrementVisitor +{ + void operator()(std::string& v) + { + v += "1"; + } + + void operator()(int& v) + { + v += 1; + } +}; + +TEST_CASE("Visit") +{ + Variant v1 = std::string("123"); + Variant v2 = 45; + const Variant& v1c = v1; + const Variant& v2c = v2; + + // void-returning visitor, const variants + std::string r1; + visit( + [&](const auto& v) { + r1 += ToStringVisitor()(v); + }, + v1c); + visit( + [&](const auto& v) { + r1 += ToStringVisitor()(v); + }, + v2c); + CHECK(r1 == "12345"); + + // value-returning visitor, const variants + std::string r2; + r2 += visit(ToStringVisitor(), v1c); + r2 += visit(ToStringVisitor(), v2c); + CHECK(r2 == "12345"); + + // void-returning visitor, mutable variant + visit(IncrementVisitor(), v1); + visit(IncrementVisitor(), v2); + CHECK(visit(ToStringVisitor(), v1) == "1231"); + CHECK(visit(ToStringVisitor(), v2) == "46"); + + // value-returning visitor, mutable variant + std::string r3; + r3 += visit( + [&](auto& v) { + IncrementVisitor()(v); + return ToStringVisitor()(v); + }, + v1); + r3 += visit( + [&](auto& v) { + IncrementVisitor()(v); + return ToStringVisitor()(v); + }, + v2); + CHECK(r3 == "1231147"); +} + +TEST_SUITE_END(); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua new file mode 100644 index 0000000..5e03b05 --- /dev/null +++ b/tests/conformance/apicalls.lua @@ -0,0 +1,8 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing function calls through API') + +function add(a, b) + return a + b +end + +return('OK') diff --git a/tests/conformance/assert.lua b/tests/conformance/assert.lua new file mode 100644 index 0000000..2ae9fec --- /dev/null +++ b/tests/conformance/assert.lua @@ -0,0 +1,34 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing asserts") -- someone has to + +if pcall(assert, false) or pcall(function() assert(false) end) then + error('catastrophic assertion failure') -- surely error() can't be broken +end + +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub(err:find(": ") + 2, #err) +end + +-- zero-ret calls work +assert(1) +assert(true) + +-- returns first arg +assert(assert(1) == 1) +assert(type(assert({})) == 'table') + +-- fails correctly +assert(ecall(function() assert() end) == "missing argument #1") +assert(ecall(function() assert(nil) end) == "assertion failed!") +assert(ecall(function() assert(false) end) == "assertion failed!") + +-- fails with a message +assert(ecall(function() assert(nil, "epic fail") end) == "epic fail") + +-- returns all arguments for multi-arg calls +assert(select('#', assert(1, 2, 3)) == 3) +assert(table.concat(table.pack(assert(1, 2, 3)), "") == "123") + +return('OK') diff --git a/tests/conformance/attrib.lua b/tests/conformance/attrib.lua new file mode 100644 index 0000000..58ad976 --- /dev/null +++ b/tests/conformance/attrib.lua @@ -0,0 +1,106 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print("testing assignments, logical operators, and constructors") + +local unpack = table.unpack + +local res, res2 = 27 + +a, b = 1, 2+3 +assert(a==1 and b==5) +a={} +function f() return 10, 11, 12 end +a.x, b, a[1] = 1, 2, f() +assert(a.x==1 and b==2 and a[1]==10) +a[f()], b, a[f()+3] = f(), a, 'x' +assert(a[10] == 10 and b == a and a[13] == 'x') + +do + local f = function (n) local x = {}; for i=1,n do x[i]=i end; + return unpack(x) end; + local a,b,c + a,b = 0, f(1) + assert(a == 0 and b == 1) + A,b = 0, f(1) + assert(A == 0 and b == 1) + a,b,c = 0,5,f(4) + assert(a==0 and b==5 and c==1) + a,b,c = 0,5,f(0) + assert(a==0 and b==5 and c==nil) +end + + +a, b, c, d = 1 and nil, 1 or nil, (1 and (nil or 1)), 6 +assert(not a and b and c and d==6) + +d = 20 +a, b, c, d = f() +assert(a==10 and b==11 and c==12 and d==nil) +a,b = f(), 1, 2, 3, f() +assert(a==10 and b==1) + +assert(ab == true) +assert((10 and 2) == 2) +assert((10 or 2) == 10) +assert((10 or assert(nil)) == 10) +assert(not (nil and assert(nil))) +assert((nil or "alo") == "alo") +assert((nil and 10) == nil) +assert((false and 10) == false) +assert((true or 10) == true) +assert((false or 10) == 10) +assert(false ~= nil) +assert(nil ~= false) +assert(not nil == true) +assert(not not nil == false) +assert(not not 1 == true) +assert(not not a == true) +assert(not not (6 or nil) == true) +assert(not not (nil and 56) == false) +assert(not not (nil and true) == false) +print('+') + +a = {} +a[true] = 20 +a[false] = 10 +assert(a[1<2] == 20 and a[1>2] == 10) + +function f(a) return a end + +local a = {} +for i=3000,-3000,-1 do a[i] = i; end +a[10e30] = "alo"; a[true] = 10; a[false] = 20 +assert(a[10e30] == 'alo' and a[not 1] == 20 and a[10<20] == 10) +for i=3000,-3000,-1 do assert(a[i] == i); end +a[print] = assert +a[f] = print +a[a] = a +assert(a[a][a][a][a][print] == assert) +a[print](a[a[f]] == a[print]) +a = nil + +a = {10,9,8,7,6,5,4,3,2; [-3]='a', [f]=print, a='a', b='ab'} +a, a.x, a.y = a, a[-3] +assert(a[1]==10 and a[-3]==a.a and a[f]==print and a.x=='a' and not a.y) +a[1], f(a)[2], b, c = {['alo']=assert}, 10, a[1], a[f], 6, 10, 23, f(a), 2 +a[1].alo(a[2]==10 and b==10 and c==print) + +a[2^31] = 10; a[2^31+1] = 11; a[-2^31] = 12; +a[2^32] = 13; a[-2^32] = 14; a[2^32+1] = 15; a[10^33] = 16; + +assert(a[2^31] == 10 and a[2^31+1] == 11 and a[-2^31] == 12 and + a[2^32] == 13 and a[-2^32] == 14 and a[2^32+1] == 15 and + a[10^33] == 16) + +a = nil + + +do + local a,i,j,b + a = {'a', 'b'}; i=1; j=2; b=a + i, a[i], a, j, a[j], a[i+j] = j, i, i, b, j, i + assert(i == 2 and b[1] == 1 and a == 1 and j == b and b[2] == 2 and + b[3] == 1) +end + +return('OK') diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua new file mode 100644 index 0000000..687fff1 --- /dev/null +++ b/tests/conformance/basic.lua @@ -0,0 +1,879 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing language/library basics") + +function concat(head, ...) + if select('#', ...) == 0 then + return tostring(head) + else + return tostring(head) .. "," .. concat(...) + end +end + +-- constants +assert(tostring(1) == "1") +assert(tostring(-1) == "-1") +assert(tostring(1.125) == "1.125") +assert(tostring(true) == "true") +assert(tostring(nil) == "nil") + +-- empty return +assert(select('#', (function() end)()) == 0) +assert(select('#', (function() return end)()) == 0) + +-- locals +assert((function() local a = 1 return a end)() == 1) +assert((function() local a, b, c = 1, 2, 3 return c end)() == 3) +assert((function() local a, b, c = 1, 2 return c end)() == nil) +assert((function() local a = 1, 2 return a end)() == 1) + +-- function calls +local function foo(a, b) return b end +assert(foo(1) == nil) +assert(foo(1, 2) == 2) +assert(foo(1, 2, 3) == 2) + +-- pcall +assert(concat(pcall(function () end)) == "true") +assert(concat(pcall(function () return nil end)) == "true,nil") +assert(concat(pcall(function () return 1,2,3 end)) == "true,1,2,3") +assert(concat(pcall(function () error("oops") end)) == "false,basic.lua:39: oops") + +-- assignments +assert((function() local a = 1 a = 2 return a end)() == 2) +assert((function() a = 1 a = 2 return a end)() == 2) +assert((function() local a = 1 a, b = 1 return a end)() == 1) +assert((function() local a = 1 a, b = 1 return b end)() == nil) +assert((function() local a = 1 b = 2 a, b = b, a return a end)() == 2) +assert((function() local a = 1 b = 2 a, b = b, a return b end)() == 1) +assert((function() _G.foo = 1 return _G['foo'] end)() == 1) +assert((function() _G['bar'] = 1 return _G.bar end)() == 1) +assert((function() local a = 1 (function () a = 2 end)() return a end)() == 2) + +-- upvalues +assert((function() local a = 1 function foo() return a end return foo() end)() == 1) + +-- check upvalue propagation - foo must have numupvalues=1 +assert((function() local a = 1 function foo() return function() return a end end return foo()() end)() == 1) + +-- check that function args are properly closed over +assert((function() function foo(a) return function () return a end end return foo(1)() end)() == 1) + +-- this checks local aliasing - b & a should share the same local slot, but the capture must return 1 instead of 2 +assert((function() function foo() local f do local a = 1 f = function () return a end end local b = 2 return f end return foo()() end)() == 1) + +-- this checks local mutability - we capture a ref to 1 but must return 2 +assert((function() function foo() local a = 1 local function f() return a end a = 2 return f end return foo()() end)() == 2) + +-- this checks upval mutability - we change the value from a context where it's upval +assert((function() function foo() local a = 1 (function () a = 2 end)() return a end return foo() end)() == 2) + +-- check self capture: does self go into any upvalues? +assert((function() local t = {f=5} function t:get() return (function() return self.f end)() end return t:get() end)() == 5) + +-- check self capture & close: is self copied to upval? +assert((function() function foo() local t = {f=5} function t:get() return function() return self.f end end return t:get() end return foo()() end)() == 5) + +-- if +assert((function() local a = 1 if a then a = 2 end return a end)() == 2) +assert((function() local a if a then a = 2 end return a end)() == nil) + +assert((function() local a = 0 if a then a = 1 else a = 2 end return a end)() == 1) +assert((function() local a if a then a = 1 else a = 2 end return a end)() == 2) + +-- binary ops +assert((function() local a = 1 a = a + 2 return a end)() == 3) +assert((function() local a = 1 a = a - 2 return a end)() == -1) +assert((function() local a = 1 a = a * 2 return a end)() == 2) +assert((function() local a = 1 a = a / 2 return a end)() == 0.5) +assert((function() local a = 5 a = a % 2 return a end)() == 1) +assert((function() local a = 3 a = a ^ 2 return a end)() == 9) + +assert((function() local a = '1' a = a .. '2' return a end)() == "12") +assert((function() local a = '1' a = a .. '2' .. '3' return a end)() == "123") + +assert(concat(pcall(function() return '1' .. nil .. '2' end)):match("^false,.*attempt to concatenate nil with string")) + +assert((function() local a = 1 a = a == 2 return a end)() == false) +assert((function() local a = 1 a = a ~= 2 return a end)() == true) +assert((function() local a = 1 a = a < 2 return a end)() == true) +assert((function() local a = 1 a = a <= 2 return a end)() == true) +assert((function() local a = 1 a = a > 2 return a end)() == false) +assert((function() local a = 1 a = a >= 2 return a end)() == false) + +assert((function() local a = 1 a = a and 2 return a end)() == 2) +assert((function() local a = nil a = a and 2 return a end)() == nil) +assert((function() local a = 1 a = a or 2 return a end)() == 1) +assert((function() local a = nil a = a or 2 return a end)() == 2) + +-- binary arithmetics coerces strings to numbers (sadly) +assert(1 + "2" == 3) +assert(2 * "0xa" == 20) + +-- unary ops +assert((function() local a = true a = not a return a end)() == false) +assert((function() local a = false a = not a return a end)() == true) +assert((function() local a = nil a = not a return a end)() == true) + +assert((function() return #_G end)() == 0) +assert((function() return #{1,2} end)() == 2) +assert((function() return #'g' end)() == 1) + +assert((function() local a = 1 a = -a return a end)() == -1) + +-- while/repeat +assert((function() local a = 10 local b = 1 while a > 1 do b = b * 2 a = a - 1 end return b end)() == 512) +assert((function() local a = 10 local b = 1 repeat b = b * 2 a = a - 1 until a == 1 return b end)() == 512) + +assert((function() local a = 10 local b = 1 while true do b = b * 2 a = a - 1 if a == 1 then break end end return b end)() == 512) +assert((function() local a = 10 local b = 1 while true do b = b * 2 a = a - 1 if a == 1 then break else end end return b end)() == 512) +assert((function() local a = 10 local b = 1 repeat b = b * 2 a = a - 1 if a == 1 then break end until false return b end)() == 512) +assert((function() local a = 10 local b = 1 repeat b = b * 2 a = a - 1 if a == 1 then break else end until false return b end)() == 512) + +-- this makes sure a - 4 doesn't clobber a (which would happen if the lifetime of locals inside the repeat..until block is contained within +-- the block and ends before the condition is evaluated +assert((function() repeat local a = 5 until a - 4 < 0 or a - 4 >= 0 end)() == nil) + +-- numeric for +-- basic tests with positive/negative step sizes +assert((function() local a = 1 for b=1,9 do a = a * 2 end return a end)() == 512) +assert((function() local a = 1 for b=1,9,2 do a = a * 2 end return a end)() == 32) +assert((function() local a = 1 for b=1,9,-2 do a = a * 2 end return a end)() == 1) +assert((function() local a = 1 for b=9,1,-2 do a = a * 2 end return a end)() == 32) + +-- make sure break works +assert((function() local a = 1 for b=1,9 do a = a * 2 if a == 128 then break end end return a end)() == 128) +assert((function() local a = 1 for b=1,9 do a = a * 2 if a == 128 then break else end end return a end)() == 128) + +-- make sure internal index is protected against modification +assert((function() local a = 1 for b=9,1,-2 do a = a * 2 b = nil end return a end)() == 32) + +-- generic for +-- ipairs +assert((function() local a = '' for k in ipairs({5, 6, 7}) do a = a .. k end return a end)() == "123") +assert((function() local a = '' for k,v in ipairs({5, 6, 7}) do a = a .. k end return a end)() == "123") +assert((function() local a = '' for k,v in ipairs({5, 6, 7}) do a = a .. v end return a end)() == "567") + +-- ipairs with gaps +assert((function() local a = '' for k in ipairs({5, 6, 7, nil, 8}) do a = a .. k end return a end)() == "123") +assert((function() local a = '' for k,v in ipairs({5, 6, 7, nil, 8}) do a = a .. k end return a end)() == "123") +assert((function() local a = '' for k,v in ipairs({5, 6, 7, nil, 8}) do a = a .. v end return a end)() == "567") + +-- manual ipairs/inext +local inext = ipairs({5,6,7}) +assert(concat(inext({5,6,7}, 2)) == "3,7") + +-- pairs on array +assert((function() local a = '' for k in pairs({5, 6, 7}) do a = a .. k end return a end)() == "123") +assert((function() local a = '' for k,v in pairs({5, 6, 7}) do a = a .. k end return a end)() == "123") +assert((function() local a = '' for k,v in pairs({5, 6, 7}) do a = a .. v end return a end)() == "567") + +-- pairs on array with gaps +assert((function() local a = '' for k in pairs({5, 6, 7, nil, 8}) do a = a .. k end return a end)() == "1235") +assert((function() local a = '' for k,v in pairs({5, 6, 7, nil, 8}) do a = a .. k end return a end)() == "1235") +assert((function() local a = '' for k,v in pairs({5, 6, 7, nil, 8}) do a = a .. v end return a end)() == "5678") + +-- pairs on table +assert((function() local a = {} for k in pairs({a=1, b=2, c=3}) do a[k] = 1 end return a.a + a.b + a.c end)() == 3) +assert((function() local a = {} for k,v in pairs({a=1, b=2, c=3}) do a[k] = 1 end return a.a + a.b + a.c end)() == 3) +assert((function() local a = {} for k,v in pairs({a=1, b=2, c=3}) do a[k] = v end return a.a + a.b + a.c end)() == 6) + +-- pairs on mixed array/table + gaps in the array portion +-- note that a,b,c results in a,c,b during traversal since index is based on hash & size +assert((function() local a = {} for k,v in pairs({1, 2, 3, a=5, b=6, c=7}) do a[#a+1] = v end return table.concat(a, ',') end)() == "1,2,3,5,7,6") +assert((function() local a = {} for k,v in pairs({1, 2, 3, nil, 4, a=5, b=6, c=7}) do a[#a+1] = v end return table.concat(a, ',') end)() == "1,2,3,4,5,7,6") + +-- pairs manually +assert((function() local a = '' for k in next,{5, 6, 7} do a = a .. k end return a end)() == "123") +assert((function() local a = '' for k,v in next,{5, 6, 7} do a = a .. k end return a end)() == "123") +assert((function() local a = '' for k,v in next,{5, 6, 7} do a = a .. v end return a end)() == "567") +assert((function() local a = {} for k in next,{a=1, b=2, c=3} do a[k] = 1 end return a.a + a.b + a.c end)() == 3) +assert((function() local a = {} for k,v in next,{a=1, b=2, c=3} do a[k] = 1 end return a.a + a.b + a.c end)() == 3) +assert((function() local a = {} for k,v in next,{a=1, b=2, c=3} do a[k] = v end return a.a + a.b + a.c end)() == 6) + +-- too many vars +assert((function() local a = '' for k,v,p in pairs({a=1, b=2, c=3}) do a = a .. tostring(p) end return a end)() == "nilnilnil") + +-- make sure break works +assert((function() local a = 1 for _ in pairs({1,2,3}) do a = a * 2 if a == 4 then break end end return a end)() == 4) +assert((function() local a = 1 for _ in pairs({1,2,3}) do a = a * 2 if a == 4 then break else end end return a end)() == 4) + +-- make sure internal index is protected against modification +assert((function() local a = 1 for b in pairs({1,2,3}) do a = a * 2 b = nil end return a end)() == 8) + +-- make sure custom iterators work! example is from PIL 7.1 +function list_iter(t) + local i = 0 + local n = table.getn(t) + return function() + i = i + 1 + if i <= n then return t[i] end + end +end + +assert((function() local a = '' for e in list_iter({4,2,1}) do a = a .. e end return a end)() == "421") + +-- make sure multret works in context of pairs() - this is a very painful to handle combination due to complex internal details +assert((function() local function f() return {5,6,7},8,9,0 end local a = '' for k,v in ipairs(f()) do a = a .. v end return a end)() == "567") + +-- table literals +-- basic tests +assert((function() local t = {} return #t end)() == 0) + +assert((function() local t = {1, 2} return #t end)() == 2) +assert((function() local t = {1, 2} return t[1] + t[2] end)() == 3) + +assert((function() local t = {data = 4} return t.data end)() == 4) +assert((function() local t = {[1+2] = 4} return t[3] end)() == 4) + +assert((function() local t = {data = 4, [1+2] = 5} return t.data + t[3] end)() == 9) + +assert((function() local t = {[1] = 1, [2] = 2} return t[1] + t[2] end)() == 3) + +-- since table ctor is chunked in groups of 16, we should be careful with edge cases around that +assert((function() return table.concat({}, ',') end)() == "") +assert((function() return table.concat({1}, ',') end)() == "1") +assert((function() return table.concat({1,2}, ',') end)() == "1,2") +assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, ',') end)() == + "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15") +assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16") +assert((function() return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17") + +-- some scripts rely on exact table traversal order; while it's evil to do so, let's check that it works +assert((function() + local kSelectedBiomes = { + ['Mountains'] = true, + ['Canyons'] = true, + ['Dunes'] = true, + ['Arctic'] = true, + ['Lavaflow'] = true, + ['Hills'] = true, + ['Plains'] = true, + ['Marsh'] = true, + ['Water'] = true, + } + local result = "" + for k in pairs(kSelectedBiomes) do result = result .. k end + return result +end)() == "ArcticDunesCanyonsWaterMountainsHillsLavaflowPlainsMarsh") + +-- multiple returns +-- local= +assert((function() function foo() return 2, 3, 4 end local a, b, c = foo() return ''..a..b..c end)() == "234") +assert((function() function foo() return 2, 3, 4 end local a, b, c = 1, foo() return ''..a..b..c end)() == "123") +assert((function() function foo() return 2 end local a, b, c = 1, foo() return ''..a..b..tostring(c) end)() == "12nil") + +-- assignments +assert((function() function foo() return 2, 3 end a, b, c, d = 1, foo() return ''..a..b..c..tostring(d) end)() == "123nil") +assert((function() function foo() return 2, 3 end local a, b, c, d a, b, c, d = 1, foo() return ''..a..b..c..tostring(d) end)() == "123nil") + +-- varargs +-- local= +assert((function() function foo(...) local a, b, c = ... return a + b + c end return foo(1, 2, 3) end)() == 6) +assert((function() function foo(x, ...) local a, b, c = ... return a + b + c end return foo(1, 2, 3, 4) end)() == 9) + +-- assignments +assert((function() function foo(...) a, b, c = ... return a + b + c end return foo(1, 2, 3) end)() == 6) +assert((function() function foo(x, ...) a, b, c = ... return a + b + c end return foo(1, 2, 3, 4) end)() == 9) + +-- extra nils +assert((function() function foo(...) local a, b, c = ... return tostring(a) .. tostring(b) .. tostring(c) end return foo(1, 2) end)() == "12nil") + +-- varargs + multiple returns +-- return +assert((function() function foo(...) return ... end return concat(foo(1, 2, 3)) end)() == "1,2,3") +assert((function() function foo(...) return ... end return foo() end)() == nil) +assert((function() function foo(a, ...) return a + 10, ... end return concat(foo(1, 2, 3)) end)() == "11,2,3") + +-- call +assert((function() function foo(...) return ... end function bar(...) return foo(...) end return concat(bar(1, 2, 3)) end)() == "1,2,3") +assert((function() function foo(...) return ... end function bar(...) return foo(...) end return bar() end)() == nil) +assert((function() function foo(a, ...) return a + 10, ... end function bar(a, ...) return foo(a * 2, ...) end return concat(bar(1, 2, 3)) end)() == "12,2,3") + +-- manual pack +assert((function() function pack(first, ...) if not first then return {} end local t = pack(...) table.insert(t, 1, first) return t end function foo(...) return pack(...) end return #foo(0, 1, 2) end)() == 3) + +-- multret + table literals +-- basic tests +assert((function() function foo(...) return { ... } end return #(foo()) end)() == 0) +assert((function() function foo(...) return { ... } end return #(foo(1, 2, 3)) end)() == 3) +assert((function() function foo() return 1, 2, 3 end return #({foo()}) end)() == 3) + +-- since table ctor is chunked in groups of 16, we should be careful with edge cases around that +assert((function() function foo() return 1, 2, 3 end return table.concat({foo()}, ',') end)() == "1,2,3") +assert((function() function foo() return 1, 2, 3 end return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, foo()}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,1,2,3") +assert((function() function foo() return 1, 2, 3 end return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, foo()}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,1,2,3") +assert((function() function foo() return 1, 2, 3 end return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, foo()}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,1,2,3") +assert((function() function foo() return 1, 2, 3 end return table.concat({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, foo()}, ',') end)() == "1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,1,2,3") + +-- table access +assert((function() local t = {6, 9, 7} return t[2] end)() == 9) +assert((function() local t = {6, 9, 7} return t[0] end)() == nil) +assert((function() local t = {6, 9, 7} return t[4] end)() == nil) +assert((function() local t = {6, 9, 7} return t[4.5] end)() == nil) +assert((function() local t = {6, 9, 7, [4.5]=11} return t[4.5] end)() == 11) +assert((function() local t = {6, 9, 7, a=11} return t['a'] end)() == 11) +assert((function() local t = {6, 9, 7} setmetatable(t, { __index = function(t,i) return i * 10 end }) return concat(t[2],t[5]) end)() == "9,50") + +assert((function() local t = {6, 9, 7} t[2] = 10 return t[2] end)() == 10) +assert((function() local t = {6, 9, 7} t[0] = 5 return t[0] end)() == 5) +assert((function() local t = {6, 9, 7} t[4] = 10 return t[4] end)() == 10) +assert((function() local t = {6, 9, 7} t[4.5] = 10 return t[4.5] end)() == 10) +assert((function() local t = {6, 9, 7} t['a'] = 11 return t['a'] end)() == 11) +assert((function() local t = {6, 9, 7} setmetatable(t, { __newindex = function(t,i,v) rawset(t, i * 10, v) end }) t[1] = 17 t[5] = 1 return concat(t[1],t[5],t[50]) end)() == "17,nil,1") + +-- and/or +-- rhs is a constant +assert((function() local a = 1 a = a and 2 return a end)() == 2) +assert((function() local a = nil a = a and 2 return a end)() == nil) +assert((function() local a = 1 a = a or 2 return a end)() == 1) +assert((function() local a = nil a = a or 2 return a end)() == 2) + +-- rhs is a local +assert((function() local a = 1 local b = 2 a = a and b return a end)() == 2) +assert((function() local a = nil local b = 2 a = a and b return a end)() == nil) +assert((function() local a = 1 local b = 2 a = a or b return a end)() == 1) +assert((function() local a = nil local b = 2 a = a or b return a end)() == 2) + +-- rhs is a global (prevents optimizations) +assert((function() local a = 1 b = 2 a = a and b return a end)() == 2) +assert((function() local a = nil b = 2 a = a and b return a end)() == nil) +assert((function() local a = 1 b = 2 a = a or b return a end)() == 1) +assert((function() local a = nil b = 2 a = a or b return a end)() == 2) + +-- table access: method calls + fake oop via mt +assert((function() + local Class = {} + Class.__index = Class + + function Class.new() + local self = {} + setmetatable(self, Class) + + self.field = 5 + + return self + end + + function Class:GetField() + return self.field + end + + local object = Class.new() + return object:GetField() +end)() == 5) + +-- table access: evil indexer +assert((function() + local a = {5} + local b = {6} + local mt = { __index = function() return b[1] end } + setmetatable(a, mt) + b = a.hi + return b +end)() == 6) + +-- table access: fast-path tests for array lookup +-- in-bounds array lookup shouldn't call into Lua, but if the element isn't there we'll still call the metatable +assert((function() local a = {9, [1.5] = 7} return a[1], a[2], a[1.5] end)() == 9,nil,7) +assert((function() local a = {9, [1.5] = 7} setmetatable(a, { __index = function() return 5 end }) return concat(a[1],a[2],a[1.5]) end)() == "9,5,7") +assert((function() local a = {9, nil, 11} setmetatable(a, { __index = function() return 5 end }) return concat(a[1],a[2],a[3],a[4]) end)() == "9,5,11,5") + +-- namecall for userdata: technically not officially supported but hard to test in a different way! +-- warning: this test may break at any time as we may decide that we'll only use userdata-namecall on tagged user data objects +assert((function() + local obj = newproxy(true) + getmetatable(obj).__namecall = function(self, arg) return 42 + arg end + return obj:Foo(10) +end)() == 52) +assert((function() + local obj = newproxy(true) + local t = {} + setmetatable(t, { __call = function(self1, self2, arg) return 42 + arg end }) + getmetatable(obj).__namecall = t + return obj:Foo(10) +end)() == 52) + +-- namecall for oop to test fast paths +assert((function() + local Class = {} + Class.__index = Class + + function Class:new(klass, v) -- note, this isn't necessarily common but it exercises additional namecall paths + local self = {value = v} + setmetatable(self, Class) + return self + end + + function Class:get() + return self.value + end + + function Class:set(v) + self.value = v + end + + local n = Class:new(32) + n:set(42) + return n:get() +end)() == 42) + +-- comparison +-- basic types +assert((function() a = nil return concat(a == nil, a ~= nil) end)() == "true,false") +assert((function() a = nil return concat(a == 1, a ~= 1) end)() == "false,true") +assert((function() a = 1 return concat(a == 1, a ~= 1) end)() == "true,false") +assert((function() a = 1 return concat(a == 2, a ~= 2) end)() == "false,true") +assert((function() a = true return concat(a == true, a ~= true) end)() == "true,false") +assert((function() a = true return concat(a == false, a ~= false) end)() == "false,true") +assert((function() a = 'a' return concat(a == 'a', a ~= 'a') end)() == "true,false") +assert((function() a = 'a' return concat(a == 'b', a ~= 'b') end)() == "false,true") + +-- tables, reference equality (no mt) +assert((function() a = {} return concat(a == a, a ~= a) end)() == "true,false") +assert((function() a = {} b = {} return concat(a == b, a ~= b) end)() == "false,true") + +-- tables, reference equality (mt without __eq) +assert((function() a = {} setmetatable(a, {}) return concat(a == a, a ~= a) end)() == "true,false") +assert((function() a = {} b = {} mt = {} setmetatable(a, mt) setmetatable(b, mt) return concat(a == b, a ~= b) end)() == "false,true") + +-- tables, __eq with same mt/different mt but same function/different function +assert((function() a = {} b = {} mt = { __eq = function(l, r) return #l == #r end } setmetatable(a, mt) setmetatable(b, mt) return concat(a == b, a ~= b) end)() == "true,false") +assert((function() a = {} b = {} function eq(l, r) return #l == #r end setmetatable(a, {__eq = eq}) setmetatable(b, {__eq = eq}) return concat(a == b, a ~= b) end)() == "true,false") +assert((function() a = {} b = {} setmetatable(a, {__eq = function(l, r) return #l == #r end}) setmetatable(b, {__eq = function(l, r) return #l == #r end}) return concat(a == b, a ~= b) end)() == "false,true") + +-- userdata, reference equality (no mt) +assert((function() a = newproxy(true) return concat(a == newproxy(true),a ~= newproxy(true)) end)() == "false,true") + +-- rawequal +assert(rawequal(true, 5) == false) +assert(rawequal(nil, nil) == true) +assert(rawequal(true, false) == false) +assert(rawequal(true, true) == true) +assert(rawequal(0, -0) == true) +assert(rawequal(1, 2) == false) +assert(rawequal("a", "a") == true) +assert(rawequal("a", "b") == false) +assert((function() a = {} b = {} mt = { __eq = function(l, r) return #l == #r end } setmetatable(a, mt) setmetatable(b, mt) return concat(a == b, rawequal(a, b)) end)() == "true,false") + +-- metatable ops +local function vec3t(x, y, z) + return setmetatable({ x=x, y=y, z=z}, { + __add = function(l, r) return vec3t(l.x + r.x, l.y + r.y, l.z + r.z) end, + __sub = function(l, r) return vec3t(l.x - r.x, l.y - r.y, l.z - r.z) end, + __mul = function(l, r) return type(r) == "number" and vec3t(l.x * r, l.y * r, l.z * r) or vec3t(l.x * r.x, l.y * r.y, l.z * r.z) end, + __div = function(l, r) return type(r) == "number" and vec3t(l.x / r, l.y / r, l.z / r) or vec3t(l.x / r.x, l.y / r.y, l.z / r.z) end, + __unm = function(v) return vec3t(-v.x, -v.y, -v.z) end, + __tostring = function(v) return string.format("%g, %g, %g", v.x, v.y, v.z) end + }) +end + +-- reg vs reg +assert((function() return tostring(vec3t(1,2,3) + vec3t(4,5,6)) end)() == "5, 7, 9") +assert((function() return tostring(vec3t(1,2,3) - vec3t(4,5,6)) end)() == "-3, -3, -3") +assert((function() return tostring(vec3t(1,2,3) * vec3t(4,5,6)) end)() == "4, 10, 18") +assert((function() return tostring(vec3t(1,2,3) / vec3t(2,4,8)) end)() == "0.5, 0.5, 0.375") + +-- reg vs constant +assert((function() return tostring(vec3t(1,2,3) * 2) end)() == "2, 4, 6") +assert((function() return tostring(vec3t(1,2,3) / 2) end)() == "0.5, 1, 1.5") + +-- unary +assert((function() return tostring(-vec3t(1,2,3)) end)() == "-1, -2, -3") + +-- string comparison +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('a', 'b')) end)() == "true,true,false,false") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('a', 'a')) end)() == "false,true,false,true") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('a', '')) end)() == "false,false,true,true") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('', '\\0')) end)() == "true,true,false,false") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('abc', 'abd')) end)() == "true,true,false,false") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('ab\\0c', 'ab\\0d')) end)() == "true,true,false,false") +assert((function() function cmp(a,b) return ab,a>=b end return concat(cmp('ab\\0c', 'ab\\0')) end)() == "false,false,true,true") + +-- array access +assert((function() local a = {4,5,6} return a[3] end)() == 6) +assert((function() local a = {4,5,nil,6} return a[3] end)() == nil) +assert((function() local a = {4,5,nil,6} setmetatable(a, { __index = function() return 42 end }) return a[4] end)() == 6) +assert((function() local a = {4,5,nil,6} setmetatable(a, { __index = function() return 42 end }) return a[3] end)() == 42) +assert((function() local a = {4,5,6} a[3] = 8 return a[3] end)() == 8) +assert((function() local a = {4,5,nil,6} a[3] = 8 return a[3] end)() == 8) +assert((function() local a = {4,5,nil,6} setmetatable(a, { __newindex = function(t,i) rawset(t,i,42) end }) a[4] = 0 return a[4] end)() == 0) +assert((function() local a = {4,5,nil,6} setmetatable(a, { __newindex = function(t,i) rawset(t,i,42) end }) a[3] = 0 return a[3] end)() == 42) + +-- array index for literal +assert((function() local a = {4, 5, nil, 6} return concat(a[1], a[3], a[4], a[100]) end)() == "4,nil,6,nil") +assert((function() local a = {4, 5, nil, 6} a[1] = 42 a[3] = 0 a[100] = 75 return concat(a[1], a[3], a[75], a[100]) end)() == "42,0,nil,75") + +-- load error +assert((function() return concat(loadstring('hello world')) end)() == "nil,[string \"hello world\"]:1: Incomplete statement: expected assignment or a function call") + +-- many arguments & locals +function f(p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, + p11, p12, p13, p14, p15, p16, p17, p18, p19, p20, + p21, p22, p23, p24, p25, p26, p27, p28, p29, p30, + p31, p32, p33, p34, p35, p36, p37, p38, p39, p40, + p41, p42, p43, p44, p45, p46, p48, p49, p50, ...) + local a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14 +end + +assert(f() == nil) + +-- upvalues & loops (validates timely closing) +assert((function() + local res = {} + + for i=1,5 do + res[#res+1] = (function() return i end) + end + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 15) + +assert((function() + local res = {} + + for i in ipairs({1,2,3,4,5}) do + res[#res+1] =(function() return i end) + end + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 15) + +assert((function() + local res = {} + + local i = 0 + while i <= 5 do + local j = i + res[#res+1] = (function() return j end) + i = i + 1 + end + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 15) + +assert((function() + local res = {} + + local i = 0 + repeat + local j = i + res[#res+1] = (function() return j end) + i = i + 1 + until i > 5 + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 15) + +-- upvalues & loops & break! +assert((function() + local res = {} + + for i=1,10 do + res[#res+1] = (function() return i end) + if i == 5 then + break + end + end + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 15) + +assert((function() + local res = {} + + for i in ipairs({1,2,3,4,5,6,7,8,9,10}) do + res[#res+1] =(function() return i end) + if i == 5 then + break + end + end + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 15) + +assert((function() + local res = {} + + local i = 0 + while i < 10 do + local j = i + res[#res+1] = (function() return j end) + if i == 5 then + break + end + i = i + 1 + end + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 15) + +assert((function() + local res = {} + + local i = 0 + repeat + local j = i + res[#res+1] = (function() return j end) + if i == 5 then + break + end + i = i + 1 + until i >= 10 + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 15) + +-- the reason why this test is interesting is that the table created here has arraysize=0 and a single hash element with key = 1.0 +-- ipairs must iterate through that +assert((function() + local arr = { [1] = 42 } + local sum = 0 + for i,v in ipairs(arr) do + sum = sum + v + end + return sum +end)() == 42) + +-- the reason why this test is interesting is it ensures we do correct mutability analysis for locals +local function chainTest(n) + local first = nil + local last = nil + + -- Build chain of n equality constraints + for i = 0, n do + local name = "v" .. i; + if i == 0 then first = name end + if i == n then last = name end + end + + return concat(first, last) +end + +assert(chainTest(100) == "v0,v100") + +-- this validates import fallbacks +assert(pcall(function() return idontexist.a end) == false) + +-- make sure that NaN is preserved by the bytecode compiler +local realnan = tostring(math.abs(0)/math.abs(0)) +assert(tostring(0/0*0) == realnan) +assert(tostring((-1)^(1/2)) == realnan) + +-- make sure that negative zero is preserved by bytecode compiler +assert(tostring(0) == "0") +assert(tostring(-0) == "-0") + +-- test newline handling in long strings +assert((function() + local s1 = [[ +]] + local s2 = [[ + +]] + local s3 = [[ +foo +bar]] + local s4 = [[ +foo +bar +]] + return concat(s1,s2,s3,s4) +end)() == ",\n,foo\nbar,foo\nbar\n") + +-- fastcall +-- positive tests for all simple examples; note that in this case the call is a multret call (nresults=LUA_MULTRET) +assert((function() return math.abs(-5) end)() == 5) +assert((function() local abs = math.abs return abs(-5) end)() == 5) +assert((function() local abs = math.abs function foo() return abs(-5) end return foo() end)() == 5) + +-- vararg testing - in this case nparams = LUA_MULTRET and it gets adjusted before execution +assert((function() function foo(...) return math.abs(...) end return foo(-5) end)() == 5) +assert((function() function foo(...) local abs = math.abs return abs(...) end return foo(-5) end)() == 5) +assert((function() local abs = math.abs function foo(...) return abs(...) end return foo(-5) end)() == 5) + +-- NOTE: getfenv breaks fastcalls for the remainder of the source! hence why this is delayed until the end +function testgetfenv() + getfenv() + -- getfenv breaks fastcalls (we assume we can't rely on knowing the semantics), but behavior shouldn't change + assert((function() return math.abs(-5) end)() == 5) + assert((function() local abs = math.abs return abs(-5) end)() == 5) + assert((function() local abs = math.abs function foo() return abs(-5) end return foo() end)() == 5) + + -- ... unless you actually reassign the function :D + getfenv().math = { abs = function(n) return n*n end } + assert((function() return math.abs(-5) end)() == 25) + assert((function() local abs = math.abs return abs(-5) end)() == 25) + assert((function() local abs = math.abs function foo() return abs(-5) end return foo() end)() == 25) +end + +-- you need to have enough arguments and arguments of the right type; if you don't, we'll fallback to the regular code. This checks coercions +-- first to make sure all fallback paths work +assert((function() return math.abs('-5') end)() == 5) +assert((function() local abs = math.abs return abs('-5') end)() == 5) +assert((function() local abs = math.abs function foo() return abs('-5') end return foo() end)() == 5) + +-- if you don't have enough arguments or types are wrong, we fall back to the regular execution; this checks that the error generated is actually correct +assert(concat(pcall(function() return math.abs() end)):match("missing argument #1 to 'abs'")) +assert(concat(pcall(function() return math.abs(nil) end)):match("invalid argument #1 to 'abs'")) +assert(concat(pcall(function() return math.abs({}) end)):match("invalid argument #1 to 'abs'")) + +-- very large unpack +assert(select('#', table.unpack({1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1})) == 263) + +-- basic continue in for/while/repeat loops +assert((function() local a = 1 for i=1,8 do a = a + 1 if a < 5 then continue end a = a * 2 end return a end)() == 190) +assert((function() local a = 1 while a < 100 do a = a + 1 if a < 5 then continue end a = a * 2 end return a end)() == 190) +assert((function() local a = 1 repeat a = a + 1 if a < 5 then continue end a = a * 2 until a > 100 return a end)() == 190) + +-- upvalues, loops, continue +assert((function() + local res = {} + + for i=1,10 do + res[#res+1] = (function() return i end) + if i == 5 then + continue + end + i = i * 2 + end + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 105) + +assert((function() + local res = {} + + for i in ipairs({1,2,3,4,5,6,7,8,9,10}) do + res[#res+1] =(function() return i end) + if i == 5 then + continue + end + i = i * 2 + end + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 105) + +assert((function() + local res = {} + + local i = 1 + while i <= 10 do + local j = i + res[#res+1] = (function() return j end) + if i == 5 then + i = i + 1 + continue + end + i = i + 1 + j = j * 2 + end + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 105) + +assert((function() + local res = {} + + local i = 1 + repeat + local j = i + res[#res+1] = (function() return j end) + if i == 5 then + i = i + 1 + continue + end + i = i + 1 + j = j * 2 + until i > 10 + + local sum = 0 + for i,f in pairs(res) do sum = sum + f() end + + return sum +end)() == 105) + +-- upvalues: recursive capture +assert((function() local function fact(n) return n < 1 and 1 or n * fact(n-1) end return fact(5) end)() == 120) + +-- basic compound assignment +assert((function() + local a = 1 + b = 2 + local c = { value = 3 } + local d = { 4 } + local e = 3 + local f = 2 + + a += 5 + b -= a + c.value *= 3 + d[1] /= b + e %= 2 + f ^= 4 + + return concat(a,b,c.value,d[1],e,f) +end)() == "6,-4,9,-1,1,16") + +-- compound concat +assert((function() + local a = 'a' + + a ..= 'b' + a ..= 'c' .. 'd' + a ..= 'e' .. 'f' .. a + + return a +end)() == "abcdefabcd") + +-- compound assignment with side effects validates lhs is evaluated once +assert((function() + local res = { 1, 2, 3 } + local count = 0 + + res[(function() count += 1 return count end)()] += 5 + res[(function() count += 1 return count end)()] += 6 + res[(function() count += 1 return count end)()] += 7 + + return table.concat(res, ',') +end)() == "6,8,10") + +-- typeof == type in absence of custom userdata +assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number,nil,table,userdata") + +testgetfenv() -- DONT MOVE THIS LINE + +return'OK' diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua new file mode 100644 index 0000000..6efa596 --- /dev/null +++ b/tests/conformance/bitwise.lua @@ -0,0 +1,140 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print("testing bitwise operations") + +assert(bit32.band() == bit32.bnot(0)) +assert(bit32.btest() == true) +assert(bit32.bor() == 0) +assert(bit32.bxor() == 0) + +assert(bit32.band() == bit32.band(0xffffffff)) +assert(bit32.band(1,2) == 0) + + +-- out-of-range numbers +assert(bit32.band(-1) == 0xffffffff) +assert(bit32.band(2^33 - 1) == 0xffffffff) +assert(bit32.band(-2^33 - 1) == 0xffffffff) +assert(bit32.band(2^33 + 1) == 1) +assert(bit32.band(-2^33 + 1) == 1) +assert(bit32.band(-2^40) == 0) +assert(bit32.band(2^40) == 0) +assert(bit32.band(-2^40 - 2) == 0xfffffffe) +assert(bit32.band(2^40 - 4) == 0xfffffffc) + +assert(bit32.lrotate(0, -1) == 0) +assert(bit32.lrotate(0, 7) == 0) +assert(bit32.lrotate(0x12345678, 4) == 0x23456781) +assert(bit32.rrotate(0x12345678, -4) == 0x23456781) +assert(bit32.lrotate(0x12345678, -8) == 0x78123456) +assert(bit32.rrotate(0x12345678, 8) == 0x78123456) +assert(bit32.lrotate(0xaaaaaaaa, 2) == 0xaaaaaaaa) +assert(bit32.lrotate(0xaaaaaaaa, -2) == 0xaaaaaaaa) +for i = -50, 50 do + assert(bit32.lrotate(0x89abcdef, i) == bit32.lrotate(0x89abcdef, i%32)) +end + +assert(bit32.lshift(0x12345678, 4) == 0x23456780) +assert(bit32.lshift(0x12345678, 8) == 0x34567800) +assert(bit32.lshift(0x12345678, -4) == 0x01234567) +assert(bit32.lshift(0x12345678, -8) == 0x00123456) +assert(bit32.lshift(0x12345678, 32) == 0) +assert(bit32.lshift(0x12345678, -32) == 0) +assert(bit32.rshift(0x12345678, 4) == 0x01234567) +assert(bit32.rshift(0x12345678, 8) == 0x00123456) +assert(bit32.rshift(0x12345678, 32) == 0) +assert(bit32.rshift(0x12345678, -32) == 0) +assert(bit32.arshift(0x12345678, 0) == 0x12345678) +assert(bit32.arshift(0x12345678, 1) == 0x12345678 / 2) +assert(bit32.arshift(0x12345678, -1) == 0x12345678 * 2) +assert(bit32.arshift(-1, 1) == 0xffffffff) +assert(bit32.arshift(-1, 24) == 0xffffffff) +assert(bit32.arshift(-1, 32) == 0xffffffff) +assert(bit32.arshift(-1, -1) == (-1 * 2) % 2^32) + +print("+") +-- some special cases +local c = {0, 1, 2, 3, 10, 0x80000000, 0xaaaaaaaa, 0x55555555, + 0xffffffff, 0x7fffffff} + +for _, b in pairs(c) do + assert(bit32.band(b) == b) + assert(bit32.band(b, b) == b) + assert(bit32.btest(b, b) == (b ~= 0)) + assert(bit32.band(b, b, b) == b) + assert(bit32.btest(b, b, b) == (b ~= 0)) + assert(bit32.band(b, bit32.bnot(b)) == 0) + assert(bit32.bor(b, bit32.bnot(b)) == bit32.bnot(0)) + assert(bit32.bor(b) == b) + assert(bit32.bor(b, b) == b) + assert(bit32.bor(b, b, b) == b) + assert(bit32.bxor(b) == b) + assert(bit32.bxor(b, b) == 0) + assert(bit32.bxor(b, 0) == b) + assert(bit32.bnot(b) ~= b) + assert(bit32.bnot(bit32.bnot(b)) == b) + assert(bit32.bnot(b) == 2^32 - 1 - b) + assert(bit32.lrotate(b, 32) == b) + assert(bit32.rrotate(b, 32) == b) + assert(bit32.lshift(bit32.lshift(b, -4), 4) == bit32.band(b, bit32.bnot(0xf))) + assert(bit32.rshift(bit32.rshift(b, 4), -4) == bit32.band(b, bit32.bnot(0xf))) + for i = -40, 40 do + assert(bit32.lshift(b, i) == math.floor((b * 2^i) % 2^32)) + end +end + +assert(not pcall(bit32.band, {})) +assert(not pcall(bit32.bnot, "a")) +assert(not pcall(bit32.lshift, 45)) +assert(not pcall(bit32.lshift, 45, print)) +assert(not pcall(bit32.rshift, 45, print)) + +print("+") + + +-- testing extract/replace + +assert(bit32.extract(0x12345678, 0, 4) == 8) +assert(bit32.extract(0x12345678, 4, 4) == 7) +assert(bit32.extract(0xa0001111, 28, 4) == 0xa) +assert(bit32.extract(0xa0001111, 31, 1) == 1) +assert(bit32.extract(0x50000111, 31, 1) == 0) +assert(bit32.extract(0xf2345679, 0, 32) == 0xf2345679) + +assert(not pcall(bit32.extract, 0, -1)) +assert(not pcall(bit32.extract, 0, 32)) +assert(not pcall(bit32.extract, 0, 0, 33)) +assert(not pcall(bit32.extract, 0, 31, 2)) + +assert(bit32.replace(0x12345678, 5, 28, 4) == 0x52345678) +assert(bit32.replace(0x12345678, 0x87654321, 0, 32) == 0x87654321) +assert(bit32.replace(0, 1, 2) == 2^2) +assert(bit32.replace(0, -1, 4) == 2^4) +assert(bit32.replace(-1, 0, 31) == 2^31 - 1) +assert(bit32.replace(-1, 0, 1, 2) == 2^32 - 7) + +--[[ +This test verifies a fix in luauF_replace() where if the 4th +parameter was not a number, but the first three are numbers, it will +cause the Luau math library to crash. +]]-- + +assert(bit32.replace(-1, 0, 1, "2") == 2^32 - 7) + +-- many of the tests above go through fastcall path +-- to make sure the basic implementations are also correct we test some functions with string->number coercions +assert(bit32.lrotate("0x12345678", 4) == 0x23456781) +assert(bit32.rrotate("0x12345678", -4) == 0x23456781) +assert(bit32.arshift("0x12345678", 1) == 0x12345678 / 2) +assert(bit32.arshift("-1", 32) == 0xffffffff) +assert(bit32.bnot("1") == 0xfffffffe) +assert(bit32.band("1", 3) == 1) +assert(bit32.band(1, "3") == 1) +assert(bit32.bor("1", 2) == 3) +assert(bit32.bor(1, "2") == 3) +assert(bit32.bxor("1", 3) == 2) +assert(bit32.bxor(1, "3") == 2) +assert(bit32.btest(1, "3") == true) +assert(bit32.btest("1", 3) == true) + +return('OK') diff --git a/tests/conformance/calls.lua b/tests/conformance/calls.lua new file mode 100644 index 0000000..7f9610a --- /dev/null +++ b/tests/conformance/calls.lua @@ -0,0 +1,229 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print("testing functions and calls") + +-- get the opportunity to test 'type' too ;) + +assert(type(1<2) == 'boolean') +assert(type(true) == 'boolean' and type(false) == 'boolean') +assert(type(nil) == 'nil' and type(-3) == 'number' and type'x' == 'string' and + type{} == 'table' and type(type) == 'function') + +assert(type(assert) == type(print)) +f = nil +function f (x) return a:x (x) end +assert(type(f) == 'function') + +local pack = table.pack +local unpack = table.unpack + +-- testing local-function recursion +fact = false +do + local res = 1 + local function fact (n) + if n==0 then return res + else return n*fact(n-1) + end + end + assert(fact(5) == 120) +end +assert(fact == false) + +-- testing declarations +a = {i = 10} +self = 20 +function a:x (x) return x+self.i end +function a.y (x) return x+self end + +assert(a:x(1)+10 == a.y(1)) + +a.t = {i=-100} +a["t"].x = function (self, a,b) return self.i+a+b end + +assert(a.t:x(2,3) == -95) + +do + local a = {x=0} + function a:add (x) self.x, a.y = self.x+x, 20; return self end + assert(a:add(10):add(20):add(30).x == 60 and a.y == 20) +end + +local a = {b={c={}}} + +function a.b.c.f1 (x) return x+1 end +function a.b.c:f2 (x,y) self[x] = y end +assert(a.b.c.f1(4) == 5) +a.b.c:f2('k', 12); assert(a.b.c.k == 12) + +print('+') + +t = nil -- 'declare' t +function f(a,b,c) local d = 'a'; t={a,b,c,d} end + +f( -- this line change must be valid + 1,2) +assert(t[1] == 1 and t[2] == 2 and t[3] == nil and t[4] == 'a') +f(1,2, -- this one too + 3,4) +assert(t[1] == 1 and t[2] == 2 and t[3] == 3 and t[4] == 'a') + +function fat(x) + if x <= 1 then return 1 + else return x*loadstring("return fat(" .. x-1 .. ")")() + end +end + +assert(loadstring "loadstring 'assert(fat(6)==720)' () ")() +a = loadstring('return fat(5), 3') +a,b = a() +assert(a == 120 and b == 3) +print('+') + +function err_on_n (n) + if n==0 then error(); exit(1); + else err_on_n (n-1); exit(1); + end +end + +do + function dummy (n) + if n > 0 then + assert(not pcall(err_on_n, n)) + dummy(n-1) + end + end +end + +dummy(10) + +function deep (n) + if n>0 then deep(n-1) end +end +deep(10) +deep(200) + +-- testing tail call +function deep (n) if n>0 then return deep(n-1) else return 101 end end +assert(deep(10000) == 101) +a = {} +function a:deep (n) if n>0 then return self:deep(n-1) else return 101 end end +assert(a:deep(10000) == 101) + +print('+') + + +a = nil +(function (x) a=x end)(23) +assert(a == 23 and (function (x) return x*2 end)(20) == 40) + + +local x,y,z,a +a = {}; lim = 2000 +for i=1, lim do a[i]=i end +assert(select(lim, unpack(a)) == lim and select('#', unpack(a)) == lim) +x = unpack(a) +assert(x == 1) +x = {unpack(a)} +assert(table.getn(x) == lim and x[1] == 1 and x[lim] == lim) +x = {unpack(a, lim-2)} +assert(table.getn(x) == 3 and x[1] == lim-2 and x[3] == lim) +x = {unpack(a, 10, 6)} +assert(next(x) == nil) -- no elements +x = {unpack(a, 11, 10)} +assert(next(x) == nil) -- no elements +x,y = unpack(a, 10, 10) +assert(x == 10 and y == nil) +x,y,z = unpack(a, 10, 11) +assert(x == 10 and y == 11 and z == nil) +a,x = unpack{1} +assert(a==1 and x==nil) +a,x = unpack({1,2}, 1, 1) +assert(a==1 and x==nil) + + +-- testing closures + +-- fixed-point operator +Y = function (le) + local function a (f) + return le(function (x) return f(f)(x) end) + end + return a(a) + end + + +-- non-recursive factorial + +F = function (f) + return function (n) + if n == 0 then return 1 + else return n*f(n-1) end + end + end + +fat = Y(F) + +assert(fat(0) == 1 and fat(4) == 24 and Y(F)(5)==5*Y(F)(4)) + +local function g (z) + local function f (a,b,c,d) + return function (x,y) return a+b+c+d+a+x+y+z end + end + return f(z,z+1,z+2,z+3) +end + +f = g(10) +assert(f(9, 16) == 10+11+12+13+10+9+16+10) + +Y, F, f = nil +print('+') + +-- testing multiple returns + +function unlpack (t, i) + i = i or 1 + if (i <= table.getn(t)) then + return t[i], unlpack(t, i+1) + end +end + +function equaltab (t1, t2) + assert(table.getn(t1) == table.getn(t2)) + for i,v1 in ipairs(t1) do + assert(v1 == t2[i]) + end +end + +function f() return 1,2,30,4 end +function ret2 (a,b) return a,b end + +local a,b,c,d = unlpack{1,2,3} +assert(a==1 and b==2 and c==3 and d==nil) +a = {1,2,3,4,false,10,'alo',false,assert} +equaltab(pack(unlpack(a)), a) +equaltab(pack(unlpack(a), -1), {1,-1}) +a,b,c,d = ret2(f()), ret2(f()) +assert(a==1 and b==1 and c==2 and d==nil) +a,b,c,d = unlpack(pack(ret2(f()), ret2(f()))) +assert(a==1 and b==1 and c==2 and d==nil) +a,b,c,d = unlpack(pack(ret2(f()), (ret2(f())))) +assert(a==1 and b==1 and c==nil and d==nil) + +a = ret2{ unlpack{1,2,3}, unlpack{3,2,1}, unlpack{"a", "b"}} +assert(a[1] == 1 and a[2] == 3 and a[3] == "a" and a[4] == "b") + + +-- testing calls with 'incorrect' arguments +rawget({}, "x", 1) +rawset({}, "x", 1, 2) +assert(math.sin(1,2) == math.sin(1)) +table.sort({10,9,8,4,19,23,0,0}, function (a,b) return a 10 or a[i]() ~= x +assert(i == 11 and a[1]() == 1 and a[3]() == 3 and i == 4) + +print'+' + + +-- test for correctly closing upvalues in tail calls of vararg functions +local function t () + local function c(a,b) assert(a=="test" and b=="OK") end + local function v(f, ...) c("test", f() ~= 1 and "FAILED" or "OK") end + local x = 1 + return v(function() return x end) +end +t() + + +-- coroutine tests + +local f + +-- assert(coroutine.running() == nil) + + +-- tests for global environment +local _G = getfenv() + +local function foo (a) + setfenv(0, a) + coroutine.yield(getfenv()) + assert(getfenv(0) == a) + assert(getfenv(1) == _G) + assert(getfenv(loadstring"") == a) + return getfenv() +end + +f = coroutine.wrap(foo) +local a = {} +assert(f(a) == _G) +local a,b = pcall(f) +assert(a and b == _G) + + +-- tests for multiple yield/resume arguments + +local function eqtab (t1, t2) + assert(table.getn(t1) == table.getn(t2)) + for i,v in ipairs(t1) do + assert(t2[i] == v) + end +end + +_G.x = nil -- declare x +function foo (a, ...) + assert(coroutine.running() == f) + assert(coroutine.status(f) == "running") + local arg = {...} + for i=1,table.getn(arg) do + _G.x = {coroutine.yield(unpack(arg[i]))} + end + return unpack(a) +end + +f = coroutine.create(foo) +assert(type(f) == "thread" and coroutine.status(f) == "suspended") +assert(string.find(tostring(f), "thread")) +local s,a,b,c,d +s,a,b,c,d = coroutine.resume(f, {1,2,3}, {}, {1}, {'a', 'b', 'c'}) +assert(s and a == nil and coroutine.status(f) == "suspended") +s,a,b,c,d = coroutine.resume(f) +eqtab(_G.x, {}) +assert(s and a == 1 and b == nil) +s,a,b,c,d = coroutine.resume(f, 1, 2, 3) +eqtab(_G.x, {1, 2, 3}) +assert(s and a == 'a' and b == 'b' and c == 'c' and d == nil) +s,a,b,c,d = coroutine.resume(f, "xuxu") +eqtab(_G.x, {"xuxu"}) +assert(s and a == 1 and b == 2 and c == 3 and d == nil) +assert(coroutine.status(f) == "dead") +s, a = coroutine.resume(f, "xuxu") +assert(not s and string.find(a, "dead") and coroutine.status(f) == "dead") + + +-- yields in tail calls +local function foo (i) return coroutine.yield(i) end +f = coroutine.wrap(function () + for i=1,10 do + assert(foo(i) == _G.x) + end + return 'a' +end) +for i=1,10 do _G.x = i; assert(f(i) == i) end +_G.x = 'xuxu'; assert(f('xuxu') == 'a') + +-- recursive +function pf (n, i) + coroutine.yield(n) + pf(n*i, i+1) +end + +f = coroutine.wrap(pf) +local s=1 +for i=1,10 do + assert(f(1, 1) == s) + s = s*i +end + +-- sieve +function gen (n) + return coroutine.wrap(function () + for i=2,n do coroutine.yield(i) end + end) +end + + +function filter (p, g) + return coroutine.wrap(function () + while 1 do + local n = g() + if n == nil then return end + if n%p ~= 0 then coroutine.yield(n) end + end + end) +end + +local x = gen(100) +local a = {} +while 1 do + local n = x() + if n == nil then break end + table.insert(a, n) + x = filter(n, x) +end + +assert(table.getn(a) == 25 and a[table.getn(a)] == 97) + + +-- errors in coroutines +function foo () + -- assert(debug.getinfo(1).currentline == debug.getinfo(foo).linedefined + 1) + -- assert(debug.getinfo(2).currentline == debug.getinfo(goo).linedefined) + coroutine.yield(3) + error("foo") +end + +local fooerr = "closure.lua:284: foo" + +function goo() foo() end +x = coroutine.wrap(goo) +assert(x() == 3) +local a,b = pcall(x) +assert(not a and b == fooerr) + +x = coroutine.create(goo) +a,b = coroutine.resume(x) +assert(a and b == 3) +a,b = coroutine.resume(x) +assert(not a and b == fooerr and coroutine.status(x) == "dead") +a,b = coroutine.resume(x) +assert(not a and string.find(b, "dead") and coroutine.status(x) == "dead") + + +-- co-routines x for loop +function all (a, n, k) + if k == 0 then coroutine.yield(a) + else + for i=1,n do + a[k] = i + all(a, n, k-1) + end + end +end + +local a = 0 +for t in coroutine.wrap(function () all({}, 5, 4) end) do + a = a+1 +end +assert(a == 5^4) + + +-- access to locals of collected corroutines +local C = {}; setmetatable(C, {__mode = "kv"}) +local x = coroutine.wrap (function () + local a = 10 + local function f () a = a+10; return a end + while true do + a = a+1 + coroutine.yield(f) + end + end) + +C[1] = x; + +local f = x() +assert(f() == 21 and x()() == 32 and x() == f) +x = nil +collectgarbage() +-- assert(C[1] == nil) +assert(f() == 43 and f() == 53) + + +-- old bug: attempt to resume itself + +function co_func (current_co) + assert(coroutine.running() == current_co) + assert(coroutine.resume(current_co) == false) + assert(coroutine.resume(current_co) == false) + return 10 +end + +local co = coroutine.create(co_func) +local a,b = coroutine.resume(co, co) +assert(a == true and b == 10) +assert(coroutine.resume(co, co) == false) +assert(coroutine.resume(co, co) == false) + +-- access to locals of erroneous coroutines +local x = coroutine.create (function () + local a = 10 + _G.f = function () a=a+1; return a end + error('x') + end) + +assert(not coroutine.resume(x)) +-- overwrite previous position of local `a' +assert(not coroutine.resume(x, 1, 1, 1, 1, 1, 1, 1)) +assert(_G.f() == 11) +assert(_G.f() == 12) + + +if not T then + (Message or print)('\a\n >>> testC not active: skipping yield/hook tests <<<\n\a') +else + + local turn + + function fact (t, x) + assert(turn == t) + if x == 0 then return 1 + else return x*fact(t, x-1) + end + end + + local A,B,a,b = 0,0,0,0 + + local x = coroutine.create(function () + T.setyhook("", 2) + A = fact("A", 10) + end) + + local y = coroutine.create(function () + T.setyhook("", 3) + B = fact("B", 11) + end) + + while A==0 or B==0 do + if A==0 then turn = "A"; T.resume(x) end + if B==0 then turn = "B"; T.resume(y) end + end + + assert(B/A == 11) +end + + +-- leaving a pending coroutine open +_X = coroutine.wrap(function () + local a = 10 + local x = function () a = a+1 end + coroutine.yield() + end) + +_X() + + +-- coroutine environments +co = coroutine.create(function () + coroutine.yield(getfenv(0)) + return loadstring("return a")() + end) + +a = {a = 15} +-- debug.setfenv(co, a) +-- assert(debug.getfenv(co) == a) +-- assert(select(2, coroutine.resume(co)) == a) +-- assert(select(2, coroutine.resume(co)) == a.a) + + +return'OK' diff --git a/tests/conformance/constructs.lua b/tests/conformance/constructs.lua new file mode 100644 index 0000000..16c63b0 --- /dev/null +++ b/tests/conformance/constructs.lua @@ -0,0 +1,240 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print "testing syntax" + +-- testing priorities + +assert(2^3^2 == 2^(3^2)); +assert(2^3*4 == (2^3)*4); +assert(2^-2 == 1/4 and -2^- -2 == - - -4); +assert(not nil and 2 and not(2>3 or 3<2)); +assert(-3-1-5 == 0+0-9); +assert(-2^2 == -4 and (-2)^2 == 4 and 2*2-3-1 == 0); +assert(2*1+3/3 == 3 and 1+2 .. 3*1 == "33"); +assert(not(2+1 > 3*1) and "a".."b" > "a"); + +assert(not ((true or false) and nil)) +assert( true or false and nil) + +local a,b = 1,nil; +assert(-(1 or 2) == -1 and (1 and 2)+(-1.25 or -4) == 0.75); +x = ((b or a)+1 == 2 and (10 or a)+1 == 11); assert(x); +x = (((2<3) or 1) == true and (2<3 and 4) == 4); assert(x); + +x,y=1,2; +assert((x>y) and x or y == 2); +x,y=2,1; +assert((x>y) and x or y == 2); + +assert(1234567890 == tonumber('1234567890') and 1234567890+1 == 1234567891) + + +-- silly loops +repeat until 1; repeat until true; +while false do end; while nil do end; + +do -- test old bug (first name could not be an `upvalue') + local a; function f(x) x={a=1}; x={x=1}; x={G=1} end +end + +function f (i) + if type(i) ~= 'number' then return i,'jojo'; end; + if i > 0 then return i, f(i-1); end; +end + +x = {f(3), f(5), f(10);}; +assert(x[1] == 3 and x[2] == 5 and x[3] == 10 and x[4] == 9 and x[12] == 1); +assert(x[nil] == nil) +x = {f'alo', f'xixi', nil}; +assert(x[1] == 'alo' and x[2] == 'xixi' and x[3] == nil); +x = {f'alo'..'xixi'}; +assert(x[1] == 'aloxixi') +x = {f{}} +assert(x[2] == 'jojo' and type(x[1]) == 'table') + + +local f = function (i) + if i < 10 then return 'a'; + elseif i < 20 then return 'b'; + elseif i < 30 then return 'c'; + end; +end + +assert(f(3) == 'a' and f(12) == 'b' and f(26) == 'c' and f(100) == nil) + +for i=1,1000 do break; end; +n=100; +i=3; +t = {}; +a=nil +while not a do + a=0; for i=1,n do for i=i,1,-1 do a=a+1; t[i]=1; end; end; +end +assert(a == n*(n+1)/2 and i==3); +assert(t[1] and t[n] and not t[0] and not t[n+1]) + +function f(b) + local x = 1; + repeat + local a; + if b==1 then local b=1; x=10; break + elseif b==2 then x=20; break; + elseif b==3 then x=30; + else local a,b,c,d=math.sin(1); x=x+1; + end + until x>=12; + return x; +end; + +assert(f(1) == 10 and f(2) == 20 and f(3) == 30 and f(4)==12) + + +local f = function (i) + if i < 10 then return 'a' + elseif i < 20 then return 'b' + elseif i < 30 then return 'c' + else return 8 + end +end + +assert(f(3) == 'a' and f(12) == 'b' and f(26) == 'c' and f(100) == 8) + +local a, b = nil, 23 +x = {f(100)*2+3 or a, a or b+2} +assert(x[1] == 19 and x[2] == 25) +x = {f=2+3 or a, a = b+2} +assert(x.f == 5 and x.a == 25) + +a={y=1} +x = {a.y} +assert(x[1] == 1) + +function f(i) + while 1 do + if i>0 then i=i-1; + else return; end; + end; +end; + +function g(i) + while 1 do + if i>0 then i=i-1 + else return end + end +end + +f(10); g(10); + +do + function f () return 1,2,3; end + local a, b, c = f(); + assert(a==1 and b==2 and c==3) + a, b, c = (f()); + assert(a==1 and b==nil and c==nil) +end + +local a,b = 3 and f(); +assert(a==1 and b==nil) + +function g() f(); return; end; +assert(g() == nil) +function g() return nil or f() end +a,b = g() +assert(a==1 and b==nil) + +print'+'; + + +f = [[ +return function ( a , b , c , d , e ) + local x = a >= b or c or ( d and e ) or nil + return x +end , { a = 1 , b = 2 >= 1 , } or { 1 }; +]] +f = string.gsub(f, "%s+", "\n"); -- force a SETLINE between opcodes +f,a = loadstring(f)(); +assert(a.a == 1 and a.b) + +function g (a,b,c,d,e) + if not (a>=b or c or d and e or nil) then return 0; else return 1; end; +end + +function h (a,b,c,d,e) + while (a>=b or c or (d and e) or nil) do return 1; end; + return 0; +end; + +assert(f(2,1) == true and g(2,1) == 1 and h(2,1) == 1) +assert(f(1,2,'a') == 'a' and g(1,2,'a') == 1 and h(1,2,'a') == 1) +assert(f(1,2,'a') +~= -- force SETLINE before nil +nil, "") +assert(f(1,2,'a') == 'a' and g(1,2,'a') == 1 and h(1,2,'a') == 1) +assert(f(1,2,nil,1,'x') == 'x' and g(1,2,nil,1,'x') == 1 and + h(1,2,nil,1,'x') == 1) +assert(f(1,2,nil,nil,'x') == nil and g(1,2,nil,nil,'x') == 0 and + h(1,2,nil,nil,'x') == 0) +assert(f(1,2,nil,1,nil) == nil and g(1,2,nil,1,nil) == 0 and + h(1,2,nil,1,nil) == 0) + +assert(1 and 2<3 == true and 2<3 and 'a'<'b' == true) +x = 2<3 and not 3; assert(x==false) +x = 2<1 or (2>1 and 'a'); assert(x=='a') + + +do + local a; if nil then a=1; else a=2; end; -- this nil comes as PUSHNIL 2 + assert(a==2) +end + +function F(a) + -- assert(debug.getinfo(1, "n").name == 'F') + return a,2,3 +end + +a,b = F(1)~=nil; assert(a == true and b == nil); +a,b = F(nil)==nil; assert(a == true and b == nil) + +---------------------------------------------------------------- +-- creates all combinations of +-- [not] ([not] arg op [not] arg) +-- and tests each one +function ID(x) return x end + +function f(t, i) + local b = t.n + local res = (math.floor(i/c)%b)+1 + c = c*b + return t[res] +end + +local arg = {" ( 1 < 2 ) ", " ( 1 >= 2 ) ", " F ( ) ", " nil "; n=4} + +local op = {" and ", " or ", " == ", " ~= "; n=4} + +local neg = {" ", " not "; n=2} + +local i = 0 +repeat + c = 1 + local s = f(neg, i)..'ID('..f(neg, i)..f(arg, i)..f(op, i)..f(neg, i)..'ID('..f(arg, i)..'))' + local s1 = string.gsub(s, 'ID', '') + K,X,NX,WX1,WX2 = nil + s = string.format([[ + local a = %s + local b = not %s + K = b + local xxx; + if %s then X = a else X = b end + if %s then NX = b else NX = a end + while %s do WX1 = a; break end + while %s do WX2 = a; break end + repeat if (%s) then break end; assert(b) until not(%s) + ]], s1, s, s1, s, s1, s, s1, s, s) + assert(loadstring(s))() + assert(X and not NX and not WX1 == K and not WX2 == K) + if i%4000 == 0 then print('+') end + i = i+1 +until i==c + +return'OK' diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua new file mode 100644 index 0000000..73c3833 --- /dev/null +++ b/tests/conformance/coroutine.lua @@ -0,0 +1,322 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print "testing coroutines" + +local f + +local main = coroutine.running() +assert(main == nil) +-- assert(not coroutine.resume(main)) +assert(coroutine.isyieldable()) -- note: we run this in context of a yieldable thread like all other Lua code +--assert(not pcall(coroutine.yield)) + + +-- trivial errors +assert(not pcall(coroutine.resume, 0)) +assert(not pcall(coroutine.status, 0)) + + +-- tests for multiple yield/resume arguments + +local function eqtab (t1, t2) + assert(#t1 == #t2) + for i = 1, #t1 do + local v = t1[i] + assert(t2[i] == v) + end +end + +_G.x = nil -- declare x +function foo (a, ...) + local x = coroutine.running() + assert(x == f) + -- next call should not corrupt coroutine (but must fail, + -- as it attempts to resume the running coroutine) + assert(coroutine.resume(f) == false) + assert(coroutine.status(f) == "running") + local arg = {...} + assert(coroutine.isyieldable()) + for i=1,#arg do + _G.x = {coroutine.yield(table.unpack(arg[i]))} + end + return table.unpack(a) +end + +f = coroutine.create(foo) +assert(type(f) == "thread" and coroutine.status(f) == "suspended") +assert(string.find(tostring(f), "thread")) +local s,a,b,c,d +s,a,b,c,d = coroutine.resume(f, {1,2,3}, {}, {1}, {'a', 'b', 'c'}) +assert(s and a == nil and coroutine.status(f) == "suspended") +s,a,b,c,d = coroutine.resume(f) +eqtab(_G.x, {}) +assert(s and a == 1 and b == nil) +s,a,b,c,d = coroutine.resume(f, 1, 2, 3) +eqtab(_G.x, {1, 2, 3}) +assert(s and a == 'a' and b == 'b' and c == 'c' and d == nil) +s,a,b,c,d = coroutine.resume(f, "xuxu") +eqtab(_G.x, {"xuxu"}) +assert(s and a == 1 and b == 2 and c == 3 and d == nil) +assert(coroutine.status(f) == "dead") +s, a = coroutine.resume(f, "xuxu") +assert(not s and string.find(a, "dead") and coroutine.status(f) == "dead") + + +-- yields in tail calls +local function foo (i) return coroutine.yield(i) end +f = coroutine.wrap(function () + for i=1,10 do + assert(foo(i) == _G.x) + end + return 'a' +end) +for i=1,10 do _G.x = i; assert(f(i) == i) end +_G.x = 'xuxu'; assert(f('xuxu') == 'a') + +-- recursive +function pf (n, i) + coroutine.yield(n) + pf(n*i, i+1) +end + +f = coroutine.wrap(pf) +local s=1 +for i=1,10 do + assert(f(1, 1) == s) + s = s*i +end + +-- sieve +function gen (n) + return coroutine.wrap(function () + for i=2,n do coroutine.yield(i) end + end) +end + + +function filter (p, g) + return coroutine.wrap(function () + while 1 do + local n = g() + if n == nil then return end + if math.fmod(n, p) ~= 0 then coroutine.yield(n) end + end + end) +end + +local x = gen(80) +local a = {} +while 1 do + local n = x() + if n == nil then break end + table.insert(a, n) + x = filter(n, x) +end + +assert(#a == 22 and a[#a] == 79) +x, a = nil + + +-- yielding across C boundaries + +co = coroutine.wrap(function() + assert(not pcall(table.sort,{1,2,3}, coroutine.yield)) + assert(coroutine.isyieldable()) + coroutine.yield(20) + return 30 + end) + +assert(co() == 20) +assert(co() == 30) + +-- unyieldable C call +do + local function f (c) + assert(not coroutine.isyieldable()) + return c .. c + end + + local co = coroutine.wrap(function (c) + assert(coroutine.isyieldable()) + local s = string.gsub("a", ".", f) + return s + end) + assert(co() == "aa") +end + + + +-- errors in coroutines +function foo () + coroutine.yield(3) + error(foo) +end + +function goo() foo() end +x = coroutine.wrap(goo) +assert(x() == 3) +local a,b = pcall(x) +assert(not a and b == foo) + +x = coroutine.create(goo) +a,b = coroutine.resume(x) +assert(a and b == 3) +a,b = coroutine.resume(x) +assert(not a and b == foo and coroutine.status(x) == "dead") +a,b = coroutine.resume(x) +assert(not a and string.find(b, "dead") and coroutine.status(x) == "dead") + + +-- co-routines x for loop +function all (a, n, k) + if k == 0 then coroutine.yield(a) + else + for i=1,n do + a[k] = i + all(a, n, k-1) + end + end +end + +local a = 0 +for t in coroutine.wrap(function () all({}, 5, 4) end) do + a = a+1 +end +assert(a == 5^4) + + +-- access to locals of collected corroutines +local C = {}; setmetatable(C, {__mode = "kv"}) +local x = coroutine.wrap (function () + local a = 10 + local function f () a = a+10; return a end + while true do + a = a+1 + coroutine.yield(f) + end + end) + +C[1] = x; + +local f = x() +assert(f() == 21 and x()() == 32 and x() == f) +x = nil +collectgarbage() +assert(C[1] == undef) +assert(f() == 43 and f() == 53) + + +-- old bug: attempt to resume itself + +function co_func (current_co) + assert(coroutine.running() == current_co) + assert(coroutine.resume(current_co) == false) + coroutine.yield(10, 20) + assert(coroutine.resume(current_co) == false) + coroutine.yield(23) + return 10 +end + +local co = coroutine.create(co_func) +local a,b,c = coroutine.resume(co, co) +assert(a == true and b == 10 and c == 20) +a,b = coroutine.resume(co, co) +assert(a == true and b == 23) +a,b = coroutine.resume(co, co) +assert(a == true and b == 10) +assert(coroutine.resume(co, co) == false) +assert(coroutine.resume(co, co) == false) + + +-- attempt to resume 'normal' coroutine +local co1, co2 +co1 = coroutine.create(function () return co2() end) +co2 = coroutine.wrap(function () + assert(coroutine.status(co1) == 'normal') + assert(not coroutine.resume(co1)) + coroutine.yield(3) + end) + +a,b = coroutine.resume(co1) +assert(a and b == 3) +assert(coroutine.status(co1) == 'dead') + +if not limitedstack then + -- infinite recursion of coroutines + a = function(a) coroutine.wrap(a)(a) end + assert(not pcall(a, a)) + a = nil +end + +-- access to locals of erroneous coroutines +local x = coroutine.create (function () + local a = 10 + _G.f = function () a=a+1; return a end + error('x') + end) + +assert(not coroutine.resume(x)) +-- overwrite previous position of local `a' +assert(not coroutine.resume(x, 1, 1, 1, 1, 1, 1, 1)) +assert(_G.f() == 11) +assert(_G.f() == 12) + +-- leaving a pending coroutine open +_X = coroutine.wrap(function () + local a = 10 + local x = function () a = a+1 end + coroutine.yield() + end) + +_X() + + +if not limitedstack then + -- bug (stack overflow) + local j = 2^9 + local lim = 1000000 -- (C stack limit; assume 32-bit machine) + local t = {lim - 10, lim - 5, lim - 1, lim, lim + 1} + for i = 1, #t do + local j = t[i] + co = coroutine.create(function() + local t = {} + for i = 1, j do t[i] = i end + return table.unpack(t) + end) + local r, msg = coroutine.resume(co) + assert(not r) + end + co = nil +end + + +assert(coroutine.running() == main) + +-- bug in nCcalls +local co = coroutine.wrap(function () + local a = {pcall(pcall,pcall,pcall,pcall,pcall,pcall,pcall,error,"hi")} + return pcall(assert, table.unpack(a)) +end) + +local a = {co()} +assert(a[10] == "hi") + +-- test coroutine with C functions +local co = coroutine.create(coroutine.yield) +assert(coroutine.status(co) == "suspended") +coroutine.resume(co) +assert(coroutine.status(co) == "suspended") +coroutine.resume(co) +assert(coroutine.status(co) == "dead") + +-- test correct handling of coroutine.yield returns for 0-30 values +for i=0,30 do + local T = table.create(i, 42) + local co = coroutine.create(function() coroutine.yield(table.unpack(T)) end) + local T2 = table.pack(coroutine.resume(co)) + assert(T2[1] == true) + assert(1 + #T == #T2) + assert(#T2 == 1 or T2[#T2] == 42) +end + +return'OK' diff --git a/tests/conformance/datetime.lua b/tests/conformance/datetime.lua new file mode 100644 index 0000000..21ef60d --- /dev/null +++ b/tests/conformance/datetime.lua @@ -0,0 +1,77 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print "testing datetime library" + +local function checkerr (msg, f, ...) + local stat, err = pcall(f, ...) + assert(not stat and string.find(err, msg, 1, true)) +end + +assert(os.date("") == "") +assert(os.date("!") == "") +local x = string.rep("a", 10000) +assert(os.date(x) == x) +local t = os.time() +D = os.date("*t", t) +assert(os.date(string.rep("%d", 1000), t) == + string.rep(os.date("%d", t), 1000)) +assert(os.date(string.rep("%", 200)) == string.rep("%", 100)) + +local function checkDateTable (t) + local D = os.date("!*t", t) + assert(os.time(D) == t) + local DC = os.date("!%Y %m %d %H %M %S %w %j"):split(" ") + assert(D.year == tonumber(DC[1])) + assert(D.month == tonumber(DC[2])) + assert(D.day == tonumber(DC[3])) + assert(D.hour == tonumber(DC[4])) + assert(D.min == tonumber(DC[5])) + assert(D.sec == tonumber(DC[6])) + assert(D.wday == tonumber(DC[7]) + 1) + assert(D.yday == tonumber(DC[8])) +end + +checkDateTable(os.time()) + +checkerr("invalid conversion specifier", os.date, "%9") +checkerr("invalid conversion specifier", os.date, "%O") +checkerr("invalid conversion specifier", os.date, "%E") +checkerr("invalid conversion specifier", os.date, "%Ea") + +checkerr("missing", os.time, {hour = 12}) -- missing date + +do + local D = os.date("*t") + local t = os.time(D) + if D.isdst == nil then + print("no daylight saving information") + else + assert(type(D.isdst) == 'boolean') + end + D.isdst = nil + local t1 = os.time(D) + assert(t == t1) -- if isdst is absent uses correct default +end + +local D = os.date("*t") +t = os.time(D) +D.year = D.year-1; +local t1 = os.time(D) +-- allow for leap years +assert(math.abs(os.difftime(t,t1)/(24*3600) - 365) < 2) + +-- should not take more than 1 second to execute these two lines +t = os.time() +t1 = os.time(os.date("!*t")) +local diff = os.difftime(t1,t) +assert(0 <= diff and diff <= 1) +diff = os.difftime(t,t1) +assert(-1 <= diff and diff <= 0) + +local t1 = os.time{year=2000, month=10, day=1, hour=23, min=12} +local t2 = os.time{year=2000, month=10, day=1, hour=23, min=10, sec=19} +assert(os.difftime(t1,t2) == 60*2-19) + +assert(os.time({ year = 1970, day = 1, month = 1, hour = 0}) == 0) + +return'OK' diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua new file mode 100644 index 0000000..ee79a14 --- /dev/null +++ b/tests/conformance/debug.lua @@ -0,0 +1,101 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print "testing debug library" + +-- traceback +function foo(...) + return debug.traceback(...) +end + +function bar() + coroutine.yield() +end + +assert(foo():find("foo") > 0) +assert(foo("hello"):find("hello") > 0) +assert(foo("hello"):find("foo") > 0) +assert(foo("hello", 2):find("hello") > 0) +assert(foo("hello", 2):find("foo") == nil) + +local co = coroutine.create(bar) +coroutine.resume(co) + +assert(debug.traceback(co):find("bar") > 0) +assert(debug.traceback(co, "hello"):find("hello") > 0) +assert(debug.traceback(co, "hello"):find("bar") > 0) +assert(debug.traceback(co, "hello", 2):find("hello") > 0) +assert(debug.traceback(co, "hello", 2):find("bar") == nil) + +-- traceback for the top frame +function halp(key, value) + local t = {} + t[key] = value -- line 30 + return t +end + +local co2 = coroutine.create(halp) +coroutine.resume(co2, 0 / 0, 42) + +assert(debug.traceback(co2) == "debug.lua:31 function halp\n") +assert(debug.info(co2, 0, "l") == 31) + +-- info errors +function qux(...) + local ok, err = pcall(debug.info, ...) + assert(not ok) + return err +end + +assert(qux():find("function or level expected")) +assert(qux(1):find("string expected")) +assert(qux(-1):find("level can't be negative")) +assert(qux(1, "?"):find("invalid option")) +assert(qux(1, "nn"):find("duplicate option")) +assert(qux(co):find("function or level expected")) +assert(qux(co, 1):find("string expected")) + +-- info single-arg returns +function baz(...) + return debug.info(...) +end + +assert(baz(0, "n") == "info") +assert(baz(1, "n") == "baz") +assert(baz(2, "n") == "") -- main/anonymous +assert(baz(3, "n") == nil) +assert(baz(0, "s") == "[C]") +assert(baz(1, "s") == "debug.lua") +assert(baz(0, "l") == -1) +assert(baz(1, "l") > 42) +assert(baz(0, "f") == debug.info) +assert(baz(1, "f") == baz) +assert(baz(0, "a") == 0) +assert(baz(1, "a") == 0) +assert(baz(co, 1, "n") == "bar") +assert(baz(co, 2, "n") == nil) +assert(baz(math.sqrt, "n") == "sqrt") +assert(baz(math.sqrt, "f") == math.sqrt) -- yes this is pointless + +-- info multi-arg returns +function quux(...) + return {debug.info(...)} +end + +assert(#(quux(1, "nlsf")) == 4) +assert(quux(1, "nlsf")[1] == "quux") +assert(quux(1, "nlsf")[2] > 64) +assert(quux(1, "nlsf")[3] == "debug.lua") +assert(quux(1, "nlsf")[4] == quux) + +-- info arity +function quuz(f) + local a, v = debug.info(f, "a") + return tostring(a) .. " " .. tostring(v) +end + +assert(quuz(math.cos) == "0 true") -- C functions are treated as fully variadic +assert(quuz(function() end) == "0 false") +assert(quuz(function(...) end) == "0 true") +assert(quuz(function(a, b) end) == "2 false") +assert(quuz(function(a, b, ...) end) == "2 true") + +return'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua new file mode 100644 index 0000000..5e69fc6 --- /dev/null +++ b/tests/conformance/debugger.lua @@ -0,0 +1,48 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print "testing debugger" -- note, this file can't run in isolation from C tests + +local a = 5 + +function foo(b) + print("in foo", b) + a = 6 +end + +breakpoint(8) + +foo(50) + +breakpoint(16) -- next line +print("here") + +function coro(arg) + print("before coro break") + a = arg + print("after coro break") + return 42 +end + +breakpoint(20) -- break inside coro() + +a = 7 + +local co = coroutine.create(coro) +local _, res = coroutine.resume(co, 8) -- this breaks and resumes! +assert(res == 42) + +local cof = coroutine.wrap(coro) +assert(cof(9) == 42) -- this breaks and resumes! + +function corobad() + print("before coro break") + error("oops") +end + +assert(a == 9) + +breakpoint(38) -- break inside corobad() + +local co = coroutine.create(corobad) +assert(coroutine.resume(co) == false) -- this breaks, resumes and dies! + +return 'OK' diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua new file mode 100644 index 0000000..eded14e --- /dev/null +++ b/tests/conformance/errors.lua @@ -0,0 +1,296 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print("testing errors") + +function doit (s) + local f, msg = loadstring(s) + if f == nil then return msg end + local cond, msg = pcall(f) + return (not cond) and msg +end + + +function checkmessage (prog, msg) + -- assert(string.find(doit(prog), msg, 1, true)) +end + +function checksyntax (prog, extra, token, line) + local msg = doit(prog) + token = string.gsub(token, "(%p)", "%%%1") + local pt = string.format([[^%%[string ".*"%%]:]]) -- only check that an error happened - ignore token/line since Luau error messages differ from Lua substantially + assert(string.find(msg, pt)) + assert(string.find(msg, msg, 1, true)) +end + + +-- test error message with no extra info +assert(doit("error('hi', 0)") == 'hi') + +-- test error message with no info +-- assert(doit("error()") == nil) + + +-- test common errors/errors that crashed in the past +assert(doit("unpack({}, 1, n=2^30)")) +assert(doit("a=math.sin()")) +assert(not doit("tostring(1)") and doit("tostring()")) +assert(doit"tonumber()") +assert(doit"repeat until 1; a") +checksyntax("break label", "", "label", 1) +assert(doit";") +assert(doit"a=1;;") +assert(doit"return;;") +assert(doit"assert(false)") +assert(doit"assert(nil)") +assert(doit"a=math.sin\n(3)") +assert(doit("function a (... , ...) end")) +assert(doit("function a (, ...) end")) + +checksyntax([[ + local a = {4 + +]], "'}' expected (to close '{' at line 1)", "", 3) + + +-- tests for better error messages + +checkmessage("a=1; bbbb=2; a=math.sin(3)+bbbb(3)", "global 'bbbb'") +checkmessage("a=1; local a,bbbb=2,3; a = math.sin(1) and bbbb(3)", + "local 'bbbb'") +checkmessage("a={}; do local a=1 end a:bbbb(3)", "method 'bbbb'") +checkmessage("local a={}; a.bbbb(3)", "field 'bbbb'") +assert(not string.find(doit"a={13}; local bbbb=1; a[bbbb](3)", "'bbbb'")) +checkmessage("a={13}; local bbbb=1; a[bbbb](3)", "number") + +aaa = nil +checkmessage("aaa.bbb:ddd(9)", "global 'aaa'") +checkmessage("local aaa={bbb=1}; aaa.bbb:ddd(9)", "field 'bbb'") +checkmessage("local aaa={bbb={}}; aaa.bbb:ddd(9)", "method 'ddd'") +checkmessage("local a,b,c; (function () a = b+1 end)()", "upvalue 'b'") +assert(not doit"local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)") + +checkmessage("b=1; local aaa='a'; x=aaa+b", "local 'aaa'") +checkmessage("aaa={}; x=3/aaa", "global 'aaa'") +checkmessage("aaa='2'; b=nil;x=aaa*b", "global 'b'") +checkmessage("aaa={}; x=-aaa", "global 'aaa'") +assert(not string.find(doit"aaa={}; x=(aaa or aaa)+(aaa and aaa)", "'aaa'")) +assert(not string.find(doit"aaa={}; (aaa or aaa)()", "'aaa'")) + +checkmessage([[aaa=9 +repeat until 3==3 +local x=math.sin(math.cos(3)) +if math.sin(1) == x then return math.sin(1) end -- tail call +local a,b = 1, { + {x='a'..'b'..'c', y='b', z=x}, + {1,2,3,4,5} or 3+3<=3+3, + 3+1>3+1, + {d = x and aaa[x or y]}} +]], "global 'aaa'") + +checkmessage([[ +local x,y = {},1 +if math.sin(1) == 0 then return 3 end -- return +x.a()]], "field 'a'") + +checkmessage([[ +prefix = nil +insert = nil +while 1 do + local a + if nil then break end + insert(prefix, a) +end]], "global 'insert'") + +checkmessage([[ -- tail call + return math.sin("a") +]], "'sin'") + +checkmessage([[collectgarbage("nooption")]], "invalid option") + +checkmessage([[x = print .. "a"]], "concatenate") + +checkmessage("getmetatable(io.stdin).__gc()", "no value") + +print'+' + + +-- testing line error + +function lineerror (s) + local err,msg = pcall(loadstring(s)) + local line = string.match(msg, ":(%d+):") + return line and line+0 +end + +assert(lineerror"local a\n for i=1,'a' do \n print(i) \n end" == 2) +-- assert(lineerror"\n local a \n for k,v in 3 \n do \n print(k) \n end" == 3) +-- assert(lineerror"\n\n for k,v in \n 3 \n do \n print(k) \n end" == 4) +assert(lineerror"function a.x.y ()\na=a+1\nend" == 1) + +local p = [[ +function g() f() end +function f(x) error('a', X) end +g() +]] +X=3;assert(lineerror(p) == 3) +X=0;assert(lineerror(p) == nil) +X=1;assert(lineerror(p) == 2) +X=2;assert(lineerror(p) == 1) + +lineerror = nil + +if not limitedstack then + C = 0 + -- local l = debug.getinfo(1, "l").currentline + function y () C=C+1; y() end + + local function checkstackmessage (m) + return (string.find(m, "^.-:%d+: stack overflow")) + end + + assert(checkstackmessage(doit('y()'))) + assert(checkstackmessage(doit('y()'))) + assert(checkstackmessage(doit('y()'))) + + -- teste de linhas em erro + C = 0 + local l1 + local function g() + -- l1 = debug.getinfo(1, "l").currentline + y() + end + local _, stackmsg = xpcall(g, debug.traceback) + local stack = {} + for line in string.gmatch(stackmsg, "[^\n]*") do + local curr = string.match(line, ":(%d+):") + if curr then table.insert(stack, tonumber(curr)) end + end +end + +--[[ +local i=1 +while stack[i] ~= l1 do + assert(stack[i] == l) + i = i+1 +end +assert(i > 15) +]]-- + +-- error in error handling +local res, msg = xpcall(error, error) +assert(not res and type(msg) == 'string') + +local function f (x) + if x==0 then error('a\n') + else + local aux = function () return f(x-1) end + local a,b = xpcall(aux, aux) + return a,b + end +end + +if not limitedstack then + f(3) +end + +-- non string messages +function f() error{msg='x'} end +res, msg = xpcall(f, function (r) return {msg=r.msg..'y'} end) +assert(msg.msg == 'xy') + +print('+') +checksyntax("syntax error", "", "error", 1) +checksyntax("1.000", "", "1.000", 1) +checksyntax("[[a]]", "", "[[a]]", 1) +checksyntax("'aa'", "", "'aa'", 1) + +-- test 255 as first char in a chunk +checksyntax("\255a = 1", "", "\255", 1) + +doit('I = loadstring("a=9+"); a=3') +assert(a==3 and I == nil) +print('+') + +lim = 1000 +if rawget(_G, "_soft") then lim = 100 end +for i=1,lim do + doit('a = ') + doit('a = 4+nil') +end + + +-- testing syntax limits +local function testrep (init, rep) + local s = "local a; "..init .. string.rep(rep, 300) + local a,b = loadstring(s) + assert(not a) -- and string.find(b, "syntax levels")) +end +testrep("a=", "{") +testrep("a=", "(") +testrep("", "a(") +testrep("", "do ") +testrep("", "while a do ") +testrep("", "if a then else ") +testrep("", "function foo () ") +testrep("a=", "a..") +testrep("a=", "a^") + + +-- testing other limits +-- upvalues +local s = "function foo ()\n local " +for j = 1,70 do + s = s.."a"..j..", " +end +s = s.."b\n" +for j = 1,70 do + s = s.."function foo"..j.." ()\n a"..j.."=3\n" +end +local a,b = loadstring(s) +assert(not a) +-- assert(string.find(b, "line 3")) + +-- local variables +s = "\nfunction foo ()\n local " +for j = 1,300 do + s = s.."a"..j..", " +end +s = s.."b\n" +local a,b = loadstring(s) +assert(not a) +--assert(string.find(b, "line 2")) + +-- Test for CLI-28786 +-- The xpcall is intentially going to cause an exception +-- followed by a forced exception in the error handler. +-- If the secondary handler isn't trapped, it will cause +-- the unit test to fail. If the xpcall captures the +-- second fault, it's a success. + +a, b = xpcall( + function() + return game[{}] + end, + function() + return game.CoreGui.Name + end) +assert(not a) +print(b) + +coroutine.wrap(function() + assert(not pcall(debug.getinfo, coroutine.running(), 0, ">")) +end)() + +-- arith errors +function ecall(fn, ...) + local ok, err = pcall(fn, ...) + assert(not ok) + return err:sub(err:find(": ") + 2, #err) +end + +assert(ecall(function() return nil + 5 end) == "attempt to perform arithmetic (add) on nil and number") +assert(ecall(function() return "a" + "b" end) == "attempt to perform arithmetic (add) on string") +assert(ecall(function() return 1 > nil end) == "attempt to compare nil < number") -- note reversed order (by design) +assert(ecall(function() return "a" <= 5 end) == "attempt to compare string <= number") + +return('OK') diff --git a/tests/conformance/events.lua b/tests/conformance/events.lua new file mode 100644 index 0000000..32e090a --- /dev/null +++ b/tests/conformance/events.lua @@ -0,0 +1,389 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing metatables') + +local unpack = table.unpack + +X = 20; B = 30 + +local _G = getfenv() +setfenv(1, setmetatable({}, {__index=_G})) + +collectgarbage() + +X = X+10 +assert(X == 30 and _G.X == 20) +B = false +assert(B == false) +B = nil +assert(B == 30) + +assert(getmetatable{} == nil) +assert(getmetatable(4) == nil) +assert(getmetatable(nil) == nil) +a={}; setmetatable(a, {__metatable = "xuxu", + __tostring=function(x) return x.name end}) +assert(getmetatable(a) == "xuxu") +local res,err = pcall(tostring, a) +assert(not res and err == "'__tostring' must return a string") +-- cannot change a protected metatable +assert(pcall(setmetatable, a, {}) == false) +a.name = "gororoba" +assert(tostring(a) == "gororoba") + +local a, t = {10,20,30; x="10", y="20"}, {} +assert(setmetatable(a,t) == a) +assert(getmetatable(a) == t) +assert(setmetatable(a,nil) == a) +assert(getmetatable(a) == nil) +assert(setmetatable(a,t) == a) + + +function f (t, i, e) + assert(not e) + local p = rawget(t, "parent") + return (p and p[i]+3), "dummy return" +end + +t.__index = f + +a.parent = {z=25, x=12, [4] = 24} +assert(a[1] == 10 and a.z == 28 and a[4] == 27 and a.x == "10") + +collectgarbage() + +a = setmetatable({}, t) +function f(t, i, v) rawset(t, i, v-3) end +t.__newindex = f +a[1] = 30; a.x = "101"; a[5] = 200 +assert(a[1] == 27 and a.x == 98 and a[5] == 197) + + +local c = {} +a = setmetatable({}, t) +t.__newindex = c +a[1] = 10; a[2] = 20; a[3] = 90 +assert(c[1] == 10 and c[2] == 20 and c[3] == 90) + + +do + local a; + a = setmetatable({}, {__index = setmetatable({}, + {__index = setmetatable({}, + {__index = function (_,n) return a[n-3]+4, "lixo" end})})}) + a[0] = 20 + for i=0,10 do + assert(a[i*3] == 20 + i*4) + end +end + + +do -- newindex + local foi + local a = {} + for i=1,10 do a[i] = 0; a['a'..i] = 0; end + setmetatable(a, {__newindex = function (t,k,v) foi=true; rawset(t,k,v) end}) + foi = false; a[1]=0; assert(not foi) + foi = false; a['a1']=0; assert(not foi) + foi = false; a['a11']=0; assert(foi) + foi = false; a[11]=0; assert(foi) + foi = false; a[1]=nil; assert(not foi) + foi = false; a[1]=nil; assert(foi) +end + + +function f (t, ...) return t, {...} end +t.__call = f + +do + local x,y = a(unpack{'a', 1}) + assert(x==a and y[1]=='a' and y[2]==1 and y[3]==nil) + x,y = a() + assert(x==a and y[1]==nil) +end + + +local b = setmetatable({}, t) +setmetatable(b,t) + +function f(op) + return function (...) cap = {[0] = op, ...} ; return (...) end +end +t.__add = f("add") +t.__sub = f("sub") +t.__mul = f("mul") +t.__div = f("div") +t.__mod = f("mod") +t.__unm = f("unm") +t.__pow = f("pow") + +assert(b+5 == b) +assert(cap[0] == "add" and cap[1] == b and cap[2] == 5 and cap[3]==nil) +assert(b+'5' == b) +assert(cap[0] == "add" and cap[1] == b and cap[2] == '5' and cap[3]==nil) +assert(5+b == 5) +assert(cap[0] == "add" and cap[1] == 5 and cap[2] == b and cap[3]==nil) +assert('5'+b == '5') +assert(cap[0] == "add" and cap[1] == '5' and cap[2] == b and cap[3]==nil) +b=b-3; assert(getmetatable(b) == t) +assert(5-a == 5) +assert(cap[0] == "sub" and cap[1] == 5 and cap[2] == a and cap[3]==nil) +assert('5'-a == '5') +assert(cap[0] == "sub" and cap[1] == '5' and cap[2] == a and cap[3]==nil) +assert(a*a == a) +assert(cap[0] == "mul" and cap[1] == a and cap[2] == a and cap[3]==nil) +assert(a/0 == a) +assert(cap[0] == "div" and cap[1] == a and cap[2] == 0 and cap[3]==nil) +assert(a%2 == a) +assert(cap[0] == "mod" and cap[1] == a and cap[2] == 2 and cap[3]==nil) +assert(-a == a) +assert(cap[0] == "unm" and cap[1] == a) +assert(a^4 == a) +assert(cap[0] == "pow" and cap[1] == a and cap[2] == 4 and cap[3]==nil) +assert(a^'4' == a) +assert(cap[0] == "pow" and cap[1] == a and cap[2] == '4' and cap[3]==nil) +assert(4^a == 4) +assert(cap[0] == "pow" and cap[1] == 4 and cap[2] == a and cap[3]==nil) +assert('4'^a == '4') +assert(cap[0] == "pow" and cap[1] == '4' and cap[2] == a and cap[3]==nil) + + +t = {} +t.__lt = function (a,b,c) + collectgarbage() + assert(c == nil) + if type(a) == 'table' then a = a.x end + if type(b) == 'table' then b = b.x end + return aOp(1)) and not(Op(1)>Op(2)) and (Op(2)>Op(1))) + assert(not(Op('a')>Op('a')) and not(Op('a')>Op('b')) and (Op('b')>Op('a'))) + assert((Op(1)>=Op(1)) and not(Op(1)>=Op(2)) and (Op(2)>=Op(1))) + assert((Op('a')>=Op('a')) and not(Op('a')>=Op('b')) and (Op('b')>=Op('a'))) +end + +test() + +t.__le = function (a,b,c) + assert(c == nil) + if type(a) == 'table' then a = a.x end + if type(b) == 'table' then b = b.x end + return a<=b, "dummy" +end + +test() -- retest comparisons, now using both `lt' and `le' + + +-- test `partial order' + +local function Set(x) + local y = {} + for _,k in pairs(x) do y[k] = 1 end + return setmetatable(y, t) +end + +t.__lt = function (a,b) + for k in pairs(a) do + if not b[k] then return false end + b[k] = nil + end + return next(b) ~= nil +end + +t.__le = nil + +assert(Set{1,2,3} < Set{1,2,3,4}) +assert(not(Set{1,2,3,4} < Set{1,2,3,4})) +assert((Set{1,2,3,4} <= Set{1,2,3,4})) +assert((Set{1,2,3,4} >= Set{1,2,3,4})) +assert((Set{1,3} <= Set{3,5})) -- wrong!! model needs a `le' method ;-) + +t.__le = function (a,b) + for k in pairs(a) do + if not b[k] then return false end + end + return true +end + +assert(not (Set{1,3} <= Set{3,5})) -- now its OK! +assert(not(Set{1,3} <= Set{3,5})) +assert(not(Set{1,3} >= Set{3,5})) + +t.__eq = function (a,b) + for k in pairs(a) do + if not b[k] then return false end + b[k] = nil + end + return next(b) == nil +end + +local s = Set{1,3,5} +assert(s == Set{3,5,1}) +assert(not rawequal(s, Set{3,5,1})) +assert(rawequal(s, s)) +assert(Set{1,3,5,1} == Set{3,5,1}) +assert(Set{1,3,5} ~= Set{3,5,1,6}) +t[Set{1,3,5}] = 1 +assert(t[Set{1,3,5}] == nil) -- `__eq' is not valid for table accesses + + +t.__concat = function (a,b,c) + assert(c == nil) + if type(a) == 'table' then a = a.val end + if type(b) == 'table' then b = b.val end + if A then return a..b + else + return setmetatable({val=a..b}, t) + end +end + +c = {val="c"}; setmetatable(c, t) +d = {val="d"}; setmetatable(d, t) + +A = true +assert(c..d == 'cd') +assert(0 .."a".."b"..c..d.."e".."f"..(5+3).."g" == "0abcdef8g") + +A = false +x = c..d +assert(getmetatable(x) == t and x.val == 'cd') +x = 0 .."a".."b"..c..d.."e".."f".."g" +assert(x.val == "0abcdefg") + + +-- test comparison compatibilities +local t1, t2, c, d +t1 = {}; c = {}; setmetatable(c, t1) +d = {} +t1.__eq = function () return true end +t1.__lt = function () return true end +assert(c ~= d and not pcall(function () return c < d end)) +setmetatable(d, t1) +assert(c == d and c < d and not(d <= c)) +t2 = {} +t2.__eq = t1.__eq +t2.__lt = t1.__lt +setmetatable(d, t2) +assert(c == d and c < d and not(d <= c)) + + + +-- test for several levels of calls +local i +local tt = { + __call = function (t, ...) + i = i+1 + if t.f then return t.f(...) + else return {...} + end + end +} + +local a = setmetatable({}, tt) +local b = setmetatable({f=a}, tt) +local c = setmetatable({f=b}, tt) + +i = 0 +x = c(3,4,5) +assert(i == 3 and x[1] == 3 and x[3] == 5) + + +assert(_G.X == 20) +assert(_G == getfenv(0)) + +print'+' + +local _g = _G +setfenv(1, setmetatable({}, {__index=function (_,k) return _g[k] end})) + +-- testing proxies +assert(getmetatable(newproxy()) == nil) +assert(getmetatable(newproxy(false)) == nil) +assert(getmetatable(newproxy(nil)) == nil) + +local u = newproxy(true) + +getmetatable(u).__newindex = function (u,k,v) + getmetatable(u)[k] = v +end + +getmetatable(u).__index = function (u,k) + return getmetatable(u)[k] +end + +for i=1,10 do u[i] = i end +for i=1,10 do assert(u[i] == i) end + +-- local k = newproxy(u) +-- assert(getmetatable(k) == getmetatable(u)) + + +a = {} +rawset(a, "x", 1, 2, 3) +assert(a.x == 1 and rawget(a, "x", 3) == 1) + +print '+' + +--[[ +-- testing metatables for basic types +mt = {} +debug.setmetatable(10, mt) +assert(getmetatable(-2) == mt) +mt.__index = function (a,b) return a+b end +assert((10)[3] == 13) +assert((10)["3"] == 13) +debug.setmetatable(23, nil) +assert(getmetatable(-2) == nil) + +debug.setmetatable(true, mt) +assert(getmetatable(false) == mt) +mt.__index = function (a,b) return a or b end +assert((true)[false] == true) +assert((false)[false] == false) +debug.setmetatable(false, nil) +assert(getmetatable(true) == nil) + +debug.setmetatable(nil, mt) +assert(getmetatable(nil) == mt) +mt.__add = function (a,b) return (a or 0) + (b or 0) end +assert(10 + nil == 10) +assert(nil + 23 == 23) +assert(nil + nil == 0) +debug.setmetatable(nil, nil) +assert(getmetatable(nil) == nil) + +debug.setmetatable(nil, {}) +]]-- + +do + -- == must not do ref equality for tables and userdata in presence of __eq + local t = {} + local u = newproxy(true) + + -- print() returns nil which is converted to false + setmetatable(t, { __eq = print }) + getmetatable(u).__eq = print + + assert(t ~= t) + assert(u ~= u) +end + +do + -- verify that internal mt flags are set correctly after two table assignments + local mt = { + {}, -- mixed metatable + __index = {X = true}, + } + local t = setmetatable({}, mt) + assert(t.X) -- fails if table flags are set incorrectly +end + +return 'OK' diff --git a/tests/conformance/exceptions.lua b/tests/conformance/exceptions.lua new file mode 100644 index 0000000..e07029f --- /dev/null +++ b/tests/conformance/exceptions.lua @@ -0,0 +1,35 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing lua_exception') + +-- Verify that no exception is generated +function empty_function() +end + +function pass_number_to_error() + -- Verify the error value of 42 is part of the exception's string. + error(42) +end + +function pass_string_to_error() + -- Verify the error value of "string argument" is part of the exception's string. + error("string argument") +end + +function pass_table_to_error() + -- Pass a table to `error`. A table is used since it is won't be + -- convertable to a string using `lua_tostring`. + error({field="value"}) +end + +function infinite_recursion_error() + -- Generate a stack overflow error + infinite_recursion_error() +end + +function large_allocation_error() + -- Create a table that will require more memory than the test's memory + -- allocator will allow. + table.create(1000000) +end + +return('OK') diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua new file mode 100644 index 0000000..fd4b4de --- /dev/null +++ b/tests/conformance/gc.lua @@ -0,0 +1,294 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing garbage collection') + +collectgarbage() + +_G["while"] = 234 + +limit = 5000 + + + +contCreate = 0 + +print('tables') +while contCreate <= limit do + local a = {}; a = nil + contCreate = contCreate+1 +end + +a = "a" + +contCreate = 0 +print('strings') +while contCreate <= limit do + a = contCreate .. "b"; + a = string.gsub(a, '(%d%d*)', string.upper) + a = "a" + contCreate = contCreate+1 +end + + +contCreate = 0 + +a = {} + +print('functions') +function a:test () + while contCreate <= limit do + loadstring(string.format("function temp(a) return 'a%d' end", contCreate))() + assert(temp() == string.format('a%d', contCreate)) + contCreate = contCreate+1 + end +end + +a:test() + +-- collection of functions without locals, globals, etc. +do local f = function () end end + + +print('long strings') +x = "01234567890123456789012345678901234567890123456789012345678901234567890123456789" +assert(string.len(x)==80) +s = '' +n = 0 +k = 300 +while n < k do s = s..x; n=n+1; j=tostring(n) end +assert(string.len(s) == k*80) +s = string.sub(s, 1, 20000) +s, i = string.gsub(s, '(%d%d%d%d)', math.sin) +assert(i==20000/4) +s = nil +x = nil + +assert(_G["while"] == 234) + + +local bytes = gcinfo() +while 1 do + local nbytes = gcinfo() + if nbytes < bytes then break end -- run until gc + bytes = nbytes + a = {} +end + + +local function dosteps (siz) + collectgarbage() + collectgarbage"stop" + local a = {} + for i=1,100 do a[i] = {{}}; local b = {} end + local x = gcinfo() + local i = 0 + repeat + i = i+1 + until collectgarbage("step", siz) + assert(gcinfo() < x) + return i +end + +assert(dosteps(0) > 10) +assert(dosteps(6) < dosteps(2)) +assert(dosteps(10000) == 1) +-- assert(collectgarbage("step", 1000000) == true) +-- assert(collectgarbage("step", 1000000)) + + +do + local x = gcinfo() + collectgarbage() + collectgarbage"stop" + repeat + local a = {} + until gcinfo() > 1000 + collectgarbage"restart" + repeat + local a = {} + until gcinfo() < 1000 +end + +lim = 15 +a = {} +-- fill a with `collectable' indices +for i=1,lim do a[{}] = i end +b = {} +for k,v in pairs(a) do b[k]=v end +-- remove all indices and collect them +for n in pairs(b) do + a[n] = nil + assert(type(n) == 'table' and next(n) == nil) + collectgarbage() +end +b = nil +collectgarbage() +for n in pairs(a) do error'cannot be here' end +for i=1,lim do a[i] = i end +for i=1,lim do assert(a[i] == i) end + + +print('weak tables') +a = {}; setmetatable(a, {__mode = 'k'}); +-- fill a with some `collectable' indices +for i=1,lim do a[{}] = i end +-- and some non-collectable ones +for i=1,lim do local t={}; a[t]=t end +for i=1,lim do a[i] = i end +for i=1,lim do local s=string.rep('@', i); a[s] = s..'#' end +collectgarbage() +local i = 0 +for k,v in pairs(a) do assert(k==v or k..'#'==v); i=i+1 end +assert(i == 3*lim) + +a = {}; setmetatable(a, {__mode = 'v'}); +a[1] = string.rep('b', 21) +collectgarbage() +assert(a[1]) -- strings are *values* +a[1] = nil +-- fill a with some `collectable' values (in both parts of the table) +for i=1,lim do a[i] = {} end +for i=1,lim do a[i..'x'] = {} end +-- and some non-collectable ones +for i=1,lim do local t={}; a[t]=t end +for i=1,lim do a[i+lim]=i..'x' end +collectgarbage() +local i = 0 +for k,v in pairs(a) do assert(k==v or k-lim..'x' == v); i=i+1 end +assert(i == 2*lim) + +a = {}; setmetatable(a, {__mode = 'vk'}); +local x, y, z = {}, {}, {} +-- keep only some items +a[1], a[2], a[3] = x, y, z +a[string.rep('$', 11)] = string.rep('$', 11) +-- fill a with some `collectable' values +for i=4,lim do a[i] = {} end +for i=1,lim do a[{}] = i end +for i=1,lim do local t={}; a[t]=t end +collectgarbage() +assert(next(a) ~= nil) +local i = 0 +for k,v in pairs(a) do + assert((k == 1 and v == x) or + (k == 2 and v == y) or + (k == 3 and v == z) or k==v); + i = i+1 +end +assert(i == 4) +x,y,z=nil +collectgarbage() +assert(next(a) == string.rep('$', 11)) + + +-- testing userdata +collectgarbage("stop") -- stop collection +local u = newproxy(true) +local s = 0 +local a = {[u] = 0}; setmetatable(a, {__mode = 'vk'}) +for i=1,10 do a[newproxy(true)] = i end +-- for k in pairs(a) do assert(getmetatable(k) == getmetatable(u)) end +local a1 = {}; for k,v in pairs(a) do a1[k] = v end +for k,v in pairs(a1) do a[v] = k end +for i =1,10 do assert(a[i]) end +getmetatable(u).a = a1 +getmetatable(u).u = u +do + local u = u + getmetatable(u).__gc = function (o) + assert(a[o] == 10-s) + assert(a[10-s] == nil) -- udata already removed from weak table + assert(getmetatable(o) == getmetatable(u)) + assert(getmetatable(o).a[o] == 10-s) + s=s+1 + end +end +a1, u = nil +assert(next(a) ~= nil) +collectgarbage() +-- assert(s==11) +collectgarbage() +assert(next(a) == nil) -- finalized keys are removed in two cycles + + +-- __gc x weak tables +local u = newproxy(true) +setmetatable(getmetatable(u), {__mode = "v"}) +getmetatable(u).__gc = function (o) os.exit(1) end -- cannot happen +collectgarbage() + +local u = newproxy(true) +local m = getmetatable(u) +m.x = {[{0}] = 1; [0] = {1}}; setmetatable(m.x, {__mode = "kv"}); +m.__gc = function (o) + assert(next(getmetatable(o).x) == nil) + m = 10 +end +u, m = nil +collectgarbage() +-- assert(m==10) + + +-- errors during collection +u = newproxy(true) +getmetatable(u).__gc = function () error "!!!" end +u = nil +-- assert(not pcall(collectgarbage)) +collectgarbage() + + +if not rawget(_G, "_soft") then + print("deep structures") + local a = {} + for i = 1,200000 do + a = {next = a} + end + collectgarbage() +end + +-- create many threads with self-references and open upvalues +local thread_id = 0 +local threads = {} + +function fn(thread) + local x = {} + threads[thread_id] = function() + thread = x + end + coroutine.yield() +end + +while thread_id < 1000 do + local thread = coroutine.create(fn) + coroutine.resume(thread, thread) + thread_id = thread_id + 1 +end + + + +-- create a userdata to be collected when state is closed +do + local newproxy,assert,type,print,getmetatable = + newproxy,assert,type,print,getmetatable + local u = newproxy(true) + local tt = getmetatable(u) + ___Glob = {u} -- avoid udata being collected before program end + tt.__gc = function (o) + assert(getmetatable(o) == tt) + -- create new objects during GC + local a = 'xuxu'..(10+3)..'joao', {} + ___Glob = o -- ressurect object! + newproxy(o) -- creates a new one with same metatable + print(">>> closing state " .. "<<<\n") + end +end + +-- create several udata to raise errors when collected while closing state +do + local u = newproxy(true) + getmetatable(u).__gc = function (o) return o + 1 end + table.insert(___Glob, u) -- preserve udata until the end + for i = 1,10 do table.insert(___Glob, newproxy(true)) end +end + +return('OK') diff --git a/tests/conformance/ifelseexpr.lua b/tests/conformance/ifelseexpr.lua new file mode 100644 index 0000000..da40b79 --- /dev/null +++ b/tests/conformance/ifelseexpr.lua @@ -0,0 +1,80 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print("testing if-else expressions") + +function True() + return true +end + +function False() + return false +end + +function EvalElseifChain(condition1, condition2, condition3) + return if condition1 then 10 elseif condition2 then 20 elseif condition3 then 30 else 40 +end + +function EvalElse_IfChain(condition1, condition2, condition3) + return if condition1 then 10 else if condition2 then 20 else if condition3 then 30 else 40 +end + +function CheckForConditionalEvaluation(condition) + local counter = 0 + + local function AddToCounter(count) + counter += count + return counter + end + + local result = if condition then AddToCounter(7) else AddToCounter(17) + if condition then + assert(result == 7) + else + assert(result == 17) + end + -- ensure the counter value matches the result of the clause that was evaluated + assert(counter == result) +end + +-- Test expression using only constants +assert(if true then true else false) +assert(if false then false else true) +assert(if nil then false else true) +assert((7 + if true then 10 else 20) == 17) + +-- Test expression using non-constant condition +assert(if True() then true else false) +assert(if False() then false else true) + +-- Test evaluation of a "chain" of if/elseif/else in an expression +assert(EvalElseifChain(false, false, false) == 40) +assert(EvalElseifChain(false, false, true) == 30) +assert(EvalElseifChain(false, true, false) == 20) +assert(EvalElseifChain(false, true, true) == 20) +assert(EvalElseifChain(true, false, false) == 10) +assert(EvalElseifChain(true, false, true) == 10) +assert(EvalElseifChain(true, true, false) == 10) +assert(EvalElseifChain(true, true, true) == 10) + +-- Test evaluation of a "chain" of if/"else if"/else in an expression +assert(EvalElse_IfChain(false, false, false) == 40) +assert(EvalElse_IfChain(false, false, true) == 30) +assert(EvalElse_IfChain(false, true, false) == 20) +assert(EvalElse_IfChain(false, true, true) == 20) +assert(EvalElse_IfChain(true, false, false) == 10) +assert(EvalElse_IfChain(true, false, true) == 10) +assert(EvalElse_IfChain(true, true, false) == 10) +assert(EvalElse_IfChain(true, true, true) == 10) + +-- Test nesting of if-else expressions inside the condition of an if-else expression +assert((if (if True() then false else true) then 10 else 20) == 20) +assert((if if True() then false else true then 10 else 20) == 20) + + +-- Ensure that if/else expressions are conditionally evaluated +-- i.e. verify the evaluated expression doesn't evaluate the true and false expressions and +-- merely select the proper value. +CheckForConditionalEvaluation(true) +CheckForConditionalEvaluation(false) + +return('OK') diff --git a/tests/conformance/literals.lua b/tests/conformance/literals.lua new file mode 100644 index 0000000..da4083d --- /dev/null +++ b/tests/conformance/literals.lua @@ -0,0 +1,180 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing scanner') + +local function dostring (x) return assert(loadstring(x))() end + +-- Luau doesn't support unescaped NUL literals inside Lua source code +-- dostring("x = 'a\0a'") +-- assert(x == 'a\0a' and string.len(x) == 3) + +-- escape sequences +assert('\n\"\'\\' == [[ + +"'\]]) + +assert(string.find("\a\b\f\n\r\t\v", "^%c%c%c%c%c%c%c$")) + +-- assume ASCII just for tests: +assert("\09912" == 'c12') +assert("\99ab" == 'cab') +assert("\099" == '\99') +assert("\099\n" == 'c\10') +assert('\0\0\0alo' == '\0' .. '\0\0' .. 'alo') + +assert(010 .. 020 .. -030 == "1020-30") + +-- long variable names + +var = string.rep('a', 15000) +prog = string.format("%s = 5", var) +dostring(prog) +assert(dostring("return " .. var) == 5) +var = nil +print('+') + +-- escapes -- +assert("\n\t" == [[ + + ]]) +assert([[ + + $debug]] == "\n $debug") +assert([[ [ ]] ~= [[ ] ]]) +-- long strings -- +b = "001234567890123456789012345678901234567891234567890123456789012345678901234567890012345678901234567890123456789012345678912345678901234567890123456789012345678900123456789012345678901234567890123456789123456789012345678901234567890123456789001234567890123456789012345678901234567891234567890123456789012345678901234567890012345678901234567890123456789012345678912345678901234567890123456789012345678900123456789012345678901234567890123456789123456789012345678901234567890123456789001234567890123456789012345678901234567891234567890123456789012345678901234567890012345678901234567890123456789012345678912345678901234567890123456789012345678900123456789012345678901234567890123456789123456789012345678901234567890123456789001234567890123456789012345678901234567891234567890123456789012345678901234567890012345678901234567890123456789012345678912345678901234567890123456789012345678900123456789012345678901234567890123456789123456789012345678901234567890123456789" +assert(string.len(b) == 960) +prog = [=[ +print('+') + +a1 = [["isto e' um string com vrias 'aspas'"]] +a2 = "'aspas'" + +assert(string.find(a1, a2) == 31) +print('+') + +a1 = [==[temp = [[um valor qualquer]]; ]==] +assert(loadstring(a1))() +assert(temp == 'um valor qualquer') +-- long strings -- +b = "001234567890123456789012345678901234567891234567890123456789012345678901234567890012345678901234567890123456789012345678912345678901234567890123456789012345678900123456789012345678901234567890123456789123456789012345678901234567890123456789001234567890123456789012345678901234567891234567890123456789012345678901234567890012345678901234567890123456789012345678912345678901234567890123456789012345678900123456789012345678901234567890123456789123456789012345678901234567890123456789001234567890123456789012345678901234567891234567890123456789012345678901234567890012345678901234567890123456789012345678912345678901234567890123456789012345678900123456789012345678901234567890123456789123456789012345678901234567890123456789001234567890123456789012345678901234567891234567890123456789012345678901234567890012345678901234567890123456789012345678912345678901234567890123456789012345678900123456789012345678901234567890123456789123456789012345678901234567890123456789" +assert(string.len(b) == 960) +print('+') + +a = [[00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +00123456789012345678901234567890123456789123456789012345678901234567890123456789 +]] +assert(string.len(a) == 1863) +assert(string.sub(a, 1, 40) == string.sub(b, 1, 40)) +x = 1 +]=] + +print('+') +x = nil +dostring(prog) +assert(x) + +prog = nil +a = nil +b = nil + + +-- testing line ends +prog = [[ +a = 1 -- a comment +b = 2 + + +x = [=[ +hi +]=] +y = "\ +hello\r\n\ +" +return debug.getinfo(1).currentline +]] + +for _, n in pairs{"\n", "\r", "\n\r", "\r\n"} do + local prog, nn = string.gsub(prog, "\n", n) + -- assert(dostring(prog) == nn) + -- assert(_G.x == "hi\n" and _G.y == "\nhello\r\n\n") +end + + +-- testing comments and strings with long brackets +a = [==[]=]==] +assert(a == "]=") + +a = [==[[===[[=[]]=][====[]]===]===]==] +assert(a == "[===[[=[]]=][====[]]===]===") + +a = [====[[===[[=[]]=][====[]]===]===]====] +assert(a == "[===[[=[]]=][====[]]===]===") + +a = [=[]]]]]]]]]=] +assert(a == "]]]]]]]]") + + +--[===[ +x y z [==[ blu foo +]== +] +]=]==] +error error]=]===] + +-- generate all strings of four of these chars +local x = {"=", "[", "]", "\n"} +local len = 4 +local function gen (c, n) + if n==0 then coroutine.yield(c) + else + for _, a in pairs(x) do + gen(c..a, n-1) + end + end +end + +for s in coroutine.wrap(function () gen("", len) end) do + assert(s == loadstring("return [====[\n"..s.."]====]")()) +end + + +--[[ +-- testing decimal point locale +if os.setlocale("pt_BR") or os.setlocale("ptb") then + assert(tonumber("3,4") == 3.4 and tonumber"3.4" == nil) + assert(assert(loadstring("return 3.4"))() == 3.4) + assert(assert(loadstring("return .4,3"))() == .4) + assert(assert(loadstring("return 4."))() == 4.) + assert(assert(loadstring("return 4.+.5"))() == 4.5) + local a,b = loadstring("return 4.5.") + assert(string.find(b, "'4%.5%.'")) + assert(os.setlocale("C")) +else + (Message or print)( + '\a\n >>> pt_BR locale not available: skipping decimal point tests <<<\n\a') +end +]]-- + +return('OK') diff --git a/tests/conformance/locals.lua b/tests/conformance/locals.lua new file mode 100644 index 0000000..cbe5f92 --- /dev/null +++ b/tests/conformance/locals.lua @@ -0,0 +1,127 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing local variables plus some extra stuff') + +do + local i = 10 + do local i = 100; assert(i==100) end + do local i = 1000; assert(i==1000) end + assert(i == 10) + if i ~= 10 then + local i = 20 + else + local i = 30 + assert(i == 30) + end +end + + + +f = nil + +local f +x = 1 + +a = nil +loadstring('local a = {}')() +assert(type(a) ~= 'table') + +function f (a) + local _1, _2, _3, _4, _5 + local _6, _7, _8, _9, _10 + local x = 3 + local b = a + local c,d = a,b + if (d == b) then + local x = 'q' + x = b + assert(x == 2) + else + assert(nil) + end + assert(x == 3) + local f = 10 +end + +local b=10 +local a; repeat local b; a,b=1,2; assert(a+1==b); until a+b==3 + + +assert(x == 1) + +f(2) +assert(type(f) == 'function') + + +-- testing globals ;-) +do + local f = {} + local _G = getfenv() + for i=1,10 do f[i] = function (x) A=A+1; return A, _G.getfenv(x) end end + A=10; assert(f[1]() == 11) + for i=1,10 do assert(setfenv(f[i], {A=i}) == f[i]) end + assert(f[3]() == 4 and A == 11) + local a,b = f[8](1) + assert(b.A == 9) + a,b = f[8](0) + assert(b.A == 11) -- `real' global + local g + local function f () assert(setfenv(2, {a='10'}) == g) end + g = function () f(); _G.assert(_G.getfenv(1).a == '10') end + g(); assert(getfenv(g).a == '10') +end + +-- test for global table of loaded chunks +local function foo (s) + return loadstring(s) +end + +assert(getfenv(foo("")) == getfenv()) +local a = {loadstring = loadstring} +setfenv(foo, a) +assert(getfenv(foo("")) == getfenv()) +setfenv(0, a) -- change global environment +assert(getfenv(foo("")) == a) +setfenv(0, getfenv()) + + +-- testing limits for special instructions + +local a +local p = 4 +for i=2,31 do + for j=-3,3 do + assert(loadstring(string.format([[local a=%s;a=a+ + %s; + assert(a + ==2^%s)]], j, p-j, i))) () + assert(loadstring(string.format([[local a=%s; + a=a-%s; + assert(a==-2^%s)]], -j, p-j, i))) () + assert(loadstring(string.format([[local a,b=0,%s; + a=b-%s; + assert(a==-2^%s)]], -j, p-j, i))) () + end + p =2*p +end + +print'+' + + +if rawget(_G, "querytab") then + -- testing clearing of dead elements from tables + collectgarbage("stop") -- stop GC + local a = {[{}] = 4, [3] = 0, alo = 1, + a1234567890123456789012345678901234567890 = 10} + + local t = querytab(a) + + for k,_ in pairs(a) do a[k] = nil end + collectgarbage() -- restore GC and collect dead fiels in `a' + for i=0,t-1 do + local k = querytab(a, i) + assert(k == nil or type(k) == 'number' or k == 'alo') + end +end + +return('OK') diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua new file mode 100644 index 0000000..5e8b939 --- /dev/null +++ b/tests/conformance/math.lua @@ -0,0 +1,296 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print("testing numbers and math lib") + +do + local a,b,c = "2", " 3e0 ", " 10 " + assert(a+b == 5 and -b == -3 and b+"2" == 5 and "10"-c == 0) + assert(type(a) == 'string' and type(b) == 'string' and type(c) == 'string') + assert(a == "2" and b == " 3e0 " and c == " 10 " and -c == -" 10 ") + assert(c%a == 0 and a^b == 8) +end + + +do + local a,b = math.modf(3.5) + assert(a == 3 and b == 0.5) + assert(math.huge > 10e30) + assert(-math.huge < -10e30) +end + +function f(...) + if select('#', ...) == 1 then + return (...) + else + return "***" + end +end + +assert(tonumber{} == nil) +assert(tonumber'+0.01' == 1/100 and tonumber'+.01' == 0.01 and + tonumber'.01' == 0.01 and tonumber'-1.' == -1 and + tonumber'+1.' == 1) +assert(tonumber'+ 0.01' == nil and tonumber'+.e1' == nil and + tonumber'1e' == nil and tonumber'1.0e+' == nil and + tonumber'.' == nil) +assert(tonumber('-12') == -10-2) +assert(tonumber('-1.2e2') == - - -120) +assert(f(tonumber('1 a')) == nil) +assert(f(tonumber('e1')) == nil) +assert(f(tonumber('e 1')) == nil) +assert(f(tonumber(' 3.4.5 ')) == nil) +assert(f(tonumber('')) == nil) +assert(f(tonumber('', 8)) == nil) +assert(f(tonumber(' ')) == nil) +assert(f(tonumber(' ', 9)) == nil) +assert(f(tonumber('99', 8)) == nil) +assert(tonumber(' 1010 ', 2) == 10) +assert(tonumber('10', 36) == 36) +--assert(tonumber('\n -10 \n', 36) == -36) +--assert(tonumber('-fFfa', 16) == -(10+(16*(15+(16*(15+(16*15))))))) +assert(tonumber('fFfa', 15) == nil) +--assert(tonumber(string.rep('1', 42), 2) + 1 == 2^42) +assert(tonumber(string.rep('1', 32), 2) + 1 == 2^32) +--assert(tonumber('-fffffFFFFF', 16)-1 == -2^40) +assert(tonumber('ffffFFFF', 16)+1 == 2^32) + +assert(1.1 == 1.+.1) +assert(100.0 == 1E2 and .01 == 1e-2) +assert(1111111111111111-1111111111111110== 1000.00e-03) +-- 1234567890123456 +assert(1.1 == '1.'+'.1') +assert('1111111111111111'-'1111111111111110' == tonumber" +0.001e+3 \n\t") + +function eq (a,b,limit) + if not limit then limit = 10E-10 end + return math.abs(a-b) <= limit +end + +assert(0.1e-30 > 0.9E-31 and 0.9E30 < 0.1e31) + +assert(0.123456 > 0.123455) + +assert(tonumber('+1.23E30') == 1.23*10^30) + +-- testing order operators +assert(not(1<1) and (1<2) and not(2<1)) +assert(not('a'<'a') and ('a'<'b') and not('b'<'a')) +assert((1<=1) and (1<=2) and not(2<=1)) +assert(('a'<='a') and ('a'<='b') and not('b'<='a')) +assert(not(1>1) and not(1>2) and (2>1)) +assert(not('a'>'a') and not('a'>'b') and ('b'>'a')) +assert((1>=1) and not(1>=2) and (2>=1)) +assert(('a'>='a') and not('a'>='b') and ('b'>='a')) + +-- testing mod operator +assert(-4%3 == 2) +assert(4%-3 == -2) +assert(math.pi - math.pi % 1 == 3) +assert(math.pi - math.pi % 0.001 == 3.141) + +do + local a = 3 % 0; + assert(a ~= a) -- Expect NaN + assert(((2^53+1) % 2) == 0) + assert((1234 % (2^53+1)) == 1234) +end + +local function testbit(a, n) + return a/2^n % 2 >= 1 +end + +assert(eq(math.sin(-9.8)^2 + math.cos(-9.8)^2, 1)) +assert(eq(math.tan(math.pi/4), 1)) +assert(eq(math.sin(math.pi/2), 1) and eq(math.cos(math.pi/2), 0)) +assert(eq(math.atan(1), math.pi/4) and eq(math.acos(0), math.pi/2) and + eq(math.asin(1), math.pi/2)) +assert(eq(math.deg(math.pi/2), 90) and eq(math.rad(90), math.pi/2)) +assert(math.abs(-10) == 10) +assert(eq(math.atan2(1,0), math.pi/2)) +assert(math.ceil(4.5) == 5.0) +assert(math.floor(4.5) == 4.0) +assert(10 % 3 == 1) +assert(eq(math.sqrt(10)^2, 10)) +assert(eq(math.log10(2), math.log(2)/math.log(10))) +assert(eq(math.log(2, 2), 1)) +assert(eq(math.log(9, 3), 2)) +assert(eq(math.log(100, 10), 2)) +assert(eq(math.exp(0), 1)) +assert(eq(math.sin(10), math.sin(10%(2*math.pi)))) +local v,e = math.frexp(math.pi) +assert(eq(math.ldexp(v,e), math.pi)) + +assert(eq(math.tanh(3.5), math.sinh(3.5)/math.cosh(3.5))) + +assert(tonumber(' 1.3e-2 ') == 1.3e-2) +assert(tonumber(' -1.00000000000001 ') == -1.00000000000001) + +-- testing constant limits +-- 2^23 = 8388608 +assert(8388609 + -8388609 == 0) +assert(8388608 + -8388608 == 0) +assert(8388607 + -8388607 == 0) + +if rawget(_G, "_soft") then return end + +f = "a = {" +i = 1 +repeat + f = f .. "{" .. math.sin(i) .. ", " .. math.cos(i) .. ", " .. (i/3) .. "},\n" + i=i+1 +until i > 1000 +f = f .. "}" +assert(loadstring(f))() + +assert(eq(a[300][1], math.sin(300))) +assert(eq(a[600][1], math.sin(600))) +assert(eq(a[500][2], math.cos(500))) +assert(eq(a[800][2], math.cos(800))) +assert(eq(a[200][3], 200/3)) +assert(eq(a[1000][3], 1000/3, 0.001)) +print('+') + +do -- testing NaN + local NaN = 10e500 - 10e400 + assert(NaN ~= NaN) + assert(not (NaN < NaN)) + assert(not (NaN <= NaN)) + assert(not (NaN > NaN)) + assert(not (NaN >= NaN)) + assert(not (0 < NaN)) + assert(not (NaN < 0)) + local a = {} + assert(not pcall(function () a[NaN] = 1 end)) + assert(a[NaN] == nil) + a[1] = 1 + assert(not pcall(function () a[NaN] = 1 end)) + assert(a[NaN] == nil) +end + +-- require "checktable" +-- stat(a) + +a = nil + +-- testing implicit convertions + +local a,b = '10', '20' +assert(a*b == 200 and a+b == 30 and a-b == -10 and a/b == 0.5 and -b == -20) +assert(a == '10' and b == '20') + + +math.randomseed(0) + +local i = 0 +local Max = 0 +local Min = 2 +repeat + local t = math.random() + Max = math.max(Max, t) + Min = math.min(Min, t) + i=i+1 + flag = eq(Max, 1, 0.001) and eq(Min, 0, 0.001) +until flag or i>10000 +assert(0 <= Min and Max<1) +assert(flag); + +for i=1,10 do + local t = math.random(5) + assert(1 <= t and t <= 5) +end + +i = 0 +Max = -200 +Min = 200 +repeat + local t = math.random(-10,0) + Max = math.max(Max, t) + Min = math.min(Min, t) + i=i+1 + flag = (Max == 0 and Min == -10) +until flag or i>10000 +assert(-10 <= Min and Max<=0) +assert(flag); + +assert(select(2, pcall(math.random, 1, 2, 3)):match("wrong number of arguments")) + +-- noise +assert(math.noise(0.5) == 0) +assert(math.noise(0.5, 0.5) == -0.25) +assert(math.noise(0.5, 0.5, -0.5) == 0.125) + +local inf = math.huge * 2 +local nan = 0 / 0 + +-- sign +assert(math.sign(0) == 0) +assert(math.sign(42) == 1) +assert(math.sign(-42) == -1) +assert(math.sign(inf) == 1) +assert(math.sign(-inf) == -1) +assert(math.sign(nan) == 0) + +-- clamp +assert(math.clamp(-1, 0, 1) == 0) +assert(math.clamp(0.5, 0, 1) == 0.5) +assert(math.clamp(2, 0, 1) == 1) +assert(math.clamp(4, 0, 0) == 0) + +-- round +assert(math.round(0) == 0) +assert(math.round(0.4) == 0) +assert(math.round(0.5) == 1) +assert(math.round(3.5) == 4) +assert(math.round(-0.4) == 0) +assert(math.round(-0.5) == -1) +assert(math.round(-3.5) == -4) +assert(math.round(math.huge) == math.huge) + +-- fmod +assert(math.fmod(3, 2) == 1) +assert(math.fmod(-3, 2) == -1) +assert(math.fmod(3, -2) == 1) +assert(math.fmod(-3, -2) == -1) + +-- most of the tests above go through fastcall path +-- to make sure the basic implementations are also correct we test these functions with string->number coercions +assert(math.abs("-4") == 4) +assert(math.acos("1") == 0) +assert(math.asin("0") == 0) +assert(math.atan2("0", "0") == 0) +assert(math.atan("0") == 0) +assert(math.ceil("1.5") == 2) +assert(math.cosh("0") == 1) +assert(math.cos("0") == 1) +assert(math.deg("0") == 0) +assert(math.exp("0") == 1) +assert(math.floor("1.5") == 1) +assert(math.fmod("1.5", 1) == 0.5) +local v,e = math.frexp("1.5") +assert(v == 0.75 and e == 1) +assert(math.ldexp("0.75", 1) == 1.5) +assert(math.log10("10") == 1) +assert(math.log("0") == -inf) +assert(math.log("8", 2) == 3) +assert(math.log("10", 10) == 1) +assert(math.log("9", 3) == 2) +assert(math.max("1", 2) == 2) +assert(math.max(2, "1") == 2) +assert(math.min("1", 2) == 1) +assert(math.min(2, "1") == 1) +local v,f = math.modf("1.5") +assert(v == 1 and f == 0.5) +assert(math.pow("2", 2) == 4) +assert(math.rad("0") == 0) +assert(math.sinh("0") == 0) +assert(math.sin("0") == 0) +assert(math.sqrt("4") == 2) +assert(math.tanh("0") == 0) +assert(math.tan("0") == 0) +assert(math.clamp("0", 2, 3) == 2) +assert(math.sign("2") == 1) +assert(math.sign("-2") == -1) +assert(math.sign("0") == 0) +assert(math.round("1.8") == 2) + +return('OK') diff --git a/tests/conformance/move.lua b/tests/conformance/move.lua new file mode 100644 index 0000000..3f28b4b --- /dev/null +++ b/tests/conformance/move.lua @@ -0,0 +1,77 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +local function checkerror (msg, f, ...) + local s, err = pcall(f, ...) + assert(not s and string.find(err, msg)) +end + +print("testing move") + +local maxI = 2147483647 +local minI = -2147483648 + +-- testing move +do + + checkerror("table expected", table.move, 1, 2, 3, 4) + checkerror("table expected", table.move, {}, 2, 3, 4, "foo") + + local function eqT (a, b) + for k, v in pairs(a) do assert(b[k] == v) end + for k, v in pairs(b) do assert(a[k] == v) end + end + + local a = table.move({10,20,30}, 1, 3, 2) -- move forward + eqT(a, {10,10,20,30}) + + -- move forward with overlap of 1 + a = table.move({10, 20, 30}, 1, 3, 3) + eqT(a, {10, 20, 10, 20, 30}) + + -- moving to the same table (not being explicit about it) + a = {10, 20, 30, 40} + table.move(a, 1, 4, 2, a) + eqT(a, {10, 10, 20, 30, 40}) + + a = table.move({10,20,30}, 2, 3, 1) -- move backward + eqT(a, {20,30,30}) + + a = {} -- move to new table + assert(table.move({10,20,30}, 1, 3, 1, a) == a) + eqT(a, {10,20,30}) + + a = {} + assert(table.move({10,20,30}, 1, 0, 3, a) == a) -- empty move (no move) + eqT(a, {}) + + a = table.move({10,20,30}, 1, 10, 1) -- move to the same place + eqT(a, {10,20,30}) + + -- moving on the fringes + a = table.move({[maxI - 2] = 1, [maxI - 1] = 2, [maxI] = 3}, + maxI - 2, maxI, -10, {}) + eqT(a, {[-10] = 1, [-9] = 2, [-8] = 3}) + + a = table.move({[minI] = 1, [minI + 1] = 2, [minI + 2] = 3}, + minI, minI + 2, -10, {}) + eqT(a, {[-10] = 1, [-9] = 2, [-8] = 3}) + + a = table.move({45}, 1, 1, maxI) + eqT(a, {45, [maxI] = 45}) + + a = table.move({[maxI] = 100}, maxI, maxI, minI) + eqT(a, {[minI] = 100, [maxI] = 100}) + + a = table.move({[minI] = 100}, minI, minI, maxI) + eqT(a, {[minI] = 100, [maxI] = 100}) +end + +checkerror("too many", table.move, {}, 0, maxI, 1) +checkerror("too many", table.move, {}, -1, maxI - 1, 1) +checkerror("too many", table.move, {}, minI, -1, 1) +checkerror("too many", table.move, {}, minI, maxI, 1) +checkerror("wrap around", table.move, {}, 1, maxI, 2) +checkerror("wrap around", table.move, {}, 1, 2, maxI) +checkerror("wrap around", table.move, {}, minI, -2, 2) + +return"OK" diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua new file mode 100644 index 0000000..7f9b759 --- /dev/null +++ b/tests/conformance/nextvar.lua @@ -0,0 +1,515 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing tables, next, and for') + +local unpack = table.unpack + +-- testing table.insert return value +assert(select('#', table.insert({}, 42)) == 0) + +local a = {} + +-- make sure table has lots of space in hash part +for i=1,100 do a[i.."+"] = true end +for i=1,100 do a[i.."+"] = nil end +-- fill hash part with numeric indices testing size operator +for i=1,100 do + a[i] = true + assert(#a == i) +end + + +if T then +-- testing table sizes + +local l2 = math.log(2) +local function log2 (x) return math.log(x)/l2 end + +local function mp2 (n) -- minimum power of 2 >= n + local mp = 2^math.ceil(log2(n)) + assert(n == 0 or (mp/2 < n and n <= mp)) + return mp +end + +local function fb (n) + local r, nn = T.int2fb(n) + assert(r < 256) + return nn +end + +-- test fb function +local a = 1 +local lim = 2^30 +while a < lim do + local n = fb(a) + assert(a <= n and n <= a*1.125) + a = math.ceil(a*1.3) +end + + +local function check (t, na, nh) + local a, h = T.querytab(t) + if a ~= na or h ~= nh then + print(na, nh, a, h) + assert(nil) + end +end + +-- testing constructor sizes +local lim = 40 +local s = 'return {' +for i=1,lim do + s = s..i..',' + local s = s + for k=0,lim do + local t = loadstring(s..'}')() + assert(#t == i) + check(t, fb(i), mp2(k)) + s = string.format('%sa%d=%d,', s, k, k) + end +end + + +-- tests with unknown number of elements +local a = {} +for i=1,lim do a[i] = i end -- build auxiliary table +for k=0,lim do + local a = {unpack(a,1,k)} + assert(#a == k) + check(a, k, 0) + a = {1,2,3,unpack(a,1,k)} + check(a, k+3, 0) + assert(#a == k + 3) +end + + +print'+' + +-- testing tables dynamically built +local lim = 130 +local a = {}; a[2] = 1; check(a, 0, 1) +a = {}; a[0] = 1; check(a, 0, 1); a[2] = 1; check(a, 0, 2) +a = {}; a[0] = 1; a[1] = 1; check(a, 1, 1) +a = {} +for i = 1,lim do + a[i] = 1 + assert(#a == i) + check(a, mp2(i), 0) +end + +a = {} +for i = 1,lim do + a['a'..i] = 1 + assert(#a == 0) + check(a, 0, mp2(i)) +end + +a = {} +for i=1,16 do a[i] = i end +check(a, 16, 0) +for i=1,11 do a[i] = nil end +for i=30,40 do a[i] = nil end -- force a rehash (?) +check(a, 0, 8) +a[10] = 1 +for i=30,40 do a[i] = nil end -- force a rehash (?) +check(a, 0, 8) +for i=1,14 do a[i] = nil end +for i=30,50 do a[i] = nil end -- force a rehash (?) +check(a, 0, 4) + +-- reverse filling +for i=1,lim do + local a = {} + for i=i,1,-1 do a[i] = i end -- fill in reverse + check(a, mp2(i), 0) +end + +-- size tests for vararg +lim = 35 +function foo (n, ...) + local arg = {...} + check(arg, n, 0) + assert(select('#', ...) == n) + arg[n+1] = true + check(arg, mp2(n+1), 0) + arg.x = true + check(arg, mp2(n+1), 1) +end +local a = {} +for i=1,lim do a[i] = true; foo(i, unpack(a)) end + +end + + +-- test size operation on empty tables +assert(#{} == 0) +assert(#{nil} == 0) +assert(#{nil, nil} == 0) +assert(#{nil, nil, nil} == 0) +assert(#{nil, nil, nil, nil} == 0) +print'+' + + +local nofind = {} + +a,b,c = 1,2,3 +a,b,c = nil + +local function find (name) + local n,v + while 1 do + n,v = next(_G, n) + if not n then return nofind end + assert(v ~= nil) + if n == name then return v end + end +end + +local function find1 (name) + for n,v in pairs(_G) do + if n==name then return v end + end + return nil -- not found +end + +do -- create 10000 new global variables + for i=1,10000 do _G[i] = i end +end + + +a = {x=90, y=8, z=23} +assert(table.foreach(a, function(i,v) if i=='x' then return v end end) == 90) +assert(table.foreach(a, function(i,v) if i=='a' then return v end end) == nil) +table.foreach({}, error) + +table.foreachi({x=10, y=20}, error) +local a = {n = 1} +table.foreachi({n=3}, function (i, v) + assert(a.n == i and not v) + a.n=a.n+1 +end) +a = {10,20,30,nil,50} +table.foreachi(a, function (i,v) assert(a[i] == v) end) +assert(table.foreachi({'a', 'b', 'c'}, function (i,v) + if i==2 then return v end + end) == 'b') + + +-- assert(print==find("print") and print == find1("print")) +-- assert(_G["print"]==find("print")) +-- assert(assert==find1("assert")) +assert(nofind==find("return")) +assert(not find1("return")) +_G["ret" .. "urn"] = nil +assert(nofind==find("return")) +_G["xxx"] = 1 +-- assert(xxx==find("xxx")) +print('+') + +a = {} +for i=0,10000 do + if i % 10 ~= 0 then + a['x'..i] = i + end +end + +n = {n=0} +for i,v in pairs(a) do + n.n = n.n+1 + assert(i and v and a[i] == v) +end +assert(n.n == 9000) +a = nil + +-- remove those 10000 new global variables +for i=1,10000 do _G[i] = nil end + +do -- clear global table + local a = {} + local preserve = {io = 1, string = 1, debug = 1, os = 1, + coroutine = 1, table = 1, math = 1} + for n,v in pairs(_G) do a[n]=v end + for n,v in pairs(a) do + if not preserve[n] and type(v) ~= "function" and + not string.find(n, "^[%u_]") then + _G[n] = nil + end + collectgarbage() + end +end + +local function foo () + local getfenv, setfenv, assert, next = + getfenv, setfenv, assert, next + local n = {gl1=3} + setfenv(foo, n) + assert(getfenv(foo) == getfenv(1)) + assert(getfenv(foo) == n) + assert(print == nil and gl1 == 3) + gl1 = nil + gl = 1 + assert(n.gl == 1 and next(n, 'gl') == nil) +end +foo() + +print'+' + +local function checknext (a) + local b = {} + table.foreach(a, function (k,v) b[k] = v end) + for k,v in pairs(b) do assert(a[k] == v) end + for k,v in pairs(a) do assert(b[k] == v) end + b = {} + do local k,v = next(a); while k do b[k] = v; k,v = next(a,k) end end + for k,v in pairs(b) do assert(a[k] == v) end + for k,v in pairs(a) do assert(b[k] == v) end +end + +checknext{1,x=1,y=2,z=3} +checknext{1,2,x=1,y=2,z=3} +checknext{1,2,3,x=1,y=2,z=3} +checknext{1,2,3,4,x=1,y=2,z=3} +checknext{1,2,3,4,5,x=1,y=2,z=3} + +assert(table.getn{} == 0) +assert(table.getn{[-1] = 2} == 0) +assert(table.getn{1,2,3,nil,nil} == 3) +for i=0,40 do + local a = {} + for j=1,i do a[j]=j end + assert(table.getn(a) == i) +end + + +assert(table.maxn{} == 0) +assert(table.maxn{["1000"] = true} == 0) +assert(table.maxn{["1000"] = true, [24.5] = 3} == 24.5) +assert(table.maxn{[1000] = true} == 1000) +assert(table.maxn{[10] = true, [100*math.pi] = print} == 100*math.pi) + + +-- int overflow +a = {} +for i=0,50 do a[math.pow(2,i)] = true end +assert(a[table.getn(a)]) + +print("+") + + +-- erasing values +local t = {[{1}] = 1, [{2}] = 2, [string.rep("x ", 4)] = 3, + [100.3] = 4, [4] = 5} + +local n = 0 +for k, v in pairs( t ) do + n = n+1 + assert(t[k] == v) + t[k] = nil + collectgarbage() + assert(t[k] == nil) +end +assert(n == 5) + + +local function test (a) + table.insert(a, 10); table.insert(a, 2, 20); + table.insert(a, 1, -1); table.insert(a, 40); + table.insert(a, table.getn(a)+1, 50) + table.insert(a, 2, -2) + assert(table.remove(a,1) == -1) + assert(table.remove(a,1) == -2) + assert(table.remove(a,1) == 10) + assert(table.remove(a,1) == 20) + assert(table.remove(a,1) == 40) + assert(table.remove(a,1) == 50) + assert(table.remove(a,1) == nil) +end + +a = {n=0, [-7] = "ban"} +test(a) +assert(a.n == 0 and a[-7] == "ban") + +a = {[-7] = "ban"}; +test(a) +assert(a.n == nil and table.getn(a) == 0 and a[-7] == "ban") + + +table.insert(a, 1, 10); table.insert(a, 1, 20); table.insert(a, 1, -1) +assert(table.remove(a) == 10) +assert(table.remove(a) == 20) +assert(table.remove(a) == -1) + +a = {'c', 'd'} +table.insert(a, 3, 'a') +table.insert(a, 'b') +assert(table.remove(a, 1) == 'c') +assert(table.remove(a, 1) == 'd') +assert(table.remove(a, 1) == 'a') +assert(table.remove(a, 1) == 'b') +assert(table.getn(a) == 0 and a.n == nil) +print("+") + +-- out of range insertion +a = {1, 2, 3} +table.insert(a, 0, 0) +assert(a[0] == 0 and table.concat(a) == "123") +table.insert(a, 10, 10) +assert(a[0] == 0 and table.concat(a) == "123" and table.maxn(a) == 10 and a[10] == 10) +table.insert(a, -10^9, 42) +assert(a[0] == 0 and table.concat(a) == "123" and table.maxn(a) == 10 and a[10] == 10 and a[-10^9] == 42) +table.insert(a, 0 / 0, 42) -- platform-dependent behavior atm so hard to validate + +a = {} +for i=1,1000 do + a[i] = i; a[i-1] = nil +end +assert(next(a,nil) == 1000 and next(a,1000) == nil) + +assert(next({}) == nil) +assert(next({}, nil) == nil) + +for a,b in pairs{} do error"not here" end +for i=1,0 do error'not here' end +for i=0,1,-1 do error'not here' end +a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) +a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) + +a = 0; for i=0, 1, 0.1 do a=a+1 end; assert(a==11) +-- precision problems +--a = 0; for i=1, 0, -0.01 do a=a+1 end; assert(a==101) +a = 0; for i=0, 0.999999999, 0.1 do a=a+1 end; assert(a==10) +a = 0; for i=1, 1, 1 do a=a+1 end; assert(a==1) +a = 0; for i=1e10, 1e10, -1 do a=a+1 end; assert(a==1) +a = 0; for i=1, 0.99999, 1 do a=a+1 end; assert(a==0) +a = 0; for i=99999, 1e5, -1 do a=a+1 end; assert(a==0) +a = 0; for i=1, 0.99999, -1 do a=a+1 end; assert(a==1) + +-- conversion +a = 0; for i="10","1","-2" do a=a+1 end; assert(a==5) + + +collectgarbage() + + +-- testing generic 'for' + +local function f (n, p) + local t = {}; for i=1,p do t[i] = i*10 end + return function (_,n) + if n > 0 then + n = n-1 + return n, unpack(t) + end + end, nil, n +end + +local x = 0 +for n,a,b,c,d in f(5,3) do + x = x+1 + assert(a == 10 and b == 20 and c == 30 and d == nil) +end +assert(x == 5) + +-- testing table.create and table.find +do + local t = table.create(5) + assert(#t == 0) -- filled with nil! + t[5] = 5 + assert(#t == 5) -- magic + + local t2 = table.create(5, "nice") + assert(table.concat(t2,"!") == "nice!nice!nice!nice!nice") + + assert(table.find(t2, "nice") == 1) + assert(table.find(t2, "nice", 5) == 5) + assert(table.find(t2, "nice", 6) == nil) + + assert(table.find({false, true}, true) == 2) + + -- make sure table.find checks the hash portion as well by constructing a table literal that forces the value into the hash part + assert(table.find({[(1)] = true}, true) == 1) +end + +-- test indexing with strings that have zeroes embedded in them +do + local t = {} + t['start\0end'] = 1 + t['start'] = 2 + assert(t['start\0end'] == 1) + assert(t['start'] == 2) +end + +-- test table freezing +do + local t = {} + assert(not table.isfrozen(t)) + t[1] = 1 + t.a = 2 + + -- basic freeze test to validate invariants + assert(table.freeze(t) == t) + assert(table.isfrozen(t)) + assert(not pcall(rawset, t, 1, 2)) + assert(not pcall(rawset, t, "a", 2)) + assert(not pcall(function() t.a = 3 end)) + assert(not pcall(function() t[1] = 3 end)) + assert(not pcall(setmetatable, t, {})) + assert(not pcall(table.freeze, t)) -- already frozen + + -- can't freeze tables with protected metatable + local t = {} + setmetatable(t, { __metatable = "nope" }) + assert(not pcall(table.freeze, t)) + assert(not table.isfrozen(t)) + + -- note that it's valid to freeze a table with a metatable and protect it later, freeze doesn't freeze metatable automatically + local mt = {} + local t = setmetatable({}, mt) + table.freeze(t) + mt.__metatable = "nope" + assert(table.isfrozen(t)) + assert(getmetatable(t) == "nope") +end + +-- test #t +do + local t = table.create(10, 1) + assert(#t == 10) + t[5] = nil + assert(#t == 10) + t[10] = nil + assert(#t == 9) + t[9] = nil + t[8] = nil + assert(#t == 7) + + t = table.create(10) + assert(#t == 0) + t[1] = 1 + assert(#t == 1) + t[2] = 1 + assert(#t == 2) + t[3] = 1 + t[4] = 1 + assert(#t == 4) + + t = table.create(10) + assert(#t == 0) + table.insert(t, 1) + assert(#t == 1) + table.insert(t, 1) + assert(#t == 2) + table.insert(t, 1) + table.insert(t, 1) + assert(#t == 4) + + t = table.create(10, 1) + assert(#t == 10) + table.remove(t) + assert(#t == 9) + table.remove(t) + table.remove(t) + assert(#t == 7) +end + +return"OK" diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua new file mode 100644 index 0000000..a2072d2 --- /dev/null +++ b/tests/conformance/pcall.lua @@ -0,0 +1,147 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing pcall") + +function checkresults(e, ...) + local t = table.pack(...) + assert(t.n == #e) + for i=1,t.n do + assert(t[i] == e[i]) + end +end + +function checkerror(...) + local t = table.pack(...) + assert(t.n == 2) + assert(t[1] == false) + assert(type(t[2]) == "string") +end + +function corun(f) + local co = coroutine.create(f) + local res = {} + while coroutine.status(co) == "suspended" do + res = {coroutine.resume(co)} + end + assert(coroutine.status(co) == "dead") + return table.unpack(res) +end + +function colog(f) + local co = coroutine.create(f) + local res = {} + while coroutine.status(co) == "suspended" do + local run = {coroutine.resume(co)} + if run[1] then + table.insert(res, coroutine.status(co) == "suspended" and "yield" or "return"); + else + table.insert(res, "error"); + end + table.move(run, 2, #run, 1 + #res, res) -- equivalent to table.append(res, run) + print(coroutine.status(co), table.unpack(res)) + end + assert(coroutine.status(co) == "dead") + return table.unpack(res) +end + +-- basic behavior tests - no error/yielding, just checking argument passing +checkresults({ true, 42 }, pcall(function() return 42 end)) +checkresults({ true, 1, 2, 42 }, pcall(function(a, b) return a, b, 42 end, 1, 2)) +checkresults({ true, 2 }, pcall(function(...) return select('#', ...) end, 1, 2)) + +-- the argument could be a C function or a callable +checkresults({ true, 42 }, pcall(math.abs, -42)) +checkresults({ true, 42 }, pcall(setmetatable({}, { __call = function(self, arg) return math.abs(arg) end }), -42)) + +-- basic error tests - including interpreter errors and errors generated by C APIs +checkerror(pcall(function() local a = nil / 5 end)) +checkerror(pcall(function() select(-100) end)) + +if not limitedstack then + -- complex error tests - stack overflow, and stack overflow through pcall + function stackinfinite() return stackinfinite() end + checkerror(pcall(stackinfinite)) + + function stackover() return pcall(stackover) end + local res = {pcall(stackover)} + assert(#res == 200) +end + +-- yield tests +checkresults({ "yield", "return", true, 42 }, colog(function() return pcall(function() coroutine.yield() return 42 end) end)) +checkresults({ "yield", 1, "return", true, 42 }, colog(function() return pcall(function() coroutine.yield(1) return 42 end) end)) +checkresults({ "yield", 1, 2, 3, "return", true, 42 }, colog(function() return pcall(function() coroutine.yield(1, 2, 3) return 42 end) end)) +checkresults({ "yield", 1, "yield", 2, "yield", 3, "return", true, 42 }, colog(function() return pcall(function() for i=1,3 do coroutine.yield(i) end return 42 end) end)) +checkresults({ "yield", "return", true, 1, 2, 3}, colog(function() return pcall(function() coroutine.yield() return 1, 2, 3 end) end)) + +-- recursive yield tests +checkresults({ "yield", 1, "yield", 2, "return", true, true, 3}, colog(function() return pcall(function() coroutine.yield(1) return pcall(function() coroutine.yield(2) return 3 end) end) end)) + +-- error after yield tests +checkresults({ "yield", "return", false, "pcall.lua:80: foo" }, colog(function() return pcall(function() coroutine.yield() error("foo") end) end)) +checkresults({ "yield", "yield", "return", true, false, "pcall.lua:81: foo" }, colog(function() return pcall(function() coroutine.yield() return pcall(function() coroutine.yield() error("foo") end) end) end)) +checkresults({ "yield", "yield", "return", false, "pcall.lua:82: bar" }, colog(function() return pcall(function() coroutine.yield() pcall(function() coroutine.yield() error("foo") end) error("bar") end) end)) + +-- returning lots of results (past MINSTACK limits) +local res = {pcall(function() return table.unpack(table.create(100, 'a')) end)} +assert(#res == 101 and res[1] == true and res[2] == 'a' and res[101] == 'a') + +local res = {corun(function() return pcall(function() coroutine.yield() return table.unpack(table.create(100, 'a')) end) end)} +assert(#res == 102 and res[1] == true and res[2] == true and res[3] == 'a' and res[102] == 'a') + +-- pcall a C function after yield; resume gets multiple C entries this way +checkresults({ "yield", 1, 2, 3, "return", true }, colog(function() return pcall(coroutine.yield, 1, 2, 3) end)) +checkresults({ "yield", 1, 2, 3, "return", true, true, true }, colog(function() return pcall(pcall, pcall, coroutine.yield, 1, 2, 3) end)) +checkresults({ "yield", "return", true, true, true, 42 }, colog(function() return pcall(pcall, pcall, function() coroutine.yield() return 42 end) end)) + +-- xpcall basic tests, including yielding; xpcall uses the same infra as pcall so the main testing opportunity is for error handling +checkresults({ true, 42 }, xpcall(function() return 42 end, error)) +checkresults({ true, 1, 2, 42 }, xpcall(function(a, b) return a, b, 42 end, error, 1, 2)) +checkresults({ true, 2 }, xpcall(function(...) return select('#', ...) end, error, 1, 2)) +checkresults({ "yield", "return", true, 42 }, colog(function() return xpcall(function() coroutine.yield() return 42 end, error) end)) + +-- xpcall immediate error handling +checkresults({ false, "pcall.lua:103: foo" }, xpcall(function() error("foo") end, function(err) return err end)) +checkresults({ false, "bar" }, xpcall(function() error("foo") end, function(err) return "bar" end)) +checkresults({ false, 1 }, xpcall(function() error("foo") end, function(err) return 1, 2 end)) +checkresults({ false, "pcall.lua:106: foo\npcall.lua:106\npcall.lua:106\n" }, xpcall(function() error("foo") end, debug.traceback)) +checkresults({ false, "error in error handling" }, xpcall(function() error("foo") end, function(err) error("bar") end)) + +-- xpcall error handling after yields +checkresults({ "yield", "return", false, "pcall.lua:110: foo" }, colog(function() return xpcall(function() coroutine.yield() error("foo") end, function(err) return err end) end)) +checkresults({ "yield", "return", false, "pcall.lua:111: foo\npcall.lua:111\npcall.lua:111\n" }, colog(function() return xpcall(function() coroutine.yield() error("foo") end, debug.traceback) end)) + +-- xpcall error handling during error handling inside xpcall after yields +checkresults({ "yield", "return", true, false, "error in error handling" }, colog(function() return xpcall(function() return xpcall(function() coroutine.yield() error("foo") end, function(err) error("bar") end) end, error) end)) + +-- xpcall + pcall + yield +checkresults({"yield", 42, "return", true, true, true}, colog(function() return xpcall(pcall, function (...) return ... end, function() return pcall(function() coroutine.yield(42) end) end) end)) + +-- xpcall error +checkresults({ false, "missing argument #2 to 'xpcall' (function expected)" }, pcall(xpcall, function() return 42 end)) +checkresults({ false, "invalid argument #2 to 'xpcall' (function expected, got boolean)" }, pcall(xpcall, function() return 42 end, true)) + +-- stack overflow during coroutine resumption +function weird() +coroutine.yield(weird) +weird() +end + +checkresults({ false, "pcall.lua:129: cannot resume dead coroutine" }, pcall(function() for _ in coroutine.wrap(pcall), weird do end end)) + +-- c++ exception +checkresults({ false, "oops" }, pcall(cxxthrow)) + +-- resumeerror +local co = coroutine.create(function() + local ok, err = pcall(function() + coroutine.yield() + end) + coroutine.yield() + return ok, err +end) + +coroutine.resume(co) +resumeerror(co, "fail") +checkresults({ true, false, "fail" }, coroutine.resume(co)) + +return'OK' diff --git a/tests/conformance/pm.lua b/tests/conformance/pm.lua new file mode 100644 index 0000000..9a11396 --- /dev/null +++ b/tests/conformance/pm.lua @@ -0,0 +1,317 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing pattern matching') + +function f(s, p) + local i,e = string.find(s, p) + if i then return string.sub(s, i, e) end +end + +function f1(s, p) + p = string.gsub(p, "%%([0-9])", function (s) return "%" .. (s+1) end) + p = string.gsub(p, "^(^?)", "%1()", 1) + p = string.gsub(p, "($?)$", "()%1", 1) + local t = {string.match(s, p)} + return string.sub(s, t[1], t[#t] - 1) +end + +a,b = string.find('', '') -- empty patterns are tricky +assert(a == 1 and b == 0); +a,b = string.find('alo', '') +assert(a == 1 and b == 0) +a,b = string.find('a\0o a\0o a\0o', 'a', 1) -- first position +assert(a == 1 and b == 1) +a,b = string.find('a\0o a\0o a\0o', 'a\0o', 2) -- starts in the midle +assert(a == 5 and b == 7) +a,b = string.find('a\0o a\0o a\0o', 'a\0o', 9) -- starts in the midle +assert(a == 9 and b == 11) +a,b = string.find('a\0a\0a\0a\0\0ab', '\0ab', 2); -- finds at the end +assert(a == 9 and b == 11); +a,b = string.find('a\0a\0a\0a\0\0ab', 'b') -- last position +assert(a == 11 and b == 11) +assert(string.find('a\0a\0a\0a\0\0ab', 'b\0') == nil) -- check ending +assert(string.find('', '\0') == nil) +assert(string.find('alo123alo', '12') == 4) +assert(string.find('alo123alo', '^12') == nil) + +assert(f('aloALO', '%l*') == 'alo') +assert(f('aLo_ALO', '%a*') == 'aLo') + +assert(f(" \n\r*&\n\r xuxu \n\n", "%g%g%g+") == "xuxu") + +assert(f('aaab', 'a*') == 'aaa'); +assert(f('aaa', '^.*$') == 'aaa'); +assert(f('aaa', 'b*') == ''); +assert(f('aaa', 'ab*a') == 'aa') +assert(f('aba', 'ab*a') == 'aba') +assert(f('aaab', 'a+') == 'aaa') +assert(f('aaa', '^.+$') == 'aaa') +assert(f('aaa', 'b+') == nil) +assert(f('aaa', 'ab+a') == nil) +assert(f('aba', 'ab+a') == 'aba') +assert(f('a$a', '.$') == 'a') +assert(f('a$a', '.%$') == 'a$') +assert(f('a$a', '.$.') == 'a$a') +assert(f('a$a', '$$') == nil) +assert(f('a$b', 'a$') == nil) +assert(f('a$a', '$') == '') +assert(f('', 'b*') == '') +assert(f('aaa', 'bb*') == nil) +assert(f('aaab', 'a-') == '') +assert(f('aaa', '^.-$') == 'aaa') +assert(f('aabaaabaaabaaaba', 'b.*b') == 'baaabaaabaaab') +assert(f('aabaaabaaabaaaba', 'b.-b') == 'baaab') +assert(f('alo xo', '.o$') == 'xo') +assert(f(' \n isto assim', '%S%S*') == 'isto') +assert(f(' \n isto assim', '%S*$') == 'assim') +assert(f(' \n isto assim', '[a-z]*$') == 'assim') +assert(f('um caracter ? extra', '[^%sa-z]') == '?') +assert(f('', 'a?') == '') +assert(f('', '?') == '') +assert(f('bl', '?b?l?') == 'bl') +assert(f(' bl', '?b?l?') == '') +assert(f('aa', '^aa?a?a') == 'aa') +assert(f(']]]b', '[^]]') == '') +assert(f("0alo alo", "%x*") == "0a") +assert(f("alo alo", "%C+") == "alo alo") +print('+') + +assert(f1('alo alx 123 b\0o b\0o', '(..*) %1') == "b\0o b\0o") +assert(f1('axz123= 4= 4 34', '(.+)=(.*)=%2 %1') == '3= 4= 4 3') +assert(f1('=======', '^(=*)=%1$') == '=======') +assert(string.match('==========', '^([=]*)=%1$') == nil) + +local function range (i, j) + if i <= j then + return i, range(i+1, j) + end +end + +local abc = string.char(range(0, 255)); + +assert(string.len(abc) == 256) + +function strset (p) + local res = {s=''} + string.gsub(abc, p, function (c) res.s = res.s .. c end) + return res.s +end; + +assert(string.len(strset('[\200-\210]')) == 11) + +assert(strset('[a-z]') == "abcdefghijklmnopqrstuvwxyz") +assert(strset('[a-z%d]') == strset('[%da-uu-z]')) +assert(strset('[a-]') == "-a") +assert(strset('[^%W]') == strset('[%w]')) +assert(strset('[]%%]') == '%]') +assert(strset('[a%-z]') == '-az') +assert(strset('[%^%[%-a%]%-b]') == '-[]^ab') +assert(strset('%Z') == strset('[\1-\255]')) +assert(strset('.') == strset('[\1-\255%z]')) +print('+'); + +assert(string.match("alo xyzK", "(%w+)K") == "xyz") +assert(string.match("254 K", "(%d*)K") == "") +assert(string.match("alo ", "(%w*)$") == "") +assert(string.match("alo ", "(%w+)$") == nil) +assert(string.find("(lo)", "%(") == 1) +local a, b, c, d, e = string.match("lo alo", "^(((.).).* (%w*))$") +assert(a == 'lo alo' and b == 'l' and c == '' and d == 'alo' and e == nil) +a, b, c, d = string.match('0123456789', '(.+(.?)())') +assert(a == '0123456789' and b == '' and c == 11 and d == nil) +print('+') + +assert(string.gsub('lo lo', '', 'x') == 'xlo xlo') +assert(string.gsub('alo lo ', ' +$', '') == 'alo lo') -- trim +assert(string.gsub(' alo alo ', '^%s*(.-)%s*$', '%1') == 'alo alo') -- double trim +assert(string.gsub('alo alo \n 123\n ', '%s+', ' ') == 'alo alo 123 ') +t = "ab d" +a, b = string.gsub(t, '(.)', '%1@') +assert('@'..a == string.gsub(t, '', '@') and b == 5) +a, b = string.gsub('abd', '(.)', '%0@', 2) +assert(a == 'a@b@d' and b == 2) +assert(string.gsub('alo alo', '()[al]', '%1') == '12o 56o') +assert(string.gsub("abc=xyz", "(%w*)(%p)(%w+)", "%3%2%1-%0") == + "xyz=abc-abc=xyz") +assert(string.gsub("abc", "%w", "%1%0") == "aabbcc") +assert(string.gsub("abc", "%w+", "%0%1") == "abcabc") +assert(string.gsub('', '$', '\0') == '\0') +assert(string.gsub('', '^', 'r') == 'r') +assert(string.gsub('', '$', 'r') == 'r') +print('+') + +assert(string.gsub("um (dois) tres (quatro)", "(%(%w+%))", string.upper) == + "um (DOIS) tres (QUATRO)") + +do + local function setglobal (n,v) rawset(_G, n, v) end + string.gsub("a=roberto,roberto=a", "(%w+)=(%w%w*)", setglobal) + assert(_G.a=="roberto" and _G.roberto=="a") +end + +function f(a,b) return string.gsub(a,'.',b) end +assert(string.gsub("trocar tudo em |teste|b| |beleza|al|", "|([^|]*)|([^|]*)|", f) == + "trocar tudo em bbbbb alalalalalal") + +local function dostring (s) return loadstring(s)() or "" end +assert(string.gsub("alo $a=1$ novamente $return a$", "$([^$]*)%$", dostring) == + "alo novamente 1") + +x = string.gsub("$x=string.gsub('alo', '.', string.upper)$ assim vai para $return x$", + "$([^$]*)%$", dostring) +assert(x == ' assim vai para ALO') + +t = {} +s = 'a alo jose joao' +r = string.gsub(s, '()(%w+)()', function (a,w,b) + assert(string.len(w) == b-a); + t[a] = b-a; + end) +assert(s == r and t[1] == 1 and t[3] == 3 and t[7] == 4 and t[13] == 4) + + +function isbalanced (s) + return string.find(string.gsub(s, "%b()", ""), "[()]") == nil +end + +assert(isbalanced("(9 ((8))(\0) 7) \0\0 a b ()(c)() a")) +assert(not isbalanced("(9 ((8) 7) a b (\0 c) a")) +assert(string.gsub("alo 'oi' alo", "%b''", '"') == 'alo " alo') + + +local t = {"apple", "orange", "lime"; n=0} +assert(string.gsub("x and x and x", "x", function () t.n=t.n+1; return t[t.n] end) + == "apple and orange and lime") + +t = {n=0} +string.gsub("first second word", "%w%w*", function (w) t.n=t.n+1; t[t.n] = w end) +assert(t[1] == "first" and t[2] == "second" and t[3] == "word" and t.n == 3) + +t = {n=0} +assert(string.gsub("first second word", "%w+", + function (w) t.n=t.n+1; t[t.n] = w end, 2) == "first second word") +assert(t[1] == "first" and t[2] == "second" and t[3] == nil) + +assert(not pcall(string.gsub, "alo", "(.", print)) +assert(not pcall(string.gsub, "alo", ".)", print)) +assert(not pcall(string.gsub, "alo", "(.", {})) +assert(not pcall(string.gsub, "alo", "(.)", "%2")) +assert(not pcall(string.gsub, "alo", "(%1)", "a")) +assert(not pcall(string.gsub, "alo", "(%0)", "a")) + +-- big strings +local a = string.rep('a', 300000) +assert(string.find(a, '^a*.?$')) +assert(not string.find(a, '^a*.?b$')) +assert(string.find(a, '^a-.?$')) + +-- deep nest of gsubs +function rev (s) + return string.gsub(s, "(.)(.+)", function (c,s1) return rev(s1)..c end) +end + +local x = string.rep('012345', 10) +assert(rev(rev(x)) == x) + + +-- gsub with tables +assert(string.gsub("alo alo", ".", {}) == "alo alo") +assert(string.gsub("alo alo", "(.)", {a="AA", l=""}) == "AAo AAo") +assert(string.gsub("alo alo", "(.).", {a="AA", l="K"}) == "AAo AAo") +assert(string.gsub("alo alo", "((.)(.?))", {al="AA", o=false}) == "AAo AAo") + +assert(string.gsub("alo alo", "().", {2,5,6}) == "256 alo") + +t = {}; setmetatable(t, {__index = function (t,s) return string.upper(s) end}) +assert(string.gsub("a alo b hi", "%w%w+", t) == "a ALO b HI") + + +-- tests for gmatch +-- assert(string.gfind == string.gmatch) +local a = 0 +for i in string.gmatch('abcde', '()') do assert(i == a+1); a=i end +assert(a==6) + +t = {n=0} +for w in string.gmatch("first second word", "%w+") do + t.n=t.n+1; t[t.n] = w +end +assert(t[1] == "first" and t[2] == "second" and t[3] == "word") + +t = {3, 6, 9} +for i in string.gmatch ("xuxx uu ppar r", "()(.)%2") do + assert(i == table.remove(t, 1)) +end +assert(table.getn(t) == 0) + +t = {} +for i,j in string.gmatch("13 14 10 = 11, 15= 16, 22=23", "(%d+)%s*=%s*(%d+)") do + t[i] = j +end +a = 0 +for k,v in pairs(t) do assert(k+1 == v+0); a=a+1 end +assert(a == 3) + + +-- tests for `%f' (`frontiers') + +assert(string.gsub("aaa aa a aaa a", "%f[%w]a", "x") == "xaa xa x xaa x") +assert(string.gsub("[[]] [][] [[[[", "%f[[].", "x") == "x[]] x]x] x[[[") +assert(string.gsub("01abc45de3", "%f[%d]", ".") == ".01abc.45de.3") +assert(string.gsub("01abc45 de3x", "%f[%D]%w", ".") == "01.bc45 de3.") +assert(string.gsub("function", "%f[\1-\255]%w", ".") == ".unction") +assert(string.gsub("function", "%f[^\1-\255]", ".") == "function.") + +assert(string.find("a", "%f[a]") == 1) +assert(string.find("a", "%f[^%z]") == 1) +assert(string.find("a", "%f[^%l]") == 2) +assert(string.find("aba", "%f[a%z]") == 3) +assert(string.find("aba", "%f[%z]") == 4) +assert(not string.find("aba", "%f[%l%z]")) +assert(not string.find("aba", "%f[^%l%z]")) + +local i, e = string.find(" alo aalo allo", "%f[%S].-%f[%s].-%f[%S]") +assert(i == 2 and e == 5) +local k = string.match(" alo aalo allo", "%f[%S](.-%f[%s].-%f[%S])") +assert(k == 'alo ') + +local a = {1, 5, 9, 14, 17,} +for k in string.gmatch("alo alo th02 is 1hat", "()%f[%w%d]") do + assert(table.remove(a, 1) == k) +end +assert(#a == 0) + + +-- malformed patterns +local function malform (p, m) + m = m or "malformed" + local r, msg = pcall(string.find, "a", p) + assert(not r and string.find(msg, m)) +end + +malform("(.", "unfinished capture") +malform(".)", "invalid pattern capture") +malform("[a") +malform("[]") +malform("[^]") +malform("[a%]") +malform("[a%") +malform("%b") +malform("%ba") +malform("%") +malform("%f", "missing") + +-- \0 in patterns +assert(string.match("ab\0\1\2c", "[\0-\2]+") == "\0\1\2") +assert(string.match("ab\0\1\2c", "[\0-\0]+") == "\0") +assert(string.find("b$a", "$\0?") == 2) +assert(string.find("abc\0efg", "%\0") == 4) +assert(string.match("abc\0efg\0\1e\1g", "%b\0\1") == "\0efg\0\1e\1") +assert(string.match("abc\0\0\0", "%\0+") == "\0\0\0") +assert(string.match("abc\0\0\0", "%\0%\0?") == "\0\0") + +-- magic char after \0 +assert(string.find("abc\0\0","\0.") == 4) +assert(string.find("abcx\0\0abc\0abc","x\0\0abc\0a.") == 4) + +return('OK') diff --git a/tests/conformance/sort.lua b/tests/conformance/sort.lua new file mode 100644 index 0000000..95940e1 --- /dev/null +++ b/tests/conformance/sort.lua @@ -0,0 +1,75 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print"testing sort" + +function check (a, f) + f = f or function (x,y) return x 'alo\0alo\0') +assert('alo' < 'alo\0') +assert('alo\0' > 'alo') +assert('\0' < '\1') +assert('\0\0' < '\0\1') +assert('\1\0a\0a' <= '\1\0a\0a') +assert(not ('\1\0a\0b' <= '\1\0a\0a')) +assert('\0\0\0' < '\0\0\0\0') +assert(not('\0\0\0\0' < '\0\0\0')) +assert('\0\0\0' <= '\0\0\0\0') +assert(not('\0\0\0\0' <= '\0\0\0')) +assert('\0\0\0' <= '\0\0\0') +assert('\0\0\0' >= '\0\0\0') +assert(not ('\0\0b' < '\0\0a\0')) +print('+') + +assert(string.sub("123456789",2,4) == "234") +assert(string.sub("123456789",7) == "789") +assert(string.sub("123456789",7,6) == "") +assert(string.sub("123456789",7,7) == "7") +assert(string.sub("123456789",0,0) == "") +assert(string.sub("123456789",-10,10) == "123456789") +assert(string.sub("123456789",1,9) == "123456789") +assert(string.sub("123456789",-10,-20) == "") +assert(string.sub("123456789",-1) == "9") +assert(string.sub("123456789",-4) == "6789") +assert(string.sub("123456789",-6, -4) == "456") +assert(string.sub("\000123456789",3,5) == "234") +assert(("\000123456789"):sub(8) == "789") +print('+') + +assert(string.find("123456789", "345") == 3) +a,b = string.find("123456789", "345") +assert(string.sub("123456789", a, b) == "345") +assert(string.find("1234567890123456789", "345", 3) == 3) +assert(string.find("1234567890123456789", "345", 4) == 13) +assert(string.find("1234567890123456789", "346", 4) == nil) +assert(string.find("1234567890123456789", ".45", -9) == 13) +assert(string.find("abcdefg", "\0", 5, 1) == nil) +assert(string.find("", "") == 1) +assert(string.find('', 'aaa', 1) == nil) +assert(('alo(.)alo'):find('(.)', 1, 1) == 4) +assert(string.find('', '1', 2) == nil) +print('+') + +assert(string.len("") == 0) +assert(string.len("\0\0\0") == 3) +assert(string.len("1234567890") == 10) +assert(string.len(123) == 3) + +assert(#"" == 0) +assert(#"\0\0\0" == 3) +assert(#"1234567890" == 10) + +assert(string.byte("a") == 97) +assert(string.byte("") > 127) +assert(string.byte(string.char(255)) == 255) +assert(string.byte(string.char(0)) == 0) +assert(string.byte("\0") == 0) +assert(string.byte("\0\0alo\0x", -1) == string.byte('x')) +assert(string.byte("ba", 2) == 97) +assert(string.byte("\n\n", 2, -1) == 10) +assert(string.byte("\n\n", 2, 2) == 10) +assert(string.byte("") == nil) +assert(string.byte("hi", -3) == nil) +assert(string.byte("hi", 3) == nil) +assert(string.byte("hi", 9, 10) == nil) +assert(string.byte("hi", 2, 1) == nil) +assert(string.char() == "") +assert(string.char(0, 255, 0) == "\0\255\0") +assert(string.char(0, string.byte(""), 0) == "\0\0") +assert(string.char(string.byte("l\0u", 1, -1)) == "l\0u") +assert(string.char(string.byte("l\0u", 1, 0)) == "") +assert(string.char(string.byte("l\0u", -10, 100)) == "l\0u") +assert(pcall(function() return string.char(256) end) == false) +assert(pcall(function() return string.char(-1) end) == false) +print('+') + +assert(string.upper("ab\0c") == "AB\0C") +assert(string.lower("\0ABCc%$") == "\0abcc%$") +assert(string.rep('teste', 0) == '') +assert(string.rep('ts\00t', 2) == 'ts\0tts\000t') +assert(string.rep('', 10) == '') + +assert(string.reverse"" == "") +assert(string.reverse"\0\1\2\3" == "\3\2\1\0") +assert(string.reverse"\0001234" == "4321\0") + +for i=0,30 do assert(string.len(string.rep('a', i)) == i) end + +assert(type(tostring(nil)) == 'string') +assert(type(tostring(12)) == 'string') +assert(''..12 == '12' and type(12 .. '') == 'string') +assert(string.find(tostring{}, 'table:')) +assert(string.find(tostring(print), 'function:')) +assert(tostring(1234567890123) == '1234567890123') +assert(#tostring('\0') == 1) +assert(tostring(true) == "true") +assert(tostring(false) == "false") +print('+') + +x = '"lo"\n\\' +assert(string.format('%q%s', x, x) == '"\\"lo\\"\\\n\\\\""lo"\n\\') +assert(string.format('%q', "\0") == [["\000"]]) +assert(string.format('%q', "\r") == [["\r"]]) +assert(string.format("\0%c\0%c%x\0", string.byte(""), string.byte("b"), 140) == + "\0\0b8c\0") +assert(string.format('') == "") +assert(string.format("%c",34)..string.format("%c",48)..string.format("%c",90)..string.format("%c",100) == + string.format("%c%c%c%c", 34, 48, 90, 100)) +assert(string.format("%s\0 is not \0%s", 'not be', 'be') == 'not be\0 is not \0be') +assert(string.format("%%%d %010d", 10, 23) == "%10 0000000023") +assert(tonumber(string.format("%f", 10.3)) == 10.3) +x = string.format('"%-50s"', 'a') +assert(#x == 52) +assert(string.sub(x, 1, 4) == '"a ') + +assert(string.format("-%.20s.20s", string.rep("%", 2000)) == "-"..string.rep("%", 20)..".20s") +assert(string.format('"-%20s.20s"', string.rep("%", 2000)) == + string.format("%q", "-"..string.rep("%", 2000)..".20s")) + + +-- longest number that can be formated +assert(string.len(string.format('%99.99f', -1e308)) >= 100) + +assert(loadstring("return 1\n--comentrio sem EOL no final")() == 1) + + +assert(table.concat{} == "") +assert(table.concat({}, 'x') == "") +assert(table.concat({'\0', '\0\1', '\0\1\2'}, '.\0.') == "\0.\0.\0\1.\0.\0\1\2") +local a = {}; for i=1,3000 do a[i] = "xuxu" end +assert(table.concat(a, "123").."123" == string.rep("xuxu123", 3000)) +assert(table.concat(a, "b", 20, 20) == "xuxu") +assert(table.concat(a, "", 20, 21) == "xuxuxuxu") +assert(table.concat(a, "", 22, 21) == "") +assert(table.concat(a, "3", 2999) == "xuxu3xuxu") + +a = {"a","b","c"} +assert(table.concat(a, ",", 1, 0) == "") +assert(table.concat(a, ",", 1, 1) == "a") +assert(table.concat(a, ",", 1, 2) == "a,b") +assert(table.concat(a, ",", 2) == "b,c") +assert(table.concat(a, ",", 3) == "c") +assert(table.concat(a, ",", 4) == "") + +--[[ +local locales = { "ptb", "ISO-8859-1", "pt_BR" } +local function trylocale (w) + for _, l in ipairs(locales) do + if os.setlocale(l, w) then return true end + end + return false +end + +if not trylocale("collate") then + print("locale not supported") +else + assert("alo" < "lo" and "lo" < "amo") +end + +if not trylocale("ctype") then + print("locale not supported") +else + assert(string.gsub("", "%a", "x") == "xxxxx") + assert(string.gsub("", "%l", "x") == "xx") + assert(string.gsub("", "%u", "x") == "xx") + assert(string.upper"{xuxu}o" == "{XUXU}O") +end + +os.setlocale("C") +assert(os.setlocale() == 'C') +assert(os.setlocale(nil, "numeric") == 'C') +]]-- + +return('OK') + + diff --git a/tests/conformance/tpack.lua b/tests/conformance/tpack.lua new file mode 100644 index 0000000..835bf56 --- /dev/null +++ b/tests/conformance/tpack.lua @@ -0,0 +1,327 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +local pack = string.pack +local packsize = string.packsize +local unpack = string.unpack + +print "testing pack/unpack" + +-- maximum size for integers +local NB = 16 + +local sizeshort = packsize("h") +local sizeint = packsize("i") +local sizelong = packsize("l") +local sizesize_t = packsize("T") +local sizeLI = packsize("j") +local sizeMI = packsize("l") +local sizefloat = packsize("f") +local sizedouble = packsize("d") +local sizenumber = packsize("n") +local little = (pack("i2", 1) == "\1\0") +local align = packsize("!xXi16") + +assert(1 <= sizeshort and sizeshort <= sizeint and sizeint <= sizelong and + sizefloat <= sizedouble) + +print("platform:") +print(string.format( + "\tshort %d, int %d, long %d, size_t %d, float %d, double %d,\n\z + \tlua Integer %d, lua Number %d", + sizeshort, sizeint, sizelong, sizesize_t, sizefloat, sizedouble, + sizeLI, sizenumber)) +print("\t" .. (little and "little" or "big") .. " endian") +print("\talignment: " .. align) + + +-- check errors in arguments +function checkerror (msg, f, ...) + local status, err = pcall(f, ...) + -- print(status, err, msg) + assert(not status and string.find(err, msg)) +end + +-- minimum behavior for integer formats +assert(unpack("B", pack("B", 0xff)) == 0xff) +assert(unpack("b", pack("b", 0x7f)) == 0x7f) +assert(unpack("b", pack("b", -0x80)) == -0x80) + +assert(unpack("H", pack("H", 0xffff)) == 0xffff) +assert(unpack("h", pack("h", 0x7fff)) == 0x7fff) +assert(unpack("h", pack("h", -0x8000)) == -0x8000) + +assert(unpack("L", pack("L", 0xffffffff)) == 0xffffffff) +assert(unpack("l", pack("l", 0x7fffffff)) == 0x7fffffff) +assert(unpack("l", pack("l", -0x80000000)) == -0x80000000) + +assert(unpack("J", pack("J", 0xffffffff)) == 0xffffffff) +assert(unpack("j", pack("j", 0x7fffffff)) == 0x7fffffff) +assert(unpack("j", pack("j", -0x80000000)) == -0x80000000) + +for i = 1, NB do + -- small numbers with signal extension ("\xFF...") + local s = string.rep("\xff", i) + assert(pack("i" .. i, -1) == s) + assert(packsize("i" .. i) == #s) + assert(unpack("i" .. i, s) == -1) + + -- small unsigned number ("\0...\xAA") + s = "\xAA" .. string.rep("\0", i - 1) + assert(pack("I" .. i, 0xAA) == s:reverse()) + assert(unpack(">I" .. i, s:reverse()) == 0xAA) +end + +do + local lnum = 0x060504030201 -- 48-bit + local s = pack("i" .. i, ("\xFF"):rep(i - sizeMI) .. s:reverse()) == -lnum) + assert(unpack("i" .. i, "\1" .. ("\x00"):rep(i - 1)) + end +end + +for i = 1, 4 do + local lstr = "\1\2\3\4" + local lnum = 0x04030201 + local n = bit32.band(lnum, bit32.bnot(bit32.lshift(-1, i * 8))) + local s = string.sub(lstr, 1, i) + assert(pack("i" .. i, n) == s:reverse()) + assert(unpack(">i" .. i, s:reverse()) == n) +end + +-- sign extension +do + local u = 0xf0 + for i = 1, sizeLI - 1 do + assert(unpack("I"..i, "\xf0"..("\xff"):rep(i - 1)) == u) + u = u * 256 + 0xff + end +end + +-- mixed endianness +do + assert(pack(">i2 i2", "\10\0\0\20") + assert(a == 10 and b == 20) + assert(pack("=i4", 2001) == pack("i4", 2001)) +end + +print("testing invalid formats") + +checkerror("out of limits", pack, "i0", 0) +checkerror("out of limits", pack, "i" .. NB + 1, 0) +checkerror("out of limits", pack, "!" .. NB + 1, 0) +checkerror("%(17%) out of limits %[1,16%]", pack, "Xi" .. NB + 1) +checkerror("invalid format option 'r'", pack, "i3r", 0) +checkerror("16%-byte integer", unpack, "i16", string.rep('\3', 16)) +checkerror("not power of 2", pack, "!4i3", 0); +checkerror("missing size", pack, "c", "") +checkerror("variable%-length format", packsize, "s") +checkerror("variable%-length format", packsize, "z") + +if packsize("i") == 4 then + -- result would be 2^31 (2^3 repetitions of 2^28 strings) + local s = string.rep("c268435456", 2^3) + checkerror("too large", packsize, s) + -- one less would be OK... except that we limit string sizes to 1GB + s = string.rep("c268435456", 2^3 - 1) .. "c268435453" + checkerror("too large", packsize, s) + -- but 1GB is reachable + assert(packsize("c1073741824") == 2^30) +end + +-- overflow in packing +for i = 1, 3 do + local umax = bit32.lshift(1, i * 8) - 1 + local max = bit32.rshift(umax, 1) + local min = -max-1 + checkerror("overflow", pack, "I" .. i, umax + 1) + + checkerror("overflow", pack, ">i" .. i, umax) + checkerror("overflow", pack, ">i" .. i, max + 1) + checkerror("overflow", pack, "i" .. i, pack(">i" .. i, max)) == max) + assert(unpack("I" .. i, pack(">I" .. i, umax)) == umax) +end + +if little then + assert(pack("f", 24) == pack("f", 24)) +end + +print "testing pack/unpack of floating-point numbers" + +for _, n in ipairs{0, -1.1, 1.9, 1/0, -1/0, 1e20, -1e20, 0.1, 2000.7} do + assert(unpack("n", pack("n", n)) == n) + assert(unpack("n", pack(">n", n)) == n) + assert(pack("f", n):reverse()) + assert(pack(">d", n) == pack("f", pack(">f", n)) == n) + assert(unpack("d", pack(">d", n)) == n) +end + +print "testing pack/unpack of strings" +do + local s = string.rep("abc", 1000) + assert(pack("zB", s, 247) == s .. "\0\xF7") + local s1, b = unpack("zB", s .. "\0\xF9") + assert(b == 249 and s1 == s) + s1 = pack("s", s) + assert(unpack("s", s1) == s) + + checkerror("does not fit", pack, "s1", s) + + checkerror("contains zeros", pack, "z", "alo\0"); + + checkerror("unfinished string", unpack, "zc10000000", "alo") + + for i = 2, NB do + local s1 = pack("s" .. i, s) + assert(unpack("s" .. i, s1) == s and #s1 == #s + i) + end +end + +do + local x = pack("s", "alo") + checkerror("too short", unpack, "s", x:sub(1, -2)) + checkerror("too short", unpack, "c5", "abcd") + checkerror("out of limits", pack, "s100", "alo") +end + +do + assert(pack("c0", "") == "") + assert(packsize("c0") == 0) + assert(unpack("c0", "") == "") + assert(pack("!4 c6", "abcdef") == "abcdef") + assert(pack("c3", "123") == "123") + assert(pack("c0", "") == "") + assert(pack("c8", "123456") == "123456\0\0") + assert(pack("c88", "") == string.rep("\0", 88)) + assert(pack("c188", "ab") == "ab" .. string.rep("\0", 188 - 2)) + local a, b, c = unpack("!4 z c3", "abcdefghi\0xyz") + assert(a == "abcdefghi" and b == "xyz" and c == 14) + checkerror("longer than", pack, "c3", "1234") +end + + +-- testing multiple types and sequence +do + local x = pack("!8 b Xh i4 i8 c1 Xi8", -12, 100, 200, "\xEC") + assert(#x == packsize(">!8 b Xh i4 i8 c1 Xi8")) + assert(x == "\xf4" .. "\0\0\0" .. + "\0\0\0\100" .. + "\0\0\0\0\0\0\0\xC8" .. + "\xEC" .. "\0\0\0\0\0\0\0") + local a, b, c, d, pos = unpack(">!8 c1 Xh i4 i8 b Xi8 XI XH", x) + assert(a == "\xF4" and b == 100 and c == 200 and d == -20 and (pos - 1) == #x) + + x = pack(">!4 c3 c4 c2 z i4 c5 c2 Xi4", + "abc", "abcd", "xz", "hello", 5, "world", "xy") + assert(x == "abcabcdxzhello\0\0\0\0\0\5worldxy\0") + local a, b, c, d, e, f, g, pos = unpack(">!4 c3 c4 c2 z i4 c5 c2 Xh Xi4", x) + assert(a == "abc" and b == "abcd" and c == "xz" and d == "hello" and + e == 5 and f == "world" and g == "xy" and (pos - 1) % 4 == 0) + + x = pack(" b b Xd b Xb x", 1, 2, 3) + assert(packsize(" b b Xd b Xb x") == 4) + assert(x == "\1\2\3\0") + a, b, c, pos = unpack("bbXdb", x) + assert(a == 1 and b == 2 and c == 3 and pos == #x) + + -- only alignment + assert(packsize("!8 xXi8") == 8) + local pos = unpack("!8 xXi8", "0123456701234567"); assert(pos == 9) + assert(packsize("!8 xXi2") == 2) + local pos = unpack("!8 xXi2", "0123456701234567"); assert(pos == 3) + assert(packsize("!2 xXi2") == 2) + local pos = unpack("!2 xXi2", "0123456701234567"); assert(pos == 3) + assert(packsize("!2 xXi8") == 2) + local pos = unpack("!2 xXi8", "0123456701234567"); assert(pos == 3) + assert(packsize("!16 xXi16") == 16) + local pos = unpack("!16 xXi16", "0123456701234567"); assert(pos == 17) + + checkerror("invalid next option", pack, "X") + checkerror("invalid next option", unpack, "XXi", "") + checkerror("invalid next option", unpack, "X i", "") + checkerror("invalid next option", pack, "Xc1") +end + +do -- testing initial position + local x = pack("i4i4i4i4", 1, 2, 3, 4) + for pos = 1, 16, 4 do + local i, p = unpack("i4", x, pos) + assert(i == math.floor(pos/4) + 1 and p == pos + 4) + end + + -- with alignment + for pos = 0, 12 do -- will always round position to power of 2 + local i, p = unpack("!4 i4", x, pos + 1) + assert(i == math.floor((pos + 3)/4) + 1 and p == i*4 + 1) + end + + -- negative indices + local i, p = unpack("!4 i4", x, -4) + assert(i == 4 and p == 17) + local i, p = unpack("!4 i4", x, -7) + assert(i == 4 and p == 17) + local i, p = unpack("!4 i4", x, -#x) + assert(i == 1 and p == 5) + + -- limits + for i = 1, #x + 1 do + assert(unpack("c0", x, i) == "") + end + checkerror("out of string", unpack, "c0", x, #x + 2) + +end + +do -- testing out of range values + checkerror("out of limits", unpack, "i17", "") + checkerror("out of limits", unpack, "i987654321", "") + checkerror("too large", unpack, "i9876543210", "") + checkerror("too large", unpack, "c9876543210", "") + checkerror("too large", packsize, "c1" .. string.rep("0", 40)) + checkerror("missing size", unpack, "c-2", "") +end + +return "OK" diff --git a/tests/conformance/types.lua b/tests/conformance/types.lua new file mode 100644 index 0000000..cdddcee --- /dev/null +++ b/tests/conformance/types.lua @@ -0,0 +1,56 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print "testing builtin types" + +local ignore = +{ + -- these are permanently ignored, as they are only exposed in tests + "_G.limitedstack", + "_G.RTTI", + "_G.collectgarbage", + + -- what follows is a set of mismatches that hopefully eventually will go down to 0 + "_G.require", -- need to move to Roblox type defs + "_G.utf8.nfcnormalize", -- need to move to Roblox type defs + "_G.utf8.nfdnormalize", -- need to move to Roblox type defs + "_G.utf8.graphemes", -- need to move to Roblox type defs +} + +function verify(real, rtti, path) + if table.find(ignore, path) then + return + end + + if real and rtti then + if type(real) == "table" then + assert(type(rtti) == "table", path .. " is not a table in type information") + + local keys = {} + + for k, v in pairs(real) do + keys[k] = 1 + end + + for k, v in pairs(rtti) do + keys[k] = 1 + end + + for k, v in pairs(keys) do + if k ~= "_G" then + verify(real[k], rtti[k], path .. '.' .. k) + end + end + else + assert(type(real) == rtti, path .. " has inconsistent types (" .. type(real) .. " vs " .. rtti .. ")") + end + else + if not rtti then + assert(false, path .. " missing from type information") + else + assert(false, path .. " present in type information but absent from VM") + end + end +end + +verify(getmetatable(_G).__index, RTTI, "_G") + +return 'OK' diff --git a/tests/conformance/utf8.lua b/tests/conformance/utf8.lua new file mode 100644 index 0000000..024cb16 --- /dev/null +++ b/tests/conformance/utf8.lua @@ -0,0 +1,208 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print "testing UTF-8 library" + +local function checkerror (msg, f, ...) + local s, err = pcall(f, ...) + assert(not s and string.find(err, msg)) +end + + +local function len (s) + return #string.gsub(s, "[\x80-\xBF]", "") +end + + +local justone = "^" .. utf8.charpattern .. "$" + +assert(not utf8.offset("alo", 5)) +assert(not utf8.offset("alo", -4)) + +-- 'check' makes several tests over the validity of string 's'. +-- 't' is the list of codepoints of 's'. +local function check (s, t, nonstrict) + local l = utf8.len(s, 1, -1, nonstrict) + assert(#t == l and len(s) == l) + assert(utf8.char(table.unpack(t)) == s) -- 't' and 's' are equivalent + + assert(utf8.offset(s, 0) == 1) + + -- creates new table with all codepoints of 's' + local t1 = {utf8.codepoint(s, 1, -1, nonstrict)} + assert(#t == #t1) + for i = 1, #t do assert(t[i] == t1[i]) end -- 't' is equal to 't1' + + for i = 1, l do -- for all codepoints + local pi = utf8.offset(s, i) -- position of i-th char + local pi1 = utf8.offset(s, 2, pi) -- position of next char + assert(string.find(string.sub(s, pi, pi1 - 1), justone)) + assert(utf8.offset(s, -1, pi1) == pi) + assert(utf8.offset(s, i - l - 1) == pi) + assert(pi1 - pi == #utf8.char(utf8.codepoint(s, pi, pi, nonstrict))) + for j = pi, pi1 - 1 do + assert(utf8.offset(s, 0, j) == pi) + end + for j = pi + 1, pi1 - 1 do + assert(not utf8.len(s, j)) + end + assert(utf8.len(s, pi, pi, nonstrict) == 1) + assert(utf8.len(s, pi, pi1 - 1, nonstrict) == 1) + assert(utf8.len(s, pi, -1, nonstrict) == l - i + 1) + assert(utf8.len(s, pi1, -1, nonstrict) == l - i) + assert(utf8.len(s, 1, pi, nonstrict) == i) + end + + local i = 0 + for p, c in utf8.codes(s, nonstrict) do + i = i + 1 + assert(c == t[i] and p == utf8.offset(s, i)) + assert(utf8.codepoint(s, p, p, nonstrict) == c) + end + assert(i == #t) + + i = 0 + for c in string.gmatch(s, utf8.charpattern) do + i = i + 1 + assert(c == utf8.char(t[i])) + end + assert(i == #t) + + for i = 1, l do + assert(utf8.offset(s, i) == utf8.offset(s, i - l - 1, #s + 1)) + end + +end + + +do -- error indication in utf8.len + local function check (s, p) + local a, b = utf8.len(s) + assert(not a and b == p) + end + check("abc\xE3def", 4) + check("汉字\x80", #("汉字") + 1) + check("\xF4\x9F\xBF", 1) + check("\xF4\x9F\xBF\xBF", 1) +end + +-- errors in utf8.codes +do + local function errorcodes (s) + checkerror("invalid UTF%-8 code", + function () + for c in utf8.codes(s) do assert(c) end + end) + end + errorcodes("ab\xff") + -- errorcodes("\u{110000}") +end + +-- error in initial position for offset +checkerror("position out of range", utf8.offset, "abc", 1, 5) +checkerror("position out of range", utf8.offset, "abc", 1, -4) +checkerror("position out of range", utf8.offset, "", 1, 2) +checkerror("position out of range", utf8.offset, "", 1, -1) +checkerror("continuation byte", utf8.offset, "𦧺", 1, 2) +checkerror("continuation byte", utf8.offset, "𦧺", 1, 2) +checkerror("continuation byte", utf8.offset, "\x80", 1) + +-- error in indices for len +checkerror("out of string", utf8.len, "abc", 0, 2) +checkerror("out of string", utf8.len, "abc", 1, 4) + + +local s = "hello World" +local t = {string.byte(s, 1, -1)} +for i = 1, utf8.len(s) do assert(t[i] == string.byte(s, i)) end +check(s, t) + +check("汉字/漢字", {27721, 23383, 47, 28450, 23383,}) + +do + local s = "áéí\128" + local t = {utf8.codepoint(s,1,#s - 1)} + assert(#t == 3 and t[1] == 225 and t[2] == 233 and t[3] == 237) + checkerror("invalid UTF%-8 code", utf8.codepoint, s, 1, #s) + checkerror("out of range", utf8.codepoint, s, #s + 1) + t = {utf8.codepoint(s, 4, 3)} + assert(#t == 0) + checkerror("out of range", utf8.codepoint, s, -(#s + 1), 1) + checkerror("out of range", utf8.codepoint, s, 1, #s + 1) + -- surrogates + assert(utf8.codepoint("\u{D7FF}") == 0xD800 - 1) + assert(utf8.codepoint("\u{E000}") == 0xDFFF + 1) + assert(utf8.codepoint("\u{D800}", 1, 1, true) == 0xD800) + assert(utf8.codepoint("\u{DFFF}", 1, 1, true) == 0xDFFF) + -- assert(utf8.codepoint("\u{7FFFFFFF}", 1, 1, true) == 0x7FFFFFFF) +end + +assert(utf8.char() == "") +assert(utf8.char(0, 97, 98, 99, 1) == "\0abc\1") + +assert(utf8.codepoint(utf8.char(0x10FFFF)) == 0x10FFFF) +-- assert(utf8.codepoint(utf8.char(0x7FFFFFFF), 1, 1, true) == 2147483647) + +checkerror("value out of range", utf8.char, 0x7FFFFFFF + 1) +checkerror("value out of range", utf8.char, -1) + +local function invalid (s) + checkerror("invalid UTF%-8 code", utf8.codepoint, s) + assert(not utf8.len(s)) +end + +-- UTF-8 representation for 0x11ffff (value out of valid range) +invalid("\xF4\x9F\xBF\xBF") + +-- surrogates +-- invalid("\u{D800}") +-- invalid("\u{DFFF}") + +-- overlong sequences +invalid("\xC0\x80") -- zero +invalid("\xC1\xBF") -- 0x7F (should be coded in 1 byte) +invalid("\xE0\x9F\xBF") -- 0x7FF (should be coded in 2 bytes) +invalid("\xF0\x8F\xBF\xBF") -- 0xFFFF (should be coded in 3 bytes) + + +-- invalid bytes +invalid("\x80") -- continuation byte +invalid("\xBF") -- continuation byte +invalid("\xFE") -- invalid byte +invalid("\xFF") -- invalid byte + + +-- empty string +check("", {}) + +-- minimum and maximum values for each sequence size +s = "\0 \x7F\z + \xC2\x80 \xDF\xBF\z + \xE0\xA0\x80 \xEF\xBF\xBF\z + \xF0\x90\x80\x80 \xF4\x8F\xBF\xBF" +s = string.gsub(s, " ", "") +check(s, {0,0x7F, 0x80,0x7FF, 0x800,0xFFFF, 0x10000,0x10FFFF}) + +x = "日本語a-4\0éó" +check(x, {26085, 26412, 35486, 97, 45, 52, 0, 233, 243}) + + +-- Supplementary Characters +check("𣲷𠜎𠱓𡁻𠵼ab𠺢", + {0x23CB7, 0x2070E, 0x20C53, 0x2107B, 0x20D7C, 0x61, 0x62, 0x20EA2,}) + +check("𨳊𩶘𦧺𨳒𥄫𤓓\xF4\x8F\xBF\xBF", + {0x28CCA, 0x29D98, 0x269FA, 0x28CD2, 0x2512B, 0x244D3, 0x10ffff}) + + +local i = 0 +for p, c in string.gmatch(x, "()(" .. utf8.charpattern .. ")") do + i = i + 1 + assert(utf8.offset(x, i) == p) + assert(utf8.len(x, p) == utf8.len(x) - i + 1) + assert(utf8.len(c) == 1) + for j = 1, #c - 1 do + assert(utf8.offset(x, 0, p + j - 1) == p) + end +end + +return'OK' diff --git a/tests/conformance/vararg.lua b/tests/conformance/vararg.lua new file mode 100644 index 0000000..5aaa422 --- /dev/null +++ b/tests/conformance/vararg.lua @@ -0,0 +1,137 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print('testing vararg') + +local unpack = table.unpack + +_G.arg = nil + +function f(a, ...) + local arg = { n=select('#',...), ... } + if a then + assert(#arg == #a) + else + assert(#arg == 0) + end + for i=1,#arg do assert(a[i]==arg[i]) end + return arg.n +end + +function c12 (...) + local x = {...}; x.n = table.getn(x) + local res = (x.n==2 and x[1] == 1 and x[2] == 2) + if res then res = 55 end + return res, 2 +end + +function vararg (...) return { n=select('#',...), ... } end + +local call = function (f, args) return f(unpack(args, 1, args.n)) end + +assert(f() == 0) +assert(f({1,2,3}, 1, 2, 3) == 3) +assert(f({"alo", nil, 45, f, nil}, "alo", nil, 45, f, nil) == 5) + +assert(c12(1,2)==55) +a,b = assert(call(c12, {1,2})) +assert(a == 55 and b == 2) +a = call(c12, {1,2;n=2}) +assert(a == 55 and b == 2) +a = call(c12, {1,2;n=1}) +assert(not a) +assert(c12(1,2,3) == false) +local G = {foo=1,bar=2,foobar=3} +local a = vararg(call(next, {G,nil;n=2})) +local b,c = next(G) +assert(a[1] == b and a[2] == c and a.n == 2) +a = vararg(call(call, {c12, {1,2}})) +assert(a.n == 2 and a[1] == 55 and a[2] == 2) +a = call(print, {'+'}) +assert(a == nil) + +local t = {1, 10} +function t:f (...) + local arg = { n=select('#',...), ... } + return self[arg[1]]+arg.n +end +assert(t:f(1,4) == 3 and t:f(2) == 11) +print('+') + +lim = 20 +local i, a = 1, {} +while i <= lim do a[i] = i+0.3; i=i+1 end + +function f(a, b, c, d, ...) + local more = {...} + assert(a == 1.3 and more[1] == 5.3 and + more[lim-4] == lim+0.3 and not more[lim-3]) +end + +function g(a,b,c) + assert(a == 1.3 and b == 2.3 and c == 3.3) +end + +call(f, a) +call(g, a) + +a = {} +i = 1 +while i <= lim do a[i] = i; i=i+1 end +assert(call(math.max, a) == lim) + +print("+") + + +-- new-style varargs + +function oneless (a, ...) return ... end + +function f (n, a, ...) + local b + assert(arg == nil) + if n == 0 then + local b, c, d = ... + return a, b, c, d, oneless(oneless(oneless(...))) + else + n, b, a = n-1, ..., a + assert(b == ...) + return f(n, a, ...) + end +end + +a,b,c,d,e = assert(f(10,5,4,3,2,1)) +assert(a==5 and b==4 and c==3 and d==2 and e==1) + +a,b,c,d,e = f(4) +assert(a==nil and b==nil and c==nil and d==nil and e==nil) + + +-- varargs for main chunks +f = loadstring[[ return {...} ]] +x = f(2,3) +assert(x[1] == 2 and x[2] == 3 and x[3] == nil) + + +f = loadstring[[ + local x = {...} + for i=1,select('#', ...) do assert(x[i] == select(i, ...)) end + assert(x[select('#', ...)+1] == nil) + return true +]] + +assert(f("a", "b", nil, {}, assert)) +assert(f()) + +a = {select(3, unpack{10,20,30,40})} +assert(table.getn(a) == 2 and a[1] == 30 and a[2] == 40) +a = {select(1)} +assert(next(a) == nil) +a = {select(-1, 3, 5, 7)} +assert(a[1] == 7 and a[2] == nil) +a = {select(-2, 3, 5, 7)} +assert(a[1] == 5 and a[2] == 7 and a[3] == nil) +pcall(select, 10000) +pcall(select, -10000) + +return('OK') + diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua new file mode 100644 index 0000000..620f646 --- /dev/null +++ b/tests/conformance/vector.lua @@ -0,0 +1,74 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print('testing vectors') + +-- equality +assert(vector(1, 2, 3) == vector(1, 2, 3)) +assert(vector(0, 1, 2) == vector(-0, 1, 2)) +assert(vector(1, 2, 3) ~= vector(1, 2, 4)) + +-- rawequal +assert(rawequal(vector(1, 2, 3), vector(1, 2, 3))) +assert(rawequal(vector(0, 1, 2), vector(-0, 1, 2))) +assert(not rawequal(vector(1, 2, 3), vector(1, 2, 4))) + +-- type & tostring +assert(type(vector(1, 2, 3)) == "vector") +assert(tostring(vector(1, 2, 3)) == "1, 2, 3") +assert(tostring(vector(-1, 2, 0.5)) == "-1, 2, 0.5") + +local t = {} + +-- basic table access +t[vector(1, 2, 3)] = 42 +assert(t[vector(1, 2, 3)] == 42) +assert(t[vector(1, 2, 4)] == nil) + +-- negative zero should hash the same as zero +assert(t[vector(0, 0, 0)] == nil) +t[vector(0, 0, 0)] = "hello" +assert(t[vector(0, 0, 0)] == "hello") +assert(t[vector(0, -0, 0)] == "hello") + +-- test arithmetic instructions +assert(vector(1, 2, 4) + vector(8, 16, 24) == vector(9, 18, 28)); +assert(vector(1, 2, 4) - vector(8, 16, 24) == vector(-7, -14, -20)); + +local val = 1/'8' + +assert(vector(1, 2, 4) * vector(8, 16, 24) == vector(8, 32, 96)); +assert(vector(1, 2, 4) * 8 == vector(8, 16, 32)); +assert(vector(1, 2, 4) * (1 / val) == vector(8, 16, 32)); +assert(8 * vector(8, 16, 24) == vector(64, 128, 192)); +assert(vector(1, 2, 4) * '8' == vector(8, 16, 32)); +assert('8' * vector(8, 16, 24) == vector(64, 128, 192)); + +assert(vector(1, 2, 4) / vector(8, 16, 24) == vector(1/8, 2/16, 4/24)); +assert(vector(1, 2, 4) / 8 == vector(1/8, 1/4, 1/2)); +assert(vector(1, 2, 4) / (1 / val) == vector(1/8, 2/8, 4/8)); +assert(8 / vector(8, 16, 24) == vector(1, 1/2, 1/3)); +assert(vector(1, 2, 4) / '8' == vector(1/8, 1/4, 1/2)); +assert('8' / vector(8, 16, 24) == vector(1, 1/2, 1/3)); + +assert(-vector(1, 2, 4) == vector(-1, -2, -4)); + +-- test NaN comparison +local nanv = vector(0/0, 0/0, 0/0) +assert(nanv ~= nanv); + +-- __index +assert(vector(1, 2, 2).Magnitude == 3) +assert(vector(0, 0, 0)['Dot'](vector(1, 2, 4), vector(5, 6, 7)) == 45) + +-- __namecall +assert(vector(1, 2, 4):Dot(vector(5, 6, 7)) == 45) + +-- can't use vector with NaN components as table key +assert(pcall(function() local t = {} t[vector(0/0, 2, 3)] = 1 end) == false) +assert(pcall(function() local t = {} t[vector(1, 0/0, 3)] = 1 end) == false) +assert(pcall(function() local t = {} t[vector(1, 2, 0/0)] = 1 end) == false) +assert(pcall(function() local t = {} rawset(t, vector(0/0, 2, 3), 1) end) == false) + +-- make sure we cover both builtin and C impl +assert(vector(1, 2, 4) == vector("1", "2", "4")) + +return 'OK' diff --git a/tests/main.cpp b/tests/main.cpp new file mode 100644 index 0000000..ed17070 --- /dev/null +++ b/tests/main.cpp @@ -0,0 +1,267 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" + +#define DOCTEST_CONFIG_IMPLEMENT +#include "doctest.h" + +#ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include // IsDebuggerPresent +#endif + +#ifdef __APPLE__ +#include +#include +#endif + +#include + +static bool skipFastFlag(const char* flagName) +{ + if (strncmp(flagName, "Test", 4) == 0) + return true; + + if (strncmp(flagName, "Debug", 5) == 0) + return true; + + return false; +} + +static bool debuggerPresent() +{ +#if defined(_WIN32) + return 0 != IsDebuggerPresent(); +#elif defined(__APPLE__) + // ask sysctl information about a specific process ID + int mib[4] = {CTL_KERN, KERN_PROC, KERN_PROC_PID, getpid()}; + kinfo_proc info = {}; + size_t size = sizeof(info); + int ret = sysctl(mib, sizeof(mib) / sizeof(*mib), &info, &size, nullptr, 0); + // debugger is attached if the P_TRACED flag is set + return ret == 0 && (info.kp_proc.p_flag & P_TRACED) != 0; +#else + return false; // assume no debugger is attached. +#endif +} + +static int assertionHandler(const char* expr, const char* file, int line) +{ + if (debuggerPresent()) + LUAU_DEBUGBREAK(); + + ADD_FAIL_AT(file, line, "Assertion failed: ", expr); + return 1; +} + +struct BoostLikeReporter : doctest::IReporter +{ + const doctest::TestCaseData* currentTest = nullptr; + + BoostLikeReporter(const doctest::ContextOptions& in) {} + + // called when a query should be reported (listing test cases, printing the version, etc.) + void report_query(const doctest::QueryData& qd) override + { + for (unsigned int i = 0; i < qd.num_data; ++i) + { + const doctest::TestCaseData& tc = *qd.data[i]; + + fprintf(stderr, "%s/%s\n", tc.m_test_suite, tc.m_name); + } + + fprintf(stderr, "Found %d tests.\n", int(qd.num_data)); + } + + // called when the whole test run starts/ends + void test_run_start() override {} + + void test_run_end(const doctest::TestRunStats& ts) override {} + + // called when a test case is started (safe to cache a pointer to the input) + void test_case_start(const doctest::TestCaseData& tc) override + { + currentTest = &tc; + + printf("Entering test suite \"%s\"\n", tc.m_test_suite); + printf("Entering test case \"%s\"\n", tc.m_name); + } + + // called when a test case has ended + void test_case_end(const doctest::CurrentTestCaseStats& tc) override + { + LUAU_ASSERT(currentTest); + + printf("Leaving test case \"%s\"\n", currentTest->m_name); + printf("Leaving test suite \"%s\"\n", currentTest->m_test_suite); + + currentTest = nullptr; + } + + // called when an exception is thrown from the test case (or it crashes) + void test_case_exception(const doctest::TestCaseException& e) override + { + LUAU_ASSERT(currentTest); + + printf("%s(%d): FATAL: Unhandled exception %s\n", currentTest->m_file.c_str(), currentTest->m_line, e.error_string.c_str()); + } + + // called whenever a subcase is entered/exited (noop) + void test_case_reenter(const doctest::TestCaseData&) override {} + void subcase_start(const doctest::SubcaseSignature&) override {} + void subcase_end() override {} + + void log_assert(const doctest::AssertData& ad) override + { + if (!ad.m_failed) + return; + + if (ad.m_decomp.size()) + printf("%s(%d): ERROR: %s (%s)\n", ad.m_file, ad.m_line, ad.m_expr, ad.m_decomp.c_str()); + else + printf("%s(%d): ERROR: %s\n", ad.m_file, ad.m_line, ad.m_expr); + } + + void log_message(const doctest::MessageData& md) override + { // + printf("%s(%d): ERROR: %s\n", md.m_file, md.m_line, md.m_string.c_str()); + } + + // called when a test case is skipped either because it doesn't pass the filters, has a skip decorator + // or isn't in the execution range (between first and last) (safe to cache a pointer to the input) + void test_case_skipped(const doctest::TestCaseData&) override {} +}; + +template +using FValueResult = std::pair; + +static FValueResult> parseFValueHelper(std::string_view view) +{ + size_t equalpos = view.find('='); + if (equalpos == std::string_view::npos) + return {std::string{view}, std::nullopt}; + + std::string name{view.substr(0, equalpos)}; + view.remove_prefix(equalpos + 1); + return {name, std::string{view}}; +} + +static FValueResult parseFInt(std::string_view view) +{ + // If this was an FInt but there were no provided value, we should be noisy about that. + // std::stoi already throws an exception on invalid conversion. + auto [name, value] = parseFValueHelper(view); + if (!value) + throw std::runtime_error("Expected a value associated with " + name); + + return {name, std::stoi(*value)}; +} + +static FValueResult parseFFlag(std::string_view view) +{ + // If we have a flag name but there's no provided value, we default to true. + auto [name, value] = parseFValueHelper(view); + bool state = value ? *value == "true" : true; + if (value && value != "true" && value != "false") + std::cerr << "Ignored '" << name << "' because '" << *value << "' is not a valid FFlag state." << std::endl; + + return {name, state}; +} + +template +static void setFastValue(const std::string& name, T value) +{ + for (Luau::FValue* fvalue = Luau::FValue::list; fvalue; fvalue = fvalue->next) + if (fvalue->name == name) + fvalue->value = value; +} + +static void setFastFlags(const std::vector& flags) +{ + for (const doctest::String& flag : flags) + { + std::string_view view = flag.c_str(); + if (view == "true" || view == "false") + { + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + { + if (!skipFastFlag(flag->name)) + flag->value = view == "true"; + } + + continue; + } + + if (view.size() >= 2 && view[0] == 'D' && view[1] == 'F') + view.remove_prefix(1); + + if (view.substr(0, 4) == "FInt") + { + auto [name, value] = parseFInt(view.substr(4)); + setFastValue(name, value); + } + else + { + // We want to prevent the footgun where '--fflags=LuauSomeFlag' is ignored. We'll assume that this was declared as FFlag. + auto [name, value] = parseFFlag(view.substr(0, 5) == "FFlag" ? view.substr(5) : view); + setFastValue(name, value); + } + } +} + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + doctest::registerReporter("boost", 0, true); + + doctest::Context context; + context.setOption("no-version", true); + context.applyCommandLine(argc, argv); + + if (doctest::parseFlag(argc, argv, "--list-fflags")) + { + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + { + if (skipFastFlag(flag->name)) + continue; + + if (flag->dynamic) + std::cout << 'D'; + std::cout << "FFlag" << flag->name << std::endl; + } + + return 0; + } + + if (std::vector flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags)) + setFastFlags(flags); + + if (doctest::parseFlag(argc, argv, "--list_content")) + { + const char* ltc[] = {argv[0], "--list-test-cases"}; + context.applyCommandLine(2, ltc); + } + + doctest::String filter; + if (doctest::parseOption(argc, argv, "--run_test", &filter) && filter[0] == '=') + { + const char* f = filter.c_str() + 1; + const char* s = strchr(f, '/'); + + if (s) + { + context.addFilter("test-suite", std::string(f, s).c_str()); + context.addFilter("test-case", s + 1); + } + else + { + context.addFilter("test-suite", f); + } + } + + return context.run(); +} + + diff --git a/tools/gdb-printers.py b/tools/gdb-printers.py new file mode 100644 index 0000000..c711c5e --- /dev/null +++ b/tools/gdb-printers.py @@ -0,0 +1,19 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +class VariantPrinter: + def __init__(self, val): + self.val = val + + def to_string(self): + typeId = int(self.val['typeId']) + type = self.val.type.template_argument(typeId) + value = self.val['storage'].reinterpret_cast(type.pointer()).dereference() + return type.name + " [" + str(value) + "]" + +def match_printer(val): + type = val.type.strip_typedefs() + if type.name and type.name.startswith('Luau::Variant<'): + return VariantPrinter(val) + return None + +gdb.pretty_printers.append(match_printer) diff --git a/tools/heapgraph.py b/tools/heapgraph.py new file mode 100644 index 0000000..106db54 --- /dev/null +++ b/tools/heapgraph.py @@ -0,0 +1,182 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given two heap snapshots (A & B), this tool performs reachability analysis on new objects allocated in B +# This is useful to find memory leaks - reachability analysis answers the question "why is this set of objects not freed" +# This tool can also be ran with just one snapshot, in which case it displays all allocated objects +# The result of analysis is a .svg file which can be viewed in a browser +# To generate these dumps, use luaC_dump, ideally preceded by luaC_fullgc + +import json +import sys +import svg + +class Node(svg.Node): + def __init__(self): + svg.Node.__init__(self) + self.size = 0 + self.count = 0 + # data for memory category filtering + self.objects = [] + self.categories = set() + + def text(self): + return self.name + + def title(self): + return self.name + + def details(self, root): + return "{} ({:,} bytes, {:.1%}); self: {:,} bytes in {:,} objects".format(self.name, self.width, self.width / root.width, self.size, self.count) + +# load files +if len(sys.argv) == 2: + dumpold = None + with open(sys.argv[1]) as f: + dump = json.load(f) +else: + with open(sys.argv[1]) as f: + dumpold = json.load(f) + with open(sys.argv[2]) as f: + dump = json.load(f) + +# reachability analysis: how much of the heap is reachable from roots? +visited = set() +queue = [] +offset = 0 +root = Node() + +for name, addr in dump["roots"].items(): + queue.append((addr, root.child(name))) + +while offset < len(queue): + addr, node = queue[offset] + offset += 1 + if addr in visited: + continue + + visited.add(addr) + obj = dump["objects"][addr] + + if not dumpold or not addr in dumpold["objects"]: + node.count += 1 + node.size += obj["size"] + node.objects.append(obj) + + if obj["type"] == "table": + pairs = obj.get("pairs", []) + + for i in range(0, len(pairs), 2): + key = pairs[i+0] + val = pairs[i+1] + if key and val and dump["objects"][key]["type"] == "string": + queue.append((key, node)) + queue.append((val, node.child(dump["objects"][key]["data"]))) + else: + if key: + queue.append((key, node)) + if val: + queue.append((val, node)) + + for a in obj.get("array", []): + queue.append((a, node)) + if "metatable" in obj: + queue.append((obj["metatable"], node.child("__meta"))) + elif obj["type"] == "function": + queue.append((obj["env"], node.child("__env"))) + + source = "" + if "proto" in obj: + proto = dump["objects"][obj["proto"]] + if "source" in proto: + source = proto["source"] + + if "proto" in obj: + queue.append((obj["proto"], node.child("__proto"))) + for a in obj.get("upvalues", []): + queue.append((a, node.child(source))) + elif obj["type"] == "userdata": + if "metatable" in obj: + queue.append((obj["metatable"], node.child("__meta"))) + elif obj["type"] == "thread": + queue.append((obj["env"], node.child("__env"))) + for a in obj.get("stack", []): + queue.append((a, node.child("__stack"))) + elif obj["type"] == "proto": + for a in obj.get("constants", []): + queue.append((a, node)) + for a in obj.get("protos", []): + queue.append((a, node)) + elif obj["type"] == "upvalue": + if "object" in obj: + queue.append((obj["object"], node)) + +def annotateContainedCategories(node): + for obj in node.objects: + node.categories.add(obj["cat"]) + + for child in node.children.values(): + annotateContainedCategories(child) + + for cat in child.categories: + node.categories.add(cat) + +def filteredTreeForCategory(node, category): + children = {} + + for c in node.children.values(): + if category in c.categories: + filtered = filteredTreeForCategory(c, category) + + if filtered: + children[filtered.name] = filtered + + if len(children): + result = Node() + result.name = node.name + + # re-count the objects with the correct category that we have + for obj in node.objects: + if obj["cat"] == category: + result.count += 1 + result.size += obj["size"] + + result.children = children + return result + else: + result = Node() + result.name = node.name + + # re-count the objects with the correct category that we have + for obj in node.objects: + if obj["cat"] == category: + result.count += 1 + result.size += obj["size"] + + if result.count != 0: + return result + + return None + +def splitIntoCategories(root): + result = Node() + + for i in range(0, 256): + filtered = filteredTreeForCategory(root, i) + + if filtered: + name = dump["stats"]["categories"][str(i)]["name"] + + filtered.name = name + result.children[name] = filtered + + return result + +# temporarily disabled because it makes FG harder to read, maybe this should be a separate command line option? +if dump["stats"].get("categories") and False: + annotateContainedCategories(root) + + root = splitIntoCategories(root) + +svg.layout(root, lambda n: n.size) +svg.display(root, "Memory Graph", "cold") diff --git a/tools/heapstat.py b/tools/heapstat.py new file mode 100644 index 0000000..838521b --- /dev/null +++ b/tools/heapstat.py @@ -0,0 +1,59 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a heap snapshot, this tool gathers basic statistics about the allocated objects +# To generate a snapshot, use luaC_dump, ideally preceded by luaC_fullgc + +import json +import sys +from collections import defaultdict + +def updatesize(d, k, s): + oc, os = d.get(k, (0, 0)) + d[k] = (oc + 1, os + s) + +def sortedsize(p): + return sorted(p, key = lambda s: s[1][1], reverse = True) + +with open(sys.argv[1]) as f: + dump = json.load(f) + heap = dump["objects"] + +type_addr = next((addr for addr,obj in heap.items() if obj["type"] == "string" and obj["data"] == "__type"), None) + +size_type = {} +size_udata = {} +size_category = {} + +for addr, obj in heap.items(): + updatesize(size_type, obj["type"], obj["size"]) + + if obj.get("cat") != None: + updatesize(size_category, str(obj["cat"]), obj["size"]) + + if obj["type"] == "userdata" and "metatable" in obj: + metatable = heap[obj["metatable"]] + pairs = metatable.get("pairs", []) + typemt = "unknown" + for i in range(0, len(pairs), 2): + if type_addr and pairs[i] == type_addr and pairs[i + 1] and heap[pairs[i + 1]]["type"] == "string": + typemt = heap[pairs[i + 1]]["data"] + updatesize(size_udata, typemt, obj["size"]) + +print("objects by type:") +for type, (count, size) in sortedsize(size_type.items()): + print(type.ljust(10), str(size).rjust(8), "bytes", str(count).rjust(5), "objects") + +print() + +print("userdata by __type:") +for type, (count, size) in sortedsize(size_udata.items()): + print(type.ljust(20), str(size).rjust(8), "bytes", str(count).rjust(5), "objects") + +if len(size_category) != 0: + print() + + print("objects by category:") + for type, (count, size) in sortedsize(size_category.items()): + name = dump["stats"]["categories"][type]["name"] + print(name.ljust(30), str(size).rjust(8), "bytes", str(count).rjust(5), "objects") diff --git a/tools/perfgraph.py b/tools/perfgraph.py new file mode 100644 index 0000000..95baef9 --- /dev/null +++ b/tools/perfgraph.py @@ -0,0 +1,52 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a profile dump, this tool generates a flame graph based on the stacks listed in the profile +# The result of analysis is a .svg file which can be viewed in a browser + +import sys +import svg + +class Node(svg.Node): + def __init__(self): + svg.Node.__init__(self) + self.function = "" + self.source = "" + self.line = 0 + self.ticks = 0 + + def text(self): + return self.function + + def title(self): + if self.line > 0: + return "{}\n{}:{}".format(self.function, self.source, self.line) + else: + return self.function + + def details(self, root): + return "Function: {} [{}:{}] ({:,} usec, {:.1%}); self: {:,} usec".format(self.function, self.source, self.line, self.width, self.width / root.width, self.ticks) + +with open(sys.argv[1]) as f: + dump = f.readlines() + +root = Node() + +for l in dump: + ticks, stack = l.strip().split(" ", 1) + node = root + + for f in reversed(stack.split(";")): + source, function, line = f.split(",") + + child = node.child(f) + child.function = function + child.source = source + child.line = int(line) if len(line) > 0 else 0 + + node = child + + node.ticks += int(ticks) + +svg.layout(root, lambda n: n.ticks) +svg.display(root, "Flame Graph", "hot", flip = True) diff --git a/tools/svg.py b/tools/svg.py new file mode 100644 index 0000000..3b3bb28 --- /dev/null +++ b/tools/svg.py @@ -0,0 +1,498 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +class Node: + def __init__(self): + self.name = "" + self.children = {} + # computed + self.depth = 0 + self.width = 0 + self.offset = 0 + + def child(self, name): + node = self.children.get(name) + if not node: + node = self.__class__() + node.name = name + self.children[name] = node + return node + + def subtree(self): + result = [self] + offset = 0 + + while offset < len(result): + p = result[offset] + offset += 1 + for c in p.children.values(): + result.append(c) + + return result + +def escape(s): + return s.replace("&", "&").replace("<", "<").replace(">", ">") + +def layout(root, widthcb): + for n in reversed(root.subtree()): + # propagate width to the parent + n.width = widthcb(n) + for c in n.children.values(): + n.width += c.width + + # compute offset from parent for every child in width order (layout order) + offset = 0 + for c in sorted(n.children.values(), key = lambda x: x.width, reverse = True): + c.offset = offset + offset += c.width + + for n in root.subtree(): + for c in n.children.values(): + c.depth = n.depth + 1 + c.offset += n.offset + +# svg template (stolen from framegraph.pl) +template = r""" + + + + + + + + + + + + +$title +Reset Zoom +Search +ic + + + +""" + +def namehash(s): + # FNV-1a + hval = 0x811c9dc5 + for ch in s: + hval = hval ^ ord(ch) + hval = hval * 0x01000193 + hval = hval % (2 ** 32) + return (hval % 31337) / 31337.0 + +def display(root, title, colors, flip = False): + if colors == "cold": + gradient_start = "#eef2ee" + gradient_end = "#e0ffe0" + else: + gradient_start = "#eeeeee" + gradient_end = "#eeeeb0" + + maxdepth = 0 + for n in root.subtree(): + maxdepth = max(maxdepth, n.depth) + + svgheight = maxdepth * 16 + 3 * 16 + 2 * 16 + + print(template + .replace("$title", title) + .replace("$gradient-start", gradient_start) + .replace("$gradient-end", gradient_end) + .replace("$height", str(svgheight)) + .replace("$status", str(svgheight - 16 + 3)) + .replace("$flip", str(int(flip))) + ) + + framewidth = 1200 - 20 + + for n in root.subtree(): + if n.width / root.width * framewidth < 0.1: + continue + + x = 10 + n.offset / root.width * framewidth + y = (maxdepth - 1 - n.depth if flip else n.depth) * 16 + 3 * 16 + width = n.width / root.width * framewidth + height = 15 + + if colors == "cold": + fillr = 0 + fillg = int(190 + 50 * namehash(n.name)) + fillb = int(210 * namehash(n.name[::-1])) + else: + fillr = int(205 + 50 * namehash(n.name)) + fillg = int(230 * namehash(n.name[::-1])) + fillb = int(55 * namehash(n.name[::-2])) + + fill = "rgb({},{},{})".format(fillr, fillg, fillb) + chars = width / (12 * 0.59) + + text = n.text() + + if chars >= 3: + if chars < len(text): + text = text[:int(chars-2)] + ".." + else: + text = "" + + print("") + print("{}".format(escape(n.title()))) + print("
{}
".format(escape(n.details(root)))) + print("".format(x, y, width, height, fill)) + print("{}".format(x + 3, y + 10.5, escape(text))) + print("{}".format(escape(n.text()))) + print("
") + + print("
\n
\n")