首页 > 搜索 > DFS搜索 > HDU 4603-Color the Tree-动态规划-[解题报告]HOJ
2015
09-17

HDU 4603-Color the Tree-动态规划-[解题报告]HOJ

Color the Tree

问题描述 :

Alice and Bob are playing games again! This time, they invent a new game called "Color the Tree". They draw a tree with N nodes and, certainly, (N-1) edges connecting them to assure a path between each pair of the nodes. But the tree they play with is a little special – each edge is assigned to a color and a specific value. Initially, the value of each edge is settled while the colors are all white.
When the game starts, Alice and Bob each choose a node as her/his starting node. In each round, Alice firstly makes a move from her current node to another through an edge with a color of white or red, and if the edge is white, she colors it to red. After that, Bob makes a similar move through a white or blue edge from his current node, and if the edge is white, he colors it to blue. The game keeps going on until all the edges are colored to either red or blue.
Alice’s goal is to maximize the sum of values of the red edges, and Bob, wants to maximize that of the blue edges. Given the starting node of them, figure out the maximum sum that Alice is able to obtain if both of them take the best strategy in each round.

输入:

The first line of the input contains a single integer T, indicating there are T test cases.
In each case, the first line contains two integers N and M, which denotes the number of nodes and queries.
Each of the following (N-1) lines contains a triple of integers (a,b,c), indicating there is an edge connecting node a and node b with a value of c.
The following M lines describe the queries. Each of the M lines consists of two integers A and B, indicating the starting node of Alice and Bob, respectively.
(2≤N≤100000,1≤M≤100000,1≤a,b,A,B≤N,1≤c≤1000)

输出:

The first line of the input contains a single integer T, indicating there are T test cases.
In each case, the first line contains two integers N and M, which denotes the number of nodes and queries.
Each of the following (N-1) lines contains a triple of integers (a,b,c), indicating there is an edge connecting node a and node b with a value of c.
The following M lines describe the queries. Each of the M lines consists of two integers A and B, indicating the starting node of Alice and Bob, respectively.
(2≤N≤100000,1≤M≤100000,1≤a,b,A,B≤N,1≤c≤1000)

样例输入:

2
2 1
1 2 3
1 2
3 2
1 2 3
1 3 1
2 3
1 3

样例输出:

3
3
4
Hint
For C++ user, "scanf" is recommended.

这道题细节真的很多

首先可以想到a和b的最优策略一定是沿着a和b在树上的链走,走到某个点停止,然后再依次占领和这个点邻接的边

所以,解决这道题的步骤如下:

预处理阶段:

step 1:取任意一个点为根节点,找出父子关系并且对这个树进行dp,求出从某个节点出发往下所包含的所有边的权值总和  复杂度O(n)

step 2:从tree dp 的结果中计算对于某个节点,从某条边出发所包含的边的综合,并且对其从大到小进行排序 复杂度O(n*logn)

step 3:dfs求出这颗树的欧拉回路,以及每个点的dfn,并且按欧拉回路的顺序计算每个节点的深度 复杂度O(2*n)

step 4:利用sparse table算法初始化step 3中的深度序列 复杂度 O(n*logn)

step 5:计算出从某个节点往上走2的n次方步所到达的节点  复杂度O(n*logn)

查询阶段:

关键是找到两点的 LCA 以及相遇点,并且找到一条或两条所经过且和相遇点邻接的边

分几种情况讨论

1. 两个点在一起 

2.两个点之间的距离为1

3.dep[a] == dep[b]

4.dep[a] > dep[b] + 1

5.dep[a] < dep[b]

6.dep[a] == dep[b]+1 

ps:少考虑第六种情况wa了一个下午

#pragma comment(linker, "/STACK:102400000,102400000")
#include<iostream>
#include<vector>
#include<algorithm>
#include<cstdio>
#include<queue>
#include<stack>
#include<string>
#include<map>
#include<set>
#include<cmath>
#include<cassert>
#include<cstring>
#include<iomanip>
#include<ctime>
using namespace std;
#ifdef _WIN32
typedef __int64 i64;
#define out64 "%I64d\n"
#define in64 "%I64d"
#else
typedef long long i64;
#define out64 "%lld\n"
#define in64 "%lld"
#endif
/************ for topcoder by zz1215 *******************/
#define foreach(c,itr)  for(__typeof((c).begin()) itr=(c).begin();itr!=(c).end();itr++)
#define FOR(i,a,b)      for( int i = (a) ; i <= (b) ; i ++)
#define FF(i,a)         for( int i = 0 ; i < (a) ; i ++)
#define FFD(i,a,b)      for( int i = (a) ; i >= (b) ; i --)
#define S64(a)          scanf(in64,&a)
#define SS(a)           scanf("%d",&a)
#define LL(a)           ((a)<<1)
#define RR(a)           (((a)<<1)+1)
#define pb              push_back
#define pf              push_front
#define X               first
#define Y               second
#define CL(Q)           while(!Q.empty())Q.pop()
#define MM(name,what)   memset(name,what,sizeof(name))
#define MC(a,b)		memcpy(a,b,sizeof(b))
#define MAX(a,b)        ((a)>(b)?(a):(b))
#define MIN(a,b)        ((a)<(b)?(a):(b))
#define read            freopen("in.txt","r",stdin)
#define write           freopen("out.txt","w",stdout)

