diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 57a1907..07d897b 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -34,7 +34,6 @@ TypeId makeFunction( // Polymorphic std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); void attachMagicFunction(TypeId ty, MagicFunction fn); -void attachFunctionTag(TypeId ty, std::string constraint); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName); diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h new file mode 100644 index 0000000..f46df14 --- /dev/null +++ b/Analysis/include/Luau/Quantify.h @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeVar.h" + +namespace Luau +{ + +struct Module; +using ModulePtr = std::shared_ptr; + +void quantify(ModulePtr module, TypeId ty, TypeLevel level); + +} // namespace Luau diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 0897ec8..e5683fc 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -69,4 +69,6 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); void dump(TypeId ty); void dump(TypePackId ty); +std::string generateName(size_t n); + } // namespace Luau diff --git a/Analysis/include/Luau/TopoSortStatements.h b/Analysis/include/Luau/TopoSortStatements.h index 751694f..4a4acfa 100644 --- a/Analysis/include/Luau/TopoSortStatements.h +++ b/Analysis/include/Luau/TopoSortStatements.h @@ -12,6 +12,7 @@ struct AstArray; class AstStat; bool containsFunctionCall(const AstStat& stat); +bool containsFunctionCallOrReturn(const AstStat& stat); bool isFunction(const AstStat& stat); void toposort(std::vector& stats); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 055441c..322abd1 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -3,19 +3,37 @@ #include "Luau/TypeVar.h" +LUAU_FASTFLAG(LuauShareTxnSeen); + namespace Luau { // Log of where what TypeIds we are rebinding and what they used to be struct TxnLog { - TxnLog() = default; - - explicit TxnLog(const std::vector>& seen) - : seen(seen) + TxnLog() + : originalSeenSize(0) + , ownedSeen() + , sharedSeen(&ownedSeen) { } + explicit TxnLog(std::vector>* sharedSeen) + : originalSeenSize(sharedSeen->size()) + , ownedSeen() + , sharedSeen(sharedSeen) + { + } + + explicit TxnLog(const std::vector>& ownedSeen) + : originalSeenSize(ownedSeen.size()) + , ownedSeen(ownedSeen) + , sharedSeen(nullptr) + { + // This is deprecated! + LUAU_ASSERT(!FFlag::LuauShareTxnSeen); + } + TxnLog(const TxnLog&) = delete; TxnLog& operator=(const TxnLog&) = delete; @@ -38,9 +56,11 @@ private: std::vector> typeVarChanges; std::vector> typePackChanges; std::vector>> tableChanges; + size_t originalSeenSize; public: - std::vector> seen; // used to avoid infinite recursion when types are cyclic + std::vector> ownedSeen; // used to avoid infinite recursion when types are cyclic + std::vector>* sharedSeen; // shared with all the descendent logs }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index d701eb2..9d62fef 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -11,6 +11,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" +#include "Luau/UnifierSharedState.h" #include #include @@ -121,7 +122,7 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatForIn& forin); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); - void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false); + void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0, bool forwardDeclare = false); void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); @@ -336,7 +337,7 @@ private: // Note: `scope` must be a fresh scope. std::pair, std::vector> createGenericTypes( - const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -383,6 +384,8 @@ public: std::function prepareModuleScope; InternalErrorReporter* iceHandler; + UnifierSharedState unifierState; + public: const TypeId nilType; const TypeId numberType; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index e04aa2c..9611e88 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -540,4 +540,11 @@ UnionTypeVarIterator end(const UnionTypeVar* utv); using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); +void attachTag(TypeId ty, const std::string& tagName); +void attachTag(Property& prop, const std::string& tagName); + +bool hasTag(TypeId ty, const std::string& tagName); +bool hasTag(const Property& prop, const std::string& tagName); +bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work. + } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 522914b..56632e3 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -6,6 +6,7 @@ #include "Luau/TxnLog.h" #include "Luau/TypeInfer.h" #include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header. +#include "Luau/UnifierSharedState.h" #include @@ -41,11 +42,14 @@ struct Unifier std::shared_ptr counters_DEPRECATED; - InternalErrorReporter* iceHandler; + UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED = nullptr, + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, + UnifierCounters* counters = nullptr); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, UnifierCounters* counters = nullptr); // Test whether the two type vars unify. Never commits the result. @@ -69,7 +73,8 @@ private: void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); - TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + void cacheResult(TypeId superTy, TypeId subTy); public: void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); @@ -101,8 +106,9 @@ private: [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); - DenseHashSet tempSeenTy{nullptr}; - DenseHashSet tempSeenTp{nullptr}; + // Remove with FFlagLuauCacheUnifyTableResults + DenseHashSet tempSeenTy_DEPRECATED{nullptr}; + DenseHashSet tempSeenTp_DEPRECATED{nullptr}; }; } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h new file mode 100644 index 0000000..f252a00 --- /dev/null +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -0,0 +1,44 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ +struct InternalErrorReporter; + +struct TypeIdPairHash +{ + size_t hashOne(Luau::TypeId key) const + { + return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9); + } + + size_t operator()(const std::pair& x) const + { + return hashOne(x.first) ^ (hashOne(x.second) << 1); + } +}; + +struct UnifierSharedState +{ + UnifierSharedState(InternalErrorReporter* iceHandler) + : iceHandler(iceHandler) + { + } + + InternalErrorReporter* iceHandler; + + DenseHashSet seenAny{nullptr}; + DenseHashMap skipCacheForType{nullptr}; + DenseHashSet, TypeIdPairHash> cachedUnify{{nullptr, nullptr}}; + + DenseHashSet tempSeenTy{nullptr}; + DenseHashSet tempSeenTp{nullptr}; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index df0bd42..a866655 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -1,9 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include "Luau/TypeVar.h" #include "Luau/TypePack.h" +LUAU_FASTFLAG(LuauCacheUnifyTableResults) + namespace Luau { @@ -32,17 +35,33 @@ inline bool hasSeen(std::unordered_set& seen, const void* tv) return !seen.insert(ttv).second; } +inline bool hasSeen(DenseHashSet& seen, const void* tv) +{ + void* ttv = const_cast(tv); + + if (seen.contains(ttv)) + return true; + + seen.insert(ttv); + return false; +} + inline void unsee(std::unordered_set& seen, const void* tv) { void* ttv = const_cast(tv); seen.erase(ttv); } -template -void visit(TypePackId tp, F& f, std::unordered_set& seen); +inline void unsee(DenseHashSet& seen, const void* tv) +{ + // When DenseHashSet is used for 'visitOnce', where don't forget visited elements +} -template -void visit(TypeId ty, F& f, std::unordered_set& seen) +template +void visit(TypePackId tp, F& f, Set& seen); + +template +void visit(TypeId ty, F& f, Set& seen) { if (visit_detail::hasSeen(seen, ty)) { @@ -79,15 +98,23 @@ void visit(TypeId ty, F& f, std::unordered_set& seen) else if (auto ttv = get(ty)) { + // Some visitors want to see bound tables, that's why we visit the original type if (apply(ty, *ttv, seen, f)) { - for (auto& [_name, prop] : ttv->props) - visit(prop.type, f, seen); - - if (ttv->indexer) + if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo) { - visit(ttv->indexer->indexType, f, seen); - visit(ttv->indexer->indexResultType, f, seen); + visit(*ttv->boundTo, f, seen); + } + else + { + for (auto& [_name, prop] : ttv->props) + visit(prop.type, f, seen); + + if (ttv->indexer) + { + visit(ttv->indexer->indexType, f, seen); + visit(ttv->indexer->indexResultType, f, seen); + } } } } @@ -140,8 +167,8 @@ void visit(TypeId ty, F& f, std::unordered_set& seen) visit_detail::unsee(seen, ty); } -template -void visit(TypePackId tp, F& f, std::unordered_set& seen) +template +void visit(TypePackId tp, F& f, Set& seen) { if (visit_detail::hasSeen(seen, tp)) { @@ -182,6 +209,7 @@ void visit(TypePackId tp, F& f, std::unordered_set& seen) visit_detail::unsee(seen, tp); } + } // namespace visit_detail template @@ -197,4 +225,11 @@ void visitTypeVar(TID ty, F& f) visit_detail::visit(ty, f, seen); } +template +void visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) +{ + seen.clear(); + visit_detail::visit(ty, f, seen); +} + } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index b31d85a..3c43c80 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -196,7 +196,8 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) { InternalErrorReporter iceReporter; - Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter); + UnifierSharedState unifierState(&iceReporter); + Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); unifier.tryUnify(expectedType, actualType); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 3b0c216..f6f2363 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -106,18 +106,6 @@ void attachMagicFunction(TypeId ty, MagicFunction fn) LUAU_ASSERT(!"Got a non functional type"); } -void attachFunctionTag(TypeId ty, std::string tag) -{ - if (auto ftv = getMutable(ty)) - { - ftv->tags.emplace_back(std::move(tag)); - } - else - { - LUAU_ASSERT(!"Got a non functional type"); - } -} - Property makeProperty(TypeId ty, std::optional documentationSymbol) { return { diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b252984..5e7af50 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -23,6 +23,7 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau { @@ -248,7 +249,7 @@ struct RequireCycle // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) // However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true) std::vector getRequireCycles( - const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) + const FileResolver* resolver, const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) { std::vector result; @@ -282,9 +283,9 @@ std::vector getRequireCycles( if (top == start) { for (const SourceNode* node : path) - cycle.push_back(node->name); + cycle.push_back(resolver->getHumanReadableModuleName(node->name)); - cycle.push_back(top->name); + cycle.push_back(resolver->getHumanReadableModuleName(top->name)); break; } } @@ -404,7 +405,7 @@ CheckResult Frontend::check(const ModuleName& name) // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // all correct programs must be acyclic so this code triggers rarely if (cycleDetected) - requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck); + requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck); // This is used by the type checker to replace the resulting type of cyclic modules with any sourceModule.cyclic = !requireCycles.empty(); @@ -458,6 +459,8 @@ CheckResult Frontend::check(const ModuleName& name) module->astTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); + if (FFlag::LuauClearScopes) + module->scopes.resize(1); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index df6be76..2fd9589 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false) namespace Luau { @@ -299,6 +300,14 @@ void TypeCloner::operator()(const FunctionTypeVar& t) void TypeCloner::operator()(const TableTypeVar& t) { + // If table is now bound to another one, we ignore the content of the original + if (FFlag::LuauCloneBoundTables && t.boundTo) + { + TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + seenTypes[typeId] = boundTo; + return; + } + TypeId result = dest.addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(result); LUAU_ASSERT(ttv != nullptr); @@ -321,8 +330,11 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)}; - if (t.boundTo) - ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + if (!FFlag::LuauCloneBoundTables) + { + if (t.boundTo) + ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + } for (TypeId& arg : ttv->instantiatedTypeParams) arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); @@ -335,7 +347,7 @@ void TypeCloner::operator()(const TableTypeVar& t) if (ttv->state == TableState::Free) { - if (!t.boundTo) + if (FFlag::LuauCloneBoundTables || !t.boundTo) { if (encounteredFreeType) *encounteredFreeType = true; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp new file mode 100644 index 0000000..bf6d81a --- /dev/null +++ b/Analysis/src/Quantify.cpp @@ -0,0 +1,90 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Quantify.h" + +#include "Luau/VisitTypeVar.h" + +namespace Luau +{ + +struct Quantifier +{ + ModulePtr module; + TypeLevel level; + std::vector generics; + std::vector genericPacks; + + Quantifier(ModulePtr module, TypeLevel level) + : module(module) + , level(level) + { + } + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + if (!level.subsumes(ftv.level)) + return false; + + *asMutable(ty) = GenericTypeVar{level}; + generics.push_back(ty); + + return false; + } + + template + bool operator()(TypeId ty, const T& t) + { + return true; + } + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + TableTypeVar& ttv = *getMutable(ty); + + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) + return false; + if (!level.subsumes(ttv.level)) + return false; + + if (ttv.state == TableState::Free) + ttv.state = TableState::Generic; + else if (ttv.state == TableState::Unsealed) + ttv.state = TableState::Sealed; + + ttv.level = level; + + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + if (!level.subsumes(ftp.level)) + return false; + + *asMutable(tp) = GenericTypePack{level}; + genericPacks.push_back(tp); + return true; + } +}; + +void quantify(ModulePtr module, TypeId ty, TypeLevel level) +{ + Quantifier q{std::move(module), level}; + visitTypeVar(ty, q); + + FunctionTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; +} + +} // namespace Luau diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 4634eff..95910b5 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -182,7 +182,7 @@ struct RequireTracerOld : AstVisitor struct RequireTracer : AstVisitor { - RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName) + RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName) : result(result) , fileResolver(fileResolver) , currentModuleName(currentModuleName) @@ -260,7 +260,7 @@ struct RequireTracer : AstVisitor // seed worklist with require arguments work.reserve(requires.size()); - for (AstExprCall* require: requires) + for (AstExprCall* require : requires) work.push_back(require->args.data[0]); // push all dependent expressions to the work stack; note that the vector is modified during traversal diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5651af7..cd8180d 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauExtraNilRecovery) LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) LUAU_FASTFLAG(LuauTypeAliasPacks) @@ -159,15 +158,6 @@ struct StringifierState seen.erase(iter); } - static std::string generateName(size_t i) - { - std::string n; - n = char('a' + i % 26); - if (i >= 26) - n += std::to_string(i / 26); - return n; - } - std::string getName(TypeId ty) { const size_t s = result.nameMap.typeVars.size(); @@ -584,8 +574,7 @@ struct TypeVarStringifier std::vector results = {}; for (auto el : &uv) { - if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) - el = follow(el); + el = follow(el); if (isNil(el)) { @@ -649,8 +638,7 @@ struct TypeVarStringifier std::vector results = {}; for (auto el : uv.parts) { - if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) - el = follow(el); + el = follow(el); std::string saved = std::move(state.result.name); @@ -1204,4 +1192,13 @@ void dump(TypePackId ty) printf("%s\n", toString(ty, opts).c_str()); } +std::string generateName(size_t i) +{ + std::string n; + n = char('a' + i % 26); + if (i >= 26) + n += std::to_string(i / 26); + return n; +} + } // namespace Luau diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 2d35638..dba694b 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -298,8 +298,15 @@ struct ArcCollector : public AstVisitor struct ContainsFunctionCall : public AstVisitor { + bool alsoReturn = false; bool result = false; + ContainsFunctionCall() = default; + explicit ContainsFunctionCall(bool alsoReturn) + : alsoReturn(alsoReturn) + { + } + bool visit(AstExpr*) override { return !result; // short circuit if result is true @@ -318,6 +325,17 @@ struct ContainsFunctionCall : public AstVisitor return false; } + bool visit(AstStatReturn* stat) override + { + if (alsoReturn) + { + result = true; + return false; + } + else + return AstVisitor::visit(stat); + } + bool visit(AstExprFunction*) override { return false; @@ -479,6 +497,13 @@ bool containsFunctionCall(const AstStat& stat) return cfc.result; } +bool containsFunctionCallOrReturn(const AstStat& stat) +{ + detail::ContainsFunctionCall cfc{true}; + const_cast(stat).visit(&cfc); + return cfc.result; +} + bool isFunction(const AstStat& stat) { return stat.is() || stat.is(); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 702d0ca..383bb05 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false) + namespace Luau { @@ -33,6 +35,12 @@ void TxnLog::rollback() for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) std::swap(it->first->boundTo, it->second); + + if (FFlag::LuauShareTxnSeen) + { + LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); + sharedSeen->resize(originalSeenSize); + } } void TxnLog::concat(TxnLog rhs) @@ -46,27 +54,44 @@ void TxnLog::concat(TxnLog rhs) tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); rhs.tableChanges.clear(); - seen.swap(rhs.seen); - rhs.seen.clear(); + if (!FFlag::LuauShareTxnSeen) + { + ownedSeen.swap(rhs.ownedSeen); + rhs.ownedSeen.clear(); + } } bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair)); + if (FFlag::LuauShareTxnSeen) + return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); + else + return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair)); } void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - seen.push_back(sortedPair); + if (FFlag::LuauShareTxnSeen) + sharedSeen->push_back(sortedPair); + else + ownedSeen.push_back(sortedPair); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - LUAU_ASSERT(sortedPair == seen.back()); - seen.pop_back(); + if (FFlag::LuauShareTxnSeen) + { + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); + } + else + { + LUAU_ASSERT(sortedPair == ownedSeen.back()); + ownedSeen.pop_back(); + } } } // namespace Luau diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 266c198..49f8e0c 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -6,6 +6,7 @@ #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" +#include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -33,14 +34,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data return result; } +using SyntheticNames = std::unordered_map; + namespace Luau { + +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen) +{ + size_t s = syntheticNames->size(); + char*& n = (*syntheticNames)[&gen]; + if (!n) + { + std::string str = gen.explicitName ? gen.name : generateName(s); + n = static_cast(allocator->allocate(str.size() + 1)); + strcpy(n, str.c_str()); + } + + return n; +} + class TypeRehydrationVisitor { - mutable std::map seen; - mutable int count = 0; + std::map seen; + int count = 0; - bool hasSeen(const void* tv) const + bool hasSeen(const void* tv) { void* ttv = const_cast(tv); auto it = seen.find(ttv); @@ -52,15 +70,16 @@ class TypeRehydrationVisitor } public: - TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions()) + TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions()) : allocator(alloc) + , syntheticNames(syntheticNames) , options(options) { } - AstTypePack* rehydrate(TypePackId tp) const; + AstTypePack* rehydrate(TypePackId tp); - AstType* operator()(const PrimitiveTypeVar& ptv) const + AstType* operator()(const PrimitiveTypeVar& ptv) { switch (ptv.type) { @@ -78,11 +97,11 @@ public: return nullptr; } } - AstType* operator()(const AnyTypeVar&) const + AstType* operator()(const AnyTypeVar&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); } - AstType* operator()(const TableTypeVar& ttv) const + AstType* operator()(const TableTypeVar& ttv) { RecursionCounter counter(&count); @@ -144,12 +163,12 @@ public: return allocator->alloc(Location(), props, indexer); } - AstType* operator()(const MetatableTypeVar& mtv) const + AstType* operator()(const MetatableTypeVar& mtv) { return Luau::visit(*this, mtv.table->ty); } - AstType* operator()(const ClassTypeVar& ctv) const + AstType* operator()(const ClassTypeVar& ctv) { RecursionCounter counter(&count); @@ -176,7 +195,7 @@ public: return allocator->alloc(Location(), props); } - AstType* operator()(const FunctionTypeVar& ftv) const + AstType* operator()(const FunctionTypeVar& ftv) { RecursionCounter counter(&count); @@ -253,10 +272,12 @@ public: size_t i = 0; for (const auto& el : ftv.argNames) { + std::optional* arg = &argNames.data[i++]; + if (el) - argNames.data[i++] = {AstName(el->name.c_str()), el->location}; + new (arg) std::optional(AstArgumentName(AstName(el->name.c_str()), el->location)); else - argNames.data[i++] = {}; + new (arg) std::optional(); } AstArray returnTypes; @@ -290,23 +311,23 @@ public: return allocator->alloc( Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); } - AstType* operator()(const Unifiable::Error&) const + AstType* operator()(const Unifiable::Error&) { return allocator->alloc(Location(), std::nullopt, AstName("Unifiable")); } - AstType* operator()(const GenericTypeVar& gtv) const + AstType* operator()(const GenericTypeVar& gtv) { - return allocator->alloc(Location(), std::nullopt, AstName(gtv.name.c_str())); + return allocator->alloc(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv))); } - AstType* operator()(const Unifiable::Bound& bound) const + AstType* operator()(const Unifiable::Bound& bound) { return Luau::visit(*this, bound.boundTo->ty); } - AstType* operator()(Unifiable::Free ftv) const + AstType* operator()(const FreeTypeVar& ftv) { return allocator->alloc(Location(), std::nullopt, AstName("free")); } - AstType* operator()(const UnionTypeVar& uv) const + AstType* operator()(const UnionTypeVar& uv) { AstArray unionTypes; unionTypes.size = uv.options.size(); @@ -317,7 +338,7 @@ public: } return allocator->alloc(Location(), unionTypes); } - AstType* operator()(const IntersectionTypeVar& uv) const + AstType* operator()(const IntersectionTypeVar& uv) { AstArray intersectionTypes; intersectionTypes.size = uv.parts.size(); @@ -328,23 +349,28 @@ public: } return allocator->alloc(Location(), intersectionTypes); } - AstType* operator()(const LazyTypeVar& ltv) const + AstType* operator()(const LazyTypeVar& ltv) { return allocator->alloc(Location(), std::nullopt, AstName("")); } private: Allocator* allocator; + SyntheticNames* syntheticNames; const TypeRehydrationOptions& options; }; class TypePackRehydrationVisitor { public: - TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor) + TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor) : allocator(allocator) + , syntheticNames(syntheticNames) , typeVisitor(typeVisitor) { + LUAU_ASSERT(allocator); + LUAU_ASSERT(syntheticNames); + LUAU_ASSERT(typeVisitor); } AstTypePack* operator()(const BoundTypePack& btp) const @@ -359,7 +385,7 @@ public: head.data = static_cast(allocator->allocate(sizeof(AstType*) * tp.head.size())); for (size_t i = 0; i < tp.head.size(); i++) - head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty); + head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty); AstTypePack* tail = nullptr; @@ -371,12 +397,12 @@ public: AstTypePack* operator()(const VariadicTypePack& vtp) const { - return allocator->alloc(Location(), Luau::visit(typeVisitor, vtp.ty->ty)); + return allocator->alloc(Location(), Luau::visit(*typeVisitor, vtp.ty->ty)); } AstTypePack* operator()(const GenericTypePack& gtp) const { - return allocator->alloc(Location(), AstName(gtp.name.c_str())); + return allocator->alloc(Location(), AstName(getName(allocator, syntheticNames, gtp))); } AstTypePack* operator()(const FreeTypePack& gtp) const @@ -391,12 +417,13 @@ public: private: Allocator* allocator; - const TypeRehydrationVisitor& typeVisitor; + SyntheticNames* syntheticNames; + TypeRehydrationVisitor* typeVisitor; }; -AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const +AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) { - TypePackRehydrationVisitor tprv(allocator, *this); + TypePackRehydrationVisitor tprv(allocator, syntheticNames, this); return Luau::visit(tprv, tp->ty); } @@ -431,7 +458,7 @@ public: { if (!type) return nullptr; - return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty); + return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty); } AstArray typeAstPack(TypePackId type) @@ -443,7 +470,7 @@ public: result.data = static_cast(allocator->allocate(sizeof(AstType*) * v.size())); for (size_t i = 0; i < v.size(); ++i) { - result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty); + result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty); } return result; } @@ -495,7 +522,7 @@ public: { if (FFlag::LuauTypeAliasPacks) { - variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail); + variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); } else { @@ -515,6 +542,7 @@ public: private: Module& module; Allocator* allocator; + SyntheticNames syntheticNames; }; void attachTypeData(SourceModule& source, Module& result) @@ -525,7 +553,8 @@ void attachTypeData(SourceModule& source, Module& result) AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options) { - return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty); + SyntheticNames syntheticNames; + return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty); } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index da93c8e..38e2e52 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" +#include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Substitution.h" @@ -33,18 +34,16 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false) -LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false) LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false) LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false) LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) -LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) +LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAG(LuauNewRequireTrace) LUAU_FASTFLAG(LuauTypeAliasPacks) @@ -215,6 +214,7 @@ static bool isMetamethod(const Name& name) TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler) : resolver(resolver) , iceHandler(iceHandler) + , unifierState(iceHandler) , nilType(singletonTypes.nilType) , numberType(singletonTypes.numberType) , stringType(singletonTypes.stringType) @@ -370,13 +370,18 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) return; } + int subLevel = 0; + std::vector sorted(block.body.data, block.body.data + block.body.size); toposort(sorted); for (const auto& stat : sorted) { if (const auto& typealias = stat->as()) - check(scope, *typealias, true); + { + check(scope, *typealias, subLevel, true); + ++subLevel; + } } auto protoIter = sorted.begin(); @@ -399,8 +404,6 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } }; - int subLevel = 0; - while (protoIter != sorted.end()) { // protoIter walks forward @@ -433,7 +436,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (containsFunctionCall(**protoIter)) + if (FFlag::LuauQuantifyInPlace2 ? containsFunctionCallOrReturn(**protoIter) : containsFunctionCall(**protoIter)) { while (checkIter != protoIter) { @@ -1161,7 +1164,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location}; } -void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare) +void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare) { // This function should be called at most twice for each type alias. // Once with forwardDeclare, and once without. @@ -1189,11 +1192,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { - ScopePtr aliasScope = childScope(scope, typealias.location); + ScopePtr aliasScope = + FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); if (FFlag::LuauTypeAliasPacks) { - auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks); + auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); FreeTypeVar* ftv = getMutable(ty); @@ -1418,7 +1422,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo { ScopePtr funScope = childFunctionScope(scope, global.location); - auto [generics, genericPacks] = createGenericTypes(funScope, global, global.generics, global.genericPacks); + auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks); TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); @@ -1610,25 +1614,11 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn if (std::optional ty = resolveLValue(scope, *lvalue)) return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; - if (FFlag::LuauExtraNilRecovery) - lhsType = stripFromNilAndReport(lhsType, expr.expr->location); + lhsType = stripFromNilAndReport(lhsType, expr.expr->location); if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) return {*ty}; - if (!FFlag::LuauMissingUnionPropertyError) - reportError(expr.indexLocation, UnknownProperty{lhsType, expr.index.value}); - - if (!FFlag::LuauExtraNilRecovery) - { - // Try to recover using a union without 'nil' options - if (std::optional strippedUnion = tryStripUnionFromNil(lhsType)) - { - if (std::optional ty = getIndexTypeFromType(scope, *strippedUnion, name, expr.location, false)) - return {*ty}; - } - } - return {errorType}; } @@ -1694,61 +1684,37 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (const UnionTypeVar* utv = get(type)) { - if (FFlag::LuauMissingUnionPropertyError) + std::vector goodOptions; + std::vector badOptions; + + for (TypeId t : utv) { - std::vector goodOptions; - std::vector badOptions; + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - for (TypeId t : utv) - { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) - goodOptions.push_back(*ty); - else - badOptions.push_back(t); - } - - if (!badOptions.empty()) - { - if (addErrors) - { - if (goodOptions.empty()) - reportError(location, UnknownProperty{type, name}); - else - reportError(location, MissingUnionProperty{type, badOptions, name}); - } - return std::nullopt; - } - - std::vector result = reduceUnion(goodOptions); - - if (result.size() == 1) - return result[0]; - - return addType(UnionTypeVar{std::move(result)}); + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + goodOptions.push_back(*ty); + else + badOptions.push_back(t); } - else + + if (!badOptions.empty()) { - std::vector options; - - for (TypeId t : utv->options) + if (addErrors) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) - options.push_back(*ty); + if (goodOptions.empty()) + reportError(location, UnknownProperty{type, name}); else - return std::nullopt; + reportError(location, MissingUnionProperty{type, badOptions, name}); } - - std::vector result = reduceUnion(options); - - if (result.size() == 1) - return result[0]; - - return addType(UnionTypeVar{std::move(result)}); + return std::nullopt; } + + std::vector result = reduceUnion(goodOptions); + + if (result.size() == 1) + return result[0]; + + return addType(UnionTypeVar{std::move(result)}); } else if (const IntersectionTypeVar* itv = get(type)) { @@ -1765,7 +1731,7 @@ std::optional TypeChecker::getIndexTypeFromType( // If no parts of the intersection had the property we looked up for, it never existed at all. if (parts.empty()) { - if (FFlag::LuauMissingUnionPropertyError && addErrors) + if (addErrors) reportError(location, UnknownProperty{type, name}); return std::nullopt; } @@ -1779,7 +1745,7 @@ std::optional TypeChecker::getIndexTypeFromType( return addType(IntersectionTypeVar{result}); } - if (FFlag::LuauMissingUnionPropertyError && addErrors) + if (addErrors) reportError(location, UnknownProperty{type, name}); return std::nullopt; @@ -2062,8 +2028,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn case AstExprUnary::Len: tablify(operandType); - if (FFlag::LuauExtraNilRecovery) - operandType = stripFromNilAndReport(operandType, expr.location); + operandType = stripFromNilAndReport(operandType, expr.location); if (get(operandType)) return {errorType}; @@ -2635,8 +2600,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope Name name = expr.index.value; - if (FFlag::LuauExtraNilRecovery) - lhs = stripFromNilAndReport(lhs, expr.expr->location); + lhs = stripFromNilAndReport(lhs, expr.expr->location); if (TableTypeVar* lhsTable = getMutableTableType(lhs)) { @@ -2710,8 +2674,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope TypeId exprType = checkExpr(scope, *expr.expr).type; tablify(exprType); - if (FFlag::LuauExtraNilRecovery) - exprType = stripFromNilAndReport(exprType, expr.expr->location); + exprType = stripFromNilAndReport(exprType, expr.expr->location); TypeId indexType = checkExpr(scope, *expr.index).type; @@ -2738,10 +2701,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { - if (FFlag::LuauExtraNilRecovery) - reportError(TypeError{expr.expr->location, NotATable{exprType}}); - else - reportError(TypeError{expr.location, NotATable{exprType}}); + reportError(TypeError{expr.expr->location, NotATable{exprType}}); return std::pair(errorType, nullptr); } @@ -2910,7 +2870,7 @@ std::pair TypeChecker::checkFunctionSignature( if (FFlag::LuauGenericFunctions) { - std::tie(generics, genericPacks) = createGenericTypes(funScope, expr, expr.generics, expr.genericPacks); + std::tie(generics, genericPacks) = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); } TypePackId retPack; @@ -3016,9 +2976,6 @@ std::pair TypeChecker::checkFunctionSignature( if (expectedArgsCurr != expectedArgsEnd) { argType = *expectedArgsCurr; - - if (!FFlag::LuauInferFunctionArgsFix) - ++expectedArgsCurr; } else if (auto expectedArgsTail = expectedArgsCurr.tail()) { @@ -3034,7 +2991,7 @@ std::pair TypeChecker::checkFunctionSignature( funScope->bindings[local] = {argType, local->location}; argTypes.push_back(argType); - if (FFlag::LuauInferFunctionArgsFix && expectedArgsCurr != expectedArgsEnd) + if (expectedArgsCurr != expectedArgsEnd) ++expectedArgsCurr; } @@ -3402,8 +3359,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A if (!FFlag::LuauRankNTypes) instantiate(scope, selfType, expr.func->location); - if (FFlag::LuauExtraNilRecovery) - selfType = stripFromNilAndReport(selfType, expr.func->location); + selfType = stripFromNilAndReport(selfType, expr.func->location); if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) { @@ -3412,34 +3368,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } else { - if (!FFlag::LuauMissingUnionPropertyError) - reportError(indexExpr->indexLocation, UnknownProperty{selfType, indexExpr->index.value}); - - if (!FFlag::LuauExtraNilRecovery) - { - // Try to recover using a union without 'nil' options - if (std::optional strippedUnion = tryStripUnionFromNil(selfType)) - { - if (std::optional propTy = getIndexTypeFromType(scope, *strippedUnion, indexExpr->index.value, expr.location, false)) - { - selfType = *strippedUnion; - - functionType = *propTy; - actualFunctionType = instantiate(scope, functionType, expr.func->location); - } - } - - if (!actualFunctionType) - { - functionType = errorType; - actualFunctionType = errorType; - } - } - else - { - functionType = errorType; - actualFunctionType = errorType; - } + functionType = errorType; + actualFunctionType = errorType; } } else @@ -3555,8 +3485,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& errors) { - if (FFlag::LuauExtraNilRecovery) - fn = stripFromNilAndReport(fn, expr.func->location); + fn = stripFromNilAndReport(fn, expr.func->location); if (get(fn)) { @@ -4283,6 +4212,12 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) return ty; + if (FFlag::LuauQuantifyInPlace2) + { + Luau::quantify(currentModule, ty, scope->level); + return ty; + } + quantification.level = scope->level; quantification.generics.clear(); quantification.genericPacks.clear(); @@ -4491,12 +4426,12 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) Unifier TypeChecker::mkUnifier(const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, iceHandler}; + return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState}; } Unifier TypeChecker::mkUnifier(const std::vector>& seen, const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, iceHandler}; + return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, unifierState}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -4753,7 +4688,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (FFlag::LuauGenericFunctions) { - std::tie(generics, genericPacks) = createGenericTypes(funcScope, annotation, func->generics, func->genericPacks); + std::tie(generics, genericPacks) = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); } // TODO: better error message CLI-39912 @@ -5041,10 +4976,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } std::pair, std::vector> TypeChecker::createGenericTypes( - const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) + const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); + const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; + std::vector generics; for (const AstName& generic : genericNames) { @@ -5063,12 +5000,12 @@ std::pair, std::vector> TypeChecker::createGener { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) - cached = addType(GenericTypeVar{scope->level, n}); + cached = addType(GenericTypeVar{level, n}); g = cached; } else { - g = addType(Unifiable::Generic{scope->level, n}); + g = addType(Unifiable::Generic{level, n}); } generics.push_back(g); @@ -5093,12 +5030,12 @@ std::pair, std::vector> TypeChecker::createGener { TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); g = cached; } else { - g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); } genericPacks.push_back(g); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index e963fc7..e82f751 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -22,6 +22,7 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) namespace Luau { @@ -217,8 +218,7 @@ std::optional getMetatable(TypeId type) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) return classType->metatable; - else if (const PrimitiveTypeVar* primitiveType = get(type); - primitiveType && primitiveType->metatable) + else if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) { LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); return primitiveType->metatable; @@ -1490,4 +1490,86 @@ std::vector filterMap(TypeId type, TypeIdPredicate predicate) return {}; } +static Tags* getTags(TypeId ty) +{ + ty = follow(ty); + + if (auto ftv = getMutable(ty)) + return &ftv->tags; + else if (auto ttv = getMutable(ty)) + return &ttv->tags; + else if (auto ctv = getMutable(ty)) + return &ctv->tags; + + return nullptr; +} + +void attachTag(TypeId ty, const std::string& tagName) +{ + if (!FFlag::LuauRefactorTagging) + { + if (auto ftv = getMutable(ty)) + { + ftv->tags.emplace_back(tagName); + } + else + { + LUAU_ASSERT(!"Got a non functional type"); + } + } + else + { + if (auto tags = getTags(ty)) + tags->push_back(tagName); + else + LUAU_ASSERT(!"This TypeId does not support tags"); + } +} + +void attachTag(Property& prop, const std::string& tagName) +{ + LUAU_ASSERT(FFlag::LuauRefactorTagging); + + prop.tags.push_back(tagName); +} + +// We would ideally not expose this because it could cause a footgun. +// If the Base class has a tag and you ask if Derived has that tag, it would return false. +// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it. +bool hasTag(const Tags& tags, const std::string& tagName) +{ + LUAU_ASSERT(FFlag::LuauRefactorTagging); + return std::find(tags.begin(), tags.end(), tagName) != tags.end(); +} + +bool hasTag(TypeId ty, const std::string& tagName) +{ + ty = follow(ty); + + // We special case classes because getTags only returns a pointer to one vector of tags. + // But classes has multiple vector of tags, represented throughout the hierarchy. + if (auto ctv = get(ty)) + { + while (ctv) + { + if (hasTag(ctv->tags, tagName)) + return true; + else if (!ctv->parent) + return false; + + ctv = get(*ctv->parent); + LUAU_ASSERT(ctv); + } + } + else if (auto tags = getTags(ty)) + return hasTag(*tags, tagName); + + return false; +} + +bool hasTag(const Property& prop, const std::string& tagName) +{ + return hasTag(prop.tags, tagName); +} + } // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 56f2515..2539650 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -7,6 +7,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/VisitTypeVar.h" #include @@ -22,9 +23,99 @@ LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) +LUAU_FASTFLAG(LuauShareTxnSeen); +LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) namespace Luau { +struct SkipCacheForType +{ + SkipCacheForType(const DenseHashMap& skipCacheForType) + : skipCacheForType(skipCacheForType) + { + } + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const BoundTypeVar& btv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const GenericTypeVar& btv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + TableTypeVar& ttv = *getMutable(ty); + + if (ttv.boundTo) + { + result = true; + return false; + } + + if (ttv.state != TableState::Sealed) + { + result = true; + return false; + } + + return true; + } + + template + bool operator()(TypeId ty, const T& t) + { + const bool* prev = skipCacheForType.find(ty); + + if (prev && *prev) + { + result = true; + return false; + } + + return true; + } + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + result = true; + return false; + } + + bool operator()(TypePackId tp, const BoundTypePack& ftp) + { + result = true; + return false; + } + + bool operator()(TypePackId tp, const GenericTypePack& ftp) + { + result = true; + return false; + } + + const DenseHashMap& skipCacheForType; + bool result = false; +}; static std::optional hasUnificationTooComplex(const ErrorVec& errors) { @@ -39,7 +130,7 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) , globalScope(std::move(globalScope)) @@ -47,24 +138,39 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , variance(variance) , counters(&countersData) , counters_DEPRECATED(std::make_shared()) - , iceHandler(iceHandler) + , sharedState(sharedState) { - LUAU_ASSERT(iceHandler); + LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) : types(types) , mode(mode) , globalScope(std::move(globalScope)) - , log(seen) + , log(ownedSeen) , location(location) , variance(variance) , counters(counters ? counters : &countersData) , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) - , iceHandler(iceHandler) + , sharedState(sharedState) { - LUAU_ASSERT(iceHandler); + LUAU_ASSERT(sharedState.iceHandler); +} + +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) + : types(types) + , mode(mode) + , globalScope(std::move(globalScope)) + , log(sharedSeen) + , location(location) + , variance(variance) + , counters(counters ? counters : &countersData) + , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) + , sharedState(sharedState) +{ + LUAU_ASSERT(sharedState.iceHandler); } void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) @@ -74,7 +180,7 @@ void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool i else counters_DEPRECATED->iterationCount = 0; - return tryUnify_(superTy, subTy, isFunctionCall, isIntersection); + tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) @@ -206,6 +312,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (get(subTy) || get(subTy)) return tryUnifyWithAny(subTy, superTy); + bool cacheEnabled = FFlag::LuauCacheUnifyTableResults && !isFunctionCall && !isIntersection; + auto& cache = sharedState.cachedUnify; + + // What if the types are immutable and we proved their relation before + if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) + return; + // If we have seen this pair of types before, we are currently recursing into cyclic types. // Here, we assume that the types unify. If they do not, we will find out as we roll back // the stack. @@ -257,6 +370,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (FFlag::LuauUnionHeuristic) { + bool found = false; + const std::string* subName = getName(subTy); if (subName) { @@ -264,6 +379,21 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { const std::string* optionName = getName(uv->options[i]); if (optionName && *optionName == *subName) + { + found = true; + startIndex = i; + break; + } + } + } + + if (!found && cacheEnabled) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[i]; + + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) { startIndex = i; break; @@ -311,8 +441,25 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool bool found = false; std::optional unificationTooComplex; - for (TypeId type : uv->parts) + size_t startIndex = 0; + + if (cacheEnabled) { + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[i]; + + if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) + { + startIndex = i; + break; + } + } + } + + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; Unifier innerState = makeChildUnifier(); innerState.tryUnify_(superTy, type, isFunctionCall); @@ -342,8 +489,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool tryUnifyFunctions(superTy, subTy, isFunctionCall); else if (get(superTy) && get(subTy)) + { tryUnifyTables(superTy, subTy, isIntersection); + if (cacheEnabled && errors.empty()) + cacheResult(superTy, subTy); + } + // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. else if (get(superTy)) tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false); @@ -364,6 +516,41 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool log.popSeen(superTy, subTy); } +void Unifier::cacheResult(TypeId superTy, TypeId subTy) +{ + LUAU_ASSERT(FFlag::LuauCacheUnifyTableResults); + + bool* superTyInfo = sharedState.skipCacheForType.find(superTy); + + if (superTyInfo && *superTyInfo) + return; + + bool* subTyInfo = sharedState.skipCacheForType.find(subTy); + + if (subTyInfo && *subTyInfo) + return; + + auto skipCacheFor = [this](TypeId ty) { + SkipCacheForType visitor{sharedState.skipCacheForType}; + visitTypeVarOnce(ty, visitor, sharedState.seenAny); + + sharedState.skipCacheForType[ty] = visitor.result; + + return visitor.result; + }; + + if (!superTyInfo && skipCacheFor(superTy)) + return; + + if (!subTyInfo && skipCacheFor(subTy)) + return; + + sharedState.cachedUnify.insert({superTy, subTy}); + + if (variance == Invariant) + sharedState.cachedUnify.insert({subTy, superTy}); +} + struct WeirdIter { TypePackId packId; @@ -459,7 +646,7 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall else counters_DEPRECATED->iterationCount = 0; - return tryUnify_(superTp, subTp, isFunctionCall); + tryUnify_(superTp, subTp, isFunctionCall); } /* @@ -797,6 +984,40 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) std::vector missingProperties; std::vector extraProperties; + // Optimization: First test that the property sets are compatible without doing any recursive unification + if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer && rt->state != TableState::Free) + { + for (const auto& [propName, superProp] : lt->props) + { + auto subIter = rt->props.find(propName); + if (subIter == rt->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) + missingProperties.push_back(propName); + } + + if (!missingProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + return; + } + } + + // And vice versa if we're invariant + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free) + { + for (const auto& [propName, subProp] : rt->props) + { + auto superIter = lt->props.find(propName); + if (superIter == lt->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) + extraProperties.push_back(propName); + } + + if (!extraProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + return; + } + } + // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. @@ -833,9 +1054,10 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) innerState.log.rollback(); } else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 - // TODO: should isOptional(anyType) be true? - {} + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + { + } else if (rt->state == TableState::Free) { log(rt); @@ -878,11 +1100,13 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) lt->props[name] = clone; } else if (variance == Covariant) - {} + { + } else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 - // TODO: should isOptional(anyType) be true? - {} + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + { + } else if (lt->state == TableState::Free) { log(lt); @@ -980,10 +1204,10 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see TableTypeVar* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) prop.type = deeplyOptional(prop.type, seen); - return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});; + return types->addType(UnionTypeVar{{singletonTypes.nilType, result}}); } else - return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }}); + return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}}); } void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) @@ -1697,10 +1921,20 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { std::vector queue = {ty}; - tempSeenTy.clear(); - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); + } + else + { + tempSeenTy_DEPRECATED.clear(); + tempSeenTp_DEPRECATED.clear(); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, singletonTypes.anyType, anyTP); + } } else { @@ -1721,12 +1955,24 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { std::vector queue; - tempSeenTy.clear(); - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - queueTypePack(queue, tempSeenTp, *this, ty, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); - Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); + } + else + { + tempSeenTy_DEPRECATED.clear(); + tempSeenTp_DEPRECATED.clear(); + + queueTypePack(queue, tempSeenTp_DEPRECATED, *this, ty, any); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, anyTy, any); + } } else { @@ -1775,10 +2021,20 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack) { std::unordered_set seen_DEPRECATED; - if (FFlag::LuauTypecheckOpts) - tempSeenTy.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + if (FFlag::LuauTypecheckOpts) + sharedState.tempSeenTy.clear(); - return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack); + return occursCheck(seen_DEPRECATED, sharedState.tempSeenTy, needle, haystack); + } + else + { + if (FFlag::LuauTypecheckOpts) + tempSeenTy_DEPRECATED.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTy_DEPRECATED, needle, haystack); + } } void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) @@ -1851,10 +2107,20 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { std::unordered_set seen_DEPRECATED; - if (FFlag::LuauTypecheckOpts) - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + if (FFlag::LuauTypecheckOpts) + sharedState.tempSeenTp.clear(); - return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack); + return occursCheck(seen_DEPRECATED, sharedState.tempSeenTp, needle, haystack); + } + else + { + if (FFlag::LuauTypecheckOpts) + tempSeenTp_DEPRECATED.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTp_DEPRECATED, needle, haystack); + } } void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) @@ -1922,7 +2188,10 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters}; + if (FFlag::LuauShareTxnSeen) + return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; + else + return Unifier{types, mode, globalScope, log.ownedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; } bool Unifier::isNonstrictMode() const @@ -1940,12 +2209,12 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId void Unifier::ice(const std::string& message, const Location& location) { - iceHandler->ice(message, location); + sharedState.iceHandler->ice(message, location); } void Unifier::ice(const std::string& message) { - iceHandler->ice(message); + sharedState.iceHandler->ice(message); } } // namespace Luau diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 641dfd3..503eca6 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -194,20 +194,20 @@ LUAU_NOINLINE std::pair createScopeDa } // namespace Luau // Regular scope -#define LUAU_TIMETRACE_SCOPE(name, category) \ +#define LUAU_TIMETRACE_SCOPE(name, category) \ static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) // A scope without nested scopes that may be skipped if the time it took is less than the threshold -#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) // Extra key/value data can be added to regular scopes -#define LUAU_TIMETRACE_ARGUMENT(name, value) \ - do \ - { \ - if (FFlag::DebugLuauTimeTracing) \ +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + if (FFlag::DebugLuauTimeTracing) \ lttScopeStatic.second.eventArgument(name, value); \ } while (false) @@ -216,8 +216,8 @@ LUAU_NOINLINE std::pair createScopeDa #define LUAU_TIMETRACE_SCOPE(name, category) #define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) #define LUAU_TIMETRACE_ARGUMENT(name, value) \ - do \ - { \ + do \ + { \ } while (false) #endif diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index e6aab20..ded50e5 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -77,7 +77,10 @@ struct GlobalContext // Ideally we would want all ThreadContext destructors to run // But in VS, not all thread_local object instances are destroyed for (ThreadContext* context : threads) - context->flushEvents(); + { + if (!context->events.empty()) + context->flushEvents(); + } if (traceFile) fclose(traceFile); diff --git a/Makefile b/Makefile index dd63ffb..ea416be 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +.SUFFIXES: MAKEFLAGS+=-r -j8 COMMA=, diff --git a/Sources.cmake b/Sources.cmake index 83ed523..c30cf77 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -46,6 +46,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Module.h Analysis/include/Luau/ModuleResolver.h Analysis/include/Luau/Predicate.h + Analysis/include/Luau/Quantify.h Analysis/include/Luau/RecursionCounter.h Analysis/include/Luau/RequireTracer.h Analysis/include/Luau/Scope.h @@ -63,6 +64,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeVar.h Analysis/include/Luau/Unifiable.h Analysis/include/Luau/Unifier.h + Analysis/include/Luau/UnifierSharedState.h Analysis/include/Luau/Variant.h Analysis/include/Luau/VisitTypeVar.h @@ -77,6 +79,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Linter.cpp Analysis/src/Module.cpp Analysis/src/Predicate.cpp + Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp Analysis/src/Substitution.cpp diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 0131536..f2e97c6 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -1153,7 +1151,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) luaC_checkGC(L); luaC_checkthreadsleep(L); Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); - memcpy(u->data + sz, &dtor, sizeof(dtor)); + memcpy(&u->data + sz, &dtor, sizeof(dtor)); setuvalue(L, L->top, u); api_incr_top(L); return u->data; diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index a2515ba..9724c0e 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,8 +5,6 @@ #include "lstate.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false) - #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -17,7 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false) static const char* const statnames[] = {"running", "suspended", "normal", "dead"}; -static int costatus(lua_State* L, lua_State* co) +static int auxstatus(lua_State* L, lua_State* co) { if (co == L) return CO_RUN; @@ -34,11 +32,11 @@ static int costatus(lua_State* L, lua_State* co) return CO_SUS; /* initial state */ } -static int luaB_costatus(lua_State* L) +static int costatus(lua_State* L) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); - lua_pushstring(L, statnames[costatus(L, co)]); + lua_pushstring(L, statnames[auxstatus(L, co)]); return 1; } @@ -47,7 +45,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg) // error handling for edge cases if (co->status != LUA_YIELD) { - int status = costatus(L, co); + int status = auxstatus(L, co); if (status != CO_SUS) { lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]); @@ -115,7 +113,7 @@ static int auxresumecont(lua_State* L, lua_State* co) } } -static int luaB_coresumefinish(lua_State* L, int r) +static int coresumefinish(lua_State* L, int r) { if (r < 0) { @@ -131,7 +129,7 @@ static int luaB_coresumefinish(lua_State* L, int r) } } -static int luaB_coresumey(lua_State* L) +static int coresumey(lua_State* L) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); @@ -141,10 +139,10 @@ static int luaB_coresumey(lua_State* L) if (r == CO_STATUS_BREAK) return interruptThread(L, co); - return luaB_coresumefinish(L, r); + return coresumefinish(L, r); } -static int luaB_coresumecont(lua_State* L, int status) +static int coresumecont(lua_State* L, int status) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); @@ -155,10 +153,10 @@ static int luaB_coresumecont(lua_State* L, int status) int r = auxresumecont(L, co); - return luaB_coresumefinish(L, r); + return coresumefinish(L, r); } -static int luaB_auxwrapfinish(lua_State* L, int r) +static int auxwrapfinish(lua_State* L, int r) { if (r < 0) { @@ -173,7 +171,7 @@ static int luaB_auxwrapfinish(lua_State* L, int r) return r; } -static int luaB_auxwrapy(lua_State* L) +static int auxwrapy(lua_State* L) { lua_State* co = lua_tothread(L, lua_upvalueindex(1)); int narg = cast_int(L->top - L->base); @@ -182,10 +180,10 @@ static int luaB_auxwrapy(lua_State* L) if (r == CO_STATUS_BREAK) return interruptThread(L, co); - return luaB_auxwrapfinish(L, r); + return auxwrapfinish(L, r); } -static int luaB_auxwrapcont(lua_State* L, int status) +static int auxwrapcont(lua_State* L, int status) { lua_State* co = lua_tothread(L, lua_upvalueindex(1)); @@ -195,62 +193,52 @@ static int luaB_auxwrapcont(lua_State* L, int status) int r = auxresumecont(L, co); - return luaB_auxwrapfinish(L, r); + return auxwrapfinish(L, r); } -static int luaB_cocreate(lua_State* L) +static int cocreate(lua_State* L) { luaL_checktype(L, 1, LUA_TFUNCTION); lua_State* NL = lua_newthread(L); - - if (FFlag::LuauPreferXpush) - { - lua_xpush(L, NL, 1); // push function on top of NL - } - else - { - lua_pushvalue(L, 1); /* move function to top */ - lua_xmove(L, NL, 1); /* move function from L to NL */ - } - + lua_xpush(L, NL, 1); // push function on top of NL return 1; } -static int luaB_cowrap(lua_State* L) +static int cowrap(lua_State* L) { - luaB_cocreate(L); + cocreate(L); - lua_pushcfunction(L, luaB_auxwrapy, NULL, 1, luaB_auxwrapcont); + lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont); return 1; } -static int luaB_yield(lua_State* L) +static int coyield(lua_State* L) { int nres = cast_int(L->top - L->base); return lua_yield(L, nres); } -static int luaB_corunning(lua_State* L) +static int corunning(lua_State* L) { if (lua_pushthread(L)) lua_pushnil(L); /* main thread is not a coroutine */ return 1; } -static int luaB_yieldable(lua_State* L) +static int coyieldable(lua_State* L) { lua_pushboolean(L, lua_isyieldable(L)); return 1; } static const luaL_Reg co_funcs[] = { - {"create", luaB_cocreate}, - {"running", luaB_corunning}, - {"status", luaB_costatus}, - {"wrap", luaB_cowrap}, - {"yield", luaB_yield}, - {"isyieldable", luaB_yieldable}, + {"create", cocreate}, + {"running", corunning}, + {"status", costatus}, + {"wrap", cowrap}, + {"yield", coyield}, + {"isyieldable", coyieldable}, {NULL, NULL}, }; @@ -258,7 +246,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); - lua_pushcfunction(L, luaB_coresumey, "resume", 0, luaB_coresumecont); + lua_pushcfunction(L, coresumey, "resume", 0, coresumecont); lua_setfield(L, -2, "resume"); return 1; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index fb4978d..328b47e 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -18,6 +18,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) +LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) /* ** {====================================================== @@ -536,6 +537,12 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e status = LUA_ERRERR; } + if (FFlag::LuauCcallRestoreFix) + { + // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + L->nCcalls = oldnCcalls; + } + // an error occurred, check if we have a protected error callback if (L->global->cb.debugprotectederror) { @@ -549,7 +556,10 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); /* close eventual pending closures */ seterrorobj(L, status, oldtop); - L->nCcalls = oldnCcalls; + if (!FFlag::LuauCcallRestoreFix) + { + L->nCcalls = oldnCcalls; + } L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 5de5d49..6553009 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -12,11 +12,9 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false) LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) -LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false) -LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false) LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) +LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) LUAU_FASTFLAG(LuauArrayBoundary) @@ -66,13 +64,18 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, g->gcstats.currcycle.marktime += seconds; // atomic step had to be performed during the switch and it's tracked separately - if (g->gcstate == GCSsweepstring) + if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring) g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime; break; + case GCSatomic: + g->gcstats.currcycle.atomictime += seconds; + break; case GCSsweepstring: case GCSsweep: g->gcstats.currcycle.sweeptime += seconds; break; + default: + LUAU_ASSERT(!"Unexpected GC state"); } if (assist) @@ -183,33 +186,15 @@ static int traversetable(global_State* g, Table* h) if (h->metatable) markobject(g, cast_to(Table*, h->metatable)); - if (FFlag::LuauShrinkWeakTables) + /* is there a weak mode? */ + if (const char* modev = gettablemode(g, h)) { - /* is there a weak mode? */ - if (const char* modev = gettablemode(g, h)) - { - weakkey = (strchr(modev, 'k') != NULL); - weakvalue = (strchr(modev, 'v') != NULL); - if (weakkey || weakvalue) - { /* is really weak? */ - h->gclist = g->weak; /* must be cleared after GC, ... */ - g->weak = obj2gco(h); /* ... so put in the appropriate list */ - } - } - } - else - { - const TValue* mode = gfasttm(g, h->metatable, TM_MODE); - if (mode && ttisstring(mode)) - { /* is there a weak mode? */ - const char* modev = svalue(mode); - weakkey = (strchr(modev, 'k') != NULL); - weakvalue = (strchr(modev, 'v') != NULL); - if (weakkey || weakvalue) - { /* is really weak? */ - h->gclist = g->weak; /* must be cleared after GC, ... */ - g->weak = obj2gco(h); /* ... so put in the appropriate list */ - } + weakkey = (strchr(modev, 'k') != NULL); + weakvalue = (strchr(modev, 'v') != NULL); + if (weakkey || weakvalue) + { /* is really weak? */ + h->gclist = g->weak; /* must be cleared after GC, ... */ + g->weak = obj2gco(h); /* ... so put in the appropriate list */ } } @@ -297,7 +282,7 @@ static void traversestack(global_State* g, lua_State* l, bool clearstack) for (StkId o = l->stack; o < l->top; o++) markvalue(g, o); /* final traversal? */ - if (g->gcstate == GCSatomic || (FFlag::LuauGcFullSkipInactiveThreads && clearstack)) + if (g->gcstate == GCSatomic || clearstack) { StkId stack_end = l->stack + l->stacksize; for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */ @@ -336,28 +321,16 @@ static size_t propagatemark(global_State* g) lua_State* th = gco2th(o); g->gray = th->gclist; - if (FFlag::LuauGcFullSkipInactiveThreads) + LUAU_ASSERT(!luaC_threadsleeping(th)); + + // threads that are executing and the main thread are not deactivated + bool active = luaC_threadactive(th) || th == th->global->mainthread; + + if (!active && g->gcstate == GCSpropagate) { - LUAU_ASSERT(!luaC_threadsleeping(th)); + traversestack(g, th, /* clearstack= */ true); - // threads that are executing and the main thread are not deactivated - bool active = luaC_threadactive(th) || th == th->global->mainthread; - - if (!active && g->gcstate == GCSpropagate) - { - traversestack(g, th, /* clearstack= */ true); - - l_setbit(th->stackstate, THREAD_SLEEPINGBIT); - } - else - { - th->gclist = g->grayagain; - g->grayagain = o; - - black2gray(o); - - traversestack(g, th, /* clearstack= */ false); - } + l_setbit(th->stackstate, THREAD_SLEEPINGBIT); } else { @@ -385,12 +358,14 @@ static size_t propagatemark(global_State* g) } } -static void propagateall(global_State* g) +static size_t propagateall(global_State* g) { + size_t work = 0; while (g->gray) { - propagatemark(g); + work += propagatemark(g); } + return work; } /* @@ -415,11 +390,14 @@ static int isobjcleared(GCObject* o) /* ** clear collected entries from weaktables */ -static void cleartable(lua_State* L, GCObject* l) +static size_t cleartable(lua_State* L, GCObject* l) { + size_t work = 0; while (l) { Table* h = gco2h(l); + work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); + int i = h->sizearray; while (i--) { @@ -433,50 +411,36 @@ static void cleartable(lua_State* L, GCObject* l) { LuaNode* n = gnode(h, i); - if (FFlag::LuauShrinkWeakTables) + // non-empty entry? + if (!ttisnil(gval(n))) { - // non-empty entry? - if (!ttisnil(gval(n))) - { - // can we clear key or value? - if (iscleared(gkey(n)) || iscleared(gval(n))) - { - setnilvalue(gval(n)); /* remove value ... */ - removeentry(n); /* remove entry from table */ - } - else - { - activevalues++; - } - } - } - else - { - if (!ttisnil(gval(n)) && /* non-empty entry? */ - (iscleared(gkey(n)) || iscleared(gval(n)))) + // can we clear key or value? + if (iscleared(gkey(n)) || iscleared(gval(n))) { setnilvalue(gval(n)); /* remove value ... */ removeentry(n); /* remove entry from table */ } + else + { + activevalues++; + } } } - if (FFlag::LuauShrinkWeakTables) + if (const char* modev = gettablemode(L->global, h)) { - if (const char* modev = gettablemode(L->global, h)) + // are we allowed to shrink this weak table? + if (strchr(modev, 's')) { - // are we allowed to shrink this weak table? - if (strchr(modev, 's')) - { - // shrink at 37.5% occupancy - if (activevalues < sizenode(h) * 3 / 8) - luaH_resizehash(L, h, activevalues); - } + // shrink at 37.5% occupancy + if (activevalues < sizenode(h) * 3 / 8) + luaH_resizehash(L, h, activevalues); } } l = h->gclist; } + return work; } static void shrinkstack(lua_State* L) @@ -655,37 +619,49 @@ static void markroot(lua_State* L) g->gcstate = GCSpropagate; } -static void remarkupvals(global_State* g) +static size_t remarkupvals(global_State* g) { - UpVal* uv; - for (uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + size_t work = 0; + for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) { + work += sizeof(UpVal); LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); if (isgray(obj2gco(uv))) markvalue(g, uv->v); } + return work; } -static void atomic(lua_State* L) +static size_t atomic(lua_State* L) { global_State* g = L->global; - g->gcstate = GCSatomic; + size_t work = 0; + + if (FFlag::LuauSeparateAtomic) + { + LUAU_ASSERT(g->gcstate == GCSatomic); + } + else + { + g->gcstate = GCSatomic; + } + /* remark occasional upvalues of (maybe) dead threads */ - remarkupvals(g); + work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ - propagateall(g); + work += propagateall(g); /* remark weak tables */ g->gray = g->weak; g->weak = NULL; LUAU_ASSERT(!iswhite(obj2gco(g->mainthread))); markobject(g, L); /* mark running thread */ markmt(g); /* mark basic metatables (again) */ - propagateall(g); + work += propagateall(g); /* remark gray again */ g->gray = g->grayagain; g->grayagain = NULL; - propagateall(g); - cleartable(L, g->weak); /* remove collected objects from weak tables */ + work += propagateall(g); + work += cleartable(L, g->weak); /* remove collected objects from weak tables */ g->weak = NULL; /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); @@ -693,7 +669,12 @@ static void atomic(lua_State* L) g->sweepgc = &g->rootgc; g->gcstate = GCSsweepstring; - GC_INTERRUPT(GCSatomic); + if (!FFlag::LuauSeparateAtomic) + { + GC_INTERRUPT(GCSatomic); + } + + return work; } static size_t singlestep(lua_State* L) @@ -705,46 +686,24 @@ static size_t singlestep(lua_State* L) case GCSpause: { markroot(L); /* start a new collection */ + LUAU_ASSERT(g->gcstate == GCSpropagate); break; } case GCSpropagate: { - if (FFlag::LuauRescanGrayAgain) + if (g->gray) { - if (g->gray) - { - g->gcstats.currcycle.markitems++; + g->gcstats.currcycle.markitems++; - cost = propagatemark(g); - } - else - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; - - g->gcstate = GCSpropagateagain; - } + cost = propagatemark(g); } else { - if (g->gray) - { - g->gcstats.currcycle.markitems++; + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; - cost = propagatemark(g); - } - else /* no more `gray' objects */ - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSpropagateagain; } break; } @@ -758,17 +717,34 @@ static size_t singlestep(lua_State* L) } else /* no more `gray' objects */ { - double starttimestamp = lua_clock(); + if (FFlag::LuauSeparateAtomic) + { + g->gcstate = GCSatomic; + } + else + { + double starttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - atomic(L); /* finish mark phase */ + atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } } break; } + case GCSatomic: + { + g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + cost = atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); + break; + } case GCSsweepstring: { size_t traversedcount = 0; @@ -806,7 +782,7 @@ static size_t singlestep(lua_State* L) break; } default: - LUAU_ASSERT(0); + LUAU_ASSERT(!"Unexpected GC state"); } return cost; @@ -821,48 +797,25 @@ static size_t gcstep(lua_State* L, size_t limit) case GCSpause: { markroot(L); /* start a new collection */ + LUAU_ASSERT(g->gcstate == GCSpropagate); break; } case GCSpropagate: { - if (FFlag::LuauRescanGrayAgain) + while (g->gray && cost < limit) { - while (g->gray && cost < limit) - { - g->gcstats.currcycle.markitems++; + g->gcstats.currcycle.markitems++; - cost += propagatemark(g); - } - - if (!g->gray) - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; - - g->gcstate = GCSpropagateagain; - } + cost += propagatemark(g); } - else + + if (!g->gray) { - while (g->gray && cost < limit) - { - g->gcstats.currcycle.markitems++; + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; - cost += propagatemark(g); - } - - if (!g->gray) /* no more `gray' objects */ - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSpropagateagain; } break; } @@ -877,17 +830,34 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { - double starttimestamp = lua_clock(); + if (FFlag::LuauSeparateAtomic) + { + g->gcstate = GCSatomic; + } + else + { + double starttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - atomic(L); /* finish mark phase */ + atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } } break; } + case GCSatomic: + { + g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + cost = atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); + break; + } case GCSsweepstring: { while (g->sweepstrgc < g->strt.size && cost < limit) @@ -934,7 +904,7 @@ static size_t gcstep(lua_State* L, size_t limit) break; } default: - LUAU_ASSERT(0); + LUAU_ASSERT(!"Unexpected GC state"); } return cost; } @@ -1084,7 +1054,7 @@ void luaC_fullgc(lua_State* L) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (g->gcstate <= GCSpropagateagain) + if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) { /* reset sweep marks to sweep all elements (returning them to white) */ g->sweepstrgc = 0; @@ -1095,7 +1065,7 @@ void luaC_fullgc(lua_State* L) g->weak = NULL; g->gcstate = GCSsweepstring; } - LUAU_ASSERT(g->gcstate != GCSpause && g->gcstate != GCSpropagate && g->gcstate != GCSpropagateagain); + LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); /* finish any pending sweep phase */ while (g->gcstate != GCSpause) { @@ -1143,14 +1113,11 @@ void luaC_fullgc(lua_State* L) void luaC_barrierupval(lua_State* L, GCObject* v) { - if (FFlag::LuauGcFullSkipInactiveThreads) - { - global_State* g = L->global; - LUAU_ASSERT(iswhite(v) && !isdead(g, v)); + global_State* g = L->global; + LUAU_ASSERT(iswhite(v) && !isdead(g, v)); - if (keepinvariant(g)) - reallymarkobject(g, v); - } + if (keepinvariant(g)) + reallymarkobject(g, v); } void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v) @@ -1778,7 +1745,7 @@ int64_t luaC_allocationrate(lua_State* L) global_State* g = L->global; const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms - if (g->gcstate <= GCSpropagateagain) + if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) { double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 6fddd66..f434e50 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -6,8 +6,6 @@ #include "lobject.h" #include "lstate.h" -LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) - /* ** Possible states of the Garbage Collector */ @@ -25,7 +23,7 @@ LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) ** still-black objects. The invariant is restored when sweep ends and ** all objects are white again. */ -#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain) +#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic) /* ** some useful bit tricks @@ -147,4 +145,4 @@ LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); LUAI_FUNC int64_t luaC_allocationrate(lua_State* L); LUAI_FUNC void luaC_wakethread(lua_State* L); -LUAI_FUNC const char* luaC_statename(int state); \ No newline at end of file +LUAI_FUNC const char* luaC_statename(int state); diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 2759f3b..d8b265c 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -199,7 +199,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass) if (page->freeNext >= 0) { - block = page->data + page->freeNext; + block = &page->data + page->freeNext; ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); page->freeNext -= page->blockSize; diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index d77e17c..18ee1cd 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -226,7 +226,7 @@ void luaS_freeudata(lua_State* L, Udata* u) void (*dtor)(void*) = nullptr; if (u->tag == UTAG_IDTOR) - memcpy(&dtor, u->data + u->len - sizeof(dtor), sizeof(dtor)); + memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); else if (u->tag) dtor = L->global->udatagc[u->tag]; diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index b932a85..a168b65 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,7 @@ #include // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens -template +template struct TempBuffer { lua_State* L; @@ -346,6 +346,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size uint32_t mainid = readVarInt(data, size, offset); Proto* main = protos[mainid]; + luaC_checkthreadsleep(L); + Closure* cl = luaF_newLclosure(L, 0, envt, main); setclvalue(L, L->top, cl); incr_top(L); diff --git a/bench/tests/chess.lua b/bench/tests/chess.lua new file mode 100644 index 0000000..87b9abf --- /dev/null +++ b/bench/tests/chess.lua @@ -0,0 +1,849 @@ + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +local RANKS = "12345678" +local FILES = "abcdefgh" +local PieceSymbols = "PpRrNnBbQqKk" +local UnicodePieces = {"♙", "♟", "♖", "♜", "♘", "♞", "♗", "♝", "♕", "♛", "♔", "♚"} +local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" + +-- +-- Lua 5.2 Compat +-- + +if not table.create then + function table.create(n, v) + local result = {} + for i=1,n do result[i] = v end + return result + end +end + +if not table.move then + function table.move(a, from, to, start, target) + local dx = start - from + for i=from,to do + target[i+dx] = a[i] + end + end +end + + +-- +-- Utils +-- + +local function square(s) + return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9 +end + +local function squareName(n) + local file = n % 8 + local rank = (n-file)/8 + return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1) +end + +local function moveName(v ) + local from = bit32.extract(v, 6, 6) + local to = bit32.extract(v, 0, 6) + local piece = bit32.extract(v, 20, 4) + local captured = bit32.extract(v, 25, 4) + + local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to) + + if bit32.extract(v,14) == 1 then + if to > from then + return "O-O" + else + return "O-O-O" + end + end + + local promote = bit32.extract(v,15,4) + if promote ~= 0 then + move = move .. "=" .. PieceSymbols:sub(promote,promote) + end + return move +end + +local function ucimove(m) + local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6)) + local promote = bit32.extract(m,15,4) + if promote > 0 then + mm = mm .. PieceSymbols:sub(promote,promote):lower() + end + return mm +end + +local _utils = {squareName, moveName} + +-- +-- Bitboards +-- + +local Bitboard = {} + + +function Bitboard:toString() + local out = {} + + local src = self.h + for x=7,0,-1 do + table.insert(out, RANKS:sub(x+1,x+1)) + table.insert(out, " ") + local bit = bit32.lshift(1,(x%4) * 8) + for x=0,7 do + if bit32.band(src, bit) ~= 0 then + table.insert(out, "x ") + else + table.insert(out, "- ") + end + bit = bit32.lshift(bit, 1) + end + if x == 4 then + src = self.l + end + table.insert(out, "\n") + end + table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n') + table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h) + return table.concat(out) +end + + +function Bitboard.from(l ,h ) + return setmetatable({l=l, h=h}, Bitboard) +end + +Bitboard.zero = Bitboard.from(0,0) +Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF) + +local Rank1 = Bitboard.from(0x000000FF, 0) +local Rank3 = Bitboard.from(0x00FF0000, 0) +local Rank6 = Bitboard.from(0, 0x0000FF00) +local Rank8 = Bitboard.from(0, 0xFF000000) +local FileA = Bitboard.from(0x01010101, 0x01010101) +local FileB = Bitboard.from(0x02020202, 0x02020202) +local FileC = Bitboard.from(0x04040404, 0x04040404) +local FileD = Bitboard.from(0x08080808, 0x08080808) +local FileE = Bitboard.from(0x10101010, 0x10101010) +local FileF = Bitboard.from(0x20202020, 0x20202020) +local FileG = Bitboard.from(0x40404040, 0x40404040) +local FileH = Bitboard.from(0x80808080, 0x80808080) + +local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH} + +-- These masks are filled out below for all files +local RightMasks = {FileH} +local LeftMasks = {FileA} + + + +local function popcnt32(i) + i = i - bit32.band(bit32.rshift(i,1), 0x55555555) + i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333) + return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24) +end + +function Bitboard:up() + return self:lshift(8) +end + +function Bitboard:down() + return self:rshift(8) +end + +function Bitboard:right() + return self:band(FileH:inverse()):lshift(1) +end + +function Bitboard:left() + return self:band(FileA:inverse()):rshift(1) +end + +function Bitboard:move(x,y) + local out = self + + if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end + if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end + + if y < 0 then out = out:rshift(-8 * y) end + if y > 0 then out = out:lshift(8 * y) end + return out +end + + +function Bitboard:popcnt() + return popcnt32(self.l) + popcnt32(self.h) +end + +function Bitboard:band(other ) + return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h)) +end + +function Bitboard:bandnot(other ) + return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h))) +end + +function Bitboard:bandempty(other ) + return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0 +end + +function Bitboard:bor(other ) + return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h)) +end + +function Bitboard:bxor(other ) + return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h)) +end + +function Bitboard:inverse() + return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF)) +end + +function Bitboard:empty() + return self.h == 0 and self.l == 0 +end + +function Bitboard:ctz() + local target = self.l + local offset = 0 + local result = 0 + + if target == 0 then + target = self.h + result = 32 + end + + if target == 0 then + return 64 + end + + while bit32.extract(target, offset) == 0 do + offset = offset + 1 + end + + return result + offset +end + +function Bitboard:ctzafter(start) + start = start + 1 + if start < 32 then + for i=start,31 do + if bit32.extract(self.l, i) == 1 then return i end + end + end + for i=math.max(32,start),63 do + if bit32.extract(self.h, i-32) == 1 then return i end + end + return 64 +end + + +function Bitboard:lshift(amt) + assert(amt >= 0) + if amt == 0 then return self end + + if amt > 31 then + return Bitboard.from(0, bit32.lshift(self.l, amt-31)) + end + + local l = bit32.lshift(self.l, amt) + local h = bit32.bor( + bit32.lshift(self.h, amt), + bit32.extract(self.l, 32-amt, amt) + ) + return Bitboard.from(l, h) +end + +function Bitboard:rshift(amt) + assert(amt >= 0) + if amt == 0 then return self end + local h = bit32.rshift(self.h, amt) + local l = bit32.bor( + bit32.rshift(self.l, amt), + bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt) + ) + return Bitboard.from(l, h) +end + +function Bitboard:index(i) + if i > 31 then + return bit32.extract(self.h, i - 32) + else + return bit32.extract(self.l, i) + end +end + +function Bitboard:set(i , v) + if i > 31 then + return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32)) + else + return Bitboard.from(bit32.replace(self.l, v, i), self.h) + end +end + +function Bitboard:isolate(i) + return self:band(Bitboard.some(i)) +end + +function Bitboard.some(idx ) + return Bitboard.zero:set(idx, 1) +end + +Bitboard.__index = Bitboard +Bitboard.__tostring = Bitboard.toString + +for i=2,8 do + RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH) + LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA) +end +-- +-- Board +-- + +local Board = {} + + +function Board.new() + local boards = table.create(12, Bitboard.zero) + boards.ocupied = Bitboard.zero + boards.white = Bitboard.zero + boards.black = Bitboard.zero + boards.unocupied = Bitboard.full + boards.ep = Bitboard.zero + boards.castle = Bitboard.zero + boards.toMove = 1 + boards.hm = 0 + boards.moves = 0 + boards.material = 0 + + return setmetatable(boards, Board) +end + +function Board.fromFen(fen ) + local b = Board.new() + local i = 0 + local rank = 7 + local file = 0 + + while true do + i = i + 1 + local p = fen:sub(i,i) + if p == '/' then + rank = rank - 1 + file = 0 + elseif tonumber(p) ~= nil then + file = file + tonumber(p) + else + local pidx = PieceSymbols:find(p) + if pidx == nil then break end + b[pidx] = b[pidx]:set(rank*8+file, 1) + file = file + 1 + end + end + + + local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i) + if move == nil then print(fen:sub(i)) end + b.toMove = move == 'w' and 1 or 2 + + if ep ~= "-" then + b.ep = Bitboard.some(square(ep)) + end + + if castle ~= "-" then + local oo = Bitboard.zero + if castle:find("K") then + oo = oo:set(7, 1) + end + if castle:find("Q") then + oo = oo:set(0, 1) + end + if castle:find("k") then + oo = oo:set(63, 1) + end + if castle:find("q") then + oo = oo:set(56, 1) + end + + b.castle = oo + end + + b.hm = hm + b.moves = m + + b:updateCache() + return b + +end + +function Board:index(idx ) + if self.white:index(idx) == 1 then + for p=1,12,2 do + if self[p]:index(idx) == 1 then + return p + end + end + else + for p=2,12,2 do + if self[p]:index(idx) == 1 then + return p + end + end + end + + return 0 +end + +function Board:updateCache() + for i=1,11,2 do + self.white = self.white:bor(self[i]) + self.black = self.black:bor(self[i+1]) + end + + self.ocupied = self.black:bor(self.white) + self.unocupied = self.ocupied:inverse() + self.material = + 100*self[1]:popcnt() - 100*self[2]:popcnt() + + 500*self[3]:popcnt() - 500*self[4]:popcnt() + + 300*self[5]:popcnt() - 300*self[6]:popcnt() + + 300*self[7]:popcnt() - 300*self[8]:popcnt() + + 900*self[9]:popcnt() - 900*self[10]:popcnt() + +end + +function Board:fen() + local out = {} + local s = 0 + local idx = 56 + for i=0,63 do + if i % 8 == 0 and i > 0 then + idx = idx - 16 + if s > 0 then + table.insert(out, '' .. s) + s = 0 + end + table.insert(out, '/') + end + local p = self:index(idx) + if p == 0 then + s = s + 1 + else + if s > 0 then + table.insert(out, '' .. s) + s = 0 + end + table.insert(out, PieceSymbols:sub(p,p)) + end + + idx = idx + 1 + end + if s > 0 then + table.insert(out, '' .. s) + end + + table.insert(out, self.toMove == 1 and ' w ' or ' b ') + if self.castle:empty() then + table.insert(out, '-') + else + if self.castle:index(7) == 1 then table.insert(out, 'K') end + if self.castle:index(0) == 1 then table.insert(out, 'Q') end + if self.castle:index(63) == 1 then table.insert(out, 'k') end + if self.castle:index(56) == 1 then table.insert(out, 'q') end + end + + table.insert(out, ' ') + if self.ep:empty() then + table.insert(out, '-') + else + table.insert(out, squareName(self.ep:ctz())) + end + + table.insert(out, ' ' .. self.hm) + table.insert(out, ' ' .. self.moves) + + return table.concat(out) +end + +function Board:pmoves(idx) + return self:generate(idx) +end + +function Board:pcaptures(idx) + return self:generate(idx):band(self.ocupied) +end + +local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}} +local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}} +local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}} +local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}} + +function Board:generate(idx) + local piece = self:index(idx) + local r = Bitboard.some(idx) + local out = Bitboard.zero + local type = bit32.rshift(piece - 1, 1) + local cancapture = piece % 2 == 1 and self.black or self.white + + if piece == 0 then return Bitboard.zero end + + if type == 0 then + -- Pawn + local d = -(piece*2 - 3) + local movetwo = piece == 1 and Rank3 or Rank6 + + out = out:bor(r:move(0,d):band(self.unocupied)) + out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied)) + + local captures = r:move(0,d) + captures = captures:right():bor(captures:left()) + + if not captures:bandempty(self.ep) then + out = out:bor(self.ep) + end + + captures = captures:band(cancapture) + out = out:bor(captures) + + return out + elseif type == 5 then + -- King + for x=-1,1,1 do + for y = -1,1,1 do + local w = r:move(x,y) + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + end + end + end + elseif type == 2 then + -- Knight + for _,j in ipairs(KNIGHT_MOVES) do + local w = r:move(j[1],j[2]) + + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + end + end + else + -- Sliders (Rook, Bishop, Queen) + local slides + if type == 1 then + slides = ROOK_SLIDES + elseif type == 3 then + slides = BISHOP_SLIDES + else + slides = QUEEN_SLIDES + end + + for _, op in ipairs(slides) do + local w = r + for i=1,7 do + w = w:move(op[1], op[2]) + if w:empty() then break end + + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + break + end + end + end + end + + + return out +end + +-- 0-5 - From Square +-- 6-11 - To Square +-- 12 - is Check +-- 13 - Is EnPassent +-- 14 - Is Castle +-- 15-19 - Promotion Piece +-- 20-24 - Moved Pice +-- 25-29 - Captured Piece + + +function Board:toString(mark ) + local out = {} + for x=8,1,-1 do + table.insert(out, RANKS:sub(x,x) .. " ") + + for y=1,8 do + local n = 8*x+y-9 + local i = self:index(n) + if i == 0 then + table.insert(out, '-') + else + -- out = out .. PieceSymbols:sub(i,i) + table.insert(out, UnicodePieces[i]) + end + if mark ~= nil and mark:index(n) ~= 0 then + table.insert(out, ')') + elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then + table.insert(out, '(') + else + table.insert(out, ' ') + end + end + + table.insert(out, "\n") + end + table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n') + table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n") + return table.concat(out) +end + +function Board:moveList() + local tm = self.toMove == 1 and self.white or self.black + local castle_rank = self.toMove == 1 and Rank1 or Rank8 + local out = {} + local function emit(id) + if not self:applyMove(id):illegalyChecked() then + table.insert(out, id) + end + end + + local cr = tm:band(self.castle):band(castle_rank) + if not cr:empty() then + local p = self.toMove == 1 and 11 or 12 + local tcolor = self.toMove == 1 and self.black or self.white + local kidx = self[p]:ctz() + + + local castle = bit32.replace(0, p, 20, 4) + castle = bit32.replace(castle, kidx, 6, 6) + castle = bit32.replace(castle, 1, 14) + + + local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank) + local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p]) + if + not cr:bandempty(FileA) and + mustbeemptyl:bandempty(self.ocupied) and + not self:isSquareThreatened(cantbethreatened, tcolor) + then + emit(bit32.replace(castle, kidx - 2, 0, 6)) + end + + + local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank) + if + not cr:bandempty(FileH) and + mustbeemptyr:bandempty(self.ocupied) and + not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor) + then + emit(bit32.replace(castle, kidx + 2, 0, 6)) + end + end + + local sq = tm:ctz() + repeat + local p = self:index(sq) + local moves = self:pmoves(sq) + + while not moves:empty() do + local m = moves:ctz() + moves = moves:set(m, 0) + local id = bit32.replace(m, sq, 6, 6) + id = bit32.replace(id, p, 20, 4) + local mbb = Bitboard.some(m) + if not self.ocupied:bandempty(mbb) then + id = bit32.replace(id, self:index(m), 25, 4) + end + + -- Check if pawn needs to be promoted + if p == 1 and m >= 8*7 then + for i=3,9,2 do + emit(bit32.replace(id, i, 15, 4)) + end + elseif p == 2 and m < 8 then + for i=4,10,2 do + emit(bit32.replace(id, i, 15, 4)) + end + else + emit(id) + end + end + sq = tm:ctzafter(sq) + until sq == 64 + return out +end + +function Board:illegalyChecked() + local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")] + return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black) +end + +function Board:isSquareThreatened(target , color ) + local tm = color + local sq = tm:ctz() + repeat + local moves = self:pmoves(sq) + if not moves:bandempty(target) then + return true + end + sq = color:ctzafter(sq) + until sq == 64 + return false +end + +function Board:perft(depth ) + if depth == 0 then return 1 end + if depth == 1 then + return #self:moveList() + end + local result = 0 + for k,m in ipairs(self:moveList()) do + local c = self:applyMove(m):perft(depth - 1) + if c == 0 then + -- Perft only counts leaf nodes at target depth + -- result = result + 1 + else + result = result + c + end + end + return result +end + + +function Board:applyMove(move ) + local out = Board.new() + table.move(self, 1, 12, 1, out) + local from = bit32.extract(move, 6, 6) + local to = bit32.extract(move, 0, 6) + local promote = bit32.extract(move, 15, 4) + local piece = self:index(from) + local captured = self:index(to) + local tom = Bitboard.some(to) + local isCastle = bit32.extract(move, 14) + + if piece % 2 == 0 then + out.moves = self.moves + 1 + end + + if captured == 1 or piece < 3 then + out.hm = 0 + else + out.hm = self.hm + 1 + end + out.castle = self.castle + out.toMove = self.toMove == 1 and 2 or 1 + + if isCastle == 1 then + local rank = piece == 11 and Rank1 or Rank8 + local colorOffset = piece - 11 + + out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA) + out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank)) + + out[piece] = (from < to and FileG or FileC):band(rank) + out.castle = out.castle:bandnot(rank) + out:updateCache() + return out + end + + if piece < 3 then + local dist = math.abs(to - from) + -- Pawn moved two squares, set ep square + if dist == 16 then + out.ep = Bitboard.some((from + to) / 2) + end + + -- Remove enpasent capture + if not tom:bandempty(self.ep) then + if piece == 1 then + out[2] = out[2]:bandnot(self.ep:down()) + end + if piece == 2 then + out[1] = out[1]:bandnot(self.ep:up()) + end + end + end + + if piece == 3 or piece == 4 then + out.castle = out.castle:set(from, 0) + end + + if piece > 10 then + local rank = piece == 11 and Rank1 or Rank8 + out.castle = out.castle:bandnot(rank) + end + + out[piece] = out[piece]:set(from, 0) + if promote == 0 then + out[piece] = out[piece]:set(to, 1) + else + out[promote] = out[promote]:set(to, 1) + end + if captured ~= 0 then + out[captured] = out[captured]:set(to, 0) + end + + out:updateCache() + return out +end + +Board.__index = Board +Board.__tostring = Board.toString +-- +-- Main +-- + +local failures = 0 +local function test(fen, ply, target) + local b = Board.fromFen(fen) + if b:fen() ~= fen then + print("FEN MISMATCH", fen, b:fen()) + failures = failures + 1 + return + end + + local found = b:perft(ply) + if found ~= target then + print(fen, "Found", found, "target", target) + failures = failures + 1 + for k,v in pairs(b:moveList()) do + print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1')) + end + --error("Test Failure") + else + print("OK", found, fen) + end +end + +-- From https://www.chessprogramming.org/Perft_Results +-- If interpreter, computers, or algorithm gets too fast +-- feel free to go deeper + +local testCases = {} +local function addTest(...) table.insert(testCases, {...}) end + +addTest(StartingFen, 3, 8902) +addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039) +addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812) +addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467) +addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486) +addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079) + + +local function chess() + for k,v in ipairs(testCases) do + test(v[1],v[2],v[3]) + end +end + +bench.runCode(chess, "chess") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 07910a0..44b8362 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1596,7 +1596,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion") local function target(a: number, b: string) return a + #b end local function d(a: n@1, b) - return target(a, b) + return target(a, b) end )"); @@ -1609,7 +1609,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a, b: s@1) - return target(a, b) + return target(a, b) end )"); @@ -1622,7 +1622,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a:@1 @2, b) - return target(a, b) + return target(a, b) end )"); @@ -1640,7 +1640,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a, b: @1)@2: number - return target(a, b) + return target(a, b) end )"); @@ -1682,7 +1682,7 @@ local x = target(function(a: n@1 local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end local x = target(function(a: n@1, b: @2) - return a + #b + return a + #b end) )"); @@ -1700,7 +1700,7 @@ end) local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(a: n@1) - return a + return a end )"); @@ -1716,7 +1716,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(...:n@1) - return a + return a end )"); @@ -1729,7 +1729,7 @@ end local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(a:number, b:number, ...:@1) - return a + b + return a + b end )"); @@ -1745,7 +1745,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion") local function target(callback: () -> number) return callback() end local x = target(function(): n@1 - return 1 + return 1 end )"); @@ -1758,7 +1758,7 @@ end local function target(callback: () -> (number, number)) return callback() end local x = target(function(): (number, n@1 - return 1, 2 + return 1, 2 end )"); @@ -1774,7 +1774,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion" local function target(callback: () -> ...number) return callback() end local x = target(function(): ...n@1 - return 1, 2, 3 + return 1, 2, 3 end )"); @@ -1787,7 +1787,7 @@ end local function target(callback: () -> ...number) return callback() end local x = target(function(): (number, number, ...n@1 - return 1, 2, 3 + return 1, 2, 3 end )"); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 54a31a6..bbac330 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -768,11 +768,11 @@ TEST_CASE("CaptureSelf") local MaterialsListClass = {} function MaterialsListClass:_MakeToolTip(guiElement, text) - local function updateTooltipPosition() - self._tweakingTooltipFrame = 5 - end + local function updateTooltipPosition() + self._tweakingTooltipFrame = 5 + end - updateTooltipPosition() + updateTooltipPosition() end return MaterialsListClass @@ -2001,14 +2001,14 @@ TEST_CASE("UpvaluesLoopsBytecode") { CHECK_EQ("\n" + compileFunction(R"( function test() - for i=1,10 do + for i=1,10 do i = i - foo(function() return i end) - if bar then - break - end - end - return 0 + foo(function() return i end) + if bar then + break + end + end + return 0 end )", 1), @@ -2035,14 +2035,14 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction(R"( function test() - for i in ipairs(data) do + for i in ipairs(data) do i = i - foo(function() return i end) - if bar then - break - end - end - return 0 + foo(function() return i end) + if bar then + break + end + end + return 0 end )", 1), @@ -2068,17 +2068,17 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction(R"( function test() - local i = 0 - while i < 5 do - local j + local i = 0 + while i < 5 do + local j j = i - foo(function() return j end) - i = i + 1 - if bar then - break - end - end - return 0 + foo(function() return j end) + i = i + 1 + if bar then + break + end + end + return 0 end )", 1), @@ -2105,17 +2105,17 @@ RETURN R1 1 CHECK_EQ("\n" + compileFunction(R"( function test() - local i = 0 - repeat - local j + local i = 0 + repeat + local j j = i - foo(function() return j end) - i = i + 1 - if bar then - break - end - until i < 5 - return 0 + foo(function() return j end) + i = i + 1 + if bar then + break + end + until i < 5 + return 0 end )", 1), @@ -2304,10 +2304,10 @@ local Value1, Value2, Value3 = ... local Table = {} Table.SubTable["Key"] = { - Key1 = Value1, - Key2 = Value2, - Key3 = Value3, - Key4 = true, + Key1 = Value1, + Key2 = Value2, + Key3 = Value3, + Key4 = true, } )"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 5a697a4..06b3c52 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -801,4 +801,17 @@ TEST_CASE("IfElseExpression") runConformance("ifelseexpr.lua"); } +TEST_CASE("TagMethodError") +{ + ScopedFastFlag sff{"LuauCcallRestoreFix", true}; + + runConformance("tmerror.lua", [](lua_State* L) { + auto* cb = lua_callbacks(L); + + cb->debugprotectederror = [](lua_State* L) { + CHECK(lua_isyieldable(L)); + }; + }); +} + TEST_SUITE_END(); diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index 9f87489..e55b5b0 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -2,6 +2,9 @@ #pragma once #include +#include + +namespace std { inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) { @@ -9,10 +12,12 @@ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) } template -std::ostream& operator<<(std::ostream& lhs, const std::optional& t) +auto operator<<(std::ostream& lhs, const std::optional& t) -> decltype(lhs << *t) // SFINAE to only instantiate << for supported types { if (t) return lhs << *t; else return lhs << "none"; } + +} // namespace std diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index a9ed139..37f1b60 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -791,13 +791,13 @@ TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings") { LintResult result = lint(R"(--!strict type InputData = { - id: number, - inputType: EnumItem, - inputState: EnumItem, - updated: number, - position: Vector3, - keyCode: EnumItem, - name: string + id: number, + inputType: EnumItem, + inputState: EnumItem, + updated: number, + position: Vector3, + keyCode: EnumItem, + name: string } )"); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 045f023..f580604 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -554,4 +554,54 @@ TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") CHECK_EQ("{number | string}", toString(requireType("p"), {true})); } +/* + * We had a problem where all type aliases would be prototyped into a child scope that happened + * to have the same level. This caused a problem where, if a sibling function referred to that + * type alias in its type signature, it would erroneously be quantified away, even though it doesn't + * actually belong to the function. + * + * We solved this by ascribing a unique subLevel to each prototyped alias. + */ +TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases") +{ + CheckResult result = check(R"( + --!strict + + local KeyPool = {} + + local function newkey(pool: KeyPool, index) + return {} + end + + function newKeyPool() + local pool = { + available = {} :: {Key}, + } + + return setmetatable(pool, KeyPool) + end + + export type KeyPool = typeof(newKeyPool()) + export type Key = typeof(newkey(newKeyPool(), 1)) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * We keep a cache of type alias onto TypeVar to prevent infinite types from + * being constructed via recursive or corecursive aliases. We have to adjust + * the TypeLevels of those generic TypeVars so that the unifier doesn't think + * they have improperly leaked out of their scope. + */ +TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_scope_if_they_are_reused_in_multiple_aliases") +{ + CheckResult result = check(R"( + type Array = {T} + type Exclude = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 6da33a0..eabf7e6 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -437,8 +437,6 @@ TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_orde TEST_CASE_FIXTURE(ClassFixture, "optional_class_field_access_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local b: Vector2? = nil local a = b.X + b.Z diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 3a04a18..581375a 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -695,4 +695,25 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") CHECK(requireType("y1") == requireType("y2")); } +TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") +{ + ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true}; + ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true}; + + CheckResult result = check(R"( +local exports = {} +local nested = {} + +nested.name = function(t, k) + local a = t.x.y + return rawget(t, k) +end + +exports.nested = nested +return exports + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index b00fddc..419da8a 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -9,6 +9,7 @@ #include LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -42,7 +43,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expected = R"( + const std::string old_expected = R"( function f(a:{fn:()->(free,free...)}): () if type(a) == 'boolean'then local a1:boolean=a @@ -51,7 +52,21 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end end )"; - CHECK_EQ(expected, decorateWithTypes(code)); + + const std::string expected = R"( + function f(a:{fn:()->(a,b...)}): () + if type(a) == 'boolean'then + local a1:boolean=a + elseif a.fn()then + local a2:{fn:()->(a,b...)}=a + end + end + )"; + + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ(expected, decorateWithTypes(code)); + else + CHECK_EQ(old_expected, decorateWithTypes(code)); } TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") @@ -263,8 +278,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { - ScopedFastInt sffi{"LuauTarjanChildLimit", 50}; - ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 50}; + ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; + ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1}; CheckResult result = check(R"LUA( local Result diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 31739cd..36dcaa9 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,6 +8,7 @@ LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauOrPredicate) +LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -698,10 +699,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); + else + CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + else + CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" } TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index b7f0dc7..f1451a8 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -617,7 +617,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") REQUIRE_EQ(indexer.indexType, typeChecker.numberType); - REQUIRE(nullptr != get(indexer.indexResultType)); + REQUIRE(nullptr != get(follow(indexer.indexResultType))); } TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index dbc4538..4538175 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -180,7 +180,12 @@ TEST_CASE_FIXTURE(Fixture, "expr_statement") TEST_CASE_FIXTURE(Fixture, "generic_function") { - CheckResult result = check("function id(x) return x end local a = id(55) local b = id(nil)"); + CheckResult result = check(R"( + function id(x) return x end + local a = id(55) + local b = id(nil) + )"); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*typeChecker.numberType, *requireType("a")); @@ -1889,7 +1894,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") REQUIRE_EQ(2, argVec.size()); - const FunctionTypeVar* fType = get(argVec[0]); + const FunctionTypeVar* fType = get(follow(argVec[0])); REQUIRE(fType != nullptr); std::vector fArgs = flatten(fType->argTypes).first; @@ -1926,7 +1931,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") REQUIRE_EQ(6, argVec.size()); - const FunctionTypeVar* fType = get(argVec[0]); + const FunctionTypeVar* fType = get(follow(argVec[0])); REQUIRE(fType != nullptr); } @@ -3842,10 +3847,10 @@ local T: any T = {} T.__index = T function T.new(...) - local self = {} - setmetatable(self, T) - self:construct(...) - return self + local self = {} + setmetatable(self, T) + self:construct(...) + return self end function T:construct(index) end @@ -4068,11 +4073,11 @@ function n:Clone() end local m = {} function m.a(x) - x:Clone() + x:Clone() end function m.b() - m.a(n) + m.a(n) end return m @@ -4393,8 +4398,6 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") { - ScopedFastFlag luauInferFunctionArgsFix("LuauInferFunctionArgsFix", true); - // Simple direct arg to arg propagation CheckResult result = check(R"( type Table = { x: number, y: number } @@ -4681,7 +4684,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") { ScopedFastFlag sffs[] = { {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - {"LuauExtraNilRecovery", true}, }; CheckResult result = check(R"( @@ -4698,7 +4700,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") { ScopedFastFlag sffs[] = { {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - {"LuauExtraNilRecovery", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 1f4b63e..1192a8a 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -8,6 +8,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauQuantifyInPlace2); + using namespace Luau; struct TryUnifyFixture : Fixture @@ -15,7 +17,8 @@ struct TryUnifyFixture : Fixture TypeArena arena; ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; - Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, &iceHandler}; + UnifierSharedState unifierState{&iceHandler}; + Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, unifierState}; }; TEST_SUITE_BEGIN("TryUnifyTests"); @@ -139,7 +142,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails" )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("(number) -> boolean", toString(requireType("f"))); + else + CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 3e1dedd..8dab260 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -98,10 +98,10 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function") std::vector applyArgs = flatten(applyType->argTypes).first; REQUIRE_EQ(3, applyArgs.size()); - const FunctionTypeVar* fType = get(applyArgs[0]); + const FunctionTypeVar* fType = get(follow(applyArgs[0])); REQUIRE(fType != nullptr); - const FunctionTypeVar* gType = get(applyArgs[1]); + const FunctionTypeVar* gType = get(follow(applyArgs[1])); REQUIRE(gType != nullptr); std::vector gArgs = flatten(gType->argTypes).first; @@ -285,7 +285,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") { CheckResult result = check(R"( local _ = function():((...any)->(...any),()->()) - return function() end, function() end + return function() end, function() end end for y in _() do end diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index c0ed25d..34c25a9 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -181,8 +181,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property") TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") { - ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); - CheckResult result = check(R"( type A = {x: number} type B = {} @@ -242,8 +240,6 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") TEST_CASE_FIXTURE(Fixture, "optional_union_members") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = { a = { x = 1, y = 2 }, b = 3 } type A = typeof(a) @@ -259,8 +255,6 @@ local c = bf.a.y TEST_CASE_FIXTURE(Fixture, "optional_union_functions") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = {} function a.foo(x:number, y:number) return x + y end @@ -276,8 +270,6 @@ local c = b.foo(1, 2) TEST_CASE_FIXTURE(Fixture, "optional_union_methods") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = {} function a:foo(x:number, y:number) return x + y end @@ -310,8 +302,6 @@ return f() TEST_CASE_FIXTURE(Fixture, "optional_field_access_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = { x: number } local b: A? = { x = 2 } @@ -327,8 +317,6 @@ local d = b.y TEST_CASE_FIXTURE(Fixture, "optional_index_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = {number} local a: A? = {1, 2, 3} @@ -341,8 +329,6 @@ local b = a[1] TEST_CASE_FIXTURE(Fixture, "optional_call_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = (number) -> number local a: A? = function(a) return -a end @@ -355,8 +341,6 @@ local b = a(4) TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = { x: number } local a: A? = { x = 2 } @@ -378,8 +362,6 @@ a.x = 2 TEST_CASE_FIXTURE(Fixture, "optional_length_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = {number} local a: A? = {1, 2, 3} @@ -392,9 +374,6 @@ local b = #a TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); - CheckResult result = check(R"( type A = { x: number, y: number } type B = { x: number, y: number } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index a679e3f..930c1a3 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -265,4 +265,64 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result)); } +TEST_CASE("tagging_tables") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar ttv{TableTypeVar{}}; + CHECK(!Luau::hasTag(&ttv, "foo")); + Luau::attachTag(&ttv, "foo"); + CHECK(Luau::hasTag(&ttv, "foo")); +} + +TEST_CASE("tagging_classes") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + CHECK(!Luau::hasTag(&base, "foo")); + Luau::attachTag(&base, "foo"); + CHECK(Luau::hasTag(&base, "foo")); +} + +TEST_CASE("tagging_subclasses") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}}; + + CHECK(!Luau::hasTag(&base, "foo")); + CHECK(!Luau::hasTag(&derived, "foo")); + + Luau::attachTag(&base, "foo"); + CHECK(Luau::hasTag(&base, "foo")); + CHECK(Luau::hasTag(&derived, "foo")); + + Luau::attachTag(&derived, "bar"); + CHECK(!Luau::hasTag(&base, "bar")); + CHECK(Luau::hasTag(&derived, "bar")); +} + +TEST_CASE("tagging_functions") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypePackVar empty{TypePack{}}; + TypeVar ftv{FunctionTypeVar{&empty, &empty}}; + CHECK(!Luau::hasTag(&ftv, "foo")); + Luau::attachTag(&ftv, "foo"); + CHECK(Luau::hasTag(&ftv, "foo")); +} + +TEST_CASE("tagging_props") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + Property prop{}; + CHECK(!Luau::hasTag(prop, "foo")); + Luau::attachTag(prop, "foo"); + CHECK(Luau::hasTag(prop, "foo")); +} + TEST_SUITE_END(); diff --git a/tests/conformance/tmerror.lua b/tests/conformance/tmerror.lua new file mode 100644 index 0000000..1ad4dd1 --- /dev/null +++ b/tests/conformance/tmerror.lua @@ -0,0 +1,15 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes + +-- Generate an error (i.e. throw an exception) inside a tag method which is indirectly +-- called via pcall. +-- This test is meant to detect a regression in handling errors inside a tag method + +local testtable = {} +setmetatable(testtable, { __index = function() error("Error") end }) + +pcall(function() + testtable.missingmethod() +end) + +return('OK') diff --git a/tools/gdb-printers.py b/tools/gdb-printers.py index c711c5e..017b9f9 100644 --- a/tools/gdb-printers.py +++ b/tools/gdb-printers.py @@ -11,9 +11,9 @@ class VariantPrinter: return type.name + " [" + str(value) + "]" def match_printer(val): - type = val.type.strip_typedefs() - if type.name and type.name.startswith('Luau::Variant<'): - return VariantPrinter(val) - return None + type = val.type.strip_typedefs() + if type.name and type.name.startswith('Luau::Variant<'): + return VariantPrinter(val) + return None gdb.pretty_printers.append(match_printer)