Bitcoin Core  22.99.0
P2P Digital Currency
sketch_impl.h
Go to the documentation of this file.
1 /**********************************************************************
2  * Copyright (c) 2018 Pieter Wuille, Greg Maxwell, Gleb Naumenko *
3  * Distributed under the MIT software license, see the accompanying *
4  * file LICENSE or http://www.opensource.org/licenses/mit-license.php.*
5  **********************************************************************/
6 
7 #ifndef _MINISKETCH_SKETCH_IMPL_H_
8 #define _MINISKETCH_SKETCH_IMPL_H_
9 
10 #include <random>
11 
12 #include "util.h"
13 #include "sketch.h"
14 #include "int_utils.h"
15 
17 template<typename F>
18 void PolyMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& val, const F& field) {
19  size_t modsize = mod.size();
20  CHECK_SAFE(modsize > 0 && mod.back() == 1);
21  if (val.size() < modsize) return;
22  CHECK_SAFE(val.back() != 0);
23  while (val.size() >= modsize) {
24  auto term = val.back();
25  val.pop_back();
26  if (term != 0) {
27  typename F::Multiplier mul(field, term);
28  for (size_t x = 0; x < mod.size() - 1; ++x) {
29  val[val.size() - modsize + 1 + x] ^= mul(mod[x]);
30  }
31  }
32  }
33  while (val.size() > 0 && val.back() == 0) val.pop_back();
34 }
35 
37 template<typename F>
38 void DivMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& val, std::vector<typename F::Elem>& div, const F& field) {
39  size_t modsize = mod.size();
40  CHECK_SAFE(mod.size() > 0 && mod.back() == 1);
41  if (val.size() < mod.size()) {
42  div.clear();
43  return;
44  }
45  CHECK_SAFE(val.back() != 0);
46  div.resize(val.size() - mod.size() + 1);
47  while (val.size() >= modsize) {
48  auto term = val.back();
49  div[val.size() - modsize] = term;
50  val.pop_back();
51  if (term != 0) {
52  typename F::Multiplier mul(field, term);
53  for (size_t x = 0; x < mod.size() - 1; ++x) {
54  val[val.size() - modsize + 1 + x] ^= mul(mod[x]);
55  }
56  }
57  }
58 }
59 
61 template<typename F>
62 typename F::Elem MakeMonic(std::vector<typename F::Elem>& a, const F& field) {
63  CHECK_SAFE(a.back() != 0);
64  if (a.back() == 1) return 0;
65  auto inv = field.Inv(a.back());
66  typename F::Multiplier mul(field, inv);
67  a.back() = 1;
68  for (size_t i = 0; i < a.size() - 1; ++i) {
69  a[i] = mul(a[i]);
70  }
71  return inv;
72 }
73 
75 template<typename F>
76 void GCD(std::vector<typename F::Elem>& a, std::vector<typename F::Elem>& b, const F& field) {
77  if (a.size() < b.size()) std::swap(a, b);
78  while (b.size() > 0) {
79  if (b.size() == 1) {
80  a.resize(1);
81  a[0] = 1;
82  return;
83  }
84  MakeMonic(b, field);
85  PolyMod(b, a, field);
86  std::swap(a, b);
87  }
88 }
89 
91 template<typename F>
92 void Sqr(std::vector<typename F::Elem>& poly, const F& field) {
93  if (poly.size() == 0) return;
94  poly.resize(poly.size() * 2 - 1);
95  for (int x = poly.size() - 1; x >= 0; --x) {
96  poly[x] = (x & 1) ? 0 : field.Sqr(poly[x / 2]);
97  }
98 }
99 
101 template<typename F>
102 void TraceMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& out, const typename F::Elem& param, const F& field) {
103  out.reserve(mod.size() * 2);
104  out.resize(2);
105  out[0] = 0;
106  out[1] = param;
107 
108  for (int i = 0; i < field.Bits() - 1; ++i) {
109  Sqr(out, field);
110  if (out.size() < 2) out.resize(2);
111  out[1] = param;
112  PolyMod(mod, out, field);
113  }
114 }
115 
127 template<typename F>
128 bool RecFindRoots(std::vector<std::vector<typename F::Elem>>& stack, size_t pos, std::vector<typename F::Elem>& roots, bool fully_factorizable, int depth, typename F::Elem randv, const F& field) {
129  auto& ppoly = stack[pos];
130  // We assert ppoly.size() > 1 (instead of just ppoly.size() > 0) to additionally exclude
131  // constants polynomials because
132  // - ppoly is not constant initially (this is ensured by FindRoots()), and
133  // - we never recurse on a constant polynomial.
134  CHECK_SAFE(ppoly.size() > 1 && ppoly.back() == 1);
135  /* 1st degree input: constant term is the root. */
136  if (ppoly.size() == 2) {
137  roots.push_back(ppoly[0]);
138  return true;
139  }
140  /* 2nd degree input: use direct quadratic solver. */
141  if (ppoly.size() == 3) {
142  CHECK_RETURN(ppoly[1] != 0, false); // Equations of the form (x^2 + a) have two identical solutions; contradicts square-free assumption. */
143  auto input = field.Mul(ppoly[0], field.Sqr(field.Inv(ppoly[1])));
144  auto root = field.Qrt(input);
145  if ((field.Sqr(root) ^ root) != input) {
146  CHECK_SAFE(!fully_factorizable);
147  return false; // No root found.
148  }
149  auto sol = field.Mul(root, ppoly[1]);
150  roots.push_back(sol);
151  roots.push_back(sol ^ ppoly[1]);
152  return true;
153  }
154  /* 3rd degree input and more: recurse further. */
155  if (pos + 3 > stack.size()) {
156  // Allocate memory if necessary.
157  stack.resize((pos + 3) * 2);
158  }
159  auto& poly = stack[pos];
160  auto& tmp = stack[pos + 1];
161  auto& trace = stack[pos + 2];
162  trace.clear();
163  tmp.clear();
164  for (int iter = 0;; ++iter) {
165  // Compute the polynomial (trace(x*randv) mod poly(x)) symbolically,
166  // and put the result in `trace`.
167  TraceMod(poly, trace, randv, field);
168 
169  if (iter >= 1 && !fully_factorizable) {
170  // If the polynomial cannot be factorized completely (it has an
171  // irreducible factor of degree higher than 1), we want to avoid
172  // the case where this is only detected after trying all BITS
173  // independent split attempts fail (see the assert below).
174  //
175  // Observe that if we call y = randv*x, it is true that:
176  //
177  // trace = y + y^2 + y^4 + y^8 + ... y^(FIELDSIZE/2) mod poly
178  //
179  // Due to the Frobenius endomorphism, this means:
180  //
181  // trace^2 = y^2 + y^4 + y^8 + ... + y^FIELDSIZE mod poly
182  //
183  // Or, adding them up:
184  //
185  // trace + trace^2 = y + y^FIELDSIZE mod poly.
186  // = randv*x + randv^FIELDSIZE*x^FIELDSIZE
187  // = randv*x + randv*x^FIELDSIZE
188  // = randv*(x + x^FIELDSIZE).
189  // (all mod poly)
190  //
191  // x + x^FIELDSIZE is the polynomial which has every field element
192  // as root once. Whenever x + x^FIELDSIZE is multiple of poly,
193  // this means it only has unique first degree factors. The same
194  // holds for its constant multiple randv*(x + x^FIELDSIZE) =
195  // trace + trace^2.
196  //
197  // We use this test to quickly verify whether the polynomial is
198  // fully factorizable after already having computed a trace.
199  // We don't invoke it immediately; only when splitting has failed
200  // at least once, which avoids it for most polynomials that are
201  // fully factorizable (or at least pushes the test down the
202  // recursion to factors which are smaller and thus faster).
203  tmp = trace;
204  Sqr(tmp, field);
205  for (size_t i = 0; i < trace.size(); ++i) {
206  tmp[i] ^= trace[i];
207  }
208  while (tmp.size() && tmp.back() == 0) tmp.pop_back();
209  PolyMod(poly, tmp, field);
210 
211  // Whenever the test fails, we can immediately abort the root
212  // finding. Whenever it succeeds, we can remember and pass down
213  // the information that it is in fact fully factorizable, avoiding
214  // the need to run the test again.
215  if (tmp.size() != 0) return false;
216  fully_factorizable = true;
217  }
218 
219  if (fully_factorizable) {
220  // Every succesful iteration of this algorithm splits the input
221  // polynomial further into buckets, each corresponding to a subset
222  // of 2^(BITS-depth) roots. If after depth splits the degree of
223  // the polynomial is >= 2^(BITS-depth), something is wrong.
224  CHECK_RETURN(field.Bits() - depth >= std::numeric_limits<decltype(poly.size())>::digits ||
225  (poly.size() - 2) >> (field.Bits() - depth) == 0, false);
226  }
227 
228  depth++;
229  // In every iteration we multiply randv by 2. As a result, the set
230  // of randv values forms a GF(2)-linearly independent basis of splits.
231  randv = field.Mul2(randv);
232  tmp = poly;
233  GCD(trace, tmp, field);
234  if (trace.size() != poly.size() && trace.size() > 1) break;
235  }
236  MakeMonic(trace, field);
237  DivMod(trace, poly, tmp, field);
238  // At this point, the stack looks like [... (poly) tmp trace], and we want to recursively
239  // find roots of trace and tmp (= poly/trace). As we don't care about poly anymore, move
240  // trace into its position first.
241  std::swap(poly, trace);
242  // Now the stack is [... (trace) tmp ...]. First we factor tmp (at pos = pos+1), and then
243  // we factor trace (at pos = pos).
244  if (!RecFindRoots(stack, pos + 1, roots, fully_factorizable, depth, randv, field)) return false;
245  // The stack position pos contains trace, the polynomial with all of poly's roots which (after
246  // multiplication with randv) have trace 0. This is never the case for irreducible factors
247  // (which always end up in tmp), so we can set fully_factorizable to true when recursing.
248  bool ret = RecFindRoots(stack, pos, roots, true, depth, randv, field);
249  // Because of the above, recursion can never fail here.
250  CHECK_SAFE(ret);
251  return ret;
252 }
253 
262 template<typename F>
263 std::vector<typename F::Elem> FindRoots(const std::vector<typename F::Elem>& poly, typename F::Elem basis, const F& field) {
264  std::vector<typename F::Elem> roots;
265  CHECK_RETURN(poly.size() != 0, {});
266  CHECK_RETURN(basis != 0, {});
267  if (poly.size() == 1) return roots; // No roots when the polynomial is a constant.
268  roots.reserve(poly.size() - 1);
269  std::vector<std::vector<typename F::Elem>> stack = {poly};
270 
271  // Invoke the recursive factorization algorithm.
272  if (!RecFindRoots(stack, 0, roots, false, 0, basis, field)) {
273  // Not fully factorizable.
274  return {};
275  }
276  CHECK_RETURN(poly.size() - 1 == roots.size(), {});
277  return roots;
278 }
279 
280 template<typename F>
281 std::vector<typename F::Elem> BerlekampMassey(const std::vector<typename F::Elem>& syndromes, size_t max_degree, const F& field) {
282  std::vector<typename F::Multiplier> table;
283  std::vector<typename F::Elem> current, prev, tmp;
284  current.reserve(syndromes.size() / 2 + 1);
285  prev.reserve(syndromes.size() / 2 + 1);
286  tmp.reserve(syndromes.size() / 2 + 1);
287  current.resize(1);
288  current[0] = 1;
289  prev.resize(1);
290  prev[0] = 1;
291  typename F::Elem b = 1, b_inv = 1;
292  bool b_have_inv = true;
293  table.reserve(syndromes.size());
294 
295  for (size_t n = 0; n != syndromes.size(); ++n) {
296  table.emplace_back(field, syndromes[n]);
297  auto discrepancy = syndromes[n];
298  for (size_t i = 1; i < current.size(); ++i) discrepancy ^= table[n - i](current[i]);
299  if (discrepancy != 0) {
300  int x = n + 1 - (current.size() - 1) - (prev.size() - 1);
301  if (!b_have_inv) {
302  b_inv = field.Inv(b);
303  b_have_inv = true;
304  }
305  bool swap = 2 * (current.size() - 1) <= n;
306  if (swap) {
307  if (prev.size() + x - 1 > max_degree) return {}; // We'd exceed maximum degree
308  tmp = current;
309  current.resize(prev.size() + x);
310  }
311  typename F::Multiplier mul(field, field.Mul(discrepancy, b_inv));
312  for (size_t i = 0; i < prev.size(); ++i) current[i + x] ^= mul(prev[i]);
313  if (swap) {
314  std::swap(prev, tmp);
315  b = discrepancy;
316  b_have_inv = false;
317  }
318  }
319  }
320  CHECK_RETURN(current.size() && current.back() != 0, {});
321  return current;
322 }
323 
324 template<typename F>
325 std::vector<typename F::Elem> ReconstructAllSyndromes(const std::vector<typename F::Elem>& odd_syndromes, const F& field) {
326  std::vector<typename F::Elem> all_syndromes;
327  all_syndromes.resize(odd_syndromes.size() * 2);
328  for (size_t i = 0; i < odd_syndromes.size(); ++i) {
329  all_syndromes[i * 2] = odd_syndromes[i];
330  all_syndromes[i * 2 + 1] = field.Sqr(all_syndromes[i]);
331  }
332  return all_syndromes;
333 }
334 
335 template<typename F>
336 void AddToOddSyndromes(std::vector<typename F::Elem>& osyndromes, typename F::Elem data, const F& field) {
337  auto sqr = field.Sqr(data);
338  typename F::Multiplier mul(field, sqr);
339  for (auto& osyndrome : osyndromes) {
340  osyndrome ^= data;
341  data = mul(data);
342  }
343 }
344 
345 template<typename F>
346 std::vector<typename F::Elem> FullDecode(const std::vector<typename F::Elem>& osyndromes, const F& field) {
347  auto asyndromes = ReconstructAllSyndromes<typename F::Elem>(osyndromes, field);
348  auto poly = BerlekampMassey(asyndromes, field);
349  std::reverse(poly.begin(), poly.end());
350  return FindRoots(poly, field);
351 }
352 
353 template<typename F>
354 class SketchImpl final : public Sketch
355 {
356  const F m_field;
357  std::vector<typename F::Elem> m_syndromes;
358  typename F::Elem m_basis;
359 
360 public:
361  template<typename... Args>
362  SketchImpl(int implementation, int bits, const Args&... args) : Sketch(implementation, bits), m_field(args...) {
363  std::random_device rng;
364  std::uniform_int_distribution<uint64_t> dist;
365  m_basis = m_field.FromSeed(dist(rng));
366  }
367 
368  size_t Syndromes() const override { return m_syndromes.size(); }
369  void Init(int count) override { m_syndromes.assign(count, 0); }
370 
371  void Add(uint64_t val) override
372  {
373  auto elem = m_field.FromUint64(val);
375  }
376 
377  void Serialize(unsigned char* ptr) const override
378  {
379  BitWriter writer(ptr);
380  for (const auto& val : m_syndromes) {
381  m_field.Serialize(writer, val);
382  }
383  writer.Flush();
384  }
385 
386  void Deserialize(const unsigned char* ptr) override
387  {
388  BitReader reader(ptr);
389  for (auto& val : m_syndromes) {
390  val = m_field.Deserialize(reader);
391  }
392  }
393 
394  int Decode(int max_count, uint64_t* out) const override
395  {
396  auto all_syndromes = ReconstructAllSyndromes(m_syndromes, m_field);
397  auto poly = BerlekampMassey(all_syndromes, max_count, m_field);
398  if (poly.size() == 0) return -1;
399  if (poly.size() == 1) return 0;
400  if ((int)poly.size() > 1 + max_count) return -1;
401  std::reverse(poly.begin(), poly.end());
402  auto roots = FindRoots(poly, m_basis, m_field);
403  if (roots.size() == 0) return -1;
404 
405  for (const auto& root : roots) {
406  *(out++) = m_field.ToUint64(root);
407  }
408  return roots.size();
409  }
410 
411  size_t Merge(const Sketch* other_sketch) override
412  {
413  // Sad cast. This is safe only because the caller code in minisketch.cpp checks
414  // that implementation and field size match.
415  const SketchImpl* other = static_cast<const SketchImpl*>(other_sketch);
416  m_syndromes.resize(std::min(m_syndromes.size(), other->m_syndromes.size()));
417  for (size_t i = 0; i < m_syndromes.size(); ++i) {
418  m_syndromes[i] ^= other->m_syndromes[i];
419  }
420  return m_syndromes.size();
421  }
422 
423  void SetSeed(uint64_t seed) override
424  {
425  if (seed == (uint64_t)-1) {
426  m_basis = 1;
427  } else {
428  m_basis = m_field.FromSeed(seed);
429  }
430  }
431 };
432 
433 #endif
int_utils.h
RecFindRoots
bool RecFindRoots(std::vector< std::vector< typename F::Elem >> &stack, size_t pos, std::vector< typename F::Elem > &roots, bool fully_factorizable, int depth, typename F::Elem randv, const F &field)
One step of the root finding algorithm; finds roots of stack[pos] and adds them to roots.
Definition: sketch_impl.h:128
BitReader
Definition: int_utils.h:89
count
static int count
Definition: tests.c:41
SketchImpl::SketchImpl
SketchImpl(int implementation, int bits, const Args &... args)
Definition: sketch_impl.h:362
CHECK_SAFE
#define CHECK_SAFE(cond)
Check macro that does nothing in normal non-verify builds but crashes in verify builds.
Definition: util.h:50
BitWriter
Definition: int_utils.h:52
GCD
void GCD(std::vector< typename F::Elem > &a, std::vector< typename F::Elem > &b, const F &field)
Compute the GCD of two polynomials, putting the result in a.
Definition: sketch_impl.h:76
PolyMod
void PolyMod(const std::vector< typename F::Elem > &mod, std::vector< typename F::Elem > &val, const F &field)
Compute the remainder of a polynomial division of val by mod, putting the result in mod.
Definition: sketch_impl.h:18
SketchImpl::m_field
const F m_field
Definition: sketch_impl.h:356
AddToOddSyndromes
void AddToOddSyndromes(std::vector< typename F::Elem > &osyndromes, typename F::Elem data, const F &field)
Definition: sketch_impl.h:336
BitWriter::Flush
void Flush()
Definition: int_utils.h:80
util.h
Sketch
Abstract class for internal representation of a minisketch object.
Definition: sketch.h:14
Sqr
void Sqr(std::vector< typename F::Elem > &poly, const F &field)
Square a polynomial.
Definition: sketch_impl.h:92
sketch.h
SketchImpl
Definition: sketch_impl.h:354
SketchImpl::Serialize
void Serialize(unsigned char *ptr) const override
Definition: sketch_impl.h:377
DivMod
void DivMod(const std::vector< typename F::Elem > &mod, std::vector< typename F::Elem > &val, std::vector< typename F::Elem > &div, const F &field)
Compute the quotient of a polynomial division of val by mod, putting the quotient in div and the rema...
Definition: sketch_impl.h:38
SketchImpl::Deserialize
void Deserialize(const unsigned char *ptr) override
Definition: sketch_impl.h:386
FindRoots
std::vector< typename F::Elem > FindRoots(const std::vector< typename F::Elem > &poly, typename F::Elem basis, const F &field)
Returns the roots of a fully factorizable polynomial.
Definition: sketch_impl.h:263
SketchImpl::m_basis
F::Elem m_basis
Definition: sketch_impl.h:358
SketchImpl::SetSeed
void SetSeed(uint64_t seed) override
Definition: sketch_impl.h:423
TraceMod
void TraceMod(const std::vector< typename F::Elem > &mod, std::vector< typename F::Elem > &out, const typename F::Elem &param, const F &field)
Compute the trace map of (param*x) modulo mod, putting the result in out.
Definition: sketch_impl.h:102
SketchImpl::Init
void Init(int count) override
Definition: sketch_impl.h:369
MakeMonic
F::Elem MakeMonic(std::vector< typename F::Elem > &a, const F &field)
Make a polynomial monic.
Definition: sketch_impl.h:62
SketchImpl::m_syndromes
std::vector< typename F::Elem > m_syndromes
Definition: sketch_impl.h:357
SketchImpl::Syndromes
size_t Syndromes() const override
Definition: sketch_impl.h:368
ReconstructAllSyndromes
std::vector< typename F::Elem > ReconstructAllSyndromes(const std::vector< typename F::Elem > &odd_syndromes, const F &field)
Definition: sketch_impl.h:325
CHECK_RETURN
#define CHECK_RETURN(cond, rvar)
Check a condition and return on failure in non-verify builds, crash in verify builds.
Definition: util.h:67
FullDecode
std::vector< typename F::Elem > FullDecode(const std::vector< typename F::Elem > &osyndromes, const F &field)
Definition: sketch_impl.h:346
SketchImpl::Decode
int Decode(int max_count, uint64_t *out) const override
Definition: sketch_impl.h:394
SketchImpl::Merge
size_t Merge(const Sketch *other_sketch) override
Definition: sketch_impl.h:411
BerlekampMassey
std::vector< typename F::Elem > BerlekampMassey(const std::vector< typename F::Elem > &syndromes, size_t max_degree, const F &field)
Definition: sketch_impl.h:281
SketchImpl::Add
void Add(uint64_t val) override
Definition: sketch_impl.h:371