Sync to upstream/release/541 (#644)

- Fix autocomplete not suggesting globals defined after the cursor (fixes #622)
- Improve type checker stability
- Reduce parser C stack consumption which fixes some stack overflow crashes on deeply nested sources
- Improve performance of bit32.extract/replace when width is implied (~3% faster chess)
- Improve performance of bit32.extract when field/width are constants (~10% faster base64)
- Heap dump now annotates thread stacks with local variable/function names
This commit is contained in:
Arseny Kapoulkine 2022-08-18 14:32:08 -07:00 committed by GitHub
parent 0e118b54bb
commit be2769ad14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
56 changed files with 1803 additions and 953 deletions

View File

@ -0,0 +1,38 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/NotNull.h"
#include "Luau/Substitution.h"
#include "Luau/TypeVar.h"
#include <memory>
namespace Luau
{
struct TypeArena;
struct Scope;
struct InternalErrorReporter;
using ScopePtr = std::shared_ptr<Scope>;
// A substitution which replaces free types by any
struct Anyification : Substitution
{
Anyification(TypeArena* arena, const ScopePtr& scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack);
NotNull<Scope> scope;
InternalErrorReporter* iceHandler;
TypeId anyType;
TypePackId anyTypePack;
bool normalizationTooComplex = false;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
bool ignoreChildren(TypeId ty) override;
bool ignoreChildren(TypePackId ty) override;
};
}

View File

@ -40,7 +40,6 @@ struct GeneralizationConstraint
{ {
TypeId generalizedType; TypeId generalizedType;
TypeId sourceType; TypeId sourceType;
Scope* scope;
}; };
// subType ~ inst superType // subType ~ inst superType
@ -85,13 +84,14 @@ using ConstraintPtr = std::unique_ptr<struct Constraint>;
struct Constraint struct Constraint
{ {
explicit Constraint(ConstraintV&& c); Constraint(ConstraintV&& c, NotNull<Scope> scope);
Constraint(const Constraint&) = delete; Constraint(const Constraint&) = delete;
Constraint& operator=(const Constraint&) = delete; Constraint& operator=(const Constraint&) = delete;
ConstraintV c; ConstraintV c;
std::vector<NotNull<Constraint>> dependencies; std::vector<NotNull<Constraint>> dependencies;
NotNull<Scope> scope;
}; };
inline Constraint& asMutable(const Constraint& c) inline Constraint& asMutable(const Constraint& c)

View File

@ -72,10 +72,10 @@ struct ConstraintGraphBuilder
/** /**
* Fabricates a scope that is a child of another scope. * Fabricates a scope that is a child of another scope.
* @param location the lexical extent of the scope in the source code. * @param node the lexical node that the scope belongs to.
* @param parent the parent scope of the new scope. Must not be null. * @param parent the parent scope of the new scope. Must not be null.
*/ */
ScopePtr childScope(Location location, const ScopePtr& parent); ScopePtr childScope(AstNode* node, const ScopePtr& parent);
/** /**
* Adds a new constraint with no dependencies to a given scope. * Adds a new constraint with no dependencies to a given scope.
@ -105,10 +105,12 @@ struct ConstraintGraphBuilder
void visit(const ScopePtr& scope, AstStatLocal* local); void visit(const ScopePtr& scope, AstStatLocal* local);
void visit(const ScopePtr& scope, AstStatFor* for_); void visit(const ScopePtr& scope, AstStatFor* for_);
void visit(const ScopePtr& scope, AstStatWhile* while_); void visit(const ScopePtr& scope, AstStatWhile* while_);
void visit(const ScopePtr& scope, AstStatRepeat* repeat);
void visit(const ScopePtr& scope, AstStatLocalFunction* function); void visit(const ScopePtr& scope, AstStatLocalFunction* function);
void visit(const ScopePtr& scope, AstStatFunction* function); void visit(const ScopePtr& scope, AstStatFunction* function);
void visit(const ScopePtr& scope, AstStatReturn* ret); void visit(const ScopePtr& scope, AstStatReturn* ret);
void visit(const ScopePtr& scope, AstStatAssign* assign); void visit(const ScopePtr& scope, AstStatAssign* assign);
void visit(const ScopePtr& scope, AstStatCompoundAssign* assign);
void visit(const ScopePtr& scope, AstStatIf* ifStatement); void visit(const ScopePtr& scope, AstStatIf* ifStatement);
void visit(const ScopePtr& scope, AstStatTypeAlias* alias); void visit(const ScopePtr& scope, AstStatTypeAlias* alias);
void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal); void visit(const ScopePtr& scope, AstStatDeclareGlobal* declareGlobal);
@ -133,6 +135,7 @@ struct ConstraintGraphBuilder
TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr); TypeId check(const ScopePtr& scope, AstExprIndexExpr* indexExpr);
TypeId check(const ScopePtr& scope, AstExprUnary* unary); TypeId check(const ScopePtr& scope, AstExprUnary* unary);
TypeId check(const ScopePtr& scope, AstExprBinary* binary); TypeId check(const ScopePtr& scope, AstExprBinary* binary);
TypeId check(const ScopePtr& scope, AstExprIfElse* ifElse);
TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert); TypeId check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert);
struct FunctionSignature struct FunctionSignature

View File

@ -60,6 +60,9 @@ struct ConstraintSolver
// Memoized instantiations of type aliases. // Memoized instantiations of type aliases.
DenseHashMap<InstantiationSignature, TypeId, HashInstantiationSignature> instantiatedAliases{{}}; DenseHashMap<InstantiationSignature, TypeId, HashInstantiationSignature> instantiatedAliases{{}};
// Recorded errors that take place within the solver.
ErrorVec errors;
ConstraintSolverLogger logger; ConstraintSolverLogger logger;
explicit ConstraintSolver(TypeArena* arena, NotNull<Scope> rootScope); explicit ConstraintSolver(TypeArena* arena, NotNull<Scope> rootScope);
@ -115,7 +118,7 @@ struct ConstraintSolver
* @param subType the sub-type to unify. * @param subType the sub-type to unify.
* @param superType the super-type to unify. * @param superType the super-type to unify.
*/ */
void unify(TypeId subType, TypeId superType); void unify(TypeId subType, TypeId superType, NotNull<Scope> scope);
/** /**
* Creates a new Unifier and performs a single unification operation. Commits * Creates a new Unifier and performs a single unification operation. Commits
@ -123,13 +126,15 @@ struct ConstraintSolver
* @param subPack the sub-type pack to unify. * @param subPack the sub-type pack to unify.
* @param superPack the super-type pack to unify. * @param superPack the super-type pack to unify.
*/ */
void unify(TypePackId subPack, TypePackId superPack); void unify(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope);
/** Pushes a new solver constraint to the solver. /** Pushes a new solver constraint to the solver.
* @param cv the body of the constraint. * @param cv the body of the constraint.
**/ **/
void pushConstraint(ConstraintV cv); void pushConstraint(ConstraintV cv, NotNull<Scope> scope);
void reportError(TypeErrorData&& data, const Location& location);
void reportError(TypeError e);
private: private:
/** /**
* Marks a constraint as being blocked on a type or type pack. The constraint * Marks a constraint as being blocked on a type or type pack. The constraint

View File

@ -77,6 +77,8 @@ struct Module
DenseHashMap<const AstExpr*, TypeId> astOverloadResolvedTypes{nullptr}; DenseHashMap<const AstExpr*, TypeId> astOverloadResolvedTypes{nullptr};
DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr}; DenseHashMap<const AstType*, TypeId> astResolvedTypes{nullptr};
DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr}; DenseHashMap<const AstTypePack*, TypePackId> astResolvedTypePacks{nullptr};
// Map AST nodes to the scope they create. Cannot be NotNull<Scope> because we need a sentinel value for the map.
DenseHashMap<const AstNode*, Scope*> astScopes{nullptr};
std::unordered_map<Name, TypeId> declaredGlobals; std::unordered_map<Name, TypeId> declaredGlobals;
ErrorVec errors; ErrorVec errors;

View File

@ -1,20 +1,28 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Substitution.h"
#include "Luau/TypeVar.h"
#include "Luau/Module.h" #include "Luau/Module.h"
#include "Luau/NotNull.h"
#include "Luau/TypeVar.h"
#include <memory>
namespace Luau namespace Luau
{ {
struct InternalErrorReporter; struct InternalErrorReporter;
struct Module;
struct Scope;
bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice); using ModulePtr = std::shared_ptr<Module>;
bool isSubtype(TypePackId subTy, TypePackId superTy, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice); bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, InternalErrorReporter& ice);
bool isSubtype(TypePackId subTy, TypePackId superTy, NotNull<Scope> scope, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, NotNull<Scope> scope, TypeArena& arena, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, NotNull<Module> module, InternalErrorReporter& ice);
std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice); std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, TypeArena& arena, InternalErrorReporter& ice); std::pair<TypePackId, bool> normalize(TypePackId ty, NotNull<Scope> scope, TypeArena& arena, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, NotNull<Module> module, InternalErrorReporter& ice);
std::pair<TypePackId, bool> normalize(TypePackId ty, const ModulePtr& module, InternalErrorReporter& ice); std::pair<TypePackId, bool> normalize(TypePackId ty, const ModulePtr& module, InternalErrorReporter& ice);
} // namespace Luau } // namespace Luau

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once #pragma once
#include "Luau/Anyification.h"
#include "Luau/Predicate.h" #include "Luau/Predicate.h"
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Module.h" #include "Luau/Module.h"
@ -36,40 +37,6 @@ const AstStat* getFallthrough(const AstStat* node);
struct UnifierOptions; struct UnifierOptions;
struct Unifier; struct Unifier;
// A substitution which replaces free types by any
struct Anyification : Substitution
{
Anyification(TypeArena* arena, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack)
: Substitution(TxnLog::empty(), arena)
, iceHandler(iceHandler)
, anyType(anyType)
, anyTypePack(anyTypePack)
{
}
InternalErrorReporter* iceHandler;
TypeId anyType;
TypePackId anyTypePack;
bool normalizationTooComplex = false;
bool isDirty(TypeId ty) override;
bool isDirty(TypePackId tp) override;
TypeId clean(TypeId ty) override;
TypePackId clean(TypePackId tp) override;
bool ignoreChildren(TypeId ty) override
{
if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassTypeVar>(ty))
return true;
return ty->persistent;
}
bool ignoreChildren(TypePackId ty) override
{
return ty->persistent;
}
};
struct GenericTypeDefinitions struct GenericTypeDefinitions
{ {
std::vector<GenericTypeDefinition> genericTypes; std::vector<GenericTypeDefinition> genericTypes;
@ -196,32 +163,32 @@ struct TypeChecker
/** Attempt to unify the types. /** Attempt to unify the types.
* Treat any failures as type errors in the final typecheck report. * Treat any failures as type errors in the final typecheck report.
*/ */
bool unify(TypeId subTy, TypeId superTy, const Location& location); bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
bool unify(TypeId subTy, TypeId superTy, const Location& location, const UnifierOptions& options); bool unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options);
bool unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); bool unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg);
/** Attempt to unify the types. /** Attempt to unify the types.
* If this fails, and the subTy type can be instantiated, do so and try unification again. * If this fails, and the subTy type can be instantiated, do so and try unification again.
*/ */
bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, const Location& location); bool unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, Unifier& state); void unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, Unifier& state);
/** Attempt to unify. /** Attempt to unify.
* If there are errors, undo everything and return the errors. * If there are errors, undo everything and return the errors.
* If there are no errors, commit and return an empty error vector. * If there are no errors, commit and return an empty error vector.
*/ */
template<typename Id> template<typename Id>
ErrorVec tryUnify_(Id subTy, Id superTy, const Location& location); ErrorVec tryUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location);
ErrorVec tryUnify(TypeId subTy, TypeId superTy, const Location& location); ErrorVec tryUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
ErrorVec tryUnify(TypePackId subTy, TypePackId superTy, const Location& location); ErrorVec tryUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location);
// Test whether the two type vars unify. Never commits the result. // Test whether the two type vars unify. Never commits the result.
template<typename Id> template<typename Id>
ErrorVec canUnify_(Id subTy, Id superTy, const Location& location); ErrorVec canUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location);
ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); ErrorVec canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location);
ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location);
void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location); void unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location);
std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors); std::optional<TypeId> findMetatableEntry(TypeId type, std::string entry, const Location& location, bool addErrors);
std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors); std::optional<TypeId> findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location, bool addErrors);
@ -290,7 +257,7 @@ private:
void reportErrorCodeTooComplex(const Location& location); void reportErrorCodeTooComplex(const Location& location);
private: private:
Unifier mkUnifier(const Location& location); Unifier mkUnifier(const ScopePtr& scope, const Location& location);
// These functions are only safe to call when we are in the process of typechecking a module. // These functions are only safe to call when we are in the process of typechecking a module.
@ -312,7 +279,7 @@ public:
std::pair<std::optional<TypeId>, bool> pickTypesFromSense(TypeId type, bool sense); std::pair<std::optional<TypeId>, bool> pickTypesFromSense(TypeId type, bool sense);
private: private:
TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true); TypeId unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes = true);
// ex // ex
// TypeId id = addType(FreeTypeVar()); // TypeId id = addType(FreeTypeVar());

View File

@ -13,7 +13,10 @@ namespace Luau
using ScopePtr = std::shared_ptr<struct Scope>; using ScopePtr = std::shared_ptr<struct Scope>;
std::optional<TypeId> findMetatableEntry(ErrorVec& errors, TypeId type, std::string entry, Location location); std::optional<TypeId> findMetatableEntry(ErrorVec& errors, TypeId type, const std::string& entry, Location location);
std::optional<TypeId> findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, Name name, Location location); std::optional<TypeId> findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, const std::string& name, Location location);
std::optional<TypeId> getIndexTypeFromType(
const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, TypeId type, const std::string& prop, const Location& location, bool addErrors,
InternalErrorReporter& handle);
} // namespace Luau } // namespace Luau

View File

@ -10,7 +10,7 @@ namespace Luau
{ {
void* pagedAllocate(size_t size); void* pagedAllocate(size_t size);
void pagedDeallocate(void* ptr); void pagedDeallocate(void* ptr, size_t size);
void pagedFreeze(void* ptr, size_t size); void pagedFreeze(void* ptr, size_t size);
void pagedUnfreeze(void* ptr, size_t size); void pagedUnfreeze(void* ptr, size_t size);
@ -113,7 +113,7 @@ private:
for (size_t i = 0; i < blockSize; ++i) for (size_t i = 0; i < blockSize; ++i)
block[i].~T(); block[i].~T();
pagedDeallocate(block); pagedDeallocate(block, kBlockSizeBytes);
} }
stuff.clear(); stuff.clear();

View File

@ -3,9 +3,10 @@
#include "Luau/Error.h" #include "Luau/Error.h"
#include "Luau/Location.h" #include "Luau/Location.h"
#include "Luau/Scope.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeArena.h" #include "Luau/TypeArena.h"
#include "Luau/TypeInfer.h"
#include "Luau/UnifierSharedState.h" #include "Luau/UnifierSharedState.h"
#include <unordered_set> #include <unordered_set>
@ -48,6 +49,7 @@ struct Unifier
TypeArena* const types; TypeArena* const types;
Mode mode; Mode mode;
NotNull<Scope> scope; // const Scope maybe
TxnLog log; TxnLog log;
ErrorVec errors; ErrorVec errors;
Location location; Location location;
@ -57,7 +59,7 @@ struct Unifier
UnifierSharedState& sharedState; UnifierSharedState& sharedState;
Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); Unifier(TypeArena* types, Mode mode, NotNull<Scope> scope, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr);
// Test whether the two type vars unify. Never commits the result. // Test whether the two type vars unify. Never commits the result.
ErrorVec canUnify(TypeId subTy, TypeId superTy); ErrorVec canUnify(TypeId subTy, TypeId superTy);

View File

@ -69,12 +69,14 @@ struct GenericTypeVarVisitor
using Set = S; using Set = S;
Set seen; Set seen;
bool skipBoundTypes = false;
int recursionCounter = 0; int recursionCounter = 0;
GenericTypeVarVisitor() = default; GenericTypeVarVisitor() = default;
explicit GenericTypeVarVisitor(Set seen) explicit GenericTypeVarVisitor(Set seen, bool skipBoundTypes = false)
: seen(std::move(seen)) : seen(std::move(seen))
, skipBoundTypes(skipBoundTypes)
{ {
} }
@ -199,7 +201,9 @@ struct GenericTypeVarVisitor
if (auto btv = get<BoundTypeVar>(ty)) if (auto btv = get<BoundTypeVar>(ty))
{ {
if (visit(ty, *btv)) if (skipBoundTypes)
traverse(btv->boundTo);
else if (visit(ty, *btv))
traverse(btv->boundTo); traverse(btv->boundTo);
} }
else if (auto ftv = get<FreeTypeVar>(ty)) else if (auto ftv = get<FreeTypeVar>(ty))
@ -229,7 +233,11 @@ struct GenericTypeVarVisitor
else if (auto ttv = get<TableTypeVar>(ty)) else if (auto ttv = get<TableTypeVar>(ty))
{ {
// Some visitors want to see bound tables, that's why we traverse the original type // Some visitors want to see bound tables, that's why we traverse the original type
if (visit(ty, *ttv)) if (skipBoundTypes && ttv->boundTo)
{
traverse(*ttv->boundTo);
}
else if (visit(ty, *ttv))
{ {
if (ttv->boundTo) if (ttv->boundTo)
{ {
@ -394,13 +402,17 @@ struct GenericTypeVarVisitor
*/ */
struct TypeVarVisitor : GenericTypeVarVisitor<std::unordered_set<void*>> struct TypeVarVisitor : GenericTypeVarVisitor<std::unordered_set<void*>>
{ {
explicit TypeVarVisitor(bool skipBoundTypes = false)
: GenericTypeVarVisitor{{}, skipBoundTypes}
{
}
}; };
/// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it. /// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it.
struct TypeVarOnceVisitor : GenericTypeVarVisitor<DenseHashSet<void*>> struct TypeVarOnceVisitor : GenericTypeVarVisitor<DenseHashSet<void*>>
{ {
TypeVarOnceVisitor() explicit TypeVarOnceVisitor(bool skipBoundTypes = false)
: GenericTypeVarVisitor{DenseHashSet<void*>{nullptr}} : GenericTypeVarVisitor{DenseHashSet<void*>{nullptr}, skipBoundTypes}
{ {
} }
}; };

View File

@ -0,0 +1,96 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/Anyification.h"
#include "Luau/Common.h"
#include "Luau/Normalize.h"
#include "Luau/TxnLog.h"
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau
{
Anyification::Anyification(TypeArena* arena, const ScopePtr& scope, InternalErrorReporter* iceHandler, TypeId anyType, TypePackId anyTypePack)
: Substitution(TxnLog::empty(), arena)
, scope(NotNull{scope.get()})
, iceHandler(iceHandler)
, anyType(anyType)
, anyTypePack(anyTypePack)
{
}
bool Anyification::isDirty(TypeId ty)
{
if (ty->persistent)
return false;
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed);
else if (log->getMutable<FreeTypeVar>(ty))
return true;
else if (get<ConstrainedTypeVar>(ty))
return true;
else
return false;
}
bool Anyification::isDirty(TypePackId tp)
{
if (tp->persistent)
return false;
if (log->getMutable<FreeTypePack>(tp))
return true;
else
return false;
}
TypeId Anyification::clean(TypeId ty)
{
LUAU_ASSERT(isDirty(ty));
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
{
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed};
clone.definitionModuleName = ttv->definitionModuleName;
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.tags = ttv->tags;
TypeId res = addType(std::move(clone));
asMutable(res)->normal = ty->normal;
return res;
}
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
std::vector<TypeId> copy = ctv->parts;
for (TypeId& ty : copy)
ty = replace(ty);
TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)});
auto [t, ok] = normalize(res, scope, *arena, *iceHandler);
if (!ok)
normalizationTooComplex = true;
return t;
}
else
return anyType;
}
TypePackId Anyification::clean(TypePackId tp)
{
LUAU_ASSERT(isDirty(tp));
return anyTypePack;
}
bool Anyification::ignoreChildren(TypeId ty)
{
if (FFlag::LuauClassTypeVarsInSubstitution && get<ClassTypeVar>(ty))
return true;
return ty->persistent;
}
bool Anyification::ignoreChildren(TypePackId ty)
{
return ty->persistent;
}
}

View File

