// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include #include "Luau/DenseHash.h" #include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" LUAU_FASTINT(LuauVisitRecursionLimit) LUAU_FASTFLAG(LuauNormalizeFlagIsConservative) namespace Luau { namespace visit_detail { /** * Apply f(tid, t, seen) if doing so would pass type checking, else apply f(tid, t) * * We do this to permit (but not require) TypeVar visitors to accept the seen set as an argument. */ template auto apply(A tid, const B& t, C& c, F& f) -> decltype(f(tid, t, c)) { return f(tid, t, c); } template auto apply(A tid, const B& t, C&, F& f) -> decltype(f(tid, t)) { return f(tid, t); } inline bool hasSeen(std::unordered_set& seen, const void* tv) { void* ttv = const_cast(tv); return !seen.insert(ttv).second; } inline bool hasSeen(DenseHashSet& seen, const void* tv) { void* ttv = const_cast(tv); if (seen.contains(ttv)) return true; seen.insert(ttv); return false; } inline void unsee(std::unordered_set& seen, const void* tv) { void* ttv = const_cast(tv); seen.erase(ttv); } inline void unsee(DenseHashSet& seen, const void* tv) { // When DenseHashSet is used for 'visitTypeVarOnce', where don't forget visited elements } } // namespace visit_detail template struct GenericTypeVarVisitor { using Set = S; Set seen; int recursionCounter = 0; GenericTypeVarVisitor() = default; explicit GenericTypeVarVisitor(Set seen) : seen(std::move(seen)) { } virtual void cycle(TypeId) {} virtual void cycle(TypePackId) {} virtual bool visit(TypeId ty) { return true; } virtual bool visit(TypeId ty, const BoundTypeVar& btv) { return visit(ty); } virtual bool visit(TypeId ty, const FreeTypeVar& ftv) { return visit(ty); } virtual bool visit(TypeId ty, const GenericTypeVar& gtv) { return visit(ty); } virtual bool visit(TypeId ty, const ErrorTypeVar& etv) { return visit(ty); } virtual bool visit(TypeId ty, const ConstrainedTypeVar& ctv) { return visit(ty); } virtual bool visit(TypeId ty, const PrimitiveTypeVar& ptv) { return visit(ty); } virtual bool visit(TypeId ty, const FunctionTypeVar& ftv) { return visit(ty); } virtual bool visit(TypeId ty, const TableTypeVar& ttv) { return visit(ty); } virtual bool visit(TypeId ty, const MetatableTypeVar& mtv) { return visit(ty); } virtual bool visit(TypeId ty, const ClassTypeVar& ctv) { return visit(ty); } virtual bool visit(TypeId ty, const AnyTypeVar& atv) { return visit(ty); } virtual bool visit(TypeId ty, const UnionTypeVar& utv) { return visit(ty); } virtual bool visit(TypeId ty, const IntersectionTypeVar& itv) { return visit(ty); } virtual bool visit(TypePackId tp) { return true; } virtual bool visit(TypePackId tp, const BoundTypePack& btp) { return visit(tp); } virtual bool visit(TypePackId tp, const FreeTypePack& ftp) { return visit(tp); } virtual bool visit(TypePackId tp, const GenericTypePack& gtp) { return visit(tp); } virtual bool visit(TypePackId tp, const Unifiable::Error& etp) { return visit(tp); } virtual bool visit(TypePackId tp, const TypePack& pack) { return visit(tp); } virtual bool visit(TypePackId tp, const VariadicTypePack& vtp) { return visit(tp); } void traverse(TypeId ty) { RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit}; if (visit_detail::hasSeen(seen, ty)) { cycle(ty); return; } if (auto btv = get(ty)) { if (visit(ty, *btv)) traverse(btv->boundTo); } else if (auto ftv = get(ty)) visit(ty, *ftv); else if (auto gtv = get(ty)) visit(ty, *gtv); else if (auto etv = get(ty)) visit(ty, *etv); else if (auto ctv = get(ty)) { if (visit(ty, *ctv)) { for (TypeId part : ctv->parts) traverse(part); } } else if (auto ptv = get(ty)) visit(ty, *ptv); else if (auto ftv = get(ty)) { if (visit(ty, *ftv)) { traverse(ftv->argTypes); traverse(ftv->retTypes); } } else if (auto ttv = get(ty)) { // Some visitors want to see bound tables, that's why we traverse the original type if (visit(ty, *ttv)) { if (ttv->boundTo) { traverse(*ttv->boundTo); } else { for (auto& [_name, prop] : ttv->props) traverse(prop.type); if (ttv->indexer) { traverse(ttv->indexer->indexType); traverse(ttv->indexer->indexResultType); } } } } else if (auto mtv = get(ty)) { if (visit(ty, *mtv)) { traverse(mtv->table); traverse(mtv->metatable); } } else if (auto ctv = get(ty)) { if (visit(ty, *ctv)) { for (const auto& [name, prop] : ctv->props) traverse(prop.type); if (ctv->parent) traverse(*ctv->parent); if (ctv->metatable) traverse(*ctv->metatable); } } else if (auto atv = get(ty)) visit(ty, *atv); else if (auto utv = get(ty)) { if (visit(ty, *utv)) { for (TypeId optTy : utv->options) traverse(optTy); } } else if (auto itv = get(ty)) { if (visit(ty, *itv)) { for (TypeId partTy : itv->parts) traverse(partTy); } } visit_detail::unsee(seen, ty); } void traverse(TypePackId tp) { if (visit_detail::hasSeen(seen, tp)) { cycle(tp); return; } if (auto btv = get(tp)) { if (visit(tp, *btv)) traverse(btv->boundTo); } else if (auto ftv = get(tp)) visit(tp, *ftv); else if (auto gtv = get(tp)) visit(tp, *gtv); else if (auto etv = get(tp)) visit(tp, *etv); else if (auto pack = get(tp)) { bool res = visit(tp, *pack); if (!FFlag::LuauNormalizeFlagIsConservative || res) { for (TypeId ty : pack->head) traverse(ty); if (pack->tail) traverse(*pack->tail); } } else if (auto pack = get(tp)) { bool res = visit(tp, *pack); if (!FFlag::LuauNormalizeFlagIsConservative || res) traverse(pack->ty); } else LUAU_ASSERT(!"GenericTypeVarVisitor::traverse(TypePackId) is not exhaustive!"); visit_detail::unsee(seen, tp); } }; /** Visit each type under a given type. Skips over cycles and keeps recursion depth under control. * * The same type may be visited multiple times if there are multiple distinct paths to it. If this is undesirable, use * TypeVarOnceVisitor. */ struct TypeVarVisitor : GenericTypeVarVisitor> { }; /// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it. struct TypeVarOnceVisitor : GenericTypeVarVisitor> { TypeVarOnceVisitor() : GenericTypeVarVisitor{DenseHashSet{nullptr}} { } }; } // namespace Luau