fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. #define all(v) (v).begin(), (v).end()
  5.  
  6. const int N = 1e6 + 5;
  7.  
  8. int n, T, M;
  9. vector<int> adj[N];
  10.  
  11. int par[N], dep[N];
  12. int weightNode[N];
  13.  
  14. vector<pair<int,int>> subtrees;
  15.  
  16. /*
  17.   weightNode[u]:
  18.   số bước tối thiểu để thắng
  19.   nếu chuột đang ở u
  20.   và chỉ được đi xuống subtree.
  21. */
  22.  
  23. void dfsWeight(int u, int p)
  24. {
  25. par[u] = p;
  26.  
  27. vector<int> child;
  28.  
  29. for(int v : adj[u]){
  30. if(v == p) continue;
  31.  
  32. dep[v] = dep[u] + 1;
  33.  
  34. dfsWeight(v, u);
  35.  
  36. child.push_back(weightNode[v]);
  37. }
  38.  
  39. if(child.empty()){
  40. weightNode[u] = 0;
  41. return;
  42. }
  43.  
  44. sort(all(child), greater<int>());
  45.  
  46. if((int)child.size() == 1)
  47. weightNode[u] = child[0] + 1;
  48. else
  49. weightNode[u] = child[1] + 1;
  50. }
  51.  
  52. /*
  53.   Check có thể thắng trong <= X bước không.
  54. */
  55.  
  56. bool check(int X)
  57. {
  58. vector<pair<int,int>> jobs;
  59.  
  60. int u = M;
  61. int dist = 0;
  62.  
  63. while(u != T){
  64.  
  65. for(int v : adj[u]){
  66. if(v == par[u]) continue;
  67.  
  68. jobs.push_back({dist, weightNode[v]});
  69. }
  70.  
  71. u = par[u];
  72. ++dist;
  73. }
  74.  
  75. /*
  76.   jobs:
  77.   (deadline, weight)
  78.  
  79.   cần block trước khi:
  80.   blocked + weight > X
  81.   */
  82.  
  83. sort(all(jobs));
  84.  
  85. int blocked = 0;
  86.  
  87. for(auto [d, w] : jobs){
  88.  
  89. if(d + blocked > X)
  90. return false;
  91.  
  92. if(d + blocked + w > X){
  93. ++blocked;
  94. }
  95. }
  96.  
  97. return true;
  98. }
  99.  
  100. int main()
  101. {
  102. ios::sync_with_stdio(false);
  103. cin.tie(nullptr);
  104.  
  105. cin >> n >> T >> M;
  106.  
  107. for(int i = 1; i < n; ++i){
  108. int u, v;
  109. cin >> u >> v;
  110.  
  111. adj[u].push_back(v);
  112. adj[v].push_back(u);
  113. }
  114.  
  115. dfsWeight(T, 0);
  116.  
  117. int L = 0, R = n, ans = n;
  118.  
  119. while(L <= R){
  120. int mid = (L + R) >> 1;
  121.  
  122. if(check(mid)){
  123. ans = mid;
  124. R = mid - 1;
  125. }
  126. else{
  127. L = mid + 1;
  128. }
  129. }
  130.  
  131. cout << ans << '\n';
  132.  
  133. return 0;
  134. }
  135.  
Success #stdin #stdout 0.01s 30920KB
stdin
10 1 4
1 2
2 3
2 4
3 9
3 5
4 7
4 6
6 8
7 10
stdout
2