16BJWC结业测试T1 - 数星星 star
16北京冬令营结业测试第一题。
做法好迷。
题目描述
给定二维平面上的 N 个点,求距离第 K 小的点对间距离。 其中 1 <= K, N <= 1E5,点座标非负且为 int。
做法
听说看 std 可以涨智商。
- 读入点。
 - 预处理:对点集进行遍历,找到一个小于 30 的 L 值 ,将各点座标 (x, y) 进行 trim 变换 (x » L, y » L) ,使得各点在 trim 座标系下邻接九宫格内元素个数和大于等于 K+1(换句话说,用玄学把非常离散的数据集中一下,再建个哈希索引)
 - 玄学搜索:建二维 vector 存点(第一维是哈希化的座标),对点集 (x, y) 进行遍历, 取其 trim 后的索引 (x » L, y » L) 在邻接 49 宫格所属的 vector 里进行搜索。
 - 维护一个堆,堆中存前 K+1 小的最近距离。那么分为这样两种情况
    
- 当前搜索宫格到中心宫格的最小距离大于堆顶(目前第 K+1 小元素),直接剪枝。(类似KD树的剪枝)
 - 遍历当前搜索宫格中的元素,计算到搜索点之间的距离,更新堆中答案。
 
 - 把当前遍历点扔到它对应的二维 vector 中。
 
icecathy 大佬说当时讲评的时候说的是切格子,然后我看代码理解出来是这样的 ↑
代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <cstring>
using namespace std;
const int MAX_N = 2E5;
typedef unsigned long long ULL;
typedef long long LL;
int N, K;
struct point {
    int x, y;
    point(int a, int b) : x(a), y(b) { }
    point() { }
} Raw[MAX_N];
ULL pack_int(int a, int b) {
    return (ULL(a)<<32) | ULL(b);
}
const int BSIZ = 100003;
struct hashmap {         // ULL -> int
    struct node {
        ULL k; int v; node* n;
    };
    node* BKT[BSIZ];
    node *POOL, *cur;
    hashmap(int N) {
        POOL = new node[N];
        cur = POOL;
        memset(BKT, 0, sizeof(BKT));
    }
    void set(ULL key) {
        ULL x = key % BSIZ;
        node* &p = BKT[x];
        while(p && p->k != key) p = p->n;
        if(p) {
            p->v += 1;
        } else {
            p = cur++;
            p->k = key; p->v = 1;
        }
    }
    int get(ULL key) {
        ULL x = key % BSIZ;
        node* p = BKT[x];
        while(p && p->k != key) p = p->n;
        return p ? p->v : 0;
    }
    ~hashmap() {
        delete []POOL;
    }
};
int L = 0;
void getL() {
    while(L < 30) {
        hashmap cnt(N);
        ULL s = 0;
        for(int i=0;i<N;i++) {
            for(int dx=-1;dx<=1;dx++) {
                for(int dy=-1;dy<=1;dy++) {
                    long long trkey = pack_int((Raw[i].x>>L) + dx, (Raw[i].y>>L) + dy);
                    s += cnt.get(trkey);
                }
            }
            long long key = pack_int(Raw[i].x>>L, Raw[i].y>>L);
            cnt.set(key);
        }
        if(s > ULL(K)) break;
        L += 1;
    }
}
void work() {
    priority_queue<ULL> Q;
    vector<vector<point> > B(BSIZ);
    for(int i=0;i<N;i++) {
        int kx=Raw[i].x>>L, ky = Raw[i].y>>L;
        for(int dx=-3;dx<=3;dx++) {
            for(int dy=-3;dy<=3;dy++) {
                if(int(Q.size()) == K+1 && (ULL(max(0,abs(dx)-1)*max(0,abs(dx)-1)+max(0,abs(dy)-1)*max(0,abs(dy)-1))<<(L<<1)) >= Q.top())         // 类 KD 树剪枝
                    continue;
                ULL key = pack_int(kx + dx, ky + dy);
                vector<point> &C = B[key % BSIZ];
                for(int k=0;k<int(C.size());k++) {
                    LL ofx = LL(Raw[i].x - C[k].x);
                    LL ofy = LL(Raw[i].y - C[k].y);
                    ULL dis = ULL(ofx * ofx + ofy * ofy);
                    if(int(Q.size()) == K+1 && dis < Q.top()) {
                        Q.pop(); Q.push(dis);
                    } else if(int(Q.size()) <= K) {
                        Q.push(dis);
                    }
                }
            }
        }
        vector<point> &C = B[pack_int(kx, ky) % BSIZ];
        C.push_back(Raw[i]);
    }
    cout << Q.top() << endl;
}
int main() {
    scanf("%d %d", &N, &K);
    for(int i=0;i<N;i++) {
        scanf("%d %d", &Raw[i].x, &Raw[i].y);
    }
    getL();
    work();
}
数据生成器
没有这题的数据(但有 std),所以写个了数据生成器。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#!/usr/bin/env python3
# Usage: ./gen.py {DATA_SCALE} > star.in
import sys
import random
if len(sys.argv) < 2:
    sys.exit(-1)
DATA_SIZE = int(sys.argv[1])
K_TH = random.randrange(1, DATA_SIZE - 1)
print(DATA_SIZE, K_TH)
for i in range(DATA_SIZE):
    print(random.randrange(1, 1E9), random.randrange(1, 1E9))