const int inf = 0x3f3f3f3f;
const long long inf64 = 0x3f3f3f3f3f3f3f3fLL;
const double oo = 10e9;
const double eps = 10e-9;
const double pi = acos(-1.0);
const int maxn = 100111;
const int maxlevel = 21;

struct Node
{
	int now;
	int to;
	int c;
	int tot;
	int ss;
	const bool operator <  (const Node& cmp) const {
		return tot > cmp.tot;
	}
};

int all;
int n, m;
vector<Node>g[maxn];
int t[maxn];
int dep[maxn];

int df;
int dfn[maxn];
int dfv[maxn * 2];
int st[maxn * 2][maxlevel];
int up[maxn][maxlevel];

int dp[maxn];   //down sum 
int xtof[maxn];
int ftox[maxn];

int vis[maxn];

int lg2[maxn*2];

void dfs(int now)
{
	vis[now] = true;
	int to;
	for (int i = 0; i < (int)g[now].size(); i++) {
		to = g[now][i].to;
		if (!vis[to])	{
			t[to] = now;
			dfs(to);
		}
	}
}

int treedp(int now)
{	
	int to,id;
	dp[now] = 0;
	for (int i = 0; i < (int)g[now].size(); i++) {
		to = g[now][i].to;
		if (to != t[now]) {
			int temp = treedp(to) + g[now][i].c;
			g[now][i].tot = temp;
			dp[now] += temp;
		}
		else {
			id = i;
		}
	}
	if (t[now] != -1) {
		g[now][id].tot = all - dp[now];
	}
	return dp[now];
}

void euler_circuit(int now ,int step)
{
	dep[now] = step;
	dfn[now] = df;
	dfv[df++] = now;
	int to;
	for (int i = 0; i < (int)g[now].size(); i++) {
		to = g[now][i].to;
		if (to != t[now]) {
			euler_circuit(to,step+1);
			dfv[df++] = now;
		}
	}
}

void get_up_node()
{
	for (int i = 1; i <= n; i++) {
		up[i][0] = t[i];
	}
	int to;
	for (int step = 1; step < maxlevel; step++) {
		for (int now = 1; now <= n; now++) {
			to = up[now][step - 1];
			if (to == -1) {
				up[now][step] = -1;
			}
			else {
				up[now][step] = up[to][step - 1];
			}
		}
	}
}

void sparse_table()
{
	for (int i = 1; i < df; i++){
		st[i][0] = dep[dfv[i]];
	}

	int to;
	for (int step = 1; step <= lg2[n] + 1; step++){
		for (int now = 1; now < df; now++) {
			to = now + (1 << (step - 1));
			if (to < df){
				st[now][step] = min(st[now][step - 1], st[to][step - 1]);
			}
			else{
				st[now][step] = st[now][step - 1];
			}
		}
	}
}

void relation()
{
	int to;
	for (int now = 1; now <= n; now++){
		for (int i = 0; i < (int)g[now].size(); i++){
			to = g[now][i].to;
			if (to == t[now]){
				xtof[now] = i;
			}
			else{
				ftox[to] = i;
			}
		}
	}
}

int rmq(int l,int r)
{
	return min(st[l][lg2[r - l]], st[r - (1 << lg2[r - l])][lg2[r - l]] );
}

int calculate(int x,bool first,int id1,int id2=-1)
{
	if (id2 != -1){
		if (id1 > id2){
			swap(id1, id2);
		}
	}
	int sum = g[x][0].ss;
	sum -= g[x][id1].tot;
	if (id2 != -1){
		sum -= g[x][id2].tot;
	}
	int size = (int)g[x].size() - 1;
	if (size >= 1){
		sum += g[x][1].ss;
	}
	int ans = g[x][0].ss;
	if (id1 % 2 ){
		if (id1 + 1 <= size){
			ans -= g[x][id1 + 1].ss;
			if (id1 + 2 <= size){
				ans += g[x][id1 + 2].ss;
			}
		}
		if (id2 != -1){
			if (id2 % 2){
				ans -= g[x][id2].ss;
				if (id2 + 1 <= size){
					ans += g[x][id2 + 1].ss;
				}
			}
			else{
				if (id2 + 1 <= size){
					ans -= g[x][id2 + 1].ss;
					if (id2 + 2 <= size){
						ans += g[x][id2 + 2].ss;
					}
				}
			}
		}
	}
	else{
		ans -= g[x][id1].ss;
		if (id1 + 1 <= size){
			ans += g[x][id1 + 1].ss;
		}
		if (id2 != -1){
			if (id2 % 2){
				ans -= g[x][id2].ss;
				if (id2 + 1 <= size){
					ans += g[x][id2 + 1].ss;
				}
			}
			else{
				if (id2 + 1 <= size){
					ans -= g[x][id2 + 1].ss;
					if (id2 + 2 <= size){
						ans += g[x][id2 + 2].ss;
					}
				}
			}
		}
	}
	if (first) return ans;
	else return sum - ans;
}

