You are given an integer array nums
, and you can perform the following operation any number of times on nums
:
- Swap the positions of two elements
nums[i]
andnums[j]
ifgcd(nums[i], nums[j]) > 1
wheregcd(nums[i], nums[j])
is the greatest common divisor ofnums[i]
andnums[j]
.
Return true
if it is possible to sort nums
in non-decreasing order using the above swap method, or false
otherwise.
Example 1:
Input: nums = [7,21,3] Output: true Explanation: We can sort [7,21,3] by performing the following operations: - Swap 7 and 21 because gcd(7,21) = 7. nums = [21,7,3] - Swap 21 and 3 because gcd(21,3) = 3. nums = [3,7,21]
Example 2:
Input: nums = [5,2,6,2] Output: false Explanation: It is impossible to sort the array because 5 cannot be swapped with any other element.
Example 3:
Input: nums = [10,5,9,3,15] Output: true We can sort [10,5,9,3,15] by performing the following operations: - Swap 10 and 15 because gcd(10,15) = 5. nums = [15,5,9,3,10] - Swap 15 and 3 because gcd(15,3) = 3. nums = [3,5,9,15,10] - Swap 10 and 15 because gcd(10,15) = 5. nums = [3,5,9,10,15]
Constraints:
1 <= nums.length <= 3 * 104
2 <= nums[i] <= 105
Solution: Union-Find
Let nums[j]’s target position be i. In order to put nums[j] to pos i by swapping. nums[i] and nums[j] must be in the same connected component. There is an edge between two numbers if they have gcd > 1.
We union two numbers if their have gcd > 1. However, it will be TLE if we do all pairs . Thus, for each number, we union it with its divisors instead.
Time complexity: O(n2) TLE -> O(sum(sqrt(nums[i]))) <= O(n*sqrt(m))
Space complexity: O(n)
C++
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
// Author: Huahua class Solution { public: bool gcdSort(vector<int>& nums) { const int m = *max_element(begin(nums), end(nums)); const int n = nums.size(); vector<int> p(m + 1); iota(begin(p), end(p), 0); function<int(int)> find = [&](int x) { return p[x] == x ? x : (p[x] = find(p[x])); }; for (int x : nums) for (int d = 2; d <= sqrt(x); ++d) if (x % d == 0) p[find(x)] = p[find(x / d)] = find(d); vector<int> sorted(nums); sort(begin(sorted), end(sorted)); for (int i = 0; i < n; ++i) if (find(sorted[i]) != find(nums[i])) return false; return true; } }; |