BSGS算法学习笔记

拔山盖世 北上广深

参考资料

前言

来到ZR之后的第二天,神仙dyh老师就讲了这个名字玄学的指数同余方程求解算法.

虽然在之前做某SDOI三合一模板-计算器的时候接触了一下这个算法,但是了解的并不深入,经过杜老师讲解之后,赶紧把它记下来.

BSGS

BSGS算法,全称为Baby-step Gaint-step算法,是一种求解指数同余方程的算法.

BSGS求解的指数同余方程是长这个样子:

$$a^x \equiv b(\mod p)$$

其中p是一个质数(实际上这里只要满足$\gcd(a,p) = 1$即可).

我们考虑设$x = qt + r$,$(t = \lceil \sqrt{p} \rceil)$那么原式变为:

$$a^{qt + r} \equiv b(\mod p)$$

然后我们移一下项:

$$a^{qt} \equiv b \cdot a^r(\mod p)$$

然后我们就要开始暴力啦!

暴力枚举$r$,把同余式右边的$ba^r$的所有值存到一个哈希表或者map里,然后再暴力枚举$q$,计算出$a^{qt}$次方所对应的值,然后从哈希表里查找这个值是否存在,如果存在,那么这就是一个原方程的解.

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
inline long long BSGS(long long a,long long b,long long p){
std::map<int,int> hash;
hash.clear(); b %= p;
int t = (int)std::sqrt(p) + 1;
FOR(i,0,t-1,1){
int value = (long long)b * quick_power(a,i,p) % p;
hash[value] = i;
}
a = quick_power(a,t,p);
if(!a){if(!b) return 1; else return -1;}
FOR(i,0,t,1){
int value = quick_power(a,i,p);
int tmp = hash.find(value) == hash.end() ? -1 : hash[value];
if(tmp >= 0 && i * t - tmp >= 0) return i * t - tmp;
}
return -1;
}

exBSGS

如果当$\gcd(a,p) \not = 1$时,普通的BSGS就过不去了.

因为我们BSGS出来的可能不是原方程的解,于是我们就需要进行扩展.

首先,显然当$\gcd(a,p) \not | b and b \not = 1$时,原方程无正整数解。

我们把这个同余式两边同时除一个$\gcd(a,p)$:

$$a^{x-1} \times \frac {a}{\gcd(a,p)} \equiv \frac {b}{\gcd(a,p)} (\mod \frac {p}{\gcd(a,p)})$$

然后我们设$g = (\frac {a}{\gcd(a,p)})^{-1} (\mod p)$

所以:

$$a^{x-1} \equiv \frac {bg}{\gcd(a,p)} (\mod \frac {p}{\gcd(a,p)})$$

然后我们令$x_0 = x-1,b_0 = \frac {bg}{\gcd(a,p)} q_0 = \frac {p}{\gcd(a,p)}$

于是我们不断递归这个过程,直到出现$\gcd(a,p) = 1$的情况之后BSGS即可.

代码

我们以LuoguP4195为例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// luogu-judger-enable-o2
/* Headers */
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cctype>
#include<algorithm>
#include<vector>
#include<queue>
#include<stack>
#include<climits>
#include<iostream>
#include<map>
#define FOR(i,a,b,c) for(int i=(a);i<=(b);i+=(c))
#define ROF(i,a,b,c) for(int i=(a);i>=(b);i-=(c))
#define FORL(i,a,b,c) for(long long i=(a);i<=(b);i+=(c))
#define ROFL(i,a,b,c) for(long long i=(a);i>=(b);i-=(c))
#define FORR(i,a,b,c) for(register int i=(a);i<=(b);i+=(c))
#define ROFR(i,a,b,c) for(register int i=(a);i>=(b);i-=(c))
#define lowbit(x) x&(-x)
#define LeftChild(x) x<<1
#define RightChild(x) (x<<1)+1
#define RevEdge(x) x^1
#define FILE_IN(x) freopen(x,"r",stdin);
#define FILE_OUT(x) freopen(x,"w",stdout);
#define CLOSE_IN() fclose(stdin);
#define CLOSE_OUT() fclose(stdout);
#define IOS(x) std::ios::sync_with_stdio(x)
#define Dividing() printf("-----------------------------------\n");
namespace FastIO{
const int BUFSIZE = 1 << 20;
char ibuf[BUFSIZE],*is = ibuf,*its = ibuf;
char obuf[BUFSIZE],*os = obuf,*ot = obuf + BUFSIZE;
inline char getch(){
if(is == its)
its = (is = ibuf)+fread(ibuf,1,BUFSIZE,stdin);
return (is == its)?EOF:*is++;
}
inline int getint(){
int res = 0,neg = 0,ch = getch();
while(!(isdigit(ch) || ch == '-') && ch != EOF)
ch = getch();
if(ch == '-'){
neg = 1;ch = getch();
}
while(isdigit(ch)){
res = (res << 3) + (res << 1)+ (ch - '0');
ch = getch();
}
return neg?-res:res;
}
inline void flush(){
fwrite(obuf,1,os-obuf,stdout);
os = obuf;
}
inline void putch(char ch){
*os++ = ch;
if(os == ot) flush();
}
inline void putint(int res){
static char q[10];
if(res==0) putch('0');
else if(res < 0){putch('-');res = -res;}
int top = 0;
while(res){
q[top++] = res % 10 + '0';
res /= 10;
}
while(top--) putch(q[top]);
}
inline void space(bool x){
if(!x) putch('\n');
else putch(' ');
}
}
inline void read(int &x){
int rt = FastIO::getint();
x = rt;
}
inline void print(int x,bool enter){
FastIO::putint(x);
FastIO::flush();
FastIO::space(enter);
}
/* definitions */
int a,p,b;
std::map<int,int> hash;
/* functions */
inline int gcd(int a,int b){
return (!b) ? a : gcd(b,a%b);
}
inline int quick_multi(int a,int b,int \mod){
return 1ll * a * b % \mod;
}
inline int quick_power(int a,int b,int p){
int res = 1,base = a;
while(b){
if(b & 1) res = quick_multi(res,base,p) % p;
b >>= 1; base = quick_multi(base,base,p) % p;
}
return res;
}
inline void exBSGS(int x,int y){
if(y == 1){printf("0\n"); return;}
int d = gcd(x,p), k = 1,t = 0;
while(d ^ 1){
if(y % d){printf("No Solution\n"); return;}
t = t + 1; y /= d; p /= d;
k = quick_multi(k,x/d,p);
if(y == k){printf("%d\n",t); return;}
d = gcd(x,p);
}
int s = y, m = (int)std::sqrt(p) + 1;
hash.clear();
FOR(i,0,m-1,1){
hash[s] = i; s = quick_multi(s,x,p);
}
s = k; k = quick_power(x,m,p);
FOR(i,1,m,1){
s = quick_multi(s,k,p);
if(hash[s]) {printf("%d\n",i * m - hash[s] + t); return;}
}
printf("No Solution\n"); return;
}
int main(int argc,char *argv[]){
while(1){
read(a); read(p); read(b);
if(!a && !p && !b) break;
a %= p; b %= p;
exBSGS(a,b);
}
return 0;
}

THE END