Disjoint Set Union (DSU)

為 disjoint (non-overlapping) sets 設計的資料結構

演算法

Time Complexity: O(node 的數量)
Space Complexity: O(node 的數量)

使用「秩(rank)」能在連結兩元素時所建立的樹較為平均

「秩(rank)」的定義如下:

  • 只有根節點的樹(即只有一個元素的集合),秩為0
  • 當兩棵秩不同的樹合併後,新的樹的秩為原來兩棵樹的秩的較大者;
  • 當兩棵秩相同的樹合併後,新的樹的秩為原來的樹的秩加一。

添加元素

def make_set(x):
  x.root = x

使用「秩」

def make_set(x): 
  x.root = x
  x.rank = 0

查找元素

def find(x):
  if x.root == x:
    return x

  x.root = find(x.root)
  return x.root

連結兩元素

def union(x, y):
  x_root = find(x)
  y_root = find(y)

  if x_root != y_root:
    x_root.root = y_root

使用「秩」

def union(x, y): 
  x_root = find(x) 
  y_root = find(y) 

  if x_root != y_root: 
    if x_root.rank < y_root.rank:
      large = y_root
      small = x_root 
    else:
      large = x_root 
      small = y_root
    
    small.root = large
    if large.rank = small.rank:
      large.rank += 1

包裝成 class

Python

無秩
class Dsu:
  def __init__(self, roots = None):
    self.root = {}
    if roots:
      for r in roots:
        self.make_set(r)

  def make_set(self, x): 
    self.root[x] = x

  def find(self, x):
    if self.root[x] == x:
      return x
    self.root[x] = self.find(self.root[x])
    return self.root[x]

  def union(self, x, y): 
    x_root = self.find(x) 
    y_root = self.find(y) 
    if x_root != y_root: 
      self.root[x_root] = y_root

  def are_in_same_set(self, x, y):
    return self.find(x) == self.find(y)

  def get_ds(self): # get disjoint sets
    disjoint_sets = collections.defaultdict(set)
    for key in self.root:
      disjoint_sets[self.find(key)].add(key)
    return disjoint_sets
有秩
class Dsu:
  def __init__(self, roots = None):
    self.root = {}
    self.rank = {}
    if roots:
      for r in roots:
        self.make_set(r)

  def make_set(self, x): 
    self.root[x] = x
    self.rank[x] = 0

  def find(self, x):
    if self.root[x] == x:
      return x
    self.root[x] = self.find(self.root[x])
    return self.root[x]

  def union(self, x, y): 
    x_root = self.find(x) 
    y_root = self.find(y) 
    if x_root != y_root: 
      if self.rank[x_root] < self.rank[y_root]:
        large = y_root
        small = x_root 
      else:
        large = x_root 
        small = y_root
    
      self.root[small] = large
      if self.rank[large] == self.rank[small]:
        self.rank[large] += 1

  def are_in_same_set(self, x, y):
    return self.find(x) == self.find(y)

  def get_ds(self): # get disjoint sets
    disjoint_sets = collections.defaultdict(set)
    for key in self.root:
      disjoint_sets[self.find(key)].add(key)
    return disjoint_sets

 

Ref: Wikipedia

經典例題

Last Updated on 2023/08/16 by A1go

目錄
Bitnami