Sync to upstream/release/503 (#135)

- A series of major optimizations to type checking performance on complex
programs/types (up to two orders of magnitude speedup for programs
involving huge tagged unions)
- Fix a few issues encountered by UBSAN (and maybe fix s390x builds)
- Fix gcc-11 test builds
- Fix a rare corner case where luau_load wouldn't wake inactive threads
which could result in a use-after-free due to GC
- Fix CLI crash when error object that's not a string escapes to top level
- Fix Makefile suffixes on macOS

Co-authored-by: Rodactor <rodactor@roblox.com>
This commit is contained in:
Arseny Kapoulkine 2021-11-05 08:47:21 -07:00 committed by GitHub
parent c0b95b8961
commit 279855df91
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 2201 additions and 623 deletions

View File

@ -34,7 +34,6 @@ TypeId makeFunction( // Polymorphic
std::initializer_list<TypeId> paramTypes, std::initializer_list<std::string> paramNames, std::initializer_list<TypeId> retTypes);
void attachMagicFunction(TypeId ty, MagicFunction fn);
void attachFunctionTag(TypeId ty, std::string constraint);
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol = std::nullopt);
void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName);

View File

@ -0,0 +1,14 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/TypeVar.h"
namespace Luau
{
struct Module;
using ModulePtr = std::shared_ptr<Module>;
void quantify(ModulePtr module, TypeId ty, TypeLevel level);
} // namespace Luau

View File

@ -69,4 +69,6 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {});
void dump(TypeId ty);
void dump(TypePackId ty);
std::string generateName(size_t n);
} // namespace Luau

View File

@ -12,6 +12,7 @@ struct AstArray;
class AstStat;
bool containsFunctionCall(const AstStat& stat);
bool containsFunctionCallOrReturn(const AstStat& stat);
bool isFunction(const AstStat& stat);
void toposort(std::vector<AstStat*>& stats);

View File

@ -3,19 +3,37 @@
#include "Luau/TypeVar.h"
LUAU_FASTFLAG(LuauShareTxnSeen);
namespace Luau
{
// Log of where what TypeIds we are rebinding and what they used to be
struct TxnLog
{
TxnLog() = default;
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& seen)
: seen(seen)
TxnLog()
: originalSeenSize(0)
, ownedSeen()
, sharedSeen(&ownedSeen)
{
}
explicit TxnLog(std::vector<std::pair<TypeId, TypeId>>* sharedSeen)
: originalSeenSize(sharedSeen->size())
, ownedSeen()
, sharedSeen(sharedSeen)
{
}
explicit TxnLog(const std::vector<std::pair<TypeId, TypeId>>& ownedSeen)
: originalSeenSize(ownedSeen.size())
, ownedSeen(ownedSeen)
, sharedSeen(nullptr)
{
// This is deprecated!
LUAU_ASSERT(!FFlag::LuauShareTxnSeen);
}
TxnLog(const TxnLog&) = delete;
TxnLog& operator=(const TxnLog&) = delete;
@ -38,9 +56,11 @@ private:
std::vector<std::pair<TypeId, TypeVar>> typeVarChanges;
std::vector<std::pair<TypePackId, TypePackVar>> typePackChanges;
std::vector<std::pair<TableTypeVar*, std::optional<TypeId>>> tableChanges;
size_t originalSeenSize;
public:
std::vector<std::pair<TypeId, TypeId>> seen; // used to avoid infinite recursion when types are cyclic
std::vector<std::pair<TypeId, TypeId>> ownedSeen; // used to avoid infinite recursion when types are cyclic
std::vector<std::pair<TypeId, TypeId>>* sharedSeen; // shared with all the descendent logs
};
} // namespace Luau

View File

@ -11,6 +11,7 @@
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
#include "Luau/Unifier.h"
#include "Luau/UnifierSharedState.h"
#include <memory>
#include <unordered_map>
@ -121,7 +122,7 @@ struct TypeChecker
void check(const ScopePtr& scope, const AstStatForIn& forin);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function);
void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function);
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false);
void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0, bool forwardDeclare = false);
void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass);
void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction);
@ -336,7 +337,7 @@ private:
// Note: `scope` must be a fresh scope.
std::pair<std::vector<TypeId>, std::vector<TypePackId>> createGenericTypes(
const ScopePtr& scope, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames);
public:
ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense);
@ -383,6 +384,8 @@ public:
std::function<void(const ModuleName&, const ScopePtr&)> prepareModuleScope;
InternalErrorReporter* iceHandler;
UnifierSharedState unifierState;
public:
const TypeId nilType;
const TypeId numberType;

View File

@ -540,4 +540,11 @@ UnionTypeVarIterator end(const UnionTypeVar* utv);
using TypeIdPredicate = std::function<std::optional<TypeId>(TypeId)>;
std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate);
void attachTag(TypeId ty, const std::string& tagName);
void attachTag(Property& prop, const std::string& tagName);
bool hasTag(TypeId ty, const std::string& tagName);
bool hasTag(const Property& prop, const std::string& tagName);
bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work.
} // namespace Luau

View File

@ -6,6 +6,7 @@
#include "Luau/TxnLog.h"
#include "Luau/TypeInfer.h"
#include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header.
#include "Luau/UnifierSharedState.h"
#include <unordered_set>
@ -41,11 +42,14 @@ struct Unifier
std::shared_ptr<UnifierCounters> counters_DEPRECATED;
InternalErrorReporter* iceHandler;
UnifierSharedState& sharedState;
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location,
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& ownedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
UnifierCounters* counters = nullptr);
Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector<std::pair<TypeId, TypeId>>* sharedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED = nullptr,
UnifierCounters* counters = nullptr);
// Test whether the two type vars unify. Never commits the result.
@ -69,7 +73,8 @@ private:
void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed);
void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed);
void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer);
TypeId deeplyOptional(TypeId ty, std::unordered_map<TypeId,TypeId> seen = {});
TypeId deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> seen = {});
void cacheResult(TypeId superTy, TypeId subTy);
public:
void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false);
@ -101,8 +106,9 @@ private:
[[noreturn]] void ice(const std::string& message, const Location& location);
[[noreturn]] void ice(const std::string& message);
DenseHashSet<TypeId> tempSeenTy{nullptr};
DenseHashSet<TypePackId> tempSeenTp{nullptr};
// Remove with FFlagLuauCacheUnifyTableResults
DenseHashSet<TypeId> tempSeenTy_DEPRECATED{nullptr};
DenseHashSet<TypePackId> tempSeenTp_DEPRECATED{nullptr};
};
} // namespace Luau

View File

@ -0,0 +1,44 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
#include <utility>
namespace Luau
{
struct InternalErrorReporter;
struct TypeIdPairHash
{
size_t hashOne(Luau::TypeId key) const
{
return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9);
}
size_t operator()(const std::pair<Luau::TypeId, Luau::TypeId>& x) const
{
return hashOne(x.first) ^ (hashOne(x.second) << 1);
}
};
struct UnifierSharedState
{
UnifierSharedState(InternalErrorReporter* iceHandler)
: iceHandler(iceHandler)
{
}
InternalErrorReporter* iceHandler;
DenseHashSet<void*> seenAny{nullptr};
DenseHashMap<TypeId, bool> skipCacheForType{nullptr};
DenseHashSet<std::pair<TypeId, TypeId>, TypeIdPairHash> cachedUnify{{nullptr, nullptr}};
DenseHashSet<TypeId> tempSeenTy{nullptr};
DenseHashSet<TypePackId> tempSeenTp{nullptr};
};
} // namespace Luau

View File

@ -1,9 +1,12 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/DenseHash.h"
#include "Luau/TypeVar.h"
#include "Luau/TypePack.h"
LUAU_FASTFLAG(LuauCacheUnifyTableResults)
namespace Luau
{
@ -32,17 +35,33 @@ inline bool hasSeen(std::unordered_set<void*>& seen, const void* tv)
return !seen.insert(ttv).second;
}
inline bool hasSeen(DenseHashSet<void*>& seen, const void* tv)
{
void* ttv = const_cast<void*>(tv);
if (seen.contains(ttv))
return true;
seen.insert(ttv);
return false;
}
inline void unsee(std::unordered_set<void*>& seen, const void* tv)
{
void* ttv = const_cast<void*>(tv);
seen.erase(ttv);
}
template<typename F>
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen);
inline void unsee(DenseHashSet<void*>& seen, const void* tv)
{
// When DenseHashSet is used for 'visitOnce', where don't forget visited elements
}
template<typename F>
void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
template<typename F, typename Set>
void visit(TypePackId tp, F& f, Set& seen);
template<typename F, typename Set>
void visit(TypeId ty, F& f, Set& seen)
{
if (visit_detail::hasSeen(seen, ty))
{
@ -79,15 +98,23 @@ void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
else if (auto ttv = get<TableTypeVar>(ty))
{
// Some visitors want to see bound tables, that's why we visit the original type
if (apply(ty, *ttv, seen, f))
{
for (auto& [_name, prop] : ttv->props)
visit(prop.type, f, seen);
if (ttv->indexer)
if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo)
{
visit(ttv->indexer->indexType, f, seen);
visit(ttv->indexer->indexResultType, f, seen);
visit(*ttv->boundTo, f, seen);
}
else
{
for (auto& [_name, prop] : ttv->props)
visit(prop.type, f, seen);
if (ttv->indexer)
{
visit(ttv->indexer->indexType, f, seen);
visit(ttv->indexer->indexResultType, f, seen);
}
}
}
}
@ -140,8 +167,8 @@ void visit(TypeId ty, F& f, std::unordered_set<void*>& seen)
visit_detail::unsee(seen, ty);
}
template<typename F>
void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
template<typename F, typename Set>
void visit(TypePackId tp, F& f, Set& seen)
{
if (visit_detail::hasSeen(seen, tp))
{
@ -182,6 +209,7 @@ void visit(TypePackId tp, F& f, std::unordered_set<void*>& seen)
visit_detail::unsee(seen, tp);
}
} // namespace visit_detail
template<typename TID, typename F>
@ -197,4 +225,11 @@ void visitTypeVar(TID ty, F& f)
visit_detail::visit(ty, f, seen);
}
template<typename TID, typename F>
void visitTypeVarOnce(TID ty, F& f, DenseHashSet<void*>& seen)
{
seen.clear();
visit_detail::visit(ty, f, seen);
}
} // namespace Luau

View File

@ -196,7 +196,8 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) {
InternalErrorReporter iceReporter;
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter);
UnifierSharedState unifierState(&iceReporter);
Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState);
unifier.tryUnify(expectedType, actualType);

View File

@ -106,18 +106,6 @@ void attachMagicFunction(TypeId ty, MagicFunction fn)
LUAU_ASSERT(!"Got a non functional type");
}
void attachFunctionTag(TypeId ty, std::string tag)
{
if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
ftv->tags.emplace_back(std::move(tag));
}
else
{
LUAU_ASSERT(!"Got a non functional type");
}
}
Property makeProperty(TypeId ty, std::optional<std::string> documentationSymbol)
{
return {

View File

@ -23,6 +23,7 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false)
LUAU_FASTFLAG(LuauTraceRequireLookupChild)
LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false)
LUAU_FASTFLAG(LuauNewRequireTrace)
LUAU_FASTFLAGVARIABLE(LuauClearScopes, false)
namespace Luau
{
@ -248,7 +249,7 @@ struct RequireCycle
// Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V)
// However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true)
std::vector<RequireCycle> getRequireCycles(
const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
const FileResolver* resolver, const std::unordered_map<ModuleName, SourceNode>& sourceNodes, const SourceNode* start, bool stopAtFirst = false)
{
std::vector<RequireCycle> result;
@ -282,9 +283,9 @@ std::vector<RequireCycle> getRequireCycles(
if (top == start)
{
for (const SourceNode* node : path)
cycle.push_back(node->name);
cycle.push_back(resolver->getHumanReadableModuleName(node->name));
cycle.push_back(top->name);
cycle.push_back(resolver->getHumanReadableModuleName(top->name));
break;
}
}
@ -404,7 +405,7 @@ CheckResult Frontend::check(const ModuleName& name)
// however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term
// all correct programs must be acyclic so this code triggers rarely
if (cycleDetected)
requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck);
requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck);
// This is used by the type checker to replace the resulting type of cyclic modules with any
sourceModule.cyclic = !requireCycles.empty();
@ -458,6 +459,8 @@ CheckResult Frontend::check(const ModuleName& name)
module->astTypes.clear();
module->astExpectedTypes.clear();
module->astOriginalCallTypes.clear();
if (FFlag::LuauClearScopes)
module->scopes.resize(1);
}
if (mode != Mode::NoCheck)

View File

