// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ConstraintGraphBuilder.h" namespace Luau { Constraint::Constraint(ConstraintV&& c) : c(std::move(c)) { } Constraint::Constraint(ConstraintV&& c, std::vector dependencies) : c(std::move(c)) , dependencies(dependencies) { } std::optional Scope2::lookup(Symbol sym) { Scope2* s = this; while (true) { auto it = s->bindings.find(sym); if (it != s->bindings.end()) return it->second; if (s->parent) s = s->parent; else return std::nullopt; } } ConstraintGraphBuilder::ConstraintGraphBuilder(TypeArena* arena) : singletonTypes(getSingletonTypes()) , arena(arena) , rootScope(nullptr) { LUAU_ASSERT(arena); } TypeId ConstraintGraphBuilder::freshType(Scope2* scope) { LUAU_ASSERT(scope); return arena->addType(FreeTypeVar{scope}); } TypePackId ConstraintGraphBuilder::freshTypePack(Scope2* scope) { LUAU_ASSERT(scope); FreeTypePack f{scope}; return arena->addTypePack(TypePackVar{std::move(f)}); } Scope2* ConstraintGraphBuilder::childScope(Location location, Scope2* parent) { LUAU_ASSERT(parent); auto scope = std::make_unique(); Scope2* borrow = scope.get(); scopes.emplace_back(location, std::move(scope)); borrow->parent = parent; borrow->returnType = parent->returnType; parent->children.push_back(borrow); return borrow; } void ConstraintGraphBuilder::addConstraint(Scope2* scope, ConstraintV cv) { LUAU_ASSERT(scope); scope->constraints.emplace_back(new Constraint{std::move(cv)}); } void ConstraintGraphBuilder::addConstraint(Scope2* scope, std::unique_ptr c) { LUAU_ASSERT(scope); scope->constraints.emplace_back(std::move(c)); } void ConstraintGraphBuilder::visit(AstStatBlock* block) { LUAU_ASSERT(scopes.empty()); LUAU_ASSERT(rootScope == nullptr); scopes.emplace_back(block->location, std::make_unique()); rootScope = scopes.back().second.get(); rootScope->returnType = freshTypePack(rootScope); visit(rootScope, block); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStat* stat) { LUAU_ASSERT(scope); if (auto s = stat->as()) visit(scope, s); else if (auto s = stat->as()) visit(scope, s); else if (auto f = stat->as()) visit(scope, f); else if (auto r = stat->as()) visit(scope, r); else LUAU_ASSERT(0); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocal* local) { LUAU_ASSERT(scope); std::vector varTypes; for (AstLocal* local : local->vars) { // TODO annotations TypeId ty = freshType(scope); varTypes.push_back(ty); scope->bindings[local] = ty; } for (size_t i = 0; i < local->vars.size; ++i) { if (i < local->values.size) { TypeId exprType = check(scope, local->values.data[i]); addConstraint(scope, SubtypeConstraint{varTypes[i], exprType}); } } } void addConstraints(Constraint* constraint, Scope2* scope) { LUAU_ASSERT(scope); scope->constraints.reserve(scope->constraints.size() + scope->constraints.size()); for (const auto& c : scope->constraints) constraint->dependencies.push_back(c.get()); for (Scope2* childScope : scope->children) addConstraints(constraint, childScope); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatLocalFunction* function) { LUAU_ASSERT(scope); // Local // Global // Dotted path // Self? TypeId functionType = nullptr; auto ty = scope->lookup(function->name); LUAU_ASSERT(!ty.has_value()); // The parser ensures that every local function has a distinct Symbol for its name. functionType = freshType(scope); scope->bindings[function->name] = functionType; Scope2* innerScope = childScope(function->func->body->location, scope); TypePackId returnType = freshTypePack(scope); innerScope->returnType = returnType; std::vector argTypes; for (AstLocal* local : function->func->args) { TypeId t = freshType(innerScope); argTypes.push_back(t); innerScope->bindings[local] = t; // TODO annotations } for (AstStat* stat : function->func->body->body) visit(innerScope, stat); FunctionTypeVar actualFunction{arena->addTypePack(argTypes), returnType}; TypeId actualFunctionType = arena->addType(std::move(actualFunction)); std::unique_ptr c{new Constraint{GeneralizationConstraint{functionType, actualFunctionType, innerScope}}}; addConstraints(c.get(), innerScope); addConstraint(scope, std::move(c)); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatReturn* ret) { LUAU_ASSERT(scope); TypePackId exprTypes = checkPack(scope, ret->list); addConstraint(scope, PackSubtypeConstraint{exprTypes, scope->returnType}); } void ConstraintGraphBuilder::visit(Scope2* scope, AstStatBlock* block) { LUAU_ASSERT(scope); for (AstStat* stat : block->body) visit(scope, stat); } TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstArray exprs) { LUAU_ASSERT(scope); if (exprs.size == 0) return arena->addTypePack({}); std::vector types; TypePackId last = nullptr; for (size_t i = 0; i < exprs.size; ++i) { if (i < exprs.size - 1) types.push_back(check(scope, exprs.data[i])); else last = checkPack(scope, exprs.data[i]); } LUAU_ASSERT(last != nullptr); return arena->addTypePack(TypePack{std::move(types), last}); } TypePackId ConstraintGraphBuilder::checkPack(Scope2* scope, AstExpr* expr) { LUAU_ASSERT(scope); // TEMP TEMP TEMP HACK HACK HACK FIXME FIXME TypeId t = check(scope, expr); return arena->addTypePack({t}); } TypeId ConstraintGraphBuilder::check(Scope2* scope, AstExpr* expr) { LUAU_ASSERT(scope); if (auto a = expr->as()) return singletonTypes.stringType; else if (auto a = expr->as()) return singletonTypes.numberType; else if (auto a = expr->as()) return singletonTypes.booleanType; else if (auto a = expr->as()) return singletonTypes.nilType; else if (auto a = expr->as()) { std::optional ty = scope->lookup(a->local); if (ty) return *ty; else return singletonTypes.errorRecoveryType(singletonTypes.anyType); // FIXME? Record an error at this point? } else if (auto a = expr->as()) { std::vector args; for (AstExpr* arg : a->args) { args.push_back(check(scope, arg)); } TypeId fnType = check(scope, a->func); TypeId instantiatedType = freshType(scope); addConstraint(scope, InstantiationConstraint{instantiatedType, fnType}); TypeId firstRet = freshType(scope); TypePackId rets = arena->addTypePack(TypePack{{firstRet}, arena->addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); FunctionTypeVar ftv(arena->addTypePack(TypePack{args, {}}), rets); TypeId inferredFnType = arena->addType(ftv); addConstraint(scope, SubtypeConstraint{inferredFnType, instantiatedType}); return firstRet; } else { LUAU_ASSERT(0); return freshType(scope); } } static void collectConstraints(std::vector& result, Scope2* scope) { for (const auto& c : scope->constraints) result.push_back(c.get()); for (Scope2* child : scope->children) collectConstraints(result, child); } std::vector collectConstraints(Scope2* rootScope) { std::vector result; collectConstraints(result, rootScope); return result; } } // namespace Luau