@ -14,6 +14,8 @@
LUAU_FASTFLAG(LuauSelfCallAutocompleteFix3) LUAU_FASTFLAG(LuauSelfCallAutocompleteFix3)
LUAU_FASTFLAGVARIABLE(LuauAutocompleteFixGlobalOrder, false)
static const std::unordered_set<std::string> kStatementStartingKeywords = { static const std::unordered_set<std::string> kStatementStartingKeywords = {
"while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"};
@ -135,11 +137,11 @@ static std::optional<TypeId> findExpectedTypeAt(const Module& module, AstNode* n
return *it; return *it;
} }
static bool checkTypeMatch(TypeArena* typeArena, TypeId subTy, TypeId superTy) static bool checkTypeMatch(TypeId subTy, TypeId superTy, NotNull<Scope> scope, TypeArena* typeArena)
{ {
InternalErrorReporter iceReporter; InternalErrorReporter iceReporter;
UnifierSharedState unifierState(&iceReporter); UnifierSharedState unifierState(&iceReporter);
Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); Unifier unifier(typeArena, Mode::Strict, scope, Location(), Variance::Covariant, unifierState);
return unifier.canUnify(subTy, superTy).empty(); return unifier.canUnify(subTy, superTy).empty();
} }
@ -148,12 +150,14 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
{ {
ty = follow(ty); ty = follow(ty);
auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { NotNull<Scope> moduleScope{module.getModuleScope().get()};
auto canUnify = [&typeArena, moduleScope](TypeId subTy, TypeId superTy) {
LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3); LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix3);
InternalErrorReporter iceReporter; InternalErrorReporter iceReporter;
UnifierSharedState unifierState(&iceReporter); UnifierSharedState unifierState(&iceReporter);
Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); Unifier unifier(typeArena, Mode::Strict, moduleScope, Location(), Variance::Covariant, unifierState);
unifier.tryUnify(subTy, superTy); unifier.tryUnify(subTy, superTy);
bool ok = unifier.errors.empty(); bool ok = unifier.errors.empty();
@ -167,11 +171,11 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
TypeId expectedType = follow(*typeAtPosition); TypeId expectedType = follow(*typeAtPosition);
auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { auto checkFunctionType = [typeArena, moduleScope, &canUnify, &expectedType](const FunctionTypeVar* ftv) {
if (FFlag::LuauSelfCallAutocompleteFix3) if (FFlag::LuauSelfCallAutocompleteFix3)
{ {
if (std::optional<TypeId> firstRetTy = first(ftv->retTypes)) if (std::optional<TypeId> firstRetTy = first(ftv->retTypes))
return checkTypeMatch(typeArena, *firstRetTy, expectedType); return checkTypeMatch(*firstRetTy, expectedType, moduleScope, typeArena);
return false; return false;
} }
@ -210,7 +214,7 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ
} }
if (FFlag::LuauSelfCallAutocompleteFix3) if (FFlag::LuauSelfCallAutocompleteFix3)
return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; return checkTypeMatch(ty, expectedType, NotNull{module.getModuleScope().get()}, typeArena) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
else else
return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None;
} }
@ -268,7 +272,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
return colonIndex; return colonIndex;
} }
}; };
auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) { auto isWrongIndexer = [typeArena, &module, rootTy, indexType](Luau::TypeId type) {
LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix3); LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix3);
if (indexType == PropIndexType::Key) if (indexType == PropIndexType::Key)
@ -276,7 +280,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
bool calledWithSelf = indexType == PropIndexType::Colon; bool calledWithSelf = indexType == PropIndexType::Colon;
auto isCompatibleCall = [typeArena, rootTy, calledWithSelf](const FunctionTypeVar* ftv) { auto isCompatibleCall = [typeArena, &module, rootTy, calledWithSelf](const FunctionTypeVar* ftv) {
// Strong match with definition is a success // Strong match with definition is a success
if (calledWithSelf == ftv->hasSelf) if (calledWithSelf == ftv->hasSelf)
return true; return true;
@ -289,7 +293,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId
// When called with '.', but declared with 'self', it is considered invalid if first argument is compatible // When called with '.', but declared with 'self', it is considered invalid if first argument is compatible
if (std::optional<TypeId> firstArgTy = first(ftv->argTypes)) if (std::optional<TypeId> firstArgTy = first(ftv->argTypes))
{ {
if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) if (checkTypeMatch(rootTy, *firstArgTy, NotNull{module.getModuleScope().get()}, typeArena))
return calledWithSelf; return calledWithSelf;
} }
@ -1073,10 +1077,21 @@ T* extractStat(const std::vector<AstNode*>& ancestry)
return nullptr; return nullptr;
} }
static bool isBindingLegalAtCurrentPosition(const Binding& binding, Position pos) static bool isBindingLegalAtCurrentPosition(const Symbol& symbol, const Binding& binding, Position pos)
{ {
// Default Location used for global bindings, which are always legal. if (FFlag::LuauAutocompleteFixGlobalOrder)
return binding.location == Location() || binding.location.end < pos; {
if (symbol.local)
return binding.location.end < pos;
// Builtin globals have an empty location; for defined globals, we want pos to be outside of the definition range to suggest it
return binding.location == Location() || !binding.location.containsClosed(pos);
}
else
{
// Default Location used for global bindings, which are always legal.
return binding.location == Location() || binding.location.end < pos;
}
} }
static AutocompleteEntryMap autocompleteStatement( static AutocompleteEntryMap autocompleteStatement(
@ -1097,7 +1112,7 @@ static AutocompleteEntryMap autocompleteStatement(
{ {
for (const auto& [name, binding] : scope->bindings) for (const auto& [name, binding] : scope->bindings)
{ {
if (!isBindingLegalAtCurrentPosition(binding, position)) if (!isBindingLegalAtCurrentPosition(name, binding, position))
continue; continue;
std::string n = toString(name); std::string n = toString(name);
@ -1225,7 +1240,7 @@ static AutocompleteContext autocompleteExpression(const SourceModule& sourceModu
{ {
for (const auto& [name, binding] : scope->bindings) for (const auto& [name, binding] : scope->bindings)
{ {
if (!isBindingLegalAtCurrentPosition(binding, position)) if (!isBindingLegalAtCurrentPosition(name, binding, position))
continue; continue;
if (isBeingDefined(ancestry, name)) if (isBeingDefined(ancestry, name))

View File

@ -5,8 +5,9 @@
namespace Luau namespace Luau
{ {
Constraint::Constraint(ConstraintV&& c) Constraint::Constraint(ConstraintV&& c, NotNull<Scope> scope)
: c(std::move(c)) : c(std::move(c))
, scope(scope)
{ {
} }

View File

@ -2,6 +2,7 @@
#include "Luau/ConstraintGraphBuilder.h" #include "Luau/ConstraintGraphBuilder.h"
#include "Luau/Ast.h" #include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/Constraint.h" #include "Luau/Constraint.h"
#include "Luau/RecursionCounter.h" #include "Luau/RecursionCounter.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
@ -26,6 +27,7 @@ ConstraintGraphBuilder::ConstraintGraphBuilder(
, globalScope(globalScope) , globalScope(globalScope)
{ {
LUAU_ASSERT(arena); LUAU_ASSERT(arena);
LUAU_ASSERT(module);
} }
TypeId ConstraintGraphBuilder::freshType(const ScopePtr& scope) TypeId ConstraintGraphBuilder::freshType(const ScopePtr& scope)
@ -39,20 +41,22 @@ TypePackId ConstraintGraphBuilder::freshTypePack(const ScopePtr& scope)
return arena->addTypePack(TypePackVar{std::move(f)}); return arena->addTypePack(TypePackVar{std::move(f)});
} }
ScopePtr ConstraintGraphBuilder::childScope(Location location, const ScopePtr& parent) ScopePtr ConstraintGraphBuilder::childScope(AstNode* node, const ScopePtr& parent)
{ {
auto scope = std::make_shared<Scope>(parent); auto scope = std::make_shared<Scope>(parent);
scopes.emplace_back(location, scope); scopes.emplace_back(node->location, scope);
scope->returnType = parent->returnType; scope->returnType = parent->returnType;
parent->children.push_back(NotNull(scope.get()));
parent->children.push_back(NotNull{scope.get()});
module->astScopes[node] = scope.get();
return scope; return scope;
} }
void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, ConstraintV cv) void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, ConstraintV cv)
{ {
scope->constraints.emplace_back(new Constraint{std::move(cv)}); scope->constraints.emplace_back(new Constraint{std::move(cv), NotNull{scope.get()}});
} }
void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c) void ConstraintGraphBuilder::addConstraint(const ScopePtr& scope, std::unique_ptr<Constraint> c)
@ -67,6 +71,7 @@ void ConstraintGraphBuilder::visit(AstStatBlock* block)
ScopePtr scope = std::make_shared<Scope>(globalScope); ScopePtr scope = std::make_shared<Scope>(globalScope);
rootScope = scope.get(); rootScope = scope.get();
scopes.emplace_back(block->location, scope); scopes.emplace_back(block->location, scope);
module->astScopes[block] = NotNull{scope.get()};
rootScope->returnType = freshTypePack(scope); rootScope->returnType = freshTypePack(scope);
@ -115,7 +120,7 @@ void ConstraintGraphBuilder::visitBlockWithoutChildScope(const ScopePtr& scope,
ScopePtr defnScope = scope; ScopePtr defnScope = scope;
if (hasGenerics) if (hasGenerics)
{ {
defnScope = childScope(alias->location, scope); defnScope = childScope(alias, scope);
} }
TypeId initialType = freshType(scope); TypeId initialType = freshType(scope);
@ -155,6 +160,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat)
visit(scope, s); visit(scope, s);
else if (auto s = stat->as<AstStatWhile>()) else if (auto s = stat->as<AstStatWhile>())
visit(scope, s); visit(scope, s);
else if (auto s = stat->as<AstStatRepeat>())
visit(scope, s);
else if (auto f = stat->as<AstStatFunction>()) else if (auto f = stat->as<AstStatFunction>())
visit(scope, f); visit(scope, f);
else if (auto f = stat->as<AstStatLocalFunction>()) else if (auto f = stat->as<AstStatLocalFunction>())
@ -163,6 +170,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStat* stat)
visit(scope, r); visit(scope, r);
else if (auto a = stat->as<AstStatAssign>()) else if (auto a = stat->as<AstStatAssign>())
visit(scope, a); visit(scope, a);
else if (auto a = stat->as<AstStatCompoundAssign>())
visit(scope, a);
else if (auto e = stat->as<AstStatExpr>()) else if (auto e = stat->as<AstStatExpr>())
checkPack(scope, e->expr); checkPack(scope, e->expr);
else if (auto i = stat->as<AstStatIf>()) else if (auto i = stat->as<AstStatIf>())
@ -241,7 +250,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFor* for_)
checkNumber(for_->to); checkNumber(for_->to);
checkNumber(for_->step); checkNumber(for_->step);
ScopePtr forScope = childScope(for_->location, scope); ScopePtr forScope = childScope(for_, scope);
forScope->bindings[for_->var] = Binding{singletonTypes.numberType, for_->var->location}; forScope->bindings[for_->var] = Binding{singletonTypes.numberType, for_->var->location};
visit(forScope, for_->body); visit(forScope, for_->body);
@ -251,11 +260,22 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatWhile* while_)
{ {
check(scope, while_->condition); check(scope, while_->condition);
ScopePtr whileScope = childScope(while_->location, scope); ScopePtr whileScope = childScope(while_, scope);
visit(whileScope, while_->body); visit(whileScope, while_->body);
} }
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatRepeat* repeat)
{
ScopePtr repeatScope = childScope(repeat, scope);
visit(repeatScope, repeat->body);
// The condition does indeed have access to bindings from within the body of
// the loop.
check(repeatScope, repeat->condition);
}
void addConstraints(Constraint* constraint, NotNull<Scope> scope) void addConstraints(Constraint* constraint, NotNull<Scope> scope)
{ {
scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); scope->constraints.reserve(scope->constraints.size() + scope->constraints.size());
@ -286,8 +306,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocalFunction*
checkFunctionBody(sig.bodyScope, function->func); checkFunctionBody(sig.bodyScope, function->func);
std::unique_ptr<Constraint> c{ NotNull<Scope> constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()};
new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}}}; std::unique_ptr<Constraint> c = std::make_unique<Constraint>(GeneralizationConstraint{functionType, sig.signature}, constraintScope);
addConstraints(c.get(), NotNull(sig.bodyScope.get())); addConstraints(c.get(), NotNull(sig.bodyScope.get()));
addConstraint(scope, std::move(c)); addConstraint(scope, std::move(c));
@ -356,8 +376,8 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatFunction* funct
checkFunctionBody(sig.bodyScope, function->func); checkFunctionBody(sig.bodyScope, function->func);
std::unique_ptr<Constraint> c{ NotNull<Scope> constraintScope{sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()};
new Constraint{GeneralizationConstraint{functionType, sig.signature, sig.signatureScope ? sig.signatureScope.get() : sig.bodyScope.get()}}}; std::unique_ptr<Constraint> c = std::make_unique<Constraint>(GeneralizationConstraint{functionType, sig.signature}, constraintScope);
addConstraints(c.get(), NotNull(sig.bodyScope.get())); addConstraints(c.get(), NotNull(sig.bodyScope.get()));
addConstraint(scope, std::move(c)); addConstraint(scope, std::move(c));
@ -371,7 +391,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatReturn* ret)
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatBlock* block)
{ {
ScopePtr innerScope = childScope(block->location, scope); ScopePtr innerScope = childScope(block, scope);
visitBlockWithoutChildScope(innerScope, block); visitBlockWithoutChildScope(innerScope, block);
} }
@ -384,16 +404,30 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatAssign* assign)
addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId}); addConstraint(scope, PackSubtypeConstraint{valuePack, varPackId});
} }
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatCompoundAssign* assign)
{
// Synthesize A = A op B from A op= B and then build constraints for that instead.
AstExprBinary exprBinary{assign->location, assign->op, assign->var, assign->value};
AstExpr* exprBinaryPtr = &exprBinary;
AstArray<AstExpr*> vars{&assign->var, 1};
AstArray<AstExpr*> values{&exprBinaryPtr, 1};
AstStatAssign syntheticAssign{assign->location, vars, values};
visit(scope, &syntheticAssign);
}
void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement) void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatIf* ifStatement)
{ {
check(scope, ifStatement->condition); check(scope, ifStatement->condition);
ScopePtr thenScope = childScope(ifStatement->thenbody->location, scope); ScopePtr thenScope = childScope(ifStatement->thenbody, scope);
visit(thenScope, ifStatement->thenbody); visit(thenScope, ifStatement->thenbody);
if (ifStatement->elsebody) if (ifStatement->elsebody)
{ {
ScopePtr elseScope = childScope(ifStatement->elsebody->location, scope); ScopePtr elseScope = childScope(ifStatement->elsebody, scope);
visit(elseScope, ifStatement->elsebody); visit(elseScope, ifStatement->elsebody);
} }
} }
@ -561,7 +595,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatDeclareFunction
ScopePtr funScope = scope; ScopePtr funScope = scope;
if (!generics.empty() || !genericPacks.empty()) if (!generics.empty() || !genericPacks.empty())
funScope = childScope(global->location, scope); funScope = childScope(global, scope);
TypePackId paramPack = resolveTypePack(funScope, global->params); TypePackId paramPack = resolveTypePack(funScope, global->params);
TypePackId retPack = resolveTypePack(funScope, global->retTypes); TypePackId retPack = resolveTypePack(funScope, global->retTypes);
@ -739,6 +773,8 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExpr* expr)
result = check(scope, unary); result = check(scope, unary);
else if (auto binary = expr->as<AstExprBinary>()) else if (auto binary = expr->as<AstExprBinary>())
result = check(scope, binary); result = check(scope, binary);
else if (auto ifElse = expr->as<AstExprIfElse>())
result = check(scope, ifElse);
else if (auto typeAssert = expr->as<AstExprTypeAssertion>()) else if (auto typeAssert = expr->as<AstExprTypeAssertion>())
result = check(scope, typeAssert); result = check(scope, typeAssert);
else if (auto err = expr->as<AstExprError>()) else if (auto err = expr->as<AstExprError>())
@ -819,6 +855,12 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binar
addConstraint(scope, SubtypeConstraint{leftType, rightType}); addConstraint(scope, SubtypeConstraint{leftType, rightType});
return leftType; return leftType;
} }
case AstExprBinary::Add:
{
TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, BinaryConstraint{AstExprBinary::Add, leftType, rightType, resultType});
return resultType;
}
case AstExprBinary::Sub: case AstExprBinary::Sub:
{ {
TypeId resultType = arena->addType(BlockedTypeVar{}); TypeId resultType = arena->addType(BlockedTypeVar{});
@ -833,6 +875,24 @@ TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprBinary* binar
return nullptr; return nullptr;
} }
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprIfElse* ifElse)
{
check(scope, ifElse->condition);
TypeId thenType = check(scope, ifElse->trueExpr);
TypeId elseType = check(scope, ifElse->falseExpr);
if (ifElse->hasElse)
{
TypeId resultType = arena->addType(BlockedTypeVar{});
addConstraint(scope, SubtypeConstraint{thenType, resultType});
addConstraint(scope, SubtypeConstraint{elseType, resultType});
return resultType;
}
return thenType;
}
TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert) TypeId ConstraintGraphBuilder::check(const ScopePtr& scope, AstExprTypeAssertion* typeAssert)
{ {
check(scope, typeAssert->expr); check(scope, typeAssert->expr);
@ -905,14 +965,14 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
// generics properly. // generics properly.
if (hasGenerics) if (hasGenerics)
{ {
signatureScope = childScope(fn->location, parent); signatureScope = childScope(fn, parent);
// We need to assign returnType before creating bodyScope so that the // We need to assign returnType before creating bodyScope so that the
// return type gets propogated to bodyScope. // return type gets propogated to bodyScope.
returnType = freshTypePack(signatureScope); returnType = freshTypePack(signatureScope);
signatureScope->returnType = returnType; signatureScope->returnType = returnType;
bodyScope = childScope(fn->body->location, signatureScope); bodyScope = childScope(fn->body, signatureScope);
std::vector<std::pair<Name, GenericTypeDefinition>> genericDefinitions = createGenerics(signatureScope, fn->generics); std::vector<std::pair<Name, GenericTypeDefinition>> genericDefinitions = createGenerics(signatureScope, fn->generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); std::vector<std::pair<Name, GenericTypePackDefinition>> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks);
@ -933,7 +993,7 @@ ConstraintGraphBuilder::FunctionSignature ConstraintGraphBuilder::checkFunctionS
} }
else else
{ {
bodyScope = childScope(fn->body->location, parent); bodyScope = childScope(fn->body, parent);
returnType = freshTypePack(bodyScope); returnType = freshTypePack(bodyScope);
bodyScope->returnType = returnType; bodyScope->returnType = returnType;
@ -1098,7 +1158,7 @@ TypeId ConstraintGraphBuilder::resolveType(const ScopePtr& scope, AstType* ty, b
// for the generic bindings to live on. // for the generic bindings to live on.
if (hasGenerics) if (hasGenerics)
{ {
signatureScope = childScope(fn->location, scope); signatureScope = childScope(fn, scope);
std::vector<std::pair<Name, GenericTypeDefinition>> genericDefinitions = createGenerics(signatureScope, fn->generics); std::vector<std::pair<Name, GenericTypeDefinition>> genericDefinitions = createGenerics(signatureScope, fn->generics);
std::vector<std::pair<Name, GenericTypePackDefinition>> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks); std::vector<std::pair<Name, GenericTypePackDefinition>> genericPackDefinitions = createGenericPacks(signatureScope, fn->genericPacks);

View File

@ -373,7 +373,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull<const Con
else if (isBlocked(c.superType)) else if (isBlocked(c.superType))
return block(c.superType, constraint); return block(c.superType, constraint);
unify(c.subType, c.superType); unify(c.subType, c.superType, constraint->scope);
unblock(c.subType); unblock(c.subType);
unblock(c.superType); unblock(c.superType);
@ -383,7 +383,7 @@ bool ConstraintSolver::tryDispatch(const SubtypeConstraint& c, NotNull<const Con
bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force) bool ConstraintSolver::tryDispatch(const PackSubtypeConstraint& c, NotNull<const Constraint> constraint, bool force)
{ {
unify(c.subPack, c.superPack); unify(c.subPack, c.superPack, constraint->scope);
unblock(c.subPack); unblock(c.subPack);
unblock(c.superPack); unblock(c.superPack);
@ -398,9 +398,9 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNull<co
if (isBlocked(c.generalizedType)) if (isBlocked(c.generalizedType))
asMutable(c.generalizedType)->ty.emplace<BoundTypeVar>(c.sourceType); asMutable(c.generalizedType)->ty.emplace<BoundTypeVar>(c.sourceType);
else else
unify(c.generalizedType, c.sourceType); unify(c.generalizedType, c.sourceType, constraint->scope);
TypeId generalized = quantify(arena, c.sourceType, c.scope); TypeId generalized = quantify(arena, c.sourceType, constraint->scope);
*asMutable(c.sourceType) = *generalized; *asMutable(c.sourceType) = *generalized;
unblock(c.generalizedType); unblock(c.generalizedType);
@ -422,7 +422,7 @@ bool ConstraintSolver::tryDispatch(const InstantiationConstraint& c, NotNull<con
if (isBlocked(c.subType)) if (isBlocked(c.subType))
asMutable(c.subType)->ty.emplace<BoundTypeVar>(*instantiated); asMutable(c.subType)->ty.emplace<BoundTypeVar>(*instantiated);
else else
unify(c.subType, *instantiated); unify(c.subType, *instantiated, constraint->scope);
unblock(c.subType); unblock(c.subType);
@ -465,7 +465,7 @@ bool ConstraintSolver::tryDispatch(const BinaryConstraint& c, NotNull<const Cons
if (isNumber(leftType)) if (isNumber(leftType))
{ {
unify(leftType, rightType); unify(leftType, rightType, constraint->scope);
asMutable(c.resultType)->ty.emplace<BoundTypeVar>(leftType); asMutable(c.resultType)->ty.emplace<BoundTypeVar>(leftType);
return true; return true;
} }
@ -528,16 +528,18 @@ struct InstantiationQueuer : TypeVarOnceVisitor
{ {
ConstraintSolver* solver; ConstraintSolver* solver;
const InstantiationSignature& signature; const InstantiationSignature& signature;
NotNull<Scope> scope;
explicit InstantiationQueuer(ConstraintSolver* solver, const InstantiationSignature& signature) explicit InstantiationQueuer(ConstraintSolver* solver, const InstantiationSignature& signature, NotNull<Scope> scope)
: solver(solver) : solver(solver)
, signature(signature) , signature(signature)
, scope(scope)
{ {
} }
bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override bool visit(TypeId ty, const PendingExpansionTypeVar& petv) override
{ {
solver->pushConstraint(TypeAliasExpansionConstraint{ty}); solver->pushConstraint(TypeAliasExpansionConstraint{ty}, scope);
return false; return false;
} }
}; };
@ -686,7 +688,7 @@ bool ConstraintSolver::tryDispatch(const TypeAliasExpansionConstraint& c, NotNul
// The application is not recursive, so we need to queue up application of // The application is not recursive, so we need to queue up application of
// any child type function instantiations within the result in order for it // any child type function instantiations within the result in order for it
// to be complete. // to be complete.
InstantiationQueuer queuer{this, signature}; InstantiationQueuer queuer{this, signature, constraint->scope};
queuer.traverse(target); queuer.traverse(target);
instantiatedAliases[signature] = target; instantiatedAliases[signature] = target;
@ -766,30 +768,40 @@ bool ConstraintSolver::isBlocked(NotNull<const Constraint> constraint)
return blockedIt != blockedConstraints.end() && blockedIt->second > 0; return blockedIt != blockedConstraints.end() && blockedIt->second > 0;
} }
void ConstraintSolver::unify(TypeId subType, TypeId superType) void ConstraintSolver::unify(TypeId subType, TypeId superType, NotNull<Scope> scope)
{ {
UnifierSharedState sharedState{&iceReporter}; UnifierSharedState sharedState{&iceReporter};
Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; Unifier u{arena, Mode::Strict, scope, Location{}, Covariant, sharedState};
u.tryUnify(subType, superType); u.tryUnify(subType, superType);
u.log.commit(); u.log.commit();
} }
void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack) void ConstraintSolver::unify(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope)
{ {
UnifierSharedState sharedState{&iceReporter}; UnifierSharedState sharedState{&iceReporter};
Unifier u{arena, Mode::Strict, Location{}, Covariant, sharedState}; Unifier u{arena, Mode::Strict, scope, Location{}, Covariant, sharedState};
u.tryUnify(subPack, superPack); u.tryUnify(subPack, superPack);
u.log.commit(); u.log.commit();
} }
void ConstraintSolver::pushConstraint(ConstraintV cv) void ConstraintSolver::pushConstraint(ConstraintV cv, NotNull<Scope> scope)
{ {
std::unique_ptr<Constraint> c = std::make_unique<Constraint>(std::move(cv)); std::unique_ptr<Constraint> c = std::make_unique<Constraint>(std::move(cv), scope);
NotNull<Constraint> borrow = NotNull(c.get()); NotNull<Constraint> borrow = NotNull(c.get());
solverConstraints.push_back(std::move(c)); solverConstraints.push_back(std::move(c));
unsolvedConstraints.push_back(borrow); unsolvedConstraints.push_back(borrow);
} }
void ConstraintSolver::reportError(TypeErrorData&& data, const Location& location)
{
errors.emplace_back(location, std::move(data));
}
void ConstraintSolver::reportError(TypeError e)
{
errors.emplace_back(std::move(e));
}
} // namespace Luau } // namespace Luau

View File

@ -839,6 +839,9 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, const Sco
ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope)}; ConstraintSolver cs{&result->internalTypes, NotNull(cgb.rootScope)};
cs.run(); cs.run();
for (TypeError& e : cs.errors)
result->errors.emplace_back(std::move(e));
result->scopes = std::move(cgb.scopes); result->scopes = std::move(cgb.scopes);
result->astTypes = std::move(cgb.astTypes); result->astTypes = std::move(cgb.astTypes);
result->astTypePacks = std::move(cgb.astTypePacks); result->astTypePacks = std::move(cgb.astTypePacks);

View File

@ -244,12 +244,12 @@ void Module::clonePublicInterface(InternalErrorReporter& ice)
if (FFlag::LuauLowerBoundsCalculation) if (FFlag::LuauLowerBoundsCalculation)
{ {
normalize(returnType, interfaceTypes, ice); normalize(returnType, NotNull{this}, ice);
if (FFlag::LuauForceExportSurfacesToBeNormal) if (FFlag::LuauForceExportSurfacesToBeNormal)
forceNormal.traverse(returnType); forceNormal.traverse(returnType);
if (varargPack) if (varargPack)
{ {
normalize(*varargPack, interfaceTypes, ice); normalize(*varargPack, NotNull{this}, ice);
if (FFlag::LuauForceExportSurfacesToBeNormal) if (FFlag::LuauForceExportSurfacesToBeNormal)
forceNormal.traverse(*varargPack); forceNormal.traverse(*varargPack);
} }
@ -265,7 +265,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice)
tf = clone(tf, interfaceTypes, cloneState); tf = clone(tf, interfaceTypes, cloneState);
if (FFlag::LuauLowerBoundsCalculation) if (FFlag::LuauLowerBoundsCalculation)
{ {
normalize(tf.type, interfaceTypes, ice); normalize(tf.type, NotNull{this}, ice);
// We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables // We're about to freeze the memory. We know that the flag is conservative by design. Cyclic tables
// won't be marked normal. If the types aren't normal by now, they never will be. // won't be marked normal. If the types aren't normal by now, they never will be.
@ -276,7 +276,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice)
if (param.defaultValue) if (param.defaultValue)
{ {
normalize(*param.defaultValue, interfaceTypes, ice); normalize(*param.defaultValue, NotNull{this}, ice);
forceNormal.traverse(*param.defaultValue); forceNormal.traverse(*param.defaultValue);
} }
} }
@ -302,7 +302,7 @@ void Module::clonePublicInterface(InternalErrorReporter& ice)
ty = clone(ty, interfaceTypes, cloneState); ty = clone(ty, interfaceTypes, cloneState);
if (FFlag::LuauLowerBoundsCalculation) if (FFlag::LuauLowerBoundsCalculation)
{ {
normalize(ty, interfaceTypes, ice); normalize(ty, NotNull{this}, ice);
if (FFlag::LuauForceExportSurfacesToBeNormal) if (FFlag::LuauForceExportSurfacesToBeNormal)
forceNormal.traverse(ty); forceNormal.traverse(ty);

View File

@ -15,7 +15,6 @@ LUAU_FASTINTVARIABLE(LuauNormalizeIterationLimit, 1200);
LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false); LUAU_FASTFLAGVARIABLE(LuauNormalizeCombineTableFix, false);
LUAU_FASTFLAGVARIABLE(LuauFixNormalizationOfCyclicUnions, false); LUAU_FASTFLAGVARIABLE(LuauFixNormalizationOfCyclicUnions, false);
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(LuauQuantifyConstrained)
namespace Luau namespace Luau
{ {
@ -55,11 +54,11 @@ struct Replacer
} // anonymous namespace } // anonymous namespace
bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice) bool isSubtype(TypeId subTy, TypeId superTy, NotNull<Scope> scope, InternalErrorReporter& ice)
{ {
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
TypeArena arena; TypeArena arena;
Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; Unifier u{&arena, Mode::Strict, scope, Location{}, Covariant, sharedState};
u.anyIsTop = true; u.anyIsTop = true;
u.tryUnify(subTy, superTy); u.tryUnify(subTy, superTy);
@ -67,11 +66,11 @@ bool isSubtype(TypeId subTy, TypeId superTy, InternalErrorReporter& ice)
return ok; return ok;
} }
bool isSubtype(TypePackId subPack, TypePackId superPack, InternalErrorReporter& ice) bool isSubtype(TypePackId subPack, TypePackId superPack, NotNull<Scope> scope, InternalErrorReporter& ice)
{ {
UnifierSharedState sharedState{&ice}; UnifierSharedState sharedState{&ice};
TypeArena arena; TypeArena arena;
Unifier u{&arena, Mode::Strict, Location{}, Covariant, sharedState}; Unifier u{&arena, Mode::Strict, scope, Location{}, Covariant, sharedState};
u.anyIsTop = true; u.anyIsTop = true;
u.tryUnify(subPack, superPack); u.tryUnify(subPack, superPack);
@ -134,13 +133,15 @@ struct Normalize final : TypeVarVisitor
{ {
using TypeVarVisitor::Set; using TypeVarVisitor::Set;
Normalize(TypeArena& arena, InternalErrorReporter& ice) Normalize(TypeArena& arena, NotNull<Scope> scope, InternalErrorReporter& ice)
: arena(arena) : arena(arena)
, scope(scope)
, ice(ice) , ice(ice)
{ {
} }
TypeArena& arena; TypeArena& arena;
NotNull<Scope> scope;
InternalErrorReporter& ice; InternalErrorReporter& ice;
int iterationLimit = 0; int iterationLimit = 0;
@ -215,22 +216,7 @@ struct Normalize final : TypeVarVisitor
traverse(part); traverse(part);
std::vector<TypeId> newParts = normalizeUnion(parts); std::vector<TypeId> newParts = normalizeUnion(parts);
ctv->parts = std::move(newParts);
if (FFlag::LuauQuantifyConstrained)
{
ctv->parts = std::move(newParts);
}
else
{
const bool normal = areNormal(newParts, seen, ice);
if (newParts.size() == 1)
*asMutable(ty) = BoundTypeVar{newParts[0]};
else
*asMutable(ty) = UnionTypeVar{std::move(newParts)};
asMutable(ty)->normal = normal;
}
return false; return false;
} }
@ -288,12 +274,7 @@ struct Normalize final : TypeVarVisitor
} }
// An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal. // An unsealed table can never be normal, ditto for free tables iff the type it is bound to is also not normal.
if (FFlag::LuauQuantifyConstrained) if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal))
{
if (ttv.state == TableState::Generic || ttv.state == TableState::Sealed || (ttv.state == TableState::Free && follow(ty)->normal))
asMutable(ty)->normal = normal;
}
else
asMutable(ty)->normal = normal; asMutable(ty)->normal = normal;
return false; return false;
@ -518,9 +499,9 @@ struct Normalize final : TypeVarVisitor
for (TypeId& part : result) for (TypeId& part : result)
{ {
if (isSubtype(ty, part, ice)) if (isSubtype(ty, part, scope, ice))
return; // no need to do anything return; // no need to do anything
else if (isSubtype(part, ty, ice)) else if (isSubtype(part, ty, scope, ice))
{ {
part = ty; // replace the less general type by the more general one part = ty; // replace the less general type by the more general one
return; return;
@ -572,12 +553,12 @@ struct Normalize final : TypeVarVisitor
bool merged = false; bool merged = false;
for (TypeId& part : result->parts) for (TypeId& part : result->parts)
{ {
if (isSubtype(part, ty, ice)) if (isSubtype(part, ty, scope, ice))
{ {
merged = true; merged = true;
break; // no need to do anything break; // no need to do anything
} }
else if (isSubtype(ty, part, ice)) else if (isSubtype(ty, part, scope, ice))
{ {
merged = true; merged = true;
part = ty; // replace the less general type by the more general one part = ty; // replace the less general type by the more general one
@ -710,13 +691,13 @@ struct Normalize final : TypeVarVisitor
/** /**
* @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully)
*/ */
std::pair<TypeId, bool> normalize(TypeId ty, TypeArena& arena, InternalErrorReporter& ice) std::pair<TypeId, bool> normalize(TypeId ty, NotNull<Scope> scope, TypeArena& arena, InternalErrorReporter& ice)
{ {
CloneState state; CloneState state;
if (FFlag::DebugLuauCopyBeforeNormalizing) if (FFlag::DebugLuauCopyBeforeNormalizing)
(void)clone(ty, arena, state); (void)clone(ty, arena, state);
Normalize n{arena, ice}; Normalize n{arena, scope, ice};
n.traverse(ty); n.traverse(ty);
return {ty, !n.limitExceeded}; return {ty, !n.limitExceeded};
@ -726,29 +707,39 @@ std::pair<TypeId, bool> normalize(TypeId ty, TypeArena& arena, InternalErrorRepo
// reclaim memory used by wantonly allocated intermediate types here. // reclaim memory used by wantonly allocated intermediate types here.
// The main wrinkle here is that we don't want clone() to copy a type if the source and dest // The main wrinkle here is that we don't want clone() to copy a type if the source and dest
// arena are the same. // arena are the same.
std::pair<TypeId, bool> normalize(TypeId ty, NotNull<Module> module, InternalErrorReporter& ice)
{
return normalize(ty, NotNull{module->getModuleScope().get()}, module->internalTypes, ice);
}
std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice) std::pair<TypeId, bool> normalize(TypeId ty, const ModulePtr& module, InternalErrorReporter& ice)
{ {
return normalize(ty, module->internalTypes, ice); return normalize(ty, NotNull{module.get()}, ice);
} }
/** /**
* @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully) * @returns A tuple of TypeId and a success indicator. (true indicates that the normalization completed successfully)
*/ */
std::pair<TypePackId, bool> normalize(TypePackId tp, TypeArena& arena, InternalErrorReporter& ice) std::pair<TypePackId, bool> normalize(TypePackId tp, NotNull<Scope> scope, TypeArena& arena, InternalErrorReporter& ice)
{ {
CloneState state; CloneState state;
if (FFlag::DebugLuauCopyBeforeNormalizing) if (FFlag::DebugLuauCopyBeforeNormalizing)
(void)clone(tp, arena, state); (void)clone(tp, arena, state);
Normalize n{arena, ice}; Normalize n{arena, scope, ice};
n.traverse(tp); n.traverse(tp);
return {tp, !n.limitExceeded}; return {tp, !n.limitExceeded};
} }
std::pair<TypePackId, bool> normalize(TypePackId tp, NotNull<Module> module, InternalErrorReporter& ice)
{
return normalize(tp, NotNull{module->getModuleScope().get()}, module->internalTypes, ice);
}
std::pair<TypePackId, bool> normalize(TypePackId tp, const ModulePtr& module, InternalErrorReporter& ice) std::pair<TypePackId, bool> normalize(TypePackId tp, const ModulePtr& module, InternalErrorReporter& ice)
{ {
return normalize(tp, module->internalTypes, ice); return normalize(tp, NotNull{module.get()}, ice);
} }
} // namespace Luau } // namespace Luau

View File

@ -10,7 +10,6 @@
LUAU_FASTFLAG(DebugLuauSharedSelf) LUAU_FASTFLAG(DebugLuauSharedSelf)
LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution); LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution);
LUAU_FASTFLAGVARIABLE(LuauQuantifyConstrained, false)
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
namespace Luau namespace Luau
@ -82,30 +81,25 @@ struct Quantifier final : TypeVarOnceVisitor
bool visit(TypeId ty, const ConstrainedTypeVar&) override bool visit(TypeId ty, const ConstrainedTypeVar&) override
{ {
if (FFlag::LuauQuantifyConstrained) ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty);
{
ConstrainedTypeVar* ctv = getMutable<ConstrainedTypeVar>(ty);
seenMutableType = true; seenMutableType = true;
if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ctv->scope) : !level.subsumes(ctv->level))
return false;
std::vector<TypeId> opts = std::move(ctv->parts);
// We might transmute, so it's not safe to rely on the builtin traversal logic
for (TypeId opt : opts)
traverse(opt);
if (opts.size() == 1)
*asMutable(ty) = BoundTypeVar{opts[0]};
else
*asMutable(ty) = UnionTypeVar{std::move(opts)};
if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ctv->scope) : !level.subsumes(ctv->level))
return false; return false;
}
std::vector<TypeId> opts = std::move(ctv->parts);
// We might transmute, so it's not safe to rely on the builtin traversal logic
for (TypeId opt : opts)
traverse(opt);
if (opts.size() == 1)
*asMutable(ty) = BoundTypeVar{opts[0]};
else else
return true; *asMutable(ty) = UnionTypeVar{std::move(opts)};
return false;
} }
bool visit(TypeId ty, const TableTypeVar&) override bool visit(TypeId ty, const TableTypeVar&) override
@ -119,12 +113,6 @@ struct Quantifier final : TypeVarOnceVisitor
if (ttv.state == TableState::Free) if (ttv.state == TableState::Free)
seenMutableType = true; seenMutableType = true;
if (!FFlag::LuauQuantifyConstrained)
{
if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic)
return false;
}
if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ttv.scope) : !level.subsumes(ttv.level)) if (FFlag::DebugLuauDeferredConstraintResolution ? !subsumes(scope, ttv.scope) : !level.subsumes(ttv.level))
{ {
if (ttv.state == TableState::Unsealed) if (ttv.state == TableState::Unsealed)

View File

@ -8,22 +8,59 @@
#include "Luau/Clone.h" #include "Luau/Clone.h"
#include "Luau/Instantiation.h" #include "Luau/Instantiation.h"
#include "Luau/Normalize.h" #include "Luau/Normalize.h"
#include "Luau/ToString.h"
#include "Luau/TxnLog.h" #include "Luau/TxnLog.h"
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/TypeVar.h" #include "Luau/TypeVar.h"
#include "Luau/Unifier.h" #include "Luau/Unifier.h"
#include "Luau/ToString.h"
namespace Luau namespace Luau
{ {
struct TypeChecker2 : public AstVisitor /* Push a scope onto the end of a stack for the lifetime of the StackPusher instance.
* TypeChecker2 uses this to maintain knowledge about which scope encloses every
* given AstNode.
*/
struct StackPusher
{
std::vector<NotNull<Scope>>* stack;
NotNull<Scope> scope;
explicit StackPusher(std::vector<NotNull<Scope>>& stack, Scope* scope)
: stack(&stack)
, scope(scope)
{
stack.push_back(NotNull{scope});
}
~StackPusher()
{
if (stack)
{
LUAU_ASSERT(stack->back() == scope);
stack->pop_back();
}
}
StackPusher(const StackPusher&) = delete;
StackPusher&& operator=(const StackPusher&) = delete;
StackPusher(StackPusher&& other)
: stack(std::exchange(other.stack, nullptr))
, scope(other.scope)
{
}
};
struct TypeChecker2
{ {
const SourceModule* sourceModule; const SourceModule* sourceModule;
Module* module; Module* module;
InternalErrorReporter ice; // FIXME accept a pointer from Frontend InternalErrorReporter ice; // FIXME accept a pointer from Frontend
SingletonTypes& singletonTypes; SingletonTypes& singletonTypes;
std::vector<NotNull<Scope>> stack;
TypeChecker2(const SourceModule* sourceModule, Module* module) TypeChecker2(const SourceModule* sourceModule, Module* module)
: sourceModule(sourceModule) : sourceModule(sourceModule)
, module(module) , module(module)
@ -31,7 +68,13 @@ struct TypeChecker2 : public AstVisitor
{ {
} }
using AstVisitor::visit; std::optional<StackPusher> pushStack(AstNode* node)
{
if (Scope** scope = module->astScopes.find(node))
return StackPusher{stack, *scope};
else
return std::nullopt;
}
TypePackId lookupPack(AstExpr* expr) TypePackId lookupPack(AstExpr* expr)
{ {
@ -118,11 +161,128 @@ struct TypeChecker2 : public AstVisitor
return bestScope; return bestScope;
} }
bool visit(AstStatLocal* local) override void visit(AstStat* stat)
{
auto pusher = pushStack(stat);
if (0)
{}
else if (auto s = stat->as<AstStatBlock>())
return visit(s);
else if (auto s = stat->as<AstStatIf>())
return visit(s);
else if (auto s = stat->as<AstStatWhile>())
return visit(s);
else if (auto s = stat->as<AstStatRepeat>())
return visit(s);
else if (auto s = stat->as<AstStatBreak>())
return visit(s);
else if (auto s = stat->as<AstStatContinue>())
return visit(s);
else if (auto s = stat->as<AstStatReturn>())
return visit(s);
else if (auto s = stat->as<AstStatExpr>())
return visit(s);
else if (auto s = stat->as<AstStatLocal>())
return visit(s);
else if (auto s = stat->as<AstStatFor>())
return visit(s);
else if (auto s = stat->as<AstStatForIn>())
return visit(s);
else if (auto s = stat->as<AstStatAssign>())
return visit(s);
else if (auto s = stat->as<AstStatCompoundAssign>())
return visit(s);
else if (auto s = stat->as<AstStatFunction>())
return visit(s);
else if (auto s = stat->as<AstStatLocalFunction>())
return visit(s);
else if (auto s = stat->as<AstStatTypeAlias>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareFunction>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareGlobal>())
return visit(s);
else if (auto s = stat->as<AstStatDeclareClass>())
return visit(s);
else if (auto s = stat->as<AstStatError>())
return visit(s);
else
LUAU_ASSERT(!"TypeChecker2 encountered an unknown node type");
}
void visit(AstStatBlock* block)
{
auto StackPusher = pushStack(block);
for (AstStat* statement : block->body)
visit(statement);
}
void visit(AstStatIf* ifStatement)
{
visit(ifStatement->condition);
visit(ifStatement->thenbody);
if (ifStatement->elsebody)
visit(ifStatement->elsebody);
}
void visit(AstStatWhile* whileStatement)
{
visit(whileStatement->condition);
visit(whileStatement->body);
}
void visit(AstStatRepeat* repeatStatement)
{
visit(repeatStatement->body);
visit(repeatStatement->condition);
}
void visit(AstStatBreak*)
{}
void visit(AstStatContinue*)
{}
void visit(AstStatReturn* ret)
{
Scope* scope = findInnermostScope(ret->location);
TypePackId expectedRetType = scope->returnType;
TypeArena arena;
TypePackId actualRetType = reconstructPack(ret->list, arena);
UnifierSharedState sharedState{&ice};
Unifier u{&arena, Mode::Strict, stack.back(), ret->location, Covariant, sharedState};
u.anyIsTop = true;
u.tryUnify(actualRetType, expectedRetType);
const bool ok = u.errors.empty() && u.log.empty();
if (!ok)
{
for (const TypeError& e : u.errors)
reportError(e);
}
for (AstExpr* expr : ret->list)
visit(expr);
}
void visit(AstStatExpr* expr)
{
visit(expr->expr);
}
void visit(AstStatLocal* local)
{ {
for (size_t i = 0; i < local->values.size; ++i) for (size_t i = 0; i < local->values.size; ++i)
{ {
AstExpr* value = local->values.data[i]; AstExpr* value = local->values.data[i];
visit(value);
if (i == local->values.size - 1) if (i == local->values.size - 1)
{ {
if (i < local->values.size) if (i < local->values.size)
@ -140,7 +300,7 @@ struct TypeChecker2 : public AstVisitor
if (var->annotation) if (var->annotation)
{ {
TypeId varType = lookupAnnotation(var->annotation); TypeId varType = lookupAnnotation(var->annotation);
if (!isSubtype(*it, varType, ice)) if (!isSubtype(*it, varType, stack.back(), ice))
{ {
reportError(TypeMismatch{varType, *it}, value->location); reportError(TypeMismatch{varType, *it}, value->location);
} }
@ -158,64 +318,244 @@ struct TypeChecker2 : public AstVisitor
if (var->annotation) if (var->annotation)
{ {
TypeId varType = lookupAnnotation(var->annotation); TypeId varType = lookupAnnotation(var->annotation);
if (!isSubtype(varType, valueType, ice)) if (!isSubtype(varType, valueType, stack.back(), ice))
{ {
reportError(TypeMismatch{varType, valueType}, value->location); reportError(TypeMismatch{varType, valueType}, value->location);
} }
} }
} }
} }
return true;
} }
bool visit(AstStatAssign* assign) override void visit(AstStatFor* forStatement)
{
if (forStatement->var->annotation)
visit(forStatement->var->annotation);
visit(forStatement->from);
visit(forStatement->to);
if (forStatement->step)
visit(forStatement->step);
visit(forStatement->body);
}
void visit(AstStatForIn* forInStatement)
{
for (AstLocal* local : forInStatement->vars)
{
if (local->annotation)
visit(local->annotation);
}
for (AstExpr* expr : forInStatement->values)
visit(expr);
visit(forInStatement->body);
}
void visit(AstStatAssign* assign)
{ {
size_t count = std::min(assign->vars.size, assign->values.size); size_t count = std::min(assign->vars.size, assign->values.size);
for (size_t i = 0; i < count; ++i) for (size_t i = 0; i < count; ++i)
{ {
AstExpr* lhs = assign->vars.data[i]; AstExpr* lhs = assign->vars.data[i];
visit(lhs);
TypeId lhsType = lookupType(lhs); TypeId lhsType = lookupType(lhs);
AstExpr* rhs = assign->values.data[i]; AstExpr* rhs = assign->values.data[i];
visit(rhs);
TypeId rhsType = lookupType(rhs); TypeId rhsType = lookupType(rhs);
if (!isSubtype(rhsType, lhsType, ice)) if (!isSubtype(rhsType, lhsType, stack.back(), ice))
{ {
reportError(TypeMismatch{lhsType, rhsType}, rhs->location); reportError(TypeMismatch{lhsType, rhsType}, rhs->location);
} }
} }
return true;
} }
bool visit(AstStatReturn* ret) override void visit(AstStatCompoundAssign* stat)
{ {
Scope* scope = findInnermostScope(ret->location); visit(stat->var);
TypePackId expectedRetType = scope->returnType; visit(stat->value);
}
TypeArena arena; void visit(AstStatFunction* stat)
TypePackId actualRetType = reconstructPack(ret->list, arena); {
visit(stat->name);
visit(stat->func);
}
UnifierSharedState sharedState{&ice}; void visit(AstStatLocalFunction* stat)
Unifier u{&arena, Mode::Strict, ret->location, Covariant, sharedState}; {
u.anyIsTop = true; visit(stat->func);
}
u.tryUnify(actualRetType, expectedRetType); void visit(const AstTypeList* typeList)
const bool ok = u.errors.empty() && u.log.empty(); {
for (AstType* ty : typeList->types)
visit(ty);
if (!ok) if (typeList->tailType)
visit(typeList->tailType);
}
void visit(AstStatTypeAlias* stat)
{
for (const AstGenericType& el : stat->generics)
{ {
for (const TypeError& e : u.errors) if (el.defaultValue)
reportError(e); visit(el.defaultValue);
} }
return true; for (const AstGenericTypePack& el : stat->genericPacks)
{
if (el.defaultValue)
visit(el.defaultValue);
}
visit(stat->type);
} }
bool visit(AstExprCall* call) override void visit(AstTypeList types)
{ {
for (AstType* type : types.types)
visit(type);
if (types.tailType)
visit(types.tailType);
}
void visit(AstStatDeclareFunction* stat)
{
visit(stat->params);
visit(stat->retTypes);
}
void visit(AstStatDeclareGlobal* stat)
{
visit(stat->type);
}
void visit(AstStatDeclareClass* stat)
{
for (const AstDeclaredClassProp& prop : stat->props)
visit(prop.ty);
}
void visit(AstStatError* stat)
{
for (AstExpr* expr : stat->expressions)
visit(expr);
for (AstStat* s : stat->statements)
visit(s);
}
void visit(AstExpr* expr)
{
auto StackPusher = pushStack(expr);
if (0)
{}
else if (auto e = expr->as<AstExprGroup>())
return visit(e);
else if (auto e = expr->as<AstExprConstantNil>())
return visit(e);
else if (auto e = expr->as<AstExprConstantBool>())
return visit(e);
else if (auto e = expr->as<AstExprConstantNumber>())
return visit(e);
else if (auto e = expr->as<AstExprConstantString>())
return visit(e);
else if (auto e = expr->as<AstExprLocal>())
return visit(e);
else if (auto e = expr->as<AstExprGlobal>())
return visit(e);
else if (auto e = expr->as<AstExprVarargs>())
return visit(e);
else if (auto e = expr->as<AstExprCall>())
return visit(e);
else if (auto e = expr->as<AstExprIndexName>())
return visit(e);
else if (auto e = expr->as<AstExprIndexExpr>())
return visit(e);
else if (auto e = expr->as<AstExprFunction>())
return visit(e);
else if (auto e = expr->as<AstExprTable>())
return visit(e);
else if (auto e = expr->as<AstExprUnary>())
return visit(e);
else if (auto e = expr->as<AstExprBinary>())
return visit(e);
else if (auto e = expr->as<AstExprTypeAssertion>())
return visit(e);
else if (auto e = expr->as<AstExprIfElse>())
return visit(e);
else if (auto e = expr->as<AstExprError>())
return visit(e);
else
LUAU_ASSERT(!"TypeChecker2 encountered an unknown expression type");
}
void visit(AstExprGroup* expr)
{
visit(expr->expr);
}
void visit(AstExprConstantNil* expr)
{
// TODO!
}
void visit(AstExprConstantBool* expr)
{
// TODO!
}
void visit(AstExprConstantNumber* number)
{
TypeId actualType = lookupType(number);
TypeId numberType = getSingletonTypes().numberType;
if (!isSubtype(numberType, actualType, stack.back(), ice))
{
reportError(TypeMismatch{actualType, numberType}, number->location);
}
}
void visit(AstExprConstantString* string)
{
TypeId actualType = lookupType(string);
TypeId stringType = getSingletonTypes().stringType;
if (!isSubtype(stringType, actualType, stack.back(), ice))
{
reportError(TypeMismatch{actualType, stringType}, string->location);
}
}
void visit(AstExprLocal* expr)
{
// TODO!
}
void visit(AstExprGlobal* expr)
{
// TODO!
}
void visit(AstExprVarargs* expr)
{
// TODO!
}
void visit(AstExprCall* call)
{
visit(call->func);
for (AstExpr* arg : call->args)
visit(arg);
TypeArena arena; TypeArena arena;
Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}}; Instantiation instantiation{TxnLog::empty(), &arena, TypeLevel{}};
@ -225,7 +565,7 @@ struct TypeChecker2 : public AstVisitor
LUAU_ASSERT(functionType); LUAU_ASSERT(functionType);
TypePack args; TypePack args;
for (const auto& arg : call->args) for (AstExpr* arg : call->args)
{ {
TypeId argTy = module->astTypes[arg]; TypeId argTy = module->astTypes[arg];
LUAU_ASSERT(argTy); LUAU_ASSERT(argTy);
@ -235,7 +575,7 @@ struct TypeChecker2 : public AstVisitor
TypePackId argsTp = arena.addTypePack(args); TypePackId argsTp = arena.addTypePack(args);
FunctionTypeVar ftv{argsTp, expectedRetType}; FunctionTypeVar ftv{argsTp, expectedRetType};
TypeId expectedType = arena.addType(ftv); TypeId expectedType = arena.addType(ftv);
if (!isSubtype(expectedType, instantiatedFunctionType, ice)) if (!isSubtype(expectedType, instantiatedFunctionType, stack.back(), ice))
{ {
unfreeze(module->interfaceTypes); unfreeze(module->interfaceTypes);
CloneState cloneState; CloneState cloneState;
@ -243,12 +583,36 @@ struct TypeChecker2 : public AstVisitor
freeze(module->interfaceTypes); freeze(module->interfaceTypes);
reportError(TypeMismatch{expectedType, functionType}, call->location); reportError(TypeMismatch{expectedType, functionType}, call->location);
} }
return true;
} }
bool visit(AstExprFunction* fn) override void visit(AstExprIndexName* indexName)
{ {
TypeId leftType = lookupType(indexName->expr);
TypeId resultType = lookupType(indexName);
// leftType must have a property called indexName->index
std::optional<TypeId> ty = getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true);
if (ty)
{
if (!isSubtype(resultType, *ty, stack.back(), ice))
{
reportError(TypeMismatch{resultType, *ty}, indexName->location);
}
}
}
void visit(AstExprIndexExpr* indexExpr)
{
// TODO!
visit(indexExpr->expr);
visit(indexExpr->index);
}
void visit(AstExprFunction* fn)
{
auto StackPusher = pushStack(fn);
TypeId inferredFnTy = lookupType(fn); TypeId inferredFnTy = lookupType(fn);
const FunctionTypeVar* inferredFtv = get<FunctionTypeVar>(inferredFnTy); const FunctionTypeVar* inferredFtv = get<FunctionTypeVar>(inferredFnTy);
LUAU_ASSERT(inferredFtv); LUAU_ASSERT(inferredFtv);
@ -264,7 +628,7 @@ struct TypeChecker2 : public AstVisitor
TypeId inferredArgTy = *argIt; TypeId inferredArgTy = *argIt;
TypeId annotatedArgTy = lookupAnnotation(arg->annotation); TypeId annotatedArgTy = lookupAnnotation(arg->annotation);
if (!isSubtype(annotatedArgTy, inferredArgTy, ice)) if (!isSubtype(annotatedArgTy, inferredArgTy, stack.back(), ice))
{ {
reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location); reportError(TypeMismatch{annotatedArgTy, inferredArgTy}, arg->location);
} }
@ -273,68 +637,64 @@ struct TypeChecker2 : public AstVisitor
++argIt; ++argIt;
} }
return true; visit(fn->body);
} }
bool visit(AstExprIndexName* indexName) override void visit(AstExprTable* expr)
{ {
TypeId leftType = lookupType(indexName->expr); // TODO!
TypeId resultType = lookupType(indexName); for (const AstExprTable::Item& item : expr->items)
// leftType must have a property called indexName->index
std::optional<TypeId> ty = getIndexTypeFromType(module->getModuleScope(), leftType, indexName->index.value, indexName->location, /* addErrors */ true);
if (ty)
{ {
if (!isSubtype(resultType, *ty, ice)) if (item.key)
{ visit(item.key);
reportError(TypeMismatch{resultType, *ty}, indexName->location); visit(item.value);
}
} }
return true;
} }
bool visit(AstExprConstantNumber* number) override void visit(AstExprUnary* expr)
{ {
TypeId actualType = lookupType(number); // TODO!
TypeId numberType = getSingletonTypes().numberType; visit(expr->expr);
if (!isSubtype(numberType, actualType, ice))
{
reportError(TypeMismatch{actualType, numberType}, number->location);
}
return true;
} }
bool visit(AstExprConstantString* string) override void visit(AstExprBinary* expr)
{ {
TypeId actualType = lookupType(string); // TODO!
TypeId stringType = getSingletonTypes().stringType; visit(expr->left);
visit(expr->right);
if (!isSubtype(stringType, actualType, ice))
{
reportError(TypeMismatch{actualType, stringType}, string->location);
}
return true;
} }
bool visit(AstExprTypeAssertion* expr) override void visit(AstExprTypeAssertion* expr)
{ {
visit(expr->expr);
visit(expr->annotation);
TypeId annotationType = lookupAnnotation(expr->annotation); TypeId annotationType = lookupAnnotation(expr->annotation);
TypeId computedType = lookupType(expr->expr); TypeId computedType = lookupType(expr->expr);
// Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case.
if (isSubtype(annotationType, computedType, ice)) if (isSubtype(annotationType, computedType, stack.back(), ice))
return true; return;
if (isSubtype(computedType, annotationType, ice)) if (isSubtype(computedType, annotationType, stack.back(), ice))
return true; return;
reportError(TypesAreUnrelated{computedType, annotationType}, expr->location); reportError(TypesAreUnrelated{computedType, annotationType}, expr->location);
return true; }
void visit(AstExprIfElse* expr)
{
// TODO!
visit(expr->condition);
visit(expr->trueExpr);
visit(expr->falseExpr);
}
void visit(AstExprError* expr)
{
// TODO!
for (AstExpr* e : expr->expressions)
visit(e);
} }
/** Extract a TypeId for the first type of the provided pack. /** Extract a TypeId for the first type of the provided pack.
@ -375,13 +735,32 @@ struct TypeChecker2 : public AstVisitor
ice.ice("flattenPack got a weird pack!"); ice.ice("flattenPack got a weird pack!");
} }
bool visit(AstType* ty) override void visit(AstType* ty)
{ {
return true; if (auto t = ty->as<AstTypeReference>())
return visit(t);
else if (auto t = ty->as<AstTypeTable>())
return visit(t);
else if (auto t = ty->as<AstTypeFunction>())
return visit(t);
else if (auto t = ty->as<AstTypeTypeof>())
return visit(t);
else if (auto t = ty->as<AstTypeUnion>())
return visit(t);
else if (auto t = ty->as<AstTypeIntersection>())
return visit(t);
} }
bool visit(AstTypeReference* ty) override void visit(AstTypeReference* ty)
{ {
for (const AstTypeOrPack& param : ty->parameters)
{
if (param.type)
visit(param.type);
else
visit(param.typePack);
}
Scope* scope = findInnermostScope(ty->location); Scope* scope = findInnermostScope(ty->location);
LUAU_ASSERT(scope); LUAU_ASSERT(scope);
@ -500,16 +879,76 @@ struct TypeChecker2 : public AstVisitor
reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location);
} }
} }
return true;
} }
bool visit(AstTypePack*) override void visit(AstTypeTable* table)
{ {
return true; // TODO!
for (const AstTableProp& prop : table->props)
visit(prop.type);
if (table->indexer)
{
visit(table->indexer->indexType);
visit(table->indexer->resultType);
}
} }
bool visit(AstTypePackGeneric* tp) override void visit(AstTypeFunction* ty)
{
// TODO!
visit(ty->argTypes);
visit(ty->returnTypes);
}
void visit(AstTypeTypeof* ty)
{
visit(ty->expr);
}
void visit(AstTypeUnion* ty)
{
// TODO!
for (AstType* type : ty->types)
visit(type);
}
void visit(AstTypeIntersection* ty)
{
// TODO!
for (AstType* type : ty->types)
visit(type);
}
void visit(AstTypePack* pack)
{
if (auto p = pack->as<AstTypePackExplicit>())
return visit(p);
else if (auto p = pack->as<AstTypePackVariadic>())
return visit(p);
else if (auto p = pack->as<AstTypePackGeneric>())
return visit(p);
}
void visit(AstTypePackExplicit* tp)
{
// TODO!
for (AstType* type : tp->typeList.types)
visit(type);
if (tp->typeList.tailType)
visit(tp->typeList.tailType);
}
void visit(AstTypePackVariadic* tp)
{
// TODO!
visit(tp->variadicType);
}
void visit(AstTypePackGeneric* tp)
{ {
Scope* scope = findInnermostScope(tp->location); Scope* scope = findInnermostScope(tp->location);
LUAU_ASSERT(scope); LUAU_ASSERT(scope);
@ -531,8 +970,6 @@ struct TypeChecker2 : public AstVisitor
reportError(UnknownSymbol{tp->genericName.value, UnknownSymbol::Context::Type}, tp->location); reportError(UnknownSymbol{tp->genericName.value, UnknownSymbol::Context::Type}, tp->location);
} }
} }
return true;
} }
void reportError(TypeErrorData&& data, const Location& location) void reportError(TypeErrorData&& data, const Location& location)
@ -546,139 +983,9 @@ struct TypeChecker2 : public AstVisitor
} }
std::optional<TypeId> getIndexTypeFromType( std::optional<TypeId> getIndexTypeFromType(
const ScopePtr& scope, TypeId type, const Name& name, const Location& location, bool addErrors) const ScopePtr& scope, TypeId type, const std::string& prop, const Location& location, bool addErrors)
{ {
type = follow(type); return Luau::getIndexTypeFromType(scope, module->errors, &module->internalTypes, type, prop, location, addErrors, ice);
if (get<ErrorTypeVar>(type) || get<AnyTypeVar>(type) || get<NeverTypeVar>(type))
return type;
if (auto f = get<FreeTypeVar>(type))
*asMutable(type) = TableTypeVar{TableState::Free, f->level};
if (isString(type))
{
std::optional<TypeId> mtIndex = Luau::findMetatableEntry(module->errors, singletonTypes.stringType, "__index", location);
LUAU_ASSERT(mtIndex);
type = *mtIndex;
}
if (TableTypeVar* tableType = getMutableTableType(type))
{
return findTablePropertyRespectingMeta(module->errors, type, name, location);
}
else if (const ClassTypeVar* cls = get<ClassTypeVar>(type))
{
const Property* prop = lookupClassProp(cls, name);
if (prop)
return prop->type;
}
else if (const UnionTypeVar* utv = get<UnionTypeVar>(type))
{
std::vector<TypeId> goodOptions;
std::vector<TypeId> badOptions;
for (TypeId t : utv)
{
// TODO: we should probably limit recursion here?
// RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
// Not needed when we normalize types.
if (get<AnyTypeVar>(follow(t)))
return t;
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false))
goodOptions.push_back(*ty);
else
badOptions.push_back(t);
}
if (!badOptions.empty())
{
if (addErrors)
{
if (goodOptions.empty())
reportError(UnknownProperty{type, name}, location);
else
reportError(MissingUnionProperty{type, badOptions, name}, location);
}
return std::nullopt;
}
std::vector<TypeId> result = reduceUnion(goodOptions);
if (result.empty())
return singletonTypes.neverType;
if (result.size() == 1)
return result[0];
return module->internalTypes.addType(UnionTypeVar{std::move(result)});
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type))
{
std::vector<TypeId> parts;
for (TypeId t : itv->parts)
{
// TODO: we should probably limit recursion here?
// RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, t, name, location, /* addErrors= */ false))
parts.push_back(*ty);
}
// If no parts of the intersection had the property we looked up for, it never existed at all.
if (parts.empty())
{
if (addErrors)
reportError(UnknownProperty{type, name}, location);
return std::nullopt;
}
if (parts.size() == 1)
return parts[0];
return module->internalTypes.addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct.
}
if (addErrors)
reportError(UnknownProperty{type, name}, location);
return std::nullopt;
}
std::vector<TypeId> reduceUnion(const std::vector<TypeId>& types)
{
std::vector<TypeId> result;
for (TypeId t : types)
{
t = follow(t);
if (get<NeverTypeVar>(t))
continue;
if (get<ErrorTypeVar>(t) || get<AnyTypeVar>(t))
return {t};
if (const UnionTypeVar* utv = get<UnionTypeVar>(t))
{
for (TypeId ty : utv)
{
ty = follow(ty);
if (get<NeverTypeVar>(ty))
continue;
if (get<ErrorTypeVar>(ty) || get<AnyTypeVar>(ty))
return {ty};
if (result.end() == std::find(result.begin(), result.end(), ty))
result.push_back(ty);
}
}
else if (std::find(result.begin(), result.end(), t) == result.end())
result.push_back(t);
}
return result;
} }
}; };
@ -686,7 +993,7 @@ void check(const SourceModule& sourceModule, Module* module)
{ {
TypeChecker2 typeChecker{&sourceModule, module}; TypeChecker2 typeChecker{&sourceModule, module};
sourceModule.root->visit(&typeChecker); typeChecker.visit(sourceModule.root);
} }
} // namespace Luau } // namespace Luau