@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false)
LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel)
LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans)
LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false)
namespace Luau
{
@ -299,6 +300,14 @@ void TypeCloner::operator()(const FunctionTypeVar& t)
void TypeCloner::operator()(const TableTypeVar& t)
{
// If table is now bound to another one, we ignore the content of the original
if (FFlag::LuauCloneBoundTables && t.boundTo)
{
TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
seenTypes[typeId] = boundTo;
return;
}
TypeId result = dest.addType(TableTypeVar{});
TableTypeVar* ttv = getMutable<TableTypeVar>(result);
LUAU_ASSERT(ttv != nullptr);
@ -321,8 +330,11 @@ void TypeCloner::operator()(const TableTypeVar& t)
ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType),
clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)};
if (t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
if (!FFlag::LuauCloneBoundTables)
{
if (t.boundTo)
ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType);
}
for (TypeId& arg : ttv->instantiatedTypeParams)
arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType);
@ -335,7 +347,7 @@ void TypeCloner::operator()(const TableTypeVar& t)
if (ttv->state == TableState::Free)
{
if (!t.boundTo)
if (FFlag::LuauCloneBoundTables || !t.boundTo)
{
if (encounteredFreeType)
*encounteredFreeType = true;

90
Analysis/src/Quantify.cpp Normal file
View File

@ -0,0 +1,90 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Quantify.h"
#include "Luau/VisitTypeVar.h"
namespace Luau
{
struct Quantifier
{
ModulePtr module;
TypeLevel level;
std::vector<TypeId> generics;
std::vector<TypePackId> genericPacks;
Quantifier(ModulePtr module, TypeLevel level)
: module(module)
, level(level)
{
}
void cycle(TypeId) {}
void cycle(TypePackId) {}
bool operator()(TypeId ty, const FreeTypeVar& ftv)
{
if (!level.subsumes(ftv.level))
return false;
*asMutable(ty) = GenericTypeVar{level};
generics.push_back(ty);
return false;
}
template<typename T>
bool operator()(TypeId ty, const T& t)
{
return true;
}
template<typename T>
bool operator()(TypePackId, const T&)
{
return true;
}
bool operator()(TypeId ty, const TableTypeVar&)
{
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic)
return false;
if (!level.subsumes(ttv.level))
return false;
if (ttv.state == TableState::Free)
ttv.state = TableState::Generic;
else if (ttv.state == TableState::Unsealed)
ttv.state = TableState::Sealed;
ttv.level = level;
return true;
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
if (!level.subsumes(ftp.level))
return false;
*asMutable(tp) = GenericTypePack{level};
genericPacks.push_back(tp);
return true;
}
};
void quantify(ModulePtr module, TypeId ty, TypeLevel level)
{
Quantifier q{std::move(module), level};
visitTypeVar(ty, q);
FunctionTypeVar* ftv = getMutable<FunctionTypeVar>(ty);
LUAU_ASSERT(ftv);
ftv->generics = q.generics;
ftv->genericPacks = q.genericPacks;
}
} // namespace Luau

View File

@ -182,7 +182,7 @@ struct RequireTracerOld : AstVisitor
struct RequireTracer : AstVisitor
{
RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName)
RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName)
: result(result)
, fileResolver(fileResolver)
, currentModuleName(currentModuleName)
@ -260,7 +260,7 @@ struct RequireTracer : AstVisitor
// seed worklist with require arguments
work.reserve(requires.size());
for (AstExprCall* require: requires)
for (AstExprCall* require : requires)
work.push_back(require->args.data[0]);
// push all dependent expressions to the work stack; note that the vector is modified during traversal

View File

@ -10,7 +10,6 @@
#include <algorithm>
#include <stdexcept>
LUAU_FASTFLAG(LuauExtraNilRecovery)
LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions)
LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false)
LUAU_FASTFLAG(LuauTypeAliasPacks)
@ -159,15 +158,6 @@ struct StringifierState
seen.erase(iter);
}
static std::string generateName(size_t i)
{
std::string n;
n = char('a' + i % 26);
if (i >= 26)
n += std::to_string(i / 26);
return n;
}
std::string getName(TypeId ty)
{
const size_t s = result.nameMap.typeVars.size();
@ -584,8 +574,7 @@ struct TypeVarStringifier
std::vector<std::string> results = {};
for (auto el : &uv)
{
if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow)
el = follow(el);
el = follow(el);
if (isNil(el))
{
@ -649,8 +638,7 @@ struct TypeVarStringifier
std::vector<std::string> results = {};
for (auto el : uv.parts)
{
if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow)
el = follow(el);
el = follow(el);
std::string saved = std::move(state.result.name);
@ -1204,4 +1192,13 @@ void dump(TypePackId ty)
printf("%s\n", toString(ty, opts).c_str());
}
std::string generateName(size_t i)
{
std::string n;
n = char('a' + i % 26);
if (i >= 26)
n += std::to_string(i / 26);
return n;
}
} // namespace Luau

View File

@ -298,8 +298,15 @@ struct ArcCollector : public AstVisitor
struct ContainsFunctionCall : public AstVisitor
{
bool alsoReturn = false;
bool result = false;
ContainsFunctionCall() = default;
explicit ContainsFunctionCall(bool alsoReturn)
: alsoReturn(alsoReturn)
{
}
bool visit(AstExpr*) override
{
return !result; // short circuit if result is true
@ -318,6 +325,17 @@ struct ContainsFunctionCall : public AstVisitor
return false;
}
bool visit(AstStatReturn* stat) override
{
if (alsoReturn)
{
result = true;
return false;
}
else
return AstVisitor::visit(stat);
}
bool visit(AstExprFunction*) override
{
return false;
@ -479,6 +497,13 @@ bool containsFunctionCall(const AstStat& stat)
return cfc.result;
}
bool containsFunctionCallOrReturn(const AstStat& stat)
{
detail::ContainsFunctionCall cfc{true};
const_cast<AstStat&>(stat).visit(&cfc);
return cfc.result;
}
bool isFunction(const AstStat& stat)
{
return stat.is<AstStatFunction>() || stat.is<AstStatLocalFunction>();

View File

@ -5,6 +5,8 @@
#include <algorithm>
LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false)
namespace Luau
{
@ -33,6 +35,12 @@ void TxnLog::rollback()
for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it)
std::swap(it->first->boundTo, it->second);
if (FFlag::LuauShareTxnSeen)
{
LUAU_ASSERT(originalSeenSize <= sharedSeen->size());
sharedSeen->resize(originalSeenSize);
}
}
void TxnLog::concat(TxnLog rhs)
@ -46,27 +54,44 @@ void TxnLog::concat(TxnLog rhs)
tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end());
rhs.tableChanges.clear();
seen.swap(rhs.seen);
rhs.seen.clear();
if (!FFlag::LuauShareTxnSeen)
{
ownedSeen.swap(rhs.ownedSeen);
rhs.ownedSeen.clear();
}
}
bool TxnLog::haveSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair));
if (FFlag::LuauShareTxnSeen)
return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair));
else
return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair));
}
void TxnLog::pushSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
seen.push_back(sortedPair);
if (FFlag::LuauShareTxnSeen)
sharedSeen->push_back(sortedPair);
else
ownedSeen.push_back(sortedPair);
}
void TxnLog::popSeen(TypeId lhs, TypeId rhs)
{
const std::pair<TypeId, TypeId> sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs);
LUAU_ASSERT(sortedPair == seen.back());
seen.pop_back();
if (FFlag::LuauShareTxnSeen)
{
LUAU_ASSERT(sortedPair == sharedSeen->back());
sharedSeen->pop_back();
}
else
{
LUAU_ASSERT(sortedPair == ownedSeen.back());
ownedSeen.pop_back();
}
}
} // namespace Luau

View File

@ -6,6 +6,7 @@
#include "Luau/Parser.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypePack.h"
#include "Luau/TypeVar.h"
@ -33,14 +34,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data
return result;
}
using SyntheticNames = std::unordered_map<const void*, char*>;
namespace Luau
{
static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen)
{
size_t s = syntheticNames->size();
char*& n = (*syntheticNames)[&gen];
if (!n)
{
std::string str = gen.explicitName ? gen.name : generateName(s);
n = static_cast<char*>(allocator->allocate(str.size() + 1));
strcpy(n, str.c_str());
}
return n;
}
class TypeRehydrationVisitor
{
mutable std::map<void*, int> seen;
mutable int count = 0;
std::map<void*, int> seen;
int count = 0;
bool hasSeen(const void* tv) const
bool hasSeen(const void* tv)
{
void* ttv = const_cast<void*>(tv);
auto it = seen.find(ttv);
@ -52,15 +70,16 @@ class TypeRehydrationVisitor
}
public:
TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions())
TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions())
: allocator(alloc)
, syntheticNames(syntheticNames)
, options(options)
{
}
AstTypePack* rehydrate(TypePackId tp) const;
AstTypePack* rehydrate(TypePackId tp);
AstType* operator()(const PrimitiveTypeVar& ptv) const
AstType* operator()(const PrimitiveTypeVar& ptv)
{
switch (ptv.type)
{
@ -78,11 +97,11 @@ public:
return nullptr;
}
}
AstType* operator()(const AnyTypeVar&) const
AstType* operator()(const AnyTypeVar&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("any"));
}
AstType* operator()(const TableTypeVar& ttv) const
AstType* operator()(const TableTypeVar& ttv)
{
RecursionCounter counter(&count);
@ -144,12 +163,12 @@ public:
return allocator->alloc<AstTypeTable>(Location(), props, indexer);
}
AstType* operator()(const MetatableTypeVar& mtv) const
AstType* operator()(const MetatableTypeVar& mtv)
{
return Luau::visit(*this, mtv.table->ty);
}
AstType* operator()(const ClassTypeVar& ctv) const
AstType* operator()(const ClassTypeVar& ctv)
{
RecursionCounter counter(&count);
@ -176,7 +195,7 @@ public:
return allocator->alloc<AstTypeTable>(Location(), props);
}
AstType* operator()(const FunctionTypeVar& ftv) const
AstType* operator()(const FunctionTypeVar& ftv)
{
RecursionCounter counter(&count);
@ -253,10 +272,12 @@ public:
size_t i = 0;
for (const auto& el : ftv.argNames)
{
std::optional<AstArgumentName>* arg = &argNames.data[i++];
if (el)
argNames.data[i++] = {AstName(el->name.c_str()), el->location};
new (arg) std::optional<AstArgumentName>(AstArgumentName(AstName(el->name.c_str()), el->location));
else
argNames.data[i++] = {};
new (arg) std::optional<AstArgumentName>();
}
AstArray<AstType*> returnTypes;
@ -290,23 +311,23 @@ public:
return allocator->alloc<AstTypeFunction>(
Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation});
}
AstType* operator()(const Unifiable::Error&) const
AstType* operator()(const Unifiable::Error&)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("Unifiable<Error>"));
}
AstType* operator()(const GenericTypeVar& gtv) const
AstType* operator()(const GenericTypeVar& gtv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(gtv.name.c_str()));
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv)));
}
AstType* operator()(const Unifiable::Bound<TypeId>& bound) const
AstType* operator()(const Unifiable::Bound<TypeId>& bound)
{
return Luau::visit(*this, bound.boundTo->ty);
}
AstType* operator()(Unifiable::Free ftv) const
AstType* operator()(const FreeTypeVar& ftv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("free"));
}
AstType* operator()(const UnionTypeVar& uv) const
AstType* operator()(const UnionTypeVar& uv)
{
AstArray<AstType*> unionTypes;
unionTypes.size = uv.options.size();
@ -317,7 +338,7 @@ public:
}
return allocator->alloc<AstTypeUnion>(Location(), unionTypes);
}
AstType* operator()(const IntersectionTypeVar& uv) const
AstType* operator()(const IntersectionTypeVar& uv)
{
AstArray<AstType*> intersectionTypes;
intersectionTypes.size = uv.parts.size();
@ -328,23 +349,28 @@ public:
}
return allocator->alloc<AstTypeIntersection>(Location(), intersectionTypes);
}
AstType* operator()(const LazyTypeVar& ltv) const
AstType* operator()(const LazyTypeVar& ltv)
{
return allocator->alloc<AstTypeReference>(Location(), std::nullopt, AstName("<Lazy?>"));
}
private:
Allocator* allocator;
SyntheticNames* syntheticNames;
const TypeRehydrationOptions& options;
};
class TypePackRehydrationVisitor
{
public:
TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor)
TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor)
: allocator(allocator)
, syntheticNames(syntheticNames)
, typeVisitor(typeVisitor)
{
LUAU_ASSERT(allocator);
LUAU_ASSERT(syntheticNames);
LUAU_ASSERT(typeVisitor);
}
AstTypePack* operator()(const BoundTypePack& btp) const
@ -359,7 +385,7 @@ public:
head.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * tp.head.size()));
for (size_t i = 0; i < tp.head.size(); i++)
head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty);
head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty);
AstTypePack* tail = nullptr;
@ -371,12 +397,12 @@ public:
AstTypePack* operator()(const VariadicTypePack& vtp) const
{
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(typeVisitor, vtp.ty->ty));
return allocator->alloc<AstTypePackVariadic>(Location(), Luau::visit(*typeVisitor, vtp.ty->ty));
}
AstTypePack* operator()(const GenericTypePack& gtp) const
{
return allocator->alloc<AstTypePackGeneric>(Location(), AstName(gtp.name.c_str()));
return allocator->alloc<AstTypePackGeneric>(Location(), AstName(getName(allocator, syntheticNames, gtp)));
}
AstTypePack* operator()(const FreeTypePack& gtp) const
@ -391,12 +417,13 @@ public:
private:
Allocator* allocator;
const TypeRehydrationVisitor& typeVisitor;
SyntheticNames* syntheticNames;
TypeRehydrationVisitor* typeVisitor;
};
AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const
AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp)
{
TypePackRehydrationVisitor tprv(allocator, *this);
TypePackRehydrationVisitor tprv(allocator, syntheticNames, this);
return Luau::visit(tprv, tp->ty);
}
@ -431,7 +458,7 @@ public:
{
if (!type)
return nullptr;
return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty);
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty);
}
AstArray<Luau::AstType*> typeAstPack(TypePackId type)
@ -443,7 +470,7 @@ public:
result.data = static_cast<AstType**>(allocator->allocate(sizeof(AstType*) * v.size()));
for (size_t i = 0; i < v.size(); ++i)
{
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty);
result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty);
}
return result;
}
@ -495,7 +522,7 @@ public:
{
if (FFlag::LuauTypeAliasPacks)
{
variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail);
variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail);
}
else
{
@ -515,6 +542,7 @@ public:
private:
Module& module;
Allocator* allocator;
SyntheticNames syntheticNames;
};
void attachTypeData(SourceModule& source, Module& result)
@ -525,7 +553,8 @@ void attachTypeData(SourceModule& source, Module& result)
AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options)
{
return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty);
SyntheticNames syntheticNames;
return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty);
}
} // namespace Luau

View File

