최대 1 분 소요

문제링크

풀이

import sys

input = lambda: sys.stdin.readline().strip()

n, m = map(int, input().split())

graph = [[] for _ in range(n + 1)]
roots = [i for i in range(n + 1)]

def find(x):
    if roots[x] == x:
        return x
    else:
        roots[x] = find(roots[x])
        return roots[x]


def union(x, y):
    x = find(x)
    y = find(y)

    if y == x:
        return
    else:
        if y < x:
            roots[x] = y
        else:
            roots[y] = x



for _ in range(m):
    u, v = map(int, input().split())
    graph[u].append(v)
    graph[v].append(u)
    union(u, v)
    union(v, u)


dic = dict()
for i in range(1, n + 1):
    if roots[i] not in dic:
        dic[roots[i]] = [i]
    else:
        dic[roots[i]].append(i)

ans = 1
for key in dic.keys():
    if len(dic[key]) != 0:
        ans = ans * len(dic[key])
        ans = ans % 1000000007

print(ans % 1000000007)

입력으로 주어지는 그래프가 포레스트 형태(사이클이 없는 그래프) 즉 트리들의 집합이다.

또한, 전달자가 전달하는 경우의 수를 최소화해야하므로, 각 트리마다 트리를 구성하는 노드의 수만큼 경우의 수가 존재한다.

그렇다면 유니온 파인드로 각 트리를 분별한 후 각 트리들을 구성하는 노드의 수를 구해준 후 곱해주면 답이다.

댓글남기기