luau/tests/Fixture.h

291 lines
8.5 KiB
C++

// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
#pragma once
#include "Luau/Config.h"
#include "Luau/FileResolver.h"
#include "Luau/Frontend.h"
#include "Luau/IostreamHelpers.h"
#include "Luau/Linter.h"
#include "Luau/Location.h"
#include "Luau/ModuleResolver.h"
#include "Luau/Scope.h"
#include "Luau/ToString.h"
#include "Luau/TypeInfer.h"
#include "Luau/TypeVar.h"
#include "Luau/DcrLogger.h"
#include "IostreamOptional.h"
#include "ScopedFlags.h"
#include <iostream>
#include <string>
#include <unordered_map>
#include <optional>
namespace Luau
{
struct TestFileResolver
: FileResolver
, ModuleResolver
{
std::optional<ModuleInfo> resolveModuleInfo(const ModuleName& currentModuleName, const AstExpr& pathExpr) override
{
if (auto name = pathExprToModuleName(currentModuleName, pathExpr))
return {{*name, false}};
return std::nullopt;
}
const ModulePtr getModule(const ModuleName& moduleName) const override
{
LUAU_ASSERT(false);
return nullptr;
}
bool moduleExists(const ModuleName& moduleName) const override
{
auto it = source.find(moduleName);
return (it != source.end());
}
std::optional<SourceCode> readSource(const ModuleName& name) override
{
auto it = source.find(name);
if (it == source.end())
return std::nullopt;
SourceCode::Type sourceType = SourceCode::Module;
auto it2 = sourceTypes.find(name);
if (it2 != sourceTypes.end())
sourceType = it2->second;
return SourceCode{it->second, sourceType};
}
std::optional<ModuleInfo> resolveModule(const ModuleInfo* context, AstExpr* expr) override;
std::string getHumanReadableModuleName(const ModuleName& name) const override;
std::optional<std::string> getEnvironmentForModule(const ModuleName& name) const override;
std::unordered_map<ModuleName, std::string> source;
std::unordered_map<ModuleName, SourceCode::Type> sourceTypes;
std::unordered_map<ModuleName, std::string> environments;
};
struct TestConfigResolver : ConfigResolver
{
Config defaultConfig;
std::unordered_map<ModuleName, Config> configFiles;
const Config& getConfig(const ModuleName& name) const override
{
auto it = configFiles.find(name);
if (it != configFiles.end())
return it->second;
return defaultConfig;
}
};
struct Fixture
{
explicit Fixture(bool freeze = true, bool prepareAutocomplete = false);
~Fixture();
// Throws Luau::ParseErrors if the parse fails.
AstStatBlock* parse(const std::string& source, const ParseOptions& parseOptions = {});
CheckResult check(Mode mode, std::string source);
CheckResult check(const std::string& source);
LintResult lint(const std::string& source, const std::optional<LintOptions>& lintOptions = {});
LintResult lintTyped(const std::string& source, const std::optional<LintOptions>& lintOptions = {});
/// Parse with all language extensions enabled
ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {});
ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {});
ParseResult matchParseError(const std::string& source, const std::string& message, std::optional<Location> location = std::nullopt);
// Verify a parse error occurs and the parse error message has the specified prefix
ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix);
ModulePtr getMainModule();
SourceModule* getMainSourceModule();
std::optional<PrimitiveTypeVar::Type> getPrimitiveType(TypeId ty);
std::optional<TypeId> getType(const std::string& name);
TypeId requireType(const std::string& name);
TypeId requireType(const ModuleName& moduleName, const std::string& name);
TypeId requireType(const ModulePtr& module, const std::string& name);
TypeId requireType(const ScopePtr& scope, const std::string& name);
std::optional<TypeId> findTypeAtPosition(Position position);
TypeId requireTypeAtPosition(Position position);
std::optional<TypeId> findExpectedTypeAtPosition(Position position);
std::optional<TypeId> lookupType(const std::string& name);
std::optional<TypeId> lookupImportedType(const std::string& moduleAlias, const std::string& name);
ScopedFastFlag sff_DebugLuauFreezeArena;
ScopedFastFlag sff_UnknownNever{"LuauUnknownAndNeverType", true};
TestFileResolver fileResolver;
TestConfigResolver configResolver;
NullModuleResolver moduleResolver;
std::unique_ptr<SourceModule> sourceModule;
Frontend frontend;
InternalErrorReporter ice;
TypeChecker& typeChecker;
NotNull<SingletonTypes> singletonTypes;
std::string decorateWithTypes(const std::string& code);
void dumpErrors(std::ostream& os, const std::vector<TypeError>& errors);
void dumpErrors(const CheckResult& cr);
void dumpErrors(const ModulePtr& module);
void dumpErrors(const Module& module);
void validateErrors(const std::vector<TypeError>& errors);
std::string getErrors(const CheckResult& cr);
void registerTestTypes();
LoadDefinitionFileResult loadDefinition(const std::string& source);
};
struct BuiltinsFixture : Fixture
{
BuiltinsFixture(bool freeze = true, bool prepareAutocomplete = false);
};
struct ConstraintGraphBuilderFixture : Fixture
{
TypeArena arena;
ModulePtr mainModule;
ConstraintGraphBuilder cgb;
DcrLogger logger;
ScopedFastFlag forceTheFlag;
ConstraintGraphBuilderFixture();
};
ModuleName fromString(std::string_view name);
template<typename T>
std::optional<T> get(const std::map<Name, T>& map, const Name& name)
{
auto it = map.find(name);
if (it != map.end())
return std::optional<T>(it->second);
else
return std::nullopt;
}
std::string rep(const std::string& s, size_t n);
bool isInArena(TypeId t, const TypeArena& arena);
void dumpErrors(const ModulePtr& module);
void dumpErrors(const Module& module);
void dump(const std::string& name, TypeId ty);
void dump(const std::vector<Constraint>& constraints);
std::optional<TypeId> lookupName(ScopePtr scope, const std::string& name); // Warning: This function runs in O(n**2)
std::optional<TypeId> linearSearchForBinding(Scope* scope, const char* name);
struct Nth
{
int classIndex;
int nth;
};
template<typename T>
Nth nth(int nth = 1)
{
static_assert(std::is_base_of_v<AstNode, T>, "T must be a derived class of AstNode");
LUAU_ASSERT(nth > 0); // Did you mean to use `nth<T>(1)`?
return Nth{T::ClassIndex(), nth};
}
struct FindNthOccurenceOf : public AstVisitor
{
Nth requestedNth;
int currentOccurrence = 0;
AstNode* theNode = nullptr;
FindNthOccurenceOf(Nth nth);
bool checkIt(AstNode* n);
bool visit(AstNode* n) override;
bool visit(AstType* n) override;
bool visit(AstTypePack* n) override;
};
/** DSL querying of the AST.
*
* Given an AST, one can query for a particular node directly without having to manually unwrap the tree, for example:
*
* ```
* if a and b then
* print(a + b)
* end
*
* function f(x, y)
* return x + y
* end
* ```
*
* There are numerous ways to access the second AstExprBinary.
* 1. Luau::query<AstExprBinary>(block, {nth<AstStatFunction>(), nth<AstExprBinary>()})
* 2. Luau::query<AstExprBinary>(Luau::query<AstStatFunction>(block))
* 3. Luau::query<AstExprBinary>(block, {nth<AstExprBinary>(2)})
*/
template<typename T, int N = 1>
T* query(AstNode* node, const std::vector<Nth>& nths = {nth<T>(N)})
{
static_assert(std::is_base_of_v<AstNode, T>, "T must be a derived class of AstNode");
// If a nested query call fails to find the node in question, subsequent calls can propagate rather than trying to do more.
// This supports `query(query(...))`
for (Nth nth : nths)
{
if (!node)
return nullptr;
FindNthOccurenceOf finder{nth};
node->visit(&finder);
node = finder.theNode;
}
return node ? node->as<T>() : nullptr;
}
} // namespace Luau
#define LUAU_REQUIRE_ERRORS(result) \
do \
{ \
auto&& r = (result); \
validateErrors(r.errors); \
REQUIRE(!r.errors.empty()); \
} while (false)
#define LUAU_REQUIRE_ERROR_COUNT(count, result) \
do \
{ \
auto&& r = (result); \
validateErrors(r.errors); \
REQUIRE_MESSAGE(count == r.errors.size(), getErrors(r)); \
} while (false)
#define LUAU_REQUIRE_NO_ERRORS(result) LUAU_REQUIRE_ERROR_COUNT(0, result)