@ -4,6 +4,7 @@
#include "Luau/Common.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Parser.h"
#include "Luau/Quantify.h"
#include "Luau/RecursionCounter.h"
#include "Luau/Scope.h"
#include "Luau/Substitution.h"
@ -33,18 +34,16 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false)
LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false)
LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false)
LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false)
LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false)
LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false)
LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false)
LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false)
LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false)
LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false)
LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false)
LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false)
LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false)
LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false)
LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false)
LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes)
LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false)
LUAU_FASTFLAG(LuauNewRequireTrace)
LUAU_FASTFLAG(LuauTypeAliasPacks)
@ -215,6 +214,7 @@ static bool isMetamethod(const Name& name)
TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler)
: resolver(resolver)
, iceHandler(iceHandler)
, unifierState(iceHandler)
, nilType(singletonTypes.nilType)
, numberType(singletonTypes.numberType)
, stringType(singletonTypes.stringType)
@ -370,13 +370,18 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
return;
}
int subLevel = 0;
std::vector<AstStat*> sorted(block.body.data, block.body.data + block.body.size);
toposort(sorted);
for (const auto& stat : sorted)
{
if (const auto& typealias = stat->as<AstStatTypeAlias>())
check(scope, *typealias, true);
{
check(scope, *typealias, subLevel, true);
++subLevel;
}
}
auto protoIter = sorted.begin();
@ -399,8 +404,6 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
}
};
int subLevel = 0;
while (protoIter != sorted.end())
{
// protoIter walks forward
@ -433,7 +436,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block)
// function f<a>(x:a):a local x: number = g(37) return x end
// function g(x:number):number return f(x) end
// ```
if (containsFunctionCall(**protoIter))
if (FFlag::LuauQuantifyInPlace2 ? containsFunctionCallOrReturn(**protoIter) : containsFunctionCall(**protoIter))
{
while (checkIter != protoIter)
{
@ -1161,7 +1164,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco
scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location};
}
void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare)
void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare)
{
// This function should be called at most twice for each type alias.
// Once with forwardDeclare, and once without.
@ -1189,11 +1192,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
}
else
{
ScopePtr aliasScope = childScope(scope, typealias.location);
ScopePtr aliasScope =
FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location);
if (FFlag::LuauTypeAliasPacks)
{
auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks);
auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks);
TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true));
FreeTypeVar* ftv = getMutable<FreeTypeVar>(ty);
@ -1418,7 +1422,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo
{
ScopePtr funScope = childFunctionScope(scope, global.location);
auto [generics, genericPacks] = createGenericTypes(funScope, global, global.generics, global.genericPacks);
auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks);
TypePackId argPack = resolveTypePack(funScope, global.params);
TypePackId retPack = resolveTypePack(funScope, global.retTypes);
@ -1610,25 +1614,11 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn
if (std::optional<TypeId> ty = resolveLValue(scope, *lvalue))
return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}};
if (FFlag::LuauExtraNilRecovery)
lhsType = stripFromNilAndReport(lhsType, expr.expr->location);
lhsType = stripFromNilAndReport(lhsType, expr.expr->location);
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true))
return {*ty};
if (!FFlag::LuauMissingUnionPropertyError)
reportError(expr.indexLocation, UnknownProperty{lhsType, expr.index.value});
if (!FFlag::LuauExtraNilRecovery)
{
// Try to recover using a union without 'nil' options
if (std::optional<TypeId> strippedUnion = tryStripUnionFromNil(lhsType))
{
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, *strippedUnion, name, expr.location, false))
return {*ty};
}
}
return {errorType};
}
@ -1694,61 +1684,37 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
}
else if (const UnionTypeVar* utv = get<UnionTypeVar>(type))
{
if (FFlag::LuauMissingUnionPropertyError)
std::vector<TypeId> goodOptions;
std::vector<TypeId> badOptions;
for (TypeId t : utv)
{
std::vector<TypeId> goodOptions;
std::vector<TypeId> badOptions;
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
for (TypeId t : utv)
{
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
goodOptions.push_back(*ty);
else
badOptions.push_back(t);
}
if (!badOptions.empty())
{
if (addErrors)
{
if (goodOptions.empty())
reportError(location, UnknownProperty{type, name});
else
reportError(location, MissingUnionProperty{type, badOptions, name});
}
return std::nullopt;
}
std::vector<TypeId> result = reduceUnion(goodOptions);
if (result.size() == 1)
return result[0];
return addType(UnionTypeVar{std::move(result)});
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
goodOptions.push_back(*ty);
else
badOptions.push_back(t);
}
else
if (!badOptions.empty())
{
std::vector<TypeId> options;
for (TypeId t : utv->options)
if (addErrors)
{
RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, false))
options.push_back(*ty);
if (goodOptions.empty())
reportError(location, UnknownProperty{type, name});
else
return std::nullopt;
reportError(location, MissingUnionProperty{type, badOptions, name});
}
std::vector<TypeId> result = reduceUnion(options);
if (result.size() == 1)
return result[0];
return addType(UnionTypeVar{std::move(result)});
return std::nullopt;
}
std::vector<TypeId> result = reduceUnion(goodOptions);
if (result.size() == 1)
return result[0];
return addType(UnionTypeVar{std::move(result)});
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type))
{
@ -1765,7 +1731,7 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
// If no parts of the intersection had the property we looked up for, it never existed at all.
if (parts.empty())
{
if (FFlag::LuauMissingUnionPropertyError && addErrors)
if (addErrors)
reportError(location, UnknownProperty{type, name});
return std::nullopt;
}
@ -1779,7 +1745,7 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromType(
return addType(IntersectionTypeVar{result});
}
if (FFlag::LuauMissingUnionPropertyError && addErrors)
if (addErrors)
reportError(location, UnknownProperty{type, name});
return std::nullopt;
@ -2062,8 +2028,7 @@ ExprResult<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn
case AstExprUnary::Len:
tablify(operandType);
if (FFlag::LuauExtraNilRecovery)
operandType = stripFromNilAndReport(operandType, expr.location);
operandType = stripFromNilAndReport(operandType, expr.location);
if (get<ErrorTypeVar>(operandType))
return {errorType};
@ -2635,8 +2600,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
Name name = expr.index.value;
if (FFlag::LuauExtraNilRecovery)
lhs = stripFromNilAndReport(lhs, expr.expr->location);
lhs = stripFromNilAndReport(lhs, expr.expr->location);
if (TableTypeVar* lhsTable = getMutableTableType(lhs))
{
@ -2710,8 +2674,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
TypeId exprType = checkExpr(scope, *expr.expr).type;
tablify(exprType);
if (FFlag::LuauExtraNilRecovery)
exprType = stripFromNilAndReport(exprType, expr.expr->location);
exprType = stripFromNilAndReport(exprType, expr.expr->location);
TypeId indexType = checkExpr(scope, *expr.index).type;
@ -2738,10 +2701,7 @@ std::pair<TypeId, TypeId*> TypeChecker::checkLValueBinding(const ScopePtr& scope
if (!exprTable)
{
if (FFlag::LuauExtraNilRecovery)
reportError(TypeError{expr.expr->location, NotATable{exprType}});
else
reportError(TypeError{expr.location, NotATable{exprType}});
reportError(TypeError{expr.expr->location, NotATable{exprType}});
return std::pair(errorType, nullptr);
}
@ -2910,7 +2870,7 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
if (FFlag::LuauGenericFunctions)
{
std::tie(generics, genericPacks) = createGenericTypes(funScope, expr, expr.generics, expr.genericPacks);
std::tie(generics, genericPacks) = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks);
}
TypePackId retPack;
@ -3016,9 +2976,6 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
if (expectedArgsCurr != expectedArgsEnd)
{
argType = *expectedArgsCurr;
if (!FFlag::LuauInferFunctionArgsFix)
++expectedArgsCurr;
}
else if (auto expectedArgsTail = expectedArgsCurr.tail())
{
@ -3034,7 +2991,7 @@ std::pair<TypeId, ScopePtr> TypeChecker::checkFunctionSignature(
funScope->bindings[local] = {argType, local->location};
argTypes.push_back(argType);
if (FFlag::LuauInferFunctionArgsFix && expectedArgsCurr != expectedArgsEnd)
if (expectedArgsCurr != expectedArgsEnd)
++expectedArgsCurr;
}
@ -3402,8 +3359,7 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
if (!FFlag::LuauRankNTypes)
instantiate(scope, selfType, expr.func->location);
if (FFlag::LuauExtraNilRecovery)
selfType = stripFromNilAndReport(selfType, expr.func->location);
selfType = stripFromNilAndReport(selfType, expr.func->location);
if (std::optional<TypeId> propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true))
{
@ -3412,34 +3368,8 @@ ExprResult<TypePackId> TypeChecker::checkExprPack(const ScopePtr& scope, const A
}
else
{
if (!FFlag::LuauMissingUnionPropertyError)
reportError(indexExpr->indexLocation, UnknownProperty{selfType, indexExpr->index.value});
if (!FFlag::LuauExtraNilRecovery)
{
// Try to recover using a union without 'nil' options
if (std::optional<TypeId> strippedUnion = tryStripUnionFromNil(selfType))
{
if (std::optional<TypeId> propTy = getIndexTypeFromType(scope, *strippedUnion, indexExpr->index.value, expr.location, false))
{
selfType = *strippedUnion;
functionType = *propTy;
actualFunctionType = instantiate(scope, functionType, expr.func->location);
}
}
if (!actualFunctionType)
{
functionType = errorType;
actualFunctionType = errorType;
}
}
else
{
functionType = errorType;
actualFunctionType = errorType;
}
functionType = errorType;
actualFunctionType = errorType;
}
}
else
@ -3555,8 +3485,7 @@ std::optional<ExprResult<TypePackId>> TypeChecker::checkCallOverload(const Scope
TypePackId argPack, TypePack* args, const std::vector<Location>& argLocations, const ExprResult<TypePackId>& argListResult,
std::vector<TypeId>& overloadsThatMatchArgCount, std::vector<OverloadErrorEntry>& errors)
{
if (FFlag::LuauExtraNilRecovery)
fn = stripFromNilAndReport(fn, expr.func->location);
fn = stripFromNilAndReport(fn, expr.func->location);
if (get<AnyTypeVar>(fn))
{
@ -4283,6 +4212,12 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location
if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty())
return ty;
if (FFlag::LuauQuantifyInPlace2)
{
Luau::quantify(currentModule, ty, scope->level);
return ty;
}
quantification.level = scope->level;
quantification.generics.clear();
quantification.genericPacks.clear();
@ -4491,12 +4426,12 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r)
Unifier TypeChecker::mkUnifier(const Location& location)
{
return Unifier{&currentModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, iceHandler};
return Unifier{&currentModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState};
}
Unifier TypeChecker::mkUnifier(const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location)
{
return Unifier{&currentModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, iceHandler};
return Unifier{&currentModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, unifierState};
}
TypeId TypeChecker::freshType(const ScopePtr& scope)
@ -4753,7 +4688,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation
if (FFlag::LuauGenericFunctions)
{
std::tie(generics, genericPacks) = createGenericTypes(funcScope, annotation, func->generics, func->genericPacks);
std::tie(generics, genericPacks) = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks);
}
// TODO: better error message CLI-39912
@ -5041,10 +4976,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf,
}
std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGenericTypes(
const ScopePtr& scope, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames)
const ScopePtr& scope, std::optional<TypeLevel> levelOpt, const AstNode& node, const AstArray<AstName>& genericNames, const AstArray<AstName>& genericPackNames)
{
LUAU_ASSERT(scope->parent);
const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level;
std::vector<TypeId> generics;
for (const AstName& generic : genericNames)
{
@ -5063,12 +5000,12 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
{
TypeId& cached = scope->parent->typeAliasTypeParameters[n];
if (!cached)
cached = addType(GenericTypeVar{scope->level, n});
cached = addType(GenericTypeVar{level, n});
g = cached;
}
else
{
g = addType(Unifiable::Generic{scope->level, n});
g = addType(Unifiable::Generic{level, n});
}
generics.push_back(g);
@ -5093,12 +5030,12 @@ std::pair<std::vector<TypeId>, std::vector<TypePackId>> TypeChecker::createGener
{
TypePackId& cached = scope->parent->typeAliasTypePackParameters[n];
if (!cached)
cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}});
cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}});
g = cached;
}
else
{
g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}});
g = addTypePack(TypePackVar{Unifiable::Generic{level, n}});
}
genericPacks.push_back(g);

View File

@ -22,6 +22,7 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0)
LUAU_FASTFLAG(LuauRankNTypes)
LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses)
LUAU_FASTFLAG(LuauTypeAliasPacks)
LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false)
namespace Luau
{
@ -217,8 +218,7 @@ std::optional<TypeId> getMetatable(TypeId type)
return mtType->metatable;
else if (const ClassTypeVar* classType = get<ClassTypeVar>(type))
return classType->metatable;
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type);
primitiveType && primitiveType->metatable)
else if (const PrimitiveTypeVar* primitiveType = get<PrimitiveTypeVar>(type); primitiveType && primitiveType->metatable)
{
LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String);
return primitiveType->metatable;
@ -1490,4 +1490,86 @@ std::vector<TypeId> filterMap(TypeId type, TypeIdPredicate predicate)
return {};
}
static Tags* getTags(TypeId ty)
{
ty = follow(ty);
if (auto ftv = getMutable<FunctionTypeVar>(ty))
return &ftv->tags;
else if (auto ttv = getMutable<TableTypeVar>(ty))
return &ttv->tags;
else if (auto ctv = getMutable<ClassTypeVar>(ty))
return &ctv->tags;
return nullptr;
}
void attachTag(TypeId ty, const std::string& tagName)
{
if (!FFlag::LuauRefactorTagging)
{
if (auto ftv = getMutable<FunctionTypeVar>(ty))
{
ftv->tags.emplace_back(tagName);
}
else
{
LUAU_ASSERT(!"Got a non functional type");
}
}
else
{
if (auto tags = getTags(ty))
tags->push_back(tagName);
else
LUAU_ASSERT(!"This TypeId does not support tags");
}
}
void attachTag(Property& prop, const std::string& tagName)
{
LUAU_ASSERT(FFlag::LuauRefactorTagging);
prop.tags.push_back(tagName);
}
// We would ideally not expose this because it could cause a footgun.
// If the Base class has a tag and you ask if Derived has that tag, it would return false.
// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it.
bool hasTag(const Tags& tags, const std::string& tagName)
{
LUAU_ASSERT(FFlag::LuauRefactorTagging);
return std::find(tags.begin(), tags.end(), tagName) != tags.end();
}
bool hasTag(TypeId ty, const std::string& tagName)
{
ty = follow(ty);
// We special case classes because getTags only returns a pointer to one vector of tags.
// But classes has multiple vector of tags, represented throughout the hierarchy.
if (auto ctv = get<ClassTypeVar>(ty))
{
while (ctv)
{
if (hasTag(ctv->tags, tagName))
return true;
else if (!ctv->parent)
return false;
ctv = get<ClassTypeVar>(*ctv->parent);
LUAU_ASSERT(ctv);
}
}
else if (auto tags = getTags(ty))
return hasTag(*tags, tagName);
return false;
}
bool hasTag(const Property& prop, const std::string& tagName)
{
return hasTag(prop.tags, tagName);
}
} // namespace Luau

View File

