From 1fa8311a188a932233d2ce2bc6412e5bab37c035 Mon Sep 17 00:00:00 2001 From: Andy Friesen Date: Fri, 10 Mar 2023 12:21:07 -0800 Subject: [PATCH] Sync to upstream/release/567 (#860) * Fix #817 * Fix #850 * Optimize math.floor/ceil/round with SSE4.1 * Results in a ~7-9% speedup on the math-cordic benchmark. * Optimized table.sort. * table.sort is now ~4.1x faster (when not using a predicate) and ~2.1x faster when using a simple predicate. Performance may improve further in the future. * Reorganize the memory ownership of builtin type definitions. * This is a small initial step toward affording parallel typechecking. The new type solver is coming along nicely. We are working on fixing crashes and bugs. A few major changes to native codegen landed this week: * Fixed lowering of Luau IR mod instruction when first argument is a constant * Added VM register data-flow/capture analysis * Fixed issues with optimizations in unreachable blocks --------- Co-authored-by: Arseny Kapoulkine Co-authored-by: Vyacheslav Egorov --- Analysis/include/Luau/BuiltinDefinitions.h | 27 +- Analysis/include/Luau/Error.h | 4 +- Analysis/include/Luau/Frontend.h | 24 +- Analysis/include/Luau/Quantify.h | 2 +- Analysis/include/Luau/Type.h | 8 +- Analysis/include/Luau/TypeInfer.h | 18 +- Analysis/include/Luau/Unifier.h | 5 + Analysis/src/AstQuery.cpp | 14 +- Analysis/src/Autocomplete.cpp | 49 +- Analysis/src/BuiltinDefinitions.cpp | 197 ++++---- Analysis/src/ConstraintGraphBuilder.cpp | 4 +- Analysis/src/ConstraintSolver.cpp | 34 +- Analysis/src/Error.cpp | 4 +- Analysis/src/Frontend.cpp | 146 ++++-- Analysis/src/Quantify.cpp | 5 +- Analysis/src/TxnLog.cpp | 1 + Analysis/src/Type.cpp | 8 +- Analysis/src/TypeChecker2.cpp | 73 ++- Analysis/src/TypeInfer.cpp | 51 +- Analysis/src/TypeReduction.cpp | 44 +- Analysis/src/Unifier.cpp | 158 ++++--- CLI/Analyze.cpp | 5 +- CodeGen/include/Luau/IrAnalysis.h | 56 +++ CodeGen/include/Luau/IrData.h | 12 +- CodeGen/include/Luau/IrDump.h | 9 +- CodeGen/include/Luau/IrUtils.h | 3 - CodeGen/src/CodeGen.cpp | 23 +- CodeGen/src/IrAnalysis.cpp | 517 +++++++++++++++++++++ CodeGen/src/IrBuilder.cpp | 2 +- CodeGen/src/IrDump.cpp | 163 +++++-- CodeGen/src/IrLoweringX64.cpp | 30 +- CodeGen/src/IrRegAllocX64.cpp | 22 +- CodeGen/src/IrRegAllocX64.h | 2 + CodeGen/src/IrUtils.cpp | 52 ++- Common/include/Luau/ExperimentalFlags.h | 1 - VM/include/luaconf.h | 12 + VM/src/lbuiltins.cpp | 105 +++++ VM/src/ldebug.cpp | 17 +- VM/src/lmathlib.cpp | 3 +- VM/src/ltablib.cpp | 154 +++++- VM/src/lvmexecute.cpp | 11 +- VM/src/lvmutils.cpp | 20 +- fuzz/proto.cpp | 34 +- tests/AstQuery.test.cpp | 13 + tests/Autocomplete.test.cpp | 128 ++--- tests/BuiltinDefinitions.test.cpp | 4 +- tests/ClassFixture.cpp | 41 +- tests/Conformance.test.cpp | 9 +- tests/ConstraintGraphBuilderFixture.cpp | 2 +- tests/Fixture.cpp | 61 +-- tests/Fixture.h | 1 - tests/Frontend.test.cpp | 12 +- tests/IrBuilder.test.cpp | 410 +++++++++++++--- tests/Linter.test.cpp | 49 +- tests/Module.test.cpp | 33 +- tests/NonstrictMode.test.cpp | 10 +- tests/Normalize.test.cpp | 6 +- tests/ToDot.test.cpp | 14 +- tests/ToString.test.cpp | 32 +- tests/TypeInfer.aliases.test.cpp | 33 +- tests/TypeInfer.annotations.test.cpp | 19 +- tests/TypeInfer.anyerror.test.cpp | 4 +- tests/TypeInfer.builtins.test.cpp | 32 +- tests/TypeInfer.definitions.test.cpp | 76 +-- tests/TypeInfer.functions.test.cpp | 55 ++- tests/TypeInfer.generics.test.cpp | 4 +- tests/TypeInfer.intersectionTypes.test.cpp | 14 +- tests/TypeInfer.loops.test.cpp | 26 +- tests/TypeInfer.modules.test.cpp | 39 ++ tests/TypeInfer.oop.test.cpp | 19 + tests/TypeInfer.operators.test.cpp | 44 +- tests/TypeInfer.primitives.test.cpp | 4 +- tests/TypeInfer.refinements.test.cpp | 39 +- tests/TypeInfer.tables.test.cpp | 78 ++-- tests/TypeInfer.test.cpp | 41 +- tests/TypeInfer.tryUnify.test.cpp | 43 +- tests/TypeInfer.typePacks.cpp | 32 +- tests/TypeInfer.unionTypes.test.cpp | 10 +- tests/TypeReduction.test.cpp | 18 + tests/TypeVar.test.cpp | 66 +-- tests/conformance/basic.lua | 2 + tests/conformance/math.lua | 2 + tests/conformance/sort.lua | 45 +- tools/faillist.txt | 13 +- 84 files changed, 2635 insertions(+), 1077 deletions(-) diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 0604b40..1621395 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -10,12 +10,13 @@ namespace Luau { struct Frontend; +struct GlobalTypes; struct TypeChecker; struct TypeArena; -void registerBuiltinTypes(Frontend& frontend); +void registerBuiltinTypes(GlobalTypes& globals); -void registerBuiltinGlobals(TypeChecker& typeChecker); +void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals); void registerBuiltinGlobals(Frontend& frontend); TypeId makeUnion(TypeArena& arena, std::vector&& types); @@ -23,8 +24,7 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types); /** Build an optional 't' */ -TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t); -TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t); +TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t); /** Small utility function for building up type definitions from C++. */ @@ -52,17 +52,12 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string& std::string getBuiltinDefinitionSource(); -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding); -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding); -void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding); -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding); -std::optional tryGetGlobalBinding(Frontend& frontend, const std::string& name); -Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name); -TypeId getGlobalBinding(Frontend& frontend, const std::string& name); -TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name); +void addGlobalBinding(GlobalTypes& globals, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(GlobalTypes& globals, const std::string& name, Binding binding); +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, Binding binding); +std::optional tryGetGlobalBinding(GlobalTypes& globals, const std::string& name); +Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name); +TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name); } // namespace Luau diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 69d4cca..8571430 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -411,8 +411,8 @@ struct InternalErrorReporter std::function onInternalError; std::string moduleName; - [[noreturn]] void ice(const std::string& message, const Location& location); - [[noreturn]] void ice(const std::string& message); + [[noreturn]] void ice(const std::string& message, const Location& location) const; + [[noreturn]] void ice(const std::string& message) const; }; class InternalCompilerError : public std::exception diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 7c5dc4a..9c0366a 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -21,6 +21,7 @@ class ParseError; struct Frontend; struct TypeError; struct LintWarning; +struct GlobalTypes; struct TypeChecker; struct FileResolver; struct ModuleResolver; @@ -31,11 +32,12 @@ struct LoadDefinitionFileResult { bool success; ParseResult parseResult; + SourceModule sourceModule; ModulePtr module; }; -LoadDefinitionFileResult loadDefinitionFile( - TypeChecker& typeChecker, ScopePtr targetScope, std::string_view definition, const std::string& packageName); +LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view definition, + const std::string& packageName, bool captureComments); std::optional parseMode(const std::vector& hotcomments); @@ -152,14 +154,12 @@ struct Frontend void clear(); ScopePtr addEnvironment(const std::string& environmentName); - ScopePtr getEnvironmentScope(const std::string& environmentName); + ScopePtr getEnvironmentScope(const std::string& environmentName) const; - void registerBuiltinDefinition(const std::string& name, std::function); + void registerBuiltinDefinition(const std::string& name, std::function); void applyBuiltinDefinitionToEnvironment(const std::string& environmentName, const std::string& definitionName); - LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName); - - ScopePtr getGlobalScope(); + LoadDefinitionFileResult loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments); private: ModulePtr check(const SourceModule& sourceModule, Mode mode, std::vector requireCycles, bool forAutocomplete = false, bool recordJsonLog = false); @@ -171,10 +171,10 @@ private: static LintResult classifyLints(const std::vector& warnings, const Config& config); - ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete); + ScopePtr getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const; std::unordered_map environments; - std::unordered_map> builtinDefinitions; + std::unordered_map> builtinDefinitions; BuiltinTypes builtinTypes_; @@ -184,21 +184,19 @@ public: FileResolver* fileResolver; FrontendModuleResolver moduleResolver; FrontendModuleResolver moduleResolverForAutocomplete; + GlobalTypes globals; + GlobalTypes globalsForAutocomplete; TypeChecker typeChecker; TypeChecker typeCheckerForAutocomplete; ConfigResolver* configResolver; FrontendOptions options; InternalErrorReporter iceHandler; - TypeArena globalTypes; std::unordered_map sourceNodes; std::unordered_map sourceModules; std::unordered_map requireTrace; Stats stats = {}; - -private: - ScopePtr globalScope; }; ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index b350fab..c86512f 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -10,6 +10,6 @@ struct TypeArena; struct Scope; void quantify(TypeId ty, TypeLevel level); -TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope); +std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope); } // namespace Luau diff --git a/Analysis/include/Luau/Type.h b/Analysis/include/Luau/Type.h index cf1f8da..ef2d4c6 100644 --- a/Analysis/include/Luau/Type.h +++ b/Analysis/include/Luau/Type.h @@ -640,10 +640,10 @@ struct BuiltinTypes BuiltinTypes(const BuiltinTypes&) = delete; void operator=(const BuiltinTypes&) = delete; - TypeId errorRecoveryType(TypeId guess); - TypePackId errorRecoveryTypePack(TypePackId guess); - TypeId errorRecoveryType(); - TypePackId errorRecoveryTypePack(); + TypeId errorRecoveryType(TypeId guess) const; + TypePackId errorRecoveryTypePack(TypePackId guess) const; + TypeId errorRecoveryType() const; + TypePackId errorRecoveryTypePack() const; private: std::unique_ptr arena; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 678bd41..21cb263 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -63,11 +63,22 @@ enum class ValueContext RValue }; +struct GlobalTypes +{ + GlobalTypes(NotNull builtinTypes); + + NotNull builtinTypes; // Global types are based on builtin types + + TypeArena globalTypes; + SourceModule globalNames; // names for symbols entered into globalScope + ScopePtr globalScope; // shared by all modules +}; + // All Types are retained via Environment::types. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker { - explicit TypeChecker(ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); + explicit TypeChecker(const GlobalTypes& globals, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler); TypeChecker(const TypeChecker&) = delete; TypeChecker& operator=(const TypeChecker&) = delete; @@ -355,11 +366,10 @@ public: */ std::vector unTypePack(const ScopePtr& scope, TypePackId pack, size_t expectedLength, const Location& location); - TypeArena globalTypes; + // TODO: only const version of global scope should be available to make sure nothing else is modified inside of from users of TypeChecker + const GlobalTypes& globals; ModuleResolver* resolver; - SourceModule globalNames; // names for symbols entered into globalScope - ScopePtr globalScope; // shared by all modules ModulePtr currentModule; ModuleName currentModuleName; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 50024e3..fc886ac 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -90,6 +90,11 @@ struct Unifier private: void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); void tryUnifyUnionWithType(TypeId subTy, const UnionType* uv, TypeId superTy); + + // Traverse the two types provided and block on any BlockedTypes we find. + // Returns true if any types were blocked on. + bool blockOnBlockedTypes(TypeId subTy, TypeId superTy); + void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall); void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionType* uv); void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionType* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index e95b001..b0c3750 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -12,7 +12,6 @@ #include LUAU_FASTFLAG(LuauCompleteTableKeysBetter); -LUAU_FASTFLAGVARIABLE(SupportTypeAliasGoToDeclaration, false); namespace Luau { @@ -195,17 +194,10 @@ struct FindFullAncestry final : public AstVisitor bool visit(AstType* type) override { - if (FFlag::SupportTypeAliasGoToDeclaration) - { - if (includeTypes) - return visit(static_cast(type)); - else - return false; - } + if (includeTypes) + return visit(static_cast(type)); else - { - return AstVisitor::visit(type); - } + return false; } bool visit(AstNode* node) override diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1e09497..1df4d3d 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,8 +14,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauCompleteTableKeysBetter, false); -LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInWhile, false); -LUAU_FASTFLAGVARIABLE(LuauFixAutocompleteInFor, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteSkipNormalization, false); static const std::unordered_set kStatementStartingKeywords = { @@ -1425,24 +1423,12 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (!statFor->hasDo || position < statFor->doLocation.begin) { - if (FFlag::LuauFixAutocompleteInFor) - { - if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || - (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); + if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || + (statFor->step && statFor->step->location.containsClosed(position))) + return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } - else - { - if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - - if (statFor->from->location.containsClosed(position) || statFor->to->location.containsClosed(position) || - (statFor->step && statFor->step->location.containsClosed(position))) - return autocompleteExpression(sourceModule, *module, builtinTypes, typeArena, ancestry, position); - } + if (!statFor->from->is() && !statFor->to->is() && (!statFor->step || !statFor->step->is())) + return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; return {}; } @@ -1493,14 +1479,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (!statWhile->hasDo && !statWhile->condition->is() && position > statWhile->condition->location.end) { - if (FFlag::LuauFixAutocompleteInWhile) - { - return autocompleteWhileLoopKeywords(ancestry); - } - else - { - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } + return autocompleteWhileLoopKeywords(ancestry); } if (!statWhile->hasDo || position < statWhile->doLocation.begin) @@ -1511,18 +1490,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } else if (AstStatWhile* statWhile = extractStat(ancestry); - FFlag::LuauFixAutocompleteInWhile ? (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && - statWhile->condition && !statWhile->condition->location.containsClosed(position)) - : (statWhile && !statWhile->hasDo)) + (statWhile && (!statWhile->hasDo || statWhile->doLocation.containsClosed(position)) && statWhile->condition && + !statWhile->condition->location.containsClosed(position))) { - if (FFlag::LuauFixAutocompleteInWhile) - { - return autocompleteWhileLoopKeywords(ancestry); - } - else - { - return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, ancestry, AutocompleteContext::Keyword}; - } + return autocompleteWhileLoopKeywords(ancestry); } else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { @@ -1672,7 +1643,7 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName return {}; NotNull builtinTypes = frontend.builtinTypes; - Scope* globalScope = frontend.typeCheckerForAutocomplete.globalScope.get(); + Scope* globalScope = frontend.globalsForAutocomplete.globalScope.get(); TypeArena typeArena; return autocomplete(*sourceModule, module, builtinTypes, &typeArena, globalScope, position, callback); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index b111c50..d2ace49 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -52,14 +52,9 @@ TypeId makeIntersection(TypeArena& arena, std::vector&& types) return arena.addType(IntersectionType{std::move(types)}); } -TypeId makeOption(Frontend& frontend, TypeArena& arena, TypeId t) +TypeId makeOption(NotNull builtinTypes, TypeArena& arena, TypeId t) { - return makeUnion(arena, {frontend.typeChecker.nilType, t}); -} - -TypeId makeOption(TypeChecker& typeChecker, TypeArena& arena, TypeId t) -{ - return makeUnion(arena, {typeChecker.nilType, t}); + return makeUnion(arena, {builtinTypes->nilType, t}); } TypeId makeFunction( @@ -148,85 +143,52 @@ Property makeProperty(TypeId ty, std::optional documentationSymbol) }; } -void addGlobalBinding(Frontend& frontend, const std::string& name, TypeId ty, const std::string& packageName) +void addGlobalBinding(GlobalTypes& globals, const std::string& name, TypeId ty, const std::string& packageName) { - addGlobalBinding(frontend, frontend.getGlobalScope(), name, ty, packageName); + addGlobalBinding(globals, globals.globalScope, name, ty, packageName); } -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName); - -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, TypeId ty, const std::string& packageName) +void addGlobalBinding(GlobalTypes& globals, const std::string& name, Binding binding) { - addGlobalBinding(typeChecker, typeChecker.globalScope, name, ty, packageName); + addGlobalBinding(globals, globals.globalScope, name, binding); } -void addGlobalBinding(Frontend& frontend, const std::string& name, Binding binding) -{ - addGlobalBinding(frontend, frontend.getGlobalScope(), name, binding); -} - -void addGlobalBinding(TypeChecker& typeChecker, const std::string& name, Binding binding) -{ - addGlobalBinding(typeChecker, typeChecker.globalScope, name, binding); -} - -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) { std::string documentationSymbol = packageName + "/global/" + name; - addGlobalBinding(frontend, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); + addGlobalBinding(globals, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); } -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, TypeId ty, const std::string& packageName) +void addGlobalBinding(GlobalTypes& globals, const ScopePtr& scope, const std::string& name, Binding binding) { - std::string documentationSymbol = packageName + "/global/" + name; - addGlobalBinding(typeChecker, scope, name, Binding{ty, Location{}, {}, {}, documentationSymbol}); + scope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = binding; } -void addGlobalBinding(Frontend& frontend, const ScopePtr& scope, const std::string& name, Binding binding) +std::optional tryGetGlobalBinding(GlobalTypes& globals, const std::string& name) { - addGlobalBinding(frontend.typeChecker, scope, name, binding); -} - -void addGlobalBinding(TypeChecker& typeChecker, const ScopePtr& scope, const std::string& name, Binding binding) -{ - scope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = binding; -} - -std::optional tryGetGlobalBinding(TypeChecker& typeChecker, const std::string& name) -{ - AstName astName = typeChecker.globalNames.names->getOrAdd(name.c_str()); - auto it = typeChecker.globalScope->bindings.find(astName); - if (it != typeChecker.globalScope->bindings.end()) + AstName astName = globals.globalNames.names->getOrAdd(name.c_str()); + auto it = globals.globalScope->bindings.find(astName); + if (it != globals.globalScope->bindings.end()) return it->second; return std::nullopt; } -TypeId getGlobalBinding(TypeChecker& typeChecker, const std::string& name) +TypeId getGlobalBinding(GlobalTypes& globals, const std::string& name) { - auto t = tryGetGlobalBinding(typeChecker, name); + auto t = tryGetGlobalBinding(globals, name); LUAU_ASSERT(t.has_value()); return t->typeId; } -TypeId getGlobalBinding(Frontend& frontend, const std::string& name) +Binding* tryGetGlobalBindingRef(GlobalTypes& globals, const std::string& name) { - return getGlobalBinding(frontend.typeChecker, name); -} - -std::optional tryGetGlobalBinding(Frontend& frontend, const std::string& name) -{ - return tryGetGlobalBinding(frontend.typeChecker, name); -} - -Binding* tryGetGlobalBindingRef(TypeChecker& typeChecker, const std::string& name) -{ - AstName astName = typeChecker.globalNames.names->get(name.c_str()); + AstName astName = globals.globalNames.names->get(name.c_str()); if (astName == AstName()) return nullptr; - auto it = typeChecker.globalScope->bindings.find(astName); - if (it != typeChecker.globalScope->bindings.end()) + auto it = globals.globalScope->bindings.find(astName); + if (it != globals.globalScope->bindings.end()) return &it->second; return nullptr; @@ -240,34 +202,33 @@ void assignPropDocumentationSymbols(TableType::Props& props, const std::string& } } -void registerBuiltinTypes(Frontend& frontend) +void registerBuiltinTypes(GlobalTypes& globals) { - frontend.getGlobalScope()->addBuiltinTypeBinding("any", TypeFun{{}, frontend.builtinTypes->anyType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("nil", TypeFun{{}, frontend.builtinTypes->nilType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("number", TypeFun{{}, frontend.builtinTypes->numberType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("string", TypeFun{{}, frontend.builtinTypes->stringType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("boolean", TypeFun{{}, frontend.builtinTypes->booleanType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("thread", TypeFun{{}, frontend.builtinTypes->threadType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("unknown", TypeFun{{}, frontend.builtinTypes->unknownType}); - frontend.getGlobalScope()->addBuiltinTypeBinding("never", TypeFun{{}, frontend.builtinTypes->neverType}); + globals.globalScope->addBuiltinTypeBinding("any", TypeFun{{}, globals.builtinTypes->anyType}); + globals.globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, globals.builtinTypes->nilType}); + globals.globalScope->addBuiltinTypeBinding("number", TypeFun{{}, globals.builtinTypes->numberType}); + globals.globalScope->addBuiltinTypeBinding("string", TypeFun{{}, globals.builtinTypes->stringType}); + globals.globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, globals.builtinTypes->booleanType}); + globals.globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, globals.builtinTypes->threadType}); + globals.globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, globals.builtinTypes->unknownType}); + globals.globalScope->addBuiltinTypeBinding("never", TypeFun{{}, globals.builtinTypes->neverType}); } -void registerBuiltinGlobals(TypeChecker& typeChecker) +void registerBuiltinGlobals(TypeChecker& typeChecker, GlobalTypes& globals) { - LUAU_ASSERT(!typeChecker.globalTypes.types.isFrozen()); - LUAU_ASSERT(!typeChecker.globalTypes.typePacks.isFrozen()); + LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); + LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); - TypeId nilType = typeChecker.nilType; + TypeArena& arena = globals.globalTypes; + NotNull builtinTypes = globals.builtinTypes; - TypeArena& arena = typeChecker.globalTypes; - NotNull builtinTypes = typeChecker.builtinTypes; - - LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); + LoadDefinitionFileResult loadResult = + Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, getBuiltinDefinitionSource(), "@luau", /* captureComments */ false); LUAU_ASSERT(loadResult.success); TypeId genericK = arena.addType(GenericType{"K"}); TypeId genericV = arena.addType(GenericType{"V"}); - TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); + TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), globals.globalScope->level, TableState::Generic}); std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); LUAU_ASSERT(stringMetatableTy); @@ -277,33 +238,33 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) auto it = stringMetatableTable->props.find("__index"); LUAU_ASSERT(it != stringMetatableTable->props.end()); - addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); + addGlobalBinding(globals, "string", it->second.type, "@luau"); // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(typeChecker, arena, genericK), genericV}}); - addGlobalBinding(typeChecker, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); + addGlobalBinding(globals, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, builtinTypes->nilType}}); // pairs(t: Table) -> ((Table, K?) -> (K, V), Table, nil) - addGlobalBinding(typeChecker, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + addGlobalBinding(globals, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericType{"MT"}); - TableType tab{TableState::Generic, typeChecker.globalScope->level}; + TableType tab{TableState::Generic, globals.globalScope->level}; TypeId tabTy = arena.addType(tab); TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); - addGlobalBinding(typeChecker, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); + addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); // clang-format off // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(typeChecker, "setmetatable", + addGlobalBinding(globals, "setmetatable", arena.addType( FunctionType{ {genericMT}, @@ -315,7 +276,7 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) ); // clang-format on - for (const auto& pair : typeChecker.globalScope->bindings) + for (const auto& pair : globals.globalScope->bindings) { persist(pair.second.typeId); @@ -326,12 +287,12 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) } } - attachMagicFunction(getGlobalBinding(typeChecker, "assert"), magicFunctionAssert); - attachMagicFunction(getGlobalBinding(typeChecker, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(typeChecker, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(typeChecker, "select"), dcrMagicFunctionSelect); + attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); + attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); + attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); + attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); - if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) { // tabTy is a generic table type which we can't express via declaration syntax yet ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); @@ -349,26 +310,28 @@ void registerBuiltinGlobals(TypeChecker& typeChecker) attachDcrMagicFunction(ttv->props["pack"].type, dcrMagicFunctionPack); } - attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(typeChecker, "require"), dcrMagicFunctionRequire); + attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); + attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); } void registerBuiltinGlobals(Frontend& frontend) { - LUAU_ASSERT(!frontend.globalTypes.types.isFrozen()); - LUAU_ASSERT(!frontend.globalTypes.typePacks.isFrozen()); + GlobalTypes& globals = frontend.globals; - registerBuiltinTypes(frontend); + LUAU_ASSERT(!globals.globalTypes.types.isFrozen()); + LUAU_ASSERT(!globals.globalTypes.typePacks.isFrozen()); - TypeArena& arena = frontend.globalTypes; - NotNull builtinTypes = frontend.builtinTypes; + registerBuiltinTypes(globals); - LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau"); + TypeArena& arena = globals.globalTypes; + NotNull builtinTypes = globals.builtinTypes; + + LoadDefinitionFileResult loadResult = frontend.loadDefinitionFile(getBuiltinDefinitionSource(), "@luau", /* captureComments */ false); LUAU_ASSERT(loadResult.success); TypeId genericK = arena.addType(GenericType{"K"}); TypeId genericV = arena.addType(GenericType{"V"}); - TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), frontend.getGlobalScope()->level, TableState::Generic}); + TypeId mapOfKtoV = arena.addType(TableType{{}, TableIndexer(genericK, genericV), globals.globalScope->level, TableState::Generic}); std::optional stringMetatableTy = getMetatable(builtinTypes->stringType, builtinTypes); LUAU_ASSERT(stringMetatableTy); @@ -378,33 +341,33 @@ void registerBuiltinGlobals(Frontend& frontend) auto it = stringMetatableTable->props.find("__index"); LUAU_ASSERT(it != stringMetatableTable->props.end()); - addGlobalBinding(frontend, "string", it->second.type, "@luau"); + addGlobalBinding(globals, "string", it->second.type, "@luau"); // next(t: Table, i: K?) -> (K?, V) - TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(frontend, arena, genericK)}}); - TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(frontend, arena, genericK), genericV}}); - addGlobalBinding(frontend, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); + TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(builtinTypes, arena, genericK)}}); + TypePackId nextRetsTypePack = arena.addTypePack(TypePack{{makeOption(builtinTypes, arena, genericK), genericV}}); + addGlobalBinding(globals, "next", arena.addType(FunctionType{{genericK, genericV}, {}, nextArgsTypePack, nextRetsTypePack}), "@luau"); TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); TypeId pairsNext = arena.addType(FunctionType{nextArgsTypePack, nextRetsTypePack}); - TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, frontend.builtinTypes->nilType}}); + TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, builtinTypes->nilType}}); // pairs(t: Table) -> ((Table, K?) -> (K?, V), Table, nil) - addGlobalBinding(frontend, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); + addGlobalBinding(globals, "pairs", arena.addType(FunctionType{{genericK, genericV}, {}, pairsArgsTypePack, pairsReturnTypePack}), "@luau"); TypeId genericMT = arena.addType(GenericType{"MT"}); - TableType tab{TableState::Generic, frontend.getGlobalScope()->level}; + TableType tab{TableState::Generic, globals.globalScope->level}; TypeId tabTy = arena.addType(tab); TypeId tableMetaMT = arena.addType(MetatableType{tabTy, genericMT}); - addGlobalBinding(frontend, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); + addGlobalBinding(globals, "getmetatable", makeFunction(arena, std::nullopt, {genericMT}, {}, {tableMetaMT}, {genericMT}), "@luau"); // clang-format off // setmetatable(T, MT) -> { @metatable MT, T } - addGlobalBinding(frontend, "setmetatable", + addGlobalBinding(globals, "setmetatable", arena.addType( FunctionType{ {genericMT}, @@ -416,7 +379,7 @@ void registerBuiltinGlobals(Frontend& frontend) ); // clang-format on - for (const auto& pair : frontend.getGlobalScope()->bindings) + for (const auto& pair : globals.globalScope->bindings) { persist(pair.second.typeId); @@ -427,12 +390,12 @@ void registerBuiltinGlobals(Frontend& frontend) } } - attachMagicFunction(getGlobalBinding(frontend, "assert"), magicFunctionAssert); - attachMagicFunction(getGlobalBinding(frontend, "setmetatable"), magicFunctionSetMetaTable); - attachMagicFunction(getGlobalBinding(frontend, "select"), magicFunctionSelect); - attachDcrMagicFunction(getGlobalBinding(frontend, "select"), dcrMagicFunctionSelect); + attachMagicFunction(getGlobalBinding(globals, "assert"), magicFunctionAssert); + attachMagicFunction(getGlobalBinding(globals, "setmetatable"), magicFunctionSetMetaTable); + attachMagicFunction(getGlobalBinding(globals, "select"), magicFunctionSelect); + attachDcrMagicFunction(getGlobalBinding(globals, "select"), dcrMagicFunctionSelect); - if (TableType* ttv = getMutable(getGlobalBinding(frontend, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(globals, "table"))) { // tabTy is a generic table type which we can't express via declaration syntax yet ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); @@ -449,8 +412,8 @@ void registerBuiltinGlobals(Frontend& frontend) attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); } - attachMagicFunction(getGlobalBinding(frontend, "require"), magicFunctionRequire); - attachDcrMagicFunction(getGlobalBinding(frontend, "require"), dcrMagicFunctionRequire); + attachMagicFunction(getGlobalBinding(globals, "require"), magicFunctionRequire); + attachDcrMagicFunction(getGlobalBinding(globals, "require"), dcrMagicFunctionRequire); } static std::optional> magicFunctionSelect( diff --git a/Analysis/src/ConstraintGraphBuilder.cpp b/Analysis/src/ConstraintGraphBuilder.cpp index 9ee2b08..711d357 100644 --- a/Analysis/src/ConstraintGraphBuilder.cpp +++ b/Analysis/src/ConstraintGraphBuilder.cpp @@ -18,7 +18,6 @@ LUAU_FASTINT(LuauCheckRecursionLimit); LUAU_FASTFLAG(DebugLuauMagicTypes); LUAU_FASTFLAG(LuauNegatedClassTypes); -LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration); namespace Luau { @@ -587,8 +586,7 @@ void ConstraintGraphBuilder::visit(const ScopePtr& scope, AstStatLocal* local) if (ModulePtr module = moduleResolver->getModule(moduleInfo->name)) { scope->importedTypeBindings[name] = module->exportedTypeBindings; - if (FFlag::SupportTypeAliasGoToDeclaration) - scope->importedModules[name] = moduleName; + scope->importedModules[name] = moduleInfo->name; } } } diff --git a/Analysis/src/ConstraintSolver.cpp b/Analysis/src/ConstraintSolver.cpp index 3cb4e4e..3c306b4 100644 --- a/Analysis/src/ConstraintSolver.cpp +++ b/Analysis/src/ConstraintSolver.cpp @@ -521,12 +521,19 @@ bool ConstraintSolver::tryDispatch(const GeneralizationConstraint& c, NotNullscope); - - if (isBlocked(c.generalizedType)) - asMutable(c.generalizedType)->ty.emplace(generalized); + std::optional generalized = quantify(arena, c.sourceType, constraint->scope); + if (generalized) + { + if (isBlocked(c.generalizedType)) + asMutable(c.generalizedType)->ty.emplace(*generalized); + else + unify(c.generalizedType, *generalized, constraint->scope); + } else - unify(c.generalizedType, generalized, constraint->scope); + { + reportError(CodeTooComplex{}, constraint->location); + asMutable(c.generalizedType)->ty.emplace(builtinTypes->errorRecoveryType()); + } unblock(c.generalizedType); unblock(c.sourceType); @@ -1365,7 +1372,7 @@ static std::optional updateTheTableType(NotNull arena, TypeId if (it == tbl->props.end()) return std::nullopt; - t = it->second.type; + t = follow(it->second.type); } // The last path segment should not be a property of the table at all. @@ -1450,12 +1457,6 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) subjectType = follow(mt->table); - if (get(subjectType) || get(subjectType) || get(subjectType)) - { - bind(c.resultType, subjectType); - return true; - } - if (get(subjectType)) { TypeId ty = arena->freshType(constraint->scope); @@ -1501,16 +1502,13 @@ bool ConstraintSolver::tryDispatch(const SetPropConstraint& c, NotNull(subjectType)) + else { - // Classes and intersections never change shape as a result of property - // assignments. The result is always the subject. + // Other kinds of types don't change shape when properties are assigned + // to them. (if they allow properties at all!) bind(c.resultType, subjectType); return true; } - - LUAU_ASSERT(0); - return true; } bool ConstraintSolver::tryDispatch(const SetIndexerConstraint& c, NotNull constraint, bool force) diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index a527b24..84b9cb3 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -959,7 +959,7 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) visit(visitErrorData, error.data); } -void InternalErrorReporter::ice(const std::string& message, const Location& location) +void InternalErrorReporter::ice(const std::string& message, const Location& location) const { InternalCompilerError error(message, moduleName, location); @@ -969,7 +969,7 @@ void InternalErrorReporter::ice(const std::string& message, const Location& loca throw error; } -void InternalErrorReporter::ice(const std::string& message) +void InternalErrorReporter::ice(const std::string& message) const { InternalCompilerError error(message, moduleName); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b3e453d..722f1a2 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -31,6 +31,7 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) LUAU_FASTINTVARIABLE(LuauAutocompleteCheckTimeoutMs, 100) LUAU_FASTFLAGVARIABLE(DebugLuauDeferredConstraintResolution, false) +LUAU_FASTFLAGVARIABLE(LuauDefinitionFileSourceModule, false) LUAU_FASTFLAGVARIABLE(DebugLuauLogSolverToJson, false); namespace Luau @@ -83,32 +84,31 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) } } -LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName) +LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, const std::string& packageName, bool captureComments) { if (!FFlag::DebugLuauDeferredConstraintResolution) - return Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, source, packageName); + return Luau::loadDefinitionFile(typeChecker, globals, globals.globalScope, source, packageName, captureComments); LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); + Luau::SourceModule sourceModule; ParseOptions options; options.allowDeclarationSyntax = true; + options.captureComments = captureComments; - Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); + Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, nullptr}; + return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; - Luau::SourceModule module; - module.root = parseResult.root; - module.mode = Mode::Definition; + sourceModule.root = parseResult.root; + sourceModule.mode = Mode::Definition; - ModulePtr checkedModule = check(module, Mode::Definition, {}); + ModulePtr checkedModule = check(sourceModule, Mode::Definition, {}); if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, checkedModule}; + return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; CloneState cloneState; @@ -117,20 +117,20 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, globalTypes, cloneState); + TypeId globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); - globalScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + globals.globalScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; typesToPersist.push_back(globalTy); } for (const auto& [name, ty] : checkedModule->exportedTypeBindings) { - TypeFun globalTy = clone(ty, globalTypes, cloneState); + TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); - globalScope->exportedTypeBindings[name] = globalTy; + globals.globalScope->exportedTypeBindings[name] = globalTy; typesToPersist.push_back(globalTy.type); } @@ -140,10 +140,11 @@ LoadDefinitionFileResult Frontend::loadDefinitionFile(std::string_view source, c persist(ty); } - return LoadDefinitionFileResult{true, parseResult, checkedModule}; + return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } -LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) +LoadDefinitionFileResult loadDefinitionFile_DEPRECATED( + TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, const std::string& packageName) { LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); @@ -156,7 +157,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); if (parseResult.errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, nullptr}; + return LoadDefinitionFileResult{false, parseResult, {}, nullptr}; Luau::SourceModule module; module.root = parseResult.root; @@ -165,7 +166,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t ModulePtr checkedModule = typeChecker.check(module, Mode::Definition); if (checkedModule->errors.size() > 0) - return LoadDefinitionFileResult{false, parseResult, checkedModule}; + return LoadDefinitionFileResult{false, parseResult, {}, checkedModule}; CloneState cloneState; @@ -174,17 +175,17 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, cloneState); + TypeId globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); - targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; typesToPersist.push_back(globalTy); } for (const auto& [name, ty] : checkedModule->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, cloneState); + TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; @@ -197,7 +198,67 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t persist(ty); } - return LoadDefinitionFileResult{true, parseResult, checkedModule}; + return LoadDefinitionFileResult{true, parseResult, {}, checkedModule}; +} + +LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, GlobalTypes& globals, ScopePtr targetScope, std::string_view source, + const std::string& packageName, bool captureComments) +{ + if (!FFlag::LuauDefinitionFileSourceModule) + return loadDefinitionFile_DEPRECATED(typeChecker, globals, targetScope, source, packageName); + + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); + + Luau::SourceModule sourceModule; + + ParseOptions options; + options.allowDeclarationSyntax = true; + options.captureComments = captureComments; + + Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), *sourceModule.names, *sourceModule.allocator, options); + + if (parseResult.errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, nullptr}; + + sourceModule.root = parseResult.root; + sourceModule.mode = Mode::Definition; + + ModulePtr checkedModule = typeChecker.check(sourceModule, Mode::Definition); + + if (checkedModule->errors.size() > 0) + return LoadDefinitionFileResult{false, parseResult, sourceModule, checkedModule}; + + CloneState cloneState; + + std::vector typesToPersist; + typesToPersist.reserve(checkedModule->declaredGlobals.size() + checkedModule->exportedTypeBindings.size()); + + for (const auto& [name, ty] : checkedModule->declaredGlobals) + { + TypeId globalTy = clone(ty, globals.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/global/" + name; + generateDocumentationSymbols(globalTy, documentationSymbol); + targetScope->bindings[globals.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; + + typesToPersist.push_back(globalTy); + } + + for (const auto& [name, ty] : checkedModule->exportedTypeBindings) + { + TypeFun globalTy = clone(ty, globals.globalTypes, cloneState); + std::string documentationSymbol = packageName + "/globaltype/" + name; + generateDocumentationSymbols(globalTy.type, documentationSymbol); + targetScope->exportedTypeBindings[name] = globalTy; + + typesToPersist.push_back(globalTy.type); + } + + for (TypeId ty : typesToPersist) + { + persist(ty); + } + + return LoadDefinitionFileResult{true, parseResult, sourceModule, checkedModule}; } std::vector parsePathExpr(const AstExpr& pathExpr) @@ -414,11 +475,12 @@ Frontend::Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, c , fileResolver(fileResolver) , moduleResolver(this) , moduleResolverForAutocomplete(this) - , typeChecker(&moduleResolver, builtinTypes, &iceHandler) - , typeCheckerForAutocomplete(&moduleResolverForAutocomplete, builtinTypes, &iceHandler) + , globals(builtinTypes) + , globalsForAutocomplete(builtinTypes) + , typeChecker(globals, &moduleResolver, builtinTypes, &iceHandler) + , typeCheckerForAutocomplete(globalsForAutocomplete, &moduleResolverForAutocomplete, builtinTypes, &iceHandler) , configResolver(configResolver) , options(options) - , globalScope(typeChecker.globalScope) { } @@ -704,13 +766,13 @@ bool Frontend::parseGraph(std::vector& buildQueue, const ModuleName& return cyclic; } -ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) +ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config& config, bool forAutocomplete) const { ScopePtr result; if (forAutocomplete) - result = typeCheckerForAutocomplete.globalScope; + result = globalsForAutocomplete.globalScope; else - result = typeChecker.globalScope; + result = globals.globalScope; if (module.environmentName) result = getEnvironmentScope(*module.environmentName); @@ -848,16 +910,6 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons return const_cast(this)->getSourceModule(moduleName); } -ScopePtr Frontend::getGlobalScope() -{ - if (!globalScope) - { - globalScope = typeChecker.globalScope; - } - - return globalScope; -} - ModulePtr check(const SourceModule& sourceModule, const std::vector& requireCycles, NotNull builtinTypes, NotNull iceHandler, NotNull moduleResolver, NotNull fileResolver, const ScopePtr& globalScope, FrontendOptions options) @@ -946,7 +998,7 @@ ModulePtr Frontend::check(const SourceModule& sourceModule, Mode mode, std::vect { return Luau::check(sourceModule, requireCycles, builtinTypes, NotNull{&iceHandler}, NotNull{forAutocomplete ? &moduleResolverForAutocomplete : &moduleResolver}, NotNull{fileResolver}, - forAutocomplete ? typeCheckerForAutocomplete.globalScope : typeChecker.globalScope, options, recordJsonLog); + forAutocomplete ? globalsForAutocomplete.globalScope : globals.globalScope, options, recordJsonLog); } // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. @@ -1115,7 +1167,7 @@ ScopePtr Frontend::addEnvironment(const std::string& environmentName) if (environments.count(environmentName) == 0) { - ScopePtr scope = std::make_shared(typeChecker.globalScope); + ScopePtr scope = std::make_shared(globals.globalScope); environments[environmentName] = scope; return scope; } @@ -1123,14 +1175,16 @@ ScopePtr Frontend::addEnvironment(const std::string& environmentName) return environments[environmentName]; } -ScopePtr Frontend::getEnvironmentScope(const std::string& environmentName) +ScopePtr Frontend::getEnvironmentScope(const std::string& environmentName) const { - LUAU_ASSERT(environments.count(environmentName) > 0); + if (auto it = environments.find(environmentName); it != environments.end()) + return it->second; - return environments[environmentName]; + LUAU_ASSERT(!"environment doesn't exist"); + return {}; } -void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) +void Frontend::registerBuiltinDefinition(const std::string& name, std::function applicator) { LUAU_ASSERT(builtinDefinitions.count(name) == 0); @@ -1143,7 +1197,7 @@ void Frontend::applyBuiltinDefinitionToEnvironment(const std::string& environmen LUAU_ASSERT(builtinDefinitions.count(definitionName) > 0); if (builtinDefinitions.count(definitionName) > 0) - builtinDefinitions[definitionName](typeChecker, getEnvironmentScope(environmentName)); + builtinDefinitions[definitionName](typeChecker, globals, getEnvironmentScope(environmentName)); } LintResult Frontend::classifyLints(const std::vector& warnings, const Config& config) diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 845ae3a..9da43ed 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -253,11 +253,12 @@ struct PureQuantifier : Substitution } }; -TypeId quantify(TypeArena* arena, TypeId ty, Scope* scope) +std::optional quantify(TypeArena* arena, TypeId ty, Scope* scope) { PureQuantifier quantifier{arena, scope}; std::optional result = quantifier.substitute(ty); - LUAU_ASSERT(result); + if (!result) + return std::nullopt; FunctionType* ftv = getMutable(*result); LUAU_ASSERT(ftv); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 5040952..2661831 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -276,6 +276,7 @@ PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) { LUAU_ASSERT(get(ty)); + LUAU_ASSERT(ty != newBoundTo); PendingType* newTy = queue(ty); if (TableType* ttv = Luau::getMutable(newTy)) diff --git a/Analysis/src/Type.cpp b/Analysis/src/Type.cpp index f15f8c4..4bc1223 100644 --- a/Analysis/src/Type.cpp +++ b/Analysis/src/Type.cpp @@ -856,22 +856,22 @@ TypeId BuiltinTypes::makeStringMetatable() return arena->addType(TableType{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } -TypeId BuiltinTypes::errorRecoveryType() +TypeId BuiltinTypes::errorRecoveryType() const { return errorType; } -TypePackId BuiltinTypes::errorRecoveryTypePack() +TypePackId BuiltinTypes::errorRecoveryTypePack() const { return errorTypePack; } -TypeId BuiltinTypes::errorRecoveryType(TypeId guess) +TypeId BuiltinTypes::errorRecoveryType(TypeId guess) const { return guess; } -TypePackId BuiltinTypes::errorRecoveryTypePack(TypePackId guess) +TypePackId BuiltinTypes::errorRecoveryTypePack(TypePackId guess) const { return guess; } diff --git a/Analysis/src/TypeChecker2.cpp b/Analysis/src/TypeChecker2.cpp index aacfd72..a160a1d 100644 --- a/Analysis/src/TypeChecker2.cpp +++ b/Analysis/src/TypeChecker2.cpp @@ -857,7 +857,7 @@ struct TypeChecker2 } void reportOverloadResolutionErrors(AstExprCall* call, std::vector overloads, TypePackId expectedArgTypes, - const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) + const std::vector& overloadsThatMatchArgCount, std::vector> overloadsErrors) { if (overloads.size() == 1) { @@ -883,8 +883,8 @@ struct TypeChecker2 const FunctionType* ftv = get(overload); LUAU_ASSERT(ftv); // overload must be a function type here - auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [ftv](const std::pair& e) { - return ftv == std::get<1>(e); + auto error = std::find_if(overloadsErrors.begin(), overloadsErrors.end(), [overload](const std::pair& e) { + return overload == e.second; }); LUAU_ASSERT(error != overloadsErrors.end()); @@ -1036,7 +1036,7 @@ struct TypeChecker2 TypePackId expectedArgTypes = arena->addTypePack(args); std::vector overloads = flattenIntersection(testFunctionType); - std::vector> overloadsErrors; + std::vector> overloadsErrors; overloadsErrors.reserve(overloads.size()); std::vector overloadsThatMatchArgCount; @@ -1060,7 +1060,7 @@ struct TypeChecker2 } else { - overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overloadFn); + overloadsErrors.emplace_back(std::vector{TypeError{call->func->location, UnificationTooComplex{}}}, overload); return; } } @@ -1086,7 +1086,7 @@ struct TypeChecker2 if (!argMismatch) overloadsThatMatchArgCount.push_back(overload); - overloadsErrors.emplace_back(std::move(overloadErrors), overloadFn); + overloadsErrors.emplace_back(std::move(overloadErrors), overload); } reportOverloadResolutionErrors(call, overloads, expectedArgTypes, overloadsThatMatchArgCount, overloadsErrors); @@ -1102,11 +1102,54 @@ struct TypeChecker2 visitCall(call); } + std::optional tryStripUnionFromNil(TypeId ty) + { + if (const UnionType* utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + + std::vector result; + + for (TypeId option : utv) + { + if (!isNil(option)) + result.push_back(option); + } + + if (result.empty()) + return std::nullopt; + + return result.size() == 1 ? result[0] : module->internalTypes.addType(UnionType{std::move(result)}); + } + + return std::nullopt; + } + + TypeId stripFromNilAndReport(TypeId ty, const Location& location) + { + ty = follow(ty); + + if (auto utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + } + + if (std::optional strippedUnion = tryStripUnionFromNil(ty)) + { + reportError(OptionalValueAccess{ty}, location); + return follow(*strippedUnion); + } + + return ty; + } + void visitExprName(AstExpr* expr, Location location, const std::string& propName, ValueContext context) { visit(expr, RValue); - TypeId leftType = lookupType(expr); + TypeId leftType = stripFromNilAndReport(lookupType(expr), location); const NormalizedType* norm = normalizer.normalize(leftType); if (!norm) reportError(NormalizationTooComplex{}, location); @@ -1766,7 +1809,15 @@ struct TypeChecker2 } else { - reportError(UnknownSymbol{ty->name.value, UnknownSymbol::Context::Type}, ty->location); + std::string symbol = ""; + if (ty->prefix) + { + symbol += (*(ty->prefix)).value; + symbol += "."; + } + symbol += ty->name.value; + + reportError(UnknownSymbol{symbol, UnknownSymbol::Context::Type}, ty->location); } } } @@ -2032,7 +2083,11 @@ struct TypeChecker2 { if (foundOneProp) reportError(MissingUnionProperty{tableTy, typesMissingTheProp, prop}, location); - else if (context == LValue) + // For class LValues, we don't want to report an extension error, + // because classes come into being with full knowledge of their + // shape. We instead want to report the unknown property error of + // the `else` branch. + else if (context == LValue && !get(tableTy)) reportError(CannotExtendTable{tableTy, CannotExtendTable::Property, prop}, location); else reportError(UnknownProperty{tableTy, prop}, location); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 6aa8e6c..87d5686 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -42,7 +42,6 @@ LUAU_FASTFLAGVARIABLE(LuauIntersectionTestForEquality, false) LUAU_FASTFLAG(LuauNegatedClassTypes) LUAU_FASTFLAGVARIABLE(LuauAllowIndexClassParameters, false) LUAU_FASTFLAG(LuauUninhabitedSubAnything2) -LUAU_FASTFLAG(SupportTypeAliasGoToDeclaration) LUAU_FASTFLAGVARIABLE(LuauTypecheckTypeguards, false) namespace Luau @@ -212,8 +211,24 @@ size_t HashBoolNamePair::operator()(const std::pair& pair) const return std::hash()(pair.first) ^ std::hash()(pair.second); } -TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) - : resolver(resolver) +GlobalTypes::GlobalTypes(NotNull builtinTypes) + : builtinTypes(builtinTypes) +{ + globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); + + globalScope->addBuiltinTypeBinding("any", TypeFun{{}, builtinTypes->anyType}); + globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, builtinTypes->nilType}); + globalScope->addBuiltinTypeBinding("number", TypeFun{{}, builtinTypes->numberType}); + globalScope->addBuiltinTypeBinding("string", TypeFun{{}, builtinTypes->stringType}); + globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, builtinTypes->booleanType}); + globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, builtinTypes->threadType}); + globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, builtinTypes->unknownType}); + globalScope->addBuiltinTypeBinding("never", TypeFun{{}, builtinTypes->neverType}); +} + +TypeChecker::TypeChecker(const GlobalTypes& globals, ModuleResolver* resolver, NotNull builtinTypes, InternalErrorReporter* iceHandler) + : globals(globals) + , resolver(resolver) , builtinTypes(builtinTypes) , iceHandler(iceHandler) , unifierState(iceHandler) @@ -231,16 +246,6 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, NotNull builtin , uninhabitableTypePack(builtinTypes->uninhabitableTypePack) , duplicateTypeAliases{{false, {}}} { - globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); - - globalScope->addBuiltinTypeBinding("any", TypeFun{{}, anyType}); - globalScope->addBuiltinTypeBinding("nil", TypeFun{{}, nilType}); - globalScope->addBuiltinTypeBinding("number", TypeFun{{}, numberType}); - globalScope->addBuiltinTypeBinding("string", TypeFun{{}, stringType}); - globalScope->addBuiltinTypeBinding("boolean", TypeFun{{}, booleanType}); - globalScope->addBuiltinTypeBinding("thread", TypeFun{{}, threadType}); - globalScope->addBuiltinTypeBinding("unknown", TypeFun{{}, unknownType}); - globalScope->addBuiltinTypeBinding("never", TypeFun{{}, neverType}); } ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) @@ -273,7 +278,7 @@ ModulePtr TypeChecker::checkWithoutRecursionCheck(const SourceModule& module, Mo unifierState.counters.recursionLimit = FInt::LuauTypeInferRecursionLimit; unifierState.counters.iterationLimit = unifierIterationLimit ? *unifierIterationLimit : FInt::LuauTypeInferIterationLimit; - ScopePtr parentScope = environmentScope.value_or(globalScope); + ScopePtr parentScope = environmentScope.value_or(globals.globalScope); ScopePtr moduleScope = std::make_shared(parentScope); if (module.cyclic) @@ -1121,8 +1126,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) if (ModulePtr module = resolver->getModule(moduleInfo->name)) { scope->importedTypeBindings[name] = module->exportedTypeBindings; - if (FFlag::SupportTypeAliasGoToDeclaration) - scope->importedModules[name] = moduleInfo->name; + scope->importedModules[name] = moduleInfo->name; } // In non-strict mode we force the module type on the variable, in strict mode it is already unified @@ -1580,7 +1584,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea } else { - if (globalScope->builtinTypeNames.contains(name)) + if (globals.globalScope->builtinTypeNames.contains(name)) { reportError(typealias.location, DuplicateTypeDefinition{name}); duplicateTypeAliases.insert({typealias.exported, name}); @@ -1601,8 +1605,7 @@ void TypeChecker::prototype(const ScopePtr& scope, const AstStatTypeAlias& typea bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; scope->typeAliasLocations[name] = typealias.location; - if (FFlag::SupportTypeAliasGoToDeclaration) - scope->typeAliasNameLocations[name] = typealias.nameLocation; + scope->typeAliasNameLocations[name] = typealias.nameLocation; } } } @@ -3360,19 +3363,19 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T if (auto globalName = funName.as()) { - const ScopePtr& globalScope = currentModule->getModuleScope(); + const ScopePtr& moduleScope = currentModule->getModuleScope(); Symbol name = globalName->name; - if (globalScope->bindings.count(name)) + if (moduleScope->bindings.count(name)) { if (isNonstrictMode()) - return globalScope->bindings[name].typeId; + return moduleScope->bindings[name].typeId; return errorRecoveryType(scope); } else { TypeId ty = freshTy(); - globalScope->bindings[name] = {ty, funName.location}; + moduleScope->bindings[name] = {ty, funName.location}; return ty; } } @@ -5898,7 +5901,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, RefinementMap& r if (!typeguardP.isTypeof) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - auto typeFun = globalScope->lookupType(typeguardP.kind); + auto typeFun = globals.globalScope->lookupType(typeguardP.kind); if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) return addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); diff --git a/Analysis/src/TypeReduction.cpp b/Analysis/src/TypeReduction.cpp index abafa9f..6a9fadf 100644 --- a/Analysis/src/TypeReduction.cpp +++ b/Analysis/src/TypeReduction.cpp @@ -407,9 +407,6 @@ TypePackId TypeReducer::reduce(TypePackId tp) std::optional TypeReducer::intersectionType(TypeId left, TypeId right) { - LUAU_ASSERT(!get(left)); - LUAU_ASSERT(!get(right)); - if (get(left)) return left; // never & T ~ never else if (get(right)) @@ -442,6 +439,17 @@ std::optional TypeReducer::intersectionType(TypeId left, TypeId right) return std::nullopt; // *pending* & T ~ *pending* & T else if (get(right)) return std::nullopt; // T & *pending* ~ T & *pending* + else if (auto [utl, utr] = get2(left, right); utl && utr) + { + std::vector parts; + for (TypeId optionl : utl) + { + for (TypeId optionr : utr) + parts.push_back(apply(&TypeReducer::intersectionType, optionl, optionr)); + } + + return reduce(flatten(std::move(parts))); // (T | U) & (A | B) ~ (T & A) | (T & B) | (U & A) | (U & B) + } else if (auto ut = get(left)) return reduce(distribute(begin(ut), end(ut), &TypeReducer::intersectionType, right)); // (A | B) & T ~ (A & T) | (B & T) else if (get(right)) @@ -789,6 +797,36 @@ std::optional TypeReducer::unionType(TypeId left, TypeId right) return reduce(distribute(begin(it), end(it), &TypeReducer::unionType, left)); // ~T | (A & B) ~ (~T | A) & (~T | B) else if (auto [it, nt] = get2(left, right); it && nt) return unionType(right, left); // (A & B) | ~T ~ ~T | (A & B) + else if (auto it = get(left)) + { + bool didReduce = false; + std::vector parts; + for (TypeId part : it) + { + auto nt = get(part); + if (!nt) + { + parts.push_back(part); + continue; + } + + auto redex = unionType(part, right); + if (redex && get(*redex)) + { + didReduce = true; + continue; + } + + parts.push_back(part); + } + + if (didReduce) + return flatten(std::move(parts)); // (T & ~nil) | nil ~ T + else + return std::nullopt; // (T & ~nil) | U + } + else if (get(right)) + return unionType(right, left); // A | (T & U) ~ (T & U) | A else if (auto [nl, nr] = get2(left, right); nl && nr) { // These should've been reduced already. diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index aba6427..b53401d 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,10 +18,9 @@ LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAGVARIABLE(LuauInstantiateInSubtyping, false) -LUAU_FASTFLAGVARIABLE(LuauScalarShapeUnifyToMtOwner2, false) LUAU_FASTFLAGVARIABLE(LuauUninhabitedSubAnything2, false) LUAU_FASTFLAGVARIABLE(LuauMaintainScopesInUnifier, false) -LUAU_FASTFLAGVARIABLE(LuauTableUnifyInstantiationFix, false) +LUAU_FASTFLAGVARIABLE(LuauTinyUnifyNormalsFix, false) LUAU_FASTFLAG(LuauClassTypeVarsInSubstitution) LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) LUAU_FASTFLAG(LuauNegatedFunctionTypes) @@ -108,7 +107,7 @@ struct PromoteTypeLevels final : TypeOnceVisitor // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + if (!log.is(ty)) return true; promote(ty, log.getMutable(ty)); @@ -126,7 +125,7 @@ struct PromoteTypeLevels final : TypeOnceVisitor // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauScalarShapeUnifyToMtOwner2 && !log.is(ty)) + if (!log.is(ty)) return true; promote(ty, log.getMutable(ty)); @@ -690,6 +689,31 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionType* subUnion, Typ } } +struct BlockedTypeFinder : TypeOnceVisitor +{ + std::unordered_set blockedTypes; + + bool visit(TypeId ty, const BlockedType&) override + { + blockedTypes.insert(ty); + return true; + } +}; + +bool Unifier::blockOnBlockedTypes(TypeId subTy, TypeId superTy) +{ + BlockedTypeFinder blockedTypeFinder; + blockedTypeFinder.traverse(subTy); + blockedTypeFinder.traverse(superTy); + if (!blockedTypeFinder.blockedTypes.empty()) + { + blockedTypes.insert(end(blockedTypes), begin(blockedTypeFinder.blockedTypes), end(blockedTypeFinder.blockedTypes)); + return true; + } + + return false; +} + void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionType* uv, bool cacheEnabled, bool isFunctionCall) { // T <: A | B if T <: A or T <: B @@ -788,6 +812,11 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp } else if (!found && normalize) { + // We cannot normalize a type that contains blocked types. We have to + // stop for now if we find any. + if (blockOnBlockedTypes(subTy, superTy)) + return; + // It is possible that T <: A | B even though T (superTable)) - innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); - else if (get(subTable)) - innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); + + if (FFlag::LuauTinyUnifyNormalsFix) + innerState.tryUnify(subTable, superTable); else - innerState.tryUnifyTables(subTable, superTable); + { + if (get(superTable)) + innerState.tryUnifyWithMetatable(subTable, superTable, /* reversed */ false); + else if (get(subTable)) + innerState.tryUnifyWithMetatable(superTable, subTable, /* reversed */ true); + else + innerState.tryUnifyTables(subTable, superTable); + } + if (innerState.errors.empty()) { found = true; @@ -1782,7 +1828,6 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TypeId activeSubTy = subTy; TableType* superTable = log.getMutable(superTy); TableType* subTable = log.getMutable(subTy); - TableType* instantiatedSubTable = subTable; // TODO: remove with FFlagLuauTableUnifyInstantiationFix if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1799,16 +1844,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) std::optional instantiated = instantiation.substitute(subTy); if (instantiated.has_value()) { - if (FFlag::LuauTableUnifyInstantiationFix) - { - activeSubTy = *instantiated; - subTable = log.getMutable(activeSubTy); - } - else - { - subTable = log.getMutable(*instantiated); - instantiatedSubTable = subTable; - } + activeSubTy = *instantiated; + subTable = log.getMutable(activeSubTy); if (!subTable) ice("instantiation made a table type into a non-table type in tryUnifyTables"); @@ -1910,21 +1947,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated // txn log. - TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(superTy) : superTy; - TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(activeSubTy) : activeSubTy; + TypeId superTyNew = log.follow(superTy); + TypeId subTyNew = log.follow(activeSubTy); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - // If one of the types stopped being a table altogether, we need to restart from the top - if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) - return tryUnify(subTy, superTy, false, isIntersection); - } + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) + return tryUnify(subTy, superTy, false, isIntersection); // Otherwise, restart only the table unification TableType* newSuperTable = log.getMutable(superTyNew); TableType* newSubTable = log.getMutable(subTyNew); - if (superTable != newSuperTable || (subTable != newSubTable && (FFlag::LuauTableUnifyInstantiationFix || subTable != instantiatedSubTable))) + if (superTable != newSuperTable || subTable != newSubTable) { if (errors.empty()) return tryUnifyTables(subTy, superTy, isIntersection); @@ -1981,15 +2015,12 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else extraProperties.push_back(name); - TypeId superTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(superTy) : superTy; - TypeId subTyNew = FFlag::LuauScalarShapeUnifyToMtOwner2 ? log.follow(activeSubTy) : activeSubTy; + TypeId superTyNew = log.follow(superTy); + TypeId subTyNew = log.follow(activeSubTy); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - // If one of the types stopped being a table altogether, we need to restart from the top - if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) - return tryUnify(subTy, superTy, false, isIntersection); - } + // If one of the types stopped being a table altogether, we need to restart from the top + if ((superTy != superTyNew || activeSubTy != subTyNew) && errors.empty()) + return tryUnify(subTy, superTy, false, isIntersection); // Recursive unification can change the txn log, and invalidate the old // table. If we detect that this has happened, we start over, with the updated @@ -1997,7 +2028,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TableType* newSuperTable = log.getMutable(superTyNew); TableType* newSubTable = log.getMutable(subTyNew); - if (superTable != newSuperTable || (subTable != newSubTable && (FFlag::LuauTableUnifyInstantiationFix || subTable != instantiatedSubTable))) + if (superTable != newSuperTable || subTable != newSubTable) { if (errors.empty()) return tryUnifyTables(subTy, superTy, isIntersection); @@ -2050,19 +2081,11 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // Changing the indexer can invalidate the table pointers. - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - superTable = log.getMutable(log.follow(superTy)); - subTable = log.getMutable(log.follow(activeSubTy)); + superTable = log.getMutable(log.follow(superTy)); + subTable = log.getMutable(log.follow(activeSubTy)); - if (!superTable || !subTable) - return; - } - else - { - superTable = log.getMutable(superTy); - subTable = log.getMutable(activeSubTy); - } + if (!superTable || !subTable) + return; if (!missingProperties.empty()) { @@ -2135,18 +2158,15 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) Unifier child = makeChildUnifier(); child.tryUnify_(ty, superTy); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table - // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed - // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check - TypeId newSuperTy = child.log.follow(superTy); + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // There is a chance that it was unified with the origial subtype, but then, (subtype's metatable) <: subtype could've failed + // Here we check if we have a new supertype instead of the original free table and try original subtype <: new supertype check + TypeId newSuperTy = child.log.follow(superTy); - if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) - { - log.replace(superTy, BoundType{subTy}); - return; - } + if (superTy != newSuperTy && canUnify(subTy, newSuperTy).empty()) + { + log.replace(superTy, BoundType{subTy}); + return; } if (auto e = hasUnificationTooComplex(child.errors)) @@ -2156,13 +2176,10 @@ void Unifier::tryUnifyScalarShape(TypeId subTy, TypeId superTy, bool reversed) log.concat(std::move(child.log)); - if (FFlag::LuauScalarShapeUnifyToMtOwner2) - { - // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table - // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype - if (child.errors.empty()) - log.replace(superTy, BoundType{subTy}); - } + // To perform subtype <: free table unification, we have tried to unify (subtype's metatable) <: free table + // We return success because subtype <: free table which means that correct unification is to replace free table with the subtype + if (child.errors.empty()) + log.replace(superTy, BoundType{subTy}); return; } @@ -2379,6 +2396,11 @@ void Unifier::tryUnifyNegations(TypeId subTy, TypeId superTy) if (!log.get(subTy) && !log.get(superTy)) ice("tryUnifyNegations superTy or subTy must be a negation type"); + // We cannot normalize a type that contains blocked types. We have to + // stop for now if we find any. + if (blockOnBlockedTypes(subTy, superTy)) + return; + const NormalizedType* subNorm = normalizer->normalize(subTy); const NormalizedType* superNorm = normalizer->normalize(superTy); if (!subNorm || !superNorm) diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 6257e2f..d6f1822 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -121,6 +121,7 @@ static void displayHelp(const char* argv0) static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + fflush(stdout); return 1; } @@ -267,8 +268,8 @@ int main(int argc, char** argv) CliConfigResolver configResolver(mode); Luau::Frontend frontend(&fileResolver, &configResolver, frontendOptions); - Luau::registerBuiltinGlobals(frontend.typeChecker); - Luau::freeze(frontend.typeChecker.globalTypes); + Luau::registerBuiltinGlobals(frontend.typeChecker, frontend.globals); + Luau::freeze(frontend.globals.globalTypes); #ifdef CALLGRIND CALLGRIND_ZERO_STATS; diff --git a/CodeGen/include/Luau/IrAnalysis.h b/CodeGen/include/Luau/IrAnalysis.h index d3e1a93..21fa755 100644 --- a/CodeGen/include/Luau/IrAnalysis.h +++ b/CodeGen/include/Luau/IrAnalysis.h @@ -1,7 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include #include +#include #include @@ -22,5 +24,59 @@ std::pair getLiveInOutValueCount(IrFunction& function, IrBlo uint32_t getLiveInValueCount(IrFunction& function, IrBlock& block); uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block); +struct RegisterSet +{ + std::bitset<256> regs; + + // If variadic sequence is active, we track register from which it starts + bool varargSeq = false; + uint8_t varargStart = 0; +}; + +struct CfgInfo +{ + std::vector predecessors; + std::vector predecessorsOffsets; + + std::vector successors; + std::vector successorsOffsets; + + std::vector in; + std::vector out; + + RegisterSet captured; +}; + +void computeCfgInfo(IrFunction& function); + +struct BlockIteratorWrapper +{ + uint32_t* itBegin = nullptr; + uint32_t* itEnd = nullptr; + + bool empty() const + { + return itBegin == itEnd; + } + + size_t size() const + { + return size_t(itEnd - itBegin); + } + + uint32_t* begin() const + { + return itBegin; + } + + uint32_t* end() const + { + return itEnd; + } +}; + +BlockIteratorWrapper predecessors(CfgInfo& cfg, uint32_t blockIdx); +BlockIteratorWrapper successors(CfgInfo& cfg, uint32_t blockIdx); + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/include/Luau/IrData.h b/CodeGen/include/Luau/IrData.h index 049d700..439abb9 100644 --- a/CodeGen/include/Luau/IrData.h +++ b/CodeGen/include/Luau/IrData.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/IrAnalysis.h" #include "Luau/Label.h" #include "Luau/RegisterX64.h" #include "Luau/RegisterA64.h" @@ -261,6 +262,7 @@ enum class IrCmd : uint8_t // A: Rn (value start) // B: unsigned int (number of registers to go over) // Note: result is stored in the register specified in 'A' + // Note: all referenced registers might be modified in the operation CONCAT, // Load function upvalue into stack slot @@ -382,16 +384,16 @@ enum class IrCmd : uint8_t LOP_RETURN, // Adjust loop variables for one iteration of a generic for loop, jump back to the loop header if loop needs to continue - // A: Rn (loop variable start, updates Rn+2 Rn+3 Rn+4) - // B: int (loop variable count, is more than 2, additional registers are set to nil) + // A: Rn (loop variable start, updates Rn+2 and 'B' number of registers starting from Rn+3) + // B: int (loop variable count, if more than 2, registers starting from Rn+5 are set to nil) // C: block (repeat) // D: block (exit) LOP_FORGLOOP, // Handle LOP_FORGLOOP fallback when variable being iterated is not a table // A: unsigned int (bytecode instruction index) - // B: Rn (loop state start, updates Rn+2 Rn+3 Rn+4 Rn+5) - // C: int (extra variable count or -1 for ipairs-style iteration) + // B: Rn (loop state start, updates Rn+2 and 'C' number of registers starting from Rn+3) + // C: int (loop variable count and a MSB set when it's an ipairs-like iteration loop) // D: block (repeat) // E: block (exit) LOP_FORGLOOP_FALLBACK, @@ -638,6 +640,8 @@ struct IrFunction Proto* proto = nullptr; + CfgInfo cfg; + IrBlock& blockOp(IrOp op) { LUAU_ASSERT(op.kind == IrOpKind::Block); diff --git a/CodeGen/include/Luau/IrDump.h b/CodeGen/include/Luau/IrDump.h index 47a5f9e..a6329ec 100644 --- a/CodeGen/include/Luau/IrDump.h +++ b/CodeGen/include/Luau/IrDump.h @@ -11,6 +11,8 @@ namespace Luau namespace CodeGen { +struct CfgInfo; + const char* getCmdName(IrCmd cmd); const char* getBlockKindName(IrBlockKind kind); @@ -19,6 +21,7 @@ struct IrToStringContext std::string& result; std::vector& blocks; std::vector& constants; + CfgInfo& cfg; }; void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index); @@ -27,10 +30,10 @@ void toString(IrToStringContext& ctx, IrOp op); void toString(std::string& result, IrConst constant); -void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index); -void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index); // Block title +void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo); +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo); // Block title -std::string toString(IrFunction& function, bool includeDetails); +std::string toString(IrFunction& function, bool includeUseInfo); std::string dump(IrFunction& function); diff --git a/CodeGen/include/Luau/IrUtils.h b/CodeGen/include/Luau/IrUtils.h index 153cf7a..3b14a8c 100644 --- a/CodeGen/include/Luau/IrUtils.h +++ b/CodeGen/include/Luau/IrUtils.h @@ -183,9 +183,6 @@ void kill(IrFunction& function, uint32_t start, uint32_t end); // Remove a block, including all instructions inside void kill(IrFunction& function, IrBlock& block); -void removeUse(IrFunction& function, IrInst& inst); -void removeUse(IrFunction& function, IrBlock& block); - // Replace a single operand and update use counts (can cause chain removal of dead code) void replace(IrFunction& function, IrOp& original, IrOp replacement); diff --git a/CodeGen/src/CodeGen.cpp b/CodeGen/src/CodeGen.cpp index 51bf174..c794972 100644 --- a/CodeGen/src/CodeGen.cpp +++ b/CodeGen/src/CodeGen.cpp @@ -68,9 +68,24 @@ static NativeProto* assembleFunction(X64::AssemblyBuilderX64& build, NativeState if (options.includeAssembly || options.includeIr) { if (proto->debugname) - build.logAppend("; function %s()", getstr(proto->debugname)); + build.logAppend("; function %s(", getstr(proto->debugname)); else - build.logAppend("; function()"); + build.logAppend("; function("); + + for (int i = 0; i < proto->numparams; i++) + { + LocVar* var = proto->locvars ? &proto->locvars[proto->sizelocvars - proto->numparams + i] : nullptr; + + if (var && var->varname) + build.logAppend("%s%s", i == 0 ? "" : ", ", getstr(var->varname)); + else + build.logAppend("%s$arg%d", i == 0 ? "" : ", ", i); + } + + if (proto->numparams != 0 && proto->is_vararg) + build.logAppend(", ...)"); + else + build.logAppend(")"); if (proto->linedefined >= 0) build.logAppend(" line %d\n", proto->linedefined); @@ -90,6 +105,10 @@ static NativeProto* assembleFunction(X64::AssemblyBuilderX64& build, NativeState constPropInBlockChains(builder); } + // TODO: cfg info has to be computed earlier to use in optimizations + // It's done here to appear in text output and to measure performance impact on code generation + computeCfgInfo(builder.function); + optimizeMemoryOperandsX64(builder.function); X64::IrLoweringX64 lowering(build, helpers, data, proto, builder.function); diff --git a/CodeGen/src/IrAnalysis.cpp b/CodeGen/src/IrAnalysis.cpp index aa3e19f..dc7d771 100644 --- a/CodeGen/src/IrAnalysis.cpp +++ b/CodeGen/src/IrAnalysis.cpp @@ -5,6 +5,10 @@ #include "Luau/IrData.h" #include "Luau/IrUtils.h" +#include "lobject.h" + +#include + #include namespace Luau @@ -116,5 +120,518 @@ uint32_t getLiveOutValueCount(IrFunction& function, IrBlock& block) return getLiveInOutValueCount(function, block).second; } +static void requireVariadicSequence(RegisterSet& sourceRs, const RegisterSet& defRs, uint8_t varargStart) +{ + if (!defRs.varargSeq) + { + LUAU_ASSERT(!sourceRs.varargSeq || sourceRs.varargStart == varargStart); + + sourceRs.varargSeq = true; + sourceRs.varargStart = varargStart; + } + else + { + // Variadic use sequence might include registers before def sequence + for (int i = varargStart; i < defRs.varargStart; i++) + { + if (!defRs.regs.test(i)) + sourceRs.regs.set(i); + } + } +} + +static RegisterSet computeBlockLiveInRegSet(IrFunction& function, const IrBlock& block, RegisterSet& defRs, std::bitset<256>& capturedRegs) +{ + RegisterSet inRs; + + auto def = [&](IrOp op, int offset = 0) { + LUAU_ASSERT(op.kind == IrOpKind::VmReg); + defRs.regs.set(op.index + offset, true); + }; + + auto use = [&](IrOp op, int offset = 0) { + LUAU_ASSERT(op.kind == IrOpKind::VmReg); + if (!defRs.regs.test(op.index + offset)) + inRs.regs.set(op.index + offset, true); + }; + + auto maybeDef = [&](IrOp op) { + if (op.kind == IrOpKind::VmReg) + defRs.regs.set(op.index, true); + }; + + auto maybeUse = [&](IrOp op) { + if (op.kind == IrOpKind::VmReg) + { + if (!defRs.regs.test(op.index)) + inRs.regs.set(op.index, true); + } + }; + + auto defVarargs = [&](uint8_t varargStart) { + defRs.varargSeq = true; + defRs.varargStart = varargStart; + }; + + auto useVarargs = [&](uint8_t varargStart) { + requireVariadicSequence(inRs, defRs, varargStart); + + // Variadic sequence has been consumed + defRs.varargSeq = false; + defRs.varargStart = 0; + }; + + auto defRange = [&](int start, int count) { + if (count == -1) + { + defVarargs(start); + } + else + { + for (int i = start; i < start + count; i++) + defRs.regs.set(i, true); + } + }; + + auto useRange = [&](int start, int count) { + if (count == -1) + { + useVarargs(start); + } + else + { + for (int i = start; i < start + count; i++) + { + if (!defRs.regs.test(i)) + inRs.regs.set(i, true); + } + } + }; + + for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) + { + const IrInst& inst = function.instructions[instIdx]; + + // For correct analysis, all instruction uses must be handled before handling the definitions + switch (inst.cmd) + { + case IrCmd::LOAD_TAG: + case IrCmd::LOAD_POINTER: + case IrCmd::LOAD_DOUBLE: + case IrCmd::LOAD_INT: + case IrCmd::LOAD_TVALUE: + maybeUse(inst.a); // Argument can also be a VmConst + break; + case IrCmd::STORE_TAG: + case IrCmd::STORE_POINTER: + case IrCmd::STORE_DOUBLE: + case IrCmd::STORE_INT: + case IrCmd::STORE_TVALUE: + maybeDef(inst.a); // Argument can also be a pointer value + break; + case IrCmd::JUMP_IF_TRUTHY: + case IrCmd::JUMP_IF_FALSY: + use(inst.a); + break; + case IrCmd::JUMP_CMP_ANY: + use(inst.a); + use(inst.b); + break; + // A <- B, C + case IrCmd::DO_ARITH: + case IrCmd::GET_TABLE: + case IrCmd::SET_TABLE: + use(inst.b); + maybeUse(inst.c); // Argument can also be a VmConst + + def(inst.a); + break; + // A <- B + case IrCmd::DO_LEN: + use(inst.b); + + def(inst.a); + break; + case IrCmd::GET_IMPORT: + def(inst.a); + break; + case IrCmd::CONCAT: + useRange(inst.a.index, function.uintOp(inst.b)); + + defRange(inst.a.index, function.uintOp(inst.b)); + break; + case IrCmd::GET_UPVALUE: + def(inst.a); + break; + case IrCmd::SET_UPVALUE: + use(inst.b); + break; + case IrCmd::PREPARE_FORN: + use(inst.a); + use(inst.b); + use(inst.c); + + def(inst.a); + def(inst.b); + def(inst.c); + break; + case IrCmd::INTERRUPT: + break; + case IrCmd::BARRIER_OBJ: + case IrCmd::BARRIER_TABLE_FORWARD: + use(inst.b); + break; + case IrCmd::CLOSE_UPVALS: + // Closing an upvalue should be counted as a register use (it copies the fresh register value) + // But we lack the required information about the specific set of registers that are affected + // Because we don't plan to optimize captured registers atm, we skip full dataflow analysis for them right now + break; + case IrCmd::CAPTURE: + maybeUse(inst.a); + + if (function.boolOp(inst.b)) + capturedRegs.set(inst.a.index, true); + break; + case IrCmd::LOP_SETLIST: + use(inst.b); + useRange(inst.c.index, function.intOp(inst.d)); + break; + case IrCmd::LOP_NAMECALL: + use(inst.c); + + defRange(inst.b.index, 2); + break; + case IrCmd::LOP_CALL: + use(inst.b); + useRange(inst.b.index + 1, function.intOp(inst.c)); + + defRange(inst.b.index, function.intOp(inst.d)); + break; + case IrCmd::LOP_RETURN: + useRange(inst.b.index, function.intOp(inst.c)); + break; + case IrCmd::FASTCALL: + case IrCmd::INVOKE_FASTCALL: + if (int count = function.intOp(inst.e); count != -1) + { + if (count >= 3) + { + LUAU_ASSERT(inst.d.kind == IrOpKind::VmReg && inst.d.index == inst.c.index + 1); + + useRange(inst.c.index, count); + } + else + { + if (count >= 1) + use(inst.c); + + if (count >= 2) + maybeUse(inst.d); // Argument can also be a VmConst + } + } + else + { + useVarargs(inst.c.index); + } + + defRange(inst.b.index, function.intOp(inst.f)); + break; + case IrCmd::LOP_FORGLOOP: + // First register is not used by instruction, we check that it's still 'nil' with CHECK_TAG + use(inst.a, 1); + use(inst.a, 2); + + def(inst.a, 2); + defRange(inst.a.index + 3, function.intOp(inst.b)); + break; + case IrCmd::LOP_FORGLOOP_FALLBACK: + useRange(inst.b.index, 3); + + def(inst.b, 2); + defRange(inst.b.index + 3, uint8_t(function.intOp(inst.c))); // ignore most significant bit + break; + case IrCmd::LOP_FORGPREP_XNEXT_FALLBACK: + use(inst.b); + break; + // B <- C, D + case IrCmd::LOP_AND: + case IrCmd::LOP_OR: + use(inst.c); + use(inst.d); + + def(inst.b); + break; + // B <- C + case IrCmd::LOP_ANDK: + case IrCmd::LOP_ORK: + use(inst.c); + + def(inst.b); + break; + case IrCmd::FALLBACK_GETGLOBAL: + def(inst.b); + break; + case IrCmd::FALLBACK_SETGLOBAL: + use(inst.b); + break; + case IrCmd::FALLBACK_GETTABLEKS: + use(inst.c); + + def(inst.b); + break; + case IrCmd::FALLBACK_SETTABLEKS: + use(inst.b); + use(inst.c); + break; + case IrCmd::FALLBACK_NAMECALL: + use(inst.c); + + defRange(inst.b.index, 2); + break; + case IrCmd::FALLBACK_PREPVARARGS: + // No effect on explicitly referenced registers + break; + case IrCmd::FALLBACK_GETVARARGS: + defRange(inst.b.index, function.intOp(inst.c)); + break; + case IrCmd::FALLBACK_NEWCLOSURE: + def(inst.b); + break; + case IrCmd::FALLBACK_DUPCLOSURE: + def(inst.b); + break; + case IrCmd::FALLBACK_FORGPREP: + use(inst.b); + + defRange(inst.b.index, 3); + break; + case IrCmd::ADJUST_STACK_TO_REG: + case IrCmd::ADJUST_STACK_TO_TOP: + // While these can be considered as vararg producers and consumers, it is already handled in fastcall instruction + break; + + default: + break; + } + } + + return inRs; +} + +// The algorithm used here is commonly known as backwards data-flow analysis. +// For each block, we track 'upward-exposed' (live-in) uses of registers - a use of a register that hasn't been defined in the block yet. +// We also track the set of registers that were defined in the block. +// When initial live-in sets of registers are computed, propagation of those uses upwards through predecessors is performed. +// If predecessor doesn't define the register, we have to add it to the live-in set. +// Extending the set of live-in registers of a block requires re-checking of that block. +// Propagation runs iteratively, using a worklist of blocks to visit until a fixed point is reached. +// This algorithm can be easily extended to cover phi instructions, but we don't use those yet. +static void computeCfgLiveInOutRegSets(IrFunction& function) +{ + CfgInfo& info = function.cfg; + + // Try to compute Luau VM register use-def info + info.in.resize(function.blocks.size()); + info.out.resize(function.blocks.size()); + + // Captured registers are tracked for the whole function + // It should be possible to have a more precise analysis for them in the future + std::bitset<256> capturedRegs; + + std::vector defRss; + defRss.resize(function.blocks.size()); + + // First we compute live-in set of each block + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + const IrBlock& block = function.blocks[blockIdx]; + + if (block.kind == IrBlockKind::Dead) + continue; + + info.in[blockIdx] = computeBlockLiveInRegSet(function, block, defRss[blockIdx], capturedRegs); + } + + info.captured.regs = capturedRegs; + + // With live-in sets ready, we can arrive at a fixed point for both in/out registers by requesting required registers from predecessors + std::vector worklist; + + std::vector inWorklist; + inWorklist.resize(function.blocks.size(), false); + + // We will have to visit each block at least once, so we add all of them to the worklist immediately + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + const IrBlock& block = function.blocks[blockIdx]; + + if (block.kind == IrBlockKind::Dead) + continue; + + worklist.push_back(uint32_t(blockIdx)); + inWorklist[blockIdx] = true; + } + + while (!worklist.empty()) + { + uint32_t blockIdx = worklist.back(); + worklist.pop_back(); + inWorklist[blockIdx] = false; + + IrBlock& curr = function.blocks[blockIdx]; + RegisterSet& inRs = info.in[blockIdx]; + RegisterSet& outRs = info.out[blockIdx]; + RegisterSet& defRs = defRss[blockIdx]; + + // Current block has to provide all registers in successor blocks + for (uint32_t succIdx : successors(info, blockIdx)) + { + IrBlock& succ = function.blocks[succIdx]; + + // This is a step away from the usual definition of live range flow through CFG + // Exit from a regular block to a fallback block is not considered a block terminator + // This is because fallback blocks define an alternative implementation of the same operations + // This can cause the current block to define more registers that actually were available at fallback entry + if (curr.kind != IrBlockKind::Fallback && succ.kind == IrBlockKind::Fallback) + continue; + + const RegisterSet& succRs = info.in[succIdx]; + + outRs.regs |= succRs.regs; + + if (succRs.varargSeq) + { + LUAU_ASSERT(!outRs.varargSeq || outRs.varargStart == succRs.varargStart); + + outRs.varargSeq = true; + outRs.varargStart = succRs.varargStart; + } + } + + RegisterSet oldInRs = inRs; + + // If current block didn't define a live-out, it has to be live-in + inRs.regs |= outRs.regs & ~defRs.regs; + + if (outRs.varargSeq) + requireVariadicSequence(inRs, defRs, outRs.varargStart); + + // If we have new live-ins, we have to notify all predecessors + // We don't allow changes to the start of the variadic sequence, so we skip checking that member + if (inRs.regs != oldInRs.regs || inRs.varargSeq != oldInRs.varargSeq) + { + for (uint32_t predIdx : predecessors(info, blockIdx)) + { + if (!inWorklist[predIdx]) + { + worklist.push_back(predIdx); + inWorklist[predIdx] = true; + } + } + } + } + + // If Proto data is available, validate that entry block arguments match required registers + if (function.proto) + { + RegisterSet& entryIn = info.in[0]; + + LUAU_ASSERT(!entryIn.varargSeq); + + for (size_t i = 0; i < entryIn.regs.size(); i++) + LUAU_ASSERT(!entryIn.regs.test(i) || i < function.proto->numparams); + } +} + +static void computeCfgBlockEdges(IrFunction& function) +{ + CfgInfo& info = function.cfg; + + // Compute predecessors block edges + info.predecessorsOffsets.reserve(function.blocks.size()); + info.successorsOffsets.reserve(function.blocks.size()); + + int edgeCount = 0; + + for (const IrBlock& block : function.blocks) + { + info.predecessorsOffsets.push_back(edgeCount); + edgeCount += block.useCount; + } + + info.predecessors.resize(edgeCount); + info.successors.resize(edgeCount); + + edgeCount = 0; + + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + const IrBlock& block = function.blocks[blockIdx]; + + info.successorsOffsets.push_back(edgeCount); + + if (block.kind == IrBlockKind::Dead) + continue; + + for (uint32_t instIdx = block.start; instIdx <= block.finish; instIdx++) + { + const IrInst& inst = function.instructions[instIdx]; + + auto checkOp = [&](IrOp op) { + if (op.kind == IrOpKind::Block) + { + // We use a trick here, where we use the starting offset of the predecessor list as the position where to write next predecessor + // The values will be adjusted back in a separate loop later + info.predecessors[info.predecessorsOffsets[op.index]++] = uint32_t(blockIdx); + + info.successors[edgeCount++] = op.index; + } + }; + + checkOp(inst.a); + checkOp(inst.b); + checkOp(inst.c); + checkOp(inst.d); + checkOp(inst.e); + checkOp(inst.f); + } + } + + // Offsets into the predecessor list were used as iterators in the previous loop + // To adjust them back, block use count is subtracted (predecessor count is equal to how many uses block has) + for (size_t blockIdx = 0; blockIdx < function.blocks.size(); blockIdx++) + { + const IrBlock& block = function.blocks[blockIdx]; + + info.predecessorsOffsets[blockIdx] -= block.useCount; + } +} + +void computeCfgInfo(IrFunction& function) +{ + computeCfgBlockEdges(function); + computeCfgLiveInOutRegSets(function); +} + +BlockIteratorWrapper predecessors(CfgInfo& cfg, uint32_t blockIdx) +{ + LUAU_ASSERT(blockIdx < cfg.predecessorsOffsets.size()); + + uint32_t start = cfg.predecessorsOffsets[blockIdx]; + uint32_t end = blockIdx + 1 < cfg.predecessorsOffsets.size() ? cfg.predecessorsOffsets[blockIdx + 1] : uint32_t(cfg.predecessors.size()); + + return BlockIteratorWrapper{cfg.predecessors.data() + start, cfg.predecessors.data() + end}; +} + +BlockIteratorWrapper successors(CfgInfo& cfg, uint32_t blockIdx) +{ + LUAU_ASSERT(blockIdx < cfg.successorsOffsets.size()); + + uint32_t start = cfg.successorsOffsets[blockIdx]; + uint32_t end = blockIdx + 1 < cfg.successorsOffsets.size() ? cfg.successorsOffsets[blockIdx + 1] : uint32_t(cfg.successors.size()); + + return BlockIteratorWrapper{cfg.successors.data() + start, cfg.successors.data() + end}; +} + } // namespace CodeGen } // namespace Luau diff --git a/CodeGen/src/IrBuilder.cpp b/CodeGen/src/IrBuilder.cpp index 056ea60..0a700db 100644 --- a/CodeGen/src/IrBuilder.cpp +++ b/CodeGen/src/IrBuilder.cpp @@ -256,7 +256,7 @@ void IrBuilder::translateInst(LuauOpcode op, const Instruction* pc, int i) translateInstDupTable(*this, pc, i); break; case LOP_SETLIST: - inst(IrCmd::LOP_SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_A(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); + inst(IrCmd::LOP_SETLIST, constUint(i), vmReg(LUAU_INSN_A(*pc)), vmReg(LUAU_INSN_B(*pc)), constInt(LUAU_INSN_C(*pc) - 1), constUint(pc[1])); break; case LOP_GETUPVAL: translateInstGetUpval(*this, pc, i); diff --git a/CodeGen/src/IrDump.cpp b/CodeGen/src/IrDump.cpp index cb203f7..2787fb1 100644 --- a/CodeGen/src/IrDump.cpp +++ b/CodeGen/src/IrDump.cpp @@ -306,7 +306,7 @@ void toString(IrToStringContext& ctx, const IrInst& inst, uint32_t index) void toString(IrToStringContext& ctx, const IrBlock& block, uint32_t index) { - append(ctx.result, "%s_%u:", getBlockKindName(block.kind), index); + append(ctx.result, "%s_%u", getBlockKindName(block.kind), index); } void toString(IrToStringContext& ctx, IrOp op) @@ -362,33 +362,151 @@ void toString(std::string& result, IrConst constant) } } -void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index) +void toStringDetailed(IrToStringContext& ctx, const IrInst& inst, uint32_t index, bool includeUseInfo) { size_t start = ctx.result.size(); toString(ctx, inst, index); - padToDetailColumn(ctx.result, start); - if (inst.useCount == 0 && hasSideEffects(inst.cmd)) - append(ctx.result, "; %%%u, has side-effects\n", index); + if (includeUseInfo) + { + padToDetailColumn(ctx.result, start); + + if (inst.useCount == 0 && hasSideEffects(inst.cmd)) + append(ctx.result, "; %%%u, has side-effects\n", index); + else + append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); + } else - append(ctx.result, "; useCount: %d, lastUse: %%%u\n", inst.useCount, inst.lastUse); + { + ctx.result.append("\n"); + } } -void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index) +static void appendBlockSet(IrToStringContext& ctx, BlockIteratorWrapper blocks) { + bool comma = false; + + for (uint32_t target : blocks) + { + if (comma) + append(ctx.result, ", "); + comma = true; + + toString(ctx, ctx.blocks[target], target); + } +} + +static void appendRegisterSet(IrToStringContext& ctx, const RegisterSet& rs) +{ + bool comma = false; + + for (size_t i = 0; i < rs.regs.size(); i++) + { + if (rs.regs.test(i)) + { + if (comma) + append(ctx.result, ", "); + comma = true; + + append(ctx.result, "R%d", int(i)); + } + } + + if (rs.varargSeq) + { + if (comma) + append(ctx.result, ", "); + + append(ctx.result, "R%d...", rs.varargStart); + } +} + +void toStringDetailed(IrToStringContext& ctx, const IrBlock& block, uint32_t index, bool includeUseInfo) +{ + // Report captured registers for entry block + if (block.useCount == 0 && block.kind != IrBlockKind::Dead && ctx.cfg.captured.regs.any()) + { + append(ctx.result, "; captured regs: "); + appendRegisterSet(ctx, ctx.cfg.captured); + append(ctx.result, "\n\n"); + } + size_t start = ctx.result.size(); toString(ctx, block, index); - padToDetailColumn(ctx.result, start); + append(ctx.result, ":"); - append(ctx.result, "; useCount: %d\n", block.useCount); + if (includeUseInfo) + { + padToDetailColumn(ctx.result, start); + + append(ctx.result, "; useCount: %d\n", block.useCount); + } + else + { + ctx.result.append("\n"); + } + + // Predecessor list + if (!ctx.cfg.predecessors.empty()) + { + BlockIteratorWrapper pred = predecessors(ctx.cfg, index); + + if (!pred.empty()) + { + append(ctx.result, "; predecessors: "); + + appendBlockSet(ctx, pred); + append(ctx.result, "\n"); + } + } + + // Successor list + if (!ctx.cfg.successors.empty()) + { + BlockIteratorWrapper succ = successors(ctx.cfg, index); + + if (!succ.empty()) + { + append(ctx.result, "; successors: "); + + appendBlockSet(ctx, succ); + append(ctx.result, "\n"); + } + } + + // Live-in VM regs + if (index < ctx.cfg.in.size()) + { + const RegisterSet& in = ctx.cfg.in[index]; + + if (in.regs.any() || in.varargSeq) + { + append(ctx.result, "; in regs: "); + appendRegisterSet(ctx, in); + append(ctx.result, "\n"); + } + } + + // Live-out VM regs + if (index < ctx.cfg.out.size()) + { + const RegisterSet& out = ctx.cfg.out[index]; + + if (out.regs.any() || out.varargSeq) + { + append(ctx.result, "; out regs: "); + appendRegisterSet(ctx, out); + append(ctx.result, "\n"); + } + } } -std::string toString(IrFunction& function, bool includeDetails) +std::string toString(IrFunction& function, bool includeUseInfo) { std::string result; - IrToStringContext ctx{result, function.blocks, function.constants}; + IrToStringContext ctx{result, function.blocks, function.constants, function.cfg}; for (size_t i = 0; i < function.blocks.size(); i++) { @@ -397,15 +515,7 @@ std::string toString(IrFunction& function, bool includeDetails) if (block.kind == IrBlockKind::Dead) continue; - if (includeDetails) - { - toStringDetailed(ctx, block, uint32_t(i)); - } - else - { - toString(ctx, block, uint32_t(i)); - ctx.result.append("\n"); - } + toStringDetailed(ctx, block, uint32_t(i), includeUseInfo); if (block.start == ~0u) { @@ -423,16 +533,7 @@ std::string toString(IrFunction& function, bool includeDetails) continue; append(ctx.result, " "); - - if (includeDetails) - { - toStringDetailed(ctx, inst, index); - } - else - { - toString(ctx, inst, index); - ctx.result.append("\n"); - } + toStringDetailed(ctx, inst, index, includeUseInfo); } append(ctx.result, "\n"); @@ -443,7 +544,7 @@ std::string toString(IrFunction& function, bool includeDetails) std::string dump(IrFunction& function) { - std::string result = toString(function, /* includeDetails */ true); + std::string result = toString(function, /* includeUseInfo */ true); printf("%s\n", result.c_str()); diff --git a/CodeGen/src/IrLoweringX64.cpp b/CodeGen/src/IrLoweringX64.cpp index 3833757..3b27d09 100644 --- a/CodeGen/src/IrLoweringX64.cpp +++ b/CodeGen/src/IrLoweringX64.cpp @@ -79,7 +79,7 @@ void IrLoweringX64::lower(AssemblyOptions options) } } - IrToStringContext ctx{build.text, function.blocks, function.constants}; + IrToStringContext ctx{build.text, function.blocks, function.constants, function.cfg}; // We use this to skip outlined fallback blocks from IR/asm text output size_t textSize = build.text.length(); @@ -112,7 +112,7 @@ void IrLoweringX64::lower(AssemblyOptions options) if (options.includeIr) { build.logAppend("# "); - toStringDetailed(ctx, block, blockIndex); + toStringDetailed(ctx, block, blockIndex, /* includeUseInfo */ true); } build.setLabel(block.label); @@ -145,7 +145,7 @@ void IrLoweringX64::lower(AssemblyOptions options) if (options.includeIr) { build.logAppend("# "); - toStringDetailed(ctx, inst, index); + toStringDetailed(ctx, inst, index, /* includeUseInfo */ true); } IrBlock& next = i + 1 < sortedBlocks.size() ? function.blocks[sortedBlocks[i + 1]] : dummy; @@ -416,7 +416,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); - RegisterX64 lhs = regOp(inst.a); + ScopedRegX64 optLhsTmp{regs}; + RegisterX64 lhs; + + if (inst.a.kind == IrOpKind::Constant) + { + optLhsTmp.alloc(SizeX64::xmmword); + + build.vmovsd(optLhsTmp.reg, memRegDoubleOp(inst.a)); + lhs = optLhsTmp.reg; + } + else + { + lhs = regOp(inst.a); + } if (inst.b.kind == IrOpKind::Inst) { @@ -444,14 +457,15 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, IrBlock& next) { inst.regX64 = regs.allocXmmRegOrReuse(index, {inst.a, inst.b}); - ScopedRegX64 tmp{regs, SizeX64::xmmword}; - + ScopedRegX64 optLhsTmp{regs}; RegisterX64 lhs; if (inst.a.kind == IrOpKind::Constant) { - build.vmovsd(tmp.reg, memRegDoubleOp(inst.a)); - lhs = tmp.reg; + optLhsTmp.alloc(SizeX64::xmmword); + + build.vmovsd(optLhsTmp.reg, memRegDoubleOp(inst.a)); + lhs = optLhsTmp.reg; } else { diff --git a/CodeGen/src/IrRegAllocX64.cpp b/CodeGen/src/IrRegAllocX64.cpp index 9186780..c527d03 100644 --- a/CodeGen/src/IrRegAllocX64.cpp +++ b/CodeGen/src/IrRegAllocX64.cpp @@ -169,13 +169,17 @@ void IrRegAllocX64::assertAllFree() const LUAU_ASSERT(free); } +ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner) + : owner(owner) + , reg(noreg) +{ +} + ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, SizeX64 size) : owner(owner) + , reg(noreg) { - if (size == SizeX64::xmmword) - reg = owner.allocXmmReg(); - else - reg = owner.allocGprReg(size); + alloc(size); } ScopedRegX64::ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg) @@ -190,6 +194,16 @@ ScopedRegX64::~ScopedRegX64() owner.freeReg(reg); } +void ScopedRegX64::alloc(SizeX64 size) +{ + LUAU_ASSERT(reg == noreg); + + if (size == SizeX64::xmmword) + reg = owner.allocXmmReg(); + else + reg = owner.allocGprReg(size); +} + void ScopedRegX64::free() { LUAU_ASSERT(reg != noreg); diff --git a/CodeGen/src/IrRegAllocX64.h b/CodeGen/src/IrRegAllocX64.h index ac072a3..497bb03 100644 --- a/CodeGen/src/IrRegAllocX64.h +++ b/CodeGen/src/IrRegAllocX64.h @@ -40,6 +40,7 @@ struct IrRegAllocX64 struct ScopedRegX64 { + explicit ScopedRegX64(IrRegAllocX64& owner); ScopedRegX64(IrRegAllocX64& owner, SizeX64 size); ScopedRegX64(IrRegAllocX64& owner, RegisterX64 reg); ~ScopedRegX64(); @@ -47,6 +48,7 @@ struct ScopedRegX64 ScopedRegX64(const ScopedRegX64&) = delete; ScopedRegX64& operator=(const ScopedRegX64&) = delete; + void alloc(SizeX64 size); void free(); IrRegAllocX64& owner; diff --git a/CodeGen/src/IrUtils.cpp b/CodeGen/src/IrUtils.cpp index 0808ad0..d8115be 100644 --- a/CodeGen/src/IrUtils.cpp +++ b/CodeGen/src/IrUtils.cpp @@ -14,6 +14,29 @@ namespace Luau namespace CodeGen { +static void removeInstUse(IrFunction& function, uint32_t instIdx) +{ + IrInst& inst = function.instructions[instIdx]; + + LUAU_ASSERT(inst.useCount); + inst.useCount--; + + if (inst.useCount == 0) + kill(function, inst); +} + +static void removeBlockUse(IrFunction& function, uint32_t blockIdx) +{ + IrBlock& block = function.blocks[blockIdx]; + + LUAU_ASSERT(block.useCount); + block.useCount--; + + // Entry block is never removed because is has an implicit use + if (block.useCount == 0 && blockIdx != 0) + kill(function, block); +} + void addUse(IrFunction& function, IrOp op) { if (op.kind == IrOpKind::Inst) @@ -25,9 +48,9 @@ void addUse(IrFunction& function, IrOp op) void removeUse(IrFunction& function, IrOp op) { if (op.kind == IrOpKind::Inst) - removeUse(function, function.instructions[op.index]); + removeInstUse(function, op.index); else if (op.kind == IrOpKind::Block) - removeUse(function, function.blocks[op.index]); + removeBlockUse(function, op.index); } bool isGCO(uint8_t tag) @@ -83,24 +106,6 @@ void kill(IrFunction& function, IrBlock& block) block.finish = ~0u; } -void removeUse(IrFunction& function, IrInst& inst) -{ - LUAU_ASSERT(inst.useCount); - inst.useCount--; - - if (inst.useCount == 0) - kill(function, inst); -} - -void removeUse(IrFunction& function, IrBlock& block) -{ - LUAU_ASSERT(block.useCount); - block.useCount--; - - if (block.useCount == 0) - kill(function, block); -} - void replace(IrFunction& function, IrOp& original, IrOp replacement) { // Add use before removing new one if that's the last one keeping target operand alive @@ -122,6 +127,9 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl addUse(function, replacement.e); addUse(function, replacement.f); + // An extra reference is added so block will not remove itself + block.useCount++; + // If we introduced an earlier terminating instruction, all following instructions become dead if (!isBlockTerminator(inst.cmd) && isBlockTerminator(replacement.cmd)) { @@ -142,6 +150,10 @@ void replace(IrFunction& function, IrBlock& block, uint32_t instIdx, IrInst repl removeUse(function, inst.f); inst = replacement; + + // Removing the earlier extra reference, this might leave the block without users without marking it as dead + // This will have to be handled by separate dead code elimination + block.useCount--; } void substitute(IrFunction& function, IrInst& inst, IrOp replacement) diff --git a/Common/include/Luau/ExperimentalFlags.h b/Common/include/Luau/ExperimentalFlags.h index afd3640..35e11ca 100644 --- a/Common/include/Luau/ExperimentalFlags.h +++ b/Common/include/Luau/ExperimentalFlags.h @@ -12,7 +12,6 @@ inline bool isFlagExperimental(const char* flag) // or critical bugs that are found after the code has been submitted. static const char* const kList[] = { "LuauInstantiateInSubtyping", // requires some fixes to lua-apps code - "LuauTryhardAnd", // waiting for a fix in graphql-lua -> apollo-client-lia -> lua-apps "LuauTypecheckTypeguards", // requires some fixes to lua-apps code (CLI-67030) // makes sure we always have at least one entry nullptr, diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index dcf5692..5d6b760 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -20,6 +20,18 @@ #define LUAU_FASTMATH_END #endif +// Some functions like floor/ceil have SSE4.1 equivalents but we currently support systems without SSE4.1 +// Note that we only need to do this when SSE4.1 support is not guaranteed by compiler settings, as otherwise compiler will optimize these for us. +#if (defined(__x86_64__) || defined(_M_X64)) && !defined(__SSE4_1__) && !defined(__AVX__) +#if defined(_MSC_VER) && !defined(__clang__) +#define LUAU_TARGET_SSE41 +#elif defined(__GNUC__) && defined(__has_attribute) +#if __has_attribute(target) +#define LUAU_TARGET_SSE41 __attribute__((target("sse4.1"))) +#endif +#endif +#endif + // Used on functions that have a printf-like interface to validate them statically #if defined(__GNUC__) #define LUA_PRINTF_ATTR(fmt, arg) __attribute__((format(printf, fmt, arg))) diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 71869b1..3c669bf 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -15,6 +15,16 @@ #include #endif +#ifdef LUAU_TARGET_SSE41 +#include + +#ifndef _MSC_VER +#include // on MSVC this comes from intrin.h +#endif +#endif + +LUAU_FASTFLAGVARIABLE(LuauBuiltinSSE41, false) + // luauF functions implement FASTCALL instruction that performs a direct execution of some builtin functions from the VM // The rule of thumb is that FASTCALL functions can not call user code, yield, fail, or reallocate stack. // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path @@ -95,7 +105,9 @@ static int luauF_atan(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } +// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN +LUAU_NOINLINE static int luauF_ceil(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -158,7 +170,9 @@ static int luauF_exp(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } +// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN +LUAU_NOINLINE static int luauF_floor(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -935,7 +949,9 @@ static int luauF_sign(lua_State* L, StkId res, TValue* arg0, int nresults, StkId return -1; } +// TODO: LUAU_NOINLINE can be removed with LuauBuiltinSSE41 LUAU_FASTMATH_BEGIN +LUAU_NOINLINE static int luauF_round(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) @@ -1244,6 +1260,78 @@ static int luauF_missing(lua_State* L, StkId res, TValue* arg0, int nresults, St return -1; } +#ifdef LUAU_TARGET_SSE41 +template +LUAU_TARGET_SSE41 inline double roundsd_sse41(double v) +{ + __m128d av = _mm_set_sd(v); + __m128d rv = _mm_round_sd(av, av, Rounding | _MM_FROUND_NO_EXC); + return _mm_cvtsd_f64(rv); +} + +LUAU_TARGET_SSE41 static int luauF_floor_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (!FFlag::LuauBuiltinSSE41) + return luauF_floor(L, res, arg0, nresults, args, nparams); + + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, roundsd_sse41<_MM_FROUND_TO_NEG_INF>(a1)); + return 1; + } + + return -1; +} + +LUAU_TARGET_SSE41 static int luauF_ceil_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (!FFlag::LuauBuiltinSSE41) + return luauF_ceil(L, res, arg0, nresults, args, nparams); + + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + setnvalue(res, roundsd_sse41<_MM_FROUND_TO_POS_INF>(a1)); + return 1; + } + + return -1; +} + +LUAU_TARGET_SSE41 static int luauF_round_sse41(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (!FFlag::LuauBuiltinSSE41) + return luauF_round(L, res, arg0, nresults, args, nparams); + + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + // roundsd only supports bankers rounding natively, so we need to emulate rounding by using truncation + // offset is prevfloat(0.5), which is important so that we round prevfloat(0.5) to 0. + const double offset = 0.49999999999999994; + setnvalue(res, roundsd_sse41<_MM_FROUND_TO_ZERO>(a1 + (a1 < 0 ? -offset : offset))); + return 1; + } + + return -1; +} + +static bool luau_hassse41() +{ + int cpuinfo[4] = {}; +#ifdef _MSC_VER + __cpuid(cpuinfo, 1); +#else + __cpuid(1, cpuinfo[0], cpuinfo[1], cpuinfo[2], cpuinfo[3]); +#endif + + // We requre SSE4.1 support for ROUNDSD + // https://en.wikipedia.org/wiki/CPUID#EAX=1:_Processor_Info_and_Feature_Bits + return (cpuinfo[2] & (1 << 19)) != 0; +} +#endif + const luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1253,12 +1341,24 @@ const luau_FastFunction luauF_table[256] = { luauF_asin, luauF_atan2, luauF_atan, + +#ifdef LUAU_TARGET_SSE41 + luau_hassse41() ? luauF_ceil_sse41 : luauF_ceil, +#else luauF_ceil, +#endif + luauF_cosh, luauF_cos, luauF_deg, luauF_exp, + +#ifdef LUAU_TARGET_SSE41 + luau_hassse41() ? luauF_floor_sse41 : luauF_floor, +#else luauF_floor, +#endif + luauF_fmod, luauF_frexp, luauF_ldexp, @@ -1300,7 +1400,12 @@ const luau_FastFunction luauF_table[256] = { luauF_clamp, luauF_sign, + +#ifdef LUAU_TARGET_SSE41 + luau_hassse41() ? luauF_round_sse41 : luauF_round, +#else luauF_round, +#endif luauF_rawset, luauF_rawget, diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 779ab4c..9d086f5 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCheckGetInfoIndex, false) - static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -176,18 +174,9 @@ int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar) CallInfo* ci = NULL; if (level < 0) { - if (FFlag::LuauCheckGetInfoIndex) - { - const TValue* func = luaA_toobject(L, level); - api_check(L, ttisfunction(func)); - f = clvalue(func); - } - else - { - StkId func = L->top + level; - api_check(L, ttisfunction(func)); - f = clvalue(func); - } + const TValue* func = luaA_toobject(L, level); + api_check(L, ttisfunction(func)); + f = clvalue(func); } else if (unsigned(level) < unsigned(L->ci - L->base_ci)) { diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 12b834c..2d4e327 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -343,7 +343,8 @@ static float perlin(float x, float y, float z) int bb = p[b + 1] + zi; return math_lerp(w, - math_lerp(v, math_lerp(u, grad(p[aa], xf, yf, zf), grad(p[ba], xf - 1, yf, zf)), math_lerp(u, grad(p[ab], xf, yf - 1, zf), grad(p[bb], xf - 1, yf - 1, zf))), + math_lerp(v, math_lerp(u, grad(p[aa], xf, yf, zf), grad(p[ba], xf - 1, yf, zf)), + math_lerp(u, grad(p[ab], xf, yf - 1, zf), grad(p[bb], xf - 1, yf - 1, zf))), math_lerp(v, math_lerp(u, grad(p[aa + 1], xf, yf, zf - 1), grad(p[ba + 1], xf - 1, yf, zf - 1)), math_lerp(u, grad(p[ab + 1], xf, yf - 1, zf - 1), grad(p[bb + 1], xf - 1, yf - 1, zf - 1)))); } diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 0efa9ee..ddee3a7 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -10,6 +10,8 @@ #include "ldebug.h" #include "lvm.h" +LUAU_FASTFLAGVARIABLE(LuauOptimizedSort, false) + static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -305,12 +307,14 @@ static int tunpack(lua_State* L) static void set2(lua_State* L, int i, int j) { + LUAU_ASSERT(!FFlag::LuauOptimizedSort); lua_rawseti(L, 1, i); lua_rawseti(L, 1, j); } static int sort_comp(lua_State* L, int a, int b) { + LUAU_ASSERT(!FFlag::LuauOptimizedSort); if (!lua_isnil(L, 2)) { // function? int res; @@ -328,6 +332,7 @@ static int sort_comp(lua_State* L, int a, int b) static void auxsort(lua_State* L, int l, int u) { + LUAU_ASSERT(!FFlag::LuauOptimizedSort); while (l < u) { // for tail recursion int i, j; @@ -407,16 +412,145 @@ static void auxsort(lua_State* L, int l, int u) } // repeat the routine for the larger one } -static int sort(lua_State* L) +typedef int (*SortPredicate)(lua_State* L, const TValue* l, const TValue* r); + +static int sort_func(lua_State* L, const TValue* l, const TValue* r) { - luaL_checktype(L, 1, LUA_TTABLE); - int n = lua_objlen(L, 1); - luaL_checkstack(L, 40, ""); // assume array is smaller than 2^40 - if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? - luaL_checktype(L, 2, LUA_TFUNCTION); - lua_settop(L, 2); // make sure there is two arguments - auxsort(L, 1, n); - return 0; + LUAU_ASSERT(L->top == L->base + 2); // table, function + + setobj2s(L, L->top, &L->base[1]); + setobj2s(L, L->top + 1, l); + setobj2s(L, L->top + 2, r); + L->top += 3; // safe because of LUA_MINSTACK guarantee + luaD_call(L, L->top - 3, 1); + L->top -= 1; // maintain stack depth + + return !l_isfalse(L->top); +} + +inline void sort_swap(lua_State* L, Table* t, int i, int j) +{ + TValue* arr = t->array; + int n = t->sizearray; + LUAU_ASSERT(unsigned(i) < unsigned(n) && unsigned(j) < unsigned(n)); // contract maintained in sort_less after predicate call + + // no barrier required because both elements are in the array before and after the swap + TValue temp; + setobj2s(L, &temp, &arr[i]); + setobj2t(L, &arr[i], &arr[j]); + setobj2t(L, &arr[j], &temp); +} + +inline int sort_less(lua_State* L, Table* t, int i, int j, SortPredicate pred) +{ + TValue* arr = t->array; + int n = t->sizearray; + LUAU_ASSERT(unsigned(i) < unsigned(n) && unsigned(j) < unsigned(n)); // contract maintained in sort_less after predicate call + + int res = pred(L, &arr[i], &arr[j]); + + // predicate call may resize the table, which is invalid + if (t->sizearray != n) + luaL_error(L, "table modified during sorting"); + + return res; +} + +static void sort_rec(lua_State* L, Table* t, int l, int u, SortPredicate pred) +{ + // sort range [l..u] (inclusive, 0-based) + while (l < u) + { + int i, j; + // sort elements a[l], a[(l+u)/2] and a[u] + if (sort_less(L, t, u, l, pred)) // a[u] < a[l]? + sort_swap(L, t, u, l); // swap a[l] - a[u] + if (u - l == 1) + break; // only 2 elements + i = l + ((u - l) >> 1); // midpoint + if (sort_less(L, t, i, l, pred)) // a[i]= P + while (sort_less(L, t, ++i, p, pred)) + { + if (i >= u) + luaL_error(L, "invalid order function for sorting"); + } + // repeat --j until a[j] <= P + while (sort_less(L, t, p, --j, pred)) + { + if (j <= l) + luaL_error(L, "invalid order function for sorting"); + } + if (j < i) + break; + sort_swap(L, t, i, j); + } + // swap pivot (a[u-1]) with a[i], which is the new midpoint + sort_swap(L, t, u - 1, i); + // a[l..i-1] <= a[i] == P <= a[i+1..u] + // adjust so that smaller half is in [j..i] and larger one in [l..u] + if (i - l < u - i) + { + j = l; + i = i - 1; + l = i + 2; + } + else + { + j = i + 1; + i = u; + u = j - 2; + } + sort_rec(L, t, j, i, pred); // call recursively the smaller one + } // repeat the routine for the larger one +} + +static int tsort(lua_State* L) +{ + if (FFlag::LuauOptimizedSort) + { + luaL_checktype(L, 1, LUA_TTABLE); + Table* t = hvalue(L->base); + int n = luaH_getn(t); + if (t->readonly) + luaG_readonlyerror(L); + + SortPredicate pred = luaV_lessthan; + if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? + { + luaL_checktype(L, 2, LUA_TFUNCTION); + pred = sort_func; + } + lua_settop(L, 2); // make sure there are two arguments + + if (n > 0) + sort_rec(L, t, 0, n - 1, pred); + return 0; + } + else + { + luaL_checktype(L, 1, LUA_TTABLE); + int n = lua_objlen(L, 1); + luaL_checkstack(L, 40, ""); // assume array is smaller than 2^40 + if (!lua_isnoneornil(L, 2)) // is there a 2nd argument? + luaL_checktype(L, 2, LUA_TFUNCTION); + lua_settop(L, 2); // make sure there is two arguments + auxsort(L, 1, n); + return 0; + } } // }====================================================== @@ -530,7 +664,7 @@ static const luaL_Reg tab_funcs[] = { {"maxn", maxn}, {"insert", tinsert}, {"remove", tremove}, - {"sort", sort}, + {"sort", tsort}, {"pack", tpack}, {"unpack", tunpack}, {"move", tmove}, diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 6f600e9..c8a184a 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -145,15 +145,16 @@ LUAU_NOINLINE void luau_callhook(lua_State* L, lua_Hook hook, void* userdata) L->base = L->ci->base; } + // note: the pc expectations of the hook are matching the general "pc points to next instruction" + // however, for the hook to be able to continue execution from the same point, this is called with savedpc at the *current* instruction + // this needs to be called before luaD_checkstack in case it fails to reallocate stack + if (L->ci->savedpc) + L->ci->savedpc++; + luaD_checkstack(L, LUA_MINSTACK); // ensure minimum stack size L->ci->top = L->top + LUA_MINSTACK; LUAU_ASSERT(L->ci->top <= L->stack_last); - // note: the pc expectations of the hook are matching the general "pc points to next instruction" - // however, for the hook to be able to continue execution from the same point, this is called with savedpc at the *current* instruction - if (L->ci->savedpc) - L->ci->savedpc++; - Closure* cl = clvalue(L->ci->func); lua_Debug ar; diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 05d3975..b77207d 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -201,15 +201,23 @@ static const TValue* get_compTM(lua_State* L, Table* mt1, Table* mt2, TMS event) return NULL; } -static int call_orderTM(lua_State* L, const TValue* p1, const TValue* p2, TMS event) +static int call_orderTM(lua_State* L, const TValue* p1, const TValue* p2, TMS event, bool error = false) { const TValue* tm1 = luaT_gettmbyobj(L, p1, event); const TValue* tm2; if (ttisnil(tm1)) + { + if (error) + luaG_ordererror(L, p1, p2, event); return -1; // no metamethod? + } tm2 = luaT_gettmbyobj(L, p2, event); if (!luaO_rawequalObj(tm1, tm2)) // different metamethods? + { + if (error) + luaG_ordererror(L, p1, p2, event); return -1; + } callTMres(L, L->top, tm1, p1, p2); return !l_isfalse(L->top); } @@ -239,16 +247,14 @@ int luaV_strcmp(const TString* ls, const TString* rs) int luaV_lessthan(lua_State* L, const TValue* l, const TValue* r) { - int res; - if (ttype(l) != ttype(r)) + if (LUAU_UNLIKELY(ttype(l) != ttype(r))) luaG_ordererror(L, l, r, TM_LT); - else if (ttisnumber(l)) + else if (LUAU_LIKELY(ttisnumber(l))) return luai_numlt(nvalue(l), nvalue(r)); else if (ttisstring(l)) return luaV_strcmp(tsvalue(l), tsvalue(r)) < 0; - else if ((res = call_orderTM(L, l, r, TM_LT)) == -1) - luaG_ordererror(L, l, r, TM_LT); - return res; + else + return call_orderTM(L, l, r, TM_LT, /* error= */ true); } int luaV_lessequal(lua_State* L, const TValue* l, const TValue* r) diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 5ca1657..c94f088 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -97,38 +97,39 @@ lua_State* createGlobalState() return L; } -int registerTypes(Luau::TypeChecker& env) +int registerTypes(Luau::TypeChecker& typeChecker, Luau::GlobalTypes& globals) { using namespace Luau; using std::nullopt; - Luau::registerBuiltinGlobals(env); + Luau::registerBuiltinGlobals(typeChecker, globals); - TypeArena& arena = env.globalTypes; + TypeArena& arena = globals.globalTypes; + BuiltinTypes& builtinTypes = *globals.builtinTypes; // Vector3 stub TypeId vector3MetaType = arena.addType(TableType{}); TypeId vector3InstanceType = arena.addType(ClassType{"Vector3", {}, nullopt, vector3MetaType, {}, {}, "Test"}); getMutable(vector3InstanceType)->props = { - {"X", {env.numberType}}, - {"Y", {env.numberType}}, - {"Z", {env.numberType}}, + {"X", {builtinTypes.numberType}}, + {"Y", {builtinTypes.numberType}}, + {"Z", {builtinTypes.numberType}}, }; getMutable(vector3MetaType)->props = { {"__add", {makeFunction(arena, nullopt, {vector3InstanceType, vector3InstanceType}, {vector3InstanceType})}}, }; - env.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; + globals.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vector3InstanceType}; // Instance stub TypeId instanceType = arena.addType(ClassType{"Instance", {}, nullopt, nullopt, {}, {}, "Test"}); getMutable(instanceType)->props = { - {"Name", {env.stringType}}, + {"Name", {builtinTypes.stringType}}, }; - env.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; + globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; // Part stub TypeId partType = arena.addType(ClassType{"Part", {}, instanceType, nullopt, {}, {}, "Test"}); @@ -136,9 +137,9 @@ int registerTypes(Luau::TypeChecker& env) {"Position", {vector3InstanceType}}, }; - env.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, partType}; + globals.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, partType}; - for (const auto& [_, fun] : env.globalScope->exportedTypeBindings) + for (const auto& [_, fun] : globals.globalScope->exportedTypeBindings) persist(fun.type); return 0; @@ -146,11 +147,11 @@ int registerTypes(Luau::TypeChecker& env) static void setupFrontend(Luau::Frontend& frontend) { - registerTypes(frontend.typeChecker); - Luau::freeze(frontend.typeChecker.globalTypes); + registerTypes(frontend.typeChecker, frontend.globals); + Luau::freeze(frontend.globals.globalTypes); - registerTypes(frontend.typeCheckerForAutocomplete); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + registerTypes(frontend.typeCheckerForAutocomplete, frontend.globalsForAutocomplete); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); frontend.iceHandler.onInternalError = [](const char* error) { printf("ICE: %s\n", error); @@ -264,6 +265,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) static Luau::Frontend frontend(&fileResolver, &configResolver, options); static int once = (setupFrontend(frontend), 0); + (void)once; // restart frontend.clear(); @@ -302,7 +304,7 @@ DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down // note: it's important for typeck to be destroyed at this point! - for (auto& p : frontend.typeChecker.globalScope->bindings) + for (auto& p : frontend.globals.globalScope->bindings) { Luau::ToStringOptions opts; opts.exhaustive = true; diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index a642334..25521e3 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -282,4 +282,17 @@ TEST_CASE_FIXTURE(Fixture, "Luau_selectively_query_for_a_different_boolean_2") REQUIRE(snd->value == true); } +TEST_CASE_FIXTURE(Fixture, "include_types_ancestry") +{ + check("local x: number = 4;"); + const Position pos(0, 10); + + std::vector ancestryNoTypes = findAstAncestryOfPosition(*getMainSourceModule(), pos); + std::vector ancestryTypes = findAstAncestryOfPosition(*getMainSourceModule(), pos, true); + + CHECK(ancestryTypes.size() > ancestryNoTypes.size()); + CHECK(!ancestryNoTypes.back()->asType()); + CHECK(ancestryTypes.back()->asType()); +} + TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 85bd550..aedb50a 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,8 +15,6 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauFixAutocompleteInWhile) -LUAU_FASTFLAG(LuauFixAutocompleteInFor) using namespace Luau; @@ -85,10 +83,11 @@ struct ACFixtureImpl : BaseType LoadDefinitionFileResult loadDefinition(const std::string& source) { - TypeChecker& typeChecker = this->frontend.typeCheckerForAutocomplete; - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, source, "@test"); - freeze(typeChecker.globalTypes); + GlobalTypes& globals = this->frontend.globalsForAutocomplete; + unfreeze(globals.globalTypes); + LoadDefinitionFileResult result = + loadDefinitionFile(this->frontend.typeChecker, globals, globals.globalScope, source, "@test", /* captureComments */ false); + freeze(globals.globalTypes); REQUIRE_MESSAGE(result.success, "loadDefinition: unable to load definition file"); return result; @@ -110,10 +109,10 @@ struct ACFixture : ACFixtureImpl ACFixture() : ACFixtureImpl() { - addGlobalBinding(frontend, "table", Binding{typeChecker.anyType}); - addGlobalBinding(frontend, "math", Binding{typeChecker.anyType}); - addGlobalBinding(frontend.typeCheckerForAutocomplete, "table", Binding{typeChecker.anyType}); - addGlobalBinding(frontend.typeCheckerForAutocomplete, "math", Binding{typeChecker.anyType}); + addGlobalBinding(frontend.globals, "table", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globals, "math", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "table", Binding{builtinTypes->anyType}); + addGlobalBinding(frontend.globalsForAutocomplete, "math", Binding{builtinTypes->anyType}); } }; @@ -630,19 +629,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") )"); auto ac5 = autocomplete('1'); - if (FFlag::LuauFixAutocompleteInFor) - { - CHECK_EQ(ac5.entryMap.count("math"), 1); - CHECK_EQ(ac5.entryMap.count("do"), 0); - CHECK_EQ(ac5.entryMap.count("end"), 0); - CHECK_EQ(ac5.context, AutocompleteContext::Expression); - } - else - { - CHECK_EQ(ac5.entryMap.count("do"), 1); - CHECK_EQ(ac5.entryMap.count("end"), 0); - CHECK_EQ(ac5.context, AutocompleteContext::Keyword); - } + CHECK_EQ(ac5.entryMap.count("math"), 1); + CHECK_EQ(ac5.entryMap.count("do"), 0); + CHECK_EQ(ac5.entryMap.count("end"), 0); + CHECK_EQ(ac5.context, AutocompleteContext::Expression); check(R"( for x = 1, 2, 5 f@1 @@ -661,29 +651,26 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") CHECK_EQ(ac7.entryMap.count("end"), 1); CHECK_EQ(ac7.context, AutocompleteContext::Statement); - if (FFlag::LuauFixAutocompleteInFor) + check(R"(local Foo = 1 + for x = @11, @22, @35 + )"); + + for (int i = 0; i < 3; ++i) { - check(R"(local Foo = 1 - for x = @11, @22, @35 - )"); + auto ac8 = autocomplete('1' + i); + CHECK_EQ(ac8.entryMap.count("Foo"), 1); + CHECK_EQ(ac8.entryMap.count("do"), 0); + } - for (int i = 0; i < 3; ++i) - { - auto ac8 = autocomplete('1' + i); - CHECK_EQ(ac8.entryMap.count("Foo"), 1); - CHECK_EQ(ac8.entryMap.count("do"), 0); - } + check(R"(local Foo = 1 + for x = @11, @22 + )"); - check(R"(local Foo = 1 - for x = @11, @22 - )"); - - for (int i = 0; i < 2; ++i) - { - auto ac9 = autocomplete('1' + i); - CHECK_EQ(ac9.entryMap.count("Foo"), 1); - CHECK_EQ(ac9.entryMap.count("do"), 0); - } + for (int i = 0; i < 2; ++i) + { + auto ac9 = autocomplete('1' + i); + CHECK_EQ(ac9.entryMap.count("Foo"), 1); + CHECK_EQ(ac9.entryMap.count("do"), 0); } } @@ -776,18 +763,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") )"); auto ac2 = autocomplete('1'); - if (FFlag::LuauFixAutocompleteInWhile) - { - CHECK_EQ(3, ac2.entryMap.size()); - CHECK_EQ(ac2.entryMap.count("do"), 1); - CHECK_EQ(ac2.entryMap.count("and"), 1); - CHECK_EQ(ac2.entryMap.count("or"), 1); - } - else - { - CHECK_EQ(1, ac2.entryMap.size()); - CHECK_EQ(ac2.entryMap.count("do"), 1); - } + CHECK_EQ(3, ac2.entryMap.size()); + CHECK_EQ(ac2.entryMap.count("do"), 1); + CHECK_EQ(ac2.entryMap.count("and"), 1); + CHECK_EQ(ac2.entryMap.count("or"), 1); CHECK_EQ(ac2.context, AutocompleteContext::Keyword); check(R"( @@ -803,31 +782,20 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") )"); auto ac4 = autocomplete('1'); - if (FFlag::LuauFixAutocompleteInWhile) - { - CHECK_EQ(3, ac4.entryMap.size()); - CHECK_EQ(ac4.entryMap.count("do"), 1); - CHECK_EQ(ac4.entryMap.count("and"), 1); - CHECK_EQ(ac4.entryMap.count("or"), 1); - } - else - { - CHECK_EQ(1, ac4.entryMap.size()); - CHECK_EQ(ac4.entryMap.count("do"), 1); - } + CHECK_EQ(3, ac4.entryMap.size()); + CHECK_EQ(ac4.entryMap.count("do"), 1); + CHECK_EQ(ac4.entryMap.count("and"), 1); + CHECK_EQ(ac4.entryMap.count("or"), 1); CHECK_EQ(ac4.context, AutocompleteContext::Keyword); - if (FFlag::LuauFixAutocompleteInWhile) - { - check(R"( - while t@1 - )"); + check(R"( + while t@1 + )"); - auto ac5 = autocomplete('1'); - CHECK_EQ(ac5.entryMap.count("do"), 0); - CHECK_EQ(ac5.entryMap.count("true"), 1); - CHECK_EQ(ac5.entryMap.count("false"), 1); - } + auto ac5 = autocomplete('1'); + CHECK_EQ(ac5.entryMap.count("do"), 0); + CHECK_EQ(ac5.entryMap.count("true"), 1); + CHECK_EQ(ac5.entryMap.count("false"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") @@ -3460,11 +3428,11 @@ TEST_CASE_FIXTURE(ACFixture, "string_contents_is_available_to_callback") declare function require(path: string): any )"); - std::optional require = frontend.typeCheckerForAutocomplete.globalScope->linearSearchForBinding("require"); + std::optional require = frontend.globalsForAutocomplete.globalScope->linearSearchForBinding("require"); REQUIRE(require); - Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); attachTag(require->typeId, "RequireCall"); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); check(R"( local x = require("testing/@1") diff --git a/tests/BuiltinDefinitions.test.cpp b/tests/BuiltinDefinitions.test.cpp index 188f219..75d054b 100644 --- a/tests/BuiltinDefinitions.test.cpp +++ b/tests/BuiltinDefinitions.test.cpp @@ -12,9 +12,9 @@ TEST_SUITE_BEGIN("BuiltinDefinitionsTest"); TEST_CASE_FIXTURE(BuiltinsFixture, "lib_documentation_symbols") { - CHECK(!typeChecker.globalScope->bindings.empty()); + CHECK(!frontend.globals.globalScope->bindings.empty()); - for (const auto& [name, binding] : typeChecker.globalScope->bindings) + for (const auto& [name, binding] : frontend.globals.globalScope->bindings) { std::string nameString(name.c_str()); std::string expectedRootSymbol = "@luau/global/" + nameString; diff --git a/tests/ClassFixture.cpp b/tests/ClassFixture.cpp index 087b88d..caf773b 100644 --- a/tests/ClassFixture.cpp +++ b/tests/ClassFixture.cpp @@ -11,8 +11,9 @@ namespace Luau ClassFixture::ClassFixture() { - TypeArena& arena = typeChecker.globalTypes; - TypeId numberType = typeChecker.numberType; + GlobalTypes& globals = frontend.globals; + TypeArena& arena = globals.globalTypes; + TypeId numberType = builtinTypes->numberType; unfreeze(arena); @@ -28,47 +29,47 @@ ClassFixture::ClassFixture() {"Clone", {makeFunction(arena, nullopt, {baseClassInstanceType}, {baseClassInstanceType})}}, {"New", {makeFunction(arena, nullopt, {}, {baseClassInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; - addGlobalBinding(frontend, "BaseClass", baseClassType, "@test"); + globals.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + addGlobalBinding(globals, "BaseClass", baseClassType, "@test"); TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { - {"Method", {makeFunction(arena, childClassInstanceType, {}, {typeChecker.stringType})}}, + {"Method", {makeFunction(arena, childClassInstanceType, {}, {builtinTypes->stringType})}}, }; TypeId childClassType = arena.addType(ClassType{"ChildClass", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(childClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {childClassInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; - addGlobalBinding(frontend, "ChildClass", childClassType, "@test"); + globals.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + addGlobalBinding(globals, "ChildClass", childClassType, "@test"); TypeId grandChildInstanceType = arena.addType(ClassType{"GrandChild", {}, childClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(grandChildInstanceType)->props = { - {"Method", {makeFunction(arena, grandChildInstanceType, {}, {typeChecker.stringType})}}, + {"Method", {makeFunction(arena, grandChildInstanceType, {}, {builtinTypes->stringType})}}, }; TypeId grandChildType = arena.addType(ClassType{"GrandChild", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(grandChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {grandChildInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; - addGlobalBinding(frontend, "GrandChild", childClassType, "@test"); + globals.globalScope->exportedTypeBindings["GrandChild"] = TypeFun{{}, grandChildInstanceType}; + addGlobalBinding(globals, "GrandChild", childClassType, "@test"); TypeId anotherChildInstanceType = arena.addType(ClassType{"AnotherChild", {}, baseClassInstanceType, nullopt, {}, {}, "Test"}); getMutable(anotherChildInstanceType)->props = { - {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {typeChecker.stringType})}}, + {"Method", {makeFunction(arena, anotherChildInstanceType, {}, {builtinTypes->stringType})}}, }; TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, baseClassType, nullopt, {}, {}, "Test"}); getMutable(anotherChildType)->props = { {"New", {makeFunction(arena, nullopt, {}, {anotherChildInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; - addGlobalBinding(frontend, "AnotherChild", childClassType, "@test"); + globals.globalScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildInstanceType}; + addGlobalBinding(globals, "AnotherChild", childClassType, "@test"); TypeId unrelatedClassInstanceType = arena.addType(ClassType{"UnrelatedClass", {}, nullopt, nullopt, {}, {}, "Test"}); @@ -76,8 +77,8 @@ ClassFixture::ClassFixture() getMutable(unrelatedClassType)->props = { {"New", {makeFunction(arena, nullopt, {}, {unrelatedClassInstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["UnrelatedClass"] = TypeFun{{}, unrelatedClassInstanceType}; - addGlobalBinding(frontend, "UnrelatedClass", unrelatedClassType, "@test"); + globals.globalScope->exportedTypeBindings["UnrelatedClass"] = TypeFun{{}, unrelatedClassInstanceType}; + addGlobalBinding(globals, "UnrelatedClass", unrelatedClassType, "@test"); TypeId vector2MetaType = arena.addType(TableType{}); @@ -94,17 +95,17 @@ ClassFixture::ClassFixture() getMutable(vector2MetaType)->props = { {"__add", {makeFunction(arena, nullopt, {vector2InstanceType, vector2InstanceType}, {vector2InstanceType})}}, }; - typeChecker.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; - addGlobalBinding(frontend, "Vector2", vector2Type, "@test"); + globals.globalScope->exportedTypeBindings["Vector2"] = TypeFun{{}, vector2InstanceType}; + addGlobalBinding(globals, "Vector2", vector2Type, "@test"); TypeId callableClassMetaType = arena.addType(TableType{}); TypeId callableClassType = arena.addType(ClassType{"CallableClass", {}, nullopt, callableClassMetaType, {}, {}, "Test"}); getMutable(callableClassMetaType)->props = { - {"__call", {makeFunction(arena, nullopt, {callableClassType, typeChecker.stringType}, {typeChecker.numberType})}}, + {"__call", {makeFunction(arena, nullopt, {callableClassType, builtinTypes->stringType}, {builtinTypes->numberType})}}, }; - typeChecker.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; + globals.globalScope->exportedTypeBindings["CallableClass"] = TypeFun{{}, callableClassType}; - for (const auto& [name, tf] : typeChecker.globalScope->exportedTypeBindings) + for (const auto& [name, tf] : globals.globalScope->exportedTypeBindings) persist(tf.type); freeze(arena); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 077310a..7d5f41a 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -503,14 +503,15 @@ TEST_CASE("Types") Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; Luau::BuiltinTypes builtinTypes; - Luau::TypeChecker env(&moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); + Luau::GlobalTypes globals{Luau::NotNull{&builtinTypes}}; + Luau::TypeChecker env(globals, &moduleResolver, Luau::NotNull{&builtinTypes}, &iceHandler); - Luau::registerBuiltinGlobals(env); - Luau::freeze(env.globalTypes); + Luau::registerBuiltinGlobals(env, globals); + Luau::freeze(globals.globalTypes); lua_newtable(L); - for (const auto& [name, binding] : env.globalScope->bindings) + for (const auto& [name, binding] : globals.globalScope->bindings) { populateRTTI(L, binding.typeId); lua_setfield(L, -2, toString(name).c_str()); diff --git a/tests/ConstraintGraphBuilderFixture.cpp b/tests/ConstraintGraphBuilderFixture.cpp index a9a43f0..81e5c41 100644 --- a/tests/ConstraintGraphBuilderFixture.cpp +++ b/tests/ConstraintGraphBuilderFixture.cpp @@ -22,7 +22,7 @@ void ConstraintGraphBuilderFixture::generateConstraints(const std::string& code) AstStatBlock* root = parse(code); dfg = std::make_unique(DataFlowGraphBuilder::build(root, NotNull{&ice})); cgb = std::make_unique("MainModule", mainModule, &arena, NotNull(&moduleResolver), builtinTypes, NotNull(&ice), - frontend.getGlobalScope(), &logger, NotNull{dfg.get()}); + frontend.globals.globalScope, &logger, NotNull{dfg.get()}); cgb->visit(root); rootScope = cgb->rootScope; constraints = Luau::borrowConstraints(cgb->constraints); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index cbceabb..a9c94ee 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -138,17 +138,16 @@ Fixture::Fixture(bool freeze, bool prepareAutocomplete) : sff_DebugLuauFreezeArena("DebugLuauFreezeArena", freeze) , frontend(&fileResolver, &configResolver, {/* retainFullTypeGraphs= */ true, /* forAutocomplete */ false, /* randomConstraintResolutionSeed */ randomSeed}) - , typeChecker(frontend.typeChecker) , builtinTypes(frontend.builtinTypes) { configResolver.defaultConfig.mode = Mode::Strict; configResolver.defaultConfig.enabledLint.warningMask = ~0ull; configResolver.defaultConfig.parseOptions.captureComments = true; - registerBuiltinTypes(frontend); + registerBuiltinTypes(frontend.globals); - Luau::freeze(frontend.typeChecker.globalTypes); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::freeze(frontend.globals.globalTypes); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); Luau::setPrintLine([](auto s) {}); } @@ -178,11 +177,11 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars if (FFlag::DebugLuauDeferredConstraintResolution) { - Luau::check(*sourceModule, {}, frontend.builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, - typeChecker.globalScope, frontend.options); + Luau::check(*sourceModule, {}, builtinTypes, NotNull{&ice}, NotNull{&moduleResolver}, NotNull{&fileResolver}, + frontend.globals.globalScope, frontend.options); } else - typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); + frontend.typeChecker.check(*sourceModule, sourceModule->mode.value_or(Luau::Mode::Nonstrict)); } throw ParseErrors(result.errors); @@ -447,9 +446,9 @@ void Fixture::dumpErrors(std::ostream& os, const std::vector& errors) void Fixture::registerTestTypes() { - addGlobalBinding(frontend, "game", typeChecker.anyType, "@luau"); - addGlobalBinding(frontend, "workspace", typeChecker.anyType, "@luau"); - addGlobalBinding(frontend, "script", typeChecker.anyType, "@luau"); + addGlobalBinding(frontend.globals, "game", builtinTypes->anyType, "@luau"); + addGlobalBinding(frontend.globals, "workspace", builtinTypes->anyType, "@luau"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@luau"); } void Fixture::dumpErrors(const CheckResult& cr) @@ -499,9 +498,9 @@ void Fixture::validateErrors(const std::vector& errors) LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test"); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = frontend.loadDefinitionFile(source, "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); if (result.module) dumpErrors(result.module); @@ -512,16 +511,16 @@ LoadDefinitionFileResult Fixture::loadDefinition(const std::string& source) BuiltinsFixture::BuiltinsFixture(bool freeze, bool prepareAutocomplete) : Fixture(freeze, prepareAutocomplete) { - Luau::unfreeze(frontend.typeChecker.globalTypes); - Luau::unfreeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::unfreeze(frontend.globals.globalTypes); + Luau::unfreeze(frontend.globalsForAutocomplete.globalTypes); registerBuiltinGlobals(frontend); if (prepareAutocomplete) - registerBuiltinGlobals(frontend.typeCheckerForAutocomplete); + registerBuiltinGlobals(frontend.typeCheckerForAutocomplete, frontend.globalsForAutocomplete); registerTestTypes(); - Luau::freeze(frontend.typeChecker.globalTypes); - Luau::freeze(frontend.typeCheckerForAutocomplete.globalTypes); + Luau::freeze(frontend.globals.globalTypes); + Luau::freeze(frontend.globalsForAutocomplete.globalTypes); } ModuleName fromString(std::string_view name) @@ -581,23 +580,31 @@ std::optional linearSearchForBinding(Scope* scope, const char* name) void registerHiddenTypes(Frontend* frontend) { - TypeId t = frontend->globalTypes.addType(GenericType{"T"}); + GlobalTypes& globals = frontend->globals; + + unfreeze(globals.globalTypes); + + TypeId t = globals.globalTypes.addType(GenericType{"T"}); GenericTypeDefinition genericT{t}; - ScopePtr globalScope = frontend->getGlobalScope(); - globalScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, frontend->globalTypes.addType(NegationType{t})}; + ScopePtr globalScope = globals.globalScope; + globalScope->exportedTypeBindings["Not"] = TypeFun{{genericT}, globals.globalTypes.addType(NegationType{t})}; globalScope->exportedTypeBindings["fun"] = TypeFun{{}, frontend->builtinTypes->functionType}; globalScope->exportedTypeBindings["cls"] = TypeFun{{}, frontend->builtinTypes->classType}; globalScope->exportedTypeBindings["err"] = TypeFun{{}, frontend->builtinTypes->errorType}; globalScope->exportedTypeBindings["tbl"] = TypeFun{{}, frontend->builtinTypes->tableType}; + + freeze(globals.globalTypes); } void createSomeClasses(Frontend* frontend) { - TypeArena& arena = frontend->globalTypes; + GlobalTypes& globals = frontend->globals; + + TypeArena& arena = globals.globalTypes; unfreeze(arena); - ScopePtr moduleScope = frontend->getGlobalScope(); + ScopePtr moduleScope = globals.globalScope; TypeId parentType = arena.addType(ClassType{"Parent", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); @@ -606,22 +613,22 @@ void createSomeClasses(Frontend* frontend) parentClass->props["virtual_method"] = {makeFunction(arena, parentType, {}, {})}; - addGlobalBinding(*frontend, "Parent", {parentType}); + addGlobalBinding(globals, "Parent", {parentType}); moduleScope->exportedTypeBindings["Parent"] = TypeFun{{}, parentType}; TypeId childType = arena.addType(ClassType{"Child", {}, parentType, std::nullopt, {}, nullptr, "Test"}); - addGlobalBinding(*frontend, "Child", {childType}); + addGlobalBinding(globals, "Child", {childType}); moduleScope->exportedTypeBindings["Child"] = TypeFun{{}, childType}; TypeId anotherChildType = arena.addType(ClassType{"AnotherChild", {}, parentType, std::nullopt, {}, nullptr, "Test"}); - addGlobalBinding(*frontend, "AnotherChild", {anotherChildType}); + addGlobalBinding(globals, "AnotherChild", {anotherChildType}); moduleScope->exportedTypeBindings["AnotherChild"] = TypeFun{{}, anotherChildType}; TypeId unrelatedType = arena.addType(ClassType{"Unrelated", {}, frontend->builtinTypes->classType, std::nullopt, {}, nullptr, "Test"}); - addGlobalBinding(*frontend, "Unrelated", {unrelatedType}); + addGlobalBinding(globals, "Unrelated", {unrelatedType}); moduleScope->exportedTypeBindings["Unrelated"] = TypeFun{{}, unrelatedType}; for (const auto& [name, ty] : moduleScope->exportedTypeBindings) diff --git a/tests/Fixture.h b/tests/Fixture.h index a81a5e7..5db6ed1 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -101,7 +101,6 @@ struct Fixture std::unique_ptr sourceModule; Frontend frontend; InternalErrorReporter ice; - TypeChecker& typeChecker; NotNull builtinTypes; std::string decorateWithTypes(const std::string& code); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 1d31b28..e09990f 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -81,8 +81,8 @@ struct FrontendFixture : BuiltinsFixture { FrontendFixture() { - addGlobalBinding(frontend, "game", frontend.typeChecker.anyType, "@test"); - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "game", builtinTypes->anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); } }; @@ -852,12 +852,12 @@ TEST_CASE_FIXTURE(FrontendFixture, "environments") { ScopePtr testScope = frontend.addEnvironment("test"); - unfreeze(typeChecker.globalTypes); - loadDefinitionFile(typeChecker, testScope, R"( + unfreeze(frontend.globals.globalTypes); + loadDefinitionFile(frontend.typeChecker, frontend.globals, testScope, R"( export type Foo = number | string )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); fileResolver.source["A"] = R"( --!nonstrict diff --git a/tests/IrBuilder.test.cpp b/tests/IrBuilder.test.cpp index 0896517..41146d7 100644 --- a/tests/IrBuilder.test.cpp +++ b/tests/IrBuilder.test.cpp @@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptCheckTag") optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into CHECK_TAG - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: CHECK_TAG R2, tnil, bb_fallback_1 CHECK_TAG K5, tnil, bb_fallback_1 @@ -135,7 +135,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptBinaryArith") optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into second argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_DOUBLE R1 %2 = ADD_NUM %0, R2 @@ -165,7 +165,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag1") optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into first argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %1 = LOAD_TAG R2 JUMP_EQ_TAG R1, %1, bb_1, bb_2 @@ -202,7 +202,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag2") // Load from memory is 'inlined' into second argument is it can't be done for the first one // We also swap first and second argument to generate memory access on the LHS - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R1 STORE_TAG R6, %0 @@ -239,7 +239,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptEqTag3") optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into first argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_POINTER R1 %1 = GET_ARR_ADDR %0, 0i @@ -276,7 +276,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FinalX64OptJumpCmpNum") optimizeMemoryOperandsX64(build.function); // Load from memory is 'inlined' into first argument - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %1 = LOAD_DOUBLE R2 JUMP_CMP_NUM R1, %1, bb_1, bb_2 @@ -328,7 +328,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Numeric") updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R0, 30i STORE_INT R0, -2147483648i @@ -374,7 +374,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ControlFlowEq") updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: JUMP bb_1 @@ -423,7 +423,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumToIndex") updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R0, 4i LOP_RETURN 0u @@ -458,7 +458,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "Guards") updateUseCounts(build.function); constantFold(); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: LOP_RETURN 0u @@ -579,7 +579,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTagsAndValues") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber STORE_INT R1, 10i @@ -625,7 +625,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "PropagateThroughTvalue") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber STORE_DOUBLE R0, 0.5 @@ -655,7 +655,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipCheckTag") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber LOP_RETURN 0u @@ -682,7 +682,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipOncePerBlockChecks") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: CHECK_SAFE_ENV CHECK_GC @@ -721,7 +721,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RememberTableState") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_POINTER R0 CHECK_NO_METATABLE %0, bb_fallback_1 @@ -753,7 +753,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SkipUselessBarriers") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber LOP_RETURN 0u @@ -782,7 +782,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "ConcatInvalidation") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber STORE_INT R1, 10i @@ -829,7 +829,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "BuiltinFastcallsMayInvalidateMemory") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_DOUBLE R0, 0.5 %1 = LOAD_POINTER R0 @@ -862,32 +862,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( -bb_0: - STORE_INT R0, 10i - STORE_DOUBLE R0, 0.5 - STORE_INT R0, 10i - LOP_RETURN 0u - -)"); -} - -TEST_CASE_FIXTURE(IrBuilderFixture, "RedundantStoreCheckConstantType") -{ - IrOp block = build.block(IrBlockKind::Internal); - - build.beginBlock(block); - - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10)); - build.inst(IrCmd::STORE_DOUBLE, build.vmReg(0), build.constDouble(0.5)); - build.inst(IrCmd::STORE_INT, build.vmReg(0), build.constInt(10)); - - build.inst(IrCmd::LOP_RETURN, build.constUint(0)); - - updateUseCounts(build.function); - constPropInBlockChains(build); - - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R0, 10i STORE_DOUBLE R0, 0.5 @@ -917,7 +892,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagation") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R0 CHECK_TAG %0, tnumber, bb_fallback_1 @@ -949,7 +924,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagCheckPropagationConflicting") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R0 CHECK_TAG %0, tnumber, bb_fallback_1 @@ -985,7 +960,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TruthyTestRemoval") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R1 CHECK_TAG %0, tnumber, bb_fallback_3 @@ -1024,7 +999,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "FalsyTestRemoval") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R1 CHECK_TAG %0, tnumber, bb_fallback_3 @@ -1059,7 +1034,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "TagEqRemoval") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R1 CHECK_TAG %0, tboolean @@ -1091,7 +1066,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "IntEqRemoval") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_INT R1, 5i JUMP bb_1 @@ -1122,7 +1097,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NumCmpRemoval") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_DOUBLE R1, 4 JUMP bb_2 @@ -1150,7 +1125,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataFlowsThroughDirectJumpToUniqueSuccessor updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber JUMP bb_1 @@ -1183,7 +1158,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "DataDoesNotFlowThroughDirectJumpToNonUnique updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber JUMP bb_1 @@ -1199,6 +1174,120 @@ bb_2: )"); } +TEST_CASE_FIXTURE(IrBuilderFixture, "EntryBlockUseRemoval") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + IrOp repeat = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(repeat); + build.inst(IrCmd::INTERRUPT, build.constUint(0)); + build.inst(IrCmd::JUMP, entry); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + STORE_TAG R0, tnumber + JUMP bb_1 + +bb_1: + LOP_RETURN 0u, R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval1") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp block = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + IrOp repeat = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(block); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit, repeat); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(repeat); + build.inst(IrCmd::INTERRUPT, build.constUint(0)); + build.inst(IrCmd::JUMP, block); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + LOP_RETURN 0u, R0, 0i + +bb_1: + STORE_TAG R0, tnumber + JUMP bb_2 + +bb_2: + LOP_RETURN 0u, R0, 0i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "RecursiveSccUseRemoval2") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit1 = build.block(IrBlockKind::Internal); + IrOp block = build.block(IrBlockKind::Internal); + IrOp exit2 = build.block(IrBlockKind::Internal); + IrOp repeat = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::JUMP_EQ_INT, build.constInt(0), build.constInt(1), block, exit1); + + build.beginBlock(exit1); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(block); + build.inst(IrCmd::STORE_TAG, build.vmReg(0), build.constTag(tnumber)); + build.inst(IrCmd::JUMP_IF_TRUTHY, build.vmReg(0), exit2, repeat); + + build.beginBlock(exit2); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(0)); + + build.beginBlock(repeat); + build.inst(IrCmd::INTERRUPT, build.constUint(0)); + build.inst(IrCmd::JUMP, block); + + updateUseCounts(build.function); + constPropInBlockChains(build); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: + JUMP bb_1 + +bb_1: + LOP_RETURN 0u, R0, 0i + +bb_2: + STORE_TAG R0, tnumber + JUMP bb_3 + +bb_3: + LOP_RETURN 0u, R0, 0i + +)"); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("LinearExecutionFlowExtraction"); @@ -1240,7 +1329,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "SimplePathExtraction") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R2 CHECK_TAG %0, tnumber, bb_fallback_1 @@ -1315,7 +1404,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "NoPathExtractionForBlocksWithLiveOutValues" updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: %0 = LOAD_TAG R2 CHECK_TAG %0, tnumber, bb_fallback_1 @@ -1366,7 +1455,7 @@ TEST_CASE_FIXTURE(IrBuilderFixture, "InfiniteLoopInPathAnalysis") updateUseCounts(build.function); constPropInBlockChains(build); - CHECK("\n" + toString(build.function, /* includeDetails */ false) == R"( + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( bb_0: STORE_TAG R0, tnumber JUMP bb_1 @@ -1379,3 +1468,212 @@ bb_1: } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("Analysis"); + +TEST_CASE_FIXTURE(IrBuilderFixture, "SimpleDiamond") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp a = build.block(IrBlockKind::Internal); + IrOp b = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::JUMP_EQ_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), a, b); + + build.beginBlock(a); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(2), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(b); + build.inst(IrCmd::STORE_TVALUE, build.vmReg(3), build.inst(IrCmd::LOAD_TVALUE, build.vmReg(1))); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(2), build.constInt(2)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1, bb_2 +; in regs: R0, R1, R2, R3 +; out regs: R1, R2, R3 + %0 = LOAD_TAG R0 + JUMP_EQ_TAG %0, tnumber, bb_1, bb_2 + +bb_1: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R3 +; out regs: R2, R3 + %2 = LOAD_TVALUE R1 + STORE_TVALUE R2, %2 + JUMP bb_3 + +bb_2: +; predecessors: bb_0 +; successors: bb_3 +; in regs: R1, R2 +; out regs: R2, R3 + %5 = LOAD_TVALUE R1 + STORE_TVALUE R3, %5 + JUMP bb_3 + +bb_3: +; predecessors: bb_1, bb_2 +; in regs: R2, R3 + LOP_RETURN 0u, R2, 2i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ImplicitFixedRegistersInVarargCall") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(3), build.constInt(-1)); + build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(5)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(5)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1 +; in regs: R0, R1, R2 +; out regs: R0, R1, R2, R3, R4 + FALLBACK_GETVARARGS 0u, R3, -1i + LOP_CALL 0u, R0, -1i, 5i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0, R1, R2, R3, R4 + LOP_RETURN 0u, R0, 5i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "ExplicitUseOfRegisterInVarargSequence") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); + build.inst(IrCmd::INVOKE_FASTCALL, build.constUint(0), build.vmReg(0), build.vmReg(1), build.vmReg(2), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1 +; out regs: R0... + FALLBACK_GETVARARGS 0u, R1, -1i + %1 = INVOKE_FASTCALL 0u, R0, R1, R2, -1i, -1i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0... + LOP_RETURN 0u, R0, -1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "VariadicSequenceRestart") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(1), build.constInt(0), build.constInt(-1)); + build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_1 +; in regs: R0, R1 +; out regs: R0... + LOP_CALL 0u, R1, 0i, -1i + LOP_CALL 0u, R0, -1i, -1i + JUMP bb_1 + +bb_1: +; predecessors: bb_0 +; in regs: R0... + LOP_RETURN 0u, R0, -1i + +)"); +} + +TEST_CASE_FIXTURE(IrBuilderFixture, "FallbackDoesNotFlowUp") +{ + IrOp entry = build.block(IrBlockKind::Internal); + IrOp fallback = build.block(IrBlockKind::Fallback); + IrOp exit = build.block(IrBlockKind::Internal); + + build.beginBlock(entry); + build.inst(IrCmd::FALLBACK_GETVARARGS, build.constUint(0), build.vmReg(1), build.constInt(-1)); + build.inst(IrCmd::CHECK_TAG, build.inst(IrCmd::LOAD_TAG, build.vmReg(0)), build.constTag(tnumber), fallback); + build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(fallback); + build.inst(IrCmd::LOP_CALL, build.constUint(0), build.vmReg(0), build.constInt(-1), build.constInt(-1)); + build.inst(IrCmd::JUMP, exit); + + build.beginBlock(exit); + build.inst(IrCmd::LOP_RETURN, build.constUint(0), build.vmReg(0), build.constInt(-1)); + + updateUseCounts(build.function); + computeCfgInfo(build.function); + + CHECK("\n" + toString(build.function, /* includeUseInfo */ false) == R"( +bb_0: +; successors: bb_fallback_1, bb_2 +; in regs: R0 +; out regs: R0... + FALLBACK_GETVARARGS 0u, R1, -1i + %1 = LOAD_TAG R0 + CHECK_TAG %1, tnumber, bb_fallback_1 + LOP_CALL 0u, R0, -1i, -1i + JUMP bb_2 + +bb_fallback_1: +; predecessors: bb_0 +; successors: bb_2 +; in regs: R0, R1... +; out regs: R0... + LOP_CALL 0u, R0, -1i, -1i + JUMP bb_2 + +bb_2: +; predecessors: bb_0, bb_fallback_1 +; in regs: R0... + LOP_RETURN 0u, R0, -1i + +)"); +} + +TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index c716982..ebd004d 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -35,7 +35,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobal") TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobal") { // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(frontend, "Wait", Binding{typeChecker.anyType, {}, true, "wait", "@test/global/Wait"}); + addGlobalBinding(frontend.globals, "Wait", Binding{builtinTypes->anyType, {}, true, "wait", "@test/global/Wait"}); LintResult result = lint("Wait(5)"); @@ -47,7 +47,7 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedGlobalNoReplacement") { // Normally this would be defined externally, so hack it in for testing const char* deprecationReplacementString = ""; - addGlobalBinding(frontend, "Version", Binding{typeChecker.anyType, {}, true, deprecationReplacementString}); + addGlobalBinding(frontend.globals, "Version", Binding{builtinTypes->anyType, {}, true, deprecationReplacementString}); LintResult result = lint("Version()"); @@ -373,7 +373,7 @@ return bar() TEST_CASE_FIXTURE(Fixture, "ImportUnused") { // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(frontend, "game", typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "game", builtinTypes->anyType, "@test"); LintResult result = lint(R"( local Roact = require(game.Packages.Roact) @@ -604,16 +604,16 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { - unfreeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); TableType::Props instanceProps{ - {"ClassName", {typeChecker.anyType}}, + {"ClassName", {builtinTypes->anyType}}, }; - TableType instanceTable{instanceProps, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}; - TypeId instanceType = typeChecker.globalTypes.addType(instanceTable); + TableType instanceTable{instanceProps, std::nullopt, frontend.globals.globalScope->level, Luau::TableState::Sealed}; + TypeId instanceType = frontend.globals.globalTypes.addType(instanceTable); TypeFun instanceTypeFun{{}, instanceType}; - typeChecker.globalScope->exportedTypeBindings["Part"] = instanceTypeFun; + frontend.globals.globalScope->exportedTypeBindings["Part"] = instanceTypeFun; LintResult result = lint(R"( local game = ... @@ -1270,12 +1270,12 @@ TEST_CASE_FIXTURE(Fixture, "no_spurious_warning_after_a_function_type_alias") TEST_CASE_FIXTURE(Fixture, "use_all_parent_scopes_for_globals") { ScopePtr testScope = frontend.addEnvironment("Test"); - unfreeze(typeChecker.globalTypes); - loadDefinitionFile(frontend.typeChecker, testScope, R"( + unfreeze(frontend.globals.globalTypes); + loadDefinitionFile(frontend.typeChecker, frontend.globals, testScope, R"( declare Foo: number )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); fileResolver.environments["A"] = "Test"; @@ -1444,31 +1444,32 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiTyped") { ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); - unfreeze(typeChecker.globalTypes); - TypeId instanceType = typeChecker.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); + unfreeze(frontend.globals.globalTypes); + TypeId instanceType = frontend.globals.globalTypes.addType(ClassType{"Instance", {}, std::nullopt, std::nullopt, {}, {}, "Test"}); persist(instanceType); - typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; + frontend.globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, instanceType}; getMutable(instanceType)->props = { - {"Name", {typeChecker.stringType}}, - {"DataCost", {typeChecker.numberType, /* deprecated= */ true}}, - {"Wait", {typeChecker.anyType, /* deprecated= */ true}}, + {"Name", {builtinTypes->stringType}}, + {"DataCost", {builtinTypes->numberType, /* deprecated= */ true}}, + {"Wait", {builtinTypes->anyType, /* deprecated= */ true}}, }; - TypeId colorType = typeChecker.globalTypes.addType(TableType{{}, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}); + TypeId colorType = + frontend.globals.globalTypes.addType(TableType{{}, std::nullopt, frontend.globals.globalScope->level, Luau::TableState::Sealed}); - getMutable(colorType)->props = {{"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; + getMutable(colorType)->props = {{"toHSV", {builtinTypes->anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; - addGlobalBinding(frontend, "Color3", Binding{colorType, {}}); + addGlobalBinding(frontend.globals, "Color3", Binding{colorType, {}}); - if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(frontend.globals, "table"))) { ttv->props["foreach"].deprecated = true; ttv->props["getn"].deprecated = true; ttv->props["getn"].deprecatedSuggestion = "#"; } - freeze(typeChecker.globalTypes); + freeze(frontend.globals.globalTypes); LintResult result = lintTyped(R"( return function (i: Instance) @@ -1495,7 +1496,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "DeprecatedApiUntyped") { ScopedFastFlag sff("LuauImproveDeprecatedApiLint", true); - if (TableType* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + if (TableType* ttv = getMutable(getGlobalBinding(frontend.globals, "table"))) { ttv->props["foreach"].deprecated = true; ttv->props["getn"].deprecated = true; diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 2c45cc3..d2796b6 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -48,8 +48,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") CloneState cloneState; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, cloneState); - CHECK_EQ(newNumber, typeChecker.numberType); + TypeId newNumber = clone(builtinTypes->numberType, dest, cloneState); + CHECK_EQ(newNumber, builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") @@ -58,9 +58,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") CloneState cloneState; // Create a new number type that isn't persistent - unfreeze(typeChecker.globalTypes); - TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveType{PrimitiveType::Number}); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + TypeId oldNumber = frontend.globals.globalTypes.addType(PrimitiveType{PrimitiveType::Number}); + freeze(frontend.globals.globalTypes); TypeId newNumber = clone(oldNumber, dest, cloneState); CHECK_NE(newNumber, oldNumber); @@ -170,10 +170,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_point_into_globalTypes_arena") REQUIRE(signType != nullptr); CHECK(!isInArena(signType, module->interfaceTypes)); - if (FFlag::DebugLuauDeferredConstraintResolution) - CHECK(isInArena(signType, frontend.globalTypes)); - else - CHECK(isInArena(signType, typeChecker.globalTypes)); + CHECK(isInArena(signType, frontend.globals.globalTypes)); } TEST_CASE_FIXTURE(Fixture, "deepClone_union") @@ -181,9 +178,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") TypeArena dest; CloneState cloneState; - unfreeze(typeChecker.globalTypes); - TypeId oldUnion = typeChecker.globalTypes.addType(UnionType{{typeChecker.numberType, typeChecker.stringType}}); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + TypeId oldUnion = frontend.globals.globalTypes.addType(UnionType{{builtinTypes->numberType, builtinTypes->stringType}}); + freeze(frontend.globals.globalTypes); TypeId newUnion = clone(oldUnion, dest, cloneState); CHECK_NE(newUnion, oldUnion); @@ -196,9 +193,9 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") TypeArena dest; CloneState cloneState; - unfreeze(typeChecker.globalTypes); - TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionType{{typeChecker.numberType, typeChecker.stringType}}); - freeze(typeChecker.globalTypes); + unfreeze(frontend.globals.globalTypes); + TypeId oldIntersection = frontend.globals.globalTypes.addType(IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}); + freeze(frontend.globals.globalTypes); TypeId newIntersection = clone(oldIntersection, dest, cloneState); CHECK_NE(newIntersection, oldIntersection); @@ -210,13 +207,13 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") { Type exampleMetaClass{ClassType{"ExampleClassMeta", { - {"__add", {typeChecker.anyType}}, + {"__add", {builtinTypes->anyType}}, }, std::nullopt, std::nullopt, {}, {}, "Test"}}; Type exampleClass{ClassType{"ExampleClass", { - {"PropOne", {typeChecker.numberType}}, - {"PropTwo", {typeChecker.stringType}}, + {"PropOne", {builtinTypes->numberType}}, + {"PropTwo", {builtinTypes->stringType}}, }, std::nullopt, &exampleMetaClass, {}, {}, "Test"}}; diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index a84e263..fddab80 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -64,7 +64,7 @@ TEST_CASE_FIXTURE(Fixture, "return_annotation_is_still_checked") LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE_NE(*typeChecker.anyType, *requireType("foo")); + REQUIRE_NE(*builtinTypes->anyType, *requireType("foo")); } #endif @@ -107,7 +107,7 @@ TEST_CASE_FIXTURE(Fixture, "locals_are_any_by_default") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.anyType, *requireType("m")); + CHECK_EQ(*builtinTypes->anyType, *requireType("m")); } TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") @@ -173,7 +173,7 @@ TEST_CASE_FIXTURE(Fixture, "table_props_are_any") TypeId fooProp = ttv->props["foo"].type; REQUIRE(fooProp != nullptr); - CHECK_EQ(*fooProp, *typeChecker.anyType); + CHECK_EQ(*fooProp, *builtinTypes->anyType); } TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") @@ -192,8 +192,8 @@ TEST_CASE_FIXTURE(Fixture, "inline_table_props_are_also_any") TableType* ttv = getMutable(requireType("T")); REQUIRE_MESSAGE(ttv, "Should be a table: " << toString(requireType("T"))); - CHECK_EQ(*typeChecker.anyType, *ttv->props["one"].type); - CHECK_EQ(*typeChecker.anyType, *ttv->props["two"].type); + CHECK_EQ(*builtinTypes->anyType, *ttv->props["one"].type); + CHECK_EQ(*builtinTypes->anyType, *ttv->props["two"].type); CHECK_MESSAGE(get(follow(ttv->props["three"].type)), "Should be a function: " << *ttv->props["three"].type); } diff --git a/tests/Normalize.test.cpp b/tests/Normalize.test.cpp index c45932c..b86af0e 100644 --- a/tests/Normalize.test.cpp +++ b/tests/Normalize.test.cpp @@ -325,9 +325,9 @@ TEST_CASE_FIXTURE(IsSubtypeFixture, "classes") check(""); // Ensure that we have a main Module. - TypeId p = typeChecker.globalScope->lookupType("Parent")->type; - TypeId c = typeChecker.globalScope->lookupType("Child")->type; - TypeId u = typeChecker.globalScope->lookupType("Unrelated")->type; + TypeId p = frontend.globals.globalScope->lookupType("Parent")->type; + TypeId c = frontend.globals.globalScope->lookupType("Child")->type; + TypeId u = frontend.globals.globalScope->lookupType("Unrelated")->type; CHECK(isSubtype(c, p)); CHECK(!isSubtype(p, c)); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 11dca11..73ae477 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -15,7 +15,7 @@ struct ToDotClassFixture : Fixture { ToDotClassFixture() { - TypeArena& arena = typeChecker.globalTypes; + TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); @@ -23,17 +23,17 @@ struct ToDotClassFixture : Fixture TypeId baseClassInstanceType = arena.addType(ClassType{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}, "Test"}); getMutable(baseClassInstanceType)->props = { - {"BaseField", {typeChecker.numberType}}, + {"BaseField", {builtinTypes->numberType}}, }; - typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + frontend.globals.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; TypeId childClassInstanceType = arena.addType(ClassType{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}, "Test"}); getMutable(childClassInstanceType)->props = { - {"ChildField", {typeChecker.stringType}}, + {"ChildField", {builtinTypes->stringType}}, }; - typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + frontend.globals.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; - for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + for (const auto& [name, ty] : frontend.globals.globalScope->exportedTypeBindings) persist(ty.type); freeze(arena); @@ -373,7 +373,7 @@ n1 [label="GenericTypePack T"]; TEST_CASE_FIXTURE(Fixture, "bound_pack") { - TypePackVar pack{TypePackVariant{TypePack{{typeChecker.numberType}, {}}}}; + TypePackVar pack{TypePackVariant{TypePack{{builtinTypes->numberType}, {}}}}; TypePackVar bound{TypePackVariant{BoundTypePack{&pack}}}; ToDotOptions opts; diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 2fc5187..fd24539 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -184,27 +184,27 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "exhaustive_toString_of_cyclic_table") TEST_CASE_FIXTURE(Fixture, "intersection_parenthesized_only_if_needed") { - auto utv = Type{UnionType{{typeChecker.numberType, typeChecker.stringType}}}; - auto itv = Type{IntersectionType{{&utv, typeChecker.booleanType}}}; + auto utv = Type{UnionType{{builtinTypes->numberType, builtinTypes->stringType}}}; + auto itv = Type{IntersectionType{{&utv, builtinTypes->booleanType}}}; CHECK_EQ(toString(&itv), "(number | string) & boolean"); } TEST_CASE_FIXTURE(Fixture, "union_parenthesized_only_if_needed") { - auto itv = Type{IntersectionType{{typeChecker.numberType, typeChecker.stringType}}}; - auto utv = Type{UnionType{{&itv, typeChecker.booleanType}}}; + auto itv = Type{IntersectionType{{builtinTypes->numberType, builtinTypes->stringType}}}; + auto utv = Type{UnionType{{&itv, builtinTypes->booleanType}}}; CHECK_EQ(toString(&utv), "(number & string) | boolean"); } TEST_CASE_FIXTURE(Fixture, "functions_are_always_parenthesized_in_unions_or_intersections") { - auto stringAndNumberPack = TypePackVar{TypePack{{typeChecker.stringType, typeChecker.numberType}}}; - auto numberAndStringPack = TypePackVar{TypePack{{typeChecker.numberType, typeChecker.stringType}}}; + auto stringAndNumberPack = TypePackVar{TypePack{{builtinTypes->stringType, builtinTypes->numberType}}}; + auto numberAndStringPack = TypePackVar{TypePack{{builtinTypes->numberType, builtinTypes->stringType}}}; auto sn2ns = Type{FunctionType{&stringAndNumberPack, &numberAndStringPack}}; - auto ns2sn = Type{FunctionType(typeChecker.globalScope->level, &numberAndStringPack, &stringAndNumberPack)}; + auto ns2sn = Type{FunctionType(frontend.globals.globalScope->level, &numberAndStringPack, &stringAndNumberPack)}; auto utv = Type{UnionType{{&ns2sn, &sn2ns}}}; auto itv = Type{IntersectionType{{&ns2sn, &sn2ns}}}; @@ -250,7 +250,7 @@ TEST_CASE_FIXTURE(Fixture, "quit_stringifying_table_type_when_length_is_exceeded { TableType ttv{}; for (char c : std::string("abcdefghijklmno")) - ttv.props[std::string(1, c)] = {typeChecker.numberType}; + ttv.props[std::string(1, c)] = {builtinTypes->numberType}; Type tv{ttv}; @@ -264,7 +264,7 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_is_still_capped_when_exhaust { TableType ttv{}; for (char c : std::string("abcdefg")) - ttv.props[std::string(1, c)] = {typeChecker.numberType}; + ttv.props[std::string(1, c)] = {builtinTypes->numberType}; Type tv{ttv}; @@ -339,7 +339,7 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table { TableType ttv{TableState::Sealed, TypeLevel{}}; for (char c : std::string("abcdefghij")) - ttv.props[std::string(1, c)] = {typeChecker.numberType}; + ttv.props[std::string(1, c)] = {builtinTypes->numberType}; Type tv{ttv}; @@ -350,7 +350,7 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_union_type_bails_early") { - Type tv{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; + Type tv{UnionType{{builtinTypes->stringType, builtinTypes->numberType}}}; UnionType* utv = getMutable(&tv); utv->options.push_back(&tv); utv->options.push_back(&tv); @@ -371,11 +371,11 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_cyclic_intersection_type_bails_early") TEST_CASE_FIXTURE(Fixture, "stringifying_array_uses_array_syntax") { TableType ttv{TableState::Sealed, TypeLevel{}}; - ttv.indexer = TableIndexer{typeChecker.numberType, typeChecker.stringType}; + ttv.indexer = TableIndexer{builtinTypes->numberType, builtinTypes->stringType}; CHECK_EQ("{string}", toString(Type{ttv})); - ttv.props["A"] = {typeChecker.numberType}; + ttv.props["A"] = {builtinTypes->numberType}; CHECK_EQ("{| [number]: string, A: number |}", toString(Type{ttv})); ttv.props.clear(); @@ -562,15 +562,15 @@ TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_T Type tv1{TableType{}}; TableType* ttv = getMutable(&tv1); ttv->state = TableState::Sealed; - ttv->props["hello"] = {typeChecker.numberType}; - ttv->props["world"] = {typeChecker.numberType}; + ttv->props["hello"] = {builtinTypes->numberType}; + ttv->props["world"] = {builtinTypes->numberType}; TypePackVar tpv1{TypePack{{&tv1}}}; Type tv2{TableType{}}; TableType* bttv = getMutable(&tv2); bttv->state = TableState::Free; - bttv->props["hello"] = {typeChecker.numberType}; + bttv->props["hello"] = {builtinTypes->numberType}; bttv->boundTo = &tv1; TypePackVar tpv2{TypePack{{&tv2}}}; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index a2fc0c7..b55c774 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -168,7 +168,7 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_whe TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("Node?", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "mutually_recursive_aliases") @@ -329,7 +329,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_typ TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("Wrapped", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") @@ -345,7 +345,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_typ TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->givenType); } // Check that recursive intersection type doesn't generate an OOM @@ -520,7 +520,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "type_alias_import_mutation") CheckResult result = check("type t10 = typeof(table)"); LUAU_REQUIRE_NO_ERRORS(result); - TypeId ty = getGlobalBinding(frontend, "table"); + TypeId ty = getGlobalBinding(frontend.globals, "table"); CHECK(toString(ty) == "typeof(table)"); @@ -922,4 +922,29 @@ TEST_CASE_FIXTURE(Fixture, "cannot_create_cyclic_type_with_unknown_module") CHECK(toString(result.errors[0]) == "Unknown type 'B.AAA'"); } +TEST_CASE_FIXTURE(Fixture, "type_alias_locations") +{ + check(R"( + type T = number + + do + type T = string + type X = boolean + end + )"); + + ModulePtr mod = getMainModule(); + REQUIRE(mod); + REQUIRE(mod->scopes.size() == 8); + + REQUIRE(mod->scopes[0].second->typeAliasNameLocations.count("T") > 0); + CHECK(mod->scopes[0].second->typeAliasNameLocations["T"] == Location(Position(1, 13), 1)); + + REQUIRE(mod->scopes[3].second->typeAliasNameLocations.count("T") > 0); + CHECK(mod->scopes[3].second->typeAliasNameLocations["T"] == Location(Position(4, 17), 1)); + + REQUIRE(mod->scopes[3].second->typeAliasNameLocations.count("X") > 0); + CHECK(mod->scopes[3].second->typeAliasNameLocations["X"] == Location(Position(5, 17), 1)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index d5f9537..2c87cb4 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -86,7 +86,7 @@ TEST_CASE_FIXTURE(Fixture, "function_return_annotations_are_checked") REQUIRE_EQ(1, tp->head.size()); - REQUIRE_EQ(typeChecker.anyType, follow(tp->head[0])); + REQUIRE_EQ(builtinTypes->anyType, follow(tp->head[0])); } TEST_CASE_FIXTURE(Fixture, "function_return_multret_annotations_are_checked") @@ -166,11 +166,12 @@ TEST_CASE_FIXTURE(Fixture, "infer_type_of_value_a_via_typeof_with_assignment") a = "foo" )"); - CHECK_EQ(*typeChecker.numberType, *requireType("a")); - CHECK_EQ(*typeChecker.numberType, *requireType("b")); + CHECK_EQ(*builtinTypes->numberType, *requireType("a")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b")); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{4, 12}, Position{4, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK_EQ( + result.errors[0], (TypeError{Location{Position{4, 12}, Position{4, 17}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); } TEST_CASE_FIXTURE(Fixture, "table_annotation") @@ -459,7 +460,7 @@ TEST_CASE_FIXTURE(Fixture, "type_alias_always_resolve_to_a_real_type") )"); TypeId fType = requireType("aa"); - REQUIRE(follow(fType) == typeChecker.numberType); + REQUIRE(follow(fType) == builtinTypes->numberType); LUAU_REQUIRE_NO_ERRORS(result); } @@ -480,7 +481,7 @@ TEST_CASE_FIXTURE(Fixture, "interface_types_belong_to_interface_arena") const TypeFun& a = mod.exportedTypeBindings["A"]; CHECK(isInArena(a.type, mod.interfaceTypes)); - CHECK(!isInArena(a.type, typeChecker.globalTypes)); + CHECK(!isInArena(a.type, frontend.globals.globalTypes)); std::optional exportsType = first(mod.returnType); REQUIRE(exportsType); @@ -559,7 +560,7 @@ TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definiti TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") { - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -585,7 +586,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "use_type_required_from_another_file") TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") { - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict @@ -611,7 +612,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_use_nonexported_type") TEST_CASE_FIXTURE(BuiltinsFixture, "builtin_types_are_not_exported") { - addGlobalBinding(frontend, "script", frontend.typeChecker.anyType, "@test"); + addGlobalBinding(frontend.globals, "script", builtinTypes->anyType, "@test"); fileResolver.source["Modules/Main"] = R"( --!strict diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp index 9988a1f..0488196 100644 --- a/tests/TypeInfer.anyerror.test.cpp +++ b/tests/TypeInfer.anyerror.test.cpp @@ -30,7 +30,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ(builtinTypes->anyType, requireType("a")); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") @@ -209,7 +209,7 @@ TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") LUAU_REQUIRE_NO_ERRORS(result); TypeId aType = requireType("A"); - CHECK_EQ(aType, typeChecker.anyType); + CHECK_EQ(aType, builtinTypes->anyType); } TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 860dcfd..5318b40 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -106,7 +106,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_concat_returns_string") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("r")); + CHECK_EQ(*builtinTypes->stringType, *requireType("r")); } TEST_CASE_FIXTURE(BuiltinsFixture, "sort") @@ -156,7 +156,7 @@ TEST_CASE_FIXTURE(Fixture, "strings_have_methods") )LUA"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_variatic") @@ -166,7 +166,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_variatic") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); } TEST_CASE_FIXTURE(BuiltinsFixture, "math_max_checks_for_numbers") @@ -365,7 +365,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_ )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.stringType, requireType("s")); + CHECK_EQ(builtinTypes->stringType, requireType("s")); } TEST_CASE_FIXTURE(BuiltinsFixture, "table_insert_correctly_infers_type_of_array_3_args_overload") @@ -429,7 +429,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "gcinfo") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); } TEST_CASE_FIXTURE(BuiltinsFixture, "getfenv") @@ -446,9 +446,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "os_time_takes_optional_date_table") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n1")); - CHECK_EQ(*typeChecker.numberType, *requireType("n2")); - CHECK_EQ(*typeChecker.numberType, *requireType("n3")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n1")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n2")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n3")); } TEST_CASE_FIXTURE(BuiltinsFixture, "thread_is_a_type") @@ -552,8 +552,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_correctly_ordered_types") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(tm->wantedType, typeChecker.stringType); - CHECK_EQ(tm->givenType, typeChecker.numberType); + CHECK_EQ(tm->wantedType, builtinTypes->stringType); + CHECK_EQ(tm->givenType, builtinTypes->numberType); } TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_tostring_specifier") @@ -722,8 +722,8 @@ TEST_CASE_FIXTURE(Fixture, "string_format_as_method") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(tm->wantedType, typeChecker.stringType); - CHECK_EQ(tm->givenType, typeChecker.numberType); + CHECK_EQ(tm->wantedType, builtinTypes->stringType); + CHECK_EQ(tm->givenType, builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "string_format_use_correct_argument") @@ -860,9 +860,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "string_format_report_all_type_errors_at_corr string.format("%s%d%s", 1, "hello", true) )"); - TypeId stringType = typeChecker.stringType; - TypeId numberType = typeChecker.numberType; - TypeId booleanType = typeChecker.booleanType; + TypeId stringType = builtinTypes->stringType; + TypeId numberType = builtinTypes->numberType; + TypeId booleanType = builtinTypes->booleanType; LUAU_REQUIRE_ERROR_COUNT(6, result); @@ -1027,7 +1027,7 @@ local function f(a: typeof(f)) end TEST_CASE_FIXTURE(BuiltinsFixture, "no_persistent_typelevel_change") { - TypeId mathTy = requireType(typeChecker.globalScope, "math"); + TypeId mathTy = requireType(frontend.globals.globalScope, "math"); REQUIRE(mathTy); TableType* ttv = getMutable(mathTy); REQUIRE(ttv); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 2a681d1..f3f4641 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -19,13 +19,13 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_simple") declare foo2: typeof(foo) )"); - TypeId globalFooTy = getGlobalBinding(frontend, "foo"); + TypeId globalFooTy = getGlobalBinding(frontend.globals, "foo"); CHECK_EQ(toString(globalFooTy), "number"); - TypeId globalBarTy = getGlobalBinding(frontend, "bar"); + TypeId globalBarTy = getGlobalBinding(frontend.globals, "bar"); CHECK_EQ(toString(globalBarTy), "(number) -> string"); - TypeId globalFoo2Ty = getGlobalBinding(frontend, "foo2"); + TypeId globalFoo2Ty = getGlobalBinding(frontend.globals, "foo2"); CHECK_EQ(toString(globalFoo2Ty), "number"); CheckResult result = check(R"( @@ -48,20 +48,20 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") declare function var(...: any): string )"); - TypeId globalFooTy = getGlobalBinding(frontend, "foo"); + TypeId globalFooTy = getGlobalBinding(frontend.globals, "foo"); CHECK_EQ(toString(globalFooTy), "number"); - std::optional globalAsdfTy = frontend.getGlobalScope()->lookupType("Asdf"); + std::optional globalAsdfTy = frontend.globals.globalScope->lookupType("Asdf"); REQUIRE(bool(globalAsdfTy)); CHECK_EQ(toString(globalAsdfTy->type), "number | string"); - TypeId globalBarTy = getGlobalBinding(frontend, "bar"); + TypeId globalBarTy = getGlobalBinding(frontend.globals, "bar"); CHECK_EQ(toString(globalBarTy), "(number) -> string"); - TypeId globalFoo2Ty = getGlobalBinding(frontend, "foo2"); + TypeId globalFoo2Ty = getGlobalBinding(frontend.globals, "foo2"); CHECK_EQ(toString(globalFoo2Ty), "number"); - TypeId globalVarTy = getGlobalBinding(frontend, "var"); + TypeId globalVarTy = getGlobalBinding(frontend.globals, "var"); CHECK_EQ(toString(globalVarTy), "(...any) -> string"); @@ -77,25 +77,25 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_loading") TEST_CASE_FIXTURE(Fixture, "load_definition_file_errors_do_not_pollute_global_scope") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult parseFailResult = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult parseFailResult = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( declare foo )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); REQUIRE(!parseFailResult.success); - std::optional fooTy = tryGetGlobalBinding(frontend, "foo"); + std::optional fooTy = tryGetGlobalBinding(frontend.globals, "foo"); CHECK(!fooTy.has_value()); - LoadDefinitionFileResult checkFailResult = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + LoadDefinitionFileResult checkFailResult = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( local foo: string = 123 declare bar: typeof(foo) )", - "@test"); + "@test", /* captureComments */ false); REQUIRE(!checkFailResult.success); - std::optional barTy = tryGetGlobalBinding(frontend, "bar"); + std::optional barTy = tryGetGlobalBinding(frontend.globals, "bar"); CHECK(!barTy.has_value()); } @@ -139,15 +139,15 @@ TEST_CASE_FIXTURE(Fixture, "definition_file_classes") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( declare class A X: number X: string end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); REQUIRE(!result.success); CHECK_EQ(result.parseResult.errors.size(), 0); @@ -160,15 +160,15 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_overload_non_function") TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( type NotAClass = {} declare class Foo extends NotAClass end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); REQUIRE(!result.success); CHECK_EQ(result.parseResult.errors.size(), 0); @@ -181,16 +181,16 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_cannot_extend_non_class") TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( declare class Foo extends Bar end declare class Bar extends Foo end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); REQUIRE(!result.success); } @@ -281,16 +281,16 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") } )"); - std::optional xBinding = typeChecker.globalScope->linearSearchForBinding("x"); + std::optional xBinding = frontend.globals.globalScope->linearSearchForBinding("x"); REQUIRE(bool(xBinding)); // note: loadDefinition uses the @test package name. CHECK_EQ(xBinding->documentationSymbol, "@test/global/x"); - std::optional fooTy = typeChecker.globalScope->lookupType("Foo"); + std::optional fooTy = frontend.globals.globalScope->lookupType("Foo"); REQUIRE(bool(fooTy)); CHECK_EQ(fooTy->type->documentationSymbol, "@test/globaltype/Foo"); - std::optional barTy = typeChecker.globalScope->lookupType("Bar"); + std::optional barTy = frontend.globals.globalScope->lookupType("Bar"); REQUIRE(bool(barTy)); CHECK_EQ(barTy->type->documentationSymbol, "@test/globaltype/Bar"); @@ -299,7 +299,7 @@ TEST_CASE_FIXTURE(Fixture, "definitions_documentation_symbols") REQUIRE_EQ(barClass->props.count("prop"), 1); CHECK_EQ(barClass->props["prop"].documentationSymbol, "@test/globaltype/Bar.prop"); - std::optional yBinding = typeChecker.globalScope->linearSearchForBinding("y"); + std::optional yBinding = frontend.globals.globalScope->linearSearchForBinding("y"); REQUIRE(bool(yBinding)); CHECK_EQ(yBinding->documentationSymbol, "@test/global/y"); @@ -319,7 +319,7 @@ TEST_CASE_FIXTURE(Fixture, "definitions_symbols_are_generated_for_recursively_re declare function myFunc(): MyClass )"); - std::optional myClassTy = typeChecker.globalScope->lookupType("MyClass"); + std::optional myClassTy = frontend.globals.globalScope->lookupType("MyClass"); REQUIRE(bool(myClassTy)); CHECK_EQ(myClassTy->type->documentationSymbol, "@test/globaltype/MyClass"); } @@ -330,7 +330,7 @@ TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_type export type Evil = string )"); - std::optional ty = typeChecker.globalScope->lookupType("Evil"); + std::optional ty = frontend.globals.globalScope->lookupType("Evil"); REQUIRE(bool(ty)); CHECK_EQ(ty->type->documentationSymbol, std::nullopt); } @@ -396,8 +396,8 @@ TEST_CASE_FIXTURE(Fixture, "class_definition_string_props") TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") { - unfreeze(typeChecker.globalTypes); - LoadDefinitionFileResult result = loadDefinitionFile(typeChecker, typeChecker.globalScope, R"( + unfreeze(frontend.globals.globalTypes); + LoadDefinitionFileResult result = loadDefinitionFile(frontend.typeChecker, frontend.globals, frontend.globals.globalScope, R"( declare class Channel Messages: { Message } OnMessage: (message: Message) -> () @@ -408,8 +408,8 @@ TEST_CASE_FIXTURE(Fixture, "class_definitions_reference_other_classes") Channel: Channel end )", - "@test"); - freeze(typeChecker.globalTypes); + "@test", /* captureComments */ false); + freeze(frontend.globals.globalTypes); REQUIRE(result.success); } diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 7c2e451..482a6b7 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -33,8 +33,8 @@ TEST_CASE_FIXTURE(Fixture, "check_function_bodies") LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 44}, Position{0, 48}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.booleanType, + builtinTypes->numberType, + builtinTypes->booleanType, }})); } @@ -70,7 +70,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_return_type") std::vector retVec = flatten(takeFiveType->retTypes).first; REQUIRE(!retVec.empty()); - REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); + REQUIRE_EQ(*follow(retVec[0]), *builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") @@ -78,7 +78,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") CheckResult result = check("function take_five() return 5 end local five = take_five()"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *follow(requireType("five"))); + CHECK_EQ(*builtinTypes->numberType, *follow(requireType("five"))); } TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") @@ -92,7 +92,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{builtinTypes->numberType}})); } TEST_CASE_FIXTURE(Fixture, "generalize_table_property") @@ -171,8 +171,8 @@ TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_ TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); ExtraInformation* ei = get(result.errors[1]); REQUIRE(ei); @@ -208,8 +208,8 @@ TEST_CASE_FIXTURE(Fixture, "dont_give_other_overloads_message_if_only_one_argume TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") @@ -847,13 +847,13 @@ TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK_EQ(result.errors[0], (TypeError{Location{Position{3, 12}, Position{3, 18}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.stringType, + builtinTypes->numberType, + builtinTypes->stringType, }})); CHECK_EQ(result.errors[1], (TypeError{Location{Position{3, 20}, Position{3, 23}}, TypeMismatch{ - typeChecker.stringType, - typeChecker.numberType, + builtinTypes->stringType, + builtinTypes->numberType, }})); } @@ -1669,6 +1669,10 @@ TEST_CASE_FIXTURE(Fixture, "dont_infer_parameter_types_for_functions_from_their_ LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("(a) -> a", toString(requireType("f"))); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK_EQ("({+ p: {+ q: a +} +}) -> a & ~false", toString(requireType("g"))); + else + CHECK_EQ("({+ p: {+ q: nil +} +}) -> nil", toString(requireType("g"))); } TEST_CASE_FIXTURE(Fixture, "dont_mutate_the_underlying_head_of_typepack_when_calling_with_self") @@ -1851,4 +1855,29 @@ end LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "dont_assert_when_the_tarjan_limit_is_exceeded_during_generalization") +{ + ScopedFastInt sfi{"LuauTarjanChildLimit", 2}; + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + {"LuauClonePublicInterfaceLess", true}, + {"LuauSubstitutionReentrant", true}, + {"LuauSubstitutionFixMissingFields", true}, + }; + + CheckResult result = check(R"( + function f(t) + t.x.y.z = 441 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_MESSAGE(get(result.errors[0]), "Expected CodeTooComplex but got: " << toString(result.errors[0])); + CHECK(Location({1, 17}, {1, 18}) == result.errors[0].location); + + CHECK_MESSAGE(get(result.errors[1]), "Expected UnificationTooComplex but got: " << toString(result.errors[1])); + CHECK(Location({0, 0}, {4, 4}) == result.errors[1].location); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 0ba889c..b3b2e4c 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -844,8 +844,8 @@ TEST_CASE_FIXTURE(Fixture, "generic_function") LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("(a) -> a", toString(requireType("id"))); - CHECK_EQ(*typeChecker.numberType, *requireType("a")); - CHECK_EQ(*typeChecker.nilType, *requireType("b")); + CHECK_EQ(*builtinTypes->numberType, *requireType("a")); + CHECK_EQ(*builtinTypes->nilType, *requireType("b")); } TEST_CASE_FIXTURE(Fixture, "generic_table_method") diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index ea6fff7..f6d04a9 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -22,8 +22,8 @@ TEST_CASE_FIXTURE(Fixture, "select_correct_union_fn") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(requireType("b"), typeChecker.stringType); - CHECK_EQ(requireType("c"), typeChecker.numberType); + CHECK_EQ(requireType("b"), builtinTypes->stringType); + CHECK_EQ(requireType("c"), builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "table_combines") @@ -123,11 +123,11 @@ TEST_CASE_FIXTURE(Fixture, "should_still_pick_an_overload_whose_arguments_are_un LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("a1"), *typeChecker.numberType); - CHECK_EQ(*requireType("a2"), *typeChecker.numberType); + CHECK_EQ(*requireType("a1"), *builtinTypes->numberType); + CHECK_EQ(*requireType("a2"), *builtinTypes->numberType); - CHECK_EQ(*requireType("b1"), *typeChecker.stringType); - CHECK_EQ(*requireType("b2"), *typeChecker.stringType); + CHECK_EQ(*requireType("b1"), *builtinTypes->stringType); + CHECK_EQ(*requireType("b2"), *builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "propagates_name") @@ -249,7 +249,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_property_of_t )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.anyType, *requireType("r")); + CHECK_EQ(*builtinTypes->anyType, *requireType("r")); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_all_parts_missing_the_property") diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp index 30cbe1d..50e9f80 100644 --- a/tests/TypeInfer.loops.test.cpp +++ b/tests/TypeInfer.loops.test.cpp @@ -28,7 +28,7 @@ TEST_CASE_FIXTURE(Fixture, "for_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("q")); + CHECK_EQ(*builtinTypes->numberType, *requireType("q")); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") @@ -44,8 +44,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") @@ -61,8 +61,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_with_next") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->numberType, *requireType("n")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") @@ -218,8 +218,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_factory_not_returning_t TypeMismatch* tm = get(result.errors[1]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); } TEST_CASE_FIXTURE(BuiltinsFixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") @@ -281,8 +281,8 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); + CHECK_EQ(builtinTypes->stringType, tm->givenType); } TEST_CASE_FIXTURE(Fixture, "while_loop") @@ -296,7 +296,7 @@ TEST_CASE_FIXTURE(Fixture, "while_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("i")); + CHECK_EQ(*builtinTypes->numberType, *requireType("i")); } TEST_CASE_FIXTURE(Fixture, "repeat_loop") @@ -310,7 +310,7 @@ TEST_CASE_FIXTURE(Fixture, "repeat_loop") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("i")); + CHECK_EQ(*builtinTypes->stringType, *requireType("i")); } TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") @@ -547,7 +547,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_basic") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("key")); + CHECK_EQ(*builtinTypes->numberType, *requireType("key")); } TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") @@ -561,7 +561,7 @@ TEST_CASE_FIXTURE(Fixture, "loop_iter_trailing_nil") )"); LUAU_REQUIRE_ERROR_COUNT(0, result); - CHECK_EQ(*typeChecker.nilType, *requireType("extra")); + CHECK_EQ(*builtinTypes->nilType, *requireType("extra")); } TEST_CASE_FIXTURE(Fixture, "loop_iter_no_indexer_strict") diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index ed3af11..ab07ee2 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -12,6 +12,7 @@ LUAU_FASTFLAG(LuauInstantiateInSubtyping) LUAU_FASTFLAG(LuauTypeMismatchInvarianceInError) +LUAU_FASTFLAG(DebugLuauDeferredConstraintResolution) using namespace Luau; @@ -482,4 +483,42 @@ return unpack(l0[_]) LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(BuiltinsFixture, "check_imported_module_names") +{ + fileResolver.source["game/A"] = R"( +return function(...) end + )"; + + fileResolver.source["game/B"] = R"( +local l0 = require(game.A) +return l0 + )"; + + CheckResult result = check(R"( +local l0 = require(game.B) +if true then + local l1 = require(game.A) +end +return l0 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ModulePtr mod = getMainModule(); + REQUIRE(mod); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + REQUIRE(mod->scopes.size() >= 4); + CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); + CHECK(mod->scopes[3].second->importedModules["l1"] == "game/A"); + } + else + { + + REQUIRE(mod->scopes.size() >= 3); + CHECK(mod->scopes[0].second->importedModules["l0"] == "game/B"); + CHECK(mod->scopes[2].second->importedModules["l1"] == "game/A"); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp index cf27518..ab41ce3 100644 --- a/tests/TypeInfer.oop.test.cpp +++ b/tests/TypeInfer.oop.test.cpp @@ -290,4 +290,23 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "set_prop_of_intersection_containing_metatabl )"); } +// DCR once had a bug in the following code where it would erroneously bind the 'self' table to itself. +TEST_CASE_FIXTURE(Fixture, "dont_bind_free_tables_to_themselves") +{ + CheckResult result = check(R"( + local T = {} + local b: any + + function T:m() + local a = b[i] + if a then + self:n() + if self:p(a) then + self:n() + end + end + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index d75f00a..dcdc2e3 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -48,7 +48,7 @@ TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") local x:string = s )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("s"), *typeChecker.stringType); + CHECK_EQ(*requireType("s"), *builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "and_does_not_always_add_boolean") @@ -72,7 +72,7 @@ TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") local x:boolean = s )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("x"), *typeChecker.booleanType); + CHECK_EQ(*requireType("x"), *builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "and_or_ternary") @@ -99,9 +99,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "primitive_arith_no_metatable") std::optional retType = first(functionType->retTypes); REQUIRE(retType.has_value()); - CHECK_EQ(typeChecker.numberType, follow(*retType)); - CHECK_EQ(requireType("n"), typeChecker.numberType); - CHECK_EQ(requireType("s"), typeChecker.stringType); + CHECK_EQ(builtinTypes->numberType, follow(*retType)); + CHECK_EQ(requireType("n"), builtinTypes->numberType); + CHECK_EQ(requireType("s"), builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") @@ -112,7 +112,7 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType); + CHECK_EQ(requireType("SOLAR_MASS"), builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable") @@ -248,8 +248,12 @@ TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_m LUAU_REQUIRE_ERROR_COUNT(1, result); GenericError* gen = get(result.errors[0]); + REQUIRE(gen != nullptr); - REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK(gen->message == "Types 'a' and 'b' cannot be compared with < because neither type has a metatable"); + else + REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); } TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") @@ -270,7 +274,11 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_indirectly_compare_types_that_do_not_ GenericError* gen = get(result.errors[0]); REQUIRE(gen != nullptr); - REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); + + if (FFlag::DebugLuauDeferredConstraintResolution) + CHECK(gen->message == "Types 'M' and 'M' cannot be compared with < because neither type's metatable has a '__lt' metamethod"); + else + REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); } TEST_CASE_FIXTURE(BuiltinsFixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") @@ -353,7 +361,7 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op") s += true )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}})); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{builtinTypes->numberType, builtinTypes->booleanType}})); } TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") @@ -364,8 +372,8 @@ TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") )"); LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); } TEST_CASE_FIXTURE(BuiltinsFixture, "compound_assign_metatable") @@ -521,7 +529,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_minus_error") CHECK_EQ("string", toString(requireType("a"))); TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.booleanType); + REQUIRE_EQ(*tm->wantedType, *builtinTypes->booleanType); // given type is the typeof(foo) which is complex to compare against } @@ -547,8 +555,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "typecheck_unary_len_error") CHECK_EQ("number", toString(requireType("a"))); TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); - REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); + REQUIRE_EQ(*tm->wantedType, *builtinTypes->numberType); + REQUIRE_EQ(*tm->givenType, *builtinTypes->stringType); } TEST_CASE_FIXTURE(BuiltinsFixture, "unary_not_is_boolean") @@ -596,8 +604,8 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(*tm->wantedType, *typeChecker.numberType); - CHECK_EQ(*tm->givenType, *typeChecker.stringType); + CHECK_EQ(*tm->wantedType, *builtinTypes->numberType); + CHECK_EQ(*tm->givenType, *builtinTypes->stringType); GenericError* gen1 = get(result.errors[1]); REQUIRE(gen1); @@ -608,7 +616,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "disallow_string_and_types_without_metatables TypeMismatch* tm2 = get(result.errors[2]); REQUIRE(tm2); - CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); + CHECK_EQ(*tm2->wantedType, *builtinTypes->numberType); CHECK_EQ(*tm2->givenType, *requireType("foo")); } @@ -802,7 +810,7 @@ local b: number = 1 or a TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); CHECK_EQ("number?", toString(tm->givenType)); } diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp index 02fdfa3..949e64a 100644 --- a/tests/TypeInfer.primitives.test.cpp +++ b/tests/TypeInfer.primitives.test.cpp @@ -57,7 +57,7 @@ TEST_CASE_FIXTURE(Fixture, "string_method") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("p"), *typeChecker.numberType); + CHECK_EQ(*requireType("p"), *builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "string_function_indirect") @@ -69,7 +69,7 @@ TEST_CASE_FIXTURE(Fixture, "string_function_indirect") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*requireType("p"), *typeChecker.stringType); + CHECK_EQ(*requireType("p"), *builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 570cf27..064ec16 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -30,9 +30,8 @@ std::optional> magicFunctionInstanceIsA( if (!lvalue || !tfun) return std::nullopt; - unfreeze(typeChecker.globalTypes); - TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); - freeze(typeChecker.globalTypes); + ModulePtr module = typeChecker.currentModule; + TypePackId booleanPack = module->internalTypes.addTypePack({typeChecker.booleanType}); return WithPredicate{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; } @@ -62,47 +61,47 @@ struct RefinementClassFixture : BuiltinsFixture { RefinementClassFixture() { - TypeArena& arena = typeChecker.globalTypes; - NotNull scope{typeChecker.globalScope.get()}; + TypeArena& arena = frontend.globals.globalTypes; + NotNull scope{frontend.globals.globalScope.get()}; - std::optional rootSuper = FFlag::LuauNegatedClassTypes ? std::make_optional(typeChecker.builtinTypes->classType) : std::nullopt; + std::optional rootSuper = FFlag::LuauNegatedClassTypes ? std::make_optional(builtinTypes->classType) : std::nullopt; unfreeze(arena); TypeId vec3 = arena.addType(ClassType{"Vector3", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); getMutable(vec3)->props = { - {"X", Property{typeChecker.numberType}}, - {"Y", Property{typeChecker.numberType}}, - {"Z", Property{typeChecker.numberType}}, + {"X", Property{builtinTypes->numberType}}, + {"Y", Property{builtinTypes->numberType}}, + {"Z", Property{builtinTypes->numberType}}, }; TypeId inst = arena.addType(ClassType{"Instance", {}, rootSuper, std::nullopt, {}, nullptr, "Test"}); - TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); - TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); + TypePackId isAParams = arena.addTypePack({inst, builtinTypes->stringType}); + TypePackId isARets = arena.addTypePack({builtinTypes->booleanType}); TypeId isA = arena.addType(FunctionType{isAParams, isARets}); getMutable(isA)->magicFunction = magicFunctionInstanceIsA; getMutable(isA)->dcrMagicRefinement = dcrMagicRefinementInstanceIsA; getMutable(inst)->props = { - {"Name", Property{typeChecker.stringType}}, + {"Name", Property{builtinTypes->stringType}}, {"IsA", Property{isA}}, }; - TypeId folder = typeChecker.globalTypes.addType(ClassType{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); - TypeId part = typeChecker.globalTypes.addType(ClassType{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); + TypeId folder = frontend.globals.globalTypes.addType(ClassType{"Folder", {}, inst, std::nullopt, {}, nullptr, "Test"}); + TypeId part = frontend.globals.globalTypes.addType(ClassType{"Part", {}, inst, std::nullopt, {}, nullptr, "Test"}); getMutable(part)->props = { {"Position", Property{vec3}}, }; - typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; - typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; - typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; - typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; + frontend.globals.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; + frontend.globals.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; + frontend.globals.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; + frontend.globals.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; - for (const auto& [name, ty] : typeChecker.globalScope->exportedTypeBindings) + for (const auto& [name, ty] : frontend.globals.globalScope->exportedTypeBindings) persist(ty.type); - freeze(typeChecker.globalTypes); + freeze(frontend.globals.globalTypes); } }; } // namespace diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 2a87f0e..0f5e3d3 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -198,7 +198,7 @@ TEST_CASE_FIXTURE(Fixture, "call_method") CheckResult result = check("local T = {} T.x = 0 function T:method() return self.x end local a = T:method()"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("a")); + CHECK_EQ(*builtinTypes->numberType, *requireType("a")); } TEST_CASE_FIXTURE(Fixture, "call_method_with_explicit_self_argument") @@ -576,8 +576,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_array") REQUIRE(bool(ttv->indexer)); - CHECK_EQ(*ttv->indexer->indexType, *typeChecker.numberType); - CHECK_EQ(*ttv->indexer->indexResultType, *typeChecker.stringType); + CHECK_EQ(*ttv->indexer->indexType, *builtinTypes->numberType); + CHECK_EQ(*ttv->indexer->indexResultType, *builtinTypes->stringType); } /* This is a bit weird. @@ -685,8 +685,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_array_like_table") REQUIRE(bool(ttv->indexer)); const TableIndexer& indexer = *ttv->indexer; - CHECK_EQ(*typeChecker.numberType, *indexer.indexType); - CHECK_EQ(*typeChecker.stringType, *indexer.indexResultType); + CHECK_EQ(*builtinTypes->numberType, *indexer.indexType); + CHECK_EQ(*builtinTypes->stringType, *indexer.indexResultType); } TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_value_property_in_literal") @@ -740,8 +740,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_from_its_variable_type_and_unifiable") REQUIRE(tTy != nullptr); REQUIRE(tTy->indexer); - CHECK_EQ(*typeChecker.numberType, *tTy->indexer->indexType); - CHECK_EQ(*typeChecker.stringType, *tTy->indexer->indexResultType); + CHECK_EQ(*builtinTypes->numberType, *tTy->indexer->indexType); + CHECK_EQ(*builtinTypes->stringType, *tTy->indexer->indexResultType); } TEST_CASE_FIXTURE(Fixture, "indexer_mismatch") @@ -844,7 +844,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_type_when_indexing_from_a_table_indexer") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + CHECK_EQ(*builtinTypes->stringType, *requireType("s")); } TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_possible") @@ -865,13 +865,13 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_ LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.stringType, *requireType("a1")); - CHECK_EQ(*typeChecker.stringType, *requireType("a2")); + CHECK_EQ(*builtinTypes->stringType, *requireType("a1")); + CHECK_EQ(*builtinTypes->stringType, *requireType("a2")); - CHECK_EQ(*typeChecker.numberType, *requireType("b1")); - CHECK_EQ(*typeChecker.numberType, *requireType("b2")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b1")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b2")); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ(*builtinTypes->numberType, *requireType("c")); CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } @@ -932,7 +932,7 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s LUAU_REQUIRE_NO_ERRORS(result); - CHECK("string" == toString(*typeChecker.stringType)); + CHECK("string" == toString(*builtinTypes->stringType)); TableType* tableType = getMutable(requireType("t")); REQUIRE(tableType != nullptr); @@ -941,7 +941,7 @@ TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_s TypeId propertyA = tableType->props["a"].type; REQUIRE(propertyA != nullptr); - CHECK_EQ(*typeChecker.stringType, *propertyA); + CHECK_EQ(*builtinTypes->stringType, *propertyA); } TEST_CASE_FIXTURE(BuiltinsFixture, "oop_indexer_works") @@ -964,7 +964,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "oop_indexer_works") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("words")); + CHECK_EQ(*builtinTypes->stringType, *requireType("words")); } TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_table") @@ -977,7 +977,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_table") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("b")); + CHECK_EQ(*builtinTypes->stringType, *requireType("b")); } TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_fn") @@ -988,7 +988,7 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "indexer_fn") )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("b")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b")); } TEST_CASE_FIXTURE(BuiltinsFixture, "meta_add") @@ -1102,10 +1102,10 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "oop_polymorphic") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.booleanType, *requireType("alive")); - CHECK_EQ(*typeChecker.stringType, *requireType("movement")); - CHECK_EQ(*typeChecker.stringType, *requireType("name")); - CHECK_EQ(*typeChecker.numberType, *requireType("speed")); + CHECK_EQ(*builtinTypes->booleanType, *requireType("alive")); + CHECK_EQ(*builtinTypes->stringType, *requireType("movement")); + CHECK_EQ(*builtinTypes->stringType, *requireType("name")); + CHECK_EQ(*builtinTypes->numberType, *requireType("speed")); } TEST_CASE_FIXTURE(Fixture, "user_defined_table_types_are_named") @@ -2477,7 +2477,7 @@ TEST_CASE_FIXTURE(Fixture, "table_length") LUAU_REQUIRE_NO_ERRORS(result); CHECK(nullptr != get(requireType("t"))); - CHECK_EQ(*typeChecker.numberType, *requireType("s")); + CHECK_EQ(*builtinTypes->numberType, *requireType("s")); } TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") @@ -2498,8 +2498,8 @@ TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") CHECK((Location{Position{3, 15}, Position{3, 18}}) == result.errors[0].location); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK(tm->wantedType == typeChecker.numberType); - CHECK(tm->givenType == typeChecker.stringType); + CHECK(tm->wantedType == builtinTypes->numberType); + CHECK(tm->givenType == builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") @@ -2510,8 +2510,8 @@ TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(result.errors[0], (TypeError{Location{Position{2, 17}, Position{2, 20}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.nilType, + builtinTypes->numberType, + builtinTypes->nilType, }})); } @@ -2709,7 +2709,7 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_cant_be_used_to_mutate_global_types") Fixture fix; // inherit env from parent fixture checker - fix.typeChecker.globalScope = typeChecker.globalScope; + fix.frontend.globals.globalScope = frontend.globals.globalScope; fix.check(R"( --!nonstrict @@ -2723,7 +2723,7 @@ end // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down // note: it's important for typeck to be destroyed at this point! { - for (auto& p : typeChecker.globalScope->bindings) + for (auto& p : frontend.globals.globalScope->bindings) { toString(p.second.typeId); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas } @@ -3318,8 +3318,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "a_free_shape_can_turn_into_a_scalar_if_it_is_compatible") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; // Changes argument from table type to primitive - CheckResult result = check(R"( local function f(s): string local foo = s:lower() @@ -3350,8 +3348,6 @@ caused by: TEST_CASE_FIXTURE(BuiltinsFixture, "a_free_shape_can_turn_into_a_scalar_directly") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( local function stringByteList(str) local out = {} @@ -3457,8 +3453,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "tables_should_be_fully_populated") TEST_CASE_FIXTURE(Fixture, "fuzz_table_indexer_unification_can_bound_owner_to_string") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( sin,_ = nil _ = _[_.sin][_._][_][_]._ @@ -3470,8 +3464,6 @@ _[_] = _ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_extra_prop_unification_can_bound_owner_to_string") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( l0,_ = nil _ = _,_[_.n5]._[_][_][_]._ @@ -3483,8 +3475,6 @@ _._.foreach[_],_ = _[_],_._ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_typelevel_promote_on_changed_table_type") { - ScopedFastFlag luauScalarShapeUnifyToMtOwner{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( _._,_ = nil _ = _.foreach[_]._,_[_.n5]._[_.foreach][_][_]._ @@ -3498,8 +3488,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_unify_instantiated_table") { ScopedFastFlag sff[]{ {"LuauInstantiateInSubtyping", true}, - {"LuauScalarShapeUnifyToMtOwner2", true}, - {"LuauTableUnifyInstantiationFix", true}, }; CheckResult result = check(R"( @@ -3517,8 +3505,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzz_table_unify_instantiated_table_with_prop_reallo { ScopedFastFlag sff[]{ {"LuauInstantiateInSubtyping", true}, - {"LuauScalarShapeUnifyToMtOwner2", true}, - {"LuauTableUnifyInstantiationFix", true}, }; CheckResult result = check(R"( @@ -3537,12 +3523,6 @@ end) TEST_CASE_FIXTURE(BuiltinsFixture, "fuzz_table_unify_prop_realloc") { - // For this test, we don't need LuauInstantiateInSubtyping - ScopedFastFlag sff[]{ - {"LuauScalarShapeUnifyToMtOwner2", true}, - {"LuauTableUnifyInstantiationFix", true}, - }; - CheckResult result = check(R"( n3,_ = nil _ = _[""]._,_[l0][_._][{[_]=_,_=_,}][_G].number diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 3865e83..417f80a 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -45,7 +45,8 @@ TEST_CASE_FIXTURE(Fixture, "tc_error") CheckResult result = check("local a = 7 local b = 'hi' a = b"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK_EQ( + result.errors[0], (TypeError{Location{Position{0, 35}, Position{0, 36}}, TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); } TEST_CASE_FIXTURE(Fixture, "tc_error_2") @@ -55,7 +56,7 @@ TEST_CASE_FIXTURE(Fixture, "tc_error_2") CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 18}, Position{0, 22}}, TypeMismatch{ requireType("a"), - typeChecker.stringType, + builtinTypes->stringType, }})); } @@ -123,8 +124,8 @@ TEST_CASE_FIXTURE(Fixture, "if_statement") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("a")); - CHECK_EQ(*typeChecker.numberType, *requireType("b")); + CHECK_EQ(*builtinTypes->stringType, *requireType("a")); + CHECK_EQ(*builtinTypes->numberType, *requireType("b")); } TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") @@ -256,7 +257,13 @@ TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowi end )"); - LUAU_REQUIRE_NO_ERRORS(result); + if (FFlag::DebugLuauDeferredConstraintResolution) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(get(result.errors[0])); + } + else + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") @@ -580,7 +587,7 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(builtinTypes->numberType, tm->wantedType); CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); } @@ -1150,8 +1157,6 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "it_is_ok_to_have_inconsistent_number_of_retu TEST_CASE_FIXTURE(Fixture, "fuzz_free_table_type_change_during_index_check") { - ScopedFastFlag sff{"LuauScalarShapeUnifyToMtOwner2", true}; - CheckResult result = check(R"( local _ = nil while _["" >= _] do @@ -1175,4 +1180,24 @@ local b = typeof(foo) ~= 'nil' CHECK(toString(result.errors[1]) == "Unknown global 'foo'"); } +TEST_CASE_FIXTURE(Fixture, "dcr_delays_expansion_of_function_containing_blocked_parameter_type") +{ + ScopedFastFlag sff[] = { + {"DebugLuauDeferredConstraintResolution", true}, + {"LuauTinyUnifyNormalsFix", true}, + }; + + CheckResult result = check(R"( + local b: any + + function f(x) + local a = b[1] or 'Cn' + local c = x[1] + + if a:sub(1, #c) == c then + end + end + )"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 47b140a..66e0701 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -38,7 +38,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") { Type functionOne{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; + TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))}}; Type functionTwo{TypeVariant{ FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; @@ -55,13 +55,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") { TypePackVar argPackOne{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; Type functionOne{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.numberType}))}}; + TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->numberType}))}}; Type functionOneSaved = functionOne; TypePackVar argPackTwo{TypePack{{arena.freshType(globalScope->level)}, std::nullopt}}; Type functionTwo{ - TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({typeChecker.stringType}))}}; + TypeVariant{FunctionType(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({builtinTypes->stringType}))}}; Type functionTwoSaved = functionTwo; @@ -96,12 +96,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") { Type tableOne{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.numberType}}}, std::nullopt, globalScope->level, + TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->numberType}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; Type tableTwo{TypeVariant{ - TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {typeChecker.stringType}}}, std::nullopt, globalScope->level, + TableType{{{"foo", {arena.freshType(globalScope->level)}}, {"bar", {builtinTypes->stringType}}}, std::nullopt, globalScope->level, TableState::Unsealed}, }}; @@ -214,8 +214,8 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails" TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") { - TypePackVar testPack{TypePack{{typeChecker.numberType, typeChecker.stringType}, std::nullopt}}; - TypePackVar variadicPack{VariadicTypePack{typeChecker.numberType}}; + TypePackVar testPack{TypePack{{builtinTypes->numberType, builtinTypes->stringType}, std::nullopt}}; + TypePackVar variadicPack{VariadicTypePack{builtinTypes->numberType}}; state.tryUnify(&testPack, &variadicPack); CHECK(!state.errors.empty()); @@ -223,9 +223,9 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") { - TypePackVar variadicPack{VariadicTypePack{typeChecker.booleanType}}; - TypePackVar a{TypePack{{typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType, typeChecker.booleanType}}}; - TypePackVar b{TypePack{{typeChecker.numberType, typeChecker.stringType}, &variadicPack}}; + TypePackVar variadicPack{VariadicTypePack{builtinTypes->booleanType}}; + TypePackVar a{TypePack{{builtinTypes->numberType, builtinTypes->stringType, builtinTypes->booleanType, builtinTypes->booleanType}}}; + TypePackVar b{TypePack{{builtinTypes->numberType, builtinTypes->stringType}, &variadicPack}}; state.tryUnify(&b, &a); CHECK(state.errors.empty()); @@ -266,8 +266,9 @@ TEST_CASE_FIXTURE(BuiltinsFixture, "cli_41095_concat_log_in_sealed_table_unifica TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") { - TypePackId threeNumbers = arena.addTypePack(TypePack{{typeChecker.numberType, typeChecker.numberType, typeChecker.numberType}, std::nullopt}); - TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); + TypePackId threeNumbers = + arena.addTypePack(TypePack{{builtinTypes->numberType, builtinTypes->numberType, builtinTypes->numberType}, std::nullopt}); + TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{builtinTypes->numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); ErrorVec unifyErrors = state.canUnify(numberAndFreeTail, threeNumbers); CHECK(unifyErrors.size() == 0); @@ -279,7 +280,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") Type table{TableType{}}; Type metatable{MetatableType{&redirect, &table}}; redirect = BoundType{&metatable}; // Now we have a metatable that is recursive on the table type - Type variant{UnionType{{&metatable, typeChecker.numberType}}}; + Type variant{UnionType{{&metatable, builtinTypes->numberType}}}; state.tryUnify(&metatable, &variant); } @@ -293,13 +294,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") state.tryUnify(&free, &target); // Shouldn't assert or error. - state.tryUnify(&func, typeChecker.anyType); + state.tryUnify(&func, builtinTypes->anyType); } TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") { TypeId a = arena.addType(Type{FreeType{TypeLevel{}}}); - TypeId b = typeChecker.numberType; + TypeId b = builtinTypes->numberType; state.tryUnify(a, b); state.log.commit(); @@ -310,7 +311,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_type_owner") TEST_CASE_FIXTURE(TryUnifyFixture, "txnlog_preserves_pack_owner") { TypePackId a = arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}}); - TypePackId b = typeChecker.anyTypePack; + TypePackId b = builtinTypes->anyTypePack; state.tryUnify(a, b); state.log.commit(); @@ -323,13 +324,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table ScopedFastFlag sff("DebugLuauDeferredConstraintResolution", true); TableType::Props freeProps{ - {"foo", {typeChecker.numberType}}, + {"foo", {builtinTypes->numberType}}, }; TypeId free = arena.addType(TableType{freeProps, std::nullopt, TypeLevel{}, TableState::Free}); TableType::Props indexProps{ - {"foo", {typeChecker.stringType}}, + {"foo", {builtinTypes->stringType}}, }; TypeId index = arena.addType(TableType{indexProps, std::nullopt, TypeLevel{}, TableState::Sealed}); @@ -356,9 +357,9 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "metatables_unify_against_shape_of_free_table TEST_CASE_FIXTURE(TryUnifyFixture, "fuzz_tail_unification_issue") { - TypePackVar variadicAny{VariadicTypePack{typeChecker.anyType}}; - TypePackVar packTmp{TypePack{{typeChecker.anyType}, &variadicAny}}; - TypePackVar packSub{TypePack{{typeChecker.anyType, typeChecker.anyType}, &packTmp}}; + TypePackVar variadicAny{VariadicTypePack{builtinTypes->anyType}}; + TypePackVar packTmp{TypePack{{builtinTypes->anyType}, &variadicAny}}; + TypePackVar packSub{TypePack{{builtinTypes->anyType, builtinTypes->anyType}, &packTmp}}; Type freeTy{FreeType{TypeLevel{}}}; TypePackVar freeTp{FreeTypePack{TypeLevel{}}}; diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 78eb6d4..4411916 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -27,8 +27,8 @@ TEST_CASE_FIXTURE(Fixture, "infer_multi_return") const auto& [returns, tail] = flatten(takeTwoType->retTypes); CHECK_EQ(2, returns.size()); - CHECK_EQ(typeChecker.numberType, follow(returns[0])); - CHECK_EQ(typeChecker.numberType, follow(returns[1])); + CHECK_EQ(builtinTypes->numberType, follow(returns[0])); + CHECK_EQ(builtinTypes->numberType, follow(returns[1])); CHECK(!tail); } @@ -74,9 +74,9 @@ TEST_CASE_FIXTURE(Fixture, "last_element_of_return_statement_can_itself_be_a_pac const auto& [rets, tail] = flatten(takeOneMoreType->retTypes); REQUIRE_EQ(3, rets.size()); - CHECK_EQ(typeChecker.numberType, follow(rets[0])); - CHECK_EQ(typeChecker.numberType, follow(rets[1])); - CHECK_EQ(typeChecker.numberType, follow(rets[2])); + CHECK_EQ(builtinTypes->numberType, follow(rets[0])); + CHECK_EQ(builtinTypes->numberType, follow(rets[1])); + CHECK_EQ(builtinTypes->numberType, follow(rets[2])); CHECK(!tail); } @@ -184,28 +184,28 @@ TEST_CASE_FIXTURE(Fixture, "parenthesized_varargs_returns_any") TEST_CASE_FIXTURE(Fixture, "variadic_packs") { - TypeArena& arena = typeChecker.globalTypes; + TypeArena& arena = frontend.globals.globalTypes; unfreeze(arena); - TypePackId listOfNumbers = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); - TypePackId listOfStrings = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.stringType}}); + TypePackId listOfNumbers = arena.addTypePack(TypePackVar{VariadicTypePack{builtinTypes->numberType}}); + TypePackId listOfStrings = arena.addTypePack(TypePackVar{VariadicTypePack{builtinTypes->stringType}}); // clang-format off - addGlobalBinding(frontend, "foo", + addGlobalBinding(frontend.globals, "foo", arena.addType( FunctionType{ listOfNumbers, - arena.addTypePack({typeChecker.numberType}) + arena.addTypePack({builtinTypes->numberType}) } ), "@test" ); - addGlobalBinding(frontend, "bar", + addGlobalBinding(frontend.globals, "bar", arena.addType( FunctionType{ - arena.addTypePack({{typeChecker.numberType}, listOfStrings}), - arena.addTypePack({typeChecker.numberType}) + arena.addTypePack({{builtinTypes->numberType}, listOfStrings}), + arena.addTypePack({builtinTypes->numberType}) } ), "@test" @@ -223,9 +223,11 @@ TEST_CASE_FIXTURE(Fixture, "variadic_packs") LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location(Position{3, 21}, Position{3, 26}), TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK_EQ( + result.errors[0], (TypeError{Location(Position{3, 21}, Position{3, 26}), TypeMismatch{builtinTypes->numberType, builtinTypes->stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location(Position{4, 29}, Position{4, 30}), TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); + CHECK_EQ( + result.errors[1], (TypeError{Location(Position{4, 29}, Position{4, 30}), TypeMismatch{builtinTypes->stringType, builtinTypes->numberType}})); } TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 6f69d68..704e2a3 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -131,7 +131,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_property_guaranteed_to_ex )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.numberType, *requireType("r")); + CHECK_EQ(*builtinTypes->numberType, *requireType("r")); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_mixed_types") @@ -211,7 +211,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any" )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.anyType, *requireType("r")); + CHECK_EQ(*builtinTypes->anyType, *requireType("r")); } TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") @@ -245,7 +245,7 @@ local c = bf.a.y )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ(*builtinTypes->numberType, *requireType("c")); CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } @@ -260,7 +260,7 @@ TEST_CASE_FIXTURE(Fixture, "optional_union_functions") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ(*builtinTypes->numberType, *requireType("c")); CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } @@ -275,7 +275,7 @@ local c = b:foo(1, 2) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.numberType, *requireType("c")); + CHECK_EQ(*builtinTypes->numberType, *requireType("c")); CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } diff --git a/tests/TypeReduction.test.cpp b/tests/TypeReduction.test.cpp index 582725b..5f11a71 100644 --- a/tests/TypeReduction.test.cpp +++ b/tests/TypeReduction.test.cpp @@ -31,6 +31,12 @@ struct ReductionFixture : Fixture return *reducedTy; } + std::optional tryReduce(const std::string& annotation) + { + check("type _Res = " + annotation); + return reduction.reduce(requireTypeAlias("_Res")); + } + TypeId reductionof(const std::string& annotation) { check("type _Res = " + annotation); @@ -1488,4 +1494,16 @@ TEST_CASE_FIXTURE(ReductionFixture, "cycles") } } +TEST_CASE_FIXTURE(ReductionFixture, "string_singletons") +{ + TypeId ty = reductionof("(string & Not<\"A\">)?"); + CHECK("(string & ~\"A\")?" == toStringFull(ty)); +} + +TEST_CASE_FIXTURE(ReductionFixture, "string_singletons_2") +{ + TypeId ty = reductionof("Not<\"A\"> & Not<\"B\"> & (string?)"); + CHECK("(string & ~\"A\" & ~\"B\")?" == toStringFull(ty)); +} + TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 36e437e..64ba63c 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -16,13 +16,13 @@ TEST_SUITE_BEGIN("TypeTests"); TEST_CASE_FIXTURE(Fixture, "primitives_are_equal") { - REQUIRE_EQ(typeChecker.booleanType, typeChecker.booleanType); + REQUIRE_EQ(builtinTypes->booleanType, builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "bound_type_is_equal_to_that_which_it_is_bound") { - Type bound(BoundType(typeChecker.booleanType)); - REQUIRE_EQ(bound, *typeChecker.booleanType); + Type bound(BoundType(builtinTypes->booleanType)); + REQUIRE_EQ(bound, *builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "equivalent_cyclic_tables_are_equal") @@ -54,8 +54,8 @@ TEST_CASE_FIXTURE(Fixture, "different_cyclic_tables_are_not_equal") TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_not_parenthesized_if_just_one_value") { auto emptyArgumentPack = TypePackVar{TypePack{}}; - auto returnPack = TypePackVar{TypePack{{typeChecker.numberType}}}; - auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType}}}; + auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ("() -> number", res); @@ -64,8 +64,8 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_not_parenthesized_if_just TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_not_just_one_value") { auto emptyArgumentPack = TypePackVar{TypePack{}}; - auto returnPack = TypePackVar{TypePack{{typeChecker.numberType, typeChecker.numberType}}}; - auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType, builtinTypes->numberType}}}; + auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ("() -> (number, number)", res); @@ -76,8 +76,8 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_ auto emptyArgumentPack = TypePackVar{TypePack{}}; auto free = Unifiable::Free(TypeLevel()); auto freePack = TypePackVar{TypePackVariant{free}}; - auto returnPack = TypePackVar{TypePack{{typeChecker.numberType}, &freePack}}; - auto returnsTwo = Type(FunctionType(typeChecker.globalScope->level, &emptyArgumentPack, &returnPack)); + auto returnPack = TypePackVar{TypePack{{builtinTypes->numberType}, &freePack}}; + auto returnsTwo = Type(FunctionType(frontend.globals.globalScope->level, &emptyArgumentPack, &returnPack)); std::string res = toString(&returnsTwo); CHECK_EQ(res, "() -> (number, a...)"); @@ -86,9 +86,9 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_function_is_parenthesized_if_tail_is_ TEST_CASE_FIXTURE(Fixture, "subset_check") { UnionType super, sub, notSub; - super.options = {typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType}; - sub.options = {typeChecker.numberType, typeChecker.stringType}; - notSub.options = {typeChecker.numberType, typeChecker.nilType}; + super.options = {builtinTypes->numberType, builtinTypes->stringType, builtinTypes->booleanType}; + sub.options = {builtinTypes->numberType, builtinTypes->stringType}; + notSub.options = {builtinTypes->numberType, builtinTypes->nilType}; CHECK(isSubset(super, sub)); CHECK(!isSubset(super, notSub)); @@ -97,7 +97,7 @@ TEST_CASE_FIXTURE(Fixture, "subset_check") TEST_CASE_FIXTURE(Fixture, "iterate_over_UnionType") { UnionType utv; - utv.options = {typeChecker.numberType, typeChecker.stringType, typeChecker.anyType}; + utv.options = {builtinTypes->numberType, builtinTypes->stringType, builtinTypes->anyType}; std::vector result; for (TypeId ty : &utv) @@ -110,19 +110,19 @@ TEST_CASE_FIXTURE(Fixture, "iterating_over_nested_UnionTypes") { Type subunion{UnionType{}}; UnionType* innerUtv = getMutable(&subunion); - innerUtv->options = {typeChecker.numberType, typeChecker.stringType}; + innerUtv->options = {builtinTypes->numberType, builtinTypes->stringType}; UnionType utv; - utv.options = {typeChecker.anyType, &subunion}; + utv.options = {builtinTypes->anyType, &subunion}; std::vector result; for (TypeId ty : &utv) result.push_back(ty); REQUIRE_EQ(result.size(), 3); - CHECK_EQ(result[0], typeChecker.anyType); - CHECK_EQ(result[2], typeChecker.stringType); - CHECK_EQ(result[1], typeChecker.numberType); + CHECK_EQ(result[0], builtinTypes->anyType); + CHECK_EQ(result[2], builtinTypes->stringType); + CHECK_EQ(result[1], builtinTypes->numberType); } TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypes_and_skips_over_them") @@ -132,8 +132,8 @@ TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypes_and_skips_over_th Type btv{UnionType{}}; UnionType* utv2 = getMutable(&btv); - utv2->options.push_back(typeChecker.numberType); - utv2->options.push_back(typeChecker.stringType); + utv2->options.push_back(builtinTypes->numberType); + utv2->options.push_back(builtinTypes->stringType); utv2->options.push_back(&atv); utv1->options.push_back(&btv); @@ -143,14 +143,14 @@ TEST_CASE_FIXTURE(Fixture, "iterator_detects_cyclic_UnionTypes_and_skips_over_th result.push_back(ty); REQUIRE_EQ(result.size(), 2); - CHECK_EQ(result[0], typeChecker.numberType); - CHECK_EQ(result[1], typeChecker.stringType); + CHECK_EQ(result[0], builtinTypes->numberType); + CHECK_EQ(result[1], builtinTypes->stringType); } TEST_CASE_FIXTURE(Fixture, "iterator_descends_on_nested_in_first_operator*") { - Type tv1{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; - Type tv2{UnionType{{&tv1, typeChecker.booleanType}}}; + Type tv1{UnionType{{builtinTypes->stringType, builtinTypes->numberType}}}; + Type tv2{UnionType{{&tv1, builtinTypes->booleanType}}}; auto utv = get(&tv2); std::vector result; @@ -158,19 +158,19 @@ TEST_CASE_FIXTURE(Fixture, "iterator_descends_on_nested_in_first_operator*") result.push_back(ty); REQUIRE_EQ(result.size(), 3); - CHECK_EQ(result[0], typeChecker.stringType); - CHECK_EQ(result[1], typeChecker.numberType); - CHECK_EQ(result[2], typeChecker.booleanType); + CHECK_EQ(result[0], builtinTypes->stringType); + CHECK_EQ(result[1], builtinTypes->numberType); + CHECK_EQ(result[2], builtinTypes->booleanType); } TEST_CASE_FIXTURE(Fixture, "UnionTypeIterator_with_vector_iter_ctor") { - Type tv1{UnionType{{typeChecker.stringType, typeChecker.numberType}}}; - Type tv2{UnionType{{&tv1, typeChecker.booleanType}}}; + Type tv1{UnionType{{builtinTypes->stringType, builtinTypes->numberType}}}; + Type tv2{UnionType{{&tv1, builtinTypes->booleanType}}}; auto utv = get(&tv2); std::vector actual(begin(utv), end(utv)); - std::vector expected{typeChecker.stringType, typeChecker.numberType, typeChecker.booleanType}; + std::vector expected{builtinTypes->stringType, builtinTypes->numberType, builtinTypes->booleanType}; CHECK_EQ(actual, expected); } @@ -273,10 +273,10 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId root = &ttvTweenResult; - typeChecker.currentModule = std::make_shared(); - typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(builtinTypes->anyTypePack)); + frontend.typeChecker.currentModule = std::make_shared(); + frontend.typeChecker.currentModule->scopes.emplace_back(Location{}, std::make_shared(builtinTypes->anyTypePack)); - TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); + TypeId result = frontend.typeChecker.anyify(frontend.globals.globalScope, root, Location{}); CHECK_EQ("{| f: t1 |} where t1 = () -> {| f: () -> {| f: ({| f: t1 |}) -> (), signal: {| f: (any) -> () |} |} |}", toString(result)); } diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index e23c1a5..f4a91fc 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -120,6 +120,8 @@ assert((function() local a a = nil local b = 2 b = a and b return b end)() == ni assert((function() local a a = 1 local b = 2 b = a or b return b end)() == 1) assert((function() local a a = nil local b = 2 b = a or b return b end)() == 2) +assert((function(a) return 12 % a end)(5) == 2) + -- binary arithmetics coerces strings to numbers (sadly) assert(1 + "2" == 3) assert(2 * "0xa" == 20) diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index e2f68e6..18ed137 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -281,6 +281,8 @@ assert(math.round(-0.4) == 0) assert(math.round(-0.5) == -1) assert(math.round(-3.5) == -4) assert(math.round(math.huge) == math.huge) +assert(math.round(0.49999999999999994) == 0) +assert(math.round(-0.49999999999999994) == 0) -- fmod assert(math.fmod(3, 2) == 1) diff --git a/tests/conformance/sort.lua b/tests/conformance/sort.lua index 95940e1..693a10d 100644 --- a/tests/conformance/sort.lua +++ b/tests/conformance/sort.lua @@ -2,6 +2,47 @@ -- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes print"testing sort" +function checksort(t, f, ...) + assert(#t == select('#', ...)) + local copy = table.clone(t) + table.sort(copy, f) + for i=1,#t do assert(copy[i] == select(i, ...)) end +end + +-- basic edge cases +checksort({}, nil) +checksort({1}, nil, 1) + +-- small inputs +checksort({1, 2}, nil, 1, 2) +checksort({2, 1}, nil, 1, 2) + +checksort({1, 2, 3}, nil, 1, 2, 3) +checksort({2, 1, 3}, nil, 1, 2, 3) +checksort({1, 3, 2}, nil, 1, 2, 3) +checksort({3, 2, 1}, nil, 1, 2, 3) +checksort({3, 1, 2}, nil, 1, 2, 3) + +-- "large" input +checksort({3, 8, 1, 7, 10, 2, 5, 4, 9, 6}, nil, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) +checksort({"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}, nil, "Apr", "Aug", "Dec", "Feb", "Jan", "Jul", "Jun", "Mar", "May", "Nov", "Oct", "Sep") + +-- duplicates +checksort({3, 1, 1, 7, 1, 3, 5, 1, 9, 3}, nil, 1, 1, 1, 1, 3, 3, 3, 5, 7, 9) + +-- predicates +checksort({3, 8, 1, 7, 10, 2, 5, 4, 9, 6}, function (a, b) return a > b end, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1) + +-- can't sort readonly tables +assert(pcall(table.sort, table.freeze({2, 1})) == false) + +-- first argument must be a table, second argument must be nil or function +assert(pcall(table.sort) == false) +assert(pcall(table.sort, "abc") == false) +assert(pcall(table.sort, {}, 42) == false) +assert(pcall(table.sort, {}, {}) == false) + +-- legacy Lua tests function check (a, f) f = f or function (x,y) return x