int go_up(int now, int x)
{
	int step = 0;
	while (x) {
		if (x & 1) {
			now = up[now][step];
		}
		step++;
		x >>= 1;
	}
	return now;
}

int find(int a,int b)
{
	int l = dfn[a];
	int r = dfn[b];
	if (l == r){
		return g[a][0].ss;
	}
	if (l > r){
		swap(l, r);
	}
	int lca = rmq(l, r + 1);   //dep
	if (dep[a] - lca + dep[b] - lca == 1){
		if (dep[a] == lca){
			return g[b][xtof[b]].tot + calculate(b, false, xtof[b]);
		}
		else if (dep[b] == lca){
			return g[b][ftox[a]].tot + calculate(b, false, ftox[a]);
		}
	}
	else if (dep[a] > dep[b]+1){
		int temp = dep[a] - dep[b];
		int mid = lca + temp / 2;
		int child = go_up(a, dep[a] - mid - 1);
		if (temp % 2){
			return g[t[child]][ftox[child]].tot + calculate(t[child], false, ftox[child], xtof[t[child]]);
		}
		else{
			return g[t[child]][ftox[child]].tot + calculate(t[child], true, ftox[child], xtof[t[child]]);
		}
	}
	else if (dep[a] == dep[b] + 1) {
		int ca = go_up(a, dep[a] - lca - 1);
		int cb = go_up(b, dep[b] - lca - 1);
		int meet = t[ca];
		return g[meet][ftox[ca]].tot + calculate(meet, false, ftox[ca], ftox[cb]);
	}
	else if (dep[a] < dep[b]){
		int temp = dep[b] - dep[a];
		int mid = lca + (temp + 1)/ 2;
		int child = go_up(b, dep[b] - mid - 1);
		if (temp % 2){
			return g[t[child]][xtof[t[child]]].tot + calculate(t[child], false, xtof[t[child]], ftox[child]);
		}
		else{
			return g[t[child]][xtof[t[child]]].tot + calculate(t[child], true, xtof[t[child]], ftox[child]);
		}
	}
	else if(dep[a] == dep[b]) {
		int ca = go_up(a, dep[a] - lca - 1);
		int cb = go_up(b, dep[b] - lca - 1);
		int meet = t[ca];
		return g[meet][ftox[ca]].tot + calculate(meet, true, ftox[ca], ftox[cb]);
	}
	assert(false);	
}

void start()
{
	for (int i = 1; i <= n; i++) {
		vis[i] = false;
	}
	t[0] = t[1] = -1;
	dfs(1);
	treedp(1);

	for (int now = 1; now <= n; now++) {
		sort(g[now].begin(), g[now].end());
		for (int i =(int) g[now].size() - 1; i >= 0; i--) {
			g[now][i].ss = g[now][i].tot;
			if (i + 3 <= (int)g[now].size()) {
				g[now][i].ss += g[now][i + 2].ss;
			}
		}
	}
	df = 1;
	euler_circuit(1, 0);
	get_up_node();
	sparse_table();
	relation();
}

int main()
{
	for (int i = 0; i < maxlevel; i++){
		if ( (1<<i) < maxn*2)
		lg2[1 << i] = i;
	}
	for (int i = 3; i < maxn*2; i++) {
		if (!lg2[i]){
			lg2[i] = lg2[i - 1];
		}
	}

	int T;
	cin >> T;
	while (T--) {
		all = 0;
		cin >> n >> m;
		for (int i = 0; i <= n; i++){
			g[i].clear();
		}
		Node node;
		for (int i = 1; i <= n - 1; i++) {
			//cin >> node.now >> node.to >> node.c;
			SS(node.now);
			SS(node.to);
			SS(node.c);
			g[node.now].push_back(node);
			swap(node.now, node.to);
			g[node.now].push_back(node);
			all += node.c;
		}
		start();
		int a, b;
		for (int i = 1; i <= m; i++){
			//cin >> a >> b;
			SS(a); SS(b);
			//cout << find(a, b) << endl;
			printf("%d\n", find(a, b));
		}
	}
	return 0;
}

版权声明:本文为博主原创文章,未经博主允许不得转载。

参考:http://blog.csdn.net/zz_1215/article/details/38070643