@ -7,6 +7,7 @@
#include "Luau/TypePack.h"
#include "Luau/TypeUtils.h"
#include "Luau/TimeTrace.h"
#include "Luau/VisitTypeVar.h"
#include <algorithm>
@ -22,9 +23,99 @@ LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false)
LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false)
LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false)
LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false)
LUAU_FASTFLAG(LuauShareTxnSeen);
LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false)
namespace Luau
{
struct SkipCacheForType
{
SkipCacheForType(const DenseHashMap<TypeId, bool>& skipCacheForType)
: skipCacheForType(skipCacheForType)
{
}
void cycle(TypeId) {}
void cycle(TypePackId) {}
bool operator()(TypeId ty, const FreeTypeVar& ftv)
{
result = true;
return false;
}
bool operator()(TypeId ty, const BoundTypeVar& btv)
{
result = true;
return false;
}
bool operator()(TypeId ty, const GenericTypeVar& btv)
{
result = true;
return false;
}
bool operator()(TypeId ty, const TableTypeVar&)
{
TableTypeVar& ttv = *getMutable<TableTypeVar>(ty);
if (ttv.boundTo)
{
result = true;
return false;
}
if (ttv.state != TableState::Sealed)
{
result = true;
return false;
}
return true;
}
template<typename T>
bool operator()(TypeId ty, const T& t)
{
const bool* prev = skipCacheForType.find(ty);
if (prev && *prev)
{
result = true;
return false;
}
return true;
}
template<typename T>
bool operator()(TypePackId, const T&)
{
return true;
}
bool operator()(TypePackId tp, const FreeTypePack& ftp)
{
result = true;
return false;
}
bool operator()(TypePackId tp, const BoundTypePack& ftp)
{
result = true;
return false;
}
bool operator()(TypePackId tp, const GenericTypePack& ftp)
{
result = true;
return false;
}
const DenseHashMap<TypeId, bool>& skipCacheForType;
bool result = false;
};
static std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
{
@ -39,7 +130,7 @@ static std::optional<TypeError> hasUnificationTooComplex(const ErrorVec& errors)
return *it;
}
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler)
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState)
: types(types)
, mode(mode)
, globalScope(std::move(globalScope))
@ -47,24 +138,39 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati
, variance(variance)
, counters(&countersData)
, counters_DEPRECATED(std::make_shared<UnifierCounters>())
, iceHandler(iceHandler)
, sharedState(sharedState)
{
LUAU_ASSERT(iceHandler);
LUAU_ASSERT(sharedState.iceHandler);
}
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& seen, const Location& location,
Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED, UnifierCounters* counters)
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector<std::pair<TypeId, TypeId>>& ownedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED, UnifierCounters* counters)
: types(types)
, mode(mode)
, globalScope(std::move(globalScope))
, log(seen)
, log(ownedSeen)
, location(location)
, variance(variance)
, counters(counters ? counters : &countersData)
, counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared<UnifierCounters>())
, iceHandler(iceHandler)
, sharedState(sharedState)
{
LUAU_ASSERT(iceHandler);
LUAU_ASSERT(sharedState.iceHandler);
}
Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector<std::pair<TypeId, TypeId>>* sharedSeen, const Location& location,
Variance variance, UnifierSharedState& sharedState, const std::shared_ptr<UnifierCounters>& counters_DEPRECATED, UnifierCounters* counters)
: types(types)
, mode(mode)
, globalScope(std::move(globalScope))
, log(sharedSeen)
, location(location)
, variance(variance)
, counters(counters ? counters : &countersData)
, counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared<UnifierCounters>())
, sharedState(sharedState)
{
LUAU_ASSERT(sharedState.iceHandler);
}
void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection)
@ -74,7 +180,7 @@ void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool i
else
counters_DEPRECATED->iterationCount = 0;
return tryUnify_(superTy, subTy, isFunctionCall, isIntersection);
tryUnify_(superTy, subTy, isFunctionCall, isIntersection);
}
void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection)
@ -206,6 +312,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
if (get<ErrorTypeVar>(subTy) || get<AnyTypeVar>(subTy))
return tryUnifyWithAny(subTy, superTy);
bool cacheEnabled = FFlag::LuauCacheUnifyTableResults && !isFunctionCall && !isIntersection;
auto& cache = sharedState.cachedUnify;
// What if the types are immutable and we proved their relation before
if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy})))
return;
// If we have seen this pair of types before, we are currently recursing into cyclic types.
// Here, we assume that the types unify. If they do not, we will find out as we roll back
// the stack.
@ -257,6 +370,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
if (FFlag::LuauUnionHeuristic)
{
bool found = false;
const std::string* subName = getName(subTy);
if (subName)
{
@ -264,6 +379,21 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
{
const std::string* optionName = getName(uv->options[i]);
if (optionName && *optionName == *subName)
{
found = true;
startIndex = i;
break;
}
}
}
if (!found && cacheEnabled)
{
for (size_t i = 0; i < uv->options.size(); ++i)
{
TypeId type = uv->options[i];
if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type})))
{
startIndex = i;
break;
@ -311,8 +441,25 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
bool found = false;
std::optional<TypeError> unificationTooComplex;
for (TypeId type : uv->parts)
size_t startIndex = 0;
if (cacheEnabled)
{
for (size_t i = 0; i < uv->parts.size(); ++i)
{
TypeId type = uv->parts[i];
if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy})))
{
startIndex = i;
break;
}
}
}
for (size_t i = 0; i < uv->parts.size(); ++i)
{
TypeId type = uv->parts[(i + startIndex) % uv->parts.size()];
Unifier innerState = makeChildUnifier();
innerState.tryUnify_(superTy, type, isFunctionCall);
@ -342,8 +489,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
tryUnifyFunctions(superTy, subTy, isFunctionCall);
else if (get<TableTypeVar>(superTy) && get<TableTypeVar>(subTy))
{
tryUnifyTables(superTy, subTy, isIntersection);
if (cacheEnabled && errors.empty())
cacheResult(superTy, subTy);
}
// tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical.
else if (get<MetatableTypeVar>(superTy))
tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false);
@ -364,6 +516,41 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool
log.popSeen(superTy, subTy);
}
void Unifier::cacheResult(TypeId superTy, TypeId subTy)
{
LUAU_ASSERT(FFlag::LuauCacheUnifyTableResults);
bool* superTyInfo = sharedState.skipCacheForType.find(superTy);
if (superTyInfo && *superTyInfo)
return;
bool* subTyInfo = sharedState.skipCacheForType.find(subTy);
if (subTyInfo && *subTyInfo)
return;
auto skipCacheFor = [this](TypeId ty) {
SkipCacheForType visitor{sharedState.skipCacheForType};
visitTypeVarOnce(ty, visitor, sharedState.seenAny);
sharedState.skipCacheForType[ty] = visitor.result;
return visitor.result;
};
if (!superTyInfo && skipCacheFor(superTy))
return;
if (!subTyInfo && skipCacheFor(subTy))
return;
sharedState.cachedUnify.insert({superTy, subTy});
if (variance == Invariant)
sharedState.cachedUnify.insert({subTy, superTy});
}
struct WeirdIter
{
TypePackId packId;
@ -459,7 +646,7 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall
else
counters_DEPRECATED->iterationCount = 0;
return tryUnify_(superTp, subTp, isFunctionCall);
tryUnify_(superTp, subTp, isFunctionCall);
}
/*
@ -797,6 +984,40 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
std::vector<std::string> missingProperties;
std::vector<std::string> extraProperties;
// Optimization: First test that the property sets are compatible without doing any recursive unification
if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer && rt->state != TableState::Free)
{
for (const auto& [propName, superProp] : lt->props)
{
auto subIter = rt->props.find(propName);
if (subIter == rt->props.end() && !isOptional(superProp.type) && !get<AnyTypeVar>(follow(superProp.type)))
missingProperties.push_back(propName);
}
if (!missingProperties.empty())
{
errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}});
return;
}
}
// And vice versa if we're invariant
if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free)
{
for (const auto& [propName, subProp] : rt->props)
{
auto superIter = lt->props.find(propName);
if (superIter == lt->props.end() && !isOptional(subProp.type) && !get<AnyTypeVar>(follow(subProp.type)))
extraProperties.push_back(propName);
}
if (!extraProperties.empty())
{
errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}});
return;
}
}
// Reminder: left is the supertype, right is the subtype.
// Width subtyping: any property in the supertype must be in the subtype,
// and the types must agree.
@ -833,9 +1054,10 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
innerState.log.rollback();
}
else if (isOptional(prop.type) || get<AnyTypeVar>(follow(prop.type)))
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
// TODO: should isOptional(anyType) be true?
{}
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
// TODO: should isOptional(anyType) be true?
{
}
else if (rt->state == TableState::Free)
{
log(rt);
@ -878,11 +1100,13 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
lt->props[name] = clone;
}
else if (variance == Covariant)
{}
{
}
else if (isOptional(prop.type) || get<AnyTypeVar>(follow(prop.type)))
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
// TODO: should isOptional(anyType) be true?
{}
// TODO: this case is unsound, but without it our test suite fails. CLI-46031
// TODO: should isOptional(anyType) be true?
{
}
else if (lt->state == TableState::Free)
{
log(lt);
@ -980,10 +1204,10 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map<TypeId, TypeId> see
TableTypeVar* resultTtv = getMutable<TableTypeVar>(result);
for (auto& [name, prop] : resultTtv->props)
prop.type = deeplyOptional(prop.type, seen);
return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});;
return types->addType(UnionTypeVar{{singletonTypes.nilType, result}});
}
else
return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }});
return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}});
}
void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection)
@ -1697,10 +1921,20 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty)
{
std::vector<TypeId> queue = {ty};
tempSeenTy.clear();
tempSeenTp.clear();
if (FFlag::LuauCacheUnifyTableResults)
{
sharedState.tempSeenTy.clear();
sharedState.tempSeenTp.clear();
Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP);
Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP);
}
else
{
tempSeenTy_DEPRECATED.clear();
tempSeenTp_DEPRECATED.clear();
Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, singletonTypes.anyType, anyTP);
}
}
else
{
@ -1721,12 +1955,24 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty)
{
std::vector<TypeId> queue;
tempSeenTy.clear();
tempSeenTp.clear();
if (FFlag::LuauCacheUnifyTableResults)
{
sharedState.tempSeenTy.clear();
sharedState.tempSeenTp.clear();
queueTypePack(queue, tempSeenTp, *this, ty, any);
queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any);
Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any);
Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any);
}
else
{
tempSeenTy_DEPRECATED.clear();
tempSeenTp_DEPRECATED.clear();
queueTypePack(queue, tempSeenTp_DEPRECATED, *this, ty, any);
Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, anyTy, any);
}
}
else
{
@ -1775,10 +2021,20 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack)
{
std::unordered_set<TypeId> seen_DEPRECATED;
if (FFlag::LuauTypecheckOpts)
tempSeenTy.clear();
if (FFlag::LuauCacheUnifyTableResults)
{
if (FFlag::LuauTypecheckOpts)
sharedState.tempSeenTy.clear();
return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack);
return occursCheck(seen_DEPRECATED, sharedState.tempSeenTy, needle, haystack);
}
else
{
if (FFlag::LuauTypecheckOpts)
tempSeenTy_DEPRECATED.clear();
return occursCheck(seen_DEPRECATED, tempSeenTy_DEPRECATED, needle, haystack);
}
}
void Unifier::occursCheck(std::unordered_set<TypeId>& seen_DEPRECATED, DenseHashSet<TypeId>& seen, TypeId needle, TypeId haystack)
@ -1851,10 +2107,20 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack)
{
std::unordered_set<TypePackId> seen_DEPRECATED;
if (FFlag::LuauTypecheckOpts)
tempSeenTp.clear();
if (FFlag::LuauCacheUnifyTableResults)
{
if (FFlag::LuauTypecheckOpts)
sharedState.tempSeenTp.clear();
return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack);
return occursCheck(seen_DEPRECATED, sharedState.tempSeenTp, needle, haystack);
}
else
{
if (FFlag::LuauTypecheckOpts)
tempSeenTp_DEPRECATED.clear();
return occursCheck(seen_DEPRECATED, tempSeenTp_DEPRECATED, needle, haystack);
}
}
void Unifier::occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, DenseHashSet<TypePackId>& seen, TypePackId needle, TypePackId haystack)
@ -1922,7 +2188,10 @@ void Unifier::occursCheck(std::unordered_set<TypePackId>& seen_DEPRECATED, Dense
Unifier Unifier::makeChildUnifier()
{
return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters};
if (FFlag::LuauShareTxnSeen)
return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, counters_DEPRECATED, counters};
else
return Unifier{types, mode, globalScope, log.ownedSeen, location, variance, sharedState, counters_DEPRECATED, counters};
}
bool Unifier::isNonstrictMode() const
@ -1940,12 +2209,12 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId
void Unifier::ice(const std::string& message, const Location& location)
{
iceHandler->ice(message, location);
sharedState.iceHandler->ice(message, location);
}
void Unifier::ice(const std::string& message)
{
iceHandler->ice(message);
sharedState.iceHandler->ice(message);
}
} // namespace Luau

View File

@ -194,20 +194,20 @@ LUAU_NOINLINE std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeDa
} // namespace Luau
// Regular scope
#define LUAU_TIMETRACE_SCOPE(name, category) \
#define LUAU_TIMETRACE_SCOPE(name, category) \
static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first)
// A scope without nested scopes that may be skipped if the time it took is less than the threshold
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \
static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \
Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec)
// Extra key/value data can be added to regular scopes
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
if (FFlag::DebugLuauTimeTracing) \
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
if (FFlag::DebugLuauTimeTracing) \
lttScopeStatic.second.eventArgument(name, value); \
} while (false)
@ -216,8 +216,8 @@ LUAU_NOINLINE std::pair<uint16_t, Luau::TimeTrace::ThreadContext&> createScopeDa
#define LUAU_TIMETRACE_SCOPE(name, category)
#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec)
#define LUAU_TIMETRACE_ARGUMENT(name, value) \
do \
{ \
do \
{ \
} while (false)
#endif

View File

@ -77,7 +77,10 @@ struct GlobalContext
// Ideally we would want all ThreadContext destructors to run
// But in VS, not all thread_local object instances are destroyed
for (ThreadContext* context : threads)
context->flushEvents();
{
if (!context->events.empty())
context->flushEvents();
}
if (traceFile)
fclose(traceFile);

View File

@ -1,4 +1,5 @@
# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
.SUFFIXES:
MAKEFLAGS+=-r -j8
COMMA=,

View File

@ -46,6 +46,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Module.h
Analysis/include/Luau/ModuleResolver.h
Analysis/include/Luau/Predicate.h
Analysis/include/Luau/Quantify.h
Analysis/include/Luau/RecursionCounter.h
Analysis/include/Luau/RequireTracer.h
Analysis/include/Luau/Scope.h
@ -63,6 +64,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/TypeVar.h
Analysis/include/Luau/Unifiable.h
Analysis/include/Luau/Unifier.h
Analysis/include/Luau/UnifierSharedState.h
Analysis/include/Luau/Variant.h
Analysis/include/Luau/VisitTypeVar.h
@ -77,6 +79,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/Linter.cpp
Analysis/src/Module.cpp
Analysis/src/Predicate.cpp
Analysis/src/Quantify.cpp
Analysis/src/RequireTracer.cpp
Analysis/src/Scope.cpp
Analysis/src/Substitution.cpp

View File

@ -13,8 +13,6 @@
#include <string.h>
LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n"
"$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n"
"$URL: www.lua.org $\n";
@ -1153,7 +1151,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*))
luaC_checkGC(L);
luaC_checkthreadsleep(L);
Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR);
memcpy(u->data + sz, &dtor, sizeof(dtor));
memcpy(&u->data + sz, &dtor, sizeof(dtor));
setuvalue(L, L->top, u);
api_incr_top(L);
return u->data;

View File

@ -5,8 +5,6 @@
#include "lstate.h"
#include "lvm.h"
LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false)
#define CO_RUN 0 /* running */
#define CO_SUS 1 /* suspended */
#define CO_NOR 2 /* 'normal' (it resumed another coroutine) */
@ -17,7 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false)
static const char* const statnames[] = {"running", "suspended", "normal", "dead"};
static int costatus(lua_State* L, lua_State* co)
static int auxstatus(lua_State* L, lua_State* co)
{
if (co == L)
return CO_RUN;
@ -34,11 +32,11 @@ static int costatus(lua_State* L, lua_State* co)
return CO_SUS; /* initial state */
}
static int luaB_costatus(lua_State* L)
static int costatus(lua_State* L)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
lua_pushstring(L, statnames[costatus(L, co)]);
lua_pushstring(L, statnames[auxstatus(L, co)]);
return 1;
}
@ -47,7 +45,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg)
// error handling for edge cases
if (co->status != LUA_YIELD)
{
int status = costatus(L, co);
int status = auxstatus(L, co);
if (status != CO_SUS)
{
lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]);
@ -115,7 +113,7 @@ static int auxresumecont(lua_State* L, lua_State* co)
}
}
static int luaB_coresumefinish(lua_State* L, int r)
static int coresumefinish(lua_State* L, int r)
{
if (r < 0)
{
@ -131,7 +129,7 @@ static int luaB_coresumefinish(lua_State* L, int r)
}
}
static int luaB_coresumey(lua_State* L)
static int coresumey(lua_State* L)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
@ -141,10 +139,10 @@ static int luaB_coresumey(lua_State* L)
if (r == CO_STATUS_BREAK)
return interruptThread(L, co);
return luaB_coresumefinish(L, r);
return coresumefinish(L, r);
}
static int luaB_coresumecont(lua_State* L, int status)
static int coresumecont(lua_State* L, int status)
{
lua_State* co = lua_tothread(L, 1);
luaL_argexpected(L, co, 1, "thread");
@ -155,10 +153,10 @@ static int luaB_coresumecont(lua_State* L, int status)
int r = auxresumecont(L, co);
return luaB_coresumefinish(L, r);
return coresumefinish(L, r);
}
static int luaB_auxwrapfinish(lua_State* L, int r)
static int auxwrapfinish(lua_State* L, int r)
{
if (r < 0)
{
@ -173,7 +171,7 @@ static int luaB_auxwrapfinish(lua_State* L, int r)
return r;
}
static int luaB_auxwrapy(lua_State* L)
static int auxwrapy(lua_State* L)
{
lua_State* co = lua_tothread(L, lua_upvalueindex(1));
int narg = cast_int(L->top - L->base);
@ -182,10 +180,10 @@ static int luaB_auxwrapy(lua_State* L)
if (r == CO_STATUS_BREAK)
return interruptThread(L, co);
return luaB_auxwrapfinish(L, r);
return auxwrapfinish(L, r);
}
static int luaB_auxwrapcont(lua_State* L, int status)
static int auxwrapcont(lua_State* L, int status)
{
lua_State* co = lua_tothread(L, lua_upvalueindex(1));
@ -195,62 +193,52 @@ static int luaB_auxwrapcont(lua_State* L, int status)
int r = auxresumecont(L, co);
return luaB_auxwrapfinish(L, r);
return auxwrapfinish(L, r);
}
static int luaB_cocreate(lua_State* L)
static int cocreate(lua_State* L)
{
luaL_checktype(L, 1, LUA_TFUNCTION);
lua_State* NL = lua_newthread(L);
if (FFlag::LuauPreferXpush)
{
lua_xpush(L, NL, 1); // push function on top of NL
}
else
{
lua_pushvalue(L, 1); /* move function to top */
lua_xmove(L, NL, 1); /* move function from L to NL */
}
lua_xpush(L, NL, 1); // push function on top of NL
return 1;
}
static int luaB_cowrap(lua_State* L)
static int cowrap(lua_State* L)
{
luaB_cocreate(L);
cocreate(L);
lua_pushcfunction(L, luaB_auxwrapy, NULL, 1, luaB_auxwrapcont);
lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont);
return 1;
}
static int luaB_yield(lua_State* L)
static int coyield(lua_State* L)
{
int nres = cast_int(L->top - L->base);
return lua_yield(L, nres);
}
static int luaB_corunning(lua_State* L)
static int corunning(lua_State* L)
{
if (lua_pushthread(L))
lua_pushnil(L); /* main thread is not a coroutine */
return 1;
}
static int luaB_yieldable(lua_State* L)
static int coyieldable(lua_State* L)
{
lua_pushboolean(L, lua_isyieldable(L));
return 1;
}
static const luaL_Reg co_funcs[] = {
{"create", luaB_cocreate},
{"running", luaB_corunning},
{"status", luaB_costatus},
{"wrap", luaB_cowrap},
{"yield", luaB_yield},
{"isyieldable", luaB_yieldable},
{"create", cocreate},
{"running", corunning},
{"status", costatus},
{"wrap", cowrap},
{"yield", coyield},
{"isyieldable", coyieldable},
{NULL, NULL},
};
@ -258,7 +246,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L)
{
luaL_register(L, LUA_COLIBNAME, co_funcs);
lua_pushcfunction(L, luaB_coresumey, "resume", 0, luaB_coresumecont);
lua_pushcfunction(L, coresumey, "resume", 0, coresumecont);
lua_setfield(L, -2, "resume");
return 1;

View File

@ -18,6 +18,7 @@
#include <string.h>
LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false)
LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false)
/*
** {======================================================
@ -536,6 +537,12 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
status = LUA_ERRERR;
}
if (FFlag::LuauCcallRestoreFix)
{
// Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored.
L->nCcalls = oldnCcalls;
}
// an error occurred, check if we have a protected error callback
if (L->global->cb.debugprotectederror)
{
@ -549,7 +556,10 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e
StkId oldtop = restorestack(L, old_top);
luaF_close(L, oldtop); /* close eventual pending closures */
seterrorobj(L, status, oldtop);
L->nCcalls = oldnCcalls;
if (!FFlag::LuauCcallRestoreFix)
{
L->nCcalls = oldnCcalls;
}
L->ci = restoreci(L, old_ci);
L->base = L->ci->base;
restore_stack_limit(L);

