// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once #include #include "Luau/DenseHash.h" #include "Luau/RecursionCounter.h" #include "Luau/TypePack.h" #include "Luau/Type.h" LUAU_FASTINT(LuauVisitRecursionLimit) namespace Luau { namespace visit_detail { /** * Apply f(tid, t, seen) if doing so would pass type checking, else apply f(tid, t) * * We do this to permit (but not require) Type visitors to accept the seen set as an argument. */ template auto apply(A tid, const B& t, C& c, F& f) -> decltype(f(tid, t, c)) { return f(tid, t, c); } template auto apply(A tid, const B& t, C&, F& f) -> decltype(f(tid, t)) { return f(tid, t); } inline bool hasSeen(std::unordered_set& seen, const void* tv) { void* ttv = const_cast(tv); return !seen.insert(ttv).second; } inline bool hasSeen(DenseHashSet& seen, const void* tv) { void* ttv = const_cast(tv); if (seen.contains(ttv)) return true; seen.insert(ttv); return false; } inline void unsee(std::unordered_set& seen, const void* tv) { void* ttv = const_cast(tv); seen.erase(ttv); } inline void unsee(DenseHashSet& seen, const void* tv) { // When DenseHashSet is used for 'visitTypeOnce', where don't forget visited elements } } // namespace visit_detail template struct GenericTypeVisitor { using Set = S; Set seen; bool skipBoundTypes = false; int recursionCounter = 0; GenericTypeVisitor() = default; explicit GenericTypeVisitor(Set seen, bool skipBoundTypes = false) : seen(std::move(seen)) , skipBoundTypes(skipBoundTypes) { } virtual void cycle(TypeId) {} virtual void cycle(TypePackId) {} virtual bool visit(TypeId ty) { return true; } virtual bool visit(TypeId ty, const BoundType& btv) { return visit(ty); } virtual bool visit(TypeId ty, const FreeType& ftv) { return visit(ty); } virtual bool visit(TypeId ty, const GenericType& gtv) { return visit(ty); } virtual bool visit(TypeId ty, const ErrorType& etv) { return visit(ty); } virtual bool visit(TypeId ty, const PrimitiveType& ptv) { return visit(ty); } virtual bool visit(TypeId ty, const FunctionType& ftv) { return visit(ty); } virtual bool visit(TypeId ty, const TableType& ttv) { return visit(ty); } virtual bool visit(TypeId ty, const MetatableType& mtv) { return visit(ty); } virtual bool visit(TypeId ty, const ClassType& ctv) { return visit(ty); } virtual bool visit(TypeId ty, const AnyType& atv) { return visit(ty); } virtual bool visit(TypeId ty, const UnknownType& utv) { return visit(ty); } virtual bool visit(TypeId ty, const NeverType& ntv) { return visit(ty); } virtual bool visit(TypeId ty, const UnionType& utv) { return visit(ty); } virtual bool visit(TypeId ty, const IntersectionType& itv) { return visit(ty); } virtual bool visit(TypeId ty, const BlockedType& btv) { return visit(ty); } virtual bool visit(TypeId ty, const PendingExpansionType& petv) { return visit(ty); } virtual bool visit(TypeId ty, const SingletonType& stv) { return visit(ty); } virtual bool visit(TypeId ty, const NegationType& ntv) { return visit(ty); } virtual bool visit(TypePackId tp) { return true; } virtual bool visit(TypePackId tp, const BoundTypePack& btp) { return visit(tp); } virtual bool visit(TypePackId tp, const FreeTypePack& ftp) { return visit(tp); } virtual bool visit(TypePackId tp, const GenericTypePack& gtp) { return visit(tp); } virtual bool visit(TypePackId tp, const Unifiable::Error& etp) { return visit(tp); } virtual bool visit(TypePackId tp, const TypePack& pack) { return visit(tp); } virtual bool visit(TypePackId tp, const VariadicTypePack& vtp) { return visit(tp); } virtual bool visit(TypePackId tp, const BlockedTypePack& btp) { return visit(tp); } void traverse(TypeId ty) { RecursionLimiter limiter{&recursionCounter, FInt::LuauVisitRecursionLimit}; if (visit_detail::hasSeen(seen, ty)) { cycle(ty); return; } if (auto btv = get(ty)) { if (skipBoundTypes) traverse(btv->boundTo); else if (visit(ty, *btv)) traverse(btv->boundTo); } else if (auto ftv = get(ty)) visit(ty, *ftv); else if (auto gtv = get(ty)) visit(ty, *gtv); else if (auto etv = get(ty)) visit(ty, *etv); else if (auto ptv = get(ty)) visit(ty, *ptv); else if (auto ftv = get(ty)) { if (visit(ty, *ftv)) { traverse(ftv->argTypes); traverse(ftv->retTypes); } } else if (auto ttv = get(ty)) { // Some visitors want to see bound tables, that's why we traverse the original type if (skipBoundTypes && ttv->boundTo) { traverse(*ttv->boundTo); } else if (visit(ty, *ttv)) { if (ttv->boundTo) { traverse(*ttv->boundTo); } else { for (auto& [_name, prop] : ttv->props) traverse(prop.type); if (ttv->indexer) { traverse(ttv->indexer->indexType); traverse(ttv->indexer->indexResultType); } } } } else if (auto mtv = get(ty)) { if (visit(ty, *mtv)) { traverse(mtv->table); traverse(mtv->metatable); } } else if (auto ctv = get(ty)) { if (visit(ty, *ctv)) { for (const auto& [name, prop] : ctv->props) traverse(prop.type); if (ctv->parent) traverse(*ctv->parent); if (ctv->metatable) traverse(*ctv->metatable); } } else if (auto atv = get(ty)) visit(ty, *atv); else if (auto utv = get(ty)) { if (visit(ty, *utv)) { for (TypeId optTy : utv->options) traverse(optTy); } } else if (auto itv = get(ty)) { if (visit(ty, *itv)) { for (TypeId partTy : itv->parts) traverse(partTy); } } else if (get(ty)) { // Visiting into LazyType may necessarily cause infinite expansion, so we don't do that on purpose. // Asserting also makes no sense, because the type _will_ happen here, most likely as a property of some ClassType // that doesn't need to be expanded. } else if (auto stv = get(ty)) visit(ty, *stv); else if (auto btv = get(ty)) visit(ty, *btv); else if (auto utv = get(ty)) visit(ty, *utv); else if (auto ntv = get(ty)) visit(ty, *ntv); else if (auto petv = get(ty)) { if (visit(ty, *petv)) { for (TypeId a : petv->typeArguments) traverse(a); for (TypePackId a : petv->packArguments) traverse(a); } } else if (auto ntv = get(ty)) { if (visit(ty, *ntv)) traverse(ntv->ty); } else LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypeId) is not exhaustive!"); visit_detail::unsee(seen, ty); } void traverse(TypePackId tp) { if (visit_detail::hasSeen(seen, tp)) { cycle(tp); return; } if (auto btv = get(tp)) { if (visit(tp, *btv)) traverse(btv->boundTo); } else if (auto ftv = get(tp)) visit(tp, *ftv); else if (auto gtv = get(tp)) visit(tp, *gtv); else if (auto etv = get(tp)) visit(tp, *etv); else if (auto pack = get(tp)) { bool res = visit(tp, *pack); if (res) { for (TypeId ty : pack->head) traverse(ty); if (pack->tail) traverse(*pack->tail); } } else if (auto pack = get(tp)) { bool res = visit(tp, *pack); if (res) traverse(pack->ty); } else if (auto btp = get(tp)) visit(tp, *btp); else LUAU_ASSERT(!"GenericTypeVisitor::traverse(TypePackId) is not exhaustive!"); visit_detail::unsee(seen, tp); } }; /** Visit each type under a given type. Skips over cycles and keeps recursion depth under control. * * The same type may be visited multiple times if there are multiple distinct paths to it. If this is undesirable, use * TypeOnceVisitor. */ struct TypeVisitor : GenericTypeVisitor> { explicit TypeVisitor(bool skipBoundTypes = false) : GenericTypeVisitor{{}, skipBoundTypes} { } }; /// Visit each type under a given type. Each type will only be checked once even if there are multiple paths to it. struct TypeOnceVisitor : GenericTypeVisitor> { explicit TypeOnceVisitor(bool skipBoundTypes = false) : GenericTypeVisitor{DenseHashSet{nullptr}, skipBoundTypes} { } }; } // namespace Luau