View File

@ -33,15 +33,13 @@ LUAU_FASTINTVARIABLE(LuauVisitRecursionLimit, 500)
LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAG(LuauKnowsTheDataModel3)
LUAU_FASTFLAG(LuauAutocompleteDynamicLimits) LUAU_FASTFLAG(LuauAutocompleteDynamicLimits)
LUAU_FASTFLAGVARIABLE(LuauExpectedTableUnionIndexerType, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTableUnionIndexerType, false)
LUAU_FASTFLAGVARIABLE(LuauInplaceDemoteSkipAllBound, false)
LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false) LUAU_FASTFLAGVARIABLE(LuauLowerBoundsCalculation, false)
LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false)
LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix3, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix3, false)
LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false.
LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false); LUAU_FASTFLAGVARIABLE(DebugLuauSharedSelf, false);
LUAU_FASTFLAGVARIABLE(LuauReportErrorsOnIndexerKeyMismatch, false)
LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false) LUAU_FASTFLAGVARIABLE(LuauUnknownAndNeverType, false)
LUAU_FASTFLAG(LuauQuantifyConstrained)
LUAU_FASTFLAGVARIABLE(LuauFalsyPredicateReturnsNilInstead, false)
LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false) LUAU_FASTFLAGVARIABLE(LuauCheckGenericHOFTypes, false)
LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false) LUAU_FASTFLAGVARIABLE(LuauBinaryNeedsExpectedTypesToo, false)
LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false) LUAU_FASTFLAGVARIABLE(LuauNeverTypesAndOperatorsInference, false)
@ -473,7 +471,8 @@ struct InplaceDemoter : TypeVarOnceVisitor
TypeArena* arena; TypeArena* arena;
InplaceDemoter(TypeLevel level, TypeArena* arena) InplaceDemoter(TypeLevel level, TypeArena* arena)
: newLevel(level) : TypeVarOnceVisitor(/* skipBoundTypes= */ FFlag::LuauInplaceDemoteSkipAllBound)
, newLevel(level)
, arena(arena) , arena(arena)
{ {
} }
@ -494,6 +493,7 @@ struct InplaceDemoter : TypeVarOnceVisitor
bool visit(TypeId ty, const BoundTypeVar& btyRef) override bool visit(TypeId ty, const BoundTypeVar& btyRef) override
{ {
LUAU_ASSERT(!FFlag::LuauInplaceDemoteSkipAllBound);
return true; return true;
} }
@ -656,7 +656,7 @@ void TypeChecker::checkBlockWithoutRecursionCheck(const ScopePtr& scope, const A
TypeId leftType = follow(checkFunctionName(scope, *fun->name, funScope->level)); TypeId leftType = follow(checkFunctionName(scope, *fun->name, funScope->level));
unify(funTy, leftType, fun->location); unify(funTy, leftType, scope, fun->location);
} }
else if (auto fun = (*protoIter)->as<AstStatLocalFunction>()) else if (auto fun = (*protoIter)->as<AstStatLocalFunction>())
{ {
@ -768,20 +768,20 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement)
} }
template<typename Id> template<typename Id>
ErrorVec TypeChecker::canUnify_(Id subTy, Id superTy, const Location& location) ErrorVec TypeChecker::canUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location)
{ {
Unifier state = mkUnifier(location); Unifier state = mkUnifier(scope, location);
return state.canUnify(subTy, superTy); return state.canUnify(subTy, superTy);
} }
ErrorVec TypeChecker::canUnify(TypeId subTy, TypeId superTy, const Location& location) ErrorVec TypeChecker::canUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location)
{ {
return canUnify_(subTy, superTy, location); return canUnify_(subTy, superTy, scope, location);
} }
ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Location& location) ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location)
{ {
return canUnify_(subTy, superTy, location); return canUnify_(subTy, superTy, scope, location);
} }
void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement)
@ -802,9 +802,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement)
checkExpr(repScope, *statement.condition); checkExpr(repScope, *statement.condition);
} }
void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const Location& location) void TypeChecker::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel demotedLevel, const ScopePtr& scope, const Location& location)
{ {
Unifier state = mkUnifier(location); Unifier state = mkUnifier(scope, location);
state.unifyLowerBound(subTy, superTy, demotedLevel); state.unifyLowerBound(subTy, superTy, demotedLevel);
state.log.commit(); state.log.commit();
@ -858,8 +858,6 @@ struct Demoter : Substitution
void demote(std::vector<std::optional<TypeId>>& expectedTypes) void demote(std::vector<std::optional<TypeId>>& expectedTypes)
{ {
if (!FFlag::LuauQuantifyConstrained)
return;
for (std::optional<TypeId>& ty : expectedTypes) for (std::optional<TypeId>& ty : expectedTypes)
{ {
if (ty) if (ty)
@ -897,7 +895,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_)
if (useConstrainedIntersections()) if (useConstrainedIntersections())
{ {
unifyLowerBound(retPack, scope->returnType, demoter.demotedLevel(scope->level), return_.location); unifyLowerBound(retPack, scope->returnType, demoter.demotedLevel(scope->level), scope, return_.location);
return; return;
} }
@ -905,7 +903,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_)
// start typechecking everything across module boundaries. // start typechecking everything across module boundaries.
if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType))
{ {
ErrorVec errors = tryUnify(retPack, scope->returnType, return_.location); ErrorVec errors = tryUnify(retPack, scope->returnType, scope, return_.location);
if (!errors.empty()) if (!errors.empty())
currentModule->getModuleScope()->returnType = addTypePack({anyType}); currentModule->getModuleScope()->returnType = addTypePack({anyType});
@ -913,13 +911,13 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_)
return; return;
} }
unify(retPack, scope->returnType, return_.location, CountMismatch::Context::Return); unify(retPack, scope->returnType, scope, return_.location, CountMismatch::Context::Return);
} }
template<typename Id> template<typename Id>
ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const Location& location) ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const ScopePtr& scope, const Location& location)
{ {
Unifier state = mkUnifier(location); Unifier state = mkUnifier(scope, location);
if (FFlag::DebugLuauFreezeDuringUnification) if (FFlag::DebugLuauFreezeDuringUnification)
freeze(currentModule->internalTypes); freeze(currentModule->internalTypes);
@ -935,14 +933,14 @@ ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const Location& location)
return state.errors; return state.errors;
} }
ErrorVec TypeChecker::tryUnify(TypeId subTy, TypeId superTy, const Location& location) ErrorVec TypeChecker::tryUnify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location)
{ {
return tryUnify_(subTy, superTy, location); return tryUnify_(subTy, superTy, scope, location);
} }
ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const Location& location) ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location)
{ {
return tryUnify_(subTy, superTy, location); return tryUnify_(subTy, superTy, scope, location);
} }
void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign)
@ -1036,9 +1034,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign)
{ {
// In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar.
if (isNonstrictMode() && get<FreeTypeVar>(follow(left)) && !get<FunctionTypeVar>(follow(right))) if (isNonstrictMode() && get<FreeTypeVar>(follow(left)) && !get<FunctionTypeVar>(follow(right)))
unify(anyType, left, loc); unify(anyType, left, scope, loc);
else else
unify(right, left, loc); unify(right, left, scope, loc);
} }
} }
} }
@ -1053,7 +1051,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assi
TypeId result = checkBinaryOperation(scope, expr, left, right); TypeId result = checkBinaryOperation(scope, expr, left, right);
unify(result, left, assign.location); unify(result, left, scope, assign.location);
} }
void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local)
@ -1108,7 +1106,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local)
TypePackId valuePack = TypePackId valuePack =
checkExprList(scope, local.location, local.values, /* substituteFreeForNil= */ true, instantiateGenerics, expectedTypes).type; checkExprList(scope, local.location, local.values, /* substituteFreeForNil= */ true, instantiateGenerics, expectedTypes).type;
Unifier state = mkUnifier(local.location); Unifier state = mkUnifier(scope, local.location);
state.ctx = CountMismatch::Result; state.ctx = CountMismatch::Result;
state.tryUnify(valuePack, variablePack); state.tryUnify(valuePack, variablePack);
reportErrors(state.errors); reportErrors(state.errors);
@ -1184,7 +1182,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr)
TypeId loopVarType = numberType; TypeId loopVarType = numberType;
if (expr.var->annotation) if (expr.var->annotation)
unify(loopVarType, resolveType(scope, *expr.var->annotation), expr.location); unify(loopVarType, resolveType(scope, *expr.var->annotation), scope, expr.location);
loopScope->bindings[expr.var] = {loopVarType, expr.var->location}; loopScope->bindings[expr.var] = {loopVarType, expr.var->location};
@ -1194,11 +1192,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr)
if (!expr.to) if (!expr.to)
ice("Bad AstStatFor has no to expr"); ice("Bad AstStatFor has no to expr");
unify(checkExpr(loopScope, *expr.from).type, loopVarType, expr.from->location); unify(checkExpr(loopScope, *expr.from).type, loopVarType, scope, expr.from->location);
unify(checkExpr(loopScope, *expr.to).type, loopVarType, expr.to->location); unify(checkExpr(loopScope, *expr.to).type, loopVarType, scope, expr.to->location);
if (expr.step) if (expr.step)
unify(checkExpr(loopScope, *expr.step).type, loopVarType, expr.step->location); unify(checkExpr(loopScope, *expr.step).type, loopVarType, scope, expr.step->location);
check(loopScope, *expr.body); check(loopScope, *expr.body);
} }
@ -1251,12 +1249,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
if (get<Unifiable::Free>(callRetPack)) if (get<Unifiable::Free>(callRetPack))
{ {
iterTy = freshType(scope); iterTy = freshType(scope);
unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), forin.location); unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), scope, forin.location);
} }
else if (get<Unifiable::Error>(callRetPack) || !first(callRetPack)) else if (get<Unifiable::Error>(callRetPack) || !first(callRetPack))
{ {
for (TypeId var : varTypes) for (TypeId var : varTypes)
unify(errorRecoveryType(scope), var, forin.location); unify(errorRecoveryType(scope), var, scope, forin.location);
return check(loopScope, *forin.body); return check(loopScope, *forin.body);
} }
@ -1277,7 +1275,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
// TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments // TODO: this needs to typecheck all returned values by __iter as if they were for loop arguments
// the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types // the structure of the function makes it difficult to do this especially since we don't have actual expressions, only types
for (TypeId var : varTypes) for (TypeId var : varTypes)
unify(anyType, var, forin.location); unify(anyType, var, scope, forin.location);
return check(loopScope, *forin.body); return check(loopScope, *forin.body);
} }
@ -1289,25 +1287,25 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
if (iterTable->indexer) if (iterTable->indexer)
{ {
if (varTypes.size() > 0) if (varTypes.size() > 0)
unify(iterTable->indexer->indexType, varTypes[0], forin.location); unify(iterTable->indexer->indexType, varTypes[0], scope, forin.location);
if (varTypes.size() > 1) if (varTypes.size() > 1)
unify(iterTable->indexer->indexResultType, varTypes[1], forin.location); unify(iterTable->indexer->indexResultType, varTypes[1], scope, forin.location);
for (size_t i = 2; i < varTypes.size(); ++i) for (size_t i = 2; i < varTypes.size(); ++i)
unify(nilType, varTypes[i], forin.location); unify(nilType, varTypes[i], scope, forin.location);
} }
else if (isNonstrictMode()) else if (isNonstrictMode())
{ {
for (TypeId var : varTypes) for (TypeId var : varTypes)
unify(anyType, var, forin.location); unify(anyType, var, scope, forin.location);
} }
else else
{ {
TypeId varTy = errorRecoveryType(loopScope); TypeId varTy = errorRecoveryType(loopScope);
for (TypeId var : varTypes) for (TypeId var : varTypes)
unify(varTy, var, forin.location); unify(varTy, var, scope, forin.location);
reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"}); reportError(firstValue->location, GenericError{"Cannot iterate over a table without indexer"});
} }
@ -1321,7 +1319,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
TypeId varTy = get<AnyTypeVar>(iterTy) ? anyType : errorRecoveryType(loopScope); TypeId varTy = get<AnyTypeVar>(iterTy) ? anyType : errorRecoveryType(loopScope);
for (TypeId var : varTypes) for (TypeId var : varTypes)
unify(varTy, var, forin.location); unify(varTy, var, scope, forin.location);
if (!get<ErrorTypeVar>(iterTy) && !get<AnyTypeVar>(iterTy) && !get<FreeTypeVar>(iterTy) && !get<NeverTypeVar>(iterTy)) if (!get<ErrorTypeVar>(iterTy) && !get<AnyTypeVar>(iterTy) && !get<FreeTypeVar>(iterTy) && !get<NeverTypeVar>(iterTy))
reportError(firstValue->location, CannotCallNonFunction{iterTy}); reportError(firstValue->location, CannotCallNonFunction{iterTy});
@ -1346,7 +1344,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
argPack = addTypePack(TypePack{}); argPack = addTypePack(TypePack{});
} }
Unifier state = mkUnifier(firstValue->location); Unifier state = mkUnifier(loopScope, firstValue->location);
checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {});
state.log.commit(); state.log.commit();
@ -1365,10 +1363,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin)
AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()};
TypePackId retPack = checkExprPack(scope, exprCall).type; TypePackId retPack = checkExprPack(scope, exprCall).type;
unify(retPack, varPack, forin.location); unify(retPack, varPack, scope, forin.location);
} }
else else
unify(iterFunc->retTypes, varPack, forin.location); unify(iterFunc->retTypes, varPack, scope, forin.location);
check(loopScope, *forin.body); check(loopScope, *forin.body);
} }
@ -1603,7 +1601,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias
TypeId& bindingType = bindingsMap[name].type; TypeId& bindingType = bindingsMap[name].type;
if (unify(ty, bindingType, typealias.location)) if (unify(ty, bindingType, aliasScope, typealias.location))
bindingType = ty; bindingType = ty;
if (FFlag::LuauLowerBoundsCalculation) if (FFlag::LuauLowerBoundsCalculation)
@ -1891,7 +1889,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
TypeLevel level = FFlag::LuauLowerBoundsCalculation ? ftp->level : scope->level; TypeLevel level = FFlag::LuauLowerBoundsCalculation ? ftp->level : scope->level;
TypeId head = freshType(level); TypeId head = freshType(level);
TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(level)}}); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(level)}});
unify(pack, retPack, expr.location); unify(pack, retPack, scope, expr.location);
return {head, std::move(result.predicates)}; return {head, std::move(result.predicates)};
} }
if (get<Unifiable::Error>(retPack)) if (get<Unifiable::Error>(retPack))
@ -1983,20 +1981,15 @@ std::optional<TypeId> TypeChecker::getIndexTypeFromTypeImpl(
else if (auto indexer = tableType->indexer) else if (auto indexer = tableType->indexer)
{ {
// TODO: Property lookup should work with string singletons or unions thereof as the indexer key type. // TODO: Property lookup should work with string singletons or unions thereof as the indexer key type.
ErrorVec errors = tryUnify(stringType, indexer->indexType, location); ErrorVec errors = tryUnify(stringType, indexer->indexType, scope, location);
if (FFlag::LuauReportErrorsOnIndexerKeyMismatch) if (errors.empty())
{
if (errors.empty())
return indexer->indexResultType;
if (addErrors)
reportError(location, UnknownProperty{type, name});
return std::nullopt;
}
else
return indexer->indexResultType; return indexer->indexResultType;
if (addErrors)
reportError(location, UnknownProperty{type, name});
return std::nullopt;
} }
else if (tableType->state == TableState::Free) else if (tableType->state == TableState::Free)
{ {
@ -2228,8 +2221,8 @@ TypeId TypeChecker::checkExprTable(
if (indexer) if (indexer)
{ {
unify(numberType, indexer->indexType, value->location); unify(numberType, indexer->indexType, scope, value->location);
unify(valueType, indexer->indexResultType, value->location); unify(valueType, indexer->indexResultType, scope, value->location);
} }
else else
indexer = TableIndexer{numberType, anyIfNonstrict(valueType)}; indexer = TableIndexer{numberType, anyIfNonstrict(valueType)};
@ -2248,13 +2241,13 @@ TypeId TypeChecker::checkExprTable(
if (it != expectedTable->props.end()) if (it != expectedTable->props.end())
{ {
Property expectedProp = it->second; Property expectedProp = it->second;
ErrorVec errors = tryUnify(exprType, expectedProp.type, k->location); ErrorVec errors = tryUnify(exprType, expectedProp.type, scope, k->location);
if (errors.empty()) if (errors.empty())
exprType = expectedProp.type; exprType = expectedProp.type;
} }
else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType)) else if (expectedTable->indexer && maybeString(expectedTable->indexer->indexType))
{ {
ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, scope, k->location);
if (errors.empty()) if (errors.empty())
exprType = expectedTable->indexer->indexResultType; exprType = expectedTable->indexer->indexResultType;
} }
@ -2269,8 +2262,8 @@ TypeId TypeChecker::checkExprTable(
if (indexer) if (indexer)
{ {
unify(keyType, indexer->indexType, k->location); unify(keyType, indexer->indexType, scope, k->location);
unify(valueType, indexer->indexResultType, value->location); unify(valueType, indexer->indexResultType, scope, value->location);
} }
else if (isNonstrictMode()) else if (isNonstrictMode())
{ {
@ -2411,7 +2404,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
TypePackId retTypePack = freshTypePack(scope); TypePackId retTypePack = freshTypePack(scope);
TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack));
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(scope, expr.location);
state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true);
state.log.commit(); state.log.commit();
@ -2429,7 +2422,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
return {errorRecoveryType(scope)}; return {errorRecoveryType(scope)};
} }
reportErrors(tryUnify(operandType, numberType, expr.location)); reportErrors(tryUnify(operandType, numberType, scope, expr.location));
return {numberType}; return {numberType};
} }
case AstExprUnary::Len: case AstExprUnary::Len:
@ -2459,7 +2452,7 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
TypePackId retTypePack = addTypePack({numberType}); TypePackId retTypePack = addTypePack({numberType});
TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack));
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(scope, expr.location);
state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true);
state.log.commit(); state.log.commit();
@ -2509,11 +2502,11 @@ std::string opToMetaTableEntry(const AstExprBinary::Op& op)
} }
} }
TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes) TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const ScopePtr& scope, const Location& location, bool unifyFreeTypes)
{ {
if (unifyFreeTypes && (get<FreeTypeVar>(a) || get<FreeTypeVar>(b))) if (unifyFreeTypes && (get<FreeTypeVar>(a) || get<FreeTypeVar>(b)))
{ {
if (unify(b, a, location)) if (unify(b, a, scope, location))
return a; return a;
return errorRecoveryType(anyType); return errorRecoveryType(anyType);
@ -2588,7 +2581,7 @@ TypeId TypeChecker::checkRelationalOperation(
{ {
ScopePtr subScope = childScope(scope, subexp->location); ScopePtr subScope = childScope(scope, subexp->location);
resolve(predicates, subScope, true); resolve(predicates, subScope, true);
return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), subScope, expr.location);
} }
} }
@ -2624,7 +2617,7 @@ TypeId TypeChecker::checkRelationalOperation(
* report any problems that might have been surfaced as a result of this step because we might already * report any problems that might have been surfaced as a result of this step because we might already
* have a better, more descriptive error teed up. * have a better, more descriptive error teed up.
*/ */
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(scope, expr.location);
if (!isEquality) if (!isEquality)
{ {
state.tryUnify(rhsType, lhsType); state.tryUnify(rhsType, lhsType);
@ -2703,7 +2696,7 @@ TypeId TypeChecker::checkRelationalOperation(
{ {
if (isEquality) if (isEquality)
{ {
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(scope, expr.location);
state.tryUnify(addTypePack({booleanType}), ftv->retTypes); state.tryUnify(addTypePack({booleanType}), ftv->retTypes);
if (!state.errors.empty()) if (!state.errors.empty())
@ -2755,11 +2748,11 @@ TypeId TypeChecker::checkRelationalOperation(
case AstExprBinary::And: case AstExprBinary::And:
if (lhsIsAny) if (lhsIsAny)
return lhsType; return lhsType;
return unionOfTypes(rhsType, booleanType, expr.location, false); return unionOfTypes(rhsType, booleanType, scope, expr.location, false);
case AstExprBinary::Or: case AstExprBinary::Or:
if (lhsIsAny) if (lhsIsAny)
return lhsType; return lhsType;
return unionOfTypes(lhsType, rhsType, expr.location); return unionOfTypes(lhsType, rhsType, scope, expr.location);
default: default:
LUAU_ASSERT(0); LUAU_ASSERT(0);
ice(format("checkRelationalOperation called with incorrect binary expression '%s'", toString(expr.op).c_str()), expr.location); ice(format("checkRelationalOperation called with incorrect binary expression '%s'", toString(expr.op).c_str()), expr.location);
@ -2816,7 +2809,7 @@ TypeId TypeChecker::checkBinaryOperation(
} }
if (get<FreeTypeVar>(rhsType)) if (get<FreeTypeVar>(rhsType))
unify(rhsType, lhsType, expr.location); unify(rhsType, lhsType, scope, expr.location);
if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType))
{ {
@ -2826,7 +2819,7 @@ TypeId TypeChecker::checkBinaryOperation(
TypePackId retTypePack = freshTypePack(scope); TypePackId retTypePack = freshTypePack(scope);
TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack));
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(scope, expr.location);
state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true);
reportErrors(state.errors); reportErrors(state.errors);
@ -2876,8 +2869,8 @@ TypeId TypeChecker::checkBinaryOperation(
switch (expr.op) switch (expr.op)
{ {
case AstExprBinary::Concat: case AstExprBinary::Concat:
reportErrors(tryUnify(lhsType, addType(UnionTypeVar{{stringType, numberType}}), expr.left->location)); reportErrors(tryUnify(lhsType, addType(UnionTypeVar{{stringType, numberType}}), scope, expr.left->location));
reportErrors(tryUnify(rhsType, addType(UnionTypeVar{{stringType, numberType}}), expr.right->location)); reportErrors(tryUnify(rhsType, addType(UnionTypeVar{{stringType, numberType}}), scope, expr.right->location));
return stringType; return stringType;
case AstExprBinary::Add: case AstExprBinary::Add:
case AstExprBinary::Sub: case AstExprBinary::Sub:
@ -2885,8 +2878,8 @@ TypeId TypeChecker::checkBinaryOperation(
case AstExprBinary::Div: case AstExprBinary::Div:
case AstExprBinary::Mod: case AstExprBinary::Mod:
case AstExprBinary::Pow: case AstExprBinary::Pow:
reportErrors(tryUnify(lhsType, numberType, expr.left->location)); reportErrors(tryUnify(lhsType, numberType, scope, expr.left->location));
reportErrors(tryUnify(rhsType, numberType, expr.right->location)); reportErrors(tryUnify(rhsType, numberType, scope, expr.right->location));
return numberType; return numberType;
default: default:
// These should have been handled with checkRelationalOperation // These should have been handled with checkRelationalOperation
@ -2961,10 +2954,10 @@ WithPredicate<TypeId> TypeChecker::checkExpr(const ScopePtr& scope, const AstExp
WithPredicate<TypeId> result = checkExpr(scope, *expr.expr, annotationType); WithPredicate<TypeId> result = checkExpr(scope, *expr.expr, annotationType);
// Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case.
if (canUnify(annotationType, result.type, expr.location).empty()) if (canUnify(annotationType, result.type, scope, expr.location).empty())
return {annotationType, std::move(result.predicates)}; return {annotationType, std::move(result.predicates)};
if (canUnify(result.type, annotationType, expr.location).empty()) if (canUnify(result.type, annotationType, scope, expr.location).empty())
return {annotationType, std::move(result.predicates)}; return {annotationType, std::move(result.predicates)};
reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); reportError(expr.location, TypesAreUnrelated{result.type, annotationType});
@ -3101,7 +3094,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
} }
else if (auto indexer = lhsTable->indexer) else if (auto indexer = lhsTable->indexer)
{ {
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(scope, expr.location);
state.tryUnify(stringType, indexer->indexType); state.tryUnify(stringType, indexer->indexType);
TypeId retType = indexer->indexResultType; TypeId retType = indexer->indexResultType;
if (!state.errors.empty()) if (!state.errors.empty())
@ -3213,7 +3206,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex
if (exprTable->indexer) if (exprTable->indexer)
{ {
const TableIndexer& indexer = *exprTable->indexer; const TableIndexer& indexer = *exprTable->indexer;
unify(indexType, indexer.indexType, expr.index->location); unify(indexType, indexer.indexType, scope, expr.index->location);
return indexer.indexResultType; return indexer.indexResultType;
} }
else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free)
@ -3829,7 +3822,7 @@ void TypeChecker::checkArgumentList(
{ {
// The use of unify here is deliberate. We don't want this unification // The use of unify here is deliberate. We don't want this unification
// to be undoable. // to be undoable.
unify(errorRecoveryType(scope), *argIter, state.location); unify(errorRecoveryType(scope), *argIter, scope, state.location);
++argIter; ++argIter;
} }
reportCountMismatchError(); reportCountMismatchError();
@ -3852,7 +3845,7 @@ void TypeChecker::checkArgumentList(
TypeId e = errorRecoveryType(scope); TypeId e = errorRecoveryType(scope);
while (argIter != endIter) while (argIter != endIter)
{ {
unify(e, *argIter, state.location); unify(e, *argIter, scope, state.location);
++argIter; ++argIter;
} }
@ -3869,7 +3862,7 @@ void TypeChecker::checkArgumentList(
if (argIndex < argLocations.size()) if (argIndex < argLocations.size())
location = argLocations[argIndex]; location = argLocations[argIndex];
unify(*argIter, vtp->ty, location); unify(*argIter, vtp->ty, scope, location);
++argIter; ++argIter;
++argIndex; ++argIndex;
} }
@ -3906,7 +3899,7 @@ void TypeChecker::checkArgumentList(
} }
else else
{ {
unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); unifyWithInstantiationIfNeeded(*argIter, *paramIter, scope, state);
++argIter; ++argIter;
++paramIter; ++paramIter;
} }
@ -4114,7 +4107,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
if (get<AnyTypeVar>(fn)) if (get<AnyTypeVar>(fn))
{ {
unify(anyTypePack, argPack, expr.location); unify(anyTypePack, argPack, scope, expr.location);
return {{anyTypePack}}; return {{anyTypePack}};
} }
@ -4160,7 +4153,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
UnifierOptions options; UnifierOptions options;
options.isFunctionCall = true; options.isFunctionCall = true;
unify(r, fn, expr.location, options); unify(r, fn, scope, expr.location, options);
return {{retPack}}; return {{retPack}};
} }
@ -4194,7 +4187,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
if (!ftv) if (!ftv)
{ {
reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}});
unify(errorRecoveryTypePack(scope), retPack, expr.func->location); unify(errorRecoveryTypePack(scope), retPack, scope, expr.func->location);
return {{errorRecoveryTypePack(retPack)}}; return {{errorRecoveryTypePack(retPack)}};
} }
@ -4207,7 +4200,7 @@ std::optional<WithPredicate<TypePackId>> TypeChecker::checkCallOverload(const Sc
return *ret; return *ret;
} }
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(scope, expr.location);
// Unify return types // Unify return types
checkArgumentList(scope, state, retPack, ftv->retTypes, /*argLocations*/ {}); checkArgumentList(scope, state, retPack, ftv->retTypes, /*argLocations*/ {});
@ -4269,7 +4262,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal
std::vector<TypeId> editedParamList(args->head.begin() + 1, args->head.end()); std::vector<TypeId> editedParamList(args->head.begin() + 1, args->head.end());
TypePackId editedArgPack = addTypePack(TypePack{editedParamList}); TypePackId editedArgPack = addTypePack(TypePack{editedParamList});
Unifier editedState = mkUnifier(expr.location); Unifier editedState = mkUnifier(scope, expr.location);
checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations);
if (editedState.errors.empty()) if (editedState.errors.empty())
@ -4299,7 +4292,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal
editedArgList.insert(editedArgList.begin(), checkExpr(scope, *indexName->expr).type); editedArgList.insert(editedArgList.begin(), checkExpr(scope, *indexName->expr).type);
TypePackId editedArgPack = addTypePack(TypePack{editedArgList}); TypePackId editedArgPack = addTypePack(TypePack{editedArgList});
Unifier editedState = mkUnifier(expr.location); Unifier editedState = mkUnifier(scope, expr.location);
checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations); checkArgumentList(scope, editedState, editedArgPack, ftv->argTypes, editedArgLocations);
@ -4365,7 +4358,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast
for (size_t i = 0; i < overloadTypes.size(); ++i) for (size_t i = 0; i < overloadTypes.size(); ++i)
{ {
TypeId overload = overloadTypes[i]; TypeId overload = overloadTypes[i];
Unifier state = mkUnifier(expr.location); Unifier state = mkUnifier(scope, expr.location);
// Unify return types // Unify return types
if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(overload)) if (const FunctionTypeVar* ftv = get<FunctionTypeVar>(overload))
@ -4415,7 +4408,7 @@ WithPredicate<TypePackId> TypeChecker::checkExprList(const ScopePtr& scope, cons
size_t lastIndex = exprs.size - 1; size_t lastIndex = exprs.size - 1;
tp->head.reserve(lastIndex); tp->head.reserve(lastIndex);
Unifier state = mkUnifier(location); Unifier state = mkUnifier(scope, location);
std::vector<TxnLog> inverseLogs; std::vector<TxnLog> inverseLogs;
@ -4580,15 +4573,15 @@ TypeId TypeChecker::anyIfNonstrict(TypeId ty) const
return ty; return ty;
} }
bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location) bool TypeChecker::unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location)
{ {
UnifierOptions options; UnifierOptions options;
return unify(subTy, superTy, location, options); return unify(subTy, superTy, scope, location, options);
} }
bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location, const UnifierOptions& options) bool TypeChecker::unify(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location, const UnifierOptions& options)
{ {
Unifier state = mkUnifier(location); Unifier state = mkUnifier(scope, location);
state.tryUnify(subTy, superTy, options.isFunctionCall); state.tryUnify(subTy, superTy, options.isFunctionCall);
state.log.commit(); state.log.commit();
@ -4598,9 +4591,9 @@ bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location,
return state.errors.empty(); return state.errors.empty();
} }
bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx) bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const ScopePtr& scope, const Location& location, CountMismatch::Context ctx)
{ {
Unifier state = mkUnifier(location); Unifier state = mkUnifier(scope, location);
state.ctx = ctx; state.ctx = ctx;
state.tryUnify(subTy, superTy); state.tryUnify(subTy, superTy);
@ -4611,10 +4604,10 @@ bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const Location& lo
return state.errors.empty(); return state.errors.empty();
} }
bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, const Location& location) bool TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, const Location& location)
{ {
Unifier state = mkUnifier(location); Unifier state = mkUnifier(scope, location);
unifyWithInstantiationIfNeeded(scope, subTy, superTy, state); unifyWithInstantiationIfNeeded(subTy, superTy, scope, state);
state.log.commit(); state.log.commit();
@ -4623,7 +4616,7 @@ bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s
return state.errors.empty(); return state.errors.empty();
} }
void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, Unifier& state) void TypeChecker::unifyWithInstantiationIfNeeded(TypeId subTy, TypeId superTy, const ScopePtr& scope, Unifier& state)
{ {
if (!maybeGeneric(subTy)) if (!maybeGeneric(subTy))
// Quick check to see if we definitely can't instantiate // Quick check to see if we definitely can't instantiate
@ -4662,77 +4655,6 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s
} }
} }
bool Anyification::isDirty(TypeId ty)
{
if (ty->persistent)
return false;
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
return (ttv->state == TableState::Free || ttv->state == TableState::Unsealed);
else if (log->getMutable<FreeTypeVar>(ty))
return true;
else if (get<ConstrainedTypeVar>(ty))
return true;
else
return false;
}
bool Anyification::isDirty(TypePackId tp)
{
if (tp->persistent)
return false;
if (log->getMutable<FreeTypePack>(tp))
return true;
else
return false;
}
TypeId Anyification::clean(TypeId ty)
{
LUAU_ASSERT(isDirty(ty));
if (const TableTypeVar* ttv = log->getMutable<TableTypeVar>(ty))
{
TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed};
clone.definitionModuleName = ttv->definitionModuleName;
clone.name = ttv->name;
clone.syntheticName = ttv->syntheticName;
clone.tags = ttv->tags;
TypeId res = addType(std::move(clone));
asMutable(res)->normal = ty->normal;
return res;
}
else if (auto ctv = get<ConstrainedTypeVar>(ty))
{
if (FFlag::LuauQuantifyConstrained)
{
std::vector<TypeId> copy = ctv->parts;
for (TypeId& ty : copy)
ty = replace(ty);
TypeId res = copy.size() == 1 ? copy[0] : addType(UnionTypeVar{std::move(copy)});
auto [t, ok] = normalize(res, *arena, *iceHandler);
if (!ok)
normalizationTooComplex = true;
return t;
}
else
{
auto [t, ok] = normalize(ty, *arena, *iceHandler);
if (!ok)
normalizationTooComplex = true;
return t;
}
}
else
return anyType;
}
TypePackId Anyification::clean(TypePackId tp)
{
LUAU_ASSERT(isDirty(tp));
return anyTypePack;
}
TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location) TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location)
{ {
ty = follow(ty); ty = follow(ty);
@ -4804,7 +4726,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location)
ty = t; ty = t;
} }
Anyification anyification{&currentModule->internalTypes, iceHandler, anyType, anyTypePack}; Anyification anyification{&currentModule->internalTypes, scope, iceHandler, anyType, anyTypePack};
std::optional<TypeId> any = anyification.substitute(ty); std::optional<TypeId> any = anyification.substitute(ty);
if (anyification.normalizationTooComplex) if (anyification.normalizationTooComplex)
reportError(location, NormalizationTooComplex{}); reportError(location, NormalizationTooComplex{});
@ -4827,7 +4749,7 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo
ty = t; ty = t;
} }
Anyification anyification{&currentModule->internalTypes, iceHandler, anyType, anyTypePack}; Anyification anyification{&currentModule->internalTypes, scope, iceHandler, anyType, anyTypePack};
std::optional<TypePackId> any = anyification.substitute(ty); std::optional<TypePackId> any = anyification.substitute(ty);
if (any.has_value()) if (any.has_value())
return *any; return *any;
@ -4963,9 +4885,9 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r)
}); });
} }
Unifier TypeChecker::mkUnifier(const Location& location) Unifier TypeChecker::mkUnifier(const ScopePtr& scope, const Location& location)
{ {
return Unifier{&currentModule->internalTypes, currentModule->mode, location, Variance::Covariant, unifierState}; return Unifier{&currentModule->internalTypes, currentModule->mode, NotNull{scope.get()}, location, Variance::Covariant, unifierState};
} }
TypeId TypeChecker::freshType(const ScopePtr& scope) TypeId TypeChecker::freshType(const ScopePtr& scope)
@ -5029,10 +4951,7 @@ TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense)
return sense ? std::nullopt : std::optional<TypeId>(ty); return sense ? std::nullopt : std::optional<TypeId>(ty);
// at this point, anything else is kept if sense is true, or replaced by nil // at this point, anything else is kept if sense is true, or replaced by nil
if (FFlag::LuauFalsyPredicateReturnsNilInstead) return sense ? ty : nilType;
return sense ? ty : nilType;
else
return sense ? std::optional<TypeId>(ty) : std::nullopt;
}; };
} }
@ -5875,8 +5794,8 @@ void TypeChecker::resolve(const IsAPredicate& isaP, RefinementMap& refis, const
{ {
auto predicate = [&](TypeId option) -> std::optional<TypeId> { auto predicate = [&](TypeId option) -> std::optional<TypeId> {
// This by itself is not truly enough to determine that A is stronger than B or vice versa. // This by itself is not truly enough to determine that A is stronger than B or vice versa.
bool optionIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); bool optionIsSubtype = canUnify(option, isaP.ty, scope, isaP.location).empty();
bool targetIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); bool targetIsSubtype = canUnify(isaP.ty, option, scope, isaP.location).empty();
// If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A.
if (!optionIsSubtype && targetIsSubtype) if (!optionIsSubtype && targetIsSubtype)
@ -6019,7 +5938,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, RefinementMap& refis, const Sc
if (maybeSingleton(eqP.type)) if (maybeSingleton(eqP.type))
{ {
// Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this.
if (!sense || canUnify(eqP.type, option, eqP.location).empty()) if (!sense || canUnify(eqP.type, option, scope, eqP.location).empty())
return sense ? eqP.type : option; return sense ? eqP.type : option;
// local variable works around an odd gcc 9.3 warning: <anonymous> may be used uninitialized // local variable works around an odd gcc 9.3 warning: <anonymous> may be used uninitialized
@ -6053,7 +5972,7 @@ std::vector<TypeId> TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp
size_t oldErrorsSize = currentModule->errors.size(); size_t oldErrorsSize = currentModule->errors.size();
unify(tp, expectedTypePack, location); unify(tp, expectedTypePack, scope, location);
// HACK: tryUnify would undo the changes to the expectedTypePack if the length mismatches, but // HACK: tryUnify would undo the changes to the expectedTypePack if the length mismatches, but
// we want to tie up free types to be error types, so we do this instead. // we want to tie up free types to be error types, so we do this instead.

View File

@ -1,6 +1,7 @@
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#include "Luau/TypeUtils.h" #include "Luau/TypeUtils.h"
#include "Luau/Normalize.h"
#include "Luau/Scope.h" #include "Luau/Scope.h"
#include "Luau/ToString.h" #include "Luau/ToString.h"
#include "Luau/TypeInfer.h" #include "Luau/TypeInfer.h"
@ -8,7 +9,7 @@
namespace Luau namespace Luau
{ {
std::optional<TypeId> findMetatableEntry(ErrorVec& errors, TypeId type, std::string entry, Location location) std::optional<TypeId> findMetatableEntry(ErrorVec& errors, TypeId type, const std::string& entry, Location location)
{ {
type = follow(type); type = follow(type);
@ -35,7 +36,7 @@ std::optional<TypeId> findMetatableEntry(ErrorVec& errors, TypeId type, std::str
return std::nullopt; return std::nullopt;
} }
std::optional<TypeId> findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, Name name, Location location) std::optional<TypeId> findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, const std::string& name, Location location)
{ {
if (get<AnyTypeVar>(ty)) if (get<AnyTypeVar>(ty))
return ty; return ty;
@ -83,4 +84,110 @@ std::optional<TypeId> findTablePropertyRespectingMeta(ErrorVec& errors, TypeId t
return std::nullopt; return std::nullopt;
} }
std::optional<TypeId> getIndexTypeFromType(
const ScopePtr& scope, ErrorVec& errors, TypeArena* arena, TypeId type, const std::string& prop, const Location& location, bool addErrors,
InternalErrorReporter& handle)
{
type = follow(type);
if (get<ErrorTypeVar>(type) || get<AnyTypeVar>(type) || get<NeverTypeVar>(type))
return type;
if (auto f = get<FreeTypeVar>(type))
*asMutable(type) = TableTypeVar{TableState::Free, f->level};
if (isString(type))
{
std::optional<TypeId> mtIndex = Luau::findMetatableEntry(errors, getSingletonTypes().stringType, "__index", location);
LUAU_ASSERT(mtIndex);
type = *mtIndex;
}
if (getTableType(type))
{
return findTablePropertyRespectingMeta(errors, type, prop, location);
}
else if (const ClassTypeVar* cls = get<ClassTypeVar>(type))
{
if (const Property* p = lookupClassProp(cls, prop))
return p->type;
}
else if (const UnionTypeVar* utv = get<UnionTypeVar>(type))
{
std::vector<TypeId> goodOptions;
std::vector<TypeId> badOptions;
for (TypeId t : utv)
{
// TODO: we should probably limit recursion here?
// RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
// Not needed when we normalize types.
if (get<AnyTypeVar>(follow(t)))
return t;
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, errors, arena, t, prop, location, /* addErrors= */ false, handle))
goodOptions.push_back(*ty);
else
badOptions.push_back(t);
}
if (!badOptions.empty())
{
if (addErrors)
{
if (goodOptions.empty())
errors.push_back(TypeError{location, UnknownProperty{type, prop}});
else
errors.push_back(TypeError{location, MissingUnionProperty{type, badOptions, prop}});
}
return std::nullopt;
}
if (goodOptions.empty())
return getSingletonTypes().neverType;
if (goodOptions.size() == 1)
return goodOptions[0];
// TODO: inefficient.
TypeId result = arena->addType(UnionTypeVar{std::move(goodOptions)});
auto [ty, ok] = normalize(result, NotNull{scope.get()}, *arena, handle);
if (!ok && addErrors)
errors.push_back(TypeError{location, NormalizationTooComplex{}});
return ok ? ty : getSingletonTypes().anyType;
}
else if (const IntersectionTypeVar* itv = get<IntersectionTypeVar>(type))
{
std::vector<TypeId> parts;
for (TypeId t : itv->parts)
{
// TODO: we should probably limit recursion here?
// RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit);
if (std::optional<TypeId> ty = getIndexTypeFromType(scope, errors, arena, t, prop, location, /* addErrors= */ false, handle))
parts.push_back(*ty);
}
// If no parts of the intersection had the property we looked up for, it never existed at all.
if (parts.empty())
{
if (addErrors)
errors.push_back(TypeError{location, UnknownProperty{type, prop}});
return std::nullopt;
}
if (parts.size() == 1)
return parts[0];
return arena->addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct.
}
if (addErrors)
errors.push_back(TypeError{location, UnknownProperty{type, prop}});
return std::nullopt;
}
} // namespace Luau } // namespace Luau

View File

@ -1135,7 +1135,7 @@ std::optional<WithPredicate<TypePackId>> magicFunctionFormat(
{ {
Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location;
typechecker.unify(params[i + paramOffset], expected[i], location); typechecker.unify(params[i + paramOffset], expected[i], scope, location);
} }
// if we know the argument count or if we have too many arguments for sure, we can issue an error // if we know the argument count or if we have too many arguments for sure, we can issue an error
@ -1234,7 +1234,7 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionGmatch(
if (returnTypes.empty()) if (returnTypes.empty())
return std::nullopt; return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypePackId emptyPack = arena.addTypePack({}); const TypePackId emptyPack = arena.addTypePack({});
const TypePackId returnList = arena.addTypePack(returnTypes); const TypePackId returnList = arena.addTypePack(returnTypes);
@ -1269,13 +1269,13 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionMatch(
if (returnTypes.empty()) if (returnTypes.empty())
return std::nullopt; return std::nullopt;
typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}}); const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}});
size_t initIndex = expr.self ? 1 : 2; size_t initIndex = expr.self ? 1 : 2;
if (params.size() == 3 && expr.args.size > initIndex) if (params.size() == 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, expr.args.data[initIndex]->location); typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
const TypePackId returnList = arena.addTypePack(returnTypes); const TypePackId returnList = arena.addTypePack(returnTypes);
return WithPredicate<TypePackId>{returnList}; return WithPredicate<TypePackId>{returnList};
@ -1320,17 +1320,17 @@ static std::optional<WithPredicate<TypePackId>> magicFunctionFind(
return std::nullopt; return std::nullopt;
} }
typechecker.unify(params[0], typechecker.stringType, expr.args.data[0]->location); typechecker.unify(params[0], typechecker.stringType, scope, expr.args.data[0]->location);
const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}}); const TypeId optionalNumber = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.numberType}});
const TypeId optionalBoolean = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.booleanType}}); const TypeId optionalBoolean = arena.addType(UnionTypeVar{{typechecker.nilType, typechecker.booleanType}});
size_t initIndex = expr.self ? 1 : 2; size_t initIndex = expr.self ? 1 : 2;
if (params.size() >= 3 && expr.args.size > initIndex) if (params.size() >= 3 && expr.args.size > initIndex)
typechecker.unify(params[2], optionalNumber, expr.args.data[initIndex]->location); typechecker.unify(params[2], optionalNumber, scope, expr.args.data[initIndex]->location);
if (params.size() == 4 && expr.args.size > plainIndex) if (params.size() == 4 && expr.args.size > plainIndex)
typechecker.unify(params[3], optionalBoolean, expr.args.data[plainIndex]->location); typechecker.unify(params[3], optionalBoolean, scope, expr.args.data[plainIndex]->location);
returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber}); returnTypes.insert(returnTypes.begin(), {optionalNumber, optionalNumber});

