#include using namespace std; typedef long long LL; const int mod = 1000000007, lim = 100000; int read() { int x = 0; char ch = getchar(); while (ch < '0' || ch > '9') ch = getchar(); while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar(); return x; } LL qpow(LL x, LL cnt) { LL ret = 1, bas = x, c = cnt; while (c) { if (c & 1) ret = ret * bas % mod; bas = bas * bas % mod; c >>= 1; } return ret; } int c, T, n, k, e[100005], deg[100005]; LL fac[100005], inv[100005], invf[100005], dp[100005][3], ans, al; bool is[100005]; vector > g[100005]; void Fplus(LL &x, LL y) { if ((x += y) >= mod) x -= mod; } void Fminus(LL &x, LL y) { if ((x -= y) < 0) x += mod; } void init() { fac[0] = invf[0] = 1; for (int i = 1; i <= lim; i++) fac[i] = fac[i - 1] * i % mod; invf[lim] = qpow(fac[lim], mod - 2); for (int i = lim - 1; i; i--) invf[i] = invf[i + 1] * (i + 1) % mod; for (int i = 1; i <= lim; i++) inv[i] = invf[i] * fac[i - 1] % mod; } void dfs1(int u, int f) { // printf("%d\n", u); for (auto v : g[u]) { if (v.first != f) dfs1(v.first, u); } for (auto v : g[u]) { int x = v.first, id = v.second; if (x == f) continue; if (is[id]) { Fplus(dp[u][2], inv[deg[u] - 1]); Fplus(dp[u][1], inv[deg[u] - 1] * dp[x][0] % mod); Fplus(dp[u][0], inv[deg[u] - 1] * (dp[x][1] + dp[x][2]) % mod); } Fplus(dp[u][2], dp[x][2] * inv[deg[u] - 1] % mod); Fplus(dp[u][1], dp[x][1] * inv[deg[u] - 1] % mod); Fplus(dp[u][0], dp[x][0] * inv[deg[u] - 1] % mod); } // printf("[%d %lld %lld %lld]\n", u, dp[u][0], dp[u][1], dp[u][2]); } void dfs2(int u, int f) { for (auto v : g[u]) { if (v.first != f) dfs2(v.first, u); } static LL tmp[2][3]; memset(tmp, 0, sizeof tmp); for (auto v : g[u]) { int x = v.first, id = v.second; if (x == f) continue; for (int i = 0; i < 3; i++) Fplus(tmp[0][i], dp[x][i]); if (is[id]) for (int i = 0; i < 3; i++) Fplus(tmp[1][i], dp[x][i]); } static LL h[2][3]; memcpy(h, tmp, sizeof tmp); static LL ret[2]; ret[0] = ret[1] = 0; int cnt = 0; for (auto v : g[u]) { // cross u int x = v.first, id = v.second; if (x == f) continue; for (int i = 0; i < 3; i++) Fminus(tmp[0][i], dp[x][i]); if (is[id]) { ++cnt; for (int i = 0; i < 3; i++) Fminus(tmp[1][i], dp[x][i]); // choose 2 Fplus(ret[0], dp[x][0] * tmp[1][0] % mod * inv[deg[u] - 1] % mod); Fplus(ret[0], (dp[x][1] + dp[x][2]) % mod * (tmp[1][1] + tmp[1][2]) % mod * inv[deg[u] - 1] % mod); Fplus(ret[1], dp[x][0] * (tmp[1][1] + tmp[1][2]) % mod * inv[deg[u] - 1] % mod); Fplus(ret[1], (dp[x][1] + dp[x][2]) * tmp[1][0] % mod * inv[deg[u] - 1] % mod); } // choose 0 Fplus(ret[0], dp[x][0] * tmp[0][0] % mod * inv[deg[u] - 1] % mod); Fplus(ret[1], dp[x][0] * (tmp[0][1] + tmp[0][2]) % mod * inv[deg[u] - 1] % mod); Fplus(ret[1], (dp[x][1] + dp[x][2]) * tmp[0][0] % mod * inv[deg[u] - 1] % mod); Fplus(ret[0], (dp[x][1] + dp[x][2]) % mod * (tmp[0][1] + tmp[0][2]) % mod * inv[deg[u] - 1] % mod); } for (auto v : g[u]) { // choose 1 int x = v.first, id = v.second; if (x == f) continue; if (is[id]) { for (int i = 0; i < 3; i++) Fminus(h[0][i], dp[x][i]); Fplus(ret[1], dp[x][0] * h[0][0] % mod * inv[deg[u] - 1] % mod); Fplus(ret[0], dp[x][0] * (h[0][1] + h[0][2]) % mod * inv[deg[u] - 1] % mod); Fplus(ret[0], (dp[x][1] + dp[x][2]) * h[0][0] % mod * inv[deg[u] - 1] % mod); Fplus(ret[1], (dp[x][1] + dp[x][2]) % mod * (h[0][1] + h[0][2]) % mod * inv[deg[u] - 1] % mod); Fplus(ret[1], h[0][0] * inv[deg[u] - 1] % mod); Fplus(ret[0], (h[0][1] + h[0][2]) * inv[deg[u] - 1] % mod); for (int i = 0; i < 3; i++) Fplus(h[0][i], dp[x][i]); } } for (auto v : g[u]) { // top is u int x = v.first, id = v.second; if (x == f) continue; if (is[id]) { Fplus(ret[0], dp[x][1]); Fplus(ret[1], dp[x][0]); Fplus(ret[0], dp[x][2]); Fplus(ret[1], (dp[x][1] + dp[x][2]) * (cnt - 1) % mod * inv[deg[u] - 1] % mod); Fplus(ret[0], dp[x][0] * (cnt - 1) % mod * inv[deg[u] - 1] % mod); } } Fplus(ret[0], 1ll * (cnt - 1) * cnt / 2 % mod * inv[deg[u] - 1] % mod); // printf("[%d %lld %lld]\n", u, ret[0], ret[1]); Fplus(ans, al * ret[1] % mod); Fminus(ans, al * ret[0] % mod); } void work() { n = read(), k = read(); for (int i = 1, u, v; i < n; i++) { u = read(), v = read(); g[u].push_back({v, i}), g[v].push_back({u, i}); } for (int i = 1; i <= n; i++) deg[i] = g[i].size(); for (int i = 1; i <= k; i++) { e[i] = read(); is[e[i]] = true; } al = 1; for (int i = 1; i <= n; i++) al = al * fac[deg[i] - 1] % mod; ans = al * k % mod; dfs1(1, 0); dfs2(1, 0); printf("%lld\n", ans); for (int i = 1; i <= n; i++) g[i].clear(), dp[i][0] = dp[i][1] = dp[i][2] = 0; for (int i = 1; i <= k; i++) is[e[i]] = 0; } int main() { freopen("traverse.in", "r", stdin); freopen("traverse.out", "w", stdout); c = read(), T = read(); init(); while (T--) work(); return 0; }