View File

@ -12,11 +12,9 @@
#include <string.h>
#include <stdio.h>
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false)
LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false)
LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false)
LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false)
LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false)
LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false)
LUAU_FASTFLAG(LuauArrayBoundary)
@ -66,13 +64,18 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds,
g->gcstats.currcycle.marktime += seconds;
// atomic step had to be performed during the switch and it's tracked separately
if (g->gcstate == GCSsweepstring)
if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring)
g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime;
break;
case GCSatomic:
g->gcstats.currcycle.atomictime += seconds;
break;
case GCSsweepstring:
case GCSsweep:
g->gcstats.currcycle.sweeptime += seconds;
break;
default:
LUAU_ASSERT(!"Unexpected GC state");
}
if (assist)
@ -183,33 +186,15 @@ static int traversetable(global_State* g, Table* h)
if (h->metatable)
markobject(g, cast_to(Table*, h->metatable));
if (FFlag::LuauShrinkWeakTables)
/* is there a weak mode? */
if (const char* modev = gettablemode(g, h))
{
/* is there a weak mode? */
if (const char* modev = gettablemode(g, h))
{
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
}
}
else
{
const TValue* mode = gfasttm(g, h->metatable, TM_MODE);
if (mode && ttisstring(mode))
{ /* is there a weak mode? */
const char* modev = svalue(mode);
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
weakkey = (strchr(modev, 'k') != NULL);
weakvalue = (strchr(modev, 'v') != NULL);
if (weakkey || weakvalue)
{ /* is really weak? */
h->gclist = g->weak; /* must be cleared after GC, ... */
g->weak = obj2gco(h); /* ... so put in the appropriate list */
}
}
@ -297,7 +282,7 @@ static void traversestack(global_State* g, lua_State* l, bool clearstack)
for (StkId o = l->stack; o < l->top; o++)
markvalue(g, o);
/* final traversal? */
if (g->gcstate == GCSatomic || (FFlag::LuauGcFullSkipInactiveThreads && clearstack))
if (g->gcstate == GCSatomic || clearstack)
{
StkId stack_end = l->stack + l->stacksize;
for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */
@ -336,28 +321,16 @@ static size_t propagatemark(global_State* g)
lua_State* th = gco2th(o);
g->gray = th->gclist;
if (FFlag::LuauGcFullSkipInactiveThreads)
LUAU_ASSERT(!luaC_threadsleeping(th));
// threads that are executing and the main thread are not deactivated
bool active = luaC_threadactive(th) || th == th->global->mainthread;
if (!active && g->gcstate == GCSpropagate)
{
LUAU_ASSERT(!luaC_threadsleeping(th));
traversestack(g, th, /* clearstack= */ true);
// threads that are executing and the main thread are not deactivated
bool active = luaC_threadactive(th) || th == th->global->mainthread;
if (!active && g->gcstate == GCSpropagate)
{
traversestack(g, th, /* clearstack= */ true);
l_setbit(th->stackstate, THREAD_SLEEPINGBIT);
}
else
{
th->gclist = g->grayagain;
g->grayagain = o;
black2gray(o);
traversestack(g, th, /* clearstack= */ false);
}
l_setbit(th->stackstate, THREAD_SLEEPINGBIT);
}
else
{
@ -385,12 +358,14 @@ static size_t propagatemark(global_State* g)
}
}
static void propagateall(global_State* g)
static size_t propagateall(global_State* g)
{
size_t work = 0;
while (g->gray)
{
propagatemark(g);
work += propagatemark(g);
}
return work;
}
/*
@ -415,11 +390,14 @@ static int isobjcleared(GCObject* o)
/*
** clear collected entries from weaktables
*/
static void cleartable(lua_State* L, GCObject* l)
static size_t cleartable(lua_State* L, GCObject* l)
{
size_t work = 0;
while (l)
{
Table* h = gco2h(l);
work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h);
int i = h->sizearray;
while (i--)
{
@ -433,50 +411,36 @@ static void cleartable(lua_State* L, GCObject* l)
{
LuaNode* n = gnode(h, i);
if (FFlag::LuauShrinkWeakTables)
// non-empty entry?
if (!ttisnil(gval(n)))
{
// non-empty entry?
if (!ttisnil(gval(n)))
{
// can we clear key or value?
if (iscleared(gkey(n)) || iscleared(gval(n)))
{
setnilvalue(gval(n)); /* remove value ... */
removeentry(n); /* remove entry from table */
}
else
{
activevalues++;
}
}
}
else
{
if (!ttisnil(gval(n)) && /* non-empty entry? */
(iscleared(gkey(n)) || iscleared(gval(n))))
// can we clear key or value?
if (iscleared(gkey(n)) || iscleared(gval(n)))
{
setnilvalue(gval(n)); /* remove value ... */
removeentry(n); /* remove entry from table */
}
else
{
activevalues++;
}
}
}
if (FFlag::LuauShrinkWeakTables)
if (const char* modev = gettablemode(L->global, h))
{
if (const char* modev = gettablemode(L->global, h))
// are we allowed to shrink this weak table?
if (strchr(modev, 's'))
{
// are we allowed to shrink this weak table?
if (strchr(modev, 's'))
{
// shrink at 37.5% occupancy
if (activevalues < sizenode(h) * 3 / 8)
luaH_resizehash(L, h, activevalues);
}
// shrink at 37.5% occupancy
if (activevalues < sizenode(h) * 3 / 8)
luaH_resizehash(L, h, activevalues);
}
}
l = h->gclist;
}
return work;
}
static void shrinkstack(lua_State* L)
@ -655,37 +619,49 @@ static void markroot(lua_State* L)
g->gcstate = GCSpropagate;
}
static void remarkupvals(global_State* g)
static size_t remarkupvals(global_State* g)
{
UpVal* uv;
for (uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next)
size_t work = 0;
for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next)
{
work += sizeof(UpVal);
LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv);
if (isgray(obj2gco(uv)))
markvalue(g, uv->v);
}
return work;
}
static void atomic(lua_State* L)
static size_t atomic(lua_State* L)
{
global_State* g = L->global;
g->gcstate = GCSatomic;
size_t work = 0;
if (FFlag::LuauSeparateAtomic)
{
LUAU_ASSERT(g->gcstate == GCSatomic);
}
else
{
g->gcstate = GCSatomic;
}
/* remark occasional upvalues of (maybe) dead threads */
remarkupvals(g);
work += remarkupvals(g);
/* traverse objects caught by write barrier and by 'remarkupvals' */
propagateall(g);
work += propagateall(g);
/* remark weak tables */
g->gray = g->weak;
g->weak = NULL;
LUAU_ASSERT(!iswhite(obj2gco(g->mainthread)));
markobject(g, L); /* mark running thread */
markmt(g); /* mark basic metatables (again) */
propagateall(g);
work += propagateall(g);
/* remark gray again */
g->gray = g->grayagain;
g->grayagain = NULL;
propagateall(g);
cleartable(L, g->weak); /* remove collected objects from weak tables */
work += propagateall(g);
work += cleartable(L, g->weak); /* remove collected objects from weak tables */
g->weak = NULL;
/* flip current white */
g->currentwhite = cast_byte(otherwhite(g));
@ -693,7 +669,12 @@ static void atomic(lua_State* L)
g->sweepgc = &g->rootgc;
g->gcstate = GCSsweepstring;
GC_INTERRUPT(GCSatomic);
if (!FFlag::LuauSeparateAtomic)
{
GC_INTERRUPT(GCSatomic);
}
return work;
}
static size_t singlestep(lua_State* L)
@ -705,46 +686,24 @@ static size_t singlestep(lua_State* L)
case GCSpause:
{
markroot(L); /* start a new collection */
LUAU_ASSERT(g->gcstate == GCSpropagate);
break;
}
case GCSpropagate:
{
if (FFlag::LuauRescanGrayAgain)
if (g->gray)
{
if (g->gray)
{
g->gcstats.currcycle.markitems++;
g->gcstats.currcycle.markitems++;
cost = propagatemark(g);
}
else
{
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
g->gcstate = GCSpropagateagain;
}
cost = propagatemark(g);
}
else
{
if (g->gray)
{
g->gcstats.currcycle.markitems++;
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
cost = propagatemark(g);
}
else /* no more `gray' objects */
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
g->gcstate = GCSpropagateagain;
}
break;
}
@ -758,17 +717,34 @@ static size_t singlestep(lua_State* L)
}
else /* no more `gray' objects */
{
double starttimestamp = lua_clock();
if (FFlag::LuauSeparateAtomic)
{
g->gcstate = GCSatomic;
}
else
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
}
break;
}
case GCSatomic:
{
g->gcstats.currcycle.atomicstarttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
cost = atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
break;
}
case GCSsweepstring:
{
size_t traversedcount = 0;
@ -806,7 +782,7 @@ static size_t singlestep(lua_State* L)
break;
}
default:
LUAU_ASSERT(0);
LUAU_ASSERT(!"Unexpected GC state");
}
return cost;
@ -821,48 +797,25 @@ static size_t gcstep(lua_State* L, size_t limit)
case GCSpause:
{
markroot(L); /* start a new collection */
LUAU_ASSERT(g->gcstate == GCSpropagate);
break;
}
case GCSpropagate:
{
if (FFlag::LuauRescanGrayAgain)
while (g->gray && cost < limit)
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
g->gcstats.currcycle.markitems++;
cost += propagatemark(g);
}
if (!g->gray)
{
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
g->gcstate = GCSpropagateagain;
}
cost += propagatemark(g);
}
else
if (!g->gray)
{
while (g->gray && cost < limit)
{
g->gcstats.currcycle.markitems++;
// perform one iteration over 'gray again' list
g->gray = g->grayagain;
g->grayagain = NULL;
cost += propagatemark(g);
}
if (!g->gray) /* no more `gray' objects */
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
g->gcstate = GCSpropagateagain;
}
break;
}
@ -877,17 +830,34 @@ static size_t gcstep(lua_State* L, size_t limit)
if (!g->gray) /* no more `gray' objects */
{
double starttimestamp = lua_clock();
if (FFlag::LuauSeparateAtomic)
{
g->gcstate = GCSatomic;
}
else
{
double starttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
g->gcstats.currcycle.atomicstarttimestamp = starttimestamp;
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
atomic(L); /* finish mark phase */
atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp;
}
}
break;
}
case GCSatomic:
{
g->gcstats.currcycle.atomicstarttimestamp = lua_clock();
g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes;
cost = atomic(L); /* finish mark phase */
LUAU_ASSERT(g->gcstate == GCSsweepstring);
break;
}
case GCSsweepstring:
{
while (g->sweepstrgc < g->strt.size && cost < limit)
@ -934,7 +904,7 @@ static size_t gcstep(lua_State* L, size_t limit)
break;
}
default:
LUAU_ASSERT(0);
LUAU_ASSERT(!"Unexpected GC state");
}
return cost;
}
@ -1084,7 +1054,7 @@ void luaC_fullgc(lua_State* L)
if (g->gcstate == GCSpause)
startGcCycleStats(g);
if (g->gcstate <= GCSpropagateagain)
if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain))
{
/* reset sweep marks to sweep all elements (returning them to white) */
g->sweepstrgc = 0;
@ -1095,7 +1065,7 @@ void luaC_fullgc(lua_State* L)
g->weak = NULL;
g->gcstate = GCSsweepstring;
}
LUAU_ASSERT(g->gcstate != GCSpause && g->gcstate != GCSpropagate && g->gcstate != GCSpropagateagain);
LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep);
/* finish any pending sweep phase */
while (g->gcstate != GCSpause)
{
@ -1143,14 +1113,11 @@ void luaC_fullgc(lua_State* L)
void luaC_barrierupval(lua_State* L, GCObject* v)
{
if (FFlag::LuauGcFullSkipInactiveThreads)
{
global_State* g = L->global;
LUAU_ASSERT(iswhite(v) && !isdead(g, v));
global_State* g = L->global;
LUAU_ASSERT(iswhite(v) && !isdead(g, v));
if (keepinvariant(g))
reallymarkobject(g, v);
}
if (keepinvariant(g))
reallymarkobject(g, v);
}
void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v)
@ -1778,7 +1745,7 @@ int64_t luaC_allocationrate(lua_State* L)
global_State* g = L->global;
const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms
if (g->gcstate <= GCSpropagateagain)
if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain))
{
double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp;

View File

@ -6,8 +6,6 @@
#include "lobject.h"
#include "lstate.h"
LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
/*
** Possible states of the Garbage Collector
*/
@ -25,7 +23,7 @@ LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads)
** still-black objects. The invariant is restored when sweep ends and
** all objects are white again.
*/
#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain)
#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic)
/*
** some useful bit tricks
@ -147,4 +145,4 @@ LUAI_FUNC void luaC_validate(lua_State* L);
LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat));
LUAI_FUNC int64_t luaC_allocationrate(lua_State* L);
LUAI_FUNC void luaC_wakethread(lua_State* L);
LUAI_FUNC const char* luaC_statename(int state);
LUAI_FUNC const char* luaC_statename(int state);

View File

@ -199,7 +199,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass)
if (page->freeNext >= 0)
{
block = page->data + page->freeNext;
block = &page->data + page->freeNext;
ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize);
page->freeNext -= page->blockSize;

View File

@ -226,7 +226,7 @@ void luaS_freeudata(lua_State* L, Udata* u)
void (*dtor)(void*) = nullptr;
if (u->tag == UTAG_IDTOR)
memcpy(&dtor, u->data + u->len - sizeof(dtor), sizeof(dtor));
memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor));
else if (u->tag)
dtor = L->global->udatagc[u->tag];

View File

@ -13,7 +13,7 @@
#include <string.h>
// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens
template <typename T>
template<typename T>
struct TempBuffer
{
lua_State* L;
@ -346,6 +346,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size
uint32_t mainid = readVarInt(data, size, offset);
Proto* main = protos[mainid];
luaC_checkthreadsleep(L);
Closure* cl = luaF_newLclosure(L, 0, envt, main);
setclvalue(L, L->top, cl);
incr_top(L);

849
bench/tests/chess.lua Normal file
View File

@ -0,0 +1,849 @@
local bench = script and require(script.Parent.bench_support) or require("bench_support")
local RANKS = "12345678"
local FILES = "abcdefgh"
local PieceSymbols = "PpRrNnBbQqKk"
local UnicodePieces = {"", "", "", "", "", "", "", "", "", "", "", ""}
local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
--
-- Lua 5.2 Compat
--
if not table.create then
function table.create(n, v)
local result = {}
for i=1,n do result[i] = v end
return result
end
end
if not table.move then
function table.move(a, from, to, start, target)
local dx = start - from
for i=from,to do
target[i+dx] = a[i]
end
end
end
--
-- Utils
--
local function square(s)
return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9
end
local function squareName(n)
local file = n % 8
local rank = (n-file)/8
return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1)
end
local function moveName(v )
local from = bit32.extract(v, 6, 6)
local to = bit32.extract(v, 0, 6)
local piece = bit32.extract(v, 20, 4)
local captured = bit32.extract(v, 25, 4)
local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to)
if bit32.extract(v,14) == 1 then
if to > from then
return "O-O"
else
return "O-O-O"
end
end
local promote = bit32.extract(v,15,4)
if promote ~= 0 then
move = move .. "=" .. PieceSymbols:sub(promote,promote)
end
return move
end
local function ucimove(m)
local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6))
local promote = bit32.extract(m,15,4)
if promote > 0 then
mm = mm .. PieceSymbols:sub(promote,promote):lower()
end
return mm
end
local _utils = {squareName, moveName}
--
-- Bitboards
--
local Bitboard = {}
function Bitboard:toString()
local out = {}
local src = self.h
for x=7,0,-1 do
table.insert(out, RANKS:sub(x+1,x+1))
table.insert(out, " ")
local bit = bit32.lshift(1,(x%4) * 8)
for x=0,7 do
if bit32.band(src, bit) ~= 0 then
table.insert(out, "x ")
else
table.insert(out, "- ")
end
bit = bit32.lshift(bit, 1)
end
if x == 4 then
src = self.l
end
table.insert(out, "\n")
end
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h)
return table.concat(out)
end
function Bitboard.from(l ,h )
return setmetatable({l=l, h=h}, Bitboard)
end
Bitboard.zero = Bitboard.from(0,0)
Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF)
local Rank1 = Bitboard.from(0x000000FF, 0)
local Rank3 = Bitboard.from(0x00FF0000, 0)
local Rank6 = Bitboard.from(0, 0x0000FF00)
local Rank8 = Bitboard.from(0, 0xFF000000)
local FileA = Bitboard.from(0x01010101, 0x01010101)
local FileB = Bitboard.from(0x02020202, 0x02020202)
local FileC = Bitboard.from(0x04040404, 0x04040404)
local FileD = Bitboard.from(0x08080808, 0x08080808)
local FileE = Bitboard.from(0x10101010, 0x10101010)
local FileF = Bitboard.from(0x20202020, 0x20202020)
local FileG = Bitboard.from(0x40404040, 0x40404040)
local FileH = Bitboard.from(0x80808080, 0x80808080)
local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH}
-- These masks are filled out below for all files
local RightMasks = {FileH}
local LeftMasks = {FileA}
local function popcnt32(i)
i = i - bit32.band(bit32.rshift(i,1), 0x55555555)
i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333)
return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24)
end
function Bitboard:up()
return self:lshift(8)
end
function Bitboard:down()
return self:rshift(8)
end
function Bitboard:right()
return self:band(FileH:inverse()):lshift(1)
end
function Bitboard:left()
return self:band(FileA:inverse()):rshift(1)
end
function Bitboard:move(x,y)
local out = self
if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end
if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end
if y < 0 then out = out:rshift(-8 * y) end
if y > 0 then out = out:lshift(8 * y) end
return out
end
function Bitboard:popcnt()
return popcnt32(self.l) + popcnt32(self.h)
end
function Bitboard:band(other )
return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h))
end
function Bitboard:bandnot(other )
return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h)))
end
function Bitboard:bandempty(other )
return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0
end
function Bitboard:bor(other )
return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h))
end
function Bitboard:bxor(other )
return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h))
end
function Bitboard:inverse()
return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF))
end
function Bitboard:empty()
return self.h == 0 and self.l == 0
end
function Bitboard:ctz()
local target = self.l
local offset = 0
local result = 0
if target == 0 then
target = self.h
result = 32
end
if target == 0 then
return 64
end
while bit32.extract(target, offset) == 0 do
offset = offset + 1
end
return result + offset
end
function Bitboard:ctzafter(start)
start = start + 1
if start < 32 then
for i=start,31 do
if bit32.extract(self.l, i) == 1 then return i end
end
end
for i=math.max(32,start),63 do
if bit32.extract(self.h, i-32) == 1 then return i end
end
return 64
end
function Bitboard:lshift(amt)
assert(amt >= 0)
if amt == 0 then return self end
if amt > 31 then
return Bitboard.from(0, bit32.lshift(self.l, amt-31))
end
local l = bit32.lshift(self.l, amt)
local h = bit32.bor(
bit32.lshift(self.h, amt),
bit32.extract(self.l, 32-amt, amt)
)
return Bitboard.from(l, h)
end
function Bitboard:rshift(amt)
assert(amt >= 0)
if amt == 0 then return self end
local h = bit32.rshift(self.h, amt)
local l = bit32.bor(
bit32.rshift(self.l, amt),
bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt)
)
return Bitboard.from(l, h)
end
function Bitboard:index(i)
if i > 31 then
return bit32.extract(self.h, i - 32)
else
return bit32.extract(self.l, i)
end
end
function Bitboard:set(i , v)
if i > 31 then
return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32))
else
return Bitboard.from(bit32.replace(self.l, v, i), self.h)
end
end
function Bitboard:isolate(i)
return self:band(Bitboard.some(i))
end
function Bitboard.some(idx )
return Bitboard.zero:set(idx, 1)
end
Bitboard.__index = Bitboard
Bitboard.__tostring = Bitboard.toString
for i=2,8 do
RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH)
LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA)
end
--
-- Board
--
local Board = {}
function Board.new()
local boards = table.create(12, Bitboard.zero)
boards.ocupied = Bitboard.zero
boards.white = Bitboard.zero
boards.black = Bitboard.zero
boards.unocupied = Bitboard.full
boards.ep = Bitboard.zero
boards.castle = Bitboard.zero
boards.toMove = 1
boards.hm = 0
boards.moves = 0
boards.material = 0
return setmetatable(boards, Board)
end
function Board.fromFen(fen )
local b = Board.new()
local i = 0
local rank = 7
local file = 0
while true do
i = i + 1
local p = fen:sub(i,i)
if p == '/' then
rank = rank - 1
file = 0
elseif tonumber(p) ~= nil then
file = file + tonumber(p)
else
local pidx = PieceSymbols:find(p)
if pidx == nil then break end
b[pidx] = b[pidx]:set(rank*8+file, 1)
file = file + 1
end
end
local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i)
if move == nil then print(fen:sub(i)) end
b.toMove = move == 'w' and 1 or 2
if ep ~= "-" then
b.ep = Bitboard.some(square(ep))
end
if castle ~= "-" then
local oo = Bitboard.zero
if castle:find("K") then
oo = oo:set(7, 1)
end
if castle:find("Q") then
oo = oo:set(0, 1)
end
if castle:find("k") then
oo = oo:set(63, 1)
end
if castle:find("q") then
oo = oo:set(56, 1)
end
b.castle = oo
end
b.hm = hm
b.moves = m
b:updateCache()
return b
end
function Board:index(idx )
if self.white:index(idx) == 1 then
for p=1,12,2 do
if self[p]:index(idx) == 1 then
return p
end
end
else
for p=2,12,2 do
if self[p]:index(idx) == 1 then
return p
end
end
end
return 0
end
function Board:updateCache()
for i=1,11,2 do
self.white = self.white:bor(self[i])
self.black = self.black:bor(self[i+1])
end
self.ocupied = self.black:bor(self.white)
self.unocupied = self.ocupied:inverse()
self.material =
100*self[1]:popcnt() - 100*self[2]:popcnt() +
500*self[3]:popcnt() - 500*self[4]:popcnt() +
300*self[5]:popcnt() - 300*self[6]:popcnt() +
300*self[7]:popcnt() - 300*self[8]:popcnt() +
900*self[9]:popcnt() - 900*self[10]:popcnt()
end
function Board:fen()
local out = {}
local s = 0
local idx = 56
for i=0,63 do
if i % 8 == 0 and i > 0 then
idx = idx - 16
if s > 0 then
table.insert(out, '' .. s)
s = 0
end
table.insert(out, '/')
end
local p = self:index(idx)
if p == 0 then
s = s + 1
else
if s > 0 then
table.insert(out, '' .. s)
s = 0
end
table.insert(out, PieceSymbols:sub(p,p))
end
idx = idx + 1
end
if s > 0 then
table.insert(out, '' .. s)
end
table.insert(out, self.toMove == 1 and ' w ' or ' b ')
if self.castle:empty() then
table.insert(out, '-')
else
if self.castle:index(7) == 1 then table.insert(out, 'K') end
if self.castle:index(0) == 1 then table.insert(out, 'Q') end
if self.castle:index(63) == 1 then table.insert(out, 'k') end
if self.castle:index(56) == 1 then table.insert(out, 'q') end
end
table.insert(out, ' ')
if self.ep:empty() then
table.insert(out, '-')
else
table.insert(out, squareName(self.ep:ctz()))
end
table.insert(out, ' ' .. self.hm)
table.insert(out, ' ' .. self.moves)
return table.concat(out)
end
function Board:pmoves(idx)
return self:generate(idx)
end
function Board:pcaptures(idx)
return self:generate(idx):band(self.ocupied)
end
local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}}
local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}}
local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}}
local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}}
function Board:generate(idx)
local piece = self:index(idx)
local r = Bitboard.some(idx)
local out = Bitboard.zero
local type = bit32.rshift(piece - 1, 1)
local cancapture = piece % 2 == 1 and self.black or self.white
if piece == 0 then return Bitboard.zero end
if type == 0 then
-- Pawn
local d = -(piece*2 - 3)
local movetwo = piece == 1 and Rank3 or Rank6
out = out:bor(r:move(0,d):band(self.unocupied))
out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied))
local captures = r:move(0,d)
captures = captures:right():bor(captures:left())
if not captures:bandempty(self.ep) then
out = out:bor(self.ep)
end
captures = captures:band(cancapture)
out = out:bor(captures)
return out
elseif type == 5 then
-- King
for x=-1,1,1 do
for y = -1,1,1 do
local w = r:move(x,y)
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
end
end
end
elseif type == 2 then
-- Knight
for _,j in ipairs(KNIGHT_MOVES) do
local w = r:move(j[1],j[2])
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
end
end
else
-- Sliders (Rook, Bishop, Queen)
local slides
if type == 1 then
slides = ROOK_SLIDES
elseif type == 3 then
slides = BISHOP_SLIDES
else
slides = QUEEN_SLIDES
end
for _, op in ipairs(slides) do
local w = r
for i=1,7 do
w = w:move(op[1], op[2])
if w:empty() then break end
if self.ocupied:bandempty(w) then
out = out:bor(w)
else
if not cancapture:bandempty(w) then
out = out:bor(w)
end
break
end
end
end
end
return out
end
-- 0-5 - From Square
-- 6-11 - To Square
-- 12 - is Check
-- 13 - Is EnPassent
-- 14 - Is Castle
-- 15-19 - Promotion Piece
-- 20-24 - Moved Pice
-- 25-29 - Captured Piece
function Board:toString(mark )
local out = {}
for x=8,1,-1 do
table.insert(out, RANKS:sub(x,x) .. " ")
for y=1,8 do
local n = 8*x+y-9
local i = self:index(n)
if i == 0 then
table.insert(out, '-')
else
-- out = out .. PieceSymbols:sub(i,i)
table.insert(out, UnicodePieces[i])
end
if mark ~= nil and mark:index(n) ~= 0 then
table.insert(out, ')')
elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then
table.insert(out, '(')
else
table.insert(out, ' ')
end
end
table.insert(out, "\n")
end
table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n')
table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n")
return table.concat(out)
end
function Board:moveList()
local tm = self.toMove == 1 and self.white or self.black
local castle_rank = self.toMove == 1 and Rank1 or Rank8
local out = {}
local function emit(id)
if not self:applyMove(id):illegalyChecked() then
table.insert(out, id)
end
end
local cr = tm:band(self.castle):band(castle_rank)
if not cr:empty() then
local p = self.toMove == 1 and 11 or 12
local tcolor = self.toMove == 1 and self.black or self.white
local kidx = self[p]:ctz()
local castle = bit32.replace(0, p, 20, 4)
castle = bit32.replace(castle, kidx, 6, 6)
castle = bit32.replace(castle, 1, 14)
local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank)
local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p])
if
not cr:bandempty(FileA) and
mustbeemptyl:bandempty(self.ocupied) and
not self:isSquareThreatened(cantbethreatened, tcolor)
then
emit(bit32.replace(castle, kidx - 2, 0, 6))
end
local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank)
if
not cr:bandempty(FileH) and
mustbeemptyr:bandempty(self.ocupied) and
not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor)
then
emit(bit32.replace(castle, kidx + 2, 0, 6))
end
end
local sq = tm:ctz()
repeat
local p = self:index(sq)
local moves = self:pmoves(sq)
while not moves:empty() do
local m = moves:ctz()
moves = moves:set(m, 0)
local id = bit32.replace(m, sq, 6, 6)
id = bit32.replace(id, p, 20, 4)
local mbb = Bitboard.some(m)
if not self.ocupied:bandempty(mbb) then
id = bit32.replace(id, self:index(m), 25, 4)
end
-- Check if pawn needs to be promoted
if p == 1 and m >= 8*7 then
for i=3,9,2 do
emit(bit32.replace(id, i, 15, 4))
end
elseif p == 2 and m < 8 then
for i=4,10,2 do
emit(bit32.replace(id, i, 15, 4))
end
else
emit(id)
end
end
sq = tm:ctzafter(sq)
until sq == 64
return out
end
function Board:illegalyChecked()
local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")]
return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black)
end
function Board:isSquareThreatened(target , color )
local tm = color
local sq = tm:ctz()
repeat
local moves = self:pmoves(sq)
if not moves:bandempty(target) then
return true
end
sq = color:ctzafter(sq)
until sq == 64
return false
end
function Board:perft(depth )
if depth == 0 then return 1 end
if depth == 1 then
return #self:moveList()
end
local result = 0
for k,m in ipairs(self:moveList()) do
local c = self:applyMove(m):perft(depth - 1)
if c == 0 then
-- Perft only counts leaf nodes at target depth
-- result = result + 1
else
result = result + c
end
end
return result
end
function Board:applyMove(move )
local out = Board.new()
table.move(self, 1, 12, 1, out)
local from = bit32.extract(move, 6, 6)
local to = bit32.extract(move, 0, 6)
local promote = bit32.extract(move, 15, 4)
local piece = self:index(from)
local captured = self:index(to)
local tom = Bitboard.some(to)
local isCastle = bit32.extract(move, 14)
if piece % 2 == 0 then
out.moves = self.moves + 1
end
if captured == 1 or piece < 3 then
out.hm = 0
else
out.hm = self.hm + 1
end
out.castle = self.castle
out.toMove = self.toMove == 1 and 2 or 1
if isCastle == 1 then
local rank = piece == 11 and Rank1 or Rank8
local colorOffset = piece - 11
out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA)
out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank))
out[piece] = (from < to and FileG or FileC):band(rank)
out.castle = out.castle:bandnot(rank)
out:updateCache()
return out
end
if piece < 3 then
local dist = math.abs(to - from)
-- Pawn moved two squares, set ep square
if dist == 16 then
out.ep = Bitboard.some((from + to) / 2)
end
-- Remove enpasent capture
if not tom:bandempty(self.ep) then
if piece == 1 then
out[2] = out[2]:bandnot(self.ep:down())
end
if piece == 2 then
out[1] = out[1]:bandnot(self.ep:up())
end
end
end
if piece == 3 or piece == 4 then
out.castle = out.castle:set(from, 0)
end
if piece > 10 then
local rank = piece == 11 and Rank1 or Rank8
out.castle = out.castle:bandnot(rank)
end
out[piece] = out[piece]:set(from, 0)
if promote == 0 then
out[piece] = out[piece]:set(to, 1)
else
out[promote] = out[promote]:set(to, 1)
end
if captured ~= 0 then
out[captured] = out[captured]:set(to, 0)
end
out:updateCache()
return out
end
Board.__index = Board
Board.__tostring = Board.toString
--
-- Main
--
local failures = 0
local function test(fen, ply, target)
local b = Board.fromFen(fen)
if b:fen() ~= fen then
print("FEN MISMATCH", fen, b:fen())
failures = failures + 1
return
end
local found = b:perft(ply)
if found ~= target then
print(fen, "Found", found, "target", target)
failures = failures + 1
for k,v in pairs(b:moveList()) do
print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1'))
end
--error("Test Failure")
else
print("OK", found, fen)
end
end
-- From https://www.chessprogramming.org/Perft_Results
-- If interpreter, computers, or algorithm gets too fast
-- feel free to go deeper
local testCases = {}
local function addTest(...) table.insert(testCases, {...}) end
addTest(StartingFen, 3, 8902)
addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039)
addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812)
addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467)
addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486)
addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079)
local function chess()
for k,v in ipairs(testCases) do
test(v[1],v[2],v[3])
end
end
bench.runCode(chess, "chess")

View File

@ -1596,7 +1596,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion")
local function target(a: number, b: string) return a + #b end
local function d(a: n@1, b)
return target(a, b)
return target(a, b)
end
)");
@ -1609,7 +1609,7 @@ end
local function target(a: number, b: string) return a + #b end
local function d(a, b: s@1)
return target(a, b)
return target(a, b)
end
)");
@ -1622,7 +1622,7 @@ end
local function target(a: number, b: string) return a + #b end
local function d(a:@1 @2, b)
return target(a, b)
return target(a, b)
end
)");
@ -1640,7 +1640,7 @@ end
local function target(a: number, b: string) return a + #b end
local function d(a, b: @1)@2: number
return target(a, b)
return target(a, b)
end
)");
@ -1682,7 +1682,7 @@ local x = target(function(a: n@1
local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end
local x = target(function(a: n@1, b: @2)
return a + #b
return a + #b
end)
)");
@ -1700,7 +1700,7 @@ end)
local function target(callback: (...number) -> number) return callback(1, 2, 3) end
local x = target(function(a: n@1)
return a
return a
end
)");
@ -1716,7 +1716,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio
local function target(callback: (...number) -> number) return callback(1, 2, 3) end
local x = target(function(...:n@1)
return a
return a
end
)");
@ -1729,7 +1729,7 @@ end
local function target(callback: (...number) -> number) return callback(1, 2, 3) end
local x = target(function(a:number, b:number, ...:@1)
return a + b
return a + b
end
)");
@ -1745,7 +1745,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion")
local function target(callback: () -> number) return callback() end
local x = target(function(): n@1
return 1
return 1
end
)");
@ -1758,7 +1758,7 @@ end
local function target(callback: () -> (number, number)) return callback() end
local x = target(function(): (number, n@1
return 1, 2
return 1, 2
end
)");
@ -1774,7 +1774,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion"
local function target(callback: () -> ...number) return callback() end
local x = target(function(): ...n@1
return 1, 2, 3
return 1, 2, 3
end
)");
@ -1787,7 +1787,7 @@ end
local function target(callback: () -> ...number) return callback() end
local x = target(function(): (number, number, ...n@1
return 1, 2, 3
return 1, 2, 3
end
)");

View File

@ -768,11 +768,11 @@ TEST_CASE("CaptureSelf")
local MaterialsListClass = {}
function MaterialsListClass:_MakeToolTip(guiElement, text)
local function updateTooltipPosition()
self._tweakingTooltipFrame = 5
end
local function updateTooltipPosition()
self._tweakingTooltipFrame = 5
end
updateTooltipPosition()
updateTooltipPosition()
end
return MaterialsListClass
@ -2001,14 +2001,14 @@ TEST_CASE("UpvaluesLoopsBytecode")
{
CHECK_EQ("\n" + compileFunction(R"(
function test()
for i=1,10 do
for i=1,10 do
i = i
foo(function() return i end)
if bar then
break
end
end
return 0
foo(function() return i end)
if bar then
break
end
end
return 0
end
)",
1),
@ -2035,14 +2035,14 @@ RETURN R0 1
CHECK_EQ("\n" + compileFunction(R"(
function test()
for i in ipairs(data) do
for i in ipairs(data) do
i = i
foo(function() return i end)
if bar then
break
end
end
return 0
foo(function() return i end)
if bar then
break
end
end
return 0
end
)",
1),
@ -2068,17 +2068,17 @@ RETURN R0 1
CHECK_EQ("\n" + compileFunction(R"(
function test()
local i = 0
while i < 5 do
local j
local i = 0
while i < 5 do
local j
j = i
foo(function() return j end)
i = i + 1
if bar then
break
end
end
return 0
foo(function() return j end)
i = i + 1
if bar then
break
end
end
return 0
end
)",
1),
@ -2105,17 +2105,17 @@ RETURN R1 1
CHECK_EQ("\n" + compileFunction(R"(
function test()
local i = 0
repeat
local j
local i = 0
repeat
local j
j = i
foo(function() return j end)
i = i + 1
if bar then
break
end
until i < 5
return 0
foo(function() return j end)
i = i + 1
if bar then
break
end
until i < 5
return 0
end
)",
1),
@ -2304,10 +2304,10 @@ local Value1, Value2, Value3 = ...
local Table = {}
Table.SubTable["Key"] = {
Key1 = Value1,
Key2 = Value2,
Key3 = Value3,
Key4 = true,
Key1 = Value1,
Key2 = Value2,
Key3 = Value3,
Key4 = true,
}
)");

View File

@ -801,4 +801,17 @@ TEST_CASE("IfElseExpression")
runConformance("ifelseexpr.lua");
}
TEST_CASE("TagMethodError")
{
ScopedFastFlag sff{"LuauCcallRestoreFix", true};
runConformance("tmerror.lua", [](lua_State* L) {
auto* cb = lua_callbacks(L);
cb->debugprotectederror = [](lua_State* L) {
CHECK(lua_isyieldable(L));
};
});
}
TEST_SUITE_END();

View File

@ -2,6 +2,9 @@
#pragma once
#include <ostream>
#include <optional>
namespace std {
inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&)
{
@ -9,10 +12,12 @@ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&)
}
template<typename T>
std::ostream& operator<<(std::ostream& lhs, const std::optional<T>& t)
auto operator<<(std::ostream& lhs, const std::optional<T>& t) -> decltype(lhs << *t) // SFINAE to only instantiate << for supported types
{
if (t)
return lhs << *t;
else
return lhs << "none";
}
} // namespace std

View File

@ -791,13 +791,13 @@ TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings")
{
LintResult result = lint(R"(--!strict
type InputData = {
id: number,
inputType: EnumItem,
inputState: EnumItem,
updated: number,
position: Vector3,
keyCode: EnumItem,
name: string
id: number,
inputType: EnumItem,
inputState: EnumItem,
updated: number,
position: Vector3,
keyCode: EnumItem,
name: string
}
)");

View File

@ -554,4 +554,54 @@ TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name")
CHECK_EQ("{number | string}", toString(requireType("p"), {true}));
}
/*
* We had a problem where all type aliases would be prototyped into a child scope that happened
* to have the same level. This caused a problem where, if a sibling function referred to that
* type alias in its type signature, it would erroneously be quantified away, even though it doesn't
* actually belong to the function.
*
* We solved this by ascribing a unique subLevel to each prototyped alias.
*/
TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases")
{
CheckResult result = check(R"(
--!strict
local KeyPool = {}
local function newkey(pool: KeyPool, index)
return {}
end
function newKeyPool()
local pool = {
available = {} :: {Key},
}
return setmetatable(pool, KeyPool)
end
export type KeyPool = typeof(newKeyPool())
export type Key = typeof(newkey(newKeyPool(), 1))
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
/*
* We keep a cache of type alias onto TypeVar to prevent infinite types from
* being constructed via recursive or corecursive aliases. We have to adjust
* the TypeLevels of those generic TypeVars so that the unifier doesn't think
* they have improperly leaked out of their scope.
*/
TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_scope_if_they_are_reused_in_multiple_aliases")
{
CheckResult result = check(R"(
type Array<T> = {T}
type Exclude<T, V> = T
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View File

@ -437,8 +437,6 @@ TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_orde
TEST_CASE_FIXTURE(ClassFixture, "optional_class_field_access_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local b: Vector2? = nil
local a = b.X + b.Z

View File

@ -695,4 +695,25 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types")
CHECK(requireType("y1") == requireType("y2"));
}
TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields")
{
ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true};
ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true};
CheckResult result = check(R"(
local exports = {}
local nested = {}
nested.name = function(t, k)
local a = t.x.y
return rawget(t, k)
end
exports.nested = nested
return exports
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_SUITE_END();

View File

@ -9,6 +9,7 @@
#include <algorithm>
LUAU_FASTFLAG(LuauEqConstraint)
LUAU_FASTFLAG(LuauQuantifyInPlace2)
using namespace Luau;
@ -42,7 +43,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
end
)";
const std::string expected = R"(
const std::string old_expected = R"(
function f(a:{fn:()->(free,free...)}): ()
if type(a) == 'boolean'then
local a1:boolean=a
@ -51,7 +52,21 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete")
end
end
)";
CHECK_EQ(expected, decorateWithTypes(code));
const std::string expected = R"(
function f(a:{fn:()->(a,b...)}): ()
if type(a) == 'boolean'then
local a1:boolean=a
elseif a.fn()then
local a2:{fn:()->(a,b...)}=a
end
end
)";
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ(expected, decorateWithTypes(code));
else
CHECK_EQ(old_expected, decorateWithTypes(code));
}
TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns")
@ -263,8 +278,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap")
TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5))
{
ScopedFastInt sffi{"LuauTarjanChildLimit", 50};
ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 50};
ScopedFastInt sffi{"LuauTarjanChildLimit", 1};
ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1};
CheckResult result = check(R"LUA(
local Result

View File

@ -8,6 +8,7 @@
LUAU_FASTFLAG(LuauWeakEqConstraint)
LUAU_FASTFLAG(LuauOrPredicate)
LUAU_FASTFLAG(LuauQuantifyInPlace2)
using namespace Luau;
@ -698,10 +699,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector")
CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector"
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0]));
else
CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0]));
CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance"
CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
else
CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance"
}
TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector")

