From 3008da98df9bcd7ea8da70db007c0c0aa6359ba4 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 25 Aug 2022 13:55:08 -0700 Subject: [PATCH] Sync to upstream/release/542 --- Analysis/include/Luau/Constraint.h | 2 +- Analysis/include/Luau/ConstraintSolver.h | 2 + .../include/Luau/ConstraintSolverLogger.h | 2 +- Analysis/include/Luau/Frontend.h | 2 +- Analysis/include/Luau/ToString.h | 63 ++++- Analysis/include/Luau/TxnLog.h | 2 + Analysis/include/Luau/TypeInfer.h | 1 + Analysis/src/AstJsonEncoder.cpp | 14 + Analysis/src/ConstraintGraphBuilder.cpp | 11 +- Analysis/src/ConstraintSolver.cpp | 119 ++++++-- Analysis/src/ConstraintSolverLogger.cpp | 36 ++- Analysis/src/Frontend.cpp | 12 +- Analysis/src/Linter.cpp | 38 ++- Analysis/src/ToDot.cpp | 4 +- Analysis/src/ToString.cpp | 149 +++++----- Analysis/src/Transpiler.cpp | 22 ++ Analysis/src/TxnLog.cpp | 12 + Analysis/src/TypeInfer.cpp | 10 + Analysis/src/TypeVar.cpp | 17 +- Ast/include/Luau/Ast.h | 26 +- Ast/include/Luau/Lexer.h | 21 ++ Ast/include/Luau/Parser.h | 17 +- Ast/include/Luau/StringUtils.h | 2 +- Ast/src/Ast.cpp | 26 +- Ast/src/Lexer.cpp | 182 ++++++++++-- Ast/src/Parser.cpp | 179 +++++++++--- Ast/src/StringUtils.cpp | 10 +- Common/include/Luau/ExperimentalFlags.h | 1 + Compiler/src/Compiler.cpp | 94 ++++++- Compiler/src/ConstantFolding.cpp | 5 + Compiler/src/CostModel.cpp | 10 + Makefile | 8 +- VM/src/lapi.cpp | 69 ++--- VM/src/ldebug.cpp | 40 +-- VM/src/ldo.cpp | 8 +- VM/src/lfunc.cpp | 127 ++++++--- VM/src/lfunc.h | 1 + VM/src/lgc.cpp | 259 +++++++++++++++--- VM/src/lgc.h | 31 ++- VM/src/lgcdebug.cpp | 16 +- VM/src/lobject.h | 15 +- VM/src/lstate.cpp | 13 +- VM/src/lstate.h | 2 +- VM/src/lstring.cpp | 74 +++-- VM/src/lvmexecute.cpp | 15 +- VM/src/lvmload.cpp | 2 +- VM/src/lvmutils.cpp | 46 +--- bench/bench.py | 14 +- tests/AstJsonEncoder.test.cpp | 12 + tests/AstQuery.test.cpp | 76 +++++ tests/Autocomplete.test.cpp | 9 + tests/Compiler.test.cpp | 55 +++- tests/Conformance.test.cpp | 12 +- tests/Fixture.cpp | 37 +++ tests/Fixture.h | 70 +++++ tests/Lexer.test.cpp | 86 ++++++ tests/Linter.test.cpp | 38 ++- tests/Parser.test.cpp | 140 ++++++++++ tests/ToString.test.cpp | 52 ++-- tests/Transpiler.test.cpp | 19 ++ tests/TypeInfer.builtins.test.cpp | 26 +- tests/TypeInfer.provisional.test.cpp | 65 +++++ tests/TypeInfer.test.cpp | 36 +++ tests/conformance/gc.lua | 60 ++-- tests/conformance/stringinterp.lua | 59 ++++ tools/faillist.txt | 3 +- tools/natvis/VM.natvis | 2 +- tools/perfgraph.py | 76 ++++- tools/test_dcr.py | 31 ++- 69 files changed, 2284 insertions(+), 511 deletions(-) create mode 100644 tests/conformance/stringinterp.lua diff --git a/Analysis/include/Luau/Constraint.h b/Analysis/include/Luau/Constraint.h index ce90f2c..e9f04e7 100644 --- a/Analysis/include/Luau/Constraint.h +++ b/Analysis/include/Luau/Constraint.h @@ -35,7 +35,7 @@ struct PackSubtypeConstraint TypePackId superPack; }; -// subType ~ gen superType +// generalizedType ~ gen sourceType struct GeneralizationConstraint { TypeId generalizedType; diff --git a/Analysis/include/Luau/ConstraintSolver.h b/Analysis/include/Luau/ConstraintSolver.h index 661d120..a270ec9 100644 --- a/Analysis/include/Luau/ConstraintSolver.h +++ b/Analysis/include/Luau/ConstraintSolver.h @@ -100,6 +100,8 @@ struct ConstraintSolver void unblock(NotNull progressed); void unblock(TypeId progressed); void unblock(TypePackId progressed); + void unblock(const std::vector& types); + void unblock(const std::vector& packs); /** * @returns true if the TypeId is in a blocked state. diff --git a/Analysis/include/Luau/ConstraintSolverLogger.h b/Analysis/include/Luau/ConstraintSolverLogger.h index fe2177c..55170c4 100644 --- a/Analysis/include/Luau/ConstraintSolverLogger.h +++ b/Analysis/include/Luau/ConstraintSolverLogger.h @@ -16,7 +16,7 @@ struct ConstraintSolverLogger { std::string compileOutput(); void captureBoundarySnapshot(const Scope* rootScope, std::vector>& unsolvedConstraints); - void prepareStepSnapshot(const Scope* rootScope, NotNull current, std::vector>& unsolvedConstraints); + void prepareStepSnapshot(const Scope* rootScope, NotNull current, std::vector>& unsolvedConstraints, bool force); void commitPreparedStepSnapshot(); private: diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 82df493..f8da327 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -166,7 +166,7 @@ private: static LintResult classifyLints(const std::vector& warnings, const Config& config); - ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config); + ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete = false); std::unordered_map environments; std::unordered_map> builtinDefinitions; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index a50fef7..eabbc2b 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -33,7 +33,8 @@ struct ToStringOptions bool indent = false; size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); - std::optional nameMap; + ToStringNameMap nameMap; + std::optional DEPRECATED_nameMap; std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' std::vector namedFunctionOverrideArgNames; // If present, named function argument names will be overridden }; @@ -41,7 +42,7 @@ struct ToStringOptions struct ToStringResult { std::string name; - ToStringNameMap nameMap; + ToStringNameMap DEPRECATED_nameMap; bool invalid = false; bool error = false; @@ -49,12 +50,24 @@ struct ToStringResult bool truncated = false; }; -ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts = {}); -ToStringResult toStringDetailed(TypePackId ty, const ToStringOptions& opts = {}); +ToStringResult toStringDetailed(TypeId ty, ToStringOptions& opts); +ToStringResult toStringDetailed(TypePackId ty, ToStringOptions& opts); -std::string toString(TypeId ty, const ToStringOptions& opts); -std::string toString(TypePackId ty, const ToStringOptions& opts); -std::string toString(const Constraint& c, ToStringOptions& opts); +std::string toString(TypeId ty, ToStringOptions& opts); +std::string toString(TypePackId ty, ToStringOptions& opts); + +// These overloads are selected when a temporary ToStringOptions is passed. (eg +// via an initializer list) +inline std::string toString(TypePackId ty, ToStringOptions&& opts) +{ + // Delegate to the overload (TypePackId, ToStringOptions&) + return toString(ty, opts); +} +inline std::string toString(TypeId ty, ToStringOptions&& opts) +{ + // Delegate to the overload (TypeId, ToStringOptions&) + return toString(ty, opts); +} // These are offered as overloads rather than a default parameter so that they can be easily invoked from within the MSVC debugger. // You can use them in watch expressions! @@ -66,16 +79,42 @@ inline std::string toString(TypePackId ty) { return toString(ty, ToStringOptions{}); } -inline std::string toString(const Constraint& c) + +std::string toString(const Constraint& c, ToStringOptions& opts); + +inline std::string toString(const Constraint& c, ToStringOptions&& opts) { - ToStringOptions opts; return toString(c, opts); } -std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); -std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); +inline std::string toString(const Constraint& c) +{ + return toString(c, ToStringOptions{}); +} -std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, const ToStringOptions& opts = {}); + +std::string toString(const TypeVar& tv, ToStringOptions& opts); +std::string toString(const TypePackVar& tp, ToStringOptions& opts); + +inline std::string toString(const TypeVar& tv) +{ + ToStringOptions opts; + return toString(tv, opts); +} + +inline std::string toString(const TypePackVar& tp) +{ + ToStringOptions opts; + return toString(tp, opts); +} + +std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts); + +inline std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv) +{ + ToStringOptions opts; + return toStringNamedFunction(funcName, ftv, opts); +} // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index cd115e3..016cc92 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -263,6 +263,8 @@ struct TxnLog return Luau::get_if(&ty->ty) != nullptr; } + std::pair, std::vector> getChanges() const; + private: // unique_ptr is used to give us stable pointers across insertions into the // map. Otherwise, it would be really easy to accidentally invalidate the diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 80f9085..e253edd 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -107,6 +107,7 @@ struct TypeChecker WithPredicate checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprError& expr); WithPredicate checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); + WithPredicate checkExpr(const ScopePtr& scope, const AstExprInterpString& expr); TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType); diff --git a/Analysis/src/AstJsonEncoder.cpp b/Analysis/src/AstJsonEncoder.cpp index 2897875..8d58903 100644 --- a/Analysis/src/AstJsonEncoder.cpp +++ b/Analysis/src/AstJsonEncoder.cpp @@ -445,6 +445,14 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(class AstExprInterpString* node) + { + writeNode(node, "AstExprInterpString", [&]() { + PROP(strings); + PROP(expressions); + }); + } + void write(class AstExprTable* node) { writeNode(node, "AstExprTable", [&]() { @@ -888,6 +896,12 @@ struct AstJsonEncoder : public AstVisitor return false; } + bool visit(class AstExprInterpString* node) override + { + write(node); + return false; + } + bool visit(class AstExprLocal* node) override { write(node); diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index c1e54df..8f99474 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -210,7 +210,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) for (size_t i = 0; i < local->values.size; ++i) { - if (local->values.data[i]->is()) + AstExpr* value = local->values.data[i]; + if (value->is()) { // HACK: we leave nil-initialized things floating under the assumption that they will later be populated. // See the test TypeInfer/infer_locals_with_nil_value. @@ -218,7 +219,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) } else if (i == local->values.size - 1) { - TypePackId exprPack = checkPack(scope, local->values.data[i]); + TypePackId exprPack = checkPack(scope, value); if (i < local->vars.size) { @@ -229,7 +230,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) } else { - TypeId exprType = check(scope, local->values.data[i]); + TypeId exprType = check(scope, value); if (i < varTypes.size()) addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}); } @@ -1107,9 +1108,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b if (topLevel) { - addConstraint(scope, TypeAliasExpansionConstraint{ - /* target */ result, - }); + addConstraint(scope, TypeAliasExpansionConstraint{ /* target */ result }); } } } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 6c6d272..b2b1d47 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -11,6 +11,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauLogSolver, false); LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); +LUAU_FASTFLAG(LuauFixNameMaps) namespace Luau { @@ -19,9 +20,17 @@ namespace Luau { for (const auto& [k, v] : scope->bindings) { - auto d = toStringDetailed(v.typeId, opts); - opts.nameMap = d.nameMap; - printf("\t%s : %s\n", k.c_str(), d.name.c_str()); + if (FFlag::LuauFixNameMaps) + { + auto d = toString(v.typeId, opts); + printf("\t%s : %s\n", k.c_str(), d.c_str()); + } + else + { + auto d = toStringDetailed(v.typeId, opts); + opts.DEPRECATED_nameMap = d.DEPRECATED_nameMap; + printf("\t%s : %s\n", k.c_str(), d.name.c_str()); + } } for (NotNull child : scope->children) @@ -212,12 +221,22 @@ void dump(NotNull rootScope, ToStringOptions& opts) void dump(ConstraintSolver* cs, ToStringOptions& opts) { printf("constraints:\n"); - for (const Constraint* c : cs->unsolvedConstraints) + for (NotNull c : cs->unsolvedConstraints) { - printf("\t%s\n", toString(*c, opts).c_str()); + auto it = cs->blockedConstraints.find(c); + int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); + printf("\t%d\t%s\n", blockCount, toString(*c, opts).c_str()); - for (const Constraint* dep : c->dependencies) - printf("\t\t%s\n", toString(*dep, opts).c_str()); + for (NotNull dep : c->dependencies) + { + auto unsolvedIter = std::find(begin(cs->unsolvedConstraints), end(cs->unsolvedConstraints), dep); + if (unsolvedIter == cs->unsolvedConstraints.end()) + continue; + + auto it = cs->blockedConstraints.find(dep); + int blockCount = it == cs->blockedConstraints.end() ? 0 : int(it->second); + printf("\t%d\t\t%s\n", blockCount, toString(*dep, opts).c_str()); + } } } @@ -273,7 +292,7 @@ void ConstraintSolver::run() if (FFlag::DebugLuauLogSolverToJson) { - logger.prepareStepSnapshot(rootScope, c, unsolvedConstraints); + logger.prepareStepSnapshot(rootScope, c, unsolvedConstraints, force); } bool success = tryDispatch(c, force); @@ -282,6 +301,7 @@ void ConstraintSolver::run() if (success) { + unblock(c); unsolvedConstraints.erase(unsolvedConstraints.begin() + i); if (FFlag::DebugLuauLogSolverToJson) @@ -375,18 +395,12 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNullscope); - unblock(c.subType); - unblock(c.superType); - return true; } bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull constraint, bool force) { unify(c.subPack, c.superPack, constraint->scope); - unblock(c.subPack); - unblock(c.superPack); - return true; } @@ -395,13 +409,12 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullty.emplace(c.sourceType); - else - unify(c.generalizedType, c.sourceType, constraint->scope); - TypeId generalized = quantify(arena, c.sourceType, constraint->scope); - *asMutable(c.sourceType) = *generalized; + + if (isBlocked(c.generalizedType)) + asMutable(c.generalizedType)->ty.emplace(generalized); + else + unify(c.generalizedType, generalized, constraint->scope); unblock(c.generalizedType); unblock(c.sourceType); @@ -455,23 +468,44 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull + * + * This constraint is the one that is meant to unblock A, so it doesn't + * make any sense to stop and wait for someone else to do it. + */ + if (leftType != resultType && rightType != resultType) + { + block(c.leftType, constraint); + block(c.rightType, constraint); + return false; + } } if (isNumber(leftType)) { unify(leftType, rightType, constraint->scope); - asMutable(c.resultType)->ty.emplace(leftType); + asMutable(resultType)->ty.emplace(leftType); return true; } - if (get(leftType) && !force) - return block(leftType, constraint); + if (!force) + { + if (get(leftType)) + return block(leftType, constraint); + } + + if (isBlocked(leftType)) + { + asMutable(resultType)->ty.emplace(getSingletonTypes().errorRecoveryType()); + // reportError(constraint->location, CannotInferBinaryOperation{c.op, std::nullopt, CannotInferBinaryOperation::Operation}); + return true; + } // TODO metatables, classes @@ -706,17 +740,23 @@ void ConstraintSolver::block_(BlockedConstraintId target, NotNull target, NotNull constraint) { + if (FFlag::DebugLuauLogSolver) + printf("block Constraint %s on\t%s\n", toString(*target).c_str(), toString(*constraint).c_str()); block_(target, constraint); } bool ConstraintSolver::block(TypeId target, NotNull constraint) { + if (FFlag::DebugLuauLogSolver) + printf("block TypeId %s on\t%s\n", toString(target).c_str(), toString(*constraint).c_str()); block_(target, constraint); return false; } bool ConstraintSolver::block(TypePackId target, NotNull constraint) { + if (FFlag::DebugLuauLogSolver) + printf("block TypeId %s on\t%s\n", toString(target).c_str(), toString(*constraint).c_str()); block_(target, constraint); return false; } @@ -731,6 +771,9 @@ void ConstraintSolver::unblock_(BlockedConstraintId progressed) for (NotNull unblockedConstraint : it->second) { auto& count = blockedConstraints[unblockedConstraint]; + if (FFlag::DebugLuauLogSolver) + printf("Unblocking count=%d\t%s\n", int(count), toString(*unblockedConstraint).c_str()); + // This assertion being hit indicates that `blocked` and // `blockedConstraints` desynchronized at some point. This is problematic // because we rely on this count being correct to skip over blocked @@ -757,6 +800,18 @@ void ConstraintSolver::unblock(TypePackId progressed) return unblock_(progressed); } +void ConstraintSolver::unblock(const std::vector& types) +{ + for (TypeId t : types) + unblock(t); +} + +void ConstraintSolver::unblock(const std::vector& packs) +{ + for (TypePackId t : packs) + unblock(t); +} + bool ConstraintSolver::isBlocked(TypeId ty) { return nullptr != get(follow(ty)) || nullptr != get(follow(ty)); @@ -774,7 +829,13 @@ void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull sc Unifier u{arena, Mode::Strict, scope, Location{}, Covariant, sharedState}; u.tryUnify(subType, superType); + + const auto [changedTypes, changedPacks] = u.log.getChanges(); + u.log.commit(); + + unblock(changedTypes); + unblock(changedPacks); } void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) @@ -783,7 +844,13 @@ void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull scope) diff --git a/Analysis/src/ConstraintSolverLogger.cpp b/Analysis/src/ConstraintSolverLogger.cpp index adb9c54..097ceee 100644 --- a/Analysis/src/ConstraintSolverLogger.cpp +++ b/Analysis/src/ConstraintSolverLogger.cpp @@ -4,6 +4,8 @@ #include "Luau/JsonEmitter.h" +LUAU_FASTFLAG(LuauFixNameMaps); + namespace Luau { @@ -17,9 +19,14 @@ static void dumpScopeAndChildren(const Scope* scope, Json::JsonEmitter& emitter, for (const auto& [name, binding] : scope->bindings) { - ToStringResult result = toStringDetailed(binding.typeId, opts); - opts.nameMap = std::move(result.nameMap); - o.writePair(name.c_str(), result.name); + if (FFlag::LuauFixNameMaps) + o.writePair(name.c_str(), toString(binding.typeId, opts)); + else + { + ToStringResult result = toStringDetailed(binding.typeId, opts); + opts.DEPRECATED_nameMap = std::move(result.DEPRECATED_nameMap); + o.writePair(name.c_str(), result.name); + } } o.finish(); @@ -30,6 +37,7 @@ static void dumpScopeAndChildren(const Scope* scope, Json::JsonEmitter& emitter, Json::ArrayEmitter a = emitter.writeArray(); for (const Scope* child : scope->children) { + emitter.writeComma(); dumpScopeAndChildren(child, emitter, opts); } @@ -39,7 +47,8 @@ static void dumpScopeAndChildren(const Scope* scope, Json::JsonEmitter& emitter, static std::string dumpConstraintsToDot(std::vector>& constraints, ToStringOptions& opts) { - std::string result = "digraph Constraints {\\n"; + std::string result = "digraph Constraints {\n"; + result += "rankdir=LR\n"; std::unordered_set> contained; for (NotNull c : constraints) @@ -49,11 +58,19 @@ static std::string dumpConstraintsToDot(std::vector>& for (NotNull c : constraints) { + std::string shape; + if (get(*c)) + shape = "box"; + else if (get(*c)) + shape = "box3d"; + else + shape = "oval"; + std::string id = std::to_string(reinterpret_cast(c.get())); result += id; - result += " [label=\\\""; - result += toString(*c, opts).c_str(); - result += "\\\"];\\n"; + result += " [label=\""; + result += toString(*c, opts); + result += "\" shape=" + shape + "];\n"; for (NotNull dep : c->dependencies) { @@ -63,7 +80,7 @@ static std::string dumpConstraintsToDot(std::vector>& result += std::to_string(reinterpret_cast(dep.get())); result += " -> "; result += id; - result += ";\\n"; + result += ";\n"; } } @@ -102,7 +119,7 @@ void ConstraintSolverLogger::captureBoundarySnapshot(const Scope* rootScope, std } void ConstraintSolverLogger::prepareStepSnapshot( - const Scope* rootScope, NotNull current, std::vector>& unsolvedConstraints) + const Scope* rootScope, NotNull current, std::vector>& unsolvedConstraints, bool force) { Json::JsonEmitter emitter; Json::ObjectEmitter o = emitter.writeObject(); @@ -110,6 +127,7 @@ void ConstraintSolverLogger::prepareStepSnapshot( o.writePair("constraintGraph", dumpConstraintsToDot(unsolvedConstraints, opts)); o.writePair("currentId", std::to_string(reinterpret_cast(current.get()))); o.writePair("current", toString(*current, opts)); + o.writePair("force", force); emitter.writeComma(); Json::write(emitter, "rootScope"); emitter.writeRaw(":"); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 8ab4e86..c8c5d4b 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -455,7 +455,7 @@ CheckResult Frontend::check(const ModuleName& name, std::optional& buildQueue, CheckResult& chec return cyclic; } -ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config) +ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) { - ScopePtr result = typeChecker.globalScope; + ScopePtr result; + if (forAutocomplete) + result = typeCheckerForAutocomplete.globalScope; + else + result = typeChecker.globalScope; if (module.environmentName) result = getEnvironmentScope(*module.environmentName); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 2d05837..426ff9d 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -15,6 +15,7 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) LUAU_FASTFLAGVARIABLE(LuauLintComparisonPrecedence, false) +LUAU_FASTFLAGVARIABLE(LuauLintFixDeprecationMessage, false) namespace Luau { @@ -206,6 +207,24 @@ static bool similar(AstExpr* lhs, AstExpr* rhs) return true; } CASE(AstExprIfElse) return similar(le->condition, re->condition) && similar(le->trueExpr, re->trueExpr) && similar(le->falseExpr, re->falseExpr); + CASE(AstExprInterpString) + { + if (le->strings.size != re->strings.size) + return false; + + if (le->expressions.size != re->expressions.size) + return false; + + for (size_t i = 0; i < le->strings.size; ++i) + if (le->strings.data[i].size != re->strings.data[i].size || memcmp(le->strings.data[i].data, re->strings.data[i].data, le->strings.data[i].size) != 0) + return false; + + for (size_t i = 0; i < le->expressions.size; ++i) + if (!similar(le->expressions.data[i], re->expressions.data[i])) + return false; + + return true; + } else { LUAU_ASSERT(!"Unknown expression type"); @@ -288,11 +307,22 @@ private: emitWarning(*context, LintWarning::Code_UnknownGlobal, gv->location, "Unknown global '%s'", gv->name.value); else if (g->deprecated) { - if (*g->deprecated) - emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", - gv->name.value, *g->deprecated); + if (FFlag::LuauLintFixDeprecationMessage) + { + if (const char* replacement = *g->deprecated; replacement && strlen(replacement)) + emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", + gv->name.value, replacement); + else + emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); + } else - emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); + { + if (*g->deprecated) + emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated, use '%s' instead", + gv->name.value, *g->deprecated); + else + emitWarning(*context, LintWarning::Code_DeprecatedGlobal, gv->location, "Global '%s' is deprecated", gv->name.value); + } } } diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp index 6b677bb..0d989ca 100644 --- a/Analysis/src/ToDot.cpp +++ b/Analysis/src/ToDot.cpp @@ -73,7 +73,7 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) { if (get(ty)) - formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); + formatAppend(result, "n%d [label=\"%s\"];\n", index, toString(ty).c_str()); else if (get(ty)) formatAppend(result, "n%d [label=\"any\"];\n", index); } @@ -233,7 +233,7 @@ void StateDot::visitChildren(TypeId ty, int index) } else if (get(ty)) { - formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); + formatAppend(result, "PrimitiveTypeVar %s", toString(ty).c_str()); finishNodeLabel(ty); finishNode(); } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index e31e690..ace44cd 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -13,6 +13,7 @@ LUAU_FASTFLAG(LuauLowerBoundsCalculation) LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauSpecialTypesAsterisked, false) +LUAU_FASTFLAGVARIABLE(LuauFixNameMaps, false) /* * Prefix generic typenames with gen- @@ -116,7 +117,7 @@ static std::pair> canUseTypeNameInScope(ScopePtr struct StringifierState { - const ToStringOptions& opts; + ToStringOptions& opts; ToStringResult& result; std::unordered_map cycleNames; @@ -127,18 +128,28 @@ struct StringifierState bool exhaustive; - StringifierState(const ToStringOptions& opts, ToStringResult& result, const std::optional& nameMap) + StringifierState(ToStringOptions& opts, ToStringResult& result, const std::optional& DEPRECATED_nameMap) : opts(opts) , result(result) , exhaustive(opts.exhaustive) { - if (nameMap) - result.nameMap = *nameMap; + if (!FFlag::LuauFixNameMaps && DEPRECATED_nameMap) + result.DEPRECATED_nameMap = *DEPRECATED_nameMap; - for (const auto& [_, v] : result.nameMap.typeVars) - usedNames.insert(v); - for (const auto& [_, v] : result.nameMap.typePacks) - usedNames.insert(v); + if (!FFlag::LuauFixNameMaps) + { + for (const auto& [_, v] : result.DEPRECATED_nameMap.typeVars) + usedNames.insert(v); + for (const auto& [_, v] : result.DEPRECATED_nameMap.typePacks) + usedNames.insert(v); + } + else + { + for (const auto& [_, v] : opts.nameMap.typeVars) + usedNames.insert(v); + for (const auto& [_, v] : opts.nameMap.typePacks) + usedNames.insert(v); + } } bool hasSeen(const void* tv) @@ -161,8 +172,8 @@ struct StringifierState std::string getName(TypeId ty) { - const size_t s = result.nameMap.typeVars.size(); - std::string& n = result.nameMap.typeVars[ty]; + const size_t s = FFlag::LuauFixNameMaps ? opts.nameMap.typeVars.size() : result.DEPRECATED_nameMap.typeVars.size(); + std::string& n = FFlag::LuauFixNameMaps ? opts.nameMap.typeVars[ty] : result.DEPRECATED_nameMap.typeVars[ty]; if (!n.empty()) return n; @@ -184,8 +195,8 @@ struct StringifierState std::string getName(TypePackId ty) { - const size_t s = result.nameMap.typePacks.size(); - std::string& n = result.nameMap.typePacks[ty]; + const size_t s = FFlag::LuauFixNameMaps ? opts.nameMap.typePacks.size() : result.DEPRECATED_nameMap.typePacks.size(); + std::string& n = FFlag::LuauFixNameMaps ? opts.nameMap.typePacks[ty] : result.DEPRECATED_nameMap.typePacks[ty]; if (!n.empty()) return n; @@ -377,7 +388,10 @@ struct TypeVarStringifier if (gtv.explicitName) { state.usedNames.insert(gtv.name); - state.result.nameMap.typeVars[ty] = gtv.name; + if (FFlag::LuauFixNameMaps) + state.opts.nameMap.typeVars[ty] = gtv.name; + else + state.result.DEPRECATED_nameMap.typeVars[ty] = gtv.name; state.emit(gtv.name); } else @@ -987,7 +1001,10 @@ struct TypePackStringifier if (pack.explicitName) { state.usedNames.insert(pack.name); - state.result.nameMap.typePacks[tp] = pack.name; + if (FFlag::LuauFixNameMaps) + state.opts.nameMap.typePacks[tp] = pack.name; + else + state.result.DEPRECATED_nameMap.typePacks[tp] = pack.name; state.emit(pack.name); } else @@ -1066,7 +1083,7 @@ static void assignCycleNames(const std::set& cycles, const std::set cycles; std::set cycleTPs; @@ -1176,7 +1195,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) return result; } -ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) +ToStringResult toStringDetailed(TypePackId tp, ToStringOptions& opts) { /* * 1. Walk the TypeVar and track seen TypeIds. When you reencounter a TypeId, add it to a set of seen cycles. @@ -1185,7 +1204,9 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) * 4. Print out the root of the type using the same algorithm as step 3. */ ToStringResult result; - StringifierState state{opts, result, opts.nameMap}; + StringifierState state = FFlag::LuauFixNameMaps + ? StringifierState{opts, result, opts.nameMap} + : StringifierState{opts, result, opts.DEPRECATED_nameMap}; std::set cycles; std::set cycleTPs; @@ -1248,30 +1269,32 @@ ToStringResult toStringDetailed(TypePackId tp, const ToStringOptions& opts) return result; } -std::string toString(TypeId ty, const ToStringOptions& opts) +std::string toString(TypeId ty, ToStringOptions& opts) { return toStringDetailed(ty, opts).name; } -std::string toString(TypePackId tp, const ToStringOptions& opts) +std::string toString(TypePackId tp, ToStringOptions& opts) { return toStringDetailed(tp, opts).name; } -std::string toString(const TypeVar& tv, const ToStringOptions& opts) +std::string toString(const TypeVar& tv, ToStringOptions& opts) { - return toString(const_cast(&tv), std::move(opts)); + return toString(const_cast(&tv), opts); } -std::string toString(const TypePackVar& tp, const ToStringOptions& opts) +std::string toString(const TypePackVar& tp, ToStringOptions& opts) { - return toString(const_cast(&tp), std::move(opts)); + return toString(const_cast(&tp), opts); } -std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, const ToStringOptions& opts) +std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, ToStringOptions& opts) { ToStringResult result; - StringifierState state(opts, result, opts.nameMap); + StringifierState state = FFlag::LuauFixNameMaps + ? StringifierState{opts, result, opts.nameMap} + : StringifierState{opts, result, opts.DEPRECATED_nameMap}; TypeVarStringifier tvs{state}; state.emit(funcName); @@ -1403,69 +1426,67 @@ std::string toString(const Constraint& constraint, ToStringOptions& opts) auto go = [&opts](auto&& c) { using T = std::decay_t; + // TODO: Inline and delete this function when clipping FFlag::LuauFixNameMaps + auto tos = [](auto&& a, ToStringOptions& opts) + { + if (FFlag::LuauFixNameMaps) + return toString(a, opts); + else + { + ToStringResult tsr = toStringDetailed(a, opts); + opts.DEPRECATED_nameMap = std::move(tsr.DEPRECATED_nameMap); + return tsr.name; + } + }; + if constexpr (std::is_same_v) { - ToStringResult subStr = toStringDetailed(c.subType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(c.superType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " <: " + superStr.name; + std::string subStr = tos(c.subType, opts); + std::string superStr = tos(c.superType, opts); + return subStr + " <: " + superStr; } else if constexpr (std::is_same_v) { - ToStringResult subStr = toStringDetailed(c.subPack, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(c.superPack, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " <: " + superStr.name; + std::string subStr = tos(c.subPack, opts); + std::string superStr = tos(c.superPack, opts); + return subStr + " <: " + superStr; } else if constexpr (std::is_same_v) { - ToStringResult subStr = toStringDetailed(c.generalizedType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(c.sourceType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " ~ gen " + superStr.name; + std::string subStr = tos(c.generalizedType, opts); + std::string superStr = tos(c.sourceType, opts); + return subStr + " ~ gen " + superStr; } else if constexpr (std::is_same_v) { - ToStringResult subStr = toStringDetailed(c.subType, opts); - opts.nameMap = std::move(subStr.nameMap); - ToStringResult superStr = toStringDetailed(c.superType, opts); - opts.nameMap = std::move(superStr.nameMap); - return subStr.name + " ~ inst " + superStr.name; + std::string subStr = tos(c.subType, opts); + std::string superStr = tos(c.superType, opts); + return subStr + " ~ inst " + superStr; } else if constexpr (std::is_same_v) { - ToStringResult resultStr = toStringDetailed(c.resultType, opts); - opts.nameMap = std::move(resultStr.nameMap); - ToStringResult operandStr = toStringDetailed(c.operandType, opts); - opts.nameMap = std::move(operandStr.nameMap); + std::string resultStr = tos(c.resultType, opts); + std::string operandStr = tos(c.operandType, opts); - return resultStr.name + " ~ Unary<" + toString(c.op) + ", " + operandStr.name + ">"; + return resultStr + " ~ Unary<" + toString(c.op) + ", " + operandStr + ">"; } else if constexpr (std::is_same_v) { - ToStringResult resultStr = toStringDetailed(c.resultType); - opts.nameMap = std::move(resultStr.nameMap); - ToStringResult leftStr = toStringDetailed(c.leftType); - opts.nameMap = std::move(leftStr.nameMap); - ToStringResult rightStr = toStringDetailed(c.rightType); - opts.nameMap = std::move(rightStr.nameMap); + std::string resultStr = tos(c.resultType, opts); + std::string leftStr = tos(c.leftType, opts); + std::string rightStr = tos(c.rightType, opts); - return resultStr.name + " ~ Binary<" + toString(c.op) + ", " + leftStr.name + ", " + rightStr.name + ">"; + return resultStr + " ~ Binary<" + toString(c.op) + ", " + leftStr + ", " + rightStr + ">"; } else if constexpr (std::is_same_v) { - ToStringResult namedStr = toStringDetailed(c.namedType, opts); - opts.nameMap = std::move(namedStr.nameMap); - return "@name(" + namedStr.name + ") = " + c.name; + std::string namedStr = tos(c.namedType, opts); + return "@name(" + namedStr + ") = " + c.name; } else if constexpr (std::is_same_v) { - ToStringResult targetStr = toStringDetailed(c.target, opts); - opts.nameMap = std::move(targetStr.nameMap); - return "expand " + targetStr.name; + std::string targetStr = tos(c.target, opts); + return "expand " + targetStr; } else static_assert(always_false_v, "Non-exhaustive constraint switch"); diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 9feff1c..cdfe654 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -511,6 +511,28 @@ struct Printer writer.keyword("else"); visualize(*a->falseExpr); } + else if (const auto& a = expr.as()) + { + writer.symbol("`"); + + size_t index = 0; + + for (const auto& string : a->strings) + { + writer.write(escape(std::string_view(string.data, string.size), /* escapeForInterpString = */ true)); + + if (index < a->expressions.size) + { + writer.symbol("{"); + visualize(*a->expressions.data[index]); + writer.symbol("}"); + } + + index++; + } + + writer.symbol("`"); + } else if (const auto& a = expr.as()) { writer.symbol("(error-expr"); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index b3f60d3..74d7730 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -344,4 +344,16 @@ TypePackId TxnLog::follow(TypePackId tp) const }); } +std::pair, std::vector> TxnLog::getChanges() const +{ + std::pair, std::vector> result; + + for (const auto& [typeId, _newState] : typeVarChanges) + result.first.push_back(typeId); + for (const auto& [typePackId, _newState] : typePackChanges) + result.second.push_back(typePackId); + + return result; +} + } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 9886fb1..7716805 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -1805,6 +1805,8 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp result = checkExpr(scope, *a); else if (auto a = expr.as()) result = checkExpr(scope, *a, expectedType); + else if (auto a = expr.as()) + result = checkExpr(scope, *a); else ice("Unhandled AstExpr?"); @@ -2999,6 +3001,14 @@ WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExp return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; } +WithPredicate TypeChecker::checkExpr(const ScopePtr& scope, const AstExprInterpString& expr) +{ + for (AstExpr* expr : expr.expressions) + checkExpr(scope, *expr); + + return {stringType}; +} + TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) { return checkLValueBinding(scope, expr); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 9020b1a..8974f8c 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -27,6 +27,7 @@ LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAGVARIABLE(LuauDeduceGmatchReturnTypes, false) LUAU_FASTFLAGVARIABLE(LuauMaybeGenericIntersectionTypes, false) LUAU_FASTFLAGVARIABLE(LuauDeduceFindMatchReturnTypes, false) +LUAU_FASTFLAGVARIABLE(LuauStringFormatArgumentErrorFix, false) namespace Luau { @@ -1139,11 +1140,21 @@ std::optional> magicFunctionFormat( } // if we know the argument count or if we have too many arguments for sure, we can issue an error - size_t actualParamSize = params.size() - paramOffset; + if (FFlag::LuauStringFormatArgumentErrorFix) + { + size_t numActualParams = params.size(); + size_t numExpectedParams = expected.size() + 1; // + 1 for the format string - if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) - typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); + if (numExpectedParams != numActualParams && (!tail || numExpectedParams < numActualParams)) + typechecker.reportError(TypeError{expr.location, CountMismatch{numExpectedParams, numActualParams}}); + } + else + { + size_t actualParamSize = params.size() - paramOffset; + if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) + typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); + } return WithPredicate{arena.addTypePack({typechecker.stringType})}; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 1e164d0..612283f 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -134,6 +134,10 @@ public: { return visit((class AstExpr*)node); } + virtual bool visit(class AstExprInterpString* node) + { + return visit((class AstExpr*)node); + } virtual bool visit(class AstExprError* node) { return visit((class AstExpr*)node); @@ -594,9 +598,9 @@ public: LUAU_RTTI(AstExprFunction) AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, - std::optional argLocation = std::nullopt); + AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, const std::optional& returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, + bool hasEnd = false, const std::optional& argLocation = std::nullopt); void visit(AstVisitor* visitor) override; @@ -732,6 +736,22 @@ public: AstExpr* falseExpr; }; +class AstExprInterpString : public AstExpr +{ +public: + LUAU_RTTI(AstExprInterpString) + + AstExprInterpString(const Location& location, const AstArray>& strings, const AstArray& expressions); + + void visit(AstVisitor* visitor) override; + + /// An interpolated string such as `foo{bar}baz` is represented as + /// an array of strings for "foo" and "bar", and an array of expressions for "baz". + /// `strings` will always have one more element than `expressions`. + AstArray> strings; + AstArray expressions; +}; + class AstStatBlock : public AstStat { public: diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index 4f3dbbd..7e7fe76 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -61,6 +61,12 @@ struct Lexeme SkinnyArrow, DoubleColon, + InterpStringBegin, + InterpStringMid, + InterpStringEnd, + // An interpolated string with no expressions (like `x`) + InterpStringSimple, + AddAssign, SubAssign, MulAssign, @@ -80,6 +86,8 @@ struct Lexeme BrokenString, BrokenComment, BrokenUnicode, + BrokenInterpDoubleBrace, + Error, Reserved_BEGIN, @@ -208,6 +216,11 @@ private: Lexeme readLongString(const Position& start, int sep, Lexeme::Type ok, Lexeme::Type broken); Lexeme readQuotedString(); + Lexeme readInterpolatedStringBegin(); + Lexeme readInterpolatedStringSection(Position start, Lexeme::Type formatType, Lexeme::Type endType); + + void readBackslashInString(); + std::pair readName(); Lexeme readNumber(const Position& start, unsigned int startOffset); @@ -231,6 +244,14 @@ private: bool skipComments; bool readNames; + + enum class BraceType + { + InterpolatedString, + Normal + }; + + std::vector braceStack; }; inline bool isSpace(char ch) diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 046706d..956fcf6 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -11,6 +11,7 @@ #include #include +#include namespace Luau { @@ -109,8 +110,10 @@ private: // for namelist in explist do block end | AstStat* parseFor(); - // function funcname funcbody | // funcname ::= Name {`.' Name} [`:' Name] + AstExpr* parseFunctionName(Location start, bool& hasself, AstName& debugname); + + // function funcname funcbody AstStat* parseFunctionStat(); // local function Name funcbody | @@ -135,8 +138,10 @@ private: // var [`+=' | `-=' | `*=' | `/=' | `%=' | `^=' | `..='] exp AstStat* parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op); - // funcbody ::= `(' [parlist] `)' block end - // parlist ::= namelist [`,' `...'] | `...' + std::pair> prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args); + + // funcbodyhead ::= `(' [namelist [`,' `...'] | `...'] `)' [`:` TypeAnnotation] + // funcbody ::= funcbodyhead block end std::pair parseFunctionBody( bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName); @@ -148,7 +153,7 @@ private: // bindinglist ::= (binding | `...') {`,' bindinglist} // Returns the location of the vararg ..., or std::nullopt if the function is not vararg. - std::pair, AstTypePack*> parseBindingList(TempVector& result, bool allowDot3 = false); + std::tuple parseBindingList(TempVector& result, bool allowDot3 = false); AstType* parseOptionalTypeAnnotation(); @@ -228,6 +233,9 @@ private: // TODO: Add grammar rules here? AstExpr* parseIfElseExpr(); + // stringinterp ::= exp { exp} + AstExpr* parseInterpString(); + // Name std::optional parseNameOpt(const char* context = nullptr); Name parseName(const char* context = nullptr); @@ -379,6 +387,7 @@ private: std::vector matchRecoveryStopOnToken; std::vector scratchStat; + std::vector> scratchString; std::vector scratchExpr; std::vector scratchExprAux; std::vector scratchName; diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index 6ae9e97..dab7610 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -35,6 +35,6 @@ bool equalsLower(std::string_view lhs, std::string_view rhs); size_t hashRange(const char* data, size_t size); -std::string escape(std::string_view s); +std::string escape(std::string_view s, bool escapeForInterpString = false); bool isIdentifier(std::string_view s); } // namespace Luau diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 3066b75..8291a5b 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -160,17 +160,17 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, - AstLocal* self, const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, - const AstName& debugname, std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, - std::optional argLocation) + AstLocal* self, const AstArray& args, bool vararg, const Location& varargLocation, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, const std::optional& returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, + const std::optional& argLocation) : AstExpr(ClassIndex(), location) , generics(generics) , genericPacks(genericPacks) , self(self) , args(args) , returnAnnotation(returnAnnotation) - , vararg(vararg.has_value()) - , varargLocation(vararg.value_or(Location())) + , vararg(vararg) + , varargLocation(varargLocation) , varargAnnotation(varargAnnotation) , body(body) , functionDepth(functionDepth) @@ -349,6 +349,22 @@ AstExprError::AstExprError(const Location& location, const AstArray& e { } +AstExprInterpString::AstExprInterpString(const Location& location, const AstArray>& strings, const AstArray& expressions) + : AstExpr(ClassIndex(), location) + , strings(strings) + , expressions(expressions) +{ +} + +void AstExprInterpString::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstExpr* expr : expressions) + expr->visit(visitor); + } +} + void AstExprError::visit(AstVisitor* visitor) { if (visitor->visit(this)) diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index a1f1d46..b4db8bd 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -6,6 +6,8 @@ #include +LUAU_FASTFLAG(LuauInterpolatedStringBaseSupport) + namespace Luau { @@ -89,7 +91,18 @@ Lexeme::Lexeme(const Location& location, Type type, const char* data, size_t siz , length(unsigned(size)) , data(data) { - LUAU_ASSERT(type == RawString || type == QuotedString || type == Number || type == Comment || type == BlockComment); + LUAU_ASSERT( + type == RawString + || type == QuotedString + || type == InterpStringBegin + || type == InterpStringMid + || type == InterpStringEnd + || type == InterpStringSimple + || type == BrokenInterpDoubleBrace + || type == Number + || type == Comment + || type == BlockComment + ); } Lexeme::Lexeme(const Location& location, Type type, const char* name) @@ -160,6 +173,18 @@ std::string Lexeme::toString() const case QuotedString: return data ? format("\"%.*s\"", length, data) : "string"; + case InterpStringBegin: + return data ? format("`%.*s{", length, data) : "the beginning of an interpolated string"; + + case InterpStringMid: + return data ? format("}%.*s{", length, data) : "the middle of an interpolated string"; + + case InterpStringEnd: + return data ? format("}%.*s`", length, data) : "the end of an interpolated string"; + + case InterpStringSimple: + return data ? format("`%.*s`", length, data) : "interpolated string"; + case Number: return data ? format("'%.*s'", length, data) : "number"; @@ -175,6 +200,9 @@ std::string Lexeme::toString() const case BrokenComment: return "unfinished comment"; + case BrokenInterpDoubleBrace: + return "'{{', which is invalid (did you mean '\\{'?)"; + case BrokenUnicode: if (codepoint) { @@ -515,6 +543,32 @@ Lexeme Lexer::readLongString(const Position& start, int sep, Lexeme::Type ok, Le return Lexeme(Location(start, position()), broken); } +void Lexer::readBackslashInString() +{ + LUAU_ASSERT(peekch() == '\\'); + consume(); + switch (peekch()) + { + case '\r': + consume(); + if (peekch() == '\n') + consume(); + break; + + case 0: + break; + + case 'z': + consume(); + while (isSpace(peekch())) + consume(); + break; + + default: + consume(); + } +} + Lexeme Lexer::readQuotedString() { Position start = position(); @@ -535,27 +589,7 @@ Lexeme Lexer::readQuotedString() return Lexeme(Location(start, position()), Lexeme::BrokenString); case '\\': - consume(); - switch (peekch()) - { - case '\r': - consume(); - if (peekch() == '\n') - consume(); - break; - - case 0: - break; - - case 'z': - consume(); - while (isSpace(peekch())) - consume(); - break; - - default: - consume(); - } + readBackslashInString(); break; default: @@ -568,6 +602,69 @@ Lexeme Lexer::readQuotedString() return Lexeme(Location(start, position()), Lexeme::QuotedString, &buffer[startOffset], offset - startOffset - 1); } +Lexeme Lexer::readInterpolatedStringBegin() +{ + LUAU_ASSERT(peekch() == '`'); + + Position start = position(); + consume(); + + return readInterpolatedStringSection(start, Lexeme::InterpStringBegin, Lexeme::InterpStringSimple); +} + +Lexeme Lexer::readInterpolatedStringSection(Position start, Lexeme::Type formatType, Lexeme::Type endType) +{ + unsigned int startOffset = offset; + + while (peekch() != '`') + { + switch (peekch()) + { + case 0: + case '\r': + case '\n': + return Lexeme(Location(start, position()), Lexeme::BrokenString); + + case '\\': + // Allow for \u{}, which would otherwise be consumed by looking for { + if (peekch(1) == 'u' && peekch(2) == '{') + { + consume(); // backslash + consume(); // u + consume(); // { + break; + } + + readBackslashInString(); + break; + + case '{': + { + braceStack.push_back(BraceType::InterpolatedString); + + if (peekch(1) == '{') + { + Lexeme brokenDoubleBrace = Lexeme(Location(start, position()), Lexeme::BrokenInterpDoubleBrace, &buffer[startOffset], offset - startOffset); + consume(); + consume(); + return brokenDoubleBrace; + } + + Lexeme lexemeOutput(Location(start, position()), Lexeme::InterpStringBegin, &buffer[startOffset], offset - startOffset); + consume(); + return lexemeOutput; + } + + default: + consume(); + } + } + + consume(); + + return Lexeme(Location(start, position()), endType, &buffer[startOffset], offset - startOffset - 1); +} + Lexeme Lexer::readNumber(const Position& start, unsigned int startOffset) { LUAU_ASSERT(isDigit(peekch())); @@ -660,6 +757,36 @@ Lexeme Lexer::readNext() } } + case '{': + { + consume(); + + if (!braceStack.empty()) + braceStack.push_back(BraceType::Normal); + + return Lexeme(Location(start, 1), '{'); + } + + case '}': + { + consume(); + + if (braceStack.empty()) + { + return Lexeme(Location(start, 1), '}'); + } + + const BraceType braceStackTop = braceStack.back(); + braceStack.pop_back(); + + if (braceStackTop != BraceType::InterpolatedString) + { + return Lexeme(Location(start, 1), '}'); + } + + return readInterpolatedStringSection(position(), Lexeme::InterpStringMid, Lexeme::InterpStringEnd); + } + case '=': { consume(); @@ -716,6 +843,15 @@ Lexeme Lexer::readNext() case '\'': return readQuotedString(); + case '`': + if (FFlag::LuauInterpolatedStringBaseSupport) + return readInterpolatedStringBegin(); + else + { + consume(); + return Lexeme(Location(start, 1), '`'); + } + case '.': consume(); @@ -817,8 +953,6 @@ Lexeme Lexer::readNext() case '(': case ')': - case '{': - case '}': case ']': case ';': case ',': diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index e46eebf..b6de27d 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -23,10 +23,14 @@ LUAU_FASTFLAGVARIABLE(LuauErrorDoubleHexPrefix, false) LUAU_FASTFLAGVARIABLE(LuauLintParseIntegerIssues, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseIntegerIssues, false) +LUAU_FASTFLAGVARIABLE(LuauInterpolatedStringBaseSupport, false) + bool lua_telemetry_parsed_out_of_range_bin_integer = false; bool lua_telemetry_parsed_out_of_range_hex_integer = false; bool lua_telemetry_parsed_double_prefix_hex_integer = false; +#define ERROR_INVALID_INTERP_DOUBLE_BRACE "Double braces are not permitted within interpolated strings. Did you mean '\\{'?" + namespace Luau { @@ -601,16 +605,11 @@ AstStat* Parser::parseFor() } } -// function funcname funcbody | // funcname ::= Name {`.' Name} [`:' Name] -AstStat* Parser::parseFunctionStat() +AstExpr* Parser::parseFunctionName(Location start, bool& hasself, AstName& debugname) { - Location start = lexer.current().location; - - Lexeme matchFunction = lexer.current(); - nextLexeme(); - - AstName debugname = (lexer.current().type == Lexeme::Name) ? AstName(lexer.current().name) : AstName(); + if (lexer.current().type == Lexeme::Name) + debugname = AstName(lexer.current().name); // parse funcname into a chain of indexing operators AstExpr* expr = parseNameExpr("function name"); @@ -636,8 +635,6 @@ AstStat* Parser::parseFunctionStat() recursionCounter = recursionCounterOld; // finish with : - bool hasself = false; - if (lexer.current().type == ':') { Position opPosition = lexer.current().location.begin; @@ -653,6 +650,21 @@ AstStat* Parser::parseFunctionStat() hasself = true; } + return expr; +} + +// function funcname funcbody +AstStat* Parser::parseFunctionStat() +{ + Location start = lexer.current().location; + + Lexeme matchFunction = lexer.current(); + nextLexeme(); + + bool hasself = false; + AstName debugname; + AstExpr* expr = parseFunctionName(start, hasself, debugname); + matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first; @@ -781,10 +793,11 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() TempVector args(scratchBinding); - std::optional vararg = std::nullopt; + bool vararg = false; + Location varargLocation; AstTypePack* varargAnnotation = nullptr; if (lexer.current().type != ')') - std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3 */ true); + std::tie(vararg, varargLocation, varargAnnotation) = parseBindingList(args, /* allowDot3 */ true); expectMatchAndConsume(')', matchParen); @@ -838,11 +851,12 @@ AstStat* Parser::parseDeclaration(const Location& start) TempVector args(scratchBinding); - std::optional vararg; + bool vararg = false; + Location varargLocation; AstTypePack* varargAnnotation = nullptr; if (lexer.current().type != ')') - std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); + std::tie(vararg, varargLocation, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); expectMatchAndConsume(')', matchParen); @@ -965,6 +979,21 @@ AstStat* Parser::parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op) return allocator.alloc(Location(initial->location, value->location), op, initial, value); } +std::pair> Parser::prepareFunctionArguments(const Location& start, bool hasself, const TempVector& args) +{ + AstLocal* self = nullptr; + + if (hasself) + self = pushLocal(Binding(Name(nameSelf, start), nullptr)); + + TempVector vars(scratchLocal); + + for (size_t i = 0; i < args.size(); ++i) + vars.push_back(pushLocal(args[i])); + + return {self, copy(vars)}; +} + // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // parlist ::= bindinglist [`,' `...'] | `...' std::pair Parser::parseFunctionBody( @@ -979,15 +1008,18 @@ std::pair Parser::parseFunctionBody( TempVector args(scratchBinding); - std::optional vararg; + bool vararg = false; + Location varargLocation; AstTypePack* varargAnnotation = nullptr; if (lexer.current().type != ')') - std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); + std::tie(vararg, varargLocation, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); + + std::optional argLocation; + + if (matchParen.type == Lexeme::Type('(') && lexer.current().type == Lexeme::Type(')')) + argLocation = Location(matchParen.position, lexer.current().location.end); - std::optional argLocation = matchParen.type == Lexeme::Type('(') && lexer.current().type == Lexeme::Type(')') - ? std::make_optional(Location(matchParen.position, lexer.current().location.end)) - : std::nullopt; expectMatchAndConsume(')', matchParen, true); std::optional typelist = parseOptionalReturnTypeAnnotation(); @@ -1000,19 +1032,11 @@ std::pair Parser::parseFunctionBody( unsigned int localsBegin = saveLocals(); Function fun; - fun.vararg = vararg.has_value(); + fun.vararg = vararg; - functionStack.push_back(fun); + functionStack.emplace_back(fun); - AstLocal* self = nullptr; - - if (hasself) - self = pushLocal(Binding(Name(nameSelf, start), nullptr)); - - TempVector vars(scratchLocal); - - for (size_t i = 0; i < args.size(); ++i) - vars.push_back(pushLocal(args[i])); + auto [self, vars] = prepareFunctionArguments(start, hasself, args); AstStatBlock* body = parseBlock(); @@ -1024,8 +1048,8 @@ std::pair Parser::parseFunctionBody( bool hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchFunction); - return {allocator.alloc(Location(start, end), generics, genericPacks, self, copy(vars), vararg, body, functionStack.size(), - debugname, typelist, varargAnnotation, hasEnd, argLocation), + return {allocator.alloc(Location(start, end), generics, genericPacks, self, vars, vararg, varargLocation, body, + functionStack.size(), debugname, typelist, varargAnnotation, hasEnd, argLocation), funLocal}; } @@ -1056,7 +1080,7 @@ Parser::Binding Parser::parseBinding() } // bindinglist ::= (binding | `...') [`,' bindinglist] -std::pair, AstTypePack*> Parser::parseBindingList(TempVector& result, bool allowDot3) +std::tuple Parser::parseBindingList(TempVector& result, bool allowDot3) { while (true) { @@ -1072,7 +1096,7 @@ std::pair, AstTypePack*> Parser::parseBindingList(TempVe tailAnnotation = parseVariadicArgumentAnnotation(); } - return {varargLocation, tailAnnotation}; + return {true, varargLocation, tailAnnotation}; } result.push_back(parseBinding()); @@ -1082,7 +1106,7 @@ std::pair, AstTypePack*> Parser::parseBindingList(TempVe nextLexeme(); } - return {std::nullopt, nullptr}; + return {false, Location(), nullptr}; } AstType* Parser::parseOptionalTypeAnnotation() @@ -1567,6 +1591,12 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) else return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; } + else if (lexer.current().type == Lexeme::InterpStringBegin || lexer.current().type == Lexeme::InterpStringSimple) + { + parseInterpString(); + + return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "Interpolated string literals cannot be used as types")}; + } else if (lexer.current().type == Lexeme::BrokenString) { Location location = lexer.current().location; @@ -2215,15 +2245,24 @@ AstExpr* Parser::parseSimpleExpr() { return parseNumber(); } - else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString || (FFlag::LuauInterpolatedStringBaseSupport && lexer.current().type == Lexeme::InterpStringSimple)) { return parseString(); } + else if (FFlag::LuauInterpolatedStringBaseSupport && lexer.current().type == Lexeme::InterpStringBegin) + { + return parseInterpString(); + } else if (lexer.current().type == Lexeme::BrokenString) { nextLexeme(); return reportExprError(start, {}, "Malformed string"); } + else if (lexer.current().type == Lexeme::BrokenInterpDoubleBrace) + { + nextLexeme(); + return reportExprError(start, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + } else if (lexer.current().type == Lexeme::Dot3) { if (functionStack.back().vararg) @@ -2614,11 +2653,11 @@ AstArray Parser::parseTypeParams() std::optional> Parser::parseCharArray() { - LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString); + LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::InterpStringSimple); scratchData.assign(lexer.current().data, lexer.current().length); - if (lexer.current().type == Lexeme::QuotedString) + if (lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::InterpStringSimple) { if (!Lexer::fixupQuotedString(scratchData)) { @@ -2645,6 +2684,70 @@ AstExpr* Parser::parseString() return reportExprError(location, {}, "String literal contains malformed escape sequence"); } +AstExpr* Parser::parseInterpString() +{ + TempVector> strings(scratchString); + TempVector expressions(scratchExpr); + + Location startLocation = lexer.current().location; + + do { + Lexeme currentLexeme = lexer.current(); + LUAU_ASSERT( + currentLexeme.type == Lexeme::InterpStringBegin + || currentLexeme.type == Lexeme::InterpStringMid + || currentLexeme.type == Lexeme::InterpStringEnd + || currentLexeme.type == Lexeme::InterpStringSimple + ); + + Location location = currentLexeme.location; + + Location startOfBrace = Location(location.end, 1); + + scratchData.assign(currentLexeme.data, currentLexeme.length); + + if (!Lexer::fixupQuotedString(scratchData)) + { + nextLexeme(); + return reportExprError(startLocation, {}, "Interpolated string literal contains malformed escape sequence"); + } + + AstArray chars = copy(scratchData); + + nextLexeme(); + + strings.push_back(chars); + + if (currentLexeme.type == Lexeme::InterpStringEnd || currentLexeme.type == Lexeme::InterpStringSimple) + { + AstArray> stringsArray = copy(strings); + AstArray expressionsArray = copy(expressions); + + return allocator.alloc(startLocation, stringsArray, expressionsArray); + } + + AstExpr* expression = parseExpr(); + + expressions.push_back(expression); + + switch (lexer.current().type) + { + case Lexeme::InterpStringBegin: + case Lexeme::InterpStringMid: + case Lexeme::InterpStringEnd: + break; + case Lexeme::BrokenInterpDoubleBrace: + nextLexeme(); + return reportExprError(location, {}, ERROR_INVALID_INTERP_DOUBLE_BRACE); + case Lexeme::BrokenString: + nextLexeme(); + return reportExprError(location, {}, "Malformed interpolated string, did you forget to add a '}'?"); + default: + return reportExprError(location, {}, "Malformed interpolated string, got %s", lexer.current().toString().c_str()); + } + } while (true); +} + AstExpr* Parser::parseNumber() { Location start = lexer.current().location; diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 0dc3f3f..11e0076 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -230,19 +230,25 @@ bool isIdentifier(std::string_view s) return (s.find_first_not_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_") == std::string::npos); } -std::string escape(std::string_view s) +std::string escape(std::string_view s, bool escapeForInterpString) { std::string r; r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting for (uint8_t c : s) { - if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') + if (c >= ' ' && c != '\\' && c != '\'' && c != '\"' && c != '`' && c != '{') r += c; else { r += '\\'; + if (escapeForInterpString && (c == '`' || c == '{')) + { + r += c; + continue; + } + switch (c) { case '\a': diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index 71e76ff..809c78d 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -12,6 +12,7 @@ inline bool isFlagExperimental(const char* flag) // or critical bugs that are found after the code has been submitted. static const char* kList[] = { "LuauLowerBoundsCalculation", + "LuauInterpolatedStringBaseSupport", nullptr, // makes sure we always have at least one entry }; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index bd8744c..4429e4c 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -14,6 +14,8 @@ #include #include +#include + #include LUAU_FASTINTVARIABLE(LuauCompileLoopUnrollThreshold, 25) @@ -25,6 +27,8 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTFLAGVARIABLE(LuauCompileXEQ, false) +LUAU_FASTFLAG(LuauInterpolatedStringBaseSupport) + LUAU_FASTFLAGVARIABLE(LuauCompileOptimalAssignment, false) LUAU_FASTFLAGVARIABLE(LuauCompileExtractK, false) @@ -1585,6 +1589,76 @@ struct Compiler } } + void compileExprInterpString(AstExprInterpString* expr, uint8_t target, bool targetTemp) + { + size_t formatCapacity = 0; + for (AstArray string : expr->strings) + { + formatCapacity += string.size + std::count(string.data, string.data + string.size, '%'); + } + + std::string formatString; + formatString.reserve(formatCapacity); + + size_t stringsLeft = expr->strings.size; + + for (AstArray string : expr->strings) + { + if (memchr(string.data, '%', string.size)) + { + for (size_t characterIndex = 0; characterIndex < string.size; ++characterIndex) + { + char character = string.data[characterIndex]; + formatString.push_back(character); + + if (character == '%') + formatString.push_back('%'); + } + } + else + formatString.append(string.data, string.size); + + stringsLeft--; + + if (stringsLeft > 0) + formatString += "%*"; + } + + size_t formatStringSize = formatString.size(); + + // We can't use formatStringRef.data() directly, because short strings don't have their data + // pinned in memory, so when interpFormatStrings grows, these pointers will move and become invalid. + std::unique_ptr formatStringPtr(new char[formatStringSize]); + memcpy(formatStringPtr.get(), formatString.data(), formatStringSize); + + AstArray formatStringArray{formatStringPtr.get(), formatStringSize}; + interpStrings.emplace_back(std::move(formatStringPtr)); // invalidates formatStringPtr, but keeps formatStringArray intact + + int32_t formatStringIndex = bytecode.addConstantString(sref(formatStringArray)); + if (formatStringIndex < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + RegScope rs(this); + + uint8_t baseReg = allocReg(expr, uint8_t(2 + expr->expressions.size)); + + emitLoadK(baseReg, formatStringIndex); + + for (size_t index = 0; index < expr->expressions.size; ++index) + compileExprTempTop(expr->expressions.data[index], uint8_t(baseReg + 2 + index)); + + BytecodeBuilder::StringRef formatMethod = sref(AstName("format")); + + int32_t formatMethodIndex = bytecode.addConstantString(formatMethod); + if (formatMethodIndex < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + bytecode.emitABC(LOP_NAMECALL, baseReg, baseReg, uint8_t(BytecodeBuilder::getStringHash(formatMethod))); + bytecode.emitAux(formatMethodIndex); + bytecode.emitABC(LOP_CALL, baseReg, uint8_t(expr->expressions.size + 2), 2); + bytecode.emitABC(LOP_MOVE, target, baseReg, 0); + } + static uint8_t encodeHashSize(unsigned int hashSize) { size_t hashSizeLog2 = 0; @@ -2059,6 +2133,10 @@ struct Compiler { compileExprIfElse(expr, target, targetTemp); } + else if (AstExprInterpString* interpString = node->as(); FFlag::LuauInterpolatedStringBaseSupport && interpString) + { + compileExprInterpString(interpString, target, targetTemp); + } else { LUAU_ASSERT(!"Unknown expression type"); @@ -2965,6 +3043,18 @@ struct Compiler uint8_t valueReg = kInvalidReg; }; + // This function analyzes assignments and marks assignment conflicts: cases when a variable is assigned on lhs + // but subsequently used on the rhs, assuming assignments are performed in order. Note that it's also possible + // for a variable to conflict on the lhs, if it's used in an lvalue expression after it's assigned. + // When conflicts are found, Assignment::conflictReg is allocated and that's where assignment is performed instead, + // until the final fixup in compileStatAssign. Assignment::valueReg is allocated by compileStatAssign as well. + // + // Per Lua manual, section 3.3.3 (Assignments), the proper assignment order is only guaranteed to hold for syntactic access: + // + // Note that this guarantee covers only accesses syntactically inside the assignment statement. If a function or a metamethod called + // during the assignment changes the value of a variable, Lua gives no guarantees about the order of that access. + // + // As such, we currently don't check if an assigned local is captured, which may mean it gets reassigned during a function call. void resolveAssignConflicts(AstStat* stat, std::vector& vars, const AstArray& values) { struct Visitor : AstVisitor @@ -3808,6 +3898,7 @@ struct Compiler std::vector loops; std::vector inlineFrames; std::vector captures; + std::vector> interpStrings; }; void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, const AstNameTable& names, const CompileOptions& inputOptions) @@ -3866,7 +3957,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const ParseResult& parseResult, c compiler.compileFunction(expr); AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), - /* self= */ nullptr, AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); + /* self= */ nullptr, AstArray(), /* vararg= */ true, /* varargLocation= */ Luau::Location(), root, /* functionDepth= */ 0, + /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main); const Compiler::Function* mainf = compiler.functions.find(&main); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp index 34f7954..e35c883 100644 --- a/Compiler/src/ConstantFolding.cpp +++ b/Compiler/src/ConstantFolding.cpp @@ -349,6 +349,11 @@ struct ConstantVisitor : AstVisitor if (cond.type != Constant::Type_Unknown) result = cond.isTruthful() ? trueExpr : falseExpr; } + else if (AstExprInterpString* expr = node->as()) + { + for (AstExpr* expression : expr->expressions) + analyze(expression); + } else { LUAU_ASSERT(!"Unknown expression type"); diff --git a/Compiler/src/CostModel.cpp b/Compiler/src/CostModel.cpp index 81cbfd7..ffc1cb1 100644 --- a/Compiler/src/CostModel.cpp +++ b/Compiler/src/CostModel.cpp @@ -215,6 +215,16 @@ struct CostVisitor : AstVisitor { return model(expr->condition) + model(expr->trueExpr) + model(expr->falseExpr) + 2; } + else if (AstExprInterpString* expr = node->as()) + { + // Baseline cost of string.format + Cost cost = 3; + + for (AstExpr* innerExpression : expr->expressions) + cost += model(innerExpression); + + return cost; + } else { LUAU_ASSERT(!"Unknown expression type"); diff --git a/Makefile b/Makefile index 8d95aac..0db7b28 100644 --- a/Makefile +++ b/Makefile @@ -93,7 +93,7 @@ endif ifeq ($(config),fuzz) CXX=clang++ # our fuzzing infra relies on llvm fuzzer - CXXFLAGS+=-fsanitize=address,fuzzer -Ibuild/libprotobuf-mutator -Ibuild/libprotobuf-mutator/external.protobuf/include -O2 + CXXFLAGS+=-fsanitize=address,fuzzer -Ibuild/libprotobuf-mutator -O2 LDFLAGS+=-fsanitize=address,fuzzer endif @@ -115,7 +115,7 @@ $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -ICommon/include -IAst/include -ICompiler/ $(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread -fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a build/libprotobuf-mutator/external.protobuf/lib/libprotobuf.a +fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a -lprotobuf # pseudo targets .PHONY: all test clean coverage format luau-size aliases @@ -195,7 +195,7 @@ $(BUILD)/%.c.o: %.c # protobuf fuzzer setup fuzz/luau.pb.cpp: fuzz/luau.proto build/libprotobuf-mutator - cd fuzz && ../build/libprotobuf-mutator/external.protobuf/bin/protoc luau.proto --cpp_out=. + cd fuzz && protoc luau.proto --cpp_out=. mv fuzz/luau.pb.cc fuzz/luau.pb.cpp $(BUILD)/fuzz/proto.cpp.o: fuzz/luau.pb.cpp @@ -203,7 +203,7 @@ $(BUILD)/fuzz/protoprint.cpp.o: fuzz/luau.pb.cpp build/libprotobuf-mutator: git clone https://github.com/google/libprotobuf-mutator build/libprotobuf-mutator - CXX= cmake -S build/libprotobuf-mutator -B build/libprotobuf-mutator -D CMAKE_BUILD_TYPE=Release -D LIB_PROTO_MUTATOR_DOWNLOAD_PROTOBUF=ON -D LIB_PROTO_MUTATOR_TESTING=OFF + CXX= cmake -S build/libprotobuf-mutator -B build/libprotobuf-mutator -D CMAKE_BUILD_TYPE=Release -D LIB_PROTO_MUTATOR_TESTING=OFF make -C build/libprotobuf-mutator -j8 # picks up include dependencies for all object files diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index af97dc4..4396e5d 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -24,16 +24,18 @@ * The caller is expected to handle stack reservation (by using less than LUA_MINSTACK slots or by calling lua_checkstack). * To ensure this is handled correctly, use api_incr_top(L) when pushing values to the stack. * - * Functions that push any collectable objects to the stack *should* call luaC_checkthreadsleep. Failure to do this can result - * in stack references that point to dead objects since sleeping threads don't get rescanned. + * Functions that push any collectable objects to the stack *should* call luaC_threadbarrier. Failure to do this can result + * in stack references that point to dead objects since black threads don't get rescanned. * - * Functions that push newly created objects to the stack *should* call luaC_checkGC in addition to luaC_checkthreadsleep. + * Functions that push newly created objects to the stack *should* call luaC_checkGC in addition to luaC_threadbarrier. * Failure to do this can result in OOM since GC may never run. * - * Note that luaC_checkGC may scan the thread and put it back to sleep; functions that call both before pushing objects must - * therefore call luaC_checkGC before luaC_checkthreadsleep to guarantee the object is pushed to an awake thread. + * Note that luaC_checkGC may mark the thread and paint it black; functions that call both before pushing objects must + * therefore call luaC_checkGC before luaC_threadbarrier to guarantee the object is pushed to a gray thread. */ +LUAU_FASTFLAG(LuauSimplerUpval) + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -152,7 +154,7 @@ void lua_xmove(lua_State* from, lua_State* to, int n) api_checknelems(from, n); api_check(from, from->global == to->global); api_check(from, to->ci->top - to->top >= n); - luaC_checkthreadsleep(to); + luaC_threadbarrier(to); StkId ttop = to->top; StkId ftop = from->top - n; @@ -168,7 +170,7 @@ void lua_xmove(lua_State* from, lua_State* to, int n) void lua_xpush(lua_State* from, lua_State* to, int idx) { api_check(from, from->global == to->global); - luaC_checkthreadsleep(to); + luaC_threadbarrier(to); setobj2s(to, to->top, index2addr(from, idx)); api_incr_top(to); return; @@ -177,7 +179,7 @@ void lua_xpush(lua_State* from, lua_State* to, int idx) lua_State* lua_newthread(lua_State* L) { luaC_checkGC(L); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); lua_State* L1 = luaE_newthread(L); setthvalue(L, L->top, L1); api_incr_top(L); @@ -236,7 +238,7 @@ void lua_remove(lua_State* L, int idx) void lua_insert(lua_State* L, int idx) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId p = index2addr(L, idx); api_checkvalidindex(L, p); for (StkId q = L->top; q > p; q--) @@ -248,7 +250,7 @@ void lua_insert(lua_State* L, int idx) void lua_replace(lua_State* L, int idx) { api_checknelems(L, 1); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId o = index2addr(L, idx); api_checkvalidindex(L, o); if (idx == LUA_ENVIRONINDEX) @@ -276,7 +278,7 @@ void lua_replace(lua_State* L, int idx) void lua_pushvalue(lua_State* L, int idx) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId o = index2addr(L, idx); setobj2s(L, L->top, o); api_incr_top(L); @@ -427,7 +429,7 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len) StkId o = index2addr(L, idx); if (!ttisstring(o)) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); if (!luaV_tostring(L, o)) { // conversion failed? if (len != NULL) @@ -607,7 +609,7 @@ void lua_pushvector(lua_State* L, float x, float y, float z) void lua_pushlstring(lua_State* L, const char* s, size_t len) { luaC_checkGC(L); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); setsvalue2s(L, L->top, luaS_newlstr(L, s, len)); api_incr_top(L); return; @@ -624,7 +626,7 @@ void lua_pushstring(lua_State* L, const char* s) const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp) { luaC_checkGC(L); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); const char* ret = luaO_pushvfstring(L, fmt, argp); return ret; } @@ -632,7 +634,7 @@ const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...) { luaC_checkGC(L); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); va_list argp; va_start(argp, fmt); const char* ret = luaO_pushvfstring(L, fmt, argp); @@ -643,7 +645,7 @@ const char* lua_pushfstringL(lua_State* L, const char* fmt, ...) void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) { luaC_checkGC(L); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); api_checknelems(L, nup); Closure* cl = luaF_newCclosure(L, nup, getcurrenv(L)); cl->c.f = fn; @@ -674,7 +676,7 @@ void lua_pushlightuserdata(lua_State* L, void* p) int lua_pushthread(lua_State* L) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); setthvalue(L, L->top, L); api_incr_top(L); return L->global->mainthread == L; @@ -686,7 +688,7 @@ int lua_pushthread(lua_State* L) int lua_gettable(lua_State* L, int idx) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_gettable(L, t, L->top - 1, L->top - 1); @@ -695,7 +697,7 @@ int lua_gettable(lua_State* L, int idx) int lua_getfield(lua_State* L, int idx, const char* k) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId t = index2addr(L, idx); api_checkvalidindex(L, t); TValue key; @@ -707,7 +709,7 @@ int lua_getfield(lua_State* L, int idx, const char* k) int lua_rawgetfield(lua_State* L, int idx, const char* k) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId t = index2addr(L, idx); api_check(L, ttistable(t)); TValue key; @@ -719,7 +721,7 @@ int lua_rawgetfield(lua_State* L, int idx, const char* k) int lua_rawget(lua_State* L, int idx) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top - 1, luaH_get(hvalue(t), L->top - 1)); @@ -728,7 +730,7 @@ int lua_rawget(lua_State* L, int idx) int lua_rawgeti(lua_State* L, int idx, int n) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top, luaH_getnum(hvalue(t), n)); @@ -739,7 +741,7 @@ int lua_rawgeti(lua_State* L, int idx, int n) void lua_createtable(lua_State* L, int narray, int nrec) { luaC_checkGC(L); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); sethvalue(L, L->top, luaH_new(L, narray, nrec)); api_incr_top(L); return; @@ -775,7 +777,7 @@ void lua_setsafeenv(lua_State* L, int objindex, int enabled) int lua_getmetatable(lua_State* L, int objindex) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); Table* mt = NULL; const TValue* obj = index2addr(L, objindex); switch (ttype(obj)) @@ -800,7 +802,7 @@ int lua_getmetatable(lua_State* L, int objindex) void lua_getfenv(lua_State* L, int idx) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId o = index2addr(L, idx); api_checkvalidindex(L, o); switch (ttype(o)) @@ -1161,7 +1163,7 @@ l_noret lua_error(lua_State* L) int lua_next(lua_State* L, int idx) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId t = index2addr(L, idx); api_check(L, ttistable(t)); int more = luaH_next(L, hvalue(t), L->top - 1); @@ -1180,13 +1182,13 @@ void lua_concat(lua_State* L, int n) if (n >= 2) { luaC_checkGC(L); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); luaV_concat(L, n, cast_int(L->top - L->base) - 1); L->top -= (n - 1); } else if (n == 0) { // push empty string - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); setsvalue2s(L, L->top, luaS_newlstr(L, "", 0)); api_incr_top(L); } @@ -1198,7 +1200,7 @@ void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) { api_check(L, unsigned(tag) < LUA_UTAG_LIMIT || tag == UTAG_PROXY); luaC_checkGC(L); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); Udata* u = luaU_newudata(L, sz, tag); setuvalue(L, L->top, u); api_incr_top(L); @@ -1208,7 +1210,7 @@ void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) { luaC_checkGC(L); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); // make sure sz + sizeof(dtor) doesn't overflow; luaU_newdata will reject SIZE_MAX correctly size_t as = sz < SIZE_MAX - sizeof(dtor) ? sz + sizeof(dtor) : SIZE_MAX; Udata* u = luaU_newudata(L, as, UTAG_IDTOR); @@ -1244,7 +1246,7 @@ static const char* aux_upvalue(StkId fi, int n, TValue** val) const char* lua_getupvalue(lua_State* L, int funcindex, int n) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); TValue* val; const char* name = aux_upvalue(index2addr(L, funcindex), n, &val); if (name) @@ -1266,7 +1268,8 @@ const char* lua_setupvalue(lua_State* L, int funcindex, int n) L->top--; setobj(L, val, L->top); luaC_barrier(L, clvalue(fi), L->top); - luaC_upvalbarrier(L, cast_to(UpVal*, NULL), val); + if (!FFlag::LuauSimplerUpval) + luaC_upvalbarrier(L, cast_to(UpVal*, NULL), val); } return name; } @@ -1336,7 +1339,7 @@ void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(lua_State*, void*)) void lua_clonefunction(lua_State* L, int idx) { luaC_checkGC(L); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); StkId p = index2addr(L, idx); api_check(L, isLfunction(p)); Closure* cl = clvalue(p); diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index c44ccbe..fee9aaa 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauDebuggerBreakpointHitOnNextBestLine, false); - static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -44,13 +42,13 @@ int lua_getargument(lua_State* L, int level, int n) { if (n <= fp->numparams) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); luaA_pushobject(L, ci->base + (n - 1)); res = 1; } else if (fp->is_vararg && n < ci->base - ci->func) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); luaA_pushobject(L, ci->func + n); res = 1; } @@ -69,7 +67,7 @@ const char* lua_getlocal(lua_State* L, int level, int n) const LocVar* var = fp ? luaF_getlocal(fp, n, currentpc(L, ci)) : NULL; if (var) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); luaA_pushobject(L, ci->base + var->reg); } const char* name = var ? getstr(var->varname) : NULL; @@ -185,7 +183,7 @@ int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar) status = auxgetinfo(L, what, ar, f, ci); if (strchr(what, 'f')) { - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); setclvalue(L, L->top, f); incr_top(L); } @@ -437,29 +435,17 @@ static int getnextline(Proto* p, int line) int lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) { - int target = -1; + const TValue* func = luaA_toobject(L, funcindex); + api_check(L, ttisfunction(func) && !clvalue(func)->isC); - if (FFlag::LuauDebuggerBreakpointHitOnNextBestLine) + Proto* p = clvalue(func)->l.p; + // Find line number to add the breakpoint to. + int target = getnextline(p, line); + + if (target != -1) { - const TValue* func = luaA_toobject(L, funcindex); - api_check(L, ttisfunction(func) && !clvalue(func)->isC); - - Proto* p = clvalue(func)->l.p; - // Find line number to add the breakpoint to. - target = getnextline(p, line); - - if (target != -1) - { - // Add breakpoint on the exact line - luaG_breakpoint(L, p, target, bool(enabled)); - } - } - else - { - const TValue* func = luaA_toobject(L, funcindex); - api_check(L, ttisfunction(func) && !clvalue(func)->isC); - - luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); + // Add breakpoint on the exact line + luaG_breakpoint(L, p, target, bool(enabled)); } return target; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 6016e41..51f63d3 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -158,7 +158,7 @@ l_noret luaD_throw(lua_State* L, int errcode) static void correctstack(lua_State* L, TValue* oldstack) { L->top = (L->top - oldstack) + L->stack; - for (UpVal* up = L->openupval; up != NULL; up = up->u.l.threadnext) + for (UpVal* up = L->openupval; up != NULL; up = up->u.open.threadnext) up->v = (up->v - oldstack) + L->stack; for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) { @@ -245,7 +245,7 @@ void luaD_call(lua_State* L, StkId func, int nResults) int oldactive = luaC_threadactive(L); l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); luau_execute(L); // call it @@ -454,7 +454,7 @@ int lua_resume(lua_State* L, lua_State* from, int nargs) L->baseCcalls = ++L->nCcalls; l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); status = luaD_rawrunprotected(L, resume, L->top - nargs); @@ -483,7 +483,7 @@ int lua_resumeerror(lua_State* L, lua_State* from) L->baseCcalls = ++L->nCcalls; l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); status = LUA_ERRRUN; diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index bd0f826..8c78083 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,6 +6,9 @@ #include "lmem.h" #include "lgc.h" +LUAU_FASTFLAG(LuauSimplerUpval) +LUAU_FASTFLAG(LuauNoSleepBit) + Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_newgco(L, Proto, sizeof(Proto), L->activememcat); @@ -71,59 +74,76 @@ UpVal* luaF_findupval(lua_State* L, StkId level) UpVal* p; while (*pp != NULL && (p = *pp)->v >= level) { - LUAU_ASSERT(p->v != &p->u.value); + LUAU_ASSERT(!FFlag::LuauSimplerUpval || !isdead(g, obj2gco(p))); + LUAU_ASSERT(upisopen(p)); if (p->v == level) - { // found a corresponding upvalue? - if (isdead(g, obj2gco(p))) // is it dead? - changewhite(obj2gco(p)); // resurrect it + { // found a corresponding upvalue? + if (!FFlag::LuauSimplerUpval && isdead(g, obj2gco(p))) // is it dead? + changewhite(obj2gco(p)); // resurrect it return p; } - pp = &p->u.l.threadnext; + pp = &p->u.open.threadnext; } + LUAU_ASSERT(luaC_threadactive(L)); + LUAU_ASSERT(!luaC_threadsleeping(L)); + LUAU_ASSERT(!FFlag::LuauNoSleepBit || !isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black + UpVal* uv = luaM_newgco(L, UpVal, sizeof(UpVal), L->activememcat); // not found: create a new one - uv->tt = LUA_TUPVAL; - uv->marked = luaC_white(g); - uv->memcat = L->activememcat; + luaC_init(L, uv, LUA_TUPVAL); + uv->markedopen = 0; uv->v = level; // current value lives in the stack // chain the upvalue in the threads open upvalue list at the proper position - UpVal* next = *pp; - uv->u.l.threadnext = next; - uv->u.l.threadprev = pp; - if (next) - next->u.l.threadprev = &uv->u.l.threadnext; + if (FFlag::LuauSimplerUpval) + { + uv->u.open.threadnext = *pp; + *pp = uv; + } + else + { + UpVal* next = *pp; + uv->u.open.threadnext = next; - *pp = uv; + uv->u.open.threadprev = pp; + if (next) + next->u.open.threadprev = &uv->u.open.threadnext; + + *pp = uv; + } // double link the upvalue in the global open upvalue list - uv->u.l.prev = &g->uvhead; - uv->u.l.next = g->uvhead.u.l.next; - uv->u.l.next->u.l.prev = uv; - g->uvhead.u.l.next = uv; - LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); + uv->u.open.prev = &g->uvhead; + uv->u.open.next = g->uvhead.u.open.next; + uv->u.open.next->u.open.prev = uv; + g->uvhead.u.open.next = uv; + LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); + return uv; } + void luaF_unlinkupval(UpVal* uv) { + LUAU_ASSERT(!FFlag::LuauSimplerUpval); + // unlink upvalue from the global open upvalue list - LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); - uv->u.l.next->u.l.prev = uv->u.l.prev; - uv->u.l.prev->u.l.next = uv->u.l.next; + LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); + uv->u.open.next->u.open.prev = uv->u.open.prev; + uv->u.open.prev->u.open.next = uv->u.open.next; // unlink upvalue from the thread open upvalue list - *uv->u.l.threadprev = uv->u.l.threadnext; + *uv->u.open.threadprev = uv->u.open.threadnext; - if (UpVal* next = uv->u.l.threadnext) - next->u.l.threadprev = uv->u.l.threadprev; + if (UpVal* next = uv->u.open.threadnext) + next->u.open.threadprev = uv->u.open.threadprev; } void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) { - if (uv->v != &uv->u.value) // is it open? - luaF_unlinkupval(uv); // remove from open list - luaM_freegco(L, uv, sizeof(UpVal), uv->memcat, page); // free upvalue + if (!FFlag::LuauSimplerUpval && uv->v != &uv->u.value) // is it open? + luaF_unlinkupval(uv); // remove from open list + luaM_freegco(L, uv, sizeof(UpVal), uv->memcat, page); // free upvalue } void luaF_close(lua_State* L, StkId level) @@ -133,26 +153,55 @@ void luaF_close(lua_State* L, StkId level) while (L->openupval != NULL && (uv = L->openupval)->v >= level) { GCObject* o = obj2gco(uv); - LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); + LUAU_ASSERT(!isblack(o) && upisopen(uv)); - // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue - luaF_unlinkupval(uv); - - if (isdead(g, o)) + if (FFlag::LuauSimplerUpval) { - // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again - uv->v = &uv->u.value; + LUAU_ASSERT(!isdead(g, o)); + + // unlink value *before* closing it since value storage overlaps + L->openupval = uv->u.open.threadnext; + + luaF_closeupval(L, uv, /* dead= */ false); } else { - setobj(L, &uv->u.value, uv->v); - uv->v = &uv->u.value; - // GC state of a new closed upvalue has to be initialized - luaC_initupval(L, uv); + // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue + luaF_unlinkupval(uv); + + if (isdead(g, o)) + { + // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again + uv->v = &uv->u.value; + } + else + { + setobj(L, &uv->u.value, uv->v); + uv->v = &uv->u.value; + // GC state of a new closed upvalue has to be initialized + luaC_upvalclosed(L, uv); + } } } } +void luaF_closeupval(lua_State* L, UpVal* uv, bool dead) +{ + LUAU_ASSERT(FFlag::LuauSimplerUpval); + + // unlink value from all lists *before* closing it since value storage overlaps + LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); + uv->u.open.next->u.open.prev = uv->u.open.prev; + uv->u.open.prev->u.open.next = uv->u.open.next; + + if (dead) + return; + + setobj(L, &uv->u.value, uv->v); + uv->v = &uv->u.value; + luaC_upvalclosed(L, uv); +} + void luaF_freeproto(lua_State* L, Proto* f, lua_Page* page) { luaM_freearray(L, f->code, f->sizecode, Instruction, f->memcat); diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h index 59ab572..899d040 100644 --- a/VM/src/lfunc.h +++ b/VM/src/lfunc.h @@ -12,6 +12,7 @@ LUAI_FUNC Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p LUAI_FUNC Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e); LUAI_FUNC UpVal* luaF_findupval(lua_State* L, StkId level); LUAI_FUNC void luaF_close(lua_State* L, StkId level); +LUAI_FUNC void luaF_closeupval(lua_State* L, UpVal* uv, bool dead); LUAI_FUNC void luaF_freeproto(lua_State* L, Proto* f, struct lua_Page* page); LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c, struct lua_Page* page); LUAI_FUNC void luaF_unlinkupval(UpVal* uv); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index f7a851f..b95d6de 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,6 +13,117 @@ #include +/* + * Luau uses an incremental non-generational non-moving mark&sweep garbage collector. + * + * The collector runs in three stages: mark, atomic and sweep. Mark and sweep are incremental and try to do a limited amount + * of work every GC step; atomic is ran once per the GC cycle and is indivisible. In either case, the work happens during GC + * steps that are "scheduled" by the GC pacing algorithm - the steps happen either from explicit calls to lua_gc, or after + * the mutator (aka application) allocates some amount of memory, which is known as "GC assist". In either case, GC steps + * can't happen concurrently with other access to VM state. + * + * Current GC stage is stored in global_State::gcstate, and has two additional stages for pause and second-phase mark, explained below. + * + * GC pacer is an algorithm that tries to ensure that GC can always catch up to the application allocating garbage, but do this + * with minimal amount of effort. To configure the pacer Luau provides control over three variables: GC goal, defined as the + * target heap size during atomic phase in relation to live heap size (e.g. 200% goal means the heap's worst case size is double + * the total size of alive objects), step size (how many kilobytes should the application allocate for GC step to trigger), and + * GC multiplier (how much should the GC try to mark relative to how much the application allocated). It's critical that step + * multiplier is significantly above 1, as this is what allows the GC to catch up to the application's allocation rate, and + * GC goal and GC multiplier are linked in subtle ways, described in lua.h comments for LUA_GCSETGOAL. + * + * During mark, GC tries to identify all reachable objects and mark them as reachable, while keeping unreachable objects unmarked. + * During sweep, GC tries to sweep all objects that were not reachable at the end of mark. The atomic phase is needed to ensure + * that all pending marking has completed and all objects that are still marked as unreachable are, in fact, unreachable. + * + * Notably, during mark GC doesn't free any objects, and so the heap size constantly grows; during sweep, GC doesn't do any marking + * work, so it can't immediately free objects that became unreachable after sweeping started. + * + * Every collectable object has one of three colors at any given point in time: white, gray or black. This coloring scheme + * is necessary to implement incremental marking: white objects have not been marked and may be unreachable, black objects + * have been marked and will not be marked again if they stay black, and gray objects have been marked but may contain unmarked + * references. + * + * Objects are allocated as white; however, during sweep, we need to differentiate between objects that remained white in the mark + * phase (these are not reachable and can be freed) and objects that were allocated after the mark phase ended. Because of this, the + * colors are encoded using three bits inside GCheader::marked: white0, white1 and black (so technically we use a four-color scheme: + * any object can be white0, white1, gray or black). All bits are exclusive, and gray objects have all three bits unset. This allows + * us to have the "current" white bit, which is flipped during atomic stage - during sweeping, objects that have the white color from + * the previous mark may be deleted, and all other objects may or may not be reachable, and will be changed to the current white color, + * so that the next mark can start coloring objects from scratch again. + * + * Crucially, the coloring scheme comes with what's known as a tri-color invariant: a black object may never point to a white object. + * + * At the end of atomic stage, the expectation is that there are no gray objects anymore, which means all objects are either black + * (reachable) or white (unreachable = dead). Tri-color invariant is maintained throughout mark and atomic phase. To uphold this + * invariant, every modification of an object needs to check if the object is black and the new referent is white; if so, we + * need to either mark the referent, making it non-white (known as a forward barrier), or mark the object as gray and queue it + * for additional marking (known as a backward barrier). + * + * Luau uses both types of barriers. Forward barriers advance GC progress, since they don't create new outstanding work for GC, + * but they may be expensive when an object is modified many times in succession. Backward barriers are cheaper, as they defer + * most of the work until "later", but they require queueing the object for a rescan which isn't always possible. Table writes usually + * use backward barriers (but switch to forward barriers during second-phase mark), whereas upvalue writes and setmetatable use forward + * barriers. + * + * Since marking is incremental, it needs a way to track progress, which is implemented as a gray set: at any point, objects that + * are gray need to mark their white references, objects that are black have no pending work, and objects that are white have not yet + * been reached. Once the gray set is empty, the work completes; as such, incremental marking is as simple as removing an object from + * the gray set, and turning it to black (which requires turning all its white references to gray). The gray set is implemented as + * an intrusive singly linked list, using `gclist` field in multiple objects (functions, tables, threads and protos). When an object + * doesn't have gclist field, the marking of that object needs to be "immediate", changing the colors of all references in one go. + * + * When a black object is modified, it needs to become gray again. Objects like this are placed on a separate `grayagain` list by a + * barrier - this is important because it allows us to have a mark stage that terminates when the gray set is empty even if the mutator + * is constantly changing existing objects to gray. After mark stage finishes traversing `gray` list, we copy `grayagain` list to `gray` + * once and incrementally mark it again. During this phase of marking, we may get more objects marked as `grayagain`, so after we finish + * emptying out the `gray` list the second time, we finish the mark stage and do final marking of `grayagain` during atomic phase. + * GC works correctly without this second-phase mark (called GCSpropagateagain), but it reduces the time spent during atomic phase. + * + * Sweeping is also incremental, but instead of working at a granularity of an object, it works at a granularity of a page: all GC + * objects are allocated in special pages (see lmem.cpp for details), and sweeper traverses all objects in one page in one incremental + * step, freeing objects that aren't reachable (old white), and recoloring all other objects with the new white to prepare them for next + * mark. During sweeping we don't need to maintain the GC invariant, because our goal is to paint all objects with current white - + * however, some barriers will still trigger (because some reachable objects are still black as sweeping didn't get to them yet), and + * some barriers will proactively mark black objects as white to avoid extra barriers from triggering excessively. + * + * Most references that GC deals with are strong, and as such they fit neatly into the incremental marking scheme. Some, however, are + * weak - notably, tables can be marked as having weak keys/values (using __mode metafield). During incremental marking, we don't know + * for certain if a given object is alive - if it's marked as black, it definitely was reachable during marking, but if it's marked as + * white, we don't know if it's actually unreachable. Because of this, we need to defer weak table handling to the atomic phase; after + * all objects are marked, we traverse all weak tables (that are linked into special weak table lists using `gclist` during marking), + * and remove all entries that have white keys or values. If keys or values are strong, they are marked normally. + * + * The simplified scheme described above isn't fully accurate because of threads, upvalues and strings. + * + * Strings are semantically black (they are initially white, and when the mark stage reaches a string, it changes its color and never + * touches the object again), but they are technically marked as gray - the black bit is never set on a string object. This behavior + * is inherited from Lua 5.1 GC, but doesn't have a clear rationale - effectively, strings are marked as gray but are never part of + * a gray list. + * + * Threads are hard to deal with because for them to fit into the white-gray-black scheme, writes to thread stacks need to have barriers + * that turn the thread from black (already scanned) to gray - but this is very expensive because stack writes are very common. To + * get around this problem, threads have an "active" state which means that a thread is actively executing code. When GC reaches an active + * thread, it keeps it as gray, and rescans it during atomic phase. When a thread is inactive, GC instead paints the thread black. All + * API calls that can write to thread stacks outside of execution (which implies active) uses a thread barrier that checks if the thread is + * black, and if it is it marks it as gray and puts it on a gray list to be rescanned during atomic phase. + * + * NOTE: The above is only true when LuauNoSleepBit is enabled. + * + * Upvalues are special objects that can be closed, in which case they contain the value (acting as a reference cell) and can be dealt + * with using the regular algorithm, or open, in which case they refer to a stack slot in some other thread. These are difficult to deal + * with because the stack writes are not monitored. Because of this open upvalues are treated in a somewhat special way: they are never marked + * as black (doing so would violate the GC invariant), and they are kept in a special global list (global_State::uvhead) which is traversed + * during atomic phase. This is needed because an open upvalue might point to a stack location in a dead thread that never marked the stack + * slot - upvalues like this are identified since they don't have `markedopen` bit set during thread traversal and closed in `clearupvals`. + * + * NOTE: The above is only true when LuauSimplerUpval is enabled. + */ + +LUAU_FASTFLAGVARIABLE(LuauSimplerUpval, false) +LUAU_FASTFLAGVARIABLE(LuauNoSleepBit, false) +LUAU_FASTFLAGVARIABLE(LuauEagerShrink, false) + #define GC_SWEEPPAGESTEPCOST 16 #define GC_INTERRUPT(state) \ @@ -150,8 +261,8 @@ static void reallymarkobject(global_State* g, GCObject* o) { UpVal* uv = gco2uv(o); markvalue(g, uv->v); - if (uv->v == &uv->u.value) // closed? - gray2black(o); // open upvalues are never black + if (!upisopen(uv)) // closed? + gray2black(o); // open upvalues are never black return; } case LUA_TFUNCTION: @@ -289,22 +400,34 @@ static void traverseclosure(global_State* g, Closure* cl) } } -static void traversestack(global_State* g, lua_State* l, bool clearstack) +static void traversestack(global_State* g, lua_State* l) { markobject(g, l->gt); if (l->namecall) stringmark(l->namecall); for (StkId o = l->stack; o < l->top; o++) markvalue(g, o); - // final traversal? - if (g->gcstate == GCSatomic || clearstack) + if (FFlag::LuauSimplerUpval) { - StkId stack_end = l->stack + l->stacksize; - for (StkId o = l->top; o < stack_end; o++) // clear not-marked stack slice - setnilvalue(o); + for (UpVal* uv = l->openupval; uv; uv = uv->u.open.threadnext) + { + LUAU_ASSERT(upisopen(uv)); + uv->markedopen = 1; + markobject(g, uv); + } } } +static void clearstack(lua_State* l) +{ + StkId stack_end = l->stack + l->stacksize; + for (StkId o = l->top; o < stack_end; o++) // clear not-marked stack slice + setnilvalue(o); +} + +// TODO: pull function definition here when FFlag::LuauEagerShrink is removed +static void shrinkstack(lua_State* L); + /* ** traverse one gray object, turning it to black. ** Returns `quantity' traversed. @@ -338,14 +461,17 @@ static size_t propagatemark(global_State* g) LUAU_ASSERT(!luaC_threadsleeping(th)); - // threads that are executing and the main thread are not deactivated + // threads that are executing and the main thread remain gray bool active = luaC_threadactive(th) || th == th->global->mainthread; + // TODO: Refactor this logic after LuauNoSleepBit is removed if (!active && g->gcstate == GCSpropagate) { - traversestack(g, th, /* clearstack= */ true); + traversestack(g, th); + clearstack(th); - l_setbit(th->stackstate, THREAD_SLEEPINGBIT); + if (!FFlag::LuauNoSleepBit) + l_setbit(th->stackstate, THREAD_SLEEPINGBIT); } else { @@ -354,9 +480,17 @@ static size_t propagatemark(global_State* g) black2gray(o); - traversestack(g, th, /* clearstack= */ false); + traversestack(g, th); + + // final traversal? + if (g->gcstate == GCSatomic) + clearstack(th); } + // we could shrink stack at any time but we opt to skip it during atomic since it's redundant to do that more than once per cycle + if (FFlag::LuauEagerShrink && g->gcstate != GCSatomic) + shrinkstack(th); + return sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; } case LUA_TPROTO: @@ -537,7 +671,7 @@ static bool deletegco(void* context, lua_Page* page, GCObject* gco) // we are in the process of deleting everything // threads with open upvalues will attempt to close them all on removal // but those upvalues might point to stack values that were already deleted - if (gco->gch.tt == LUA_TTHREAD) + if (!FFlag::LuauSimplerUpval && gco->gch.tt == LUA_TTHREAD) { lua_State* th = gco2th(gco); @@ -595,13 +729,53 @@ static void markroot(lua_State* L) static size_t remarkupvals(global_State* g) { size_t work = 0; - for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + + for (UpVal* uv = g->uvhead.u.open.next; uv != &g->uvhead; uv = uv->u.open.next) { work += sizeof(UpVal); - LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); + + LUAU_ASSERT(upisopen(uv)); + LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); + LUAU_ASSERT(!isblack(obj2gco(uv))); // open upvalues are never black + if (isgray(obj2gco(uv))) markvalue(g, uv->v); } + + return work; +} + +static size_t clearupvals(lua_State* L) +{ + global_State* g = L->global; + + size_t work = 0; + + for (UpVal* uv = g->uvhead.u.open.next; uv != &g->uvhead;) + { + work += sizeof(UpVal); + + LUAU_ASSERT(upisopen(uv)); + LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); + LUAU_ASSERT(!isblack(obj2gco(uv))); // open upvalues are never black + LUAU_ASSERT(iswhite(obj2gco(uv)) || !iscollectable(uv->v) || !iswhite(gcvalue(uv->v))); + + if (uv->markedopen) + { + // upvalue is still open (belongs to alive thread) + LUAU_ASSERT(isgray(obj2gco(uv))); + uv->markedopen = 0; // for next cycle + uv = uv->u.open.next; + } + else + { + // upvalue is either dead, or alive but the thread is dead; unlink and close + UpVal* next = uv->u.open.next; + luaF_closeupval(L, uv, /* dead= */ iswhite(obj2gco(uv))); + uv = next; + } + } + return work; } @@ -654,6 +828,16 @@ static size_t atomic(lua_State* L) g->gcmetrics.currcycle.atomictimeclear += recordGcDeltaTime(currts); #endif + if (FFlag::LuauSimplerUpval) + { + // close orphaned live upvalues of dead threads and clear dead upvalues + work += clearupvals(L); + +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomictimeupval += recordGcDeltaTime(currts); +#endif + } + // flip current white g->currentwhite = cast_byte(otherwhite(g)); g->sweepgcopage = g->allgcopages; @@ -677,8 +861,11 @@ static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) if (alive) { - resetbit(th->stackstate, THREAD_SLEEPINGBIT); - shrinkstack(th); + if (!FFlag::LuauNoSleepBit) + resetbit(th->stackstate, THREAD_SLEEPINGBIT); + + if (!FFlag::LuauEagerShrink) + shrinkstack(th); } } @@ -945,7 +1132,7 @@ void luaC_fullgc(lua_State* L) startGcCycleMetrics(g); #endif - if (g->gcstate <= GCSatomic) + if (FFlag::LuauSimplerUpval ? keepinvariant(g) : g->gcstate <= GCSatomic) { // reset sweep marks to sweep all elements (returning them to white) g->sweepgcopage = g->allgcopages; @@ -955,7 +1142,7 @@ void luaC_fullgc(lua_State* L) g->weak = NULL; g->gcstate = GCSsweep; } - LUAU_ASSERT(g->gcstate == GCSsweep); + LUAU_ASSERT(g->gcstate == GCSpause || g->gcstate == GCSsweep); // finish any pending sweep phase while (g->gcstate != GCSpause) { @@ -963,6 +1150,16 @@ void luaC_fullgc(lua_State* L) gcstep(L, SIZE_MAX); } + if (FFlag::LuauSimplerUpval) + { + // clear markedopen bits for all open upvalues; these might be stuck from half-finished mark prior to full gc + for (UpVal* uv = g->uvhead.u.open.next; uv != &g->uvhead; uv = uv->u.open.next) + { + LUAU_ASSERT(upisopen(uv)); + uv->markedopen = 0; + } + } + #ifdef LUAI_GCMETRICS finishGcCycleMetrics(g); startGcCycleMetrics(g); @@ -999,6 +1196,7 @@ void luaC_fullgc(lua_State* L) void luaC_barrierupval(lua_State* L, GCObject* v) { + LUAU_ASSERT(!FFlag::LuauSimplerUpval); global_State* g = L->global; LUAU_ASSERT(iswhite(v) && !isdead(g, v)); @@ -1038,30 +1236,24 @@ void luaC_barriertable(lua_State* L, Table* t, GCObject* v) g->grayagain = o; } -void luaC_barrierback(lua_State* L, Table* t) +void luaC_barrierback(lua_State* L, GCObject* o, GCObject** gclist) { global_State* g = L->global; - GCObject* o = obj2gco(t); LUAU_ASSERT(isblack(o) && !isdead(g, o)); LUAU_ASSERT(g->gcstate != GCSpause); - black2gray(o); // make table gray (again) - t->gclist = g->grayagain; + + black2gray(o); // make object gray (again) + *gclist = g->grayagain; g->grayagain = o; } -void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt) -{ - global_State* g = L->global; - o->gch.marked = luaC_white(g); - o->gch.tt = tt; - o->gch.memcat = L->activememcat; -} - -void luaC_initupval(lua_State* L, UpVal* uv) +void luaC_upvalclosed(lua_State* L, UpVal* uv) { global_State* g = L->global; GCObject* o = obj2gco(uv); + LUAU_ASSERT(!upisopen(uv)); // upvalue was closed but needs GC state fixup + if (isgray(o)) { if (keepinvariant(g)) @@ -1105,6 +1297,7 @@ int64_t luaC_allocationrate(lua_State* L) void luaC_wakethread(lua_State* L) { + LUAU_ASSERT(!FFlag::LuauNoSleepBit); if (!luaC_threadsleeping(L)) return; @@ -1116,6 +1309,8 @@ void luaC_wakethread(lua_State* L) { GCObject* o = obj2gco(L); + LUAU_ASSERT(isblack(o)); + L->gclist = g->grayagain; g->grayagain = o; diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 7b03a25..69379c8 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -6,6 +6,8 @@ #include "lobject.h" #include "lstate.h" +LUAU_FASTFLAG(LuauNoSleepBit) + /* ** Default settings for GC tunables (settable via lua_gc) */ @@ -74,6 +76,7 @@ #define luaC_white(g) cast_to(uint8_t, ((g)->currentwhite) & WHITEBITS) // Thread stack states +// TODO: Remove with FFlag::LuauNoSleepBit and replace with lua_State::threadactive #define THREAD_ACTIVEBIT 0 // thread is currently active #define THREAD_SLEEPINGBIT 1 // thread is not executing and stack should not be modified @@ -109,7 +112,7 @@ #define luaC_barrierfast(L, t) \ { \ if (isblack(obj2gco(t))) \ - luaC_barrierback(L, t); \ + luaC_barrierback(L, obj2gco(t), &t->gclist); \ } #define luaC_objbarrier(L, p, o) \ @@ -118,29 +121,43 @@ luaC_barrierf(L, obj2gco(p), obj2gco(o)); \ } +// TODO: Remove with FFlag::LuauSimplerUpval #define luaC_upvalbarrier(L, uv, tv) \ { \ if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || (uv)->v != &(uv)->u.value)) \ luaC_barrierupval(L, gcvalue(tv)); \ } -#define luaC_checkthreadsleep(L) \ +#define luaC_threadbarrier(L) \ { \ - if (luaC_threadsleeping(L)) \ - luaC_wakethread(L); \ + if (FFlag::LuauNoSleepBit) \ + { \ + if (isblack(obj2gco(L))) \ + luaC_barrierback(L, obj2gco(L), &L->gclist); \ + } \ + else \ + { \ + if (luaC_threadsleeping(L)) \ + luaC_wakethread(L); \ + } \ } -#define luaC_init(L, o, tt) luaC_initobj(L, cast_to(GCObject*, (o)), tt) +#define luaC_init(L, o, tt_) \ + { \ + o->marked = luaC_white(L->global); \ + o->tt = tt_; \ + o->memcat = L->activememcat; \ + } LUAI_FUNC void luaC_freeall(lua_State* L); LUAI_FUNC size_t luaC_step(lua_State* L, bool assist); LUAI_FUNC void luaC_fullgc(lua_State* L); LUAI_FUNC void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt); -LUAI_FUNC void luaC_initupval(lua_State* L, UpVal* uv); +LUAI_FUNC void luaC_upvalclosed(lua_State* L, UpVal* uv); LUAI_FUNC void luaC_barrierupval(lua_State* L, GCObject* v); LUAI_FUNC void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v); LUAI_FUNC void luaC_barriertable(lua_State* L, Table* t, GCObject* v); -LUAI_FUNC void luaC_barrierback(lua_State* L, Table* t); +LUAI_FUNC void luaC_barrierback(lua_State* L, GCObject* o, GCObject** gclist); LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); LUAI_FUNC int64_t luaC_allocationrate(lua_State* L); diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index b6204b1..2f9c175 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -102,10 +102,12 @@ static void validatestack(global_State* g, lua_State* l) if (l->namecall) validateobjref(g, obj2gco(l), obj2gco(l->namecall)); - for (UpVal* uv = l->openupval; uv; uv = uv->u.l.threadnext) + for (UpVal* uv = l->openupval; uv; uv = uv->u.open.threadnext) { LUAU_ASSERT(uv->tt == LUA_TUPVAL); - LUAU_ASSERT(uv->v != &uv->u.value); + LUAU_ASSERT(upisopen(uv)); + LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); + LUAU_ASSERT(!isblack(obj2gco(uv))); // open upvalues are never black } } @@ -235,11 +237,12 @@ void luaC_validate(lua_State* L) luaM_visitgco(L, L, validategco); - for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + for (UpVal* uv = g->uvhead.u.open.next; uv != &g->uvhead; uv = uv->u.open.next) { LUAU_ASSERT(uv->tt == LUA_TUPVAL); - LUAU_ASSERT(uv->v != &uv->u.value); - LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); + LUAU_ASSERT(upisopen(uv)); + LUAU_ASSERT(uv->u.open.next->u.open.prev == uv && uv->u.open.prev->u.open.next == uv); + LUAU_ASSERT(!isblack(obj2gco(uv))); // open upvalues are never black } } @@ -508,13 +511,14 @@ static void dumpproto(FILE* f, Proto* p) static void dumpupval(FILE* f, UpVal* uv) { - fprintf(f, "{\"type\":\"upvalue\",\"cat\":%d,\"size\":%d", uv->memcat, int(sizeof(UpVal))); + fprintf(f, "{\"type\":\"upvalue\",\"cat\":%d,\"size\":%d,\"open\":%s", uv->memcat, int(sizeof(UpVal)), upisopen(uv) ? "true" : "false"); if (iscollectable(uv->v)) { fprintf(f, ",\"object\":"); dumpref(f, gcvalue(uv->v)); } + fprintf(f, "}"); } diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 2097e33..778e22b 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -232,7 +232,7 @@ typedef struct TString int16_t atom; // 2 byte padding - TString* next; // next string in the hash table bucket or the string buffer linked list + TString* next; // next string in the hash table bucket unsigned int hash; unsigned int len; @@ -316,7 +316,10 @@ typedef struct LocVar typedef struct UpVal { CommonHeader; - // 1 (x86) or 5 (x64) byte padding + uint8_t markedopen; // set if reachable from an alive thread (only valid during atomic) + + // 4 byte padding (x64) + TValue* v; // points to stack or to its own value union { @@ -327,14 +330,16 @@ typedef struct UpVal struct UpVal* prev; struct UpVal* next; - // thread double linked list (when open) + // thread linked list (when open) struct UpVal* threadnext; // note: this is the location of a pointer to this upvalue in the previous element that can be either an UpVal or a lua_State - struct UpVal** threadprev; - } l; + struct UpVal** threadprev; // TODO: remove with FFlag::LuauSimplerUpval + } open; } u; } UpVal; +#define upisopen(up) ((up)->v != &(up)->u.value) + /* ** Closures */ diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 4489f84..e1cb2ab 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -10,6 +10,8 @@ #include "ldo.h" #include "ldebug.h" +LUAU_FASTFLAG(LuauSimplerUpval) + /* ** Main thread combines a thread state and the global state */ @@ -119,8 +121,11 @@ lua_State* luaE_newthread(lua_State* L) void luaE_freethread(lua_State* L, lua_State* L1, lua_Page* page) { - luaF_close(L1, L1->stack); // close all upvalues for this thread - LUAU_ASSERT(L1->openupval == NULL); + if (!FFlag::LuauSimplerUpval) + { + luaF_close(L1, L1->stack); // close all upvalues for this thread + LUAU_ASSERT(L1->openupval == NULL); + } global_State* g = L->global; if (g->cb.userthread) g->cb.userthread(NULL, L1); @@ -175,8 +180,8 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->frealloc = f; g->ud = ud; g->mainthread = L; - g->uvhead.u.l.prev = &g->uvhead; - g->uvhead.u.l.next = &g->uvhead; + g->uvhead.u.open.prev = &g->uvhead; + g->uvhead.u.open.next = &g->uvhead; g->GCthreshold = 0; // mark it as unfinished state g->registryfree = 0; g->errorjmp = NULL; diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 72a0971..df47ce7 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -167,7 +167,7 @@ typedef struct global_State GCObject* grayagain; // list of objects to be traversed atomically GCObject* weak; // list of weak tables (to be cleared) - TString* strbufgc; // list of all string buffer objects + TString* strbufgc; // list of all string buffer objects; TODO: remove with LuauNoStrbufLink size_t GCthreshold; // when totalbytes > GCthreshold, run GC step diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index 9c26603..f43d03b 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -7,6 +7,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauNoStrbufLink, false) + unsigned int luaS_hash(const char* str, size_t len) { // Note that this hashing algorithm is replicated in BytecodeBuilder.cpp, BytecodeBuilder::getStringHash @@ -70,40 +72,33 @@ void luaS_resize(lua_State* L, int newsize) static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) { - TString* ts; - stringtable* tb; if (l > MAXSSIZE) luaM_toobig(L); - ts = luaM_newgco(L, TString, sizestring(l), L->activememcat); - ts->len = unsigned(l); + + TString* ts = luaM_newgco(L, TString, sizestring(l), L->activememcat); + luaC_init(L, ts, LUA_TSTRING); + ts->atom = ATOM_UNDEF; ts->hash = h; - ts->marked = luaC_white(L->global); - ts->tt = LUA_TSTRING; - ts->memcat = L->activememcat; + ts->len = unsigned(l); + memcpy(ts->data, str, l); ts->data[l] = '\0'; // ending 0 - ts->atom = ATOM_UNDEF; - tb = &L->global->strt; + + stringtable* tb = &L->global->strt; h = lmod(h, tb->size); ts->next = tb->hash[h]; // chain new entry tb->hash[h] = ts; + tb->nuse++; if (tb->nuse > cast_to(uint32_t, tb->size) && tb->size <= INT_MAX / 2) luaS_resize(L, tb->size * 2); // too crowded + return ts; } -static void linkstrbuf(lua_State* L, TString* ts) -{ - global_State* g = L->global; - - ts->next = g->strbufgc; - g->strbufgc = ts; - ts->marked = luaC_white(g); -} - static void unlinkstrbuf(lua_State* L, TString* ts) { + LUAU_ASSERT(!FFlag::LuauNoStrbufLink); global_State* g = L->global; TString** p = &g->strbufgc; @@ -129,14 +124,24 @@ TString* luaS_bufstart(lua_State* L, size_t size) if (size > MAXSSIZE) luaM_toobig(L); + global_State* g = L->global; + TString* ts = luaM_newgco(L, TString, sizestring(size), L->activememcat); - - ts->tt = LUA_TSTRING; - ts->memcat = L->activememcat; - linkstrbuf(L, ts); - + luaC_init(L, ts, LUA_TSTRING); + ts->atom = ATOM_UNDEF; + ts->hash = 0; // computed in luaS_buffinish ts->len = unsigned(size); + if (FFlag::LuauNoStrbufLink) + { + ts->next = NULL; + } + else + { + ts->next = g->strbufgc; + g->strbufgc = ts; + } + return ts; } @@ -159,7 +164,10 @@ TString* luaS_buffinish(lua_State* L, TString* ts) } } - unlinkstrbuf(L, ts); + if (FFlag::LuauNoStrbufLink) + LUAU_ASSERT(ts->next == NULL); + else + unlinkstrbuf(L, ts); ts->hash = h; ts->data[ts->len] = '\0'; // ending 0 @@ -214,11 +222,21 @@ static bool unlinkstr(lua_State* L, TString* ts) void luaS_free(lua_State* L, TString* ts, lua_Page* page) { - // Unchain from the string table - if (!unlinkstr(L, ts)) - unlinkstrbuf(L, ts); // An unlikely scenario when we have a string buffer on our hands + if (FFlag::LuauNoStrbufLink) + { + if (unlinkstr(L, ts)) + L->global->strt.nuse--; + else + LUAU_ASSERT(ts->next == NULL); // orphaned string buffer + } else - L->global->strt.nuse--; + { + // Unchain from the string table + if (!unlinkstr(L, ts)) + unlinkstrbuf(L, ts); // An unlikely scenario when we have a string buffer on our hands + else + L->global->strt.nuse--; + } luaM_freegco(L, ts, sizestring(ts->len), ts->memcat, page); } diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 376dd40..7306b05 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,7 +16,8 @@ #include -LUAU_FASTFLAGVARIABLE(LuauNicerMethodErrors, false) +LUAU_FASTFLAG(LuauSimplerUpval) +LUAU_FASTFLAG(LuauNoSleepBit) // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ @@ -111,7 +112,7 @@ LUAU_FASTFLAGVARIABLE(LuauNicerMethodErrors, false) VM_DISPATCH_OP(LOP_LOADKX), VM_DISPATCH_OP(LOP_JUMPX), VM_DISPATCH_OP(LOP_FASTCALL), VM_DISPATCH_OP(LOP_COVERAGE), \ VM_DISPATCH_OP(LOP_CAPTURE), VM_DISPATCH_OP(LOP_JUMPIFEQK), VM_DISPATCH_OP(LOP_JUMPIFNOTEQK), VM_DISPATCH_OP(LOP_FASTCALL1), \ VM_DISPATCH_OP(LOP_FASTCALL2), VM_DISPATCH_OP(LOP_FASTCALL2K), VM_DISPATCH_OP(LOP_FORGPREP), VM_DISPATCH_OP(LOP_JUMPXEQKNIL), \ - VM_DISPATCH_OP(LOP_JUMPXEQKB), VM_DISPATCH_OP(LOP_JUMPXEQKN), VM_DISPATCH_OP(LOP_JUMPXEQKS), \ + VM_DISPATCH_OP(LOP_JUMPXEQKB), VM_DISPATCH_OP(LOP_JUMPXEQKN), VM_DISPATCH_OP(LOP_JUMPXEQKS), #if defined(__GNUC__) || defined(__clang__) #define VM_USE_CGOTO 1 @@ -317,6 +318,7 @@ static void luau_execute(lua_State* L) LUAU_ASSERT(isLua(L->ci)); LUAU_ASSERT(luaC_threadactive(L)); LUAU_ASSERT(!luaC_threadsleeping(L)); + LUAU_ASSERT(!FFlag::LuauNoSleepBit || !isblack(obj2gco(L))); // we don't use luaC_threadbarrier because active threads never turn black pc = L->ci->savedpc; cl = clvalue(L->ci->func); @@ -496,7 +498,8 @@ static void luau_execute(lua_State* L) setobj(L, uv->v, ra); luaC_barrier(L, uv, ra); - luaC_upvalbarrier(L, uv, uv->v); + if (!FFlag::LuauSimplerUpval) + luaC_upvalbarrier(L, uv, uv->v); VM_NEXT(); } @@ -932,7 +935,7 @@ static void luau_execute(lua_State* L) VM_PATCH_C(pc - 2, L->cachedslot); // recompute ra since stack might have been reallocated ra = VM_REG(LUAU_INSN_A(insn)); - if (FFlag::LuauNicerMethodErrors && ttisnil(ra)) + if (ttisnil(ra)) luaG_methoderror(L, ra + 1, tsvalue(kv)); } } @@ -973,7 +976,7 @@ static void luau_execute(lua_State* L) VM_PATCH_C(pc - 2, L->cachedslot); // recompute ra since stack might have been reallocated ra = VM_REG(LUAU_INSN_A(insn)); - if (FFlag::LuauNicerMethodErrors && ttisnil(ra)) + if (ttisnil(ra)) luaG_methoderror(L, ra + 1, tsvalue(kv)); } } @@ -984,7 +987,7 @@ static void luau_execute(lua_State* L) VM_PROTECT(luaV_gettable(L, rb, kv, ra)); // recompute ra since stack might have been reallocated ra = VM_REG(LUAU_INSN_A(insn)); - if (FFlag::LuauNicerMethodErrors && ttisnil(ra)) + if (ttisnil(ra)) luaG_methoderror(L, ra + 1, tsvalue(kv)); } } diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 86afddd..0ae85ab 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -351,7 +351,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size uint32_t mainid = readVarInt(data, size, offset); Proto* main = protos[mainid]; - luaC_checkthreadsleep(L); + luaC_threadbarrier(L); Closure* cl = luaF_newLclosure(L, 0, envt, main); setclvalue(L, L->top, cl); diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 8be241e..33d4702 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -10,9 +10,6 @@ #include "lnumutils.h" #include -#include - -LUAU_FASTFLAGVARIABLE(LuauBetterNewindex, false) // limit for table tag-method chains (to avoid loops) #define MAXTAGLOOP 100 @@ -142,46 +139,25 @@ void luaV_settable(lua_State* L, const TValue* t, TValue* key, StkId val) { // `t' is a table? Table* h = hvalue(t); - if (FFlag::LuauBetterNewindex) - { - const TValue* oldval = luaH_get(h, key); + const TValue* oldval = luaH_get(h, key); - // should we assign the key? (if key is valid or __newindex is not set) - if (!ttisnil(oldval) || (tm = fasttm(L, h->metatable, TM_NEWINDEX)) == NULL) - { - if (h->readonly) - luaG_readonlyerror(L); - - // luaH_set would work but would repeat the lookup so we use luaH_setslot that can reuse oldval if it's safe - TValue* newval = luaH_setslot(L, h, oldval, key); - - L->cachedslot = gval2slot(h, newval); // remember slot to accelerate future lookups - - setobj2t(L, newval, val); - luaC_barriert(L, h, val); - return; - } - - // fallthrough to metamethod - } - else + // should we assign the key? (if key is valid or __newindex is not set) + if (!ttisnil(oldval) || (tm = fasttm(L, h->metatable, TM_NEWINDEX)) == NULL) { if (h->readonly) luaG_readonlyerror(L); - TValue* oldval = luaH_set(L, h, key); // do a primitive set + // luaH_set would work but would repeat the lookup so we use luaH_setslot that can reuse oldval if it's safe + TValue* newval = luaH_setslot(L, h, oldval, key); - L->cachedslot = gval2slot(h, oldval); // remember slot to accelerate future lookups + L->cachedslot = gval2slot(h, newval); // remember slot to accelerate future lookups - if (!ttisnil(oldval) || // result is no nil? - (tm = fasttm(L, h->metatable, TM_NEWINDEX)) == NULL) - { // or no TM? - setobj2t(L, oldval, val); - luaC_barriert(L, h, val); - return; - } - // else will try the tag method + setobj2t(L, newval, val); + luaC_barriert(L, h, val); + return; } + + // fallthrough to metamethod } else if (ttisnil(tm = luaT_gettmbyobj(L, t, TM_NEWINDEX))) luaG_indexerror(L, t, key); diff --git a/bench/bench.py b/bench/bench.py index 42a0ac9..0db3395 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -38,6 +38,7 @@ argumentParser.add_argument('--run-test', action='store', default=None, help='Re argumentParser.add_argument('--extra-loops', action='store',type=int,default=0, help='Amount of times to loop over one test (one test already performs multiple runs)') argumentParser.add_argument('--filename', action='store',type=str,default='bench', help='File name for graph and results file') argumentParser.add_argument('--callgrind', dest='callgrind',action='store_const',const=1,default=0,help='Use callgrind to run benchmarks') +argumentParser.add_argument('--show-commands', dest='show_commands',action='store_const',const=1,default=0,help='Show the command line used to launch the VM and tests') if matplotlib != None: argumentParser.add_argument('--absolute', dest='absolute',action='store_const',const=1,default=0,help='Display absolute values instead of relative (enabled by default when benchmarking a single VM)') @@ -87,17 +88,25 @@ def getCallgrindOutput(lines): return "".join(result) +def conditionallyShowCommand(cmd): + if arguments.show_commands: + print(f'{colored(Color.BLUE, "EXECUTING")}: {cmd}') + def getVmOutput(cmd): if os.name == "nt": try: - return subprocess.check_output("start /realtime /affinity 1 /b /wait cmd /C \"" + cmd + "\"", shell=True, cwd=scriptdir).decode() + fullCmd = "start /realtime /affinity 1 /b /wait cmd /C \"" + cmd + "\"" + conditionallyShowCommand(fullCmd) + return subprocess.check_output(fullCmd, shell=True, cwd=scriptdir).decode() except KeyboardInterrupt: exit(1) except: return "" elif arguments.callgrind: try: - subprocess.check_call("valgrind --tool=callgrind --callgrind-out-file=callgrind.out --combine-dumps=yes --dump-line=no " + cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, cwd=scriptdir) + fullCmd = "valgrind --tool=callgrind --callgrind-out-file=callgrind.out --combine-dumps=yes --dump-line=no " + cmd + conditionallyShowCommand(fullCmd) + subprocess.check_call(fullCmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, cwd=scriptdir) path = os.path.join(scriptdir, "callgrind.out") with open(path, "r") as file: lines = file.readlines() @@ -106,6 +115,7 @@ def getVmOutput(cmd): except: return "" else: + conditionallyShowCommand(cmd) with subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd=scriptdir) as p: # Try to lock to a single processor if sys.platform != "darwin": diff --git a/tests/AstJsonEncoder.test.cpp b/tests/AstJsonEncoder.test.cpp index 3ff3674..a23f6f4 100644 --- a/tests/AstJsonEncoder.test.cpp +++ b/tests/AstJsonEncoder.test.cpp @@ -2,6 +2,7 @@ #include "Luau/Ast.h" #include "Luau/AstJsonEncoder.h" #include "Luau/Parser.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -175,6 +176,17 @@ TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprIfThen") CHECK(toJson(statement) == expected); } +TEST_CASE_FIXTURE(JsonEncoderFixture, "encode_AstExprInterpString") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + AstStat* statement = expectParseStatement("local a = `var = {x}`"); + + std::string_view expected = + R"({"type":"AstStatLocal","location":"0,0 - 0,17","vars":[{"luauType":null,"name":"a","type":"AstLocal","location":"0,6 - 0,7"}],"values":[{"type":"AstExprInterpString","location":"0,10 - 0,17","strings":["var = ",""],"expressions":[{"type":"AstExprGlobal","location":"0,18 - 0,19","global":"x"}]}]})"; + + CHECK(toJson(statement) == expected); +} TEST_CASE("encode_AstExprLocal") { diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 6ec1426..2b650fa 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -138,4 +138,80 @@ print(workspace:) REQUIRE(ancestry.back()->is()); } +TEST_CASE_FIXTURE(Fixture, "Luau_query") +{ + AstStatBlock* block = parse(R"( + if true then + end + )"); + + AstStatIf* if_ = Luau::query(block); + CHECK(if_); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_query_for_2nd_if_stat_which_doesnt_exist") +{ + AstStatBlock* block = parse(R"( + if true then + end + )"); + + AstStatIf* if_ = Luau::query(block); + CHECK(!if_); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_nested_query") +{ + AstStatBlock* block = parse(R"( + if true then + end + )"); + + AstStatIf* if_ = Luau::query(block); + REQUIRE(if_); + AstExprConstantBool* bool_ = Luau::query(if_); + REQUIRE(bool_); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_nested_query_but_first_query_failed") +{ + AstStatBlock* block = parse(R"( + if true then + end + )"); + + AstStatIf* if_ = Luau::query(block); + REQUIRE(!if_); + AstExprConstantBool* bool_ = Luau::query(if_); // ensure it doesn't crash + REQUIRE(!bool_); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_selectively_query_for_a_different_boolean") +{ + AstStatBlock* block = parse(R"( + local x = false and true + local y = true and false + )"); + + AstExprConstantBool* fst = Luau::query(block, {nth(), nth(2)}); + REQUIRE(fst); + REQUIRE(fst->value == true); + + AstExprConstantBool* snd = Luau::query(block, {nth(2), nth(2)}); + REQUIRE(snd); + REQUIRE(snd->value == false); +} + +TEST_CASE_FIXTURE(Fixture, "Luau_selectively_query_for_a_different_boolean_2") +{ + AstStatBlock* block = parse(R"( + local x = false and true + local y = true and false + )"); + + AstExprConstantBool* snd = Luau::query(block, {nth(2), nth()}); + REQUIRE(snd); + REQUIRE(snd->value == true); +} + TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 25447dd..988cbe8 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2708,6 +2708,15 @@ a = if temp then even else abc@3 CHECK(ac.entryMap.count("abcdef")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_interpolated_string") +{ + check(R"(f(`expression = {@1}`))"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("table")); + CHECK_EQ(ac.context, AutocompleteContext::Expression); +} + TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") { check(R"( diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index a2e748a..0a3c650 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -1230,6 +1230,58 @@ RETURN R0 0 )"); } +TEST_CASE("InterpStringWithNoExpressions") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CHECK_EQ(compileFunction0(R"(return "hello")"), compileFunction0("return `hello`")); +} + +TEST_CASE("InterpStringZeroCost") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CHECK_EQ( + "\n" + compileFunction0(R"(local _ = `hello, {"world"}!`)"), + R"( +LOADK R1 K0 +LOADK R3 K1 +NAMECALL R1 R1 K2 +CALL R1 2 1 +MOVE R0 R1 +RETURN R0 0 +)" + ); +} + +TEST_CASE("InterpStringRegisterCleanup") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CHECK_EQ( + "\n" + compileFunction0(R"( + local a, b, c = nil, "um", "uh oh" + a = `foo{"bar"}` + print(a) + )"), + + R"( +LOADNIL R0 +LOADK R1 K0 +LOADK R2 K1 +LOADK R3 K2 +LOADK R5 K3 +NAMECALL R3 R3 K4 +CALL R3 2 1 +MOVE R0 R3 +GETIMPORT R3 6 +MOVE R4 R0 +CALL R3 1 0 +RETURN R0 0 +)" + ); +} + TEST_CASE("ConstantFoldArith") { CHECK_EQ("\n" + compileFunction0("return 10 + 2"), R"( @@ -2102,8 +2154,6 @@ TEST_CASE("RecursionParse") CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile"); } -#if 0 - // This currently requires too much stack space on MSVC/x64 and crashes with stack overflow at recursion depth 935 try { Luau::compileOrThrow(bcb, rep("function a() ", 1500) + "print()" + rep(" end", 1500)); @@ -2123,7 +2173,6 @@ TEST_CASE("RecursionParse") { CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); } -#endif } TEST_CASE("ArrayIndexLiteral") diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index be2feac..f6f5b41 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -294,6 +294,14 @@ TEST_CASE("Strings") runConformance("strings.lua"); } +TEST_CASE("StringInterp") +{ + ScopedFastFlag sffInterpStrings{"LuauInterpolatedStringBaseSupport", true}; + ScopedFastFlag sffTostringFormat{"LuauTostringFormatSpecifier", true}; + + runConformance("stringinterp.lua"); +} + TEST_CASE("VarArg") { runConformance("vararg.lua"); @@ -311,15 +319,11 @@ TEST_CASE("Literals") TEST_CASE("Errors") { - ScopedFastFlag sff("LuauNicerMethodErrors", true); - runConformance("errors.lua"); } TEST_CASE("Events") { - ScopedFastFlag sff("LuauBetterNewindex", true); - runConformance("events.lua"); } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 40be39a..4051f85 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -512,4 +512,41 @@ void dump(const std::vector& constraints) printf("%s\n", toString(c, opts).c_str()); } +FindNthOccurenceOf::FindNthOccurenceOf(Nth nth) + : requestedNth(nth) +{ +} + +bool FindNthOccurenceOf::checkIt(AstNode* n) +{ + if (theNode) + return false; + + if (n->classIndex == requestedNth.classIndex) + { + // Human factor: the requestedNth starts from 1 because of the term `nth`. + if (currentOccurrence + 1 != requestedNth.nth) + ++currentOccurrence; + else + theNode = n; + } + + return !theNode; // once found, returns false and stops traversal +} + +bool FindNthOccurenceOf::visit(AstNode* n) +{ + return checkIt(n); +} + +bool FindNthOccurenceOf::visit(AstType* t) +{ + return checkIt(t); +} + +bool FindNthOccurenceOf::visit(AstTypePack* t) +{ + return checkIt(t); +} + } // namespace Luau diff --git a/tests/Fixture.h b/tests/Fixture.h index 8dc3dd2..e82ebf0 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -195,6 +195,76 @@ std::optional lookupName(ScopePtr scope, const std::string& name); // Wa std::optional linearSearchForBinding(Scope* scope, const char* name); +struct Nth +{ + int classIndex; + int nth; +}; + +template +Nth nth(int nth = 1) +{ + static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); + LUAU_ASSERT(nth > 0); // Did you mean to use `nth(1)`? + + return Nth{T::ClassIndex(), nth}; +} + +struct FindNthOccurenceOf : public AstVisitor +{ + Nth requestedNth; + size_t currentOccurrence = 0; + AstNode* theNode = nullptr; + + FindNthOccurenceOf(Nth nth); + + bool checkIt(AstNode* n); + + bool visit(AstNode* n) override; + bool visit(AstType* n) override; + bool visit(AstTypePack* n) override; +}; + +/** DSL querying of the AST. + * + * Given an AST, one can query for a particular node directly without having to manually unwrap the tree, for example: + * + * ``` + * if a and b then + * print(a + b) + * end + * + * function f(x, y) + * return x + y + * end + * ``` + * + * There are numerous ways to access the second AstExprBinary. + * 1. Luau::query(block, {nth(), nth()}) + * 2. Luau::query(Luau::query(block)) + * 3. Luau::query(block, {nth(2)}) + */ +template +T* query(AstNode* node, const std::vector& nths = {nth(N)}) +{ + static_assert(std::is_base_of_v, "T must be a derived class of AstNode"); + + // If a nested query call fails to find the node in question, subsequent calls can propagate rather than trying to do more. + // This supports `query(query(...))` + + for (Nth nth : nths) + { + if (!node) + return nullptr; + + FindNthOccurenceOf finder{nth}; + node->visit(&finder); + node = finder.theNode; + } + + return node ? node->as() : nullptr; +} + } // namespace Luau #define LUAU_REQUIRE_ERRORS(result) \ diff --git a/tests/Lexer.test.cpp b/tests/Lexer.test.cpp index 20d8d0d..890d100 100644 --- a/tests/Lexer.test.cpp +++ b/tests/Lexer.test.cpp @@ -138,4 +138,90 @@ TEST_CASE("lookahead") CHECK_EQ(lexer.lookahead().type, Lexeme::Eof); } +TEST_CASE("string_interpolation_basic") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + const std::string testInput = R"(`foo {"bar"}`)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + Lexeme interpBegin = lexer.next(); + CHECK_EQ(interpBegin.type, Lexeme::InterpStringBegin); + + Lexeme quote = lexer.next(); + CHECK_EQ(quote.type, Lexeme::QuotedString); + + Lexeme interpEnd = lexer.next(); + CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); +} + +TEST_CASE("string_interpolation_double_brace") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + const std::string testInput = R"(`foo{{bad}}bar`)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + auto brokenInterpBegin = lexer.next(); + CHECK_EQ(brokenInterpBegin.type, Lexeme::BrokenInterpDoubleBrace); + CHECK_EQ(std::string(brokenInterpBegin.data, brokenInterpBegin.length), std::string("foo")); + + CHECK_EQ(lexer.next().type, Lexeme::Name); + + auto interpEnd = lexer.next(); + CHECK_EQ(interpEnd.type, Lexeme::InterpStringEnd); + CHECK_EQ(std::string(interpEnd.data, interpEnd.length), std::string("}bar")); +} + +TEST_CASE("string_interpolation_double_but_unmatched_brace") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + const std::string testInput = R"(`{{oops}`, 1)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + CHECK_EQ(lexer.next().type, Lexeme::BrokenInterpDoubleBrace); + CHECK_EQ(lexer.next().type, Lexeme::Name); + CHECK_EQ(lexer.next().type, Lexeme::InterpStringEnd); + CHECK_EQ(lexer.next().type, ','); + CHECK_EQ(lexer.next().type, Lexeme::Number); +} + +TEST_CASE("string_interpolation_unmatched_brace") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + const std::string testInput = R"({ + `hello {"world"} + } -- this might be incorrectly parsed as a string)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + CHECK_EQ(lexer.next().type, '{'); + CHECK_EQ(lexer.next().type, Lexeme::InterpStringBegin); + CHECK_EQ(lexer.next().type, Lexeme::QuotedString); + CHECK_EQ(lexer.next().type, Lexeme::BrokenString); + CHECK_EQ(lexer.next().type, '}'); +} + +TEST_CASE("string_interpolation_with_unicode_escape") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + const std::string testInput = R"(`\u{1F41B}`)"; + Luau::Allocator alloc; + AstNameTable table(alloc); + Lexer lexer(testInput.c_str(), testInput.size(), table); + + CHECK_EQ(lexer.next().type, Lexeme::InterpStringSimple); + CHECK_EQ(lexer.next().type, Lexeme::Eof); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 64c6d3e..a7d09e8 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -43,6 +43,20 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") CHECK_EQ(result.warnings[0].text, "Global 'Wait' is deprecated, use 'wait' instead"); } +TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobalNoReplacement") +{ + ScopedFastFlag sff{"LuauLintFixDeprecationMessage", true}; + + // Normally this would be defined externally, so hack it in for testing + const char* deprecationReplacementString = ""; + addGlobalBinding(typeChecker, "Version", Binding{typeChecker.anyType, {}, true, deprecationReplacementString}); + + LintResult result = lintTyped("Version()"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Global 'Version' is deprecated"); +} + TEST_CASE_FIXTURE(Fixture, "PlaceholderRead") { LintResult result = lint(R"( @@ -1662,17 +1676,31 @@ TEST_CASE_FIXTURE(Fixture, "WrongCommentOptimize") { LintResult result = lint(R"( --!optimize ---!optimize --!optimize me --!optimize 100500 --!optimize 2 )"); - REQUIRE_EQ(result.warnings.size(), 4); + REQUIRE_EQ(result.warnings.size(), 3); CHECK_EQ(result.warnings[0].text, "optimize directive requires an optimization level"); - CHECK_EQ(result.warnings[1].text, "optimize directive requires an optimization level"); - CHECK_EQ(result.warnings[2].text, "optimize directive uses unknown optimization level 'me', 0..2 expected"); - CHECK_EQ(result.warnings[3].text, "optimize directive uses unknown optimization level '100500', 0..2 expected"); + CHECK_EQ(result.warnings[1].text, "optimize directive uses unknown optimization level 'me', 0..2 expected"); + CHECK_EQ(result.warnings[2].text, "optimize directive uses unknown optimization level '100500', 0..2 expected"); + + result = lint("--!optimize "); + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "optimize directive requires an optimization level"); +} + +TEST_CASE_FIXTURE(Fixture, "TestStringInterpolation") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + LintResult result = lint(R"( + --!nocheck + local _ = `unknown {foo}` + )"); + + REQUIRE_EQ(result.warnings.size(), 1); } TEST_CASE_FIXTURE(Fixture, "IntegerParsing") diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index c55ec18..2dd4770 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -905,6 +905,146 @@ TEST_CASE_FIXTURE(Fixture, "parse_compound_assignment_error_multiple") } } +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_begin") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + _ = `{{oops}}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Double braces are not permitted within interpolated strings. Did you mean '\\{'?", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_double_brace_mid") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + _ = `{nice} {{oops}}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Double braces are not permitted within interpolated strings. Did you mean '\\{'?", e.getErrors().front().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + auto columnOfEndBraceError = [this](const char* code) + { + try + { + parse(code); + FAIL("Expected ParseErrors to be thrown"); + return UINT_MAX; + } + catch (const ParseErrors& e) + { + CHECK_EQ(e.getErrors().size(), 1); + + auto error = e.getErrors().front(); + CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", error.getMessage()); + return error.getLocation().begin.column; + } + }; + + // This makes sure that the error is coming from the brace itself + CHECK_EQ(columnOfEndBraceError("_ = `{a`"), columnOfEndBraceError("_ = `{abcdefg`")); + CHECK_NE(columnOfEndBraceError("_ = `{a`"), columnOfEndBraceError("_ = `{a`")); +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_without_end_brace_in_table") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + _ = { `{a` } + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ(e.getErrors().size(), 2); + + CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", e.getErrors().front().getMessage()); + CHECK_EQ("Expected '}' (to close '{' at line 2), got ", e.getErrors().back().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_mid_without_end_brace_in_table") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + _ = { `x {"y"} {z` } + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ(e.getErrors().size(), 2); + + CHECK_EQ("Malformed interpolated string, did you forget to add a '}'?", e.getErrors().front().getMessage()); + CHECK_EQ("Expected '}' (to close '{' at line 2), got ", e.getErrors().back().getMessage()); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_as_type_fail") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + local a: `what` = `???` + local b: `what {"the"}` = `???` + local c: `what {"the"} heck` = `???` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& parseErrors) + { + CHECK_EQ(parseErrors.getErrors().size(), 3); + + for (ParseError error : parseErrors.getErrors()) + CHECK_EQ(error.getMessage(), "Interpolated string literals cannot be used as types"); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_interpolated_string_call_without_parens") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + try + { + parse(R"( + _ = print `{42}` + )"); + FAIL("Expected ParseErrors to be thrown"); + } + catch (const ParseErrors& e) + { + CHECK_EQ("Expected identifier when parsing expression, got `{", e.getErrors().front().getMessage()); + } +} + TEST_CASE_FIXTURE(Fixture, "parse_nesting_based_end_detection") { try diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index fe376d8..ab5d859 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -11,6 +11,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauRecursiveTypeParameterRestriction); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); +LUAU_FASTFLAG(LuauFixNameMaps); TEST_SUITE_BEGIN("ToString"); @@ -433,29 +434,40 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") LUAU_REQUIRE_NO_ERRORS(result); - TypeId id3Type = requireType("id3"); - ToStringResult nameData = toStringDetailed(id3Type); + ToStringOptions opts; + + TypeId id3Type = requireType("id3"); + ToStringResult nameData = toStringDetailed(id3Type, opts); + + if (FFlag::LuauFixNameMaps) + REQUIRE(3 == opts.nameMap.typeVars.size()); + else + REQUIRE_EQ(3, nameData.DEPRECATED_nameMap.typeVars.size()); - REQUIRE_EQ(3, nameData.nameMap.typeVars.size()); REQUIRE_EQ("(a, b, c) -> (a, b, c)", nameData.name); - ToStringOptions opts; - opts.nameMap = std::move(nameData.nameMap); + ToStringOptions opts2; // TODO: delete opts2 when clipping FFlag::LuauFixNameMaps + if (FFlag::LuauFixNameMaps) + opts2.nameMap = std::move(opts.nameMap); + else + opts2.DEPRECATED_nameMap = std::move(nameData.DEPRECATED_nameMap); const FunctionTypeVar* ftv = get(follow(id3Type)); REQUIRE(ftv != nullptr); auto params = flatten(ftv->argTypes).first; - REQUIRE_EQ(3, params.size()); + REQUIRE(3 == params.size()); - REQUIRE_EQ("a", toString(params[0], opts)); - REQUIRE_EQ("b", toString(params[1], opts)); - REQUIRE_EQ("c", toString(params[2], opts)); + CHECK("a" == toString(params[0], opts2)); + CHECK("b" == toString(params[1], opts2)); + CHECK("c" == toString(params[2], opts2)); } TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") { - ScopedFastFlag sff2{"DebugLuauSharedSelf", true}; + ScopedFastFlag sff[] = { + {"DebugLuauSharedSelf", true}, + }; CheckResult result = check(R"( local base = {} @@ -470,13 +482,18 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") )"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId tType = requireType("inst"); - ToStringResult r = toStringDetailed(tType); - CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); - CHECK_EQ(0, r.nameMap.typeVars.size()); - ToStringOptions opts; - opts.nameMap = r.nameMap; + + TypeId tType = requireType("inst"); + ToStringResult r = toStringDetailed(tType, opts); + CHECK_EQ("{ @metatable { __index: { @metatable {| __index: base |}, child } }, inst }", r.name); + if (FFlag::LuauFixNameMaps) + CHECK(0 == opts.nameMap.typeVars.size()); + else + CHECK_EQ(0, r.DEPRECATED_nameMap.typeVars.size()); + + if (!FFlag::LuauFixNameMaps) + opts.DEPRECATED_nameMap = r.DEPRECATED_nameMap; const MetatableTypeVar* tMeta = get(tType); REQUIRE(tMeta); @@ -499,7 +516,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "toStringDetailed2") REQUIRE(tMeta6); ToStringResult oneResult = toStringDetailed(tMeta5->props["one"].type, opts); - opts.nameMap = oneResult.nameMap; + if (!FFlag::LuauFixNameMaps) + opts.DEPRECATED_nameMap = oneResult.DEPRECATED_nameMap; std::string twoResult = toString(tMeta6->props["two"].type, opts); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index d2ed9ae..e79bc9b 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -6,6 +6,7 @@ #include "Luau/Transpiler.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -678,4 +679,22 @@ TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple_types") CHECK_EQ(code, transpile(code, {}, true).code); } +TEST_CASE_FIXTURE(Fixture, "transpile_string_interp") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + std::string code = R"( local _ = `hello {name}` )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_string_literal_escape") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + std::string code = R"( local _ = ` bracket = \{, backtick = \` = {'ok'} ` )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 10da0ef..12fb4aa 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -10,6 +10,7 @@ using namespace Luau; LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauSpecialTypesAsterisked); +LUAU_FASTFLAG(LuauStringFormatArgumentErrorFix) TEST_SUITE_BEGIN("BuiltinTests"); @@ -721,7 +722,14 @@ TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); + if (FFlag::LuauStringFormatArgumentErrorFix) + { + CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but 3 are specified", toString(result.errors[0])); + } + else + { + CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument2") @@ -736,6 +744,22 @@ TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument2") CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[1])); } +TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_use_correct_argument3") +{ + ScopedFastFlag LuauStringFormatArgumentErrorFix{"LuauStringFormatArgumentErrorFix", true}; + + CheckResult result = check(R"( + local s1 = string.format("%d") + local s2 = string.format("%d", 1) + local s3 = string.format("%d", 1, 2) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but only 1 is specified", toString(result.errors[0])); + CHECK_EQ("Argument count mismatch. Function expects 2 arguments, but 3 are specified", toString(result.errors[1])); +} + TEST_CASE_FIXTURE(BuiltinsFixture, "debug_traceback_is_crazy") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 9a917a6..c8fc7f2 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -544,4 +544,69 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "greedy_inference_with_shared_self_triggers_f CHECK_EQ("Not all codepaths in this function return 'self, a...'.", toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "dcr_cant_partially_dispatch_a_constraint") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + {"LuauSpecialTypesAsterisked", true}, + }; + + CheckResult result = check(R"( + local function hasDivisors(value: number) + end + + function prime_iter(state, index) + hasDivisors(index) + index += 1 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // We should be able to resolve this to number, but we're not there yet. + // Solving this requires recognizing that we can partially solve the + // following constraint: + // + // (*blocked*) -> () <: (number) -> (b...) + // + // The correct thing for us to do is to consider the constraint dispatched, + // but we need to also record a new constraint number <: *blocked* to finish + // the job later. + CHECK("(a, *error-type*) -> ()" == toString(requireType("prime_iter"))); +} + +TEST_CASE_FIXTURE(Fixture, "free_options_cannot_be_unified_together") +{ + ScopedFastFlag sff[] = { + {"LuauFixNameMaps", true}, + }; + + TypeArena arena; + TypeId nilType = getSingletonTypes().nilType; + + std::unique_ptr scope = std::make_unique(getSingletonTypes().anyTypePack); + + TypeId free1 = arena.addType(FreeTypePack{scope.get()}); + TypeId option1 = arena.addType(UnionTypeVar{{nilType, free1}}); + + TypeId free2 = arena.addType(FreeTypePack{scope.get()}); + TypeId option2 = arena.addType(UnionTypeVar{{nilType, free2}}); + + InternalErrorReporter iceHandler; + UnifierSharedState sharedState{&iceHandler}; + Unifier u{&arena, Mode::Strict, NotNull{scope.get()}, Location{}, Variance::Covariant, sharedState}; + + u.tryUnify(option1, option2); + + CHECK(u.errors.empty()); + + u.log.commit(); + + ToStringOptions opts; + CHECK("a?" == toString(option1, opts)); + + // CHECK("a?" == toString(option2, opts)); // This should hold, but does not. + CHECK("b?" == toString(option2, opts)); // This should not hold. +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 8088936..e1dc502 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -8,6 +8,7 @@ #include "Luau/VisitTypeVar.h" #include "Fixture.h" +#include "ScopedFlags.h" #include "doctest.h" @@ -828,6 +829,41 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "tc_interpolated_string_basic") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CheckResult result = check(R"( + local foo: string = `hello {"world"}` + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_interpolated_string_with_invalid_expression") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CheckResult result = check(R"( + local function f(x: number) end + + local foo: string = `hello {f("uh oh")}` + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_interpolated_string_constant_type") +{ + ScopedFastFlag sff{"LuauInterpolatedStringBaseSupport", true}; + + CheckResult result = check(R"( + local foo: "hello" = `hello` + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + /* * If it wasn't instantly obvious, we have the fuzzer to thank for this gem of a test. * diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index 5804ea7..c49dbe7 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -252,25 +252,30 @@ if not rawget(_G, "_soft") then end -- create many threads with self-references and open upvalues -local thread_id = 0 -local threads = {} +do + local thread_id = 0 + local threads = {} -function fn(thread) - local x = {} - threads[thread_id] = function() - thread = x - end - coroutine.yield() + function fn(thread) + local x = {} + threads[thread_id] = function() + thread = x + end + coroutine.yield() + end + + while thread_id < 1000 do + local thread = coroutine.create(fn) + coroutine.resume(thread, thread) + thread_id = thread_id + 1 + end + + collectgarbage() + + -- ensure that we no longer have a lot of reachable threads for subsequent tests + threads = {} end -while thread_id < 1000 do - local thread = coroutine.create(fn) - coroutine.resume(thread, thread) - thread_id = thread_id + 1 -end - - - -- create a userdata to be collected when state is closed do local newproxy,assert,type,print,getmetatable = @@ -322,4 +327,27 @@ do collectgarbage() end +-- create a lot of threads with upvalues to force a case where full gc happens after we've marked some upvalues +do + local t = {} + for i = 1,100 do + local c = coroutine.wrap(function() + local uv = {i + 1} + local function f() + return uv[1] * 10 + end + coroutine.yield(uv[1]) + uv = {i + 2} + coroutine.yield(f()) + end) + + assert(c() == i + 1) + table.insert(t, c) + end + + t = {} + + collectgarbage() +end + return('OK') diff --git a/tests/conformance/stringinterp.lua b/tests/conformance/stringinterp.lua new file mode 100644 index 0000000..efb25ba --- /dev/null +++ b/tests/conformance/stringinterp.lua @@ -0,0 +1,59 @@ +local function assertEq(left, right) + assert(typeof(left) == "string", "left is a " .. typeof(left)) + assert(typeof(right) == "string", "right is a " .. typeof(right)) + + if left ~= right then + error(string.format("%q ~= %q", left, right)) + end +end + +assertEq(`hello {"world"}`, "hello world") +assertEq(`Welcome {"to"} {"Luau"}!`, "Welcome to Luau!") + +assertEq(`2 + 2 = {2 + 2}`, "2 + 2 = 4") + +assertEq(`{1} {2} {3} {4} {5} {6} {7}`, "1 2 3 4 5 6 7") + +local combo = {5, 2, 8, 9} +assertEq(`The lock combinations are: {table.concat(combo, ", ")}`, "The lock combinations are: 5, 2, 8, 9") + +assertEq(`true = {true}`, "true = true") + +local name = "Luau" +assertEq(`Welcome to { + name +}!`, "Welcome to Luau!") + +local nameNotConstantEvaluated = (function() return "Luau" end)() +assertEq(`Welcome to {nameNotConstantEvaluated}!`, "Welcome to Luau!") + +assertEq(`This {localName} does not exist`, "This nil does not exist") + +assertEq(`Welcome to \ +{name}!`, "Welcome to \nLuau!") + +assertEq(`empty`, "empty") + +assertEq(`Escaped brace: \{}`, "Escaped brace: {}") +assertEq(`Escaped brace \{} with {"expression"}`, "Escaped brace {} with expression") +assertEq(`Backslash \ that escapes the space is not a part of the string...`, "Backslash that escapes the space is not a part of the string...") +assertEq(`Escaped backslash \\`, "Escaped backslash \\") +assertEq(`Escaped backtick: \``, "Escaped backtick: `") + +assertEq(`Hello {`from inside {"a nested string"}`}`, "Hello from inside a nested string") + +assertEq(`1 {`2 {`3 {4}`}`}`, "1 2 3 4") + +local health = 50 +assert(`You have {health}% health` == "You have 50% health") + +local function shadowsString(string) + return `Value is {string}` +end + +assertEq(shadowsString("hello"), "Value is hello") +assertEq(shadowsString(1), "Value is 1") + +assertEq(`\u{0041}\t`, "A\t") + +return "OK" diff --git a/tools/faillist.txt b/tools/faillist.txt index 630bf9f..54e7ac0 100644 --- a/tools/faillist.txt +++ b/tools/faillist.txt @@ -53,6 +53,7 @@ AutocompleteTest.generic_types AutocompleteTest.get_suggestions_for_the_very_start_of_the_script AutocompleteTest.global_function_params AutocompleteTest.global_functions_are_not_scoped_lexically +AutocompleteTest.globals_are_order_independent AutocompleteTest.if_then_else_elseif_completions AutocompleteTest.keyword_methods AutocompleteTest.keyword_types @@ -588,7 +589,6 @@ TypeInferFunctions.another_recursive_local_function TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument TypeInferFunctions.complicated_return_types_require_an_explicit_annotation -TypeInferFunctions.cyclic_function_type_in_args TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_infer_parameter_types_for_functions_from_their_call_site TypeInferFunctions.duplicate_functions_with_different_signatures_not_allowed_in_nonstrict @@ -744,7 +744,6 @@ TypeInferUnknownNever.math_operators_and_never TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.unary_minus_of_never -TypeInferUnknownNever.unknown_is_reflexive TypePackTests.higher_order_function TypePackTests.multiple_varargs_inference_are_not_confused TypePackTests.no_return_size_should_be_zero diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis index cb2f355..e45e4e2 100644 --- a/tools/natvis/VM.natvis +++ b/tools/natvis/VM.natvis @@ -195,7 +195,7 @@ openupval - u.l.threadnext + u.open.threadnext this diff --git a/tools/perfgraph.py b/tools/perfgraph.py index 25bb0dd..7d2639d 100644 --- a/tools/perfgraph.py +++ b/tools/perfgraph.py @@ -56,7 +56,63 @@ def nodeFromCallstackListFile(source_file): return root -def getDuration(obj): + +def getDuration(nodes, nid): + node = nodes[nid - 1] + total = node['TotalDuration'] + + for cid in node['NodeIds']: + total -= nodes[cid - 1]['TotalDuration'] + + return total + +def getFunctionKey(fn): + return fn['Source'] + "," + fn['Name'] + "," + str(fn['Line']) + +def recursivelyBuildNodeTree(nodes, functions, parent, fid, nid): + ninfo = nodes[nid - 1] + finfo = functions[fid - 1] + + child = parent.child(getFunctionKey(finfo)) + child.source = finfo['Source'] + child.function = finfo['Name'] + child.line = int(finfo['Line']) if finfo['Line'] > 0 else 0 + + child.ticks = getDuration(nodes, nid) + + assert(len(ninfo['FunctionIds']) == len(ninfo['NodeIds'])) + + for i in range(0, len(ninfo['FunctionIds'])): + recursivelyBuildNodeTree(nodes, functions, child, ninfo['FunctionIds'][i], ninfo['NodeIds'][i]) + + return + +def nodeFromJSONV2(dump): + assert(dump['Version'] == 2) + + nodes = dump['Nodes'] + functions = dump['Functions'] + categories = dump['Categories'] + + root = Node() + + for category in categories: + nid = category['NodeId'] + node = nodes[nid - 1] + name = category['Name'] + + child = root.child(name) + child.function = name + child.ticks = getDuration(nodes, nid) + + assert(len(node['FunctionIds']) == len(node['NodeIds'])) + + for i in range(0, len(node['FunctionIds'])): + recursivelyBuildNodeTree(nodes, functions, child, node['FunctionIds'][i], node['NodeIds'][i]) + + return root + +def getDurationV1(obj): total = obj['TotalDuration'] if 'Children' in obj: @@ -73,7 +129,7 @@ def nodeFromJSONObject(node, key, obj): node.source = source node.line = int(line) if len(line) > 0 else 0 - node.ticks = getDuration(obj) + node.ticks = getDurationV1(obj) if 'Children' in obj: for key, obj in obj['Children'].items(): @@ -81,10 +137,8 @@ def nodeFromJSONObject(node, key, obj): return node - -def nodeFromJSONFile(source_file): - dump = json.load(source_file) - +def nodeFromJSONV1(dump): + assert(dump['Version'] == 1) root = Node() if 'Children' in dump: @@ -93,6 +147,16 @@ def nodeFromJSONFile(source_file): return root +def nodeFromJSONFile(source_file): + dump = json.load(source_file) + + if dump['Version'] == 2: + return nodeFromJSONV2(dump) + elif dump['Version'] == 1: + return nodeFromJSONV1(dump) + + return Node() + arguments = argumentParser.parse_args() diff --git a/tools/test_dcr.py b/tools/test_dcr.py index 0efea3c..da33706 100644 --- a/tools/test_dcr.py +++ b/tools/test_dcr.py @@ -14,12 +14,14 @@ def loadFailList(): with open(FAIL_LIST_PATH) as f: return set(map(str.strip, f.readlines())) + def safeParseInt(i, default=0): try: return int(i) except ValueError: return default + class Handler(x.ContentHandler): def __init__(self, failList): self.currentTest = [] @@ -47,7 +49,7 @@ class Handler(x.ContentHandler): r = self.results.get(dottedName, True) self.results[dottedName] = r and passed - elif name == 'OverallResultsTestCases': + elif name == "OverallResultsTestCases": self.numSkippedTests = safeParseInt(attrs.get("skipped", 0)) def endElement(self, name): @@ -104,9 +106,9 @@ def main(): for testName, passed in handler.results.items(): if passed and testName in failList: - print('UNEXPECTED: {} should have failed'.format(testName)) + print("UNEXPECTED: {} should have failed".format(testName)) elif not passed and testName not in failList: - print('UNEXPECTED: {} should have passed'.format(testName)) + print("UNEXPECTED: {} should have passed".format(testName)) if args.write: newFailList = sorted( @@ -123,17 +125,24 @@ def main(): print("Updated faillist.txt") if handler.numSkippedTests > 0: - print('{} test(s) were skipped! That probably means that a test segfaulted!'.format(handler.numSkippedTests), file=sys.stderr) + print( + "{} test(s) were skipped! That probably means that a test segfaulted!".format( + handler.numSkippedTests + ), + file=sys.stderr, + ) sys.exit(1) - sys.exit( - 0 - if all( - not passed == (dottedName in failList) - for dottedName, passed in handler.results.items() - ) - else 1 + ok = all( + not passed == (dottedName in failList) + for dottedName, passed in handler.results.items() ) + if ok: + print("Everything in order!", file=sys.stderr) + + sys.exit(0 if ok else 1) + + if __name__ == "__main__": main()