#include <bits/stdc++.h>
using namespace std;
int n, a[100050], x, y, h = 0, s, ans = 0;
vector<int> e[100050];
int vis[100050];
void dfs(int c, int f)
{
for (int i = 0; i < e[c].size(); i++)
{
int v = e[c][i];
if (v == f)
continue;
dfs(v, c);
vis[c] += vis[v];
}
}
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
scanf("%d", a + i);
if (a[i])
{
h++;
vis[i] = 1;
s = i;
}
}
for (int i = 1; i < n; i++)
{
scanf("%d %d", &x, &y);
e[x].push_back(y);
e[y].push_back(x);
}
dfs(s, 0);
for (int i = 1; i <= n; i++)
if (vis[i])
ans++;
cout << ans - h;
return 0;
}