View File

@ -617,7 +617,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too")
REQUIRE_EQ(indexer.indexType, typeChecker.numberType);
REQUIRE(nullptr != get<GenericTypeVar>(indexer.indexResultType));
REQUIRE(nullptr != get<GenericTypeVar>(follow(indexer.indexResultType)));
}
TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2")

View File

@ -180,7 +180,12 @@ TEST_CASE_FIXTURE(Fixture, "expr_statement")
TEST_CASE_FIXTURE(Fixture, "generic_function")
{
CheckResult result = check("function id(x) return x end local a = id(55) local b = id(nil)");
CheckResult result = check(R"(
function id(x) return x end
local a = id(55)
local b = id(nil)
)");
LUAU_REQUIRE_NO_ERRORS(result);
CHECK_EQ(*typeChecker.numberType, *requireType("a"));
@ -1889,7 +1894,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function")
REQUIRE_EQ(2, argVec.size());
const FunctionTypeVar* fType = get<FunctionTypeVar>(argVec[0]);
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(argVec[0]));
REQUIRE(fType != nullptr);
std::vector<TypeId> fArgs = flatten(fType->argTypes).first;
@ -1926,7 +1931,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_2")
REQUIRE_EQ(6, argVec.size());
const FunctionTypeVar* fType = get<FunctionTypeVar>(argVec[0]);
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(argVec[0]));
REQUIRE(fType != nullptr);
}
@ -3842,10 +3847,10 @@ local T: any
T = {}
T.__index = T
function T.new(...)
local self = {}
setmetatable(self, T)
self:construct(...)
return self
local self = {}
setmetatable(self, T)
self:construct(...)
return self
end
function T:construct(index)
end
@ -4068,11 +4073,11 @@ function n:Clone() end
local m = {}
function m.a(x)
x:Clone()
x:Clone()
end
function m.b()
m.a(n)
m.a(n)
end
return m
@ -4393,8 +4398,6 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload")
TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments")
{
ScopedFastFlag luauInferFunctionArgsFix("LuauInferFunctionArgsFix", true);
// Simple direct arg to arg propagation
CheckResult result = check(R"(
type Table = { x: number, y: number }
@ -4681,7 +4684,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early")
{
ScopedFastFlag sffs[] = {
{"LuauSlightlyMoreFlexibleBinaryPredicates", true},
{"LuauExtraNilRecovery", true},
};
CheckResult result = check(R"(
@ -4698,7 +4700,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch")
{
ScopedFastFlag sffs[] = {
{"LuauSlightlyMoreFlexibleBinaryPredicates", true},
{"LuauExtraNilRecovery", true},
};
CheckResult result = check(R"(

View File

@ -8,6 +8,8 @@
#include "doctest.h"
LUAU_FASTFLAG(LuauQuantifyInPlace2);
using namespace Luau;
struct TryUnifyFixture : Fixture
@ -15,7 +17,8 @@ struct TryUnifyFixture : Fixture
TypeArena arena;
ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}};
InternalErrorReporter iceHandler;
Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, &iceHandler};
UnifierSharedState unifierState{&iceHandler};
Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, unifierState};
};
TEST_SUITE_BEGIN("TryUnifyTests");
@ -139,7 +142,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails"
)");
LUAU_REQUIRE_ERROR_COUNT(1, result);
CHECK_EQ("(number) -> (boolean)", toString(requireType("f")));
if (FFlag::LuauQuantifyInPlace2)
CHECK_EQ("(number) -> boolean", toString(requireType("f")));
else
CHECK_EQ("(number) -> (boolean)", toString(requireType("f")));
}
TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification")

