首页 > 数据结构 > Hash表 > HDU 4812-D Tree-分治-[解题报告]HOJ
2015
09-18

HDU 4812-D Tree-分治-[解题报告]HOJ

D Tree

问题描述 :

There is a skyscraping tree standing on the playground of Nanjing University of Science and Technology. On each branch of the tree is an integer (The tree can be treated as a connected graph with N vertices, while each branch can be treated as a vertex). Today the students under the tree are considering a problem: Can we find such a chain on the tree so that the multiplication of all integers on the chain (mod 106 + 3) equals to K?
Can you help them in solving this problem?

输入:

There are several test cases, please process till EOF.
Each test case starts with a line containing two integers N(1 <= N <= 105) and K(0 <=K < 106 + 3). The following line contains n numbers vi(1 <= vi < 106 + 3), where vi indicates the integer on vertex i. Then follows N – 1 lines. Each line contains two integers x and y, representing an undirected edge between vertex x and vertex y.

输出:

There are several test cases, please process till EOF.
Each test case starts with a line containing two integers N(1 <= N <= 105) and K(0 <=K < 106 + 3). The following line contains n numbers vi(1 <= vi < 106 + 3), where vi indicates the integer on vertex i. Then follows N – 1 lines. Each line contains two integers x and y, representing an undirected edge between vertex x and vertex y.

样例输入:

5 60
2 5 2 3 3
1 2
1 3
2 4
2 5
5 2
2 5 2 3 3
1 2
1 3
2 4
2 5

样例输出:

3 4
No solution

Hint
1. “please print the lexicographically smallest one.”是指: 先按照第一个数字的大小进行比较,若第一个数字大小相同,则按照第二个数字大小进行比较,依次类推。 2. 若出现栈溢出,推荐使用C++语言提交,并通过以下方式扩栈: #pragma comment(linker,"/STACK:102400000,102400000")

题意:给出一棵树,让你寻找一条路径,使得路径上的点相乘mod10^6+3等于k,输出路径的两个端点,按照字典序最小输出。

思路:这类问题很容易想到树的分治,每次找出树的重心,以重心为根,将树分成若干棵子树,然后对于每棵子树再一样的操作,现在就需要求一重心为根,寻找路径,依次遍历每一个子树,然后记录子树中点到根的权值的乘积X,然后通过在哈希表中寻找K×逆元(x),看是否存在,存在则更新答案。

#pragma comment(linker,"/STACK:102400000,102400000")
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define mod 1000003
#define N 100005
typedef long long ll;

struct E{
    int v, d, ne;
    E() {}
    E(int _v, int _ne):v(_v), ne(_ne){}
}e[N*2];

bool vis[N];
int size, head[N], ans[2], root, flag[mod], F[mod], sum[N], mi, cr, id[N];
ll val[N], ni[mod], path[N];

void init() {
    size = 0;
    memset(vis, false, sizeof(vis));
    memset(ans, -1, sizeof(ans));
    memset(head, -1, sizeof(head));
    memset(flag, 0, sizeof(flag));
}

void add(int u, int v) {
    e[size] = E(v, head[u]);
    head[u] = size++;
}

void dfs(int u, ll k) {
    int i, v;
    sum[u] = 1;
    vis[u] = true, id[cr] = u;
    path[cr++] = k*val[u]%mod;
    ll tm = path[cr-1];
    for (i = head[u];~i;i = e[i].ne) {
        v = e[i].v;
        if (vis[v]) continue;
        dfs(v, tm);
        sum[u] += sum[v];
    }
    vis[u] = false;
}

ll k;
int n, ca;
void getans(int a, int b) {
    if (a > b) swap(a,b);
    if (ans[0] == -1 || ans[0] > a) ans[0] = a, ans[1] = b;
    else if (ans[0] == a && ans[1] > b) ans[1] = b;
}

void getroot(int u) {
    int i, v, mx = 0;
    sum[u] = 1;
    vis[u] = true;
    for (i = head[u];~i;i = e[i].ne) {
        v = e[i].v;
        if (vis[v]) continue;
        getroot(v);
        sum[u] += sum[v];
        mx = max(mx, sum[v]);
    }
    mx = max(mx, sum[0]-sum[u]);
    if (mx < mi) mi = mx, root = u;
    vis[u] = false;
}
void cal(int u, int cnt) {
    if (cnt == 1) return;
    int i, v, j;
    mi = n;
    sum[0] = cnt;
    getroot(u);
    vis[root] = true;
    for (i = head[root];~i;i = e[i].ne) {
        v = e[i].v;
        if (vis[v]) continue;
        cr = 0;
        dfs(v, 1);
        for (j = 0;j < cr;j++) {
            if (path[j]*val[root]%mod == k) getans(root, id[j]);
            ll tm = k*ni[path[j]*val[root]%mod]%mod;
            if (flag[tm] != ca) continue;
            getans(F[tm], id[j]);
        }
        for (j = 0;j < cr;j++) {
            int tm = path[j];
            if (flag[tm] != ca || F[tm] > id[j]) F[tm] = id[j], flag[tm] = ca;
        }
    }
    ca++;
    for (i = head[root];~i;i = e[i].ne) {
        if (vis[e[i].v]) continue;
        cal(e[i].v, sum[e[i].v]);
    }
}

ll egcd(ll a,ll b, ll &x, ll &y) {//得到的是a*x+b*y=gcd(a,b)
	ll temp,tempx;
	if (b == 0) {
		x = 1;y = 0;
		return a;
	}
	temp = egcd(b,a % b, x, y);
	tempx = x;
	x = y;
	y = tempx - a / b * y;
	return temp;
}

int main() {
    int u, v, i, j;
    ll y;
    for (i = 0;i < mod;i++) {
        egcd(i*1ll, mod*1ll, ni[i], y);
        ni[i] %= mod, ni[i] = (ni[i]+mod)%mod;
    }
    while (~scanf("%d%I64d", &n, &k)) {
        init();
        ca = 1;
        for (i = 1;i <= n;i++) scanf("%I64d", &val[i]);
        for (i = 1;i < n;i++) {
            scanf("%d%d", &u, &v);
            add(u, v), add(v, u);
        }
        cal(1, n);
        if (ans[0] == -1) puts("No solution");
        else printf("%d %d\n", ans[0], ans[1]);
    }
}

参考:http://blog.csdn.net/u010089558/article/details/38108101