structs.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import itertools
  2. from .compat import collections_abc
  3. class DirectedGraph(object):
  4. """A graph structure with directed edges."""
  5. def __init__(self):
  6. self._vertices = set()
  7. self._forwards = {} # <key> -> Set[<key>]
  8. self._backwards = {} # <key> -> Set[<key>]
  9. def __iter__(self):
  10. return iter(self._vertices)
  11. def __len__(self):
  12. return len(self._vertices)
  13. def __contains__(self, key):
  14. return key in self._vertices
  15. def copy(self):
  16. """Return a shallow copy of this graph."""
  17. other = DirectedGraph()
  18. other._vertices = set(self._vertices)
  19. other._forwards = {k: set(v) for k, v in self._forwards.items()}
  20. other._backwards = {k: set(v) for k, v in self._backwards.items()}
  21. return other
  22. def add(self, key):
  23. """Add a new vertex to the graph."""
  24. if key in self._vertices:
  25. raise ValueError("vertex exists")
  26. self._vertices.add(key)
  27. self._forwards[key] = set()
  28. self._backwards[key] = set()
  29. def remove(self, key):
  30. """Remove a vertex from the graph, disconnecting all edges from/to it."""
  31. self._vertices.remove(key)
  32. for f in self._forwards.pop(key):
  33. self._backwards[f].remove(key)
  34. for t in self._backwards.pop(key):
  35. self._forwards[t].remove(key)
  36. def connected(self, f, t):
  37. return f in self._backwards[t] and t in self._forwards[f]
  38. def connect(self, f, t):
  39. """Connect two existing vertices.
  40. Nothing happens if the vertices are already connected.
  41. """
  42. if t not in self._vertices:
  43. raise KeyError(t)
  44. self._forwards[f].add(t)
  45. self._backwards[t].add(f)
  46. def iter_edges(self):
  47. for f, children in self._forwards.items():
  48. for t in children:
  49. yield f, t
  50. def iter_children(self, key):
  51. return iter(self._forwards[key])
  52. def iter_parents(self, key):
  53. return iter(self._backwards[key])
  54. class IteratorMapping(collections_abc.Mapping):
  55. def __init__(self, mapping, accessor, appends=None):
  56. self._mapping = mapping
  57. self._accessor = accessor
  58. self._appends = appends or {}
  59. def __contains__(self, key):
  60. return key in self._mapping or key in self._appends
  61. def __getitem__(self, k):
  62. try:
  63. v = self._mapping[k]
  64. except KeyError:
  65. return iter(self._appends[k])
  66. return itertools.chain(self._accessor(v), self._appends.get(k, ()))
  67. def __iter__(self):
  68. more = (k for k in self._appends if k not in self._mapping)
  69. return itertools.chain(self._mapping, more)
  70. def __len__(self):
  71. more = len(k for k in self._appends if k not in self._mapping)
  72. return len(self._mapping) + more
  73. class _FactoryIterableView(object):
  74. """Wrap an iterator factory returned by `find_matches()`.
  75. Calling `iter()` on this class would invoke the underlying iterator
  76. factory, making it a "collection with ordering" that can be iterated
  77. through multiple times, but lacks random access methods presented in
  78. built-in Python sequence types.
  79. """
  80. def __init__(self, factory):
  81. self._factory = factory
  82. def __repr__(self):
  83. return "{}({})".format(type(self).__name__, list(self._factory()))
  84. def __bool__(self):
  85. try:
  86. next(self._factory())
  87. except StopIteration:
  88. return False
  89. return True
  90. __nonzero__ = __bool__ # XXX: Python 2.
  91. def __iter__(self):
  92. return self._factory()
  93. class _SequenceIterableView(object):
  94. """Wrap an iterable returned by find_matches().
  95. This is essentially just a proxy to the underlying sequence that provides
  96. the same interface as `_FactoryIterableView`.
  97. """
  98. def __init__(self, sequence):
  99. self._sequence = sequence
  100. def __repr__(self):
  101. return "{}({})".format(type(self).__name__, self._sequence)
  102. def __bool__(self):
  103. return bool(self._sequence)
  104. __nonzero__ = __bool__ # XXX: Python 2.
  105. def __iter__(self):
  106. return iter(self._sequence)
  107. def build_iter_view(matches):
  108. """Build an iterable view from the value returned by `find_matches()`."""
  109. if callable(matches):
  110. return _FactoryIterableView(matches)
  111. if not isinstance(matches, collections_abc.Sequence):
  112. matches = list(matches)
  113. return _SequenceIterableView(matches)