Disjoint Set Union (DSU)
- 2021.11.29
- Disjoint Set Union
為 disjoint (non-overlapping) sets 設計的資料結構
演算法
Time Complexity: O(node 的數量)
Space Complexity: O(node 的數量)
使用「秩(rank)」能在連結兩元素時所建立的樹較為平均
「秩(rank)」的定義如下:
- 只有根節點的樹(即只有一個元素的集合),秩為0
- 當兩棵秩不同的樹合併後,新的樹的秩為原來兩棵樹的秩的較大者;
- 當兩棵秩相同的樹合併後,新的樹的秩為原來的樹的秩加一。
添加元素
def make_set(x): x.root = x
使用「秩」
def make_set(x): x.root = x x.rank = 0
查找元素
def find(x):
if x.root == x:
return x
x.root = find(x.root)
return x.root
連結兩元素
def union(x, y):
x_root = find(x)
y_root = find(y)
if x_root != y_root:
x_root.root = y_root
使用「秩」
def union(x, y):
x_root = find(x)
y_root = find(y)
if x_root != y_root:
if x_root.rank < y_root.rank:
large = y_root
small = x_root
else:
large = x_root
small = y_root
small.root = large
if large.rank = small.rank:
large.rank += 1
包裝成 class
Python
無秩
class Dsu:
def __init__(self, roots = None):
self.root = {}
if roots:
for r in roots:
self.make_set(r)
def make_set(self, x):
self.root[x] = x
def find(self, x):
if self.root[x] == x:
return x
self.root[x] = self.find(self.root[x])
return self.root[x]
def union(self, x, y):
x_root = self.find(x)
y_root = self.find(y)
if x_root != y_root:
self.root[x_root] = y_root
def are_in_same_set(self, x, y):
return self.find(x) == self.find(y)
def get_ds(self): # get disjoint sets
disjoint_sets = collections.defaultdict(set)
for key in self.root:
disjoint_sets[self.find(key)].add(key)
return disjoint_sets
有秩
class Dsu:
def __init__(self, roots = None):
self.root = {}
self.rank = {}
if roots:
for r in roots:
self.make_set(r)
def make_set(self, x):
self.root[x] = x
self.rank[x] = 0
def find(self, x):
if self.root[x] == x:
return x
self.root[x] = self.find(self.root[x])
return self.root[x]
def union(self, x, y):
x_root = self.find(x)
y_root = self.find(y)
if x_root != y_root:
if self.rank[x_root] < self.rank[y_root]:
large = y_root
small = x_root
else:
large = x_root
small = y_root
self.root[small] = large
if self.rank[large] == self.rank[small]:
self.rank[large] += 1
def are_in_same_set(self, x, y):
return self.find(x) == self.find(y)
def get_ds(self): # get disjoint sets
disjoint_sets = collections.defaultdict(set)
for key in self.root:
disjoint_sets[self.find(key)].add(key)
return disjoint_sets
Ref: Wikipedia
經典例題
- Leetcode # 200. Number of Islands
- Leetcode # 305. Number of Islands II
- Leetcode # 323. Number of Connected Components in an Undirected Graph
Last Updated on 2023/08/16 by A1go