如何找到一个类的所有子类给它的名字?

我需要一种工作方法来获取从Python基类继承的所有类。

155988 次浏览

新风格的类(即从object继承的子类,这是Python 3中的默认值)有一个__subclasses__方法,该方法返回子类:

class Foo(object): pass
class Bar(Foo): pass
class Baz(Foo): pass
class Bing(Bar): pass

下面是子类的名称:

print([cls.__name__ for cls in Foo.__subclasses__()])
# ['Bar', 'Baz']

下面是子类本身:

print(Foo.__subclasses__())
# [<class '__main__.Bar'>, <class '__main__.Baz'>]

确认子类确实列出Foo作为它们的基类:

for cls in Foo.__subclasses__():
print(cls.__base__)
# <class '__main__.Foo'>
# <class '__main__.Foo'>

注意,如果你想要子类,你必须递归:

def all_subclasses(cls):
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)])


print(all_subclasses(Foo))
# {<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>}

注意,如果一个子类的类定义还没有被执行——例如,如果子类的模块还没有被导入——那么这个子类还不存在,__subclasses__将找不到它。


你提到了“以其名字命名”。由于Python类是一级对象,所以不需要使用带有类名的字符串来代替类或类似的东西。您可以直接使用该类,而且您可能应该这样做。

如果你确实有一个表示类名的字符串,并且你想要找到该类的子类,那么有两个步骤:找到给定其名称的类,然后像上面那样找到带有__subclasses__的子类。

如何从名称中找到类取决于您希望在哪里找到它。如果您希望在与试图定位类的代码相同的模块中找到它,那么

cls = globals()[name]

会起作用,或者在不太可能的情况下,你期望在当地人身上找到它,

cls = locals()[name]

如果类可以在任何模块中,那么你的名称字符串应该包含完全限定名——比如'pkg.module.Foo',而不仅仅是'Foo'。使用importlib加载类的模块,然后检索相应的属性:

import importlib
modname, _, clsname = name.rpartition('.')
mod = importlib.import_module(modname)
cls = getattr(mod, clsname)

无论你如何找到这个类,cls.__subclasses__()都会返回它的子类列表。

这个答案不如使用@unutbu提到的特殊内置__subclasses__()类方法好,所以我只是把它作为一个练习。定义的subclasses()函数返回一个字典,该字典将所有子类名称映射到子类本身。

def traced_subclass(baseclass):
class _SubclassTracer(type):
def __new__(cls, classname, bases, classdict):
obj = type(classname, bases, classdict)
if baseclass in bases: # sanity check
attrname = '_%s__derived' % baseclass.__name__
derived = getattr(baseclass, attrname, {})
derived.update( {classname:obj} )
setattr(baseclass, attrname, derived)
return obj
return _SubclassTracer


def subclasses(baseclass):
attrname = '_%s__derived' % baseclass.__name__
return getattr(baseclass, attrname, None)




class BaseClass(object):
pass


class SubclassA(BaseClass):
__metaclass__ = traced_subclass(BaseClass)


class SubclassB(BaseClass):
__metaclass__ = traced_subclass(BaseClass)


print subclasses(BaseClass)

输出:

{'SubclassB': <class '__main__.SubclassB'>,
'SubclassA': <class '__main__.SubclassA'>}

如果你只想要直接的子类,那么.__subclasses__()就可以了。如果你想要所有的子类,子类的子类等等,你需要一个函数来为你做这些。

下面是一个简单易读的函数,它可以递归地找到给定类的所有子类:

def get_all_subclasses(cls):
all_subclasses = []


for subclass in cls.__subclasses__():
all_subclasses.append(subclass)
all_subclasses.extend(get_all_subclasses(subclass))


return all_subclasses

注意:我看到有人(不是@unutbu)改变了引用的答案,使它不再使用vars()['Foo'] -所以我的帖子的主要观点不再适用。

FWIW,这就是我所说的@unutbu的回答只使用本地定义的类—并且使用eval()而不是vars()将使它适用于任何可访问的类,而不仅仅是当前作用域中定义的类。

对于那些不喜欢使用eval()的人,也显示了一种避免使用它的方法。

首先,这里有一个具体的例子,演示了使用vars()的潜在问题:

class Foo(object): pass
class Bar(Foo): pass
class Baz(Foo): pass
class Bing(Bar): pass


# unutbu's approach
def all_subclasses(cls):
return cls.__subclasses__() + [g for s in cls.__subclasses__()
for g in all_subclasses(s)]


print(all_subclasses(vars()['Foo']))  # Fine because  Foo is in scope
# -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]


def func():  # won't work because Foo class is not locally defined
print(all_subclasses(vars()['Foo']))


try:
func()  # not OK because Foo is not local to func()
except Exception as e:
print('calling func() raised exception: {!r}'.format(e))
# -> calling func() raised exception: KeyError('Foo',)


print(all_subclasses(eval('Foo')))  # OK
# -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]


# using eval('xxx') instead of vars()['xxx']
def func2():
print(all_subclasses(eval('Foo')))


func2()  # Works
# -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]

这可以通过将eval('ClassName')向下移动到定义的函数中来改进,这使得它更容易使用,而不会失去使用eval()所获得的额外通用性,与vars()不同,eval()不是上下文敏感的:

