본문 바로가기

알고리즘/소프티어

[Softeer] [21년 재직자 대회 본선] 거리 합 구하기 / problemid 635 [JAVA]

문제링크 : https://softeer.ai/practice/info.do?eventIdx=1&psProblemId=635&sw_prbl_sbms_sn=39796 

 

Softeer

Problem을 담을 Box를 선택해 주세요. 취소 확인

softeer.ai

 

모든 노드간 연결돼 있고 간선이 N-1개인 그래프는 트리다

 

모든 노드마다 다른 모든 노드와 거리의 합을 구한다면 N이 최대 2*10^5 이므로 단순하게는 불가능 해보인다

 

1번 노드를 루트로 하고 노드간 거리의 합을 먼저 구해보자

 

 

트리

1번 노드를 루트로 하는 트리

1번에서 3번,5번,6번 노드로 가기 위해서는 빨간 간선을 지나가야 한다. 그러므로 빨간 간선은 2*3인 6을 정답값에 더해주는 간선이다. 마찬가지로 가중치가 5인 간선은 노드 2만을 위해 지나는 간선이므로 5*1인 5를 더해준다. 

가중치가 4인 간선 : 4*1

가중치가 1인 간선 : 1*1

가중치가 8인 간선 : 8*2

가중치가 6인 간선 : 6*1

다 더하면 38이며 이것이 노드 1에서 모든 노드간 거리의 합이다

 

그 다음 3번 노드를 루트로 하는 트리

3번 노드는 (1), (2), (4), (7)번 노드로 가기 위해 빨간 간선을 지난다. 

그리고 (5), (6)번 노드는 1번 노드를 루트로 할 때와 마찬가지로 한번만 지난다.

마찬가지로, (2), (4), (7)번 노드로 가는 간선도 똑같은 횟수를 지난다.

달라지는 것은 빨간 간선을 지나는 횟수이다. (1번 노드를 루트로 할 때는 3번, 3번 노드를 루트로 할 때는 4번)

노트에 그리면서 알게된 것은 서브트리를 구성하는 노드의 수에 무언가 있어보였다.

 

 - 1번 노드가 루트일 때

1번 노드를 루트로 하는 서브트리의 노드의 개수는 7(전체 N)이다.

3번 노드를 루트로 하는 서브트리의 노드의 개수는 3이다.

1번과 3번 노드를 이어주는 간선(빨간 간선)은 1번 노드가 루트일 때는 3회(1번 노드에서 3번 노드를 포함하여 그 자식 노드(서브트리의 노드)에 도달하기 위해서 반드시 지나야 하는 간선을 지나는 횟수) 지난다.

 

 - 3번 노드가 루트일 때

3번 노드를 루트로 하는 서브트리의 노드의 개수는 7이다.

1번 노드를 루트로 하는 서브트리의 노드의 개수는 4이다.

1번과 3번 노드를 이어주는 간선은 3번 노드가 전체 트리의 루트일 때 4회 지난다.

 

나머지 간선들은 1번 노드와 3번 노드가 전체 트리의 루트일 때, 같은 횟수로 지난다.

 

다른 노드들간의 관계도 확인 해보자

(1)

3번 노드가 전체 트리의 루트일 때, 6번 노드 사이에 있는 간선을 1회 지난다.

6번 노드가 전체 트리의 루트일 때에는 3번 노드 사이에 있는 간선을 6회 지난다.

 

(2)

4번 노드가 전체 트리의 루트일 때, 1번 노드 사이에 있는 간선을 5회 지난다.

1번 노드가 전체 트리의 루트일 때, 4번 노드 사이에 있는 간선을 2회 지난다.

 

 어떤 간선은 그 간선으로 이어진 노드간 서로 루트일 때 지나는 횟수의 합은 N이다.

또, A가 전체 트리의 루트이고 그 자식인 B가 루트인 서브트리의 노드의 수가 x이면 A와 B 사이의 간선을

x번 지나며 B가 전체 트리의 루트일 때를 고려하지 않아도 그 간선을 N-x번 지나간다는 것을 알 수 있다.

또, 나머지 간선을 몇번 지나는지 구하지 않더라도 전체 노드간 거리의 합을 구할 수 있다는 것이다. 

 

위 규칙은 부모와 자식간 이어진 간선에서의 관계에 있다. 그래서 임의로 1을 루트로 하는 트리에서 모든 노드간 간선의 합을 먼저 구했다. 그 다음, 자식 노드에서의 거리의 합을 구하기 위해 parent가 전체 트리의 루트일 때와 현재노드인 child노드가 전체 트리의 루트일 때 parent와 child 사이에 있는 간선을 각각 몇번 지나는지 차이를 구했다. 그리고 나머지 간선은 parent와 같으므로 parent노드의 모든 간선간의 합 결과를 그대로 이용했다. 이것을 반복하면 해결된다. 

 

import java.io.*;
import java.util.*;
public class Main {
    public static int N;
    public static long ans[];
    public static int parent[];
    public static int childCnt[];
    public static class Node{
        int num;
        long cost;
        public Node(){}
        public Node(int a,long b){
            num = a;
            cost = b;
        }
    }
    public static class Info{
        int branchCnt;
        long costSum;
    }

    public static ArrayList<Node> adj[];

    public static Info dfs(int curNode, int par){
        Info ret = new Info();
        parent[curNode] = par;
        for(int i=0;i<adj[curNode].size();i++){
            int nextNode = adj[curNode].get(i).num;
            long cost = adj[curNode].get(i).cost;
            if(nextNode==par || nextNode==par) continue;
            Info child = dfs(nextNode,curNode);
            ret.costSum += cost * child.branchCnt + child.costSum;
            ret.branchCnt += child.branchCnt;
        }
        ret.branchCnt +=1;
        childCnt[curNode] = ret.branchCnt;
        return ret;
    }
    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        N = Integer.parseInt(br.readLine());
        adj = new ArrayList[N+1];
        ans = new long[N+1];
        parent = new int[N+1];
        childCnt = new int[N+1];
        for(int i=1;i<=N;i++){
            adj[i] = new ArrayList<Node>();
        }

        for(int i=0;i<N-1;i++){
            StringTokenizer st = new StringTokenizer(br.readLine());
            int s = Integer.parseInt(st.nextToken());
            int e= Integer.parseInt(st.nextToken());
            long c = Integer.parseInt(st.nextToken());
            adj[s].add(new Node(e,c));
            adj[e].add(new Node(s,c));
        }
        Info root = dfs(1,0);
        ans[1] = root.costSum;
        Queue<Node> q = new LinkedList<Node>();
        for(int i=0;i<adj[1].size();i++){
            q.add(new Node(adj[1].get(i).num,0));
        }
        while(!q.isEmpty()){
            Node curNode = q.poll();
            for(int i=0;i<adj[curNode.num].size();i++){
                int nextNode = adj[curNode.num].get(i).num;
                long cost = adj[curNode.num].get(i).cost;

                if(parent[curNode.num]==nextNode){
                    ans[curNode.num]+= ans[nextNode] + (N-2*childCnt[curNode.num])*cost;
                }
                if(parent[curNode.num] == nextNode) continue;
                q.add(new Node(nextNode,0));
            }
        }
        for(int i=1;i<=N;i++){
            System.out.println(ans[i]);
        }
    }
}