View File

@ -27,27 +27,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena)
namespace Luau namespace Luau
{ {
static void* systemAllocateAligned(size_t size, size_t align)
{
#ifdef _WIN32
return _aligned_malloc(size, align);
#elif defined(__ANDROID__) // for Android 4.1
return memalign(align, size);
#else
void* ptr;
return posix_memalign(&ptr, align, size) == 0 ? ptr : 0;
#endif
}
static void systemDeallocateAligned(void* ptr)
{
#ifdef _WIN32
_aligned_free(ptr);
#else
free(ptr);
#endif
}
static size_t pageAlign(size_t size) static size_t pageAlign(size_t size)
{ {
return (size + kPageSize - 1) & ~(kPageSize - 1); return (size + kPageSize - 1) & ~(kPageSize - 1);
@ -55,18 +34,31 @@ static size_t pageAlign(size_t size)
void* pagedAllocate(size_t size) void* pagedAllocate(size_t size)
{ {
if (FFlag::DebugLuauFreezeArena) // By default we use operator new/delete instead of malloc/free so that they can be overridden externally
return systemAllocateAligned(pageAlign(size), kPageSize); if (!FFlag::DebugLuauFreezeArena)
else
return ::operator new(size, std::nothrow); return ::operator new(size, std::nothrow);
// On Windows, VirtualAlloc results in 64K granularity allocations; we allocate in chunks of ~32K so aligned_malloc is a little more efficient
// On Linux, we must use mmap because using regular heap results in mprotect() fragmenting the page table and us bumping into 64K mmap limit.
#ifdef _WIN32
return _aligned_malloc(size, kPageSize);
#else
return mmap(nullptr, pageAlign(size), PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1, 0);
#endif
} }
void pagedDeallocate(void* ptr) void pagedDeallocate(void* ptr, size_t size)
{ {
if (FFlag::DebugLuauFreezeArena) // By default we use operator new/delete instead of malloc/free so that they can be overridden externally
systemDeallocateAligned(ptr); if (!FFlag::DebugLuauFreezeArena)
else return ::operator delete(ptr);
::operator delete(ptr);
#ifdef _WIN32
_aligned_free(ptr);
#else
int rc = munmap(ptr, size);
LUAU_ASSERT(rc == 0);
#endif
} }
void pagedFreeze(void* ptr, size_t size) void pagedFreeze(void* ptr, size_t size)

View File

@ -20,7 +20,6 @@ LUAU_FASTINTVARIABLE(LuauTypeInferLowerBoundsIterationLimit, 2000);
LUAU_FASTFLAG(LuauLowerBoundsCalculation); LUAU_FASTFLAG(LuauLowerBoundsCalculation);
LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauErrorRecoveryType);
LUAU_FASTFLAG(LuauUnknownAndNeverType) LUAU_FASTFLAG(LuauUnknownAndNeverType)
LUAU_FASTFLAG(LuauQuantifyConstrained)
LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false) LUAU_FASTFLAGVARIABLE(LuauScalarShapeSubtyping, false)
LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution)
@ -318,9 +317,10 @@ static std::optional<std::pair<Luau::Name, const SingletonTypeVar*>> getTableMat
return std::nullopt; return std::nullopt;
} }
Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) Unifier::Unifier(TypeArena* types, Mode mode, NotNull<Scope> scope, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog)
: types(types) : types(types)
, mode(mode) , mode(mode)
, scope(scope)
, log(parentLog) , log(parentLog)
, location(location) , location(location)
, variance(variance) , variance(variance)
@ -2091,13 +2091,11 @@ void Unifier::unifyLowerBound(TypePackId subTy, TypePackId superTy, TypeLevel de
if (!freeTailPack) if (!freeTailPack)
return; return;
TypeLevel level = FFlag::LuauQuantifyConstrained ? demotedLevel : freeTailPack->level;
TypePack* tp = getMutable<TypePack>(log.replace(tailPack, TypePack{})); TypePack* tp = getMutable<TypePack>(log.replace(tailPack, TypePack{}));
for (; subIter != subEndIter; ++subIter) for (; subIter != subEndIter; ++subIter)
{ {
tp->head.push_back(types->addType(ConstrainedTypeVar{level, {follow(*subIter)}})); tp->head.push_back(types->addType(ConstrainedTypeVar{demotedLevel, {follow(*subIter)}}));
} }
tp->tail = subIter.tail(); tp->tail = subIter.tail();
@ -2270,7 +2268,7 @@ bool Unifier::occursCheck(DenseHashSet<TypePackId>& seen, TypePackId needle, Typ
Unifier Unifier::makeChildUnifier() Unifier Unifier::makeChildUnifier()
{ {
Unifier u = Unifier{types, mode, location, variance, sharedState, &log}; Unifier u = Unifier{types, mode, scope, location, variance, sharedState, &log};
u.anyIsTop = anyIsTop; u.anyIsTop = anyIsTop;
return u; return u;
} }

View File

@ -138,7 +138,7 @@ private:
// funcbody ::= `(' [parlist] `)' block end // funcbody ::= `(' [parlist] `)' block end
// parlist ::= namelist [`,' `...'] | `...' // parlist ::= namelist [`,' `...'] | `...'
std::pair<AstExprFunction*, AstLocal*> parseFunctionBody( std::pair<AstExprFunction*, AstLocal*> parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, std::optional<Name> localName); bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName);
// explist ::= {exp `,'} exp // explist ::= {exp `,'} exp
void parseExprList(TempVector<AstExpr*>& result); void parseExprList(TempVector<AstExpr*>& result);
@ -217,7 +217,7 @@ private:
AstExpr* parseSimpleExpr(); AstExpr* parseSimpleExpr();
// args ::= `(' [explist] `)' | tableconstructor | String // args ::= `(' [explist] `)' | tableconstructor | String
AstExpr* parseFunctionArgs(AstExpr* func, bool self, const Location& selfLocation); AstExpr* parseFunctionArgs(AstExpr* func, bool self);
// tableconstructor ::= `{' [fieldlist] `}' // tableconstructor ::= `{' [fieldlist] `}'
// fieldlist ::= field {fieldsep field} [fieldsep] // fieldlist ::= field {fieldsep field} [fieldsep]
@ -241,6 +241,7 @@ private:
std::optional<AstArray<char>> parseCharArray(); std::optional<AstArray<char>> parseCharArray();
AstExpr* parseString(); AstExpr* parseString();
AstExpr* parseNumber();
AstLocal* pushLocal(const Binding& binding); AstLocal* pushLocal(const Binding& binding);
@ -253,11 +254,24 @@ private:
bool expectAndConsume(Lexeme::Type type, const char* context = nullptr); bool expectAndConsume(Lexeme::Type type, const char* context = nullptr);
void expectAndConsumeFail(Lexeme::Type type, const char* context); void expectAndConsumeFail(Lexeme::Type type, const char* context);
bool expectMatchAndConsume(char value, const Lexeme& begin, bool searchForMissing = false); struct MatchLexeme
void expectMatchAndConsumeFail(Lexeme::Type type, const Lexeme& begin, const char* extra = nullptr); {
MatchLexeme(const Lexeme& l)
: type(l.type)
, position(l.location.begin)
{
}
bool expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin); Lexeme::Type type;
void expectMatchEndAndConsumeFail(Lexeme::Type type, const Lexeme& begin); Position position;
};
bool expectMatchAndConsume(char value, const MatchLexeme& begin, bool searchForMissing = false);
void expectMatchAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin, const char* extra = nullptr);
bool expectMatchAndConsumeRecover(char value, const MatchLexeme& begin, bool searchForMissing);
bool expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin);
void expectMatchEndAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin);
template<typename T> template<typename T>
AstArray<T> copy(const T* data, std::size_t size); AstArray<T> copy(const T* data, std::size_t size);
@ -283,6 +297,9 @@ private:
AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, bool isMissing, const char* format, ...) AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray<AstType*>& types, bool isMissing, const char* format, ...)
LUAU_PRINTF_ATTR(5, 6); LUAU_PRINTF_ATTR(5, 6);
AstExpr* reportFunctionArgsError(AstExpr* func, bool self);
void reportAmbiguousCallError();
void nextLexeme(); void nextLexeme();
struct Function struct Function
@ -350,7 +367,7 @@ private:
AstName nameError; AstName nameError;
AstName nameNil; AstName nameNil;
Lexeme endMismatchSuspect; MatchLexeme endMismatchSuspect;
std::vector<Function> functionStack; std::vector<Function> functionStack;

