diff --git a/Cargo.toml b/Cargo.toml index 8ea5526..10d73e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "luau0-src" -version = "0.3.0+luau526" +version = "0.3.1+luau529" authors = ["Aleksandr Orlenko "] edition = "2018" repository = "https://github.com/khvzak/luau-src-rs" diff --git a/luau/Ast/include/Luau/TimeTrace.h b/luau/Ast/include/Luau/TimeTrace.h index 9f7b2bd..be28282 100644 --- a/luau/Ast/include/Luau/TimeTrace.h +++ b/luau/Ast/include/Luau/TimeTrace.h @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Common.h" +#include "Luau/Common.h" #include diff --git a/luau/Ast/src/Parser.cpp b/luau/Ast/src/Parser.cpp index 91f5cd2..eaf1991 100644 --- a/luau/Ast/src/Parser.cpp +++ b/luau/Ast/src/Parser.cpp @@ -10,7 +10,8 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParseLocationIgnoreCommentSkipInCapture, false) + +LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false) namespace Luau { @@ -1590,6 +1591,17 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { return parseFunctionTypeAnnotation(allowPack); } + else if (FFlag::LuauParserFunctionKeywordAsTypeHelp && lexer.current().type == Lexeme::ReservedFunction) + { + Location location = lexer.current().location; + + nextLexeme(); + + return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, + "Using 'function' as a type annotation is not supported, consider replacing with a function type annotation e.g. '(...any) -> " + "...any'"), + {}}; + } else { Location location = lexer.current().location; @@ -2821,7 +2833,7 @@ void Parser::nextLexeme() hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } - type = lexer.next(/* skipComments= */ false, !FFlag::LuauParseLocationIgnoreCommentSkipInCapture).type; + type = lexer.next(/* skipComments= */ false, /* updatePrevLocation= */ false).type; } } else diff --git a/luau/Compiler/include/Luau/Bytecode.h b/luau/Common/include/Luau/Bytecode.h similarity index 100% rename from luau/Compiler/include/Luau/Bytecode.h rename to luau/Common/include/Luau/Bytecode.h diff --git a/luau/Ast/include/Luau/Common.h b/luau/Common/include/Luau/Common.h similarity index 100% rename from luau/Ast/include/Luau/Common.h rename to luau/Common/include/Luau/Common.h diff --git a/luau/Compiler/include/Luau/BytecodeBuilder.h b/luau/Compiler/include/Luau/BytecodeBuilder.h index b00440a..dbe5429 100644 --- a/luau/Compiler/include/Luau/BytecodeBuilder.h +++ b/luau/Compiler/include/Luau/BytecodeBuilder.h @@ -224,6 +224,7 @@ private: DenseHashMap constantMap; DenseHashMap tableShapeMap; + DenseHashMap protoMap; int debugLine = 0; @@ -246,7 +247,7 @@ private: void validate() const; std::string dumpCurrentFunction() const; - const uint32_t* dumpInstruction(const uint32_t* opcode, std::string& output) const; + void dumpInstruction(const uint32_t* opcode, std::string& output, int targetLabel) const; void writeFunction(std::string& ss, uint32_t id) const; void writeLineInfo(std::string& ss) const; diff --git a/luau/Compiler/src/BytecodeBuilder.cpp b/luau/Compiler/src/BytecodeBuilder.cpp index fb70392..3aa12d9 100644 --- a/luau/Compiler/src/BytecodeBuilder.cpp +++ b/luau/Compiler/src/BytecodeBuilder.cpp @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAG(LuauCompileNestedClosureO2) + namespace Luau { @@ -128,6 +130,20 @@ inline bool isSkipC(LuauOpcode op) } } +static int getJumpTarget(uint32_t insn, uint32_t pc) +{ + LuauOpcode op = LuauOpcode(LUAU_INSN_OP(insn)); + + if (isJumpD(op)) + return int(pc + LUAU_INSN_D(insn) + 1); + else if (isSkipC(op) && LUAU_INSN_C(insn)) + return int(pc + LUAU_INSN_C(insn) + 1); + else if (op == LOP_JUMPX) + return int(pc + LUAU_INSN_E(insn) + 1); + else + return -1; +} + bool BytecodeBuilder::StringRef::operator==(const StringRef& other) const { return (data && other.data) ? (length == other.length && memcmp(data, other.data, length) == 0) : (data == other.data); @@ -181,6 +197,7 @@ size_t BytecodeBuilder::TableShapeHash::operator()(const TableShape& v) const BytecodeBuilder::BytecodeBuilder(BytecodeEncoder* encoder) : constantMap({Constant::Type_Nil, ~0ull}) , tableShapeMap(TableShape()) + , protoMap(~0u) , stringTable({nullptr, 0}) , encoder(encoder) { @@ -250,6 +267,7 @@ void BytecodeBuilder::endFunction(uint8_t maxstacksize, uint8_t numupvalues) constantMap.clear(); tableShapeMap.clear(); + protoMap.clear(); debugRemarks.clear(); debugRemarkBuffer.clear(); @@ -372,11 +390,17 @@ int32_t BytecodeBuilder::addConstantClosure(uint32_t fid) int16_t BytecodeBuilder::addChildFunction(uint32_t fid) { + if (FFlag::LuauCompileNestedClosureO2) + if (int16_t* cache = protoMap.find(fid)) + return *cache; + uint32_t id = uint32_t(protos.size()); if (id >= kMaxClosureCount) return -1; + if (FFlag::LuauCompileNestedClosureO2) + protoMap[fid] = int16_t(id); protos.push_back(fid); return int16_t(id); @@ -1398,7 +1422,7 @@ void BytecodeBuilder::validate() const } #endif -const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result) const +void BytecodeBuilder::dumpInstruction(const uint32_t* code, std::string& result, int targetLabel) const { uint32_t insn = *code++; @@ -1493,39 +1517,39 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri break; case LOP_JUMP: - formatAppend(result, "JUMP %+d\n", LUAU_INSN_D(insn)); + formatAppend(result, "JUMP L%d\n", targetLabel); break; case LOP_JUMPIF: - formatAppend(result, "JUMPIF R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "JUMPIF R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; case LOP_JUMPIFNOT: - formatAppend(result, "JUMPIFNOT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "JUMPIFNOT R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; case LOP_JUMPIFEQ: - formatAppend(result, "JUMPIFEQ R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + formatAppend(result, "JUMPIFEQ R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); break; case LOP_JUMPIFLE: - formatAppend(result, "JUMPIFLE R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + formatAppend(result, "JUMPIFLE R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); break; case LOP_JUMPIFLT: - formatAppend(result, "JUMPIFLT R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + formatAppend(result, "JUMPIFLT R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); break; case LOP_JUMPIFNOTEQ: - formatAppend(result, "JUMPIFNOTEQ R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + formatAppend(result, "JUMPIFNOTEQ R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); break; case LOP_JUMPIFNOTLE: - formatAppend(result, "JUMPIFNOTLE R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + formatAppend(result, "JUMPIFNOTLE R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); break; case LOP_JUMPIFNOTLT: - formatAppend(result, "JUMPIFNOTLT R%d R%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + formatAppend(result, "JUMPIFNOTLT R%d R%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); break; case LOP_ADD: @@ -1621,35 +1645,35 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri break; case LOP_FORNPREP: - formatAppend(result, "FORNPREP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "FORNPREP R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; case LOP_FORNLOOP: - formatAppend(result, "FORNLOOP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "FORNLOOP R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; case LOP_FORGPREP: - formatAppend(result, "FORGPREP R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "FORGPREP R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; case LOP_FORGLOOP: - formatAppend(result, "FORGLOOP R%d %+d %d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn), *code++); + formatAppend(result, "FORGLOOP R%d L%d %d\n", LUAU_INSN_A(insn), targetLabel, *code++); break; case LOP_FORGPREP_INEXT: - formatAppend(result, "FORGPREP_INEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "FORGPREP_INEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; case LOP_FORGLOOP_INEXT: - formatAppend(result, "FORGLOOP_INEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "FORGLOOP_INEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; case LOP_FORGPREP_NEXT: - formatAppend(result, "FORGPREP_NEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "FORGPREP_NEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; case LOP_FORGLOOP_NEXT: - formatAppend(result, "FORGLOOP_NEXT R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_D(insn)); + formatAppend(result, "FORGLOOP_NEXT R%d L%d\n", LUAU_INSN_A(insn), targetLabel); break; case LOP_GETVARARGS: @@ -1665,7 +1689,7 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri break; case LOP_JUMPBACK: - formatAppend(result, "JUMPBACK %+d\n", LUAU_INSN_D(insn)); + formatAppend(result, "JUMPBACK L%d\n", targetLabel); break; case LOP_LOADKX: @@ -1673,26 +1697,26 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri break; case LOP_JUMPX: - formatAppend(result, "JUMPX %+d\n", LUAU_INSN_E(insn)); + formatAppend(result, "JUMPX L%d\n", targetLabel); break; case LOP_FASTCALL: - formatAppend(result, "FASTCALL %d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_C(insn)); + formatAppend(result, "FASTCALL %d L%d\n", LUAU_INSN_A(insn), targetLabel); break; case LOP_FASTCALL1: - formatAppend(result, "FASTCALL1 %d R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), LUAU_INSN_C(insn)); + formatAppend(result, "FASTCALL1 %d R%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), targetLabel); break; case LOP_FASTCALL2: { uint32_t aux = *code++; - formatAppend(result, "FASTCALL2 %d R%d R%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, LUAU_INSN_C(insn)); + formatAppend(result, "FASTCALL2 %d R%d R%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, targetLabel); break; } case LOP_FASTCALL2K: { uint32_t aux = *code++; - formatAppend(result, "FASTCALL2K %d R%d K%d %+d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, LUAU_INSN_C(insn)); + formatAppend(result, "FASTCALL2K %d R%d K%d L%d\n", LUAU_INSN_A(insn), LUAU_INSN_B(insn), aux, targetLabel); break; } @@ -1702,23 +1726,24 @@ const uint32_t* BytecodeBuilder::dumpInstruction(const uint32_t* code, std::stri case LOP_CAPTURE: formatAppend(result, "CAPTURE %s %c%d\n", - LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL" : LUAU_INSN_A(insn) == LCT_REF ? "REF" : LUAU_INSN_A(insn) == LCT_VAL ? "VAL" : "", + LUAU_INSN_A(insn) == LCT_UPVAL ? "UPVAL" + : LUAU_INSN_A(insn) == LCT_REF ? "REF" + : LUAU_INSN_A(insn) == LCT_VAL ? "VAL" + : "", LUAU_INSN_A(insn) == LCT_UPVAL ? 'U' : 'R', LUAU_INSN_B(insn)); break; case LOP_JUMPIFEQK: - formatAppend(result, "JUMPIFEQK R%d K%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + formatAppend(result, "JUMPIFEQK R%d K%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); break; case LOP_JUMPIFNOTEQK: - formatAppend(result, "JUMPIFNOTEQK R%d K%d %+d\n", LUAU_INSN_A(insn), *code++, LUAU_INSN_D(insn)); + formatAppend(result, "JUMPIFNOTEQK R%d K%d L%d\n", LUAU_INSN_A(insn), *code++, targetLabel); break; default: LUAU_ASSERT(!"Unsupported opcode"); } - - return code; } std::string BytecodeBuilder::dumpCurrentFunction() const @@ -1726,9 +1751,6 @@ std::string BytecodeBuilder::dumpCurrentFunction() const if ((dumpFlags & Dump_Code) == 0) return std::string(); - const uint32_t* code = insns.data(); - const uint32_t* codeEnd = insns.data() + insns.size(); - int lastLine = -1; size_t nextRemark = 0; @@ -1750,21 +1772,45 @@ std::string BytecodeBuilder::dumpCurrentFunction() const } } - while (code != codeEnd) + std::vector labels(insns.size(), -1); + + // annotate valid jump targets with 0 + for (size_t i = 0; i < insns.size();) { + int target = getJumpTarget(insns[i], uint32_t(i)); + + if (target >= 0) + { + LUAU_ASSERT(size_t(target) < insns.size()); + labels[target] = 0; + } + + i += getOpLength(LuauOpcode(LUAU_INSN_OP(insns[i]))); + LUAU_ASSERT(i <= insns.size()); + } + + int nextLabel = 0; + + // compute label ids (sequential integers for all jump targets) + for (size_t i = 0; i < labels.size(); ++i) + if (labels[i] == 0) + labels[i] = nextLabel++; + + for (size_t i = 0; i < insns.size();) + { + const uint32_t* code = &insns[i]; uint8_t op = LUAU_INSN_OP(*code); - uint32_t pc = uint32_t(code - insns.data()); if (op == LOP_PREPVARARGS) { // Don't emit function header in bytecode - it's used for call dispatching and doesn't contain "interesting" information - code++; + i++; continue; } if (dumpFlags & Dump_Remarks) { - while (nextRemark < debugRemarks.size() && debugRemarks[nextRemark].first == pc) + while (nextRemark < debugRemarks.size() && debugRemarks[nextRemark].first == i) { formatAppend(result, "REMARK %s\n", debugRemarkBuffer.c_str() + debugRemarks[nextRemark].second); nextRemark++; @@ -1773,7 +1819,7 @@ std::string BytecodeBuilder::dumpCurrentFunction() const if (dumpFlags & Dump_Source) { - int line = lines[pc]; + int line = lines[i]; if (line > 0 && line != lastLine) { @@ -1784,11 +1830,17 @@ std::string BytecodeBuilder::dumpCurrentFunction() const } if (dumpFlags & Dump_Lines) - { - formatAppend(result, "%d: ", lines[pc]); - } + formatAppend(result, "%d: ", lines[i]); - code = dumpInstruction(code, result); + if (labels[i] != -1) + formatAppend(result, "L%d: ", labels[i]); + + int target = getJumpTarget(*code, uint32_t(i)); + + dumpInstruction(code, result, target >= 0 ? labels[target] : -1); + + i += getOpLength(LuauOpcode(op)); + LUAU_ASSERT(i <= insns.size()); } return result; diff --git a/luau/Compiler/src/Compiler.cpp b/luau/Compiler/src/Compiler.cpp index 4fe2622..eea56c6 100644 --- a/luau/Compiler/src/Compiler.cpp +++ b/luau/Compiler/src/Compiler.cpp @@ -15,12 +15,8 @@ #include #include #include -#include - -LUAU_FASTFLAGVARIABLE(LuauCompileSupportInlining, false) LUAU_FASTFLAGVARIABLE(LuauCompileIter, false) -LUAU_FASTFLAGVARIABLE(LuauCompileIterNoReserve, false) LUAU_FASTFLAGVARIABLE(LuauCompileIterNoPairs, false) LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) @@ -30,6 +26,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) +LUAU_FASTFLAGVARIABLE(LuauCompileNestedClosureO2, false) + namespace Luau { @@ -100,13 +98,11 @@ struct Compiler upvals.reserve(16); } - uint8_t getLocal(AstLocal* local) + int getLocalReg(AstLocal* local) { Local* l = locals.find(local); - LUAU_ASSERT(l); - LUAU_ASSERT(l->allocated); - return l->reg; + return l && l->allocated ? l->reg : -1; } uint8_t getUpval(AstLocal* local) @@ -159,41 +155,38 @@ struct Compiler AstExprFunction* getFunctionExpr(AstExpr* node) { - if (AstExprLocal* le = node->as()) + if (AstExprLocal* expr = node->as()) { - Variable* lv = variables.find(le->local); + Variable* lv = variables.find(expr->local); if (!lv || lv->written || !lv->init) return nullptr; return getFunctionExpr(lv->init); } - else if (AstExprGroup* ge = node->as()) - return getFunctionExpr(ge->expr); + else if (AstExprGroup* expr = node->as()) + return getFunctionExpr(expr->expr); + else if (AstExprTypeAssertion* expr = node->as()) + return getFunctionExpr(expr->expr); else return node->as(); } bool canInlineFunctionBody(AstStat* stat) { + if (FFlag::LuauCompileNestedClosureO2) + return true; // TODO: remove this function + struct CanInlineVisitor : AstVisitor { bool result = true; - bool visit(AstExpr* node) override + bool visit(AstExprFunction* node) override { - // nested functions may capture function arguments, and our upval handling doesn't handle elided variables (constant) - // TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues - // TODO: additionally we would need to change upvalue handling in compileExprFunction to handle upvalue->local migration - result = result && !node->is(); - return result; - } + result = false; - bool visit(AstStat* node) override - { - // loops may need to be unrolled which can result in cost amplification - result = result && !node->is(); - return result; + // short-circuit to avoid analyzing nested closure bodies + return false; } }; @@ -275,8 +268,7 @@ struct Compiler f.upvals = upvals; // record information for inlining - if (FFlag::LuauCompileSupportInlining && options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && - !getfenvUsed && !setfenvUsed) + if (options.optimizationLevel >= 2 && !func->vararg && canInlineFunctionBody(func->body) && !getfenvUsed && !setfenvUsed) { f.canInline = true; f.stackSize = stackSize; @@ -346,8 +338,8 @@ struct Compiler uint8_t argreg; - if (isExprLocalReg(arg)) - argreg = getLocal(arg->as()->local); + if (int reg = getExprLocalReg(arg); reg >= 0) + argreg = uint8_t(reg); else { argreg = uint8_t(regs + 1); @@ -403,8 +395,8 @@ struct Compiler } } - if (isExprLocalReg(expr->args.data[i])) - args[i] = getLocal(expr->args.data[i]->as()->local); + if (int reg = getExprLocalReg(expr->args.data[i]); reg >= 0) + args[i] = uint8_t(reg); else { args[i] = uint8_t(regs + 1 + i); @@ -489,19 +481,19 @@ struct Compiler return false; } - // TODO: we can compile functions with mismatching arity at call site but it's more annoying - if (func->args.size != expr->args.size) - { - bytecode.addDebugRemark("inlining failed: argument count mismatch (expected %d, got %d)", int(func->args.size), int(expr->args.size)); - return false; - } - - // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to inlining + // compute constant bitvector for all arguments to feed the cost model bool varc[8] = {}; - for (size_t i = 0; i < expr->args.size && i < 8; ++i) + for (size_t i = 0; i < func->args.size && i < expr->args.size && i < 8; ++i) varc[i] = isConstant(expr->args.data[i]); - int inlinedCost = computeCost(fi->costModel, varc, std::min(int(expr->args.size), 8)); + // if the last argument only returns a single value, all following arguments are nil + if (expr->args.size != 0 && + !(expr->args.data[expr->args.size - 1]->is() || expr->args.data[expr->args.size - 1]->is())) + for (size_t i = expr->args.size; i < func->args.size && i < 8; ++i) + varc[i] = true; + + // we use a dynamic cost threshold that's based on the fixed limit boosted by the cost advantage we gain due to inlining + int inlinedCost = computeCost(fi->costModel, varc, std::min(int(func->args.size), 8)); int baselineCost = computeCost(fi->costModel, nullptr, 0) + 3; int inlineProfit = (inlinedCost == 0) ? thresholdMaxBoost : std::min(thresholdMaxBoost, 100 * baselineCost / inlinedCost); @@ -533,15 +525,44 @@ struct Compiler for (size_t i = 0; i < func->args.size; ++i) { AstLocal* var = func->args.data[i]; - AstExpr* arg = expr->args.data[i]; + AstExpr* arg = i < expr->args.size ? expr->args.data[i] : nullptr; - if (Variable* vv = variables.find(var); vv && vv->written) + if (i + 1 == expr->args.size && func->args.size > expr->args.size && (arg->is() || arg->is())) + { + // if the last argument can return multiple values, we need to compute all of them into the remaining arguments + unsigned int tail = unsigned(func->args.size - expr->args.size) + 1; + uint8_t reg = allocReg(arg, tail); + + if (AstExprCall* expr = arg->as()) + compileExprCall(expr, reg, tail, /* targetTop= */ true); + else if (AstExprVarargs* expr = arg->as()) + compileExprVarargs(expr, reg, tail); + else + LUAU_ASSERT(!"Unexpected expression type"); + + for (size_t j = i; j < func->args.size; ++j) + pushLocal(func->args.data[j], uint8_t(reg + (j - i))); + + // all remaining function arguments have been allocated and assigned to + break; + } + else if (Variable* vv = variables.find(var); vv && vv->written) { // if the argument is mutated, we need to allocate a fresh register even if it's a constant uint8_t reg = allocReg(arg, 1); - compileExprTemp(arg, reg); + + if (arg) + compileExprTemp(arg, reg); + else + bytecode.emitABC(LOP_LOADNIL, reg, 0, 0); + pushLocal(var, reg); } + else if (arg == nullptr) + { + // since the argument is not mutated, we can simply fold the value into the expressions that need it + locstants[var] = {Constant::Type_Nil}; + } else if (const Constant* cv = constants.find(arg); cv && cv->type != Constant::Type_Unknown) { // since the argument is not mutated, we can simply fold the value into the expressions that need it @@ -553,20 +574,26 @@ struct Compiler Variable* lv = le ? variables.find(le->local) : nullptr; // if the argument is a local that isn't mutated, we will simply reuse the existing register - if (isExprLocalReg(arg) && (!lv || !lv->written)) + if (int reg = le ? getExprLocalReg(le) : -1; reg >= 0 && (!lv || !lv->written)) { - uint8_t reg = getLocal(le->local); - pushLocal(var, reg); + pushLocal(var, uint8_t(reg)); } else { - uint8_t reg = allocReg(arg, 1); - compileExprTemp(arg, reg); - pushLocal(var, reg); + uint8_t temp = allocReg(arg, 1); + compileExprTemp(arg, temp); + pushLocal(var, temp); } } } + // evaluate extra expressions for side effects + for (size_t i = func->args.size; i < expr->args.size; ++i) + { + RegScope rsi(this); + compileExprAuto(expr->args.data[i], rsi); + } + // fold constant values updated above into expressions in the function body foldConstants(constants, variables, locstants, func->body); @@ -627,8 +654,16 @@ struct Compiler FInt::LuauCompileInlineThresholdMaxBoost, FInt::LuauCompileInlineDepth)) return; - if (fi && !fi->canInline) - bytecode.addDebugRemark("inlining failed: complex constructs in function body"); + // add a debug remark for cases when we didn't even call tryCompileInlinedCall + if (func && !(fi && fi->canInline)) + { + if (func->vararg) + bytecode.addDebugRemark("inlining failed: function is variadic"); + else if (!fi) + bytecode.addDebugRemark("inlining failed: can't inline recursive calls"); + else if (getfenvUsed || setfenvUsed) + bytecode.addDebugRemark("inlining failed: module uses getfenv/setfenv"); + } } RegScope rs(this); @@ -672,9 +707,9 @@ struct Compiler LUAU_ASSERT(fi); // Optimization: use local register directly in NAMECALL if possible - if (isExprLocalReg(fi->expr)) + if (int reg = getExprLocalReg(fi->expr); reg >= 0) { - selfreg = getLocal(fi->expr->as()->local); + selfreg = uint8_t(reg); } else { @@ -780,6 +815,8 @@ struct Compiler void compileExprFunction(AstExprFunction* expr, uint8_t target) { + RegScope rs(this); + const Function* f = functions.find(expr); LUAU_ASSERT(f); @@ -790,6 +827,67 @@ struct Compiler if (pid < 0) CompileError::raise(expr->location, "Exceeded closure limit; simplify the code to compile"); + if (FFlag::LuauCompileNestedClosureO2) + { + captures.clear(); + captures.reserve(f->upvals.size()); + + for (AstLocal* uv : f->upvals) + { + LUAU_ASSERT(uv->functionDepth < expr->functionDepth); + + if (int reg = getLocalReg(uv); reg >= 0) + { + // note: we can't check if uv is an upvalue in the current frame because inlining can migrate from upvalues to locals + Variable* ul = variables.find(uv); + bool immutable = !ul || !ul->written; + + captures.push_back({immutable ? LCT_VAL : LCT_REF, uint8_t(reg)}); + } + else if (const Constant* uc = locstants.find(uv); uc && uc->type != Constant::Type_Unknown) + { + // inlining can result in an upvalue capture of a constant, in which case we can't capture without a temporary register + uint8_t reg = allocReg(expr, 1); + compileExprConstant(expr, uc, reg); + + captures.push_back({LCT_VAL, reg}); + } + else + { + LUAU_ASSERT(uv->functionDepth < expr->functionDepth - 1); + + // get upvalue from parent frame + // note: this will add uv to the current upvalue list if necessary + uint8_t uid = getUpval(uv); + + captures.push_back({LCT_UPVAL, uid}); + } + } + + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + int16_t shared = -1; + + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) + { + int32_t cid = bytecode.addConstantClosure(f->id); + + if (cid >= 0 && cid < 32768) + shared = int16_t(cid); + } + + if (shared >= 0) + bytecode.emitAD(LOP_DUPCLOSURE, target, shared); + else + bytecode.emitAD(LOP_NEWCLOSURE, target, pid); + + for (const Capture& c : captures) + bytecode.emitABC(LOP_CAPTURE, uint8_t(c.type), c.data, 0); + + return; + } + bool shared = false; // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure @@ -819,9 +917,10 @@ struct Compiler if (uv->functionDepth == expr->functionDepth - 1) { // get local variable - uint8_t reg = getLocal(uv); + int reg = getLocalReg(uv); + LUAU_ASSERT(reg >= 0); - bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), reg, 0); + bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), uint8_t(reg), 0); } else { @@ -926,6 +1025,13 @@ struct Compiler return cv && cv->type != Constant::Type_Unknown && !cv->isTruthful(); } + Constant getConstant(AstExpr* node) + { + const Constant* cv = constants.find(node); + + return cv ? *cv : Constant{Constant::Type_Unknown}; + } + size_t compileCompareJump(AstExprBinary* expr, bool not_ = false) { RegScope rs(this); @@ -1036,9 +1142,7 @@ struct Compiler void compileConditionValue(AstExpr* node, const uint8_t* target, std::vector& skipJump, bool onlyTruth) { // Optimization: we don't need to compute constant values - const Constant* cv = constants.find(node); - - if (cv && cv->type != Constant::Type_Unknown) + if (const Constant* cv = constants.find(node); cv && cv->type != Constant::Type_Unknown) { // note that we only need to compute the value if it's truthy; otherwise we cal fall through if (cv->isTruthful() == onlyTruth) @@ -1196,9 +1300,7 @@ struct Compiler RegScope rs(this); // Optimization: when left hand side is a constant, we can emit left hand side or right hand side - const Constant* cl = constants.find(expr->left); - - if (cl && cl->type != Constant::Type_Unknown) + if (const Constant* cl = constants.find(expr->left); cl && cl->type != Constant::Type_Unknown) { compileExpr(and_ == cl->isTruthful() ? expr->right : expr->left, target, targetTemp); return; @@ -1208,10 +1310,10 @@ struct Compiler if (!isConditionFast(expr->left)) { // Optimization: when right hand side is a local variable, we can use AND/OR - if (isExprLocalReg(expr->right)) + if (int reg = getExprLocalReg(expr->right); reg >= 0) { uint8_t lr = compileExprAuto(expr->left, rs); - uint8_t rr = getLocal(expr->right->as()->local); + uint8_t rr = uint8_t(reg); bytecode.emitABC(and_ ? LOP_AND : LOP_OR, target, lr, rr); return; @@ -1630,13 +1732,11 @@ struct Compiler { RegScope rs(this); - // note: cv may be invalidated by compileExpr* so we stop using it before calling compile recursively - const Constant* cv = constants.find(expr->index); + Constant cv = getConstant(expr->index); - if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && - double(int(cv->valueNumber)) == cv->valueNumber) + if (cv.type == Constant::Type_Number && cv.valueNumber >= 1 && cv.valueNumber <= 256 && double(int(cv.valueNumber)) == cv.valueNumber) { - uint8_t i = uint8_t(int(cv->valueNumber) - 1); + uint8_t i = uint8_t(int(cv.valueNumber) - 1); uint8_t rt = compileExprAuto(expr->expr, rs); @@ -1644,9 +1744,9 @@ struct Compiler bytecode.emitABC(LOP_GETTABLEN, target, rt, i); } - else if (cv && cv->type == Constant::Type_String) + else if (cv.type == Constant::Type_String) { - BytecodeBuilder::StringRef iname = sref(cv->getString()); + BytecodeBuilder::StringRef iname = sref(cv.getString()); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); @@ -1759,13 +1859,10 @@ struct Compiler } // Optimization: if expression has a constant value, we can emit it directly - if (const Constant* cv = constants.find(node)) + if (const Constant* cv = constants.find(node); cv && cv->type != Constant::Type_Unknown) { - if (cv->type != Constant::Type_Unknown) - { - compileExprConstant(node, cv, target); - return; - } + compileExprConstant(node, cv, target); + return; } if (AstExprGroup* expr = node->as()) @@ -1798,19 +1895,18 @@ struct Compiler } else if (AstExprLocal* expr = node->as()) { - if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue) + // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining + if (int reg = getExprLocalReg(expr); reg >= 0) + { + bytecode.emitABC(LOP_MOVE, target, uint8_t(reg), 0); + } + else { LUAU_ASSERT(expr->upvalue); uint8_t uid = getUpval(expr->local); bytecode.emitABC(LOP_GETUPVAL, target, uid, 0); } - else - { - uint8_t reg = getLocal(expr->local); - - bytecode.emitABC(LOP_MOVE, target, reg, 0); - } } else if (AstExprGlobal* expr = node->as()) { @@ -1874,8 +1970,8 @@ struct Compiler uint8_t compileExprAuto(AstExpr* node, RegScope&) { // Optimization: directly return locals instead of copying them to a temporary - if (isExprLocalReg(node)) - return getLocal(node->as()->local); + if (int reg = getExprLocalReg(node); reg >= 0) + return uint8_t(reg); // note: the register is owned by the parent scope uint8_t reg = allocReg(node, 1); @@ -1905,7 +2001,7 @@ struct Compiler for (size_t i = 0; i < targetCount; ++i) compileExprTemp(list.data[i], uint8_t(target + i)); - // compute expressions with values that go nowhere; this is required to run side-effecting code if any + // evaluate extra expressions for side effects for (size_t i = targetCount; i < list.size; ++i) { RegScope rsi(this); @@ -1965,23 +2061,22 @@ struct Compiler LValue compileLValueIndex(uint8_t reg, AstExpr* index, RegScope& rs) { - const Constant* cv = constants.find(index); + Constant cv = getConstant(index); - if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && - double(int(cv->valueNumber)) == cv->valueNumber) + if (cv.type == Constant::Type_Number && cv.valueNumber >= 1 && cv.valueNumber <= 256 && double(int(cv.valueNumber)) == cv.valueNumber) { LValue result = {LValue::Kind_IndexNumber}; result.reg = reg; - result.number = uint8_t(int(cv->valueNumber) - 1); + result.number = uint8_t(int(cv.valueNumber) - 1); result.location = index->location; return result; } - else if (cv && cv->type == Constant::Type_String) + else if (cv.type == Constant::Type_String) { LValue result = {LValue::Kind_IndexName}; result.reg = reg; - result.name = sref(cv->getString()); + result.name = sref(cv.getString()); result.location = index->location; return result; @@ -2003,20 +2098,21 @@ struct Compiler if (AstExprLocal* expr = node->as()) { - if (FFlag::LuauCompileSupportInlining ? !isExprLocalReg(expr) : expr->upvalue) + // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining + if (int reg = getExprLocalReg(expr); reg >= 0) { - LUAU_ASSERT(expr->upvalue); - - LValue result = {LValue::Kind_Upvalue}; - result.upval = getUpval(expr->local); + LValue result = {LValue::Kind_Local}; + result.reg = uint8_t(reg); result.location = node->location; return result; } else { - LValue result = {LValue::Kind_Local}; - result.reg = getLocal(expr->local); + LUAU_ASSERT(expr->upvalue); + + LValue result = {LValue::Kind_Upvalue}; + result.upval = getUpval(expr->local); result.location = node->location; return result; @@ -2110,15 +2206,21 @@ struct Compiler compileLValueUse(lv, source, /* set= */ true); } - bool isExprLocalReg(AstExpr* expr) + int getExprLocalReg(AstExpr* node) { - AstExprLocal* le = expr->as(); - if (!le || (!FFlag::LuauCompileSupportInlining && le->upvalue)) - return false; + if (AstExprLocal* expr = node->as()) + { + // note: this can't check expr->upvalue because upvalues may be upgraded to locals during inlining + Local* l = locals.find(expr->local); - Local* l = locals.find(le->local); - - return l && l->allocated; + return l && l->allocated ? l->reg : -1; + } + else if (AstExprGroup* expr = node->as()) + return getExprLocalReg(expr->expr); + else if (AstExprTypeAssertion* expr = node->as()) + return getExprLocalReg(expr->expr); + else + return -1; } bool isStatBreak(AstStat* node) @@ -2342,17 +2444,25 @@ struct Compiler RegScope rs(this); uint8_t temp = 0; + bool consecutive = false; bool multRet = false; - // Optimization: return local value directly instead of copying it into a temporary - if (stat->list.size == 1 && isExprLocalReg(stat->list.data[0])) + // Optimization: return locals directly instead of copying them into a temporary + // this is very important for a single return value and occasionally effective for multiple values + if (int reg = stat->list.size > 0 ? getExprLocalReg(stat->list.data[0]) : -1; reg >= 0) { - AstExprLocal* le = stat->list.data[0]->as(); - LUAU_ASSERT(le); + temp = uint8_t(reg); + consecutive = true; - temp = getLocal(le->local); + for (size_t i = 1; i < stat->list.size; ++i) + if (getExprLocalReg(stat->list.data[i]) != int(temp + i)) + { + consecutive = false; + break; + } } - else if (stat->list.size > 0) + + if (!consecutive && stat->list.size > 0) { temp = allocReg(stat, unsigned(stat->list.size)); @@ -2401,41 +2511,21 @@ struct Compiler pushLocal(stat->vars.data[i], uint8_t(vars + i)); } - int getConstantShort(AstExpr* expr) - { - const Constant* c = constants.find(expr); - - if (c && c->type == Constant::Type_Number) - { - double n = c->valueNumber; - - if (n >= -32767 && n <= 32767 && double(int(n)) == n) - return int(n); - } - - return INT_MIN; - } - bool canUnrollForBody(AstStatFor* stat) { + if (FFlag::LuauCompileNestedClosureO2) + return true; // TODO: remove this function + struct CanUnrollVisitor : AstVisitor { bool result = true; - bool visit(AstExpr* node) override + bool visit(AstExprFunction* node) override { - // functions may capture loop variable, and our upval handling doesn't handle elided variables (constant) - // TODO: we could remove this case if we changed function compilation to create temporary locals for constant upvalues - result = result && !node->is(); - return result; - } + result = false; - bool visit(AstStat* node) override - { - // while we can easily unroll nested loops, our cost model doesn't take unrolling into account so this can result in code explosion - // we also avoid continue/break since they introduce control flow across iterations - result = result && !node->is() && !node->is() && !node->is(); - return result; + // short-circuit to avoid analyzing nested closure bodies + return false; } }; @@ -2447,17 +2537,29 @@ struct Compiler bool tryCompileUnrolledFor(AstStatFor* stat, int thresholdBase, int thresholdMaxBoost) { - int from = getConstantShort(stat->from); - int to = getConstantShort(stat->to); - int step = stat->step ? getConstantShort(stat->step) : 1; + Constant one = {Constant::Type_Number}; + one.valueNumber = 1.0; - // check that limits are reasonably small and trip count can be computed - if (from == INT_MIN || to == INT_MIN || step == INT_MIN || step == 0 || (step < 0 && to > from) || (step > 0 && to < from)) + Constant fromc = getConstant(stat->from); + Constant toc = getConstant(stat->to); + Constant stepc = stat->step ? getConstant(stat->step) : one; + + int tripCount = (fromc.type == Constant::Type_Number && toc.type == Constant::Type_Number && stepc.type == Constant::Type_Number) + ? getTripCount(fromc.valueNumber, toc.valueNumber, stepc.valueNumber) + : -1; + + if (tripCount < 0) { bytecode.addDebugRemark("loop unroll failed: invalid iteration count"); return false; } + if (tripCount > thresholdBase) + { + bytecode.addDebugRemark("loop unroll failed: too many iterations (%d)", tripCount); + return false; + } + if (!canUnrollForBody(stat)) { bytecode.addDebugRemark("loop unroll failed: unsupported loop body"); @@ -2470,14 +2572,6 @@ struct Compiler return false; } - int tripCount = (to - from) / step + 1; - - if (tripCount > thresholdBase) - { - bytecode.addDebugRemark("loop unroll failed: too many iterations (%d)", tripCount); - return false; - } - AstLocal* var = stat->var; uint64_t costModel = modelCost(stat->body, &var, 1); @@ -2498,23 +2592,54 @@ struct Compiler bytecode.addDebugRemark("loop unroll succeeded (iterations %d, cost %d, profit %.2fx)", tripCount, unrolledCost, double(unrollProfit) / 100); - for (int i = from; step > 0 ? i <= to : i >= to; i += step) + compileUnrolledFor(stat, tripCount, fromc.valueNumber, stepc.valueNumber); + return true; + } + + void compileUnrolledFor(AstStatFor* stat, int tripCount, double from, double step) + { + AstLocal* var = stat->var; + + size_t oldLocals = localStack.size(); + size_t oldJumps = loopJumps.size(); + + loops.push_back({oldLocals, nullptr}); + + for (int iv = 0; iv < tripCount; ++iv) { // we need to re-fold constants in the loop body with the new value; this reuses computed constant values elsewhere in the tree locstants[var].type = Constant::Type_Number; - locstants[var].valueNumber = i; + locstants[var].valueNumber = from + iv * step; foldConstants(constants, variables, locstants, stat); + size_t iterJumps = loopJumps.size(); + compileStat(stat->body); + + // all continue jumps need to go to the next iteration + size_t contLabel = bytecode.emitLabel(); + + for (size_t i = iterJumps; i < loopJumps.size(); ++i) + if (loopJumps[i].type == LoopJump::Continue) + patchJump(stat, loopJumps[i].label, contLabel); } + // all break jumps need to go past the loop + size_t endLabel = bytecode.emitLabel(); + + for (size_t i = oldJumps; i < loopJumps.size(); ++i) + if (loopJumps[i].type == LoopJump::Break) + patchJump(stat, loopJumps[i].label, endLabel); + + loopJumps.resize(oldJumps); + + loops.pop_back(); + // clean up fold state in case we need to recompile - normally we compile the loop body once, but due to inlining we may need to do it again locstants[var].type = Constant::Type_Unknown; foldConstants(constants, variables, locstants, stat); - - return true; } void compileStatFor(AstStatFor* stat) @@ -2601,16 +2726,6 @@ struct Compiler // this puts initial values of (generator, state, index) into the loop registers compileExprListTemp(stat->values, regs, 3, /* targetTop= */ true); - // we don't need this because the extra stack space is just for calling the function with a loop protocol which is similar to calling - // metamethods - it should fit into the extra stack reservation - if (!FFlag::LuauCompileIterNoReserve) - { - // for the general case, we will execute a CALL for every iteration that needs to evaluate "variables... = generator(state, index)" - // this requires at least extra 3 stack slots after index - // note that these stack slots overlap with the variables so we only need to reserve them to make sure stack frame is large enough - reserveReg(stat, 3); - } - // note that we reserve at least 2 variables; this allows our fast path to assume that we need 2 variables instead of 1 or 2 uint8_t vars = allocReg(stat, std::max(unsigned(stat->vars.size), 2u)); LUAU_ASSERT(vars == regs + 3); @@ -2858,12 +2973,9 @@ struct Compiler void compileStatFunction(AstStatFunction* stat) { // Optimization: compile value expresion directly into target local register - if (isExprLocalReg(stat->name)) + if (int reg = getExprLocalReg(stat->name); reg >= 0) { - AstExprLocal* le = stat->name->as(); - LUAU_ASSERT(le); - - compileExpr(stat->func, getLocal(le->local)); + compileExpr(stat->func, uint8_t(reg)); return; } @@ -3383,6 +3495,12 @@ struct Compiler std::vector returnJumps; }; + struct Capture + { + LuauCaptureType type; + uint8_t data; + }; + BytecodeBuilder& bytecode; CompileOptions options; @@ -3406,6 +3524,7 @@ struct Compiler std::vector loopJumps; std::vector loops; std::vector inlineFrames; + std::vector captures; }; void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) @@ -3449,6 +3568,9 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName /* self= */ nullptr, AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main); + const Compiler::Function* mainf = compiler.functions.find(&main); + LUAU_ASSERT(mainf && mainf->upvals.empty()); + bytecode.setMainFunction(mainid); bytecode.finalize(); } diff --git a/luau/Compiler/src/ConstantFolding.cpp b/luau/Compiler/src/ConstantFolding.cpp index 52ece73..a62beeb 100644 --- a/luau/Compiler/src/ConstantFolding.cpp +++ b/luau/Compiler/src/ConstantFolding.cpp @@ -3,8 +3,6 @@ #include -LUAU_FASTFLAG(LuauCompileSupportInlining) - namespace Luau { namespace Compile @@ -195,12 +193,16 @@ struct ConstantVisitor : AstVisitor DenseHashMap& variables; DenseHashMap& locals; + bool wasEmpty = false; + ConstantVisitor( DenseHashMap& constants, DenseHashMap& variables, DenseHashMap& locals) : constants(constants) , variables(variables) , locals(locals) { + // since we do a single pass over the tree, if the initial state was empty we don't need to clear out old entries + wasEmpty = constants.empty() && locals.empty(); } Constant analyze(AstExpr* node) @@ -326,7 +328,7 @@ struct ConstantVisitor : AstVisitor { if (value.type != Constant::Type_Unknown) map[key] = value; - else if (!FFlag::LuauCompileSupportInlining) + else if (wasEmpty) ; else if (Constant* old = map.find(key)) old->type = Constant::Type_Unknown; diff --git a/luau/Compiler/src/CostModel.cpp b/luau/Compiler/src/CostModel.cpp index 9afd09f..5608cd8 100644 --- a/luau/Compiler/src/CostModel.cpp +++ b/luau/Compiler/src/CostModel.cpp @@ -4,6 +4,8 @@ #include "Luau/Common.h" #include "Luau/DenseHash.h" +#include + namespace Luau { namespace Compile @@ -11,10 +13,49 @@ namespace Compile inline uint64_t parallelAddSat(uint64_t x, uint64_t y) { - uint64_t s = x + y; - uint64_t m = s & 0x8080808080808080ull; // saturation mask + uint64_t r = x + y; + uint64_t s = r & 0x8080808080808080ull; // saturation mask - return (s ^ m) | (m - (m >> 7)); + return (r ^ s) | (s - (s >> 7)); +} + +static uint64_t parallelMulSat(uint64_t a, int b) +{ + int bs = (b < 127) ? b : 127; + + // multiply every other value by b, yielding 14-bit products + uint64_t l = bs * ((a >> 0) & 0x007f007f007f007full); + uint64_t h = bs * ((a >> 8) & 0x007f007f007f007full); + + // each product is 14-bit, so adding 32768-128 sets high bit iff the sum is 128 or larger without an overflow + uint64_t ls = l + 0x7f807f807f807f80ull; + uint64_t hs = h + 0x7f807f807f807f80ull; + + // we now merge saturation bits as well as low 7-bits of each product into one + uint64_t s = (hs & 0x8000800080008000ull) | ((ls & 0x8000800080008000ull) >> 8); + uint64_t r = ((h & 0x007f007f007f007full) << 8) | (l & 0x007f007f007f007full); + + // the low bits are now correct for values that didn't saturate, and we simply need to mask them if high bit is 1 + return r | (s - (s >> 7)); +} + +inline bool getNumber(AstExpr* node, double& result) +{ + // since constant model doesn't use constant folding atm, we perform the basic extraction that's sufficient to handle positive/negative literals + if (AstExprConstantNumber* ne = node->as()) + { + result = ne->value; + return true; + } + + if (AstExprUnary* ue = node->as(); ue && ue->op == AstExprUnary::Minus) + if (AstExprConstantNumber* ne = ue->expr->as()) + { + result = -ne->value; + return true; + } + + return false; } struct Cost @@ -46,6 +87,13 @@ struct Cost return *this; } + Cost operator*(int other) const + { + Cost result; + result.model = parallelMulSat(model, other); + return result; + } + static Cost fold(const Cost& x, const Cost& y) { uint64_t newmodel = parallelAddSat(x.model, y.model); @@ -173,6 +221,16 @@ struct CostVisitor : AstVisitor *i = 0; } + void loop(AstStatBlock* body, Cost iterCost, int factor = 3) + { + Cost before = result; + + result = Cost(); + body->visit(this); + + result = before + (result + iterCost) * factor; + } + bool visit(AstExpr* node) override { // note: we short-circuit the visitor traversal through any expression trees by returning false @@ -182,12 +240,52 @@ struct CostVisitor : AstVisitor return false; } + bool visit(AstStatFor* node) override + { + result += model(node->from); + result += model(node->to); + + if (node->step) + result += model(node->step); + + int tripCount = -1; + double from, to, step = 1; + if (getNumber(node->from, from) && getNumber(node->to, to) && (!node->step || getNumber(node->step, step))) + tripCount = getTripCount(from, to, step); + + loop(node->body, 1, tripCount < 0 ? 3 : tripCount); + return false; + } + + bool visit(AstStatForIn* node) override + { + for (size_t i = 0; i < node->values.size; ++i) + result += model(node->values.data[i]); + + loop(node->body, 1); + return false; + } + + bool visit(AstStatWhile* node) override + { + Cost condition = model(node->condition); + + loop(node->body, condition); + return false; + } + + bool visit(AstStatRepeat* node) override + { + Cost condition = model(node->condition); + + loop(node->body, condition); + return false; + } + bool visit(AstStat* node) override { if (node->is()) result += 2; - else if (node->is() || node->is() || node->is() || node->is()) - result += 2; else if (node->is() || node->is()) result += 1; @@ -254,5 +352,21 @@ int computeCost(uint64_t model, const bool* varsConst, size_t varCount) return cost; } +int getTripCount(double from, double to, double step) +{ + // we compute trip count in integers because that way we know that the loop math (repeated addition) is precise + int fromi = (from >= -32767 && from <= 32767 && double(int(from)) == from) ? int(from) : INT_MIN; + int toi = (to >= -32767 && to <= 32767 && double(int(to)) == to) ? int(to) : INT_MIN; + int stepi = (step >= -32767 && step <= 32767 && double(int(step)) == step) ? int(step) : INT_MIN; + + if (fromi == INT_MIN || toi == INT_MIN || stepi == INT_MIN || stepi == 0) + return -1; + + if ((stepi < 0 && toi > fromi) || (stepi > 0 && toi < fromi)) + return 0; + + return (toi - fromi) / stepi + 1; +} + } // namespace Compile } // namespace Luau diff --git a/luau/Compiler/src/CostModel.h b/luau/Compiler/src/CostModel.h index c27861e..17defaf 100644 --- a/luau/Compiler/src/CostModel.h +++ b/luau/Compiler/src/CostModel.h @@ -14,5 +14,8 @@ uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount); // cost is computed as B - sum(Di * Ci), where B is baseline cost, Di is the discount for each variable and Ci is 1 when variable #i is constant int computeCost(uint64_t model, const bool* varsConst, size_t varCount); +// get loop trip count or -1 if we can't compute it precisely +int getTripCount(double from, double to, double step); + } // namespace Compile } // namespace Luau diff --git a/luau/VM/include/lua.h b/luau/VM/include/lua.h index c3ebadb..7f9647c 100644 --- a/luau/VM/include/lua.h +++ b/luau/VM/include/lua.h @@ -148,6 +148,7 @@ LUA_API const char* lua_tostringatom(lua_State* L, int idx, int* atom); LUA_API const char* lua_namecallatom(lua_State* L, int* atom); LUA_API int lua_objlen(lua_State* L, int idx); LUA_API lua_CFunction lua_tocfunction(lua_State* L, int idx); +LUA_API void* lua_tolightuserdata(lua_State* L, int idx); LUA_API void* lua_touserdata(lua_State* L, int idx); LUA_API void* lua_touserdatatagged(lua_State* L, int idx, int tag); LUA_API int lua_userdatatag(lua_State* L, int idx); diff --git a/luau/VM/src/lapi.cpp b/luau/VM/src/lapi.cpp index f8baefa..f3be64b 100644 --- a/luau/VM/src/lapi.cpp +++ b/luau/VM/src/lapi.cpp @@ -478,18 +478,21 @@ lua_CFunction lua_tocfunction(lua_State* L, int idx) return (!iscfunction(o)) ? NULL : cast_to(lua_CFunction, clvalue(o)->c.f); } +void* lua_tolightuserdata(lua_State* L, int idx) +{ + StkId o = index2addr(L, idx); + return (!ttislightuserdata(o)) ? NULL : pvalue(o); +} + void* lua_touserdata(lua_State* L, int idx) { StkId o = index2addr(L, idx); - switch (ttype(o)) - { - case LUA_TUSERDATA: - return uvalue(o)->data; - case LUA_TLIGHTUSERDATA: - return pvalue(o); - default: - return NULL; - } + if (ttisuserdata(o)) + return uvalue(o)->data; + else if (ttislightuserdata(o)) + return pvalue(o); + else + return NULL; } void* lua_touserdatatagged(lua_State* L, int idx, int tag) @@ -524,8 +527,9 @@ const void* lua_topointer(lua_State* L, int idx) case LUA_TTHREAD: return thvalue(o); case LUA_TUSERDATA: + return uvalue(o)->data; case LUA_TLIGHTUSERDATA: - return lua_touserdata(L, idx); + return pvalue(o); default: return NULL; } diff --git a/luau/VM/src/lbuiltins.cpp b/luau/VM/src/lbuiltins.cpp index 6014919..cc6e560 100644 --- a/luau/VM/src/lbuiltins.cpp +++ b/luau/VM/src/lbuiltins.cpp @@ -15,8 +15,6 @@ #include #endif -LUAU_FASTFLAGVARIABLE(LuauFixBuiltinsStackLimit, false) - // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -1005,7 +1003,7 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St else if (nparams == 3 && ttisnumber(args) && ttisnumber(args + 1) && nvalue(args) == 1.0) n = int(nvalue(args + 1)); - if (n >= 0 && n <= t->sizearray && cast_int(L->stack_last - res) >= n && (!FFlag::LuauFixBuiltinsStackLimit || n + nparams <= LUAI_MAXCSTACK)) + if (n >= 0 && n <= t->sizearray && cast_int(L->stack_last - res) >= n && n + nparams <= LUAI_MAXCSTACK) { TValue* array = t->array; for (int i = 0; i < n; ++i) diff --git a/luau/VM/src/lbytecode.h b/luau/VM/src/lbytecode.h index c4d250d..da2a611 100644 --- a/luau/VM/src/lbytecode.h +++ b/luau/VM/src/lbytecode.h @@ -3,7 +3,4 @@ #pragma once // This is a forwarding header for Luau bytecode definition -// Luau consists of several components, including compiler (Ast, Compiler) and VM (virtual machine) -// These components are fully independent, but they both need the bytecode format defined in this header -// so it needs to be shared. -#include "../../Compiler/include/Luau/Bytecode.h" +#include "Luau/Bytecode.h" diff --git a/luau/VM/src/lcommon.h b/luau/VM/src/lcommon.h index adbd81f..ac79cd9 100644 --- a/luau/VM/src/lcommon.h +++ b/luau/VM/src/lcommon.h @@ -7,11 +7,7 @@ #include "luaconf.h" -// This is a forwarding header for Luau common definition (assertions, flags) -// Luau consists of several components, including compiler (Ast, Compiler) and VM (virtual machine) -// These components are fully independent, but they need a common set of utilities defined in this header -// so it needs to be shared. -#include "../../Ast/include/Luau/Common.h" +#include "Luau/Common.h" typedef LUAI_USER_ALIGNMENT_T L_Umaxalign; diff --git a/luau/VM/src/ldo.cpp b/luau/VM/src/ldo.cpp index c133a59..4cab746 100644 --- a/luau/VM/src/ldo.cpp +++ b/luau/VM/src/ldo.cpp @@ -213,6 +213,14 @@ CallInfo* luaD_growCI(lua_State* L) return ++L->ci; } +void luaD_checkCstack(lua_State *L) +{ + if (L->nCcalls == LUAI_MAXCCALLS) + luaG_runerror(L, "C stack overflow"); + else if (L->nCcalls >= (LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3))) + luaD_throw(L, LUA_ERRERR); /* error while handling stack error */ +} + /* ** Call a function (C or Lua). The function to be called is at *func. ** The arguments are on the stack, right after the function. @@ -222,12 +230,8 @@ CallInfo* luaD_growCI(lua_State* L) void luaD_call(lua_State* L, StkId func, int nResults) { if (++L->nCcalls >= LUAI_MAXCCALLS) - { - if (L->nCcalls == LUAI_MAXCCALLS) - luaG_runerror(L, "C stack overflow"); - else if (L->nCcalls >= (LUAI_MAXCCALLS + (LUAI_MAXCCALLS >> 3))) - luaD_throw(L, LUA_ERRERR); /* error while handing stack error */ - } + luaD_checkCstack(L); + if (luau_precall(L, func, nResults) == PCRLUA) { /* is a Lua function? */ L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */ @@ -241,6 +245,7 @@ void luaD_call(lua_State* L, StkId func, int nResults) if (!oldactive) resetbit(L->stackstate, THREAD_ACTIVEBIT); } + L->nCcalls--; luaC_checkGC(L); } diff --git a/luau/VM/src/ldo.h b/luau/VM/src/ldo.h index 6e16e6f..5e9472b 100644 --- a/luau/VM/src/ldo.h +++ b/luau/VM/src/ldo.h @@ -49,6 +49,7 @@ LUAI_FUNC int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t oldtop, pt LUAI_FUNC void luaD_reallocCI(lua_State* L, int newsize); LUAI_FUNC void luaD_reallocstack(lua_State* L, int newsize); LUAI_FUNC void luaD_growstack(lua_State* L, int n); +LUAI_FUNC void luaD_checkCstack(lua_State* L); LUAI_FUNC l_noret luaD_throw(lua_State* L, int errcode); LUAI_FUNC int luaD_rawrunprotected(lua_State* L, Pfunc f, void* ud); diff --git a/luau/VM/src/ltablib.cpp b/luau/VM/src/ltablib.cpp index 9c1f387..27187c6 100644 --- a/luau/VM/src/ltablib.cpp +++ b/luau/VM/src/ltablib.cpp @@ -10,10 +10,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry2, false) - -void (*lua_table_move_telemetry)(lua_State* L, int f, int e, int t, int nf, int nt); - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -199,29 +195,6 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); - void (*telemetrycb)(lua_State * L, int f, int e, int t, int nf, int nt) = lua_table_move_telemetry; - - if (DFFlag::LuauTableMoveTelemetry2 && telemetrycb && e >= f) - { - int nf = lua_objlen(L, 1); - int nt = lua_objlen(L, tt); - - bool report = false; - - // source index range must be in bounds in source table unless the table is empty (permits 1..#t moves) - if (!(f == 1 || (f >= 1 && f <= nf))) - report = true; - if (!(e == nf || (e >= 1 && e <= nf))) - report = true; - - // destination index must be in bounds in dest table or be exactly at the first empty element (permits concats) - if (!(t == nt + 1 || (t >= 1 && t <= nt))) - report = true; - - if (report) - telemetrycb(L, f, e, t, nf, nt); - } - if (e >= f) { /* otherwise, nothing to move */ luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); diff --git a/luau/VM/src/lvmexecute.cpp b/luau/VM/src/lvmexecute.cpp index 3c7c276..f9fd657 100644 --- a/luau/VM/src/lvmexecute.cpp +++ b/luau/VM/src/lvmexecute.cpp @@ -17,9 +17,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauIter, false) -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauIterCallTelemetry, false) - -void (*lua_iter_call_telemetry)(lua_State* L); // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -157,17 +154,6 @@ LUAU_NOINLINE static bool luau_loopFORG(lua_State* L, int a, int c) StkId ra = &L->base[a]; LUAU_ASSERT(ra + 3 <= L->top); - if (DFFlag::LuauIterCallTelemetry) - { - /* TODO: we might be able to stop supporting this depending on whether it's used in practice */ - void (*telemetrycb)(lua_State* L) = lua_iter_call_telemetry; - - if (telemetrycb && ttistable(ra) && fasttm(L, hvalue(ra)->metatable, TM_CALL)) - telemetrycb(L); - if (telemetrycb && ttisuserdata(ra) && fasttm(L, uvalue(ra)->metatable, TM_CALL)) - telemetrycb(L); - } - setobjs2s(L, ra + 3 + 2, ra + 2); setobjs2s(L, ra + 3 + 1, ra + 1); setobjs2s(L, ra + 3, ra); @@ -195,7 +181,7 @@ LUAU_NOINLINE static void luau_callTM(lua_State* L, int nparams, int res) ++L->nCcalls; if (L->nCcalls >= LUAI_MAXCCALLS) - luaG_runerror(L, "C stack overflow"); + luaD_checkCstack(L); luaD_checkstack(L, LUA_MINSTACK); @@ -708,7 +694,7 @@ static void luau_execute(lua_State* L) } else { - // slow-path, may invoke Lua calls via __index metamethod + // slow-path, may invoke Lua calls via __newindex metamethod L->cachedslot = slot; VM_PROTECT(luaV_settable(L, rb, kv, ra)); // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ @@ -718,7 +704,7 @@ static void luau_execute(lua_State* L) } else { - // fast-path: user data with C __index TM + // fast-path: user data with C __newindex TM const TValue* fn = 0; if (ttisuserdata(rb) && (fn = fasttm(L, uvalue(rb)->metatable, TM_NEWINDEX)) && ttisfunction(fn) && clvalue(fn)->isC) { @@ -739,7 +725,7 @@ static void luau_execute(lua_State* L) } else { - // slow-path, may invoke Lua calls via __index metamethod + // slow-path, may invoke Lua calls via __newindex metamethod VM_PROTECT(luaV_settable(L, rb, kv, ra)); VM_NEXT(); } @@ -2372,9 +2358,8 @@ static void luau_execute(lua_State* L) // fast-path: ipairs/inext if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) { - if (FFlag::LuauIter) - setnilvalue(ra); - + setnilvalue(ra); + /* ra+1 is already the table */ setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } else if (FFlag::LuauIter && !ttisfunction(ra)) @@ -2394,7 +2379,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: ipairs/inext - if (ttistable(ra + 1) && ttislightuserdata(ra + 2)) + if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) { Table* h = hvalue(ra + 1); int index = int(reinterpret_cast(pvalue(ra + 2))); @@ -2445,9 +2430,8 @@ static void luau_execute(lua_State* L) // fast-path: pairs/next if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) { - if (FFlag::LuauIter) - setnilvalue(ra); - + setnilvalue(ra); + /* ra+1 is already the table */ setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } else if (FFlag::LuauIter && !ttisfunction(ra)) @@ -2467,7 +2451,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: pairs/next - if (ttistable(ra + 1) && ttislightuserdata(ra + 2)) + if (ttisnil(ra) && ttistable(ra + 1) && ttislightuserdata(ra + 2)) { Table* h = hvalue(ra + 1); int index = int(reinterpret_cast(pvalue(ra + 2))); diff --git a/src/lib.rs b/src/lib.rs index 0bb08ad..c6282e4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,7 @@ impl Build { let include_dir = out_dir.join("include"); let source_dir_base = Path::new(env!("CARGO_MANIFEST_DIR")); + let common_include_dir = source_dir_base.join("luau").join("Common").join("include"); let ast_source_dir = source_dir_base.join("luau").join("Ast").join("src"); let ast_include_dir = source_dir_base.join("luau").join("Ast").join("include"); let compiler_source_dir = source_dir_base.join("luau").join("Compiler").join("src"); @@ -90,6 +91,7 @@ impl Build { config .clone() .include(&ast_include_dir) + .include(&common_include_dir) .add_files_by_ext(&ast_source_dir, "cpp") .out_dir(&lib_dir) .compile(ast_lib_name); @@ -100,6 +102,7 @@ impl Build { .clone() .include(&compiler_include_dir) .include(&ast_include_dir) + .include(&common_include_dir) .define("LUACODE_API", "extern \"C\"") .add_files_by_ext(&compiler_source_dir, "cpp") .out_dir(&lib_dir) @@ -110,6 +113,7 @@ impl Build { config .clone() .include(&vm_include_dir) + .include(&common_include_dir) .define("LUA_API", "extern \"C\"") // .define("LUA_USE_LONGJMP", "1") .add_files_by_ext(&vm_source_dir, "cpp")