# easier to use version
def all_subclasses2(classname):
direct_subclasses = eval(classname).__subclasses__()
return direct_subclasses + [g for s in direct_subclasses
for g in all_subclasses2(s.__name__)]


# pass 'xxx' instead of eval('xxx')
def func_ez():
print(all_subclasses2('Foo'))  # simpler


func_ez()
# -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]

最后,出于安全原因避免使用eval()是可能的,在某些情况下甚至是重要的,所以这里有一个不使用它的版本:

def get_all_subclasses(cls):
""" Generator of all a class's subclasses. """
try:
for subclass in cls.__subclasses__():
yield subclass
for subclass in get_all_subclasses(subclass):
yield subclass
except TypeError:
return


def all_subclasses3(classname):
for cls in get_all_subclasses(object):  # object is base of all new-style classes.
if cls.__name__.split('.')[-1] == classname:
break
else:
raise ValueError('class %s not found' % classname)
direct_subclasses = cls.__subclasses__()
return direct_subclasses + [g for s in direct_subclasses
for g in all_subclasses3(s.__name__)]


# no eval('xxx')
def func3():
print(all_subclasses3('Foo'))


func3()  # Also works
# -> [<class '__main__.Bar'>, <class '__main__.Baz'>, <class '__main__.Bing'>]

一般形式的最简单解:

def get_subclasses(cls):
for subclass in cls.__subclasses__():
yield from get_subclasses(subclass)
yield subclass

和类方法,如果你有一个单一的类,你继承:

@classmethod
def get_subclasses(cls):
for subclass in cls.__subclasses__():
yield from subclass.get_subclasses()
yield subclass

获取所有子类列表的一个更短的版本:

from itertools import chain


def subclasses(cls):
return list(
chain.from_iterable(
[list(chain.from_iterable([[x], subclasses(x)])) for x in cls.__subclasses__()]
)
)

__abc1 - __abc0

正如其他回答提到的,你可以检查__subclasses__属性来获得子类的列表,因为python 3.6你可以通过覆盖__init_subclass__方法来修改这个属性的创建。

class PluginBase:
subclasses = []


def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.subclasses.append(cls)


class Plugin1(PluginBase):
pass


class Plugin2(PluginBase):
pass

这样,如果你知道你在做什么,你可以重写__subclasses__的行为,并从这个列表中省略/添加子类。

下面是一个没有递归的版本:

def get_subclasses_gen(cls):


def _subclasses(classes, seen):
while True:
subclasses = sum((x.__subclasses__() for x in classes), [])
yield from classes
yield from seen
found = []
if not subclasses:
return


classes = subclasses
seen = found


return _subclasses([cls], [])
这与其他实现的不同之处在于它返回原始的类。 这是因为它使代码更简单,并且:

class Ham(object):
pass


assert(issubclass(Ham, Ham)) # True

如果get_subclasses_gen看起来有点奇怪,那是因为它是通过将尾递归实现转换为循环生成器创建的:

def get_subclasses(cls):


def _subclasses(classes, seen):
subclasses = sum(*(frozenset(x.__subclasses__()) for x in classes))
found = classes + seen
if not subclasses:
return found


return _subclasses(subclasses, found)


return _subclasses([cls], [])

我怎么能找到一个类的所有子类给它的名字?

我们当然可以很容易地做到这一点,只要能访问对象本身。

仅仅给出它的名字是一个糟糕的想法,因为可以有多个同名的类,甚至在同一个模块中定义。

我为另一个回答创建了一个实现,因为它回答了这个问题,而且它比这里的其他解决方案更优雅,所以它是:

def get_subclasses(cls):
"""returns all subclasses of argument, cls"""
if issubclass(cls, type):
subclasses = cls.__subclasses__(cls)
else:
subclasses = cls.__subclasses__()
for subclass in subclasses:
subclasses.extend(get_subclasses(subclass))
return subclasses

用法:

>>> import pprint
>>> list_of_classes = get_subclasses(int)
>>> pprint.pprint(list_of_classes)
[<class 'bool'>,
<enum 'IntEnum'>,
<enum 'IntFlag'>,
<class 'sre_constants._NamedIntConstant'>,
<class 'subprocess.Handle'>,
<enum '_ParameterKind'>,
<enum 'Signals'>,
<enum 'Handlers'>,
<enum 'RegexFlag'>]

下面是一个简单但有效的代码版本:

def get_all_subclasses(cls):
subclass_list = []


def recurse(klass):
for subclass in klass.__subclasses__():
subclass_list.append(subclass)
recurse(subclass)


recurse(cls)


return set(subclass_list)

它的时间复杂度是O(n),其中n是所有子类的个数,如果没有多重继承的话。 它比递归地创建列表或使用生成器生成类的函数更有效,后者的复杂度可能是(1)当类层次结构是平衡树时O(nlogn),或(2)当类层次结构是有偏树时O(n^2)

虽然我非常倾向于__init_subclass__方法,但这将保留定义顺序,并避免组合增长顺序,如果你有一个非常密集的层次结构,到处都有多个继承:

def descendents(cls):
'''Does not return the class itself'''
R = {}
def visit(cls):
for subCls in cls.__subclasses__():
R[subCls] = True
visit(subCls)
visit(cls)
return list(R.keys())

这是因为字典会记住键的插入顺序。列表方法也会起作用。