View File

@ -14,8 +14,6 @@
LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000)
LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100)
LUAU_FASTFLAGVARIABLE(LuauParserFunctionKeywordAsTypeHelp, false)
LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false) LUAU_FASTFLAGVARIABLE(LuauFixNamedFunctionParse, false)
LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false) LUAU_DYNAMIC_FASTFLAGVARIABLE(LuaReportParseWrongNamedType, false)
@ -177,7 +175,7 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc
, lexer(buffer, bufferSize, names) , lexer(buffer, bufferSize, names)
, allocator(allocator) , allocator(allocator)
, recursionCounter(0) , recursionCounter(0)
, endMismatchSuspect(Location(), Lexeme::Eof) , endMismatchSuspect(Lexeme(Location(), Lexeme::Eof))
, localMap(AstName()) , localMap(AstName())
{ {
Function top; Function top;
@ -657,7 +655,7 @@ AstStat* Parser::parseFunctionStat()
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, {}).first; AstExprFunction* body = parseFunctionBody(hasself, matchFunction, debugname, nullptr).first;
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
@ -686,7 +684,7 @@ AstStat* Parser::parseLocal()
matchRecoveryStopOnToken[Lexeme::ReservedEnd]++; matchRecoveryStopOnToken[Lexeme::ReservedEnd]++;
auto [body, var] = parseFunctionBody(false, matchFunction, name.name, name); auto [body, var] = parseFunctionBody(false, matchFunction, name.name, &name);
matchRecoveryStopOnToken[Lexeme::ReservedEnd]--; matchRecoveryStopOnToken[Lexeme::ReservedEnd]--;
@ -778,7 +776,7 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod()
genericPacks.size = 0; genericPacks.size = 0;
genericPacks.data = nullptr; genericPacks.data = nullptr;
Lexeme matchParen = lexer.current(); MatchLexeme matchParen = lexer.current();
expectAndConsume('(', "function parameter list start"); expectAndConsume('(', "function parameter list start");
TempVector<Binding> args(scratchBinding); TempVector<Binding> args(scratchBinding);
@ -834,7 +832,7 @@ AstStat* Parser::parseDeclaration(const Location& start)
auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false);
Lexeme matchParen = lexer.current(); MatchLexeme matchParen = lexer.current();
expectAndConsume('(', "global function declaration"); expectAndConsume('(', "global function declaration");
@ -970,13 +968,13 @@ AstStat* Parser::parseCompoundAssignment(AstExpr* initial, AstExprBinary::Op op)
// funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end // funcbody ::= `(' [parlist] `)' [`:' ReturnType] block end
// parlist ::= bindinglist [`,' `...'] | `...' // parlist ::= bindinglist [`,' `...'] | `...'
std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody( std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
bool hasself, const Lexeme& matchFunction, const AstName& debugname, std::optional<Name> localName) bool hasself, const Lexeme& matchFunction, const AstName& debugname, const Name* localName)
{ {
Location start = matchFunction.location; Location start = matchFunction.location;
auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false);
Lexeme matchParen = lexer.current(); MatchLexeme matchParen = lexer.current();
expectAndConsume('(', "function"); expectAndConsume('(', "function");
TempVector<Binding> args(scratchBinding); TempVector<Binding> args(scratchBinding);
@ -988,7 +986,7 @@ std::pair<AstExprFunction*, AstLocal*> Parser::parseFunctionBody(
std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true); std::tie(vararg, varargAnnotation) = parseBindingList(args, /* allowDot3= */ true);
std::optional<Location> argLocation = matchParen.type == Lexeme::Type('(') && lexer.current().type == Lexeme::Type(')') std::optional<Location> argLocation = matchParen.type == Lexeme::Type('(') && lexer.current().type == Lexeme::Type(')')
? std::make_optional(Location(matchParen.location.begin, lexer.current().location.end)) ? std::make_optional(Location(matchParen.position, lexer.current().location.end))
: std::nullopt; : std::nullopt;
expectMatchAndConsume(')', matchParen, true); expectMatchAndConsume(')', matchParen, true);
@ -1255,7 +1253,7 @@ AstType* Parser::parseTableTypeAnnotation()
Location start = lexer.current().location; Location start = lexer.current().location;
Lexeme matchBrace = lexer.current(); MatchLexeme matchBrace = lexer.current();
expectAndConsume('{', "table type"); expectAndConsume('{', "table type");
while (lexer.current().type != '}') while (lexer.current().type != '}')
@ -1628,7 +1626,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack)
{ {
return parseFunctionTypeAnnotation(allowPack); return parseFunctionTypeAnnotation(allowPack);
} }
else if (FFlag::LuauParserFunctionKeywordAsTypeHelp && lexer.current().type == Lexeme::ReservedFunction) else if (lexer.current().type == Lexeme::ReservedFunction)
{ {
Location location = lexer.current().location; Location location = lexer.current().location;
@ -1912,14 +1910,14 @@ AstExpr* Parser::parsePrefixExpr()
{ {
if (lexer.current().type == '(') if (lexer.current().type == '(')
{ {
Location start = lexer.current().location; Position start = lexer.current().location.begin;
Lexeme matchParen = lexer.current(); MatchLexeme matchParen = lexer.current();
nextLexeme(); nextLexeme();
AstExpr* expr = parseExpr(); AstExpr* expr = parseExpr();
Location end = lexer.current().location; Position end = lexer.current().location.end;
if (lexer.current().type != ')') if (lexer.current().type != ')')
{ {
@ -1927,7 +1925,7 @@ AstExpr* Parser::parsePrefixExpr()
expectMatchAndConsumeFail(static_cast<Lexeme::Type>(')'), matchParen, suggestion); expectMatchAndConsumeFail(static_cast<Lexeme::Type>(')'), matchParen, suggestion);
end = lexer.previousLocation(); end = lexer.previousLocation().end;
} }
else else
{ {
@ -1945,7 +1943,7 @@ AstExpr* Parser::parsePrefixExpr()
// primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs } // primaryexp -> prefixexp { `.' NAME | `[' exp `]' | `:' NAME funcargs | funcargs }
AstExpr* Parser::parsePrimaryExpr(bool asStatement) AstExpr* Parser::parsePrimaryExpr(bool asStatement)
{ {
Location start = lexer.current().location; Position start = lexer.current().location.begin;
AstExpr* expr = parsePrefixExpr(); AstExpr* expr = parsePrefixExpr();
@ -1960,16 +1958,16 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement)
Name index = parseIndexName(nullptr, opPosition); Name index = parseIndexName(nullptr, opPosition);
expr = allocator.alloc<AstExprIndexName>(Location(start, index.location), expr, index.name, index.location, opPosition, '.'); expr = allocator.alloc<AstExprIndexName>(Location(start, index.location.end), expr, index.name, index.location, opPosition, '.');
} }
else if (lexer.current().type == '[') else if (lexer.current().type == '[')
{ {
Lexeme matchBracket = lexer.current(); MatchLexeme matchBracket = lexer.current();
nextLexeme(); nextLexeme();
AstExpr* index = parseExpr(); AstExpr* index = parseExpr();
Location end = lexer.current().location; Position end = lexer.current().location.end;
expectMatchAndConsume(']', matchBracket); expectMatchAndConsume(']', matchBracket);
@ -1981,27 +1979,24 @@ AstExpr* Parser::parsePrimaryExpr(bool asStatement)
nextLexeme(); nextLexeme();
Name index = parseIndexName("method name", opPosition); Name index = parseIndexName("method name", opPosition);
AstExpr* func = allocator.alloc<AstExprIndexName>(Location(start, index.location), expr, index.name, index.location, opPosition, ':'); AstExpr* func = allocator.alloc<AstExprIndexName>(Location(start, index.location.end), expr, index.name, index.location, opPosition, ':');
expr = parseFunctionArgs(func, true, index.location); expr = parseFunctionArgs(func, true);
} }
else if (lexer.current().type == '(') else if (lexer.current().type == '(')
{ {
// This error is handled inside 'parseFunctionArgs' as well, but for better error recovery we need to break out the current loop here // This error is handled inside 'parseFunctionArgs' as well, but for better error recovery we need to break out the current loop here
if (!asStatement && expr->location.end.line != lexer.current().location.begin.line) if (!asStatement && expr->location.end.line != lexer.current().location.begin.line)
{ {
report(lexer.current().location, reportAmbiguousCallError();
"Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of "
"new statement; use ';' to separate statements");
break; break;
} }
expr = parseFunctionArgs(expr, false, Location()); expr = parseFunctionArgs(expr, false);
} }
else if (lexer.current().type == '{' || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) else if (lexer.current().type == '{' || lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString)
{ {
expr = parseFunctionArgs(expr, false, Location()); expr = parseFunctionArgs(expr, false);
} }
else else
{ {
@ -2156,7 +2151,7 @@ static ConstantNumberParseResult parseInteger(double& result, const char* data,
return ConstantNumberParseResult::Ok; return ConstantNumberParseResult::Ok;
} }
static ConstantNumberParseResult parseNumber(double& result, const char* data) static ConstantNumberParseResult parseDouble(double& result, const char* data)
{ {
LUAU_ASSERT(FFlag::LuauLintParseIntegerIssues); LUAU_ASSERT(FFlag::LuauLintParseIntegerIssues);
@ -2214,61 +2209,11 @@ AstExpr* Parser::parseSimpleExpr()
Lexeme matchFunction = lexer.current(); Lexeme matchFunction = lexer.current();
nextLexeme(); nextLexeme();
return parseFunctionBody(false, matchFunction, AstName(), {}).first; return parseFunctionBody(false, matchFunction, AstName(), nullptr).first;
} }
else if (lexer.current().type == Lexeme::Number) else if (lexer.current().type == Lexeme::Number)
{ {
scratchData.assign(lexer.current().data, lexer.current().length); return parseNumber();
// Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al
if (scratchData.find('_') != std::string::npos)
{
scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end());
}
if (FFlag::LuauLintParseIntegerIssues)
{
double value = 0;
ConstantNumberParseResult result = parseNumber(value, scratchData.c_str());
nextLexeme();
if (result == ConstantNumberParseResult::Malformed)
return reportExprError(start, {}, "Malformed number");
return allocator.alloc<AstExprConstantNumber>(start, value, result);
}
else if (DFFlag::LuaReportParseIntegerIssues)
{
double value = 0;
if (const char* error = parseNumber_DEPRECATED2(value, scratchData.c_str()))
{
nextLexeme();
return reportExprError(start, {}, "%s", error);
}
else
{
nextLexeme();
return allocator.alloc<AstExprConstantNumber>(start, value);
}
}
else
{
double value = 0;
if (parseNumber_DEPRECATED(value, scratchData.c_str()))
{
nextLexeme();
return allocator.alloc<AstExprConstantNumber>(start, value);
}
else
{
nextLexeme();
return reportExprError(start, {}, "Malformed number");
}
}
} }
else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString)
{ {
@ -2309,18 +2254,15 @@ AstExpr* Parser::parseSimpleExpr()
} }
// args ::= `(' [explist] `)' | tableconstructor | String // args ::= `(' [explist] `)' | tableconstructor | String
AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self, const Location& selfLocation) AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self)
{ {
if (lexer.current().type == '(') if (lexer.current().type == '(')
{ {
Position argStart = lexer.current().location.end; Position argStart = lexer.current().location.end;
if (func->location.end.line != lexer.current().location.begin.line) if (func->location.end.line != lexer.current().location.begin.line)
{ reportAmbiguousCallError();
report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of "
"new statement; use ';' to separate statements");
}
Lexeme matchParen = lexer.current(); MatchLexeme matchParen = lexer.current();
nextLexeme(); nextLexeme();
TempVector<AstExpr*> args(scratchExpr); TempVector<AstExpr*> args(scratchExpr);
@ -2352,18 +2294,29 @@ AstExpr* Parser::parseFunctionArgs(AstExpr* func, bool self, const Location& sel
} }
else else
{ {
if (self && lexer.current().location.begin.line != func->location.end.line) return reportFunctionArgsError(func, self);
{
return reportExprError(func->location, copy({func}), "Expected function call arguments after '('");
}
else
{
return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}),
"Expected '(', '{' or <string> when parsing function call, got %s", lexer.current().toString().c_str());
}
} }
} }
LUAU_NOINLINE AstExpr* Parser::reportFunctionArgsError(AstExpr* func, bool self)
{
if (self && lexer.current().location.begin.line != func->location.end.line)
{
return reportExprError(func->location, copy({func}), "Expected function call arguments after '('");
}
else
{
return reportExprError(Location(func->location.begin, lexer.current().location.begin), copy({func}),
"Expected '(', '{' or <string> when parsing function call, got %s", lexer.current().toString().c_str());
}
}
LUAU_NOINLINE void Parser::reportAmbiguousCallError()
{
report(lexer.current().location, "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of "
"new statement; use ';' to separate statements");
}
// tableconstructor ::= `{' [fieldlist] `}' // tableconstructor ::= `{' [fieldlist] `}'
// fieldlist ::= field {fieldsep field} [fieldsep] // fieldlist ::= field {fieldsep field} [fieldsep]
// field ::= `[' exp `]' `=' exp | Name `=' exp | exp // field ::= `[' exp `]' `=' exp | Name `=' exp | exp
@ -2374,14 +2327,14 @@ AstExpr* Parser::parseTableConstructor()
Location start = lexer.current().location; Location start = lexer.current().location;
Lexeme matchBrace = lexer.current(); MatchLexeme matchBrace = lexer.current();
expectAndConsume('{', "table literal"); expectAndConsume('{', "table literal");
while (lexer.current().type != '}') while (lexer.current().type != '}')
{ {
if (lexer.current().type == '[') if (lexer.current().type == '[')
{ {
Lexeme matchLocationBracket = lexer.current(); MatchLexeme matchLocationBracket = lexer.current();
nextLexeme(); nextLexeme();
AstExpr* key = parseExpr(); AstExpr* key = parseExpr();
@ -2692,6 +2645,63 @@ AstExpr* Parser::parseString()
return reportExprError(location, {}, "String literal contains malformed escape sequence"); return reportExprError(location, {}, "String literal contains malformed escape sequence");
} }
AstExpr* Parser::parseNumber()
{
Location start = lexer.current().location;
scratchData.assign(lexer.current().data, lexer.current().length);
// Remove all internal _ - they don't hold any meaning and this allows parsing code to just pass the string pointer to strtod et al
if (scratchData.find('_') != std::string::npos)
{
scratchData.erase(std::remove(scratchData.begin(), scratchData.end(), '_'), scratchData.end());
}
if (FFlag::LuauLintParseIntegerIssues)
{
double value = 0;
ConstantNumberParseResult result = parseDouble(value, scratchData.c_str());
nextLexeme();
if (result == ConstantNumberParseResult::Malformed)
return reportExprError(start, {}, "Malformed number");
return allocator.alloc<AstExprConstantNumber>(start, value, result);
}
else if (DFFlag::LuaReportParseIntegerIssues)
{
double value = 0;
if (const char* error = parseNumber_DEPRECATED2(value, scratchData.c_str()))
{
nextLexeme();
return reportExprError(start, {}, "%s", error);
}
else
{
nextLexeme();
return allocator.alloc<AstExprConstantNumber>(start, value);
}
}
else
{
double value = 0;
if (parseNumber_DEPRECATED(value, scratchData.c_str()))
{
nextLexeme();
return allocator.alloc<AstExprConstantNumber>(start, value);
}
else
{
nextLexeme();
return reportExprError(start, {}, "Malformed number");
}
}
}
AstLocal* Parser::pushLocal(const Binding& binding) AstLocal* Parser::pushLocal(const Binding& binding)
{ {
const Name& name = binding.name; const Name& name = binding.name;
@ -2763,7 +2773,7 @@ LUAU_NOINLINE void Parser::expectAndConsumeFail(Lexeme::Type type, const char* c
report(lexer.current().location, "Expected %s, got %s", typeString.c_str(), currLexemeString.c_str()); report(lexer.current().location, "Expected %s, got %s", typeString.c_str(), currLexemeString.c_str());
} }
bool Parser::expectMatchAndConsume(char value, const Lexeme& begin, bool searchForMissing) bool Parser::expectMatchAndConsume(char value, const MatchLexeme& begin, bool searchForMissing)
{ {
Lexeme::Type type = static_cast<Lexeme::Type>(static_cast<unsigned char>(value)); Lexeme::Type type = static_cast<Lexeme::Type>(static_cast<unsigned char>(value));
@ -2771,42 +2781,7 @@ bool Parser::expectMatchAndConsume(char value, const Lexeme& begin, bool searchF
{ {
expectMatchAndConsumeFail(type, begin); expectMatchAndConsumeFail(type, begin);
if (searchForMissing) return expectMatchAndConsumeRecover(value, begin, searchForMissing);
{
// previous location is taken because 'current' lexeme is already the next token
unsigned currentLine = lexer.previousLocation().end.line;
// search to the end of the line for expected token
// we will also stop if we hit a token that can be handled by parsing function above the current one
Lexeme::Type lexemeType = lexer.current().type;
while (currentLine == lexer.current().location.begin.line && lexemeType != type && matchRecoveryStopOnToken[lexemeType] == 0)
{
nextLexeme();
lexemeType = lexer.current().type;
}
if (lexemeType == type)
{
nextLexeme();
return true;
}
}
else
{
// check if this is an extra token and the expected token is next
if (lexer.lookahead().type == type)
{
// skip invalid and consume expected
nextLexeme();
nextLexeme();
return true;
}
}
return false;
} }
else else
{ {
@ -2816,21 +2791,64 @@ bool Parser::expectMatchAndConsume(char value, const Lexeme& begin, bool searchF
} }
} }
// LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is LUAU_NOINLINE bool Parser::expectMatchAndConsumeRecover(char value, const MatchLexeme& begin, bool searchForMissing)
// cold
LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const Lexeme& begin, const char* extra)
{ {
std::string typeString = Lexeme(Location(Position(0, 0), 0), type).toString(); Lexeme::Type type = static_cast<Lexeme::Type>(static_cast<unsigned char>(value));
if (lexer.current().location.begin.line == begin.location.begin.line) if (searchForMissing)
report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), begin.toString().c_str(), {
begin.location.begin.column + 1, lexer.current().toString().c_str(), extra ? extra : ""); // previous location is taken because 'current' lexeme is already the next token
unsigned currentLine = lexer.previousLocation().end.line;
// search to the end of the line for expected token
// we will also stop if we hit a token that can be handled by parsing function above the current one
Lexeme::Type lexemeType = lexer.current().type;
while (currentLine == lexer.current().location.begin.line && lexemeType != type && matchRecoveryStopOnToken[lexemeType] == 0)
{
nextLexeme();
lexemeType = lexer.current().type;
}
if (lexemeType == type)
{
nextLexeme();
return true;
}
}
else else
report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), begin.toString().c_str(), {
begin.location.begin.line + 1, lexer.current().toString().c_str(), extra ? extra : ""); // check if this is an extra token and the expected token is next
if (lexer.lookahead().type == type)
{
// skip invalid and consume expected
nextLexeme();
nextLexeme();
return true;
}
}
return false;
} }
bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin) // LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is
// cold
LUAU_NOINLINE void Parser::expectMatchAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin, const char* extra)
{
std::string typeString = Lexeme(Location(Position(0, 0), 0), type).toString();
std::string matchString = Lexeme(Location(Position(0, 0), 0), begin.type).toString();
if (lexer.current().location.begin.line == begin.position.line)
report(lexer.current().location, "Expected %s (to close %s at column %d), got %s%s", typeString.c_str(), matchString.c_str(),
begin.position.column + 1, lexer.current().toString().c_str(), extra ? extra : "");
else
report(lexer.current().location, "Expected %s (to close %s at line %d), got %s%s", typeString.c_str(), matchString.c_str(),
begin.position.line + 1, lexer.current().toString().c_str(), extra ? extra : "");
}
bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const MatchLexeme& begin)
{ {
if (lexer.current().type != type) if (lexer.current().type != type)
{ {
@ -2852,9 +2870,9 @@ bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin)
{ {
// If the token matches on a different line and a different column, it suggests misleading indentation // If the token matches on a different line and a different column, it suggests misleading indentation
// This can be used to pinpoint the problem location for a possible future *actual* mismatch // This can be used to pinpoint the problem location for a possible future *actual* mismatch
if (lexer.current().location.begin.line != begin.location.begin.line && if (lexer.current().location.begin.line != begin.position.line &&
lexer.current().location.begin.column != begin.location.begin.column && lexer.current().location.begin.column != begin.position.column &&
endMismatchSuspect.location.begin.line < begin.location.begin.line) // Only replace the previous suspect with more recent suspects endMismatchSuspect.position.line < begin.position.line) // Only replace the previous suspect with more recent suspects
{ {
endMismatchSuspect = begin; endMismatchSuspect = begin;
} }
@ -2867,12 +2885,12 @@ bool Parser::expectMatchEndAndConsume(Lexeme::Type type, const Lexeme& begin)
// LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is // LUAU_NOINLINE is used to limit the stack cost of this function due to std::string objects, and to increase caller performance since this code is
// cold // cold
LUAU_NOINLINE void Parser::expectMatchEndAndConsumeFail(Lexeme::Type type, const Lexeme& begin) LUAU_NOINLINE void Parser::expectMatchEndAndConsumeFail(Lexeme::Type type, const MatchLexeme& begin)
{ {
if (endMismatchSuspect.type != Lexeme::Eof && endMismatchSuspect.location.begin.line > begin.location.begin.line) if (endMismatchSuspect.type != Lexeme::Eof && endMismatchSuspect.position.line > begin.position.line)
{ {
std::string suggestion = std::string matchString = Lexeme(Location(Position(0, 0), 0), endMismatchSuspect.type).toString();
format("; did you forget to close %s at line %d?", endMismatchSuspect.toString().c_str(), endMismatchSuspect.location.begin.line + 1); std::string suggestion = format("; did you forget to close %s at line %d?", matchString.c_str(), endMismatchSuspect.position.line + 1);
expectMatchAndConsumeFail(type, begin, suggestion.c_str()); expectMatchAndConsumeFail(type, begin, suggestion.c_str());
} }

View File

@ -515,6 +515,9 @@ enum LuauBuiltinFunction
// rawlen // rawlen
LBF_RAWLEN, LBF_RAWLEN,
// bit32.extract(_, k, k)
LBF_BIT32_EXTRACTK,
}; };
// Capture type, used in LOP_CAPTURE // Capture type, used in LOP_CAPTURE

View File

@ -319,11 +319,12 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count)
break; break;
case LBF_BIT32_EXTRACT: case LBF_BIT32_EXTRACT:
if (count == 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number) if (count >= 2 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number &&
(count == 2 || args[2].type == Constant::Type_Number))
{ {
uint32_t u = bit32(args[0].valueNumber); uint32_t u = bit32(args[0].valueNumber);
int f = int(args[1].valueNumber); int f = int(args[1].valueNumber);
int w = int(args[2].valueNumber); int w = count == 2 ? 1 : int(args[2].valueNumber);
if (f >= 0 && w > 0 && f + w <= 32) if (f >= 0 && w > 0 && f + w <= 32)
{ {
@ -356,13 +357,13 @@ Constant foldBuiltin(int bfid, const Constant* args, size_t count)
break; break;
case LBF_BIT32_REPLACE: case LBF_BIT32_REPLACE:
if (count == 4 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number && if (count >= 3 && args[0].type == Constant::Type_Number && args[1].type == Constant::Type_Number && args[2].type == Constant::Type_Number &&
args[3].type == Constant::Type_Number) (count == 3 || args[3].type == Constant::Type_Number))
{ {
uint32_t n = bit32(args[0].valueNumber); uint32_t n = bit32(args[0].valueNumber);
uint32_t v = bit32(args[1].valueNumber); uint32_t v = bit32(args[1].valueNumber);
int f = int(args[2].valueNumber); int f = int(args[2].valueNumber);
int w = int(args[3].valueNumber); int w = count == 3 ? 1 : int(args[3].valueNumber);
if (f >= 0 && w > 0 && f + w <= 32) if (f >= 0 && w > 0 && f + w <= 32)
{ {

View File

@ -23,13 +23,12 @@ LUAU_FASTINTVARIABLE(LuauCompileInlineThreshold, 25)
LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300) LUAU_FASTINTVARIABLE(LuauCompileInlineThresholdMaxBoost, 300)
LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5) LUAU_FASTINTVARIABLE(LuauCompileInlineDepth, 5)
LUAU_FASTFLAGVARIABLE(LuauCompileNoIpairs, false)
LUAU_FASTFLAGVARIABLE(LuauCompileFreeReassign, false)
LUAU_FASTFLAGVARIABLE(LuauCompileXEQ, false) LUAU_FASTFLAGVARIABLE(LuauCompileXEQ, false)
LUAU_FASTFLAGVARIABLE(LuauCompileOptimalAssignment, false) LUAU_FASTFLAGVARIABLE(LuauCompileOptimalAssignment, false)
LUAU_FASTFLAGVARIABLE(LuauCompileExtractK, false)
namespace Luau namespace Luau
{ {
@ -403,18 +402,37 @@ struct Compiler
} }
} }
void compileExprFastcallN(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs, int bfid) void compileExprFastcallN(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs, int bfid, int bfK = -1)
{ {
LUAU_ASSERT(!expr->self); LUAU_ASSERT(!expr->self);
LUAU_ASSERT(expr->args.size <= 2); LUAU_ASSERT(expr->args.size >= 1);
LUAU_ASSERT(expr->args.size <= 2 || (bfid == LBF_BIT32_EXTRACTK && expr->args.size == 3));
LUAU_ASSERT(bfid == LBF_BIT32_EXTRACTK ? bfK >= 0 : bfK < 0);
LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2; LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2;
uint32_t args[2] = {}; if (FFlag::LuauCompileExtractK)
{
opc = expr->args.size == 1 ? LOP_FASTCALL1 : (bfK >= 0 || isConstant(expr->args.data[1])) ? LOP_FASTCALL2K : LOP_FASTCALL2;
}
uint32_t args[3] = {};
for (size_t i = 0; i < expr->args.size; ++i) for (size_t i = 0; i < expr->args.size; ++i)
{ {
if (i > 0) if (FFlag::LuauCompileExtractK)
{
if (i > 0 && opc == LOP_FASTCALL2K)
{
int32_t cid = getConstantIndex(expr->args.data[i]);
if (cid < 0)
CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile");
args[i] = cid;
continue; // TODO: remove this and change if below to else if
}
}
else if (i > 0)
{ {
if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0) if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0)
{ {
@ -425,7 +443,9 @@ struct Compiler
} }
if (int reg = getExprLocalReg(expr->args.data[i]); reg >= 0) if (int reg = getExprLocalReg(expr->args.data[i]); reg >= 0)
{
args[i] = uint8_t(reg); args[i] = uint8_t(reg);
}
else else
{ {
args[i] = uint8_t(regs + 1 + i); args[i] = uint8_t(regs + 1 + i);
@ -437,21 +457,31 @@ struct Compiler
bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0);
if (opc != LOP_FASTCALL1) if (opc != LOP_FASTCALL1)
bytecode.emitAux(args[1]); bytecode.emitAux(bfK >= 0 ? bfK : args[1]);
// Set up a traditional Lua stack for the subsequent LOP_CALL. // Set up a traditional Lua stack for the subsequent LOP_CALL.
// Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for
// these FASTCALL variants. // these FASTCALL variants.
for (size_t i = 0; i < expr->args.size; ++i) for (size_t i = 0; i < expr->args.size; ++i)
{ {
if (i > 0 && opc == LOP_FASTCALL2K) if (FFlag::LuauCompileExtractK)
{ {
emitLoadK(uint8_t(regs + 1 + i), args[i]); if (i > 0 && opc == LOP_FASTCALL2K)
break; emitLoadK(uint8_t(regs + 1 + i), args[i]);
else if (args[i] != regs + 1 + i)
bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0);
} }
else
{
if (i > 0 && opc == LOP_FASTCALL2K)
{
emitLoadK(uint8_t(regs + 1 + i), args[i]);
break;
}
if (args[i] != regs + 1 + i) if (args[i] != regs + 1 + i)
bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0);
}
} }
// note, these instructions are normally not executed and are used as a fallback for FASTCALL // note, these instructions are normally not executed and are used as a fallback for FASTCALL
@ -600,7 +630,7 @@ struct Compiler
} }
else else
{ {
AstExprLocal* le = FFlag::LuauCompileFreeReassign ? getExprLocal(arg) : arg->as<AstExprLocal>(); AstExprLocal* le = getExprLocal(arg);
Variable* lv = le ? variables.find(le->local) : nullptr; Variable* lv = le ? variables.find(le->local) : nullptr;
// if the argument is a local that isn't mutated, we will simply reuse the existing register // if the argument is a local that isn't mutated, we will simply reuse the existing register
@ -723,6 +753,26 @@ struct Compiler
bfid = -1; bfid = -1;
} }
// Optimization: for bit32.extract with constant in-range f/w we compile using FASTCALL2K and a special builtin
if (FFlag::LuauCompileExtractK && bfid == LBF_BIT32_EXTRACT && expr->args.size == 3 && isConstant(expr->args.data[1]) && isConstant(expr->args.data[2]))
{
Constant fc = getConstant(expr->args.data[1]);
Constant wc = getConstant(expr->args.data[2]);
int fi = fc.type == Constant::Type_Number ? int(fc.valueNumber) : -1;
int wi = wc.type == Constant::Type_Number ? int(wc.valueNumber) : -1;
if (fi >= 0 && wi > 0 && fi + wi <= 32)
{
int fwp = fi | ((wi - 1) << 5);
int32_t cid = bytecode.addConstantNumber(fwp);
if (cid < 0)
CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile");
return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, LBF_BIT32_EXTRACTK, cid);
}
}
// Optimization: for 1/2 argument fast calls use specialized opcodes // Optimization: for 1/2 argument fast calls use specialized opcodes
if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2 && !isExprMultRet(expr->args.data[expr->args.size - 1])) if (bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2 && !isExprMultRet(expr->args.data[expr->args.size - 1]))
return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid);
@ -1218,7 +1268,7 @@ struct Compiler
{ {
// disambiguation: there's 4 cases (we only need truthy or falsy results based on onlyTruth) // disambiguation: there's 4 cases (we only need truthy or falsy results based on onlyTruth)
// onlyTruth = 1: a and b transforms to a ? b : dontcare // onlyTruth = 1: a and b transforms to a ? b : dontcare
// onlyTruth = 1: a or b transforms to a ? a : a // onlyTruth = 1: a or b transforms to a ? a : b
// onlyTruth = 0: a and b transforms to !a ? a : b // onlyTruth = 0: a and b transforms to !a ? a : b
// onlyTruth = 0: a or b transforms to !a ? b : dontcare // onlyTruth = 0: a or b transforms to !a ? b : dontcare
if (onlyTruth == (expr->op == AstExprBinary::And)) if (onlyTruth == (expr->op == AstExprBinary::And))
@ -2576,7 +2626,7 @@ struct Compiler
return; return;
// Optimization: for 1-1 local assignments, we can reuse the register *if* neither local is mutated // Optimization: for 1-1 local assignments, we can reuse the register *if* neither local is mutated
if (FFlag::LuauCompileFreeReassign && options.optimizationLevel >= 1 && stat->vars.size == 1 && stat->values.size == 1) if (options.optimizationLevel >= 1 && stat->vars.size == 1 && stat->values.size == 1)
{ {
if (AstExprLocal* re = getExprLocal(stat->values.data[0])) if (AstExprLocal* re = getExprLocal(stat->values.data[0]))
{ {
@ -2790,7 +2840,6 @@ struct Compiler
LUAU_ASSERT(vars == regs + 3); LUAU_ASSERT(vars == regs + 3);
LuauOpcode skipOp = LOP_FORGPREP; LuauOpcode skipOp = LOP_FORGPREP;
LuauOpcode loopOp = LOP_FORGLOOP;
// Optimization: when we iterate via pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration index // Optimization: when we iterate via pairs/ipairs, we generate special bytecode that optimizes the traversal using internal iteration index
// These instructions dynamically check if generator is equal to next/inext and bail out // These instructions dynamically check if generator is equal to next/inext and bail out
@ -2802,25 +2851,16 @@ struct Compiler
Builtin builtin = getBuiltin(stat->values.data[0]->as<AstExprCall>()->func, globals, variables); Builtin builtin = getBuiltin(stat->values.data[0]->as<AstExprCall>()->func, globals, variables);
if (builtin.isGlobal("ipairs")) // for .. in ipairs(t) if (builtin.isGlobal("ipairs")) // for .. in ipairs(t)
{
skipOp = LOP_FORGPREP_INEXT; skipOp = LOP_FORGPREP_INEXT;
loopOp = FFlag::LuauCompileNoIpairs ? LOP_FORGLOOP : LOP_FORGLOOP_INEXT;
}
else if (builtin.isGlobal("pairs")) // for .. in pairs(t) else if (builtin.isGlobal("pairs")) // for .. in pairs(t)
{
skipOp = LOP_FORGPREP_NEXT; skipOp = LOP_FORGPREP_NEXT;
loopOp = LOP_FORGLOOP;
}
} }
else if (stat->values.size == 2) else if (stat->values.size == 2)
{ {
Builtin builtin = getBuiltin(stat->values.data[0], globals, variables); Builtin builtin = getBuiltin(stat->values.data[0], globals, variables);
if (builtin.isGlobal("next")) // for .. in next,t if (builtin.isGlobal("next")) // for .. in next,t
{
skipOp = LOP_FORGPREP_NEXT; skipOp = LOP_FORGPREP_NEXT;
loopOp = LOP_FORGLOOP;
}
} }
} }
@ -2846,19 +2886,9 @@ struct Compiler
size_t backLabel = bytecode.emitLabel(); size_t backLabel = bytecode.emitLabel();
bytecode.emitAD(loopOp, regs, 0); // FORGLOOP uses aux to encode variable count and fast path flag for ipairs traversal in the high bit
bytecode.emitAD(LOP_FORGLOOP, regs, 0);
if (FFlag::LuauCompileNoIpairs) bytecode.emitAux((skipOp == LOP_FORGPREP_INEXT ? 0x80000000 : 0) | uint32_t(stat->vars.size));
{
// TODO: remove loopOp as it's a constant now
LUAU_ASSERT(loopOp == LOP_FORGLOOP);
// FORGLOOP uses aux to encode variable count and fast path flag for ipairs traversal in the high bit
bytecode.emitAux((skipOp == LOP_FORGPREP_INEXT ? 0x80000000 : 0) | uint32_t(stat->vars.size));
}
// note: FORGLOOP needs variable count encoded in AUX field, other loop instructions assume a fixed variable count
else if (loopOp == LOP_FORGLOOP)
bytecode.emitAux(uint32_t(stat->vars.size));
size_t endLabel = bytecode.emitLabel(); size_t endLabel = bytecode.emitLabel();

View File

@ -66,6 +66,7 @@ target_sources(Luau.CodeGen PRIVATE
# Luau.Analysis Sources # Luau.Analysis Sources
target_sources(Luau.Analysis PRIVATE target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Anyification.h
Analysis/include/Luau/ApplyTypeFunction.h Analysis/include/Luau/ApplyTypeFunction.h
Analysis/include/Luau/AstJsonEncoder.h Analysis/include/Luau/AstJsonEncoder.h
Analysis/include/Luau/AstQuery.h Analysis/include/Luau/AstQuery.h
@ -115,6 +116,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/include/Luau/Variant.h Analysis/include/Luau/Variant.h
Analysis/include/Luau/VisitTypeVar.h Analysis/include/Luau/VisitTypeVar.h
Analysis/src/Anyification.cpp
Analysis/src/ApplyTypeFunction.cpp Analysis/src/ApplyTypeFunction.cpp
Analysis/src/AstJsonEncoder.cpp Analysis/src/AstJsonEncoder.cpp
Analysis/src/AstQuery.cpp Analysis/src/AstQuery.cpp
@ -126,6 +128,7 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/ConstraintGraphBuilder.cpp Analysis/src/ConstraintGraphBuilder.cpp
Analysis/src/ConstraintSolver.cpp Analysis/src/ConstraintSolver.cpp
Analysis/src/ConstraintSolverLogger.cpp Analysis/src/ConstraintSolverLogger.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp
Analysis/src/Error.cpp Analysis/src/Error.cpp
Analysis/src/Frontend.cpp Analysis/src/Frontend.cpp
Analysis/src/Instantiation.cpp Analysis/src/Instantiation.cpp
@ -155,7 +158,6 @@ target_sources(Luau.Analysis PRIVATE
Analysis/src/TypeVar.cpp Analysis/src/TypeVar.cpp
Analysis/src/Unifiable.cpp Analysis/src/Unifiable.cpp
Analysis/src/Unifier.cpp Analysis/src/Unifier.cpp
Analysis/src/EmbeddedBuiltinDefinitions.cpp
) )
# Luau.VM Sources # Luau.VM Sources

View File

@ -1209,9 +1209,8 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*))
{ {
luaC_checkGC(L); luaC_checkGC(L);
luaC_checkthreadsleep(L); luaC_checkthreadsleep(L);
size_t as = sz + sizeof(dtor); // make sure sz + sizeof(dtor) doesn't overflow; luaU_newdata will reject SIZE_MAX correctly
if (as < sizeof(dtor)) size_t as = sz < SIZE_MAX - sizeof(dtor) ? sz + sizeof(dtor) : SIZE_MAX;
as = SIZE_MAX; // Will cause a memory error in luaU_newudata.
Udata* u = luaU_newudata(L, as, UTAG_IDTOR); Udata* u = luaU_newudata(L, as, UTAG_IDTOR);
memcpy(&u->data + sz, &dtor, sizeof(dtor)); memcpy(&u->data + sz, &dtor, sizeof(dtor));
setuvalue(L, L->top, u); setuvalue(L, L->top, u);

View File

@ -15,6 +15,8 @@
#include <intrin.h> #include <intrin.h>
#endif #endif
LUAU_FASTFLAGVARIABLE(LuauFasterBit32NoWidth, false)
// luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM
// The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack.
// If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path
@ -600,24 +602,39 @@ static int luauF_btest(lua_State* L, StkId res, TValue* arg0, int nresults, StkI
static int luauF_extract(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) static int luauF_extract(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{ {
if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) if (nparams >= (3 - FFlag::LuauFasterBit32NoWidth) && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args))
{ {
double a1 = nvalue(arg0); double a1 = nvalue(arg0);
double a2 = nvalue(args); double a2 = nvalue(args);
double a3 = nvalue(args + 1);
unsigned n; unsigned n;
luai_num2unsigned(n, a1); luai_num2unsigned(n, a1);
int f = int(a2); int f = int(a2);
int w = int(a3);
if (f >= 0 && w > 0 && f + w <= 32) if (nparams == 2)
{ {
uint32_t m = ~(0xfffffffeu << (w - 1)); if (unsigned(f) < 32)
uint32_t r = (n >> f) & m; {
uint32_t m = 1;
uint32_t r = (n >> f) & m;
setnvalue(res, double(r)); setnvalue(res, double(r));
return 1; return 1;
}
}
else if (ttisnumber(args + 1))
{
double a3 = nvalue(args + 1);
int w = int(a3);
if (f >= 0 && w > 0 && f + w <= 32)
{
uint32_t m = ~(0xfffffffeu << (w - 1));
uint32_t r = (n >> f) & m;
setnvalue(res, double(r));
return 1;
}
} }
} }
@ -676,26 +693,41 @@ static int luauF_lshift(lua_State* L, StkId res, TValue* arg0, int nresults, Stk
static int luauF_replace(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) static int luauF_replace(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{ {
if (nparams >= 4 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1) && ttisnumber(args + 2)) if (nparams >= (4 - FFlag::LuauFasterBit32NoWidth) && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1))
{ {
double a1 = nvalue(arg0); double a1 = nvalue(arg0);
double a2 = nvalue(args); double a2 = nvalue(args);
double a3 = nvalue(args + 1); double a3 = nvalue(args + 1);
double a4 = nvalue(args + 2);
unsigned n, v; unsigned n, v;
luai_num2unsigned(n, a1); luai_num2unsigned(n, a1);
luai_num2unsigned(v, a2); luai_num2unsigned(v, a2);
int f = int(a3); int f = int(a3);
int w = int(a4);
if (f >= 0 && w > 0 && f + w <= 32) if (nparams == 3)
{ {
uint32_t m = ~(0xfffffffeu << (w - 1)); if (unsigned(f) < 32)
uint32_t r = (n & ~(m << f)) | ((v & m) << f); {
uint32_t m = 1;
uint32_t r = (n & ~(m << f)) | ((v & m) << f);
setnvalue(res, double(r)); setnvalue(res, double(r));
return 1; return 1;
}
}
else if (ttisnumber(args + 2))
{
double a4 = nvalue(args + 2);
int w = int(a4);
if (f >= 0 && w > 0 && f + w <= 32)
{
uint32_t m = ~(0xfffffffeu << (w - 1));
uint32_t r = (n & ~(m << f)) | ((v & m) << f);
setnvalue(res, double(r));
return 1;
}
} }
} }
@ -1138,6 +1170,31 @@ static int luauF_rawlen(lua_State* L, StkId res, TValue* arg0, int nresults, Stk
return -1; return -1;
} }
static int luauF_extractk(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams)
{
// args is known to contain a number constant with packed in-range f/w
if (nparams >= 2 && nresults <= 1 && ttisnumber(arg0))
{
double a1 = nvalue(arg0);
double a2 = nvalue(args);
unsigned n;
luai_num2unsigned(n, a1);
int fw = int(a2);
int f = fw & 31;
int w1 = fw >> 5;
uint32_t m = ~(0xfffffffeu << w1);
uint32_t r = (n >> f) & m;
setnvalue(res, double(r));
return 1;
}
return -1;
}
luau_FastFunction luauF_table[256] = { luau_FastFunction luauF_table[256] = {
NULL, NULL,
luauF_assert, luauF_assert,
@ -1211,4 +1268,6 @@ luau_FastFunction luauF_table[256] = {
luauF_select, luauF_select,
luauF_rawlen, luauF_rawlen,
luauF_extractk,
}; };

View File

@ -175,8 +175,7 @@ void luaF_freeclosure(lua_State* L, Closure* c, lua_Page* page)
const LocVar* luaF_getlocal(const Proto* f, int local_number, int pc) const LocVar* luaF_getlocal(const Proto* f, int local_number, int pc)
{ {
int i; for (int i = 0; i < f->sizelocvars; i++)
for (i = 0; i < f->sizelocvars; i++)
{ {
if (pc >= f->locvars[i].startpc && pc < f->locvars[i].endpc) if (pc >= f->locvars[i].startpc && pc < f->locvars[i].endpc)
{ // is variable active? { // is variable active?
@ -185,5 +184,15 @@ const LocVar* luaF_getlocal(const Proto* f, int local_number, int pc)
return &f->locvars[i]; return &f->locvars[i];
} }
} }
return NULL; // not found
}
const LocVar* luaF_findlocal(const Proto* f, int local_reg, int pc)
{
for (int i = 0; i < f->sizelocvars; i++)
if (local_reg == f->locvars[i].reg && pc >= f->locvars[i].startpc && pc < f->locvars[i].endpc)
return &f->locvars[i];
return NULL; // not found return NULL; // not found
} }

View File

@ -17,3 +17,4 @@ LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c, struct lua_Page* page)
LUAI_FUNC void luaF_unlinkupval(UpVal* uv); LUAI_FUNC void luaF_unlinkupval(UpVal* uv);
LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv, struct lua_Page* page); LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv, struct lua_Page* page);
LUAI_FUNC const LocVar* luaF_getlocal(const Proto* func, int local_number, int pc); LUAI_FUNC const LocVar* luaF_getlocal(const Proto* func, int local_number, int pc);
LUAI_FUNC const LocVar* luaF_findlocal(const Proto* func, int local_reg, int pc);

View File

@ -345,6 +345,9 @@ static void dumpclosure(FILE* f, Closure* cl)
if (cl->isC) if (cl->isC)
{ {
if (cl->c.debugname)
fprintf(f, ",\"name\":\"%s\"", cl->c.debugname + 0);
if (cl->nupvalues) if (cl->nupvalues)
{ {
fprintf(f, ",\"upvalues\":["); fprintf(f, ",\"upvalues\":[");
@ -354,6 +357,9 @@ static void dumpclosure(FILE* f, Closure* cl)
} }
else else
{ {
if (cl->l.p->debugname)
fprintf(f, ",\"name\":\"%s\"", getstr(cl->l.p->debugname));
fprintf(f, ",\"proto\":"); fprintf(f, ",\"proto\":");
dumpref(f, obj2gco(cl->l.p)); dumpref(f, obj2gco(cl->l.p));
if (cl->nupvalues) if (cl->nupvalues)
@ -403,7 +409,7 @@ static void dumpthread(FILE* f, lua_State* th)
fprintf(f, ",\"source\":\""); fprintf(f, ",\"source\":\"");
dumpstringdata(f, p->source->data, p->source->len); dumpstringdata(f, p->source->data, p->source->len);
fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); fprintf(f, "\",\"line\":%d", p->linedefined);
} }
if (th->top > th->stack) if (th->top > th->stack)
@ -411,6 +417,55 @@ static void dumpthread(FILE* f, lua_State* th)
fprintf(f, ",\"stack\":["); fprintf(f, ",\"stack\":[");
dumprefs(f, th->stack, th->top - th->stack); dumprefs(f, th->stack, th->top - th->stack);
fprintf(f, "]"); fprintf(f, "]");
CallInfo* ci = th->base_ci;
bool first = true;
fprintf(f, ",\"stacknames\":[");
for (StkId v = th->stack; v < th->top; ++v)
{
if (!iscollectable(v))
continue;
while (ci < th->ci && v >= (ci + 1)->func)
ci++;
if (!first)
fputc(',', f);
first = false;
if (v == ci->func)
{
Closure* cl = ci_func(ci);
if (cl->isC)
{
fprintf(f, "\"frame:%s\"", cl->c.debugname ? cl->c.debugname : "[C]");
}
else
{
Proto* p = cl->l.p;
fprintf(f, "\"frame:");
if (p->source)
dumpstringdata(f, p->source->data, p->source->len);
fprintf(f, ":%d:%s\"", p->linedefined, p->debugname ? getstr(p->debugname) : "");
}
}
else if (isLua(ci))
{
Proto* p = ci_func(ci)->l.p;
int pc = pcRel(ci->savedpc, p);
const LocVar* var = luaF_findlocal(p, int(v - ci->base), pc);
if (var && var->varname)
fprintf(f, "\"%s\"", getstr(var->varname));
else
fprintf(f, "null");
}
else
fprintf(f, "null");
}
fprintf(f, "]");
} }
fprintf(f, "}"); fprintf(f, "}");
} }

View File

@ -3189,4 +3189,27 @@ a.@1
CHECK(ac.entryMap.count("y")); CHECK(ac.entryMap.count("y"));
} }
TEST_CASE_FIXTURE(ACFixture, "globals_are_order_independent")
{
ScopedFastFlag sff("LuauAutocompleteFixGlobalOrder", true);
check(R"(
local myLocal = 4
function abc0()
local myInnerLocal = 1
@1
end
function abc1()
local myInnerLocal = 1
end
)");
auto ac = autocomplete('1');
CHECK(ac.entryMap.count("myLocal"));
CHECK(ac.entryMap.count("myInnerLocal"));
CHECK(ac.entryMap.count("abc0"));
CHECK(ac.entryMap.count("abc1"));
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -261,8 +261,6 @@ L1: RETURN R0 0
TEST_CASE("ForBytecode") TEST_CASE("ForBytecode")
{ {
ScopedFastFlag sff("LuauCompileNoIpairs", true);
// basic for loop: variable directly refers to internal iteration index (R2) // basic for loop: variable directly refers to internal iteration index (R2)
CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"( CHECK_EQ("\n" + compileFunction0("for i=1,5 do print(i) end"), R"(
LOADN R2 1 LOADN R2 1
@ -349,8 +347,6 @@ RETURN R0 0
TEST_CASE("ForBytecodeBuiltin") TEST_CASE("ForBytecodeBuiltin")
{ {
ScopedFastFlag sff("LuauCompileNoIpairs", true);
// we generally recognize builtins like pairs/ipairs and emit special opcodes // we generally recognize builtins like pairs/ipairs and emit special opcodes
CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"( CHECK_EQ("\n" + compileFunction0("for k,v in ipairs({}) do end"), R"(
GETIMPORT R0 1 GETIMPORT R0 1
@ -2065,6 +2061,69 @@ TEST_CASE("RecursionParse")
{ {
CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile"); CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile");
} }
try
{
Luau::compileOrThrow(bcb, rep("a(", 1500) + "42" + rep(")", 1500));
CHECK(!"Expected exception");
}
catch (std::exception& e)
{
CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile");
}
try
{
Luau::compileOrThrow(bcb, "return " + rep("{", 1500) + "42" + rep("}", 1500));
CHECK(!"Expected exception");
}
catch (std::exception& e)
{
CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile");
}
try
{
Luau::compileOrThrow(bcb, rep("while true do ", 1500) + "print()" + rep(" end", 1500));
CHECK(!"Expected exception");
}
catch (std::exception& e)
{
CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile");
}
try
{
Luau::compileOrThrow(bcb, rep("for i=1,1 do ", 1500) + "print()" + rep(" end", 1500));
CHECK(!"Expected exception");
}
catch (std::exception& e)
{
CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your expression to make the code compile");
}
#if 0
// This currently requires too much stack space on MSVC/x64 and crashes with stack overflow at recursion depth 935
try
{
Luau::compileOrThrow(bcb, rep("function a() ", 1500) + "print()" + rep(" end", 1500));
CHECK(!"Expected exception");
}
catch (std::exception& e)
{
CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile");
}
try
{
Luau::compileOrThrow(bcb, "return " + rep("function() return ", 1500) + "42" + rep(" end", 1500));
CHECK(!"Expected exception");
}
catch (std::exception& e)
{
CHECK_EQ(std::string(e.what()), "Exceeded allowed recursion depth; simplify your block to make the code compile");
}
#endif
} }
TEST_CASE("ArrayIndexLiteral") TEST_CASE("ArrayIndexLiteral")
@ -2111,8 +2170,6 @@ L1: RETURN R3 -1
TEST_CASE("UpvaluesLoopsBytecode") TEST_CASE("UpvaluesLoopsBytecode")
{ {
ScopedFastFlag sff("LuauCompileNoIpairs", true);
CHECK_EQ("\n" + compileFunction(R"( CHECK_EQ("\n" + compileFunction(R"(
function test() function test()
for i=1,10 do for i=1,10 do
@ -3790,8 +3847,6 @@ RETURN R0 1
TEST_CASE("SharedClosure") TEST_CASE("SharedClosure")
{ {
ScopedFastFlag sff("LuauCompileFreeReassign", true);
// closures can be shared even if functions refer to upvalues, as long as upvalues are top-level // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level
CHECK_EQ("\n" + compileFunction(R"( CHECK_EQ("\n" + compileFunction(R"(
local val = ... local val = ...
@ -6004,6 +6059,8 @@ return
math.clamp(-1, 0, 1), math.clamp(-1, 0, 1),
math.sign(77), math.sign(77),
math.round(7.6), math.round(7.6),
bit32.extract(-1, 31),
bit32.replace(100, 1, 0),
(type("fin")) (type("fin"))
)", )",
0, 2), 0, 2),
@ -6055,8 +6112,10 @@ LOADK R43 K2
LOADN R44 0 LOADN R44 0
LOADN R45 1 LOADN R45 1
LOADN R46 8 LOADN R46 8
LOADK R47 K3 LOADN R47 1
RETURN R0 48 LOADN R48 101
LOADK R49 K3
RETURN R0 50
)"); )");
} }
@ -6067,7 +6126,8 @@ return
math.abs(), math.abs(),
math.max(1, true), math.max(1, true),
string.byte("abc", 42), string.byte("abc", 42),
bit32.rshift(10, 42) bit32.rshift(10, 42),
bit32.extract(1, 2, "3")
)", )",
0, 2), 0, 2),
R"( R"(
@ -6088,8 +6148,14 @@ L2: LOADN R4 10
FASTCALL2K 39 R4 K7 L3 FASTCALL2K 39 R4 K7 L3
LOADK R5 K7 LOADK R5 K7
GETIMPORT R3 13 GETIMPORT R3 13
CALL R3 2 -1 CALL R3 2 1
L3: RETURN R0 -1 L3: LOADN R5 1
LOADN R6 2
LOADK R7 K14
FASTCALL 34 L4
GETIMPORT R4 16
CALL R4 3 -1
L4: RETURN R0 -1
)"); )");
} }
@ -6146,8 +6212,6 @@ RETURN R0 1
TEST_CASE("LocalReassign") TEST_CASE("LocalReassign")
{ {
ScopedFastFlag sff("LuauCompileFreeReassign", true);
// locals can be re-assigned and the register gets reused // locals can be re-assigned and the register gets reused
CHECK_EQ("\n" + compileFunction0(R"( CHECK_EQ("\n" + compileFunction0(R"(
local function test(a, b) local function test(a, b)
@ -6459,4 +6523,26 @@ RETURN R0 0
)"); )");
} }
TEST_CASE("BuiltinExtractK")
{
ScopedFastFlag sff("LuauCompileExtractK", true);
// below, K0 refers to a packed f+w constant for bit32.extractk builtin
// K1 and K2 refer to 1 and 3 and are only used during fallback path
CHECK_EQ("\n" + compileFunction0(R"(
local v = ...
return bit32.extract(v, 1, 3)
)"), R"(
GETVARARGS R0 1
FASTCALL2K 59 R0 K0 L0
MOVE R2 R0
LOADK R3 K1
LOADK R4 K2
GETIMPORT R1 5
CALL R1 3 -1
L0: RETURN R1 -1
)");
}
TEST_SUITE_END(); TEST_SUITE_END();

View File

@ -727,7 +727,7 @@ TEST_CASE("NewUserdataOverflow")
// The overflow might segfault in the following call. // The overflow might segfault in the following call.
lua_getmetatable(L1, -1); lua_getmetatable(L1, -1);
return 0; return 0;
}, "PCall"); }, nullptr);
CHECK(lua_pcall(L, 0, 0, 0) == LUA_ERRRUN); CHECK(lua_pcall(L, 0, 0, 0) == LUA_ERRRUN);
CHECK(strcmp(lua_tostring(L, -1), "memory allocation error: block too big") == 0); CHECK(strcmp(lua_tostring(L, -1), "memory allocation error: block too big") == 0);

View File

@ -443,7 +443,8 @@ BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete)
ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture() ConstraintGraphBuilderFixture::ConstraintGraphBuilderFixture()
: Fixture() : Fixture()
, cgb(mainModuleName, getMainModule(), &arena, NotNull(&ice), frontend.getGlobalScope()) , mainModule(new Module)
, cgb(mainModuleName, mainModule, &arena, NotNull(&ice), frontend.getGlobalScope())
, forceTheFlag{"DebugLuauDeferredConstraintResolution", true} , forceTheFlag{"DebugLuauDeferredConstraintResolution", true}
{ {
BlockedTypeVar::nextIndex = 0; BlockedTypeVar::nextIndex = 0;

View File

@ -162,6 +162,7 @@ struct BuiltinsFixture : Fixture
struct ConstraintGraphBuilderFixture : Fixture struct ConstraintGraphBuilderFixture : Fixture
{ {
TypeArena arena; TypeArena arena;
ModulePtr mainModule;
ConstraintGraphBuilder cgb; ConstraintGraphBuilder cgb;
ScopedFastFlag forceTheFlag; ScopedFastFlag forceTheFlag;

View File

@ -12,6 +12,11 @@ using namespace Luau;
struct NormalizeFixture : Fixture struct NormalizeFixture : Fixture
{ {
ScopedFastFlag sff1{"LuauLowerBoundsCalculation", true}; ScopedFastFlag sff1{"LuauLowerBoundsCalculation", true};
bool isSubtype(TypeId a, TypeId b)
{
return ::Luau::isSubtype(a, b, NotNull{getMainModule()->getModuleScope().get()}, ice);
}
}; };
void createSomeClasses(TypeChecker& typeChecker) void createSomeClasses(TypeChecker& typeChecker)
@ -49,12 +54,6 @@ void createSomeClasses(TypeChecker& typeChecker)
freeze(arena); freeze(arena);
} }
static bool isSubtype(TypeId a, TypeId b)
{
InternalErrorReporter ice;
return isSubtype(a, b, ice);
}
TEST_SUITE_BEGIN("isSubtype"); TEST_SUITE_BEGIN("isSubtype");
TEST_CASE_FIXTURE(NormalizeFixture, "primitives") TEST_CASE_FIXTURE(NormalizeFixture, "primitives")
@ -511,6 +510,8 @@ TEST_CASE_FIXTURE(NormalizeFixture, "classes")
{ {
createSomeClasses(typeChecker); createSomeClasses(typeChecker);
check(""); // Ensure that we have a main Module.
TypeId p = typeChecker.globalScope->lookupType("Parent")->type; TypeId p = typeChecker.globalScope->lookupType("Parent")->type;
TypeId c = typeChecker.globalScope->lookupType("Child")->type; TypeId c = typeChecker.globalScope->lookupType("Child")->type;
TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type; TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type;
@ -595,6 +596,7 @@ TEST_CASE_FIXTURE(NormalizeFixture, "union_with_overlapping_field_that_has_a_sub
)"); )");
ModulePtr tempModule{new Module}; ModulePtr tempModule{new Module};
tempModule->scopes.emplace_back(Location(), std::make_shared<Scope>(getSingletonTypes().anyTypePack));
// HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze // HACK: Normalization is an in-place operation. We need to cheat a little here and unfreeze
// the arena that the type lives in. // the arena that the type lives in.
@ -880,7 +882,6 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect
{ {
ScopedFastFlag flags[] = { ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true}, {"LuauLowerBoundsCalculation", true},
{"LuauQuantifyConstrained", true},
}; };
// We use a function and inferred parameter types to prevent intermediate normalizations from being performed. // We use a function and inferred parameter types to prevent intermediate normalizations from being performed.
@ -921,7 +922,6 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect
{ {
ScopedFastFlag flags[] = { ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true}, {"LuauLowerBoundsCalculation", true},
{"LuauQuantifyConstrained", true},
}; };
// We use a function and inferred parameter types to prevent intermediate normalizations from being performed. // We use a function and inferred parameter types to prevent intermediate normalizations from being performed.
@ -961,7 +961,6 @@ TEST_CASE_FIXTURE(Fixture, "intersection_inside_a_table_inside_another_intersect
{ {
ScopedFastFlag flags[] = { ScopedFastFlag flags[] = {
{"LuauLowerBoundsCalculation", true}, {"LuauLowerBoundsCalculation", true},
{"LuauQuantifyConstrained", true},
}; };
// We use a function and inferred parameter types to prevent intermediate normalizations from being performed. // We use a function and inferred parameter types to prevent intermediate normalizations from being performed.
@ -1149,7 +1148,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "normalization_does_not_convert_ever")
{ {
ScopedFastFlag sff[]{ ScopedFastFlag sff[]{
{"LuauLowerBoundsCalculation", true}, {"LuauLowerBoundsCalculation", true},
{"LuauQuantifyConstrained", true},
}; };
CheckResult result = check(R"( CheckResult result = check(R"(

View File

@ -2535,7 +2535,6 @@ end
TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation") TEST_CASE_FIXTURE(Fixture, "error_message_for_using_function_as_type_annotation")
{ {
ScopedFastFlag sff{"LuauParserFunctionKeywordAsTypeHelp", true};
ParseResult result = tryParse(R"( ParseResult result = tryParse(R"(
type Foo = function type Foo = function
)"); )");

View File

@ -1637,7 +1637,6 @@ TEST_CASE_FIXTURE(Fixture, "quantify_constrained_types")
{ {
ScopedFastFlag sff[]{ ScopedFastFlag sff[]{
{"LuauLowerBoundsCalculation", true}, {"LuauLowerBoundsCalculation", true},
{"LuauQuantifyConstrained", true},
}; };
CheckResult result = check(R"( CheckResult result = check(R"(
@ -1662,7 +1661,6 @@ TEST_CASE_FIXTURE(Fixture, "call_o_with_another_argument_after_foo_was_quantifie
{ {
ScopedFastFlag sff[]{ ScopedFastFlag sff[]{
{"LuauLowerBoundsCalculation", true}, {"LuauLowerBoundsCalculation", true},
{"LuauQuantifyConstrained", true},
}; };
CheckResult result = check(R"( CheckResult result = check(R"(

View File

@ -329,6 +329,35 @@ function tbl:foo(b: number, c: number)
-- introduce BoundTypeVar to imported type -- introduce BoundTypeVar to imported type
arrayops.foo(self._regions) arrayops.foo(self._regions)
end end
-- this alias decreases function type level and causes a demotion of its type
type Table = typeof(tbl)
)");
LUAU_REQUIRE_NO_ERRORS(result);
}
TEST_CASE_FIXTURE(BuiltinsFixture, "do_not_modify_imported_types_5")
{
ScopedFastFlag luauInplaceDemoteSkipAllBound{"LuauInplaceDemoteSkipAllBound", true};
fileResolver.source["game/A"] = R"(
export type Type = {x: number, y: number}
local arrayops = {}
function arrayops.foo(x: Type) end
return arrayops
)";
CheckResult result = check(R"(
local arrayops = require(game.A)
local tbl = {}
tbl.a = 2
function tbl:foo(b: number, c: number)
-- introduce boundTo TableTypeVar to imported type
self.x.a = 2
arrayops.foo(self.x)
end
-- this alias decreases function type level and causes a demotion of its type
type Table = typeof(tbl) type Table = typeof(tbl)
)"); )");

View File

@ -485,7 +485,6 @@ TEST_CASE_FIXTURE(Fixture, "constrained_is_level_dependent")
{ {
ScopedFastFlag sff[]{ ScopedFastFlag sff[]{
{"LuauLowerBoundsCalculation", true}, {"LuauLowerBoundsCalculation", true},
{"LuauQuantifyConstrained", true},
}; };
CheckResult result = check(R"( CheckResult result = check(R"(

View File

@ -41,6 +41,7 @@ struct RefinementClassFixture : Fixture
RefinementClassFixture() RefinementClassFixture()
{ {
TypeArena& arena = typeChecker.globalTypes; TypeArena& arena = typeChecker.globalTypes;
NotNull<Scope> scope{typeChecker.globalScope.get()};
unfreeze(arena); unfreeze(arena);
TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"});
@ -49,7 +50,7 @@ struct RefinementClassFixture : Fixture
{"Y", Property{typeChecker.numberType}}, {"Y", Property{typeChecker.numberType}},
{"Z", Property{typeChecker.numberType}}, {"Z", Property{typeChecker.numberType}},
}; };
normalize(vec3, arena, *typeChecker.iceHandler); normalize(vec3, scope, arena, *typeChecker.iceHandler);
TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"}); TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr, "Test"});
@ -57,21 +58,21 @@ struct RefinementClassFixture : Fixture
TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); TypePackId isARets = arena.addTypePack({typeChecker.booleanType});
TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets});
getMutable<FunctionTypeVar>(isA)->magicFunction = magicFunctionInstanceIsA; getMutable<FunctionTypeVar>(isA)->magicFunction = magicFunctionInstanceIsA;
normalize(isA, arena, *typeChecker.iceHandler); normalize(isA, scope, arena, *typeChecker.iceHandler);
getMutable<ClassTypeVar>(inst)->props = { getMutable<ClassTypeVar>(inst)->props = {
{"Name", Property{typeChecker.stringType}}, {"Name", Property{typeChecker.stringType}},
{"IsA", Property{isA}}, {"IsA", Property{isA}},
}; };
normalize(inst, arena, *typeChecker.iceHandler); normalize(inst, scope, arena, *typeChecker.iceHandler);
TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"});
normalize(folder, arena, *typeChecker.iceHandler); normalize(folder, scope, arena, *typeChecker.iceHandler);
TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"});
getMutable<ClassTypeVar>(part)->props = { getMutable<ClassTypeVar>(part)->props = {
{"Position", Property{vec3}}, {"Position", Property{vec3}},
}; };
normalize(part, arena, *typeChecker.iceHandler); normalize(part, scope, arena, *typeChecker.iceHandler);
typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3};
typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst};
@ -934,8 +935,6 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip
TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x")
{ {
ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true};
CheckResult result = check(R"( CheckResult result = check(R"(
type T = {tag: "missing", x: nil} | {tag: "exists", x: string} type T = {tag: "missing", x: nil} | {tag: "exists", x: string}
@ -1230,8 +1229,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "refine_unknowns")
TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil") TEST_CASE_FIXTURE(BuiltinsFixture, "falsiness_of_TruthyPredicate_narrows_into_nil")
{ {
ScopedFastFlag sff{"LuauFalsyPredicateReturnsNilInstead", true};
CheckResult result = check(R"( CheckResult result = check(R"(
local function f(t: {number}) local function f(t: {number})
local x = t[1] local x = t[1]

View File

@ -3003,8 +3003,6 @@ TEST_CASE_FIXTURE(Fixture, "expected_indexer_from_table_union")
TEST_CASE_FIXTURE(Fixture, "prop_access_on_key_whose_types_mismatches") TEST_CASE_FIXTURE(Fixture, "prop_access_on_key_whose_types_mismatches")
{ {
ScopedFastFlag sff{"LuauReportErrorsOnIndexerKeyMismatch", true};
CheckResult result = check(R"( CheckResult result = check(R"(
local t: {number} = {} local t: {number} = {}
local x = t.x local x = t.x
@ -3016,8 +3014,6 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_key_whose_types_mismatches")
TEST_CASE_FIXTURE(Fixture, "prop_access_on_unions_of_indexers_where_key_whose_types_mismatches") TEST_CASE_FIXTURE(Fixture, "prop_access_on_unions_of_indexers_where_key_whose_types_mismatches")
{ {
ScopedFastFlag sff{"LuauReportErrorsOnIndexerKeyMismatch", true};
CheckResult result = check(R"( CheckResult result = check(R"(
local t: { [number]: number } | { [boolean]: number } = {} local t: { [number]: number } | { [boolean]: number } = {}
local u = t.x local u = t.x

View File

@ -17,7 +17,8 @@ struct TryUnifyFixture : Fixture
ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}};
InternalErrorReporter iceHandler; InternalErrorReporter iceHandler;
UnifierSharedState unifierState{&iceHandler}; UnifierSharedState unifierState{&iceHandler};
Unifier state{&arena, Mode::Strict, Location{}, Variance::Covariant, unifierState};
Unifier state{&arena, Mode::Strict, NotNull{globalScope.get()}, Location{}, Variance::Covariant, unifierState};
}; };
TEST_SUITE_BEGIN("TryUnifyTests"); TEST_SUITE_BEGIN("TryUnifyTests");

View File

@ -193,6 +193,11 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_only_cyclic_union")
CHECK(actual.empty()); CHECK(actual.empty());
} }
/* FIXME: This test is pretty weird. It would be much nicer if we could
* perform this operation without a TypeChecker so that we don't have to jam
* all this state into it to make stuff work.
*/
TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure")
{ {
TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; TypeVar ftv11{FreeTypeVar{TypeLevel{}}};
@ -268,6 +273,7 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure")
TypeId root = &ttvTweenResult; TypeId root = &ttvTweenResult;
typeChecker.currentModule = std::make_shared<Module>(); typeChecker.currentModule = std::make_shared<Module>();
typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared<Scope>(getSingletonTypes().anyTypePack));
TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{});

View File

@ -100,6 +100,9 @@ assert(bit32.extract(0xa0001111, 28, 4) == 0xa)
assert(bit32.extract(0xa0001111, 31, 1) == 1) assert(bit32.extract(0xa0001111, 31, 1) == 1)
assert(bit32.extract(0x50000111, 31, 1) == 0) assert(bit32.extract(0x50000111, 31, 1) == 0)
assert(bit32.extract(0xf2345679, 0, 32) == 0xf2345679) assert(bit32.extract(0xf2345679, 0, 32) == 0xf2345679)
assert(bit32.extract(0xa0001111, 16) == 0)
assert(bit32.extract(0xa0001111, 31) == 1)
assert(bit32.extract(42, 1, 3) == 5)
assert(not pcall(bit32.extract, 0, -1)) assert(not pcall(bit32.extract, 0, -1))
assert(not pcall(bit32.extract, 0, 32)) assert(not pcall(bit32.extract, 0, 32))
@ -152,5 +155,6 @@ assert(bit32.btest(1, "3") == true)
assert(bit32.btest("1", 3) == true) assert(bit32.btest("1", 3) == true)
assert(bit32.countlz("42") == 26) assert(bit32.countlz("42") == 26)
assert(bit32.countrz("42") == 1) assert(bit32.countrz("42") == 1)
assert(bit32.extract("42", 1, 3) == 5)
return('OK') return('OK')

View File

@ -10,12 +10,10 @@ AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag
AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag_handler AnnotationTests.luau_ice_triggers_an_ice_exception_with_flag_handler
AnnotationTests.luau_ice_triggers_an_ice_handler AnnotationTests.luau_ice_triggers_an_ice_handler
AnnotationTests.luau_print_is_magic_if_the_flag_is_set AnnotationTests.luau_print_is_magic_if_the_flag_is_set
AnnotationTests.luau_print_is_not_special_without_the_flag
AnnotationTests.occurs_check_on_cyclic_intersection_typevar AnnotationTests.occurs_check_on_cyclic_intersection_typevar
AnnotationTests.occurs_check_on_cyclic_union_typevar AnnotationTests.occurs_check_on_cyclic_union_typevar
AnnotationTests.too_many_type_params AnnotationTests.too_many_type_params
AnnotationTests.two_type_params AnnotationTests.two_type_params
AnnotationTests.unknown_type_reference_generates_error
AnnotationTests.use_type_required_from_another_file AnnotationTests.use_type_required_from_another_file
AstQuery.last_argument_function_call_type AstQuery.last_argument_function_call_type
AstQuery::getDocumentationSymbolAtPosition.overloaded_fn AstQuery::getDocumentationSymbolAtPosition.overloaded_fn
@ -27,17 +25,13 @@ AutocompleteTest.autocomplete_end_with_lambda
AutocompleteTest.autocomplete_first_function_arg_expected_type AutocompleteTest.autocomplete_first_function_arg_expected_type
AutocompleteTest.autocomplete_for_in_middle_keywords AutocompleteTest.autocomplete_for_in_middle_keywords
AutocompleteTest.autocomplete_for_middle_keywords AutocompleteTest.autocomplete_for_middle_keywords
AutocompleteTest.autocomplete_if_else_regression
AutocompleteTest.autocomplete_if_middle_keywords AutocompleteTest.autocomplete_if_middle_keywords
AutocompleteTest.autocomplete_ifelse_expressions
AutocompleteTest.autocomplete_on_string_singletons AutocompleteTest.autocomplete_on_string_singletons
AutocompleteTest.autocomplete_oop_implicit_self AutocompleteTest.autocomplete_oop_implicit_self
AutocompleteTest.autocomplete_repeat_middle_keyword AutocompleteTest.autocomplete_repeat_middle_keyword
AutocompleteTest.autocomplete_string_singleton_equality AutocompleteTest.autocomplete_string_singleton_equality
AutocompleteTest.autocomplete_string_singleton_escape AutocompleteTest.autocomplete_string_singleton_escape
AutocompleteTest.autocomplete_string_singletons AutocompleteTest.autocomplete_string_singletons
AutocompleteTest.autocomplete_until_expression
AutocompleteTest.autocomplete_until_in_repeat
AutocompleteTest.autocomplete_while_middle_keywords AutocompleteTest.autocomplete_while_middle_keywords
AutocompleteTest.autocompleteProp_index_function_metamethod_is_variadic AutocompleteTest.autocompleteProp_index_function_metamethod_is_variadic
AutocompleteTest.bias_toward_inner_scope AutocompleteTest.bias_toward_inner_scope
@ -60,7 +54,6 @@ AutocompleteTest.get_suggestions_for_the_very_start_of_the_script
AutocompleteTest.global_function_params AutocompleteTest.global_function_params
AutocompleteTest.global_functions_are_not_scoped_lexically AutocompleteTest.global_functions_are_not_scoped_lexically
AutocompleteTest.if_then_else_elseif_completions AutocompleteTest.if_then_else_elseif_completions
AutocompleteTest.if_then_else_full_keywords
AutocompleteTest.keyword_methods AutocompleteTest.keyword_methods
AutocompleteTest.keyword_types AutocompleteTest.keyword_types
AutocompleteTest.library_non_self_calls_are_fine AutocompleteTest.library_non_self_calls_are_fine
@ -181,7 +174,6 @@ DefinitionTests.single_class_type_identity_in_global_types
FrontendTest.ast_node_at_position FrontendTest.ast_node_at_position
FrontendTest.automatically_check_dependent_scripts FrontendTest.automatically_check_dependent_scripts
FrontendTest.check_without_builtin_next FrontendTest.check_without_builtin_next
FrontendTest.clearStats
FrontendTest.dont_reparse_clean_file_when_linting FrontendTest.dont_reparse_clean_file_when_linting
FrontendTest.environments FrontendTest.environments
FrontendTest.imported_table_modification_2 FrontendTest.imported_table_modification_2
@ -195,7 +187,6 @@ FrontendTest.reexport_cyclic_type
FrontendTest.reexport_type_alias FrontendTest.reexport_type_alias
FrontendTest.report_require_to_nonexistent_file FrontendTest.report_require_to_nonexistent_file
FrontendTest.report_syntax_error_in_required_file FrontendTest.report_syntax_error_in_required_file
FrontendTest.stats_are_not_reset_between_checks
FrontendTest.trace_requires_in_nonstrict_mode FrontendTest.trace_requires_in_nonstrict_mode
GenericsTests.apply_type_function_nested_generics1 GenericsTests.apply_type_function_nested_generics1
GenericsTests.apply_type_function_nested_generics2 GenericsTests.apply_type_function_nested_generics2
@ -213,8 +204,6 @@ GenericsTests.duplicate_generic_types
GenericsTests.error_detailed_function_mismatch_generic_pack GenericsTests.error_detailed_function_mismatch_generic_pack
GenericsTests.error_detailed_function_mismatch_generic_types GenericsTests.error_detailed_function_mismatch_generic_types
GenericsTests.factories_of_generics GenericsTests.factories_of_generics
GenericsTests.function_arguments_can_be_polytypes
GenericsTests.function_results_can_be_polytypes
GenericsTests.generic_argument_count_too_few GenericsTests.generic_argument_count_too_few
GenericsTests.generic_argument_count_too_many GenericsTests.generic_argument_count_too_many
GenericsTests.generic_factories GenericsTests.generic_factories
@ -243,7 +232,6 @@ GenericsTests.properties_can_be_instantiated_polytypes
GenericsTests.rank_N_types_via_typeof GenericsTests.rank_N_types_via_typeof
GenericsTests.reject_clashing_generic_and_pack_names GenericsTests.reject_clashing_generic_and_pack_names
GenericsTests.self_recursive_instantiated_param GenericsTests.self_recursive_instantiated_param
GenericsTests.variadic_generics
IntersectionTypes.argument_is_intersection IntersectionTypes.argument_is_intersection
IntersectionTypes.error_detailed_intersection_all IntersectionTypes.error_detailed_intersection_all
IntersectionTypes.error_detailed_intersection_part IntersectionTypes.error_detailed_intersection_part
@ -289,7 +277,6 @@ Normalize.cyclic_intersection
Normalize.cyclic_table_normalizes_sensibly Normalize.cyclic_table_normalizes_sensibly
Normalize.cyclic_union Normalize.cyclic_union
Normalize.fuzz_failure_bound_type_is_normal_but_not_its_bounded_to Normalize.fuzz_failure_bound_type_is_normal_but_not_its_bounded_to
Normalize.higher_order_function
Normalize.intersection_combine_on_bound_self Normalize.intersection_combine_on_bound_self
Normalize.intersection_inside_a_table_inside_another_intersection Normalize.intersection_inside_a_table_inside_another_intersection
Normalize.intersection_inside_a_table_inside_another_intersection_2 Normalize.intersection_inside_a_table_inside_another_intersection_2
@ -304,7 +291,6 @@ Normalize.normalization_does_not_convert_ever
Normalize.normalize_module_return_type Normalize.normalize_module_return_type
Normalize.normalize_unions_containing_never Normalize.normalize_unions_containing_never
Normalize.normalize_unions_containing_unknown Normalize.normalize_unions_containing_unknown
Normalize.return_type_is_not_a_constrained_intersection
Normalize.union_of_distinct_free_types Normalize.union_of_distinct_free_types
Normalize.variadic_tail_is_marked_normal Normalize.variadic_tail_is_marked_normal
Normalize.visiting_a_type_twice_is_not_considered_normal Normalize.visiting_a_type_twice_is_not_considered_normal
@ -456,7 +442,6 @@ TableTests.inferred_return_type_of_free_table
TableTests.inferring_crazy_table_should_also_be_quick TableTests.inferring_crazy_table_should_also_be_quick
TableTests.instantiate_table_cloning_3 TableTests.instantiate_table_cloning_3
TableTests.instantiate_tables_at_scope_level TableTests.instantiate_tables_at_scope_level
TableTests.invariant_table_properties_means_instantiating_tables_in_assignment_is_unsound
TableTests.leaking_bad_metatable_errors TableTests.leaking_bad_metatable_errors
TableTests.length_operator_intersection TableTests.length_operator_intersection
TableTests.length_operator_non_table_union TableTests.length_operator_non_table_union
@ -521,7 +506,6 @@ ToDot.function
ToDot.metatable ToDot.metatable
ToDot.table ToDot.table
ToString.exhaustive_toString_of_cyclic_table ToString.exhaustive_toString_of_cyclic_table
ToString.function_type_with_argument_names_and_self
ToString.function_type_with_argument_names_generic ToString.function_type_with_argument_names_generic
ToString.no_parentheses_around_cyclic_function_type_in_union ToString.no_parentheses_around_cyclic_function_type_in_union
ToString.toStringDetailed2 ToString.toStringDetailed2
@ -565,10 +549,7 @@ TypeInfer.do_not_bind_a_free_table_to_a_union_containing_that_table
TypeInfer.dont_report_type_errors_within_an_AstStatError TypeInfer.dont_report_type_errors_within_an_AstStatError
TypeInfer.globals TypeInfer.globals
TypeInfer.globals2 TypeInfer.globals2
TypeInfer.infer_assignment_value_types
TypeInfer.infer_assignment_value_types_mutable_lval TypeInfer.infer_assignment_value_types_mutable_lval
TypeInfer.infer_through_group_expr
TypeInfer.no_heap_use_after_free_error
TypeInfer.no_stack_overflow_from_isoptional TypeInfer.no_stack_overflow_from_isoptional
TypeInfer.tc_after_error_recovery_no_replacement_name_in_error TypeInfer.tc_after_error_recovery_no_replacement_name_in_error
TypeInfer.tc_if_else_expressions1 TypeInfer.tc_if_else_expressions1
@ -589,11 +570,8 @@ TypeInferAnyError.for_in_loop_iterator_returns_any
TypeInferAnyError.for_in_loop_iterator_returns_any2 TypeInferAnyError.for_in_loop_iterator_returns_any2
TypeInferAnyError.length_of_error_type_does_not_produce_an_error TypeInferAnyError.length_of_error_type_does_not_produce_an_error
TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any TypeInferAnyError.replace_every_free_type_when_unifying_a_complex_function_with_any
TypeInferAnyError.type_error_addition
TypeInferClasses.call_base_method TypeInferClasses.call_base_method
TypeInferClasses.call_instance_method TypeInferClasses.call_instance_method
TypeInferClasses.can_read_prop_of_base_class
TypeInferClasses.can_read_prop_of_base_class_using_string
TypeInferClasses.class_type_mismatch_with_name_conflict TypeInferClasses.class_type_mismatch_with_name_conflict
TypeInferClasses.classes_can_have_overloaded_operators TypeInferClasses.classes_can_have_overloaded_operators
TypeInferClasses.classes_without_overloaded_operators_cannot_be_added TypeInferClasses.classes_without_overloaded_operators_cannot_be_added
@ -609,8 +587,6 @@ TypeInferFunctions.another_indirect_function_case_where_it_is_ok_to_provide_too_
TypeInferFunctions.another_recursive_local_function TypeInferFunctions.another_recursive_local_function
TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types TypeInferFunctions.calling_function_with_anytypepack_doesnt_leak_free_types
TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument TypeInferFunctions.calling_function_with_incorrect_argument_type_yields_errors_spanning_argument
TypeInferFunctions.cannot_hoist_interior_defns_into_signature
TypeInferFunctions.check_function_before_lambda_that_uses_it
TypeInferFunctions.complicated_return_types_require_an_explicit_annotation TypeInferFunctions.complicated_return_types_require_an_explicit_annotation
TypeInferFunctions.cyclic_function_type_in_args TypeInferFunctions.cyclic_function_type_in_args
TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists TypeInferFunctions.dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists
@ -635,7 +611,6 @@ TypeInferFunctions.ignored_return_values
TypeInferFunctions.inconsistent_higher_order_function TypeInferFunctions.inconsistent_higher_order_function
TypeInferFunctions.inconsistent_return_types TypeInferFunctions.inconsistent_return_types
TypeInferFunctions.infer_anonymous_function_arguments TypeInferFunctions.infer_anonymous_function_arguments
TypeInferFunctions.infer_anonymous_function_arguments_outside_call
TypeInferFunctions.infer_return_type_from_selected_overload TypeInferFunctions.infer_return_type_from_selected_overload
TypeInferFunctions.infer_that_function_does_not_return_a_table TypeInferFunctions.infer_that_function_does_not_return_a_table
TypeInferFunctions.it_is_ok_not_to_supply_enough_retvals TypeInferFunctions.it_is_ok_not_to_supply_enough_retvals
@ -680,9 +655,6 @@ TypeInferLoops.loop_iter_no_indexer_strict
TypeInferLoops.loop_iter_trailing_nil TypeInferLoops.loop_iter_trailing_nil
TypeInferLoops.loop_typecheck_crash_on_empty_optional TypeInferLoops.loop_typecheck_crash_on_empty_optional
TypeInferLoops.properly_infer_iteratee_is_a_free_table TypeInferLoops.properly_infer_iteratee_is_a_free_table
TypeInferLoops.repeat_loop
TypeInferLoops.repeat_loop_condition_binds_to_its_block
TypeInferLoops.symbols_in_repeat_block_should_not_be_visible_beyond_until_condition
TypeInferLoops.unreachable_code_after_infinite_loop TypeInferLoops.unreachable_code_after_infinite_loop
TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free TypeInferLoops.varlist_declared_by_for_in_loop_should_be_free
TypeInferModules.do_not_modify_imported_types TypeInferModules.do_not_modify_imported_types
@ -718,8 +690,6 @@ TypeInferOperators.cannot_indirectly_compare_types_that_do_not_offer_overloaded_
TypeInferOperators.cli_38355_recursive_union TypeInferOperators.cli_38355_recursive_union
TypeInferOperators.compare_numbers TypeInferOperators.compare_numbers
TypeInferOperators.compare_strings TypeInferOperators.compare_strings
TypeInferOperators.compound_assign_basic
TypeInferOperators.compound_assign_metatable
TypeInferOperators.compound_assign_mismatch_metatable TypeInferOperators.compound_assign_mismatch_metatable
TypeInferOperators.compound_assign_mismatch_op TypeInferOperators.compound_assign_mismatch_op
TypeInferOperators.compound_assign_mismatch_result TypeInferOperators.compound_assign_mismatch_result
@ -775,7 +745,6 @@ TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable
TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2 TypeInferUnknownNever.type_packs_containing_never_is_itself_uninhabitable2
TypeInferUnknownNever.unary_minus_of_never TypeInferUnknownNever.unary_minus_of_never
TypeInferUnknownNever.unknown_is_reflexive TypeInferUnknownNever.unknown_is_reflexive
TypePackTests.cyclic_type_packs
TypePackTests.higher_order_function TypePackTests.higher_order_function
TypePackTests.multiple_varargs_inference_are_not_confused TypePackTests.multiple_varargs_inference_are_not_confused
TypePackTests.no_return_size_should_be_zero TypePackTests.no_return_size_should_be_zero

View File

@ -132,8 +132,16 @@ while offset < len(queue):
queue.append((obj["metatable"], node.child("__meta"))) queue.append((obj["metatable"], node.child("__meta")))
elif obj["type"] == "thread": elif obj["type"] == "thread":
queue.append((obj["env"], node.child("__env"))) queue.append((obj["env"], node.child("__env")))
for a in obj.get("stack", []): stack = obj.get("stack")
queue.append((a, node.child("__stack"))) stacknames = obj.get("stacknames", [])
stacknode = node.child("__stack")
framenode = None
for i in range(len(stack)):
name = stacknames[i] if stacknames else None
if name and name.startswith("frame:"):
framenode = stacknode.child(name[6:])
name = None
queue.append((stack[i], framenode.child(name) if framenode and name else framenode or stacknode))
elif obj["type"] == "proto": elif obj["type"] == "proto":
for a in obj.get("constants", []): for a in obj.get("constants", []):
queue.append((a, node)) queue.append((a, node))

View File

@ -59,5 +59,6 @@ if len(size_category) != 0:
print("objects by category:") print("objects by category:")
for type, (count, size) in sortedsize(size_category.items()): for type, (count, size) in sortedsize(size_category.items()):
name = dump["stats"]["categories"][type]["name"] cat = dump["stats"]["categories"][type]
name = cat["name"] if "name" in cat else str(type)
print(name.ljust(30), str(size).rjust(8), "bytes", str(count).rjust(5), "objects") print(name.ljust(30), str(size).rjust(8), "bytes", str(count).rjust(5), "objects")

View File

@ -56,18 +56,28 @@ def nodeFromCallstackListFile(source_file):
return root return root
def getDuration(obj):
total = obj['TotalDuration']
def nodeFromJSONbject(node, key, obj): if 'Children' in obj:
for key, obj in obj['Children'].items():
total -= obj['TotalDuration']
return total
def nodeFromJSONObject(node, key, obj):
source, function, line = key.split(",") source, function, line = key.split(",")
node.function = function node.function = function
node.source = source node.source = source
node.line = int(line) if len(line) > 0 else 0 node.line = int(line) if len(line) > 0 else 0
node.ticks = obj['Duration'] node.ticks = getDuration(obj)
for key, obj in obj['Children'].items(): if 'Children' in obj:
nodeFromJSONbject(node.child(key), key, obj) for key, obj in obj['Children'].items():
nodeFromJSONObject(node.child(key), key, obj)
return node return node
@ -77,8 +87,9 @@ def nodeFromJSONFile(source_file):
root = Node() root = Node()
for key, obj in dump['Children'].items(): if 'Children' in dump:
nodeFromJSONbject(root.child(key), key, obj) for key, obj in dump['Children'].items():
nodeFromJSONObject(root.child(key), key, obj)
return root return root