View File

@ -98,10 +98,10 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function")
std::vector<TypeId> applyArgs = flatten(applyType->argTypes).first;
REQUIRE_EQ(3, applyArgs.size());
const FunctionTypeVar* fType = get<FunctionTypeVar>(applyArgs[0]);
const FunctionTypeVar* fType = get<FunctionTypeVar>(follow(applyArgs[0]));
REQUIRE(fType != nullptr);
const FunctionTypeVar* gType = get<FunctionTypeVar>(applyArgs[1]);
const FunctionTypeVar* gType = get<FunctionTypeVar>(follow(applyArgs[1]));
REQUIRE(gType != nullptr);
std::vector<TypeId> gArgs = flatten(gType->argTypes).first;
@ -285,7 +285,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail")
{
CheckResult result = check(R"(
local _ = function():((...any)->(...any),()->())
return function() end, function() end
return function() end, function() end
end
for y in _() do
end

View File

@ -181,8 +181,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property")
TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property")
{
ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true);
CheckResult result = check(R"(
type A = {x: number}
type B = {}
@ -242,8 +240,6 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons")
TEST_CASE_FIXTURE(Fixture, "optional_union_members")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local a = { a = { x = 1, y = 2 }, b = 3 }
type A = typeof(a)
@ -259,8 +255,6 @@ local c = bf.a.y
TEST_CASE_FIXTURE(Fixture, "optional_union_functions")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local a = {}
function a.foo(x:number, y:number) return x + y end
@ -276,8 +270,6 @@ local c = b.foo(1, 2)
TEST_CASE_FIXTURE(Fixture, "optional_union_methods")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
local a = {}
function a:foo(x:number, y:number) return x + y end
@ -310,8 +302,6 @@ return f()
TEST_CASE_FIXTURE(Fixture, "optional_field_access_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = { x: number }
local b: A? = { x = 2 }
@ -327,8 +317,6 @@ local d = b.y
TEST_CASE_FIXTURE(Fixture, "optional_index_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = {number}
local a: A? = {1, 2, 3}
@ -341,8 +329,6 @@ local b = a[1]
TEST_CASE_FIXTURE(Fixture, "optional_call_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = (number) -> number
local a: A? = function(a) return -a end
@ -355,8 +341,6 @@ local b = a(4)
TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = { x: number }
local a: A? = { x = 2 }
@ -378,8 +362,6 @@ a.x = 2
TEST_CASE_FIXTURE(Fixture, "optional_length_error")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
CheckResult result = check(R"(
type A = {number}
local a: A? = {1, 2, 3}
@ -392,9 +374,6 @@ local b = #a
TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details")
{
ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true);
ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true);
CheckResult result = check(R"(
type A = { x: number, y: number }
type B = { x: number, y: number }

View File

@ -265,4 +265,64 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure")
CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result));
}
TEST_CASE("tagging_tables")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar ttv{TableTypeVar{}};
CHECK(!Luau::hasTag(&ttv, "foo"));
Luau::attachTag(&ttv, "foo");
CHECK(Luau::hasTag(&ttv, "foo"));
}
TEST_CASE("tagging_classes")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}};
CHECK(!Luau::hasTag(&base, "foo"));
Luau::attachTag(&base, "foo");
CHECK(Luau::hasTag(&base, "foo"));
}
TEST_CASE("tagging_subclasses")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}};
TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}};
CHECK(!Luau::hasTag(&base, "foo"));
CHECK(!Luau::hasTag(&derived, "foo"));
Luau::attachTag(&base, "foo");
CHECK(Luau::hasTag(&base, "foo"));
CHECK(Luau::hasTag(&derived, "foo"));
Luau::attachTag(&derived, "bar");
CHECK(!Luau::hasTag(&base, "bar"));
CHECK(Luau::hasTag(&derived, "bar"));
}
TEST_CASE("tagging_functions")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
TypePackVar empty{TypePack{}};
TypeVar ftv{FunctionTypeVar{&empty, &empty}};
CHECK(!Luau::hasTag(&ftv, "foo"));
Luau::attachTag(&ftv, "foo");
CHECK(Luau::hasTag(&ftv, "foo"));
}
TEST_CASE("tagging_props")
{
ScopedFastFlag sff{"LuauRefactorTagging", true};
Property prop{};
CHECK(!Luau::hasTag(prop, "foo"));
Luau::attachTag(prop, "foo");
CHECK(Luau::hasTag(prop, "foo"));
}
TEST_SUITE_END();

View File

@ -0,0 +1,15 @@
-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes
-- Generate an error (i.e. throw an exception) inside a tag method which is indirectly
-- called via pcall.
-- This test is meant to detect a regression in handling errors inside a tag method
local testtable = {}
setmetatable(testtable, { __index = function() error("Error") end })
pcall(function()
testtable.missingmethod()
end)
return('OK')

View File

@ -11,9 +11,9 @@ class VariantPrinter:
return type.name + " [" + str(value) + "]"
def match_printer(val):
type = val.type.strip_typedefs()
if type.name and type.name.startswith('Luau::Variant<'):
return VariantPrinter(val)
return None
type = val.type.strip_typedefs()
if type.name and type.name.startswith('Luau::Variant<'):
return VariantPrinter(val)
return None
gdb.pretty_printers.append(match_printer)