Falcon source files (reference implementation)


sign.c

    1 /*
    2  * Falcon signature generation.
    3  *
    4  * ==========================(LICENSE BEGIN)============================
    5  *
    6  * Copyright (c) 2017-2019  Falcon Project
    7  *
    8  * Permission is hereby granted, free of charge, to any person obtaining
    9  * a copy of this software and associated documentation files (the
   10  * "Software"), to deal in the Software without restriction, including
   11  * without limitation the rights to use, copy, modify, merge, publish,
   12  * distribute, sublicense, and/or sell copies of the Software, and to
   13  * permit persons to whom the Software is furnished to do so, subject to
   14  * the following conditions:
   15  *
   16  * The above copyright notice and this permission notice shall be
   17  * included in all copies or substantial portions of the Software.
   18  *
   19  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
   20  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
   21  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
   22  * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
   23  * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
   24  * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
   25  * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
   26  *
   27  * ===========================(LICENSE END)=============================
   28  *
   29  * @author   Thomas Pornin <thomas.pornin@nccgroup.com>
   30  */
   31 
   32 #include "inner.h"
   33 
   34 /* =================================================================== */
   35 
   36 /*
   37  * Compute degree N from logarithm 'logn'.
   38  */
   39 #define MKN(logn)   ((size_t)1 << (logn))
   40 
   41 /* =================================================================== */
   42 /*
   43  * Binary case:
   44  *   N = 2^logn
   45  *   phi = X^N+1
   46  */
   47 
   48 /*
   49  * Get the size of the LDL tree for an input with polynomials of size
   50  * 2^logn. The size is expressed in the number of elements.
   51  */
   52 static inline unsigned
   53 ffLDL_treesize(unsigned logn)
   54 {
   55         /*
   56          * For logn = 0 (polynomials are constant), the "tree" is a
   57          * single element. Otherwise, the tree node has size 2^logn, and
   58          * has two child trees for size logn-1 each. Thus, treesize s()
   59          * must fulfill these two relations:
   60          *
   61          *   s(0) = 1
   62          *   s(logn) = (2^logn) + 2*s(logn-1)
   63          */
   64         return (logn + 1) << logn;
   65 }
   66 
   67 /*
   68  * Inner function for ffLDL_fft(). It expects the matrix to be both
   69  * auto-adjoint and quasicyclic; also, it uses the source operands
   70  * as modifiable temporaries.
   71  *
   72  * tmp[] must have room for at least one polynomial.
   73  */
   74 static void
   75 ffLDL_fft_inner(fpr *restrict tree,
   76         fpr *restrict g0, fpr *restrict g1, unsigned logn, fpr *restrict tmp)
   77 {
   78         size_t n, hn;
   79 
   80         n = MKN(logn);
   81         if (n == 1) {
   82                 tree[0] = g0[0];
   83                 return;
   84         }
   85         hn = n >> 1;
   86 
   87         /*
   88          * The LDL decomposition yields L (which is written in the tree)
   89          * and the diagonal of D. Since d00 = g0, we just write d11
   90          * into tmp.
   91          */
   92         Zf(poly_LDLmv_fft)(tmp, tree, g0, g1, g0, logn);
   93 
   94         /*
   95          * Split d00 (currently in g0) and d11 (currently in tmp). We
   96          * reuse g0 and g1 as temporary storage spaces:
   97          *   d00 splits into g1, g1+hn
   98          *   d11 splits into g0, g0+hn
   99          */
  100         Zf(poly_split_fft)(g1, g1 + hn, g0, logn);
  101         Zf(poly_split_fft)(g0, g0 + hn, tmp, logn);
  102 
  103         /*
  104          * Each split result is the first row of a new auto-adjoint
  105          * quasicyclic matrix for the next recursive step.
  106          */
  107         ffLDL_fft_inner(tree + n,
  108                 g1, g1 + hn, logn - 1, tmp);
  109         ffLDL_fft_inner(tree + n + ffLDL_treesize(logn - 1),
  110                 g0, g0 + hn, logn - 1, tmp);
  111 }
  112 
  113 /*
  114  * Compute the ffLDL tree of an auto-adjoint matrix G. The matrix
  115  * is provided as three polynomials (FFT representation).
  116  *
  117  * The "tree" array is filled with the computed tree, of size
  118  * (logn+1)*(2^logn) elements (see ffLDL_treesize()).
  119  *
  120  * Input arrays MUST NOT overlap, except possibly the three unmodified
  121  * arrays g00, g01 and g11. tmp[] should have room for at least three
  122  * polynomials of 2^logn elements each.
  123  */
  124 static void
  125 ffLDL_fft(fpr *restrict tree, const fpr *restrict g00,
  126         const fpr *restrict g01, const fpr *restrict g11,
  127         unsigned logn, fpr *restrict tmp)
  128 {
  129         size_t n, hn;
  130         fpr *d00, *d11;
  131 
  132         n = MKN(logn);
  133         if (n == 1) {
  134                 tree[0] = g00[0];
  135                 return;
  136         }
  137         hn = n >> 1;
  138         d00 = tmp;
  139         d11 = tmp + n;
  140         tmp += n << 1;
  141 
  142         memcpy(d00, g00, n * sizeof *g00);
  143         Zf(poly_LDLmv_fft)(d11, tree, g00, g01, g11, logn);
  144 
  145         Zf(poly_split_fft)(tmp, tmp + hn, d00, logn);
  146         Zf(poly_split_fft)(d00, d00 + hn, d11, logn);
  147         memcpy(d11, tmp, n * sizeof *tmp);
  148         ffLDL_fft_inner(tree + n,
  149                 d11, d11 + hn, logn - 1, tmp);
  150         ffLDL_fft_inner(tree + n + ffLDL_treesize(logn - 1),
  151                 d00, d00 + hn, logn - 1, tmp);
  152 }
  153 
  154 /*
  155  * Normalize an ffLDL tree: each leaf of value x is replaced with
  156  * sigma / sqrt(x).
  157  */
  158 static void
  159 ffLDL_binary_normalize(fpr *tree, unsigned logn)
  160 {
  161         /*
  162          * TODO: make an iterative version.
  163          */
  164         size_t n;
  165 
  166         n = MKN(logn);
  167         if (n == 1) {
  168                 /*
  169                  * We actually store in the tree leaf the inverse of
  170                  * the value mandated by the specification: this
  171                  * saves a division both here and in the sampler.
  172                  */
  173                 tree[0] = fpr_mul(fpr_sqrt(tree[0]), fpr_inv_sigma);
  174         } else {
  175                 ffLDL_binary_normalize(tree + n, logn - 1);
  176                 ffLDL_binary_normalize(tree + n + ffLDL_treesize(logn - 1),
  177                         logn - 1);
  178         }
  179 }
  180 
  181 /* =================================================================== */
  182 
  183 /*
  184  * Convert an integer polynomial (with small values) into the
  185  * representation with complex numbers.
  186  */
  187 static void
  188 smallints_to_fpr(fpr *r, const int8_t *t, unsigned logn)
  189 {
  190         size_t n, u;
  191 
  192         n = MKN(logn);
  193         for (u = 0; u < n; u ++) {
  194                 r[u] = fpr_of(t[u]);
  195         }
  196 }
  197 
  198 /*
  199  * The expanded private key contains:
  200  *  - The B0 matrix (four elements)
  201  *  - The ffLDL tree
  202  */
  203 
  204 static inline size_t
  205 skoff_b00(unsigned logn)
  206 {
  207         (void)logn;
  208         return 0;
  209 }
  210 
  211 static inline size_t
  212 skoff_b01(unsigned logn)
  213 {
  214         return MKN(logn);
  215 }
  216 
  217 static inline size_t
  218 skoff_b10(unsigned logn)
  219 {
  220         return 2 * MKN(logn);
  221 }
  222 
  223 static inline size_t
  224 skoff_b11(unsigned logn)
  225 {
  226         return 3 * MKN(logn);
  227 }
  228 
  229 static inline size_t
  230 skoff_tree(unsigned logn)
  231 {
  232         return 4 * MKN(logn);
  233 }
  234 
  235 /* see inner.h */
  236 void
  237 Zf(expand_privkey)(fpr *restrict expanded_key,
  238         const int8_t *f, const int8_t *g,
  239         const int8_t *F, const int8_t *G,
  240         unsigned logn, uint8_t *restrict tmp)
  241 {
  242         size_t n;
  243         fpr *rf, *rg, *rF, *rG;
  244         fpr *b00, *b01, *b10, *b11;
  245         fpr *g00, *g01, *g11, *gxx;
  246         fpr *tree;
  247 
  248         n = MKN(logn);
  249         b00 = expanded_key + skoff_b00(logn);
  250         b01 = expanded_key + skoff_b01(logn);
  251         b10 = expanded_key + skoff_b10(logn);
  252         b11 = expanded_key + skoff_b11(logn);
  253         tree = expanded_key + skoff_tree(logn);
  254 
  255         /*
  256          * We load the private key elements directly into the B0 matrix,
  257          * since B0 = [[g, -f], [G, -F]].
  258          */
  259         rf = b01;
  260         rg = b00;
  261         rF = b11;
  262         rG = b10;
  263 
  264         smallints_to_fpr(rf, f, logn);
  265         smallints_to_fpr(rg, g, logn);
  266         smallints_to_fpr(rF, F, logn);
  267         smallints_to_fpr(rG, G, logn);
  268 
  269         /*
  270          * Compute the FFT for the key elements, and negate f and F.
  271          */
  272         Zf(FFT)(rf, logn);
  273         Zf(FFT)(rg, logn);
  274         Zf(FFT)(rF, logn);
  275         Zf(FFT)(rG, logn);
  276         Zf(poly_neg)(rf, logn);
  277         Zf(poly_neg)(rF, logn);
  278 
  279         /*
  280          * The Gram matrix is G = B·B*. Formulas are:
  281          *   g00 = b00*adj(b00) + b01*adj(b01)
  282          *   g01 = b00*adj(b10) + b01*adj(b11)
  283          *   g10 = b10*adj(b00) + b11*adj(b01)
  284          *   g11 = b10*adj(b10) + b11*adj(b11)
  285          *
  286          * For historical reasons, this implementation uses
  287          * g00, g01 and g11 (upper triangle).
  288          */
  289         g00 = (fpr *)tmp;
  290         g01 = g00 + n;
  291         g11 = g01 + n;
  292         gxx = g11 + n;
  293 
  294         memcpy(g00, b00, n * sizeof *b00);
  295         Zf(poly_mulselfadj_fft)(g00, logn);
  296         memcpy(gxx, b01, n * sizeof *b01);
  297         Zf(poly_mulselfadj_fft)(gxx, logn);
  298         Zf(poly_add)(g00, gxx, logn);
  299 
  300         memcpy(g01, b00, n * sizeof *b00);
  301         Zf(poly_muladj_fft)(g01, b10, logn);
  302         memcpy(gxx, b01, n * sizeof *b01);
  303         Zf(poly_muladj_fft)(gxx, b11, logn);
  304         Zf(poly_add)(g01, gxx, logn);
  305 
  306         memcpy(g11, b10, n * sizeof *b10);
  307         Zf(poly_mulselfadj_fft)(g11, logn);
  308         memcpy(gxx, b11, n * sizeof *b11);
  309         Zf(poly_mulselfadj_fft)(gxx, logn);
  310         Zf(poly_add)(g11, gxx, logn);
  311 
  312         /*
  313          * Compute the Falcon tree.
  314          */
  315         ffLDL_fft(tree, g00, g01, g11, logn, gxx);
  316 
  317         /*
  318          * Normalize tree.
  319          */
  320         ffLDL_binary_normalize(tree, logn);
  321 }
  322 
  323 typedef int (*samplerZ)(void *ctx, fpr mu, fpr sigma);
  324 
  325 /*
  326  * Perform Fast Fourier Sampling for target vector t. The Gram matrix
  327  * is provided (G = [[g00, g01], [adj(g01), g11]]). The sampled vector
  328  * is written over (t0,t1). The Gram matrix is modified as well. The
  329  * tmp[] buffer must have room for four polynomials.
  330  */
  331 TARGET_AVX2
  332 static void
  333 ffSampling_fft_dyntree(samplerZ samp, void *samp_ctx,
  334         fpr *restrict t0, fpr *restrict t1,
  335         fpr *restrict g00, fpr *restrict g01, fpr *restrict g11,
  336         unsigned logn, fpr *restrict tmp)
  337 {
  338         size_t n, hn;
  339         fpr *z0, *z1;
  340 
  341         /*
  342          * Deepest level: the LDL tree leaf value is just g00 (the
  343          * array has length only 1 at this point); we normalize it
  344          * with regards to sigma, then use it for sampling.
  345          */
  346         if (logn == 0) {
  347                 fpr leaf;
  348 
  349                 leaf = g00[0];
  350                 leaf = fpr_mul(fpr_sqrt(leaf), fpr_inv_sigma);
  351                 t0[0] = fpr_of(samp(samp_ctx, t0[0], leaf));
  352                 t1[0] = fpr_of(samp(samp_ctx, t1[0], leaf));
  353                 return;
  354         }
  355 
  356         n = (size_t)1 << logn;
  357         hn = n >> 1;
  358 
  359         /*
  360          * Decompose G into LDL. We only need d00 (identical to g00),
  361          * d11, and l10; we do that in place.
  362          */
  363         Zf(poly_LDL_fft)(g00, g01, g11, logn);
  364 
  365         /*
  366          * Split d00 and d11 and expand them into half-size quasi-cyclic
  367          * Gram matrices. We also save l10 in tmp[].
  368          */
  369         Zf(poly_split_fft)(tmp, tmp + hn, g00, logn);
  370         memcpy(g00, tmp, n * sizeof *tmp);
  371         Zf(poly_split_fft)(tmp, tmp + hn, g11, logn);
  372         memcpy(g11, tmp, n * sizeof *tmp);
  373         memcpy(tmp, g01, n * sizeof *g01);
  374         memcpy(g01, g00, hn * sizeof *g00);
  375         memcpy(g01 + hn, g11, hn * sizeof *g00);
  376 
  377         /*
  378          * The half-size Gram matrices for the recursive LDL tree
  379          * building are now:
  380          *   - left sub-tree: g00, g00+hn, g01
  381          *   - right sub-tree: g11, g11+hn, g01+hn
  382          * l10 is in tmp[].
  383          */
  384 
  385         /*
  386          * We split t1 and use the first recursive call on the two
  387          * halves, using the right sub-tree. The result is merged
  388          * back into tmp + 2*n.
  389          */
  390         z1 = tmp + n;
  391         Zf(poly_split_fft)(z1, z1 + hn, t1, logn);
  392         ffSampling_fft_dyntree(samp, samp_ctx, z1, z1 + hn,
  393                 g11, g11 + hn, g01 + hn, logn - 1, z1 + n);
  394         Zf(poly_merge_fft)(tmp + (n << 1), z1, z1 + hn, logn);
  395 
  396         /*
  397          * Compute tb0 = t0 + (t1 - z1) * l10.
  398          * At that point, l10 is in tmp, t1 is unmodified, and z1 is
  399          * in tmp + (n << 1). The buffer in z1 is free.
  400          *
  401          * In the end, z1 is written over t1, and tb0 is in t0.
  402          */
  403         memcpy(z1, t1, n * sizeof *t1);
  404         Zf(poly_sub)(z1, tmp + (n << 1), logn);
  405         memcpy(t1, tmp + (n << 1), n * sizeof *tmp);
  406         Zf(poly_mul_fft)(tmp, z1, logn);
  407         Zf(poly_add)(t0, tmp, logn);
  408 
  409         /*
  410          * Second recursive invocation, on the split tb0 (currently in t0)
  411          * and the left sub-tree.
  412          */
  413         z0 = tmp;
  414         Zf(poly_split_fft)(z0, z0 + hn, t0, logn);
  415         ffSampling_fft_dyntree(samp, samp_ctx, z0, z0 + hn,
  416                 g00, g00 + hn, g01, logn - 1, z0 + n);
  417         Zf(poly_merge_fft)(t0, z0, z0 + hn, logn);
  418 }
  419 
  420 /*
  421  * Perform Fast Fourier Sampling for target vector t and LDL tree T.
  422  * tmp[] must have size for at least two polynomials of size 2^logn.
  423  */
  424 TARGET_AVX2
  425 static void
  426 ffSampling_fft(samplerZ samp, void *samp_ctx,
  427         fpr *restrict z0, fpr *restrict z1,
  428         const fpr *restrict tree,
  429         const fpr *restrict t0, const fpr *restrict t1, unsigned logn,
  430         fpr *restrict tmp)
  431 {
  432         size_t n, hn;
  433         const fpr *tree0, *tree1;
  434 
  435         /*
  436          * When logn == 2, we inline the last two recursion levels.
  437          */
  438         if (logn == 2) {
  439 #if FALCON_AVX2  // yyyAVX2+1
  440                 fpr w0, w1, w2, w3, sigma;
  441                 __m128d ww0, ww1, wa, wb, wc, wd;
  442                 __m128d wy0, wy1, wz0, wz1;
  443                 __m128d half, invsqrt8, invsqrt2, neghi, neglo;
  444                 int si0, si1, si2, si3;
  445 
  446                 tree0 = tree + 4;
  447                 tree1 = tree + 8;
  448 
  449                 half = _mm_set1_pd(0.5);
  450                 invsqrt8 = _mm_set1_pd(0.353553390593273762200422181052);
  451                 invsqrt2 = _mm_set1_pd(0.707106781186547524400844362105);
  452                 neghi = _mm_set_pd(-0.0, 0.0);
  453                 neglo = _mm_set_pd(0.0, -0.0);
  454 
  455                 /*
  456                  * We split t1 into w*, then do the recursive invocation,
  457                  * with output in w*. We finally merge back into z1.
  458                  */
  459                 ww0 = _mm_loadu_pd(&t1[0].v);
  460                 ww1 = _mm_loadu_pd(&t1[2].v);
  461                 wa = _mm_unpacklo_pd(ww0, ww1);
  462                 wb = _mm_unpackhi_pd(ww0, ww1);
  463                 wc = _mm_add_pd(wa, wb);
  464                 ww0 = _mm_mul_pd(wc, half);
  465                 wc = _mm_sub_pd(wa, wb);
  466                 wd = _mm_xor_pd(_mm_permute_pd(wc, 1), neghi);
  467                 ww1 = _mm_mul_pd(_mm_add_pd(wc, wd), invsqrt8);
  468 
  469                 w2.v = _mm_cvtsd_f64(ww1);
  470                 w3.v = _mm_cvtsd_f64(_mm_permute_pd(ww1, 1));
  471                 wa = ww1;
  472                 sigma = tree1[3];
  473                 si2 = samp(samp_ctx, w2, sigma);
  474                 si3 = samp(samp_ctx, w3, sigma);
  475                 ww1 = _mm_set_pd((double)si3, (double)si2);
  476                 wa = _mm_sub_pd(wa, ww1);
  477                 wb = _mm_loadu_pd(&tree1[0].v);
  478                 wc = _mm_mul_pd(wa, wb);
  479                 wd = _mm_mul_pd(wa, _mm_permute_pd(wb, 1));
  480                 wa = _mm_unpacklo_pd(wc, wd);
  481                 wb = _mm_unpackhi_pd(wc, wd);
  482                 ww0 = _mm_add_pd(ww0, _mm_add_pd(wa, _mm_xor_pd(wb, neglo)));
  483                 w0.v = _mm_cvtsd_f64(ww0);
  484                 w1.v = _mm_cvtsd_f64(_mm_permute_pd(ww0, 1));
  485                 sigma = tree1[2];
  486                 si0 = samp(samp_ctx, w0, sigma);
  487                 si1 = samp(samp_ctx, w1, sigma);
  488                 ww0 = _mm_set_pd((double)si1, (double)si0);
  489 
  490                 wc = _mm_mul_pd(
  491                         _mm_set_pd((double)(si2 + si3), (double)(si2 - si3)),
  492                         invsqrt2);
  493                 wa = _mm_add_pd(ww0, wc);
  494                 wb = _mm_sub_pd(ww0, wc);
  495                 ww0 = _mm_unpacklo_pd(wa, wb);
  496                 ww1 = _mm_unpackhi_pd(wa, wb);
  497                 _mm_storeu_pd(&z1[0].v, ww0);
  498                 _mm_storeu_pd(&z1[2].v, ww1);
  499 
  500                 /*
  501                  * Compute tb0 = t0 + (t1 - z1) * L. Value tb0 ends up in w*.
  502                  */
  503                 wy0 = _mm_sub_pd(_mm_loadu_pd(&t1[0].v), ww0);
  504                 wy1 = _mm_sub_pd(_mm_loadu_pd(&t1[2].v), ww1);
  505                 wz0 = _mm_loadu_pd(&tree[0].v);
  506                 wz1 = _mm_loadu_pd(&tree[2].v);
  507                 ww0 = _mm_sub_pd(_mm_mul_pd(wy0, wz0), _mm_mul_pd(wy1, wz1));
  508                 ww1 = _mm_add_pd(_mm_mul_pd(wy0, wz1), _mm_mul_pd(wy1, wz0));
  509                 ww0 = _mm_add_pd(ww0, _mm_loadu_pd(&t0[0].v));
  510                 ww1 = _mm_add_pd(ww1, _mm_loadu_pd(&t0[2].v));
  511 
  512                 /*
  513                  * Second recursive invocation.
  514                  */
  515                 wa = _mm_unpacklo_pd(ww0, ww1);
  516                 wb = _mm_unpackhi_pd(ww0, ww1);
  517                 wc = _mm_add_pd(wa, wb);
  518                 ww0 = _mm_mul_pd(wc, half);
  519                 wc = _mm_sub_pd(wa, wb);
  520                 wd = _mm_xor_pd(_mm_permute_pd(wc, 1), neghi);
  521                 ww1 = _mm_mul_pd(_mm_add_pd(wc, wd), invsqrt8);
  522 
  523                 w2.v = _mm_cvtsd_f64(ww1);
  524                 w3.v = _mm_cvtsd_f64(_mm_permute_pd(ww1, 1));
  525                 wa = ww1;
  526                 sigma = tree0[3];
  527                 si2 = samp(samp_ctx, w2, sigma);
  528                 si3 = samp(samp_ctx, w3, sigma);
  529                 ww1 = _mm_set_pd((double)si3, (double)si2);
  530                 wa = _mm_sub_pd(wa, ww1);
  531                 wb = _mm_loadu_pd(&tree0[0].v);
  532                 wc = _mm_mul_pd(wa, wb);
  533                 wd = _mm_mul_pd(wa, _mm_permute_pd(wb, 1));
  534                 wa = _mm_unpacklo_pd(wc, wd);
  535                 wb = _mm_unpackhi_pd(wc, wd);
  536                 ww0 = _mm_add_pd(ww0, _mm_add_pd(wa, _mm_xor_pd(wb, neglo)));
  537                 w0.v = _mm_cvtsd_f64(ww0);
  538                 w1.v = _mm_cvtsd_f64(_mm_permute_pd(ww0, 1));
  539                 sigma = tree0[2];
  540                 si0 = samp(samp_ctx, w0, sigma);
  541                 si1 = samp(samp_ctx, w1, sigma);
  542                 ww0 = _mm_set_pd((double)si1, (double)si0);
  543 
  544                 wc = _mm_mul_pd(
  545                         _mm_set_pd((double)(si2 + si3), (double)(si2 - si3)),
  546                         invsqrt2);
  547                 wa = _mm_add_pd(ww0, wc);
  548                 wb = _mm_sub_pd(ww0, wc);
  549                 ww0 = _mm_unpacklo_pd(wa, wb);
  550                 ww1 = _mm_unpackhi_pd(wa, wb);
  551                 _mm_storeu_pd(&z0[0].v, ww0);
  552                 _mm_storeu_pd(&z0[2].v, ww1);
  553 
  554                 return;
  555 #else  // yyyAVX2+0
  556                 fpr x0, x1, y0, y1, w0, w1, w2, w3, sigma;
  557                 fpr a_re, a_im, b_re, b_im, c_re, c_im;
  558 
  559                 tree0 = tree + 4;
  560                 tree1 = tree + 8;
  561 
  562                 /*
  563                  * We split t1 into w*, then do the recursive invocation,
  564                  * with output in w*. We finally merge back into z1.
  565                  */
  566                 a_re = t1[0];
  567                 a_im = t1[2];
  568                 b_re = t1[1];
  569                 b_im = t1[3];
  570                 c_re = fpr_add(a_re, b_re);
  571                 c_im = fpr_add(a_im, b_im);
  572                 w0 = fpr_half(c_re);
  573                 w1 = fpr_half(c_im);
  574                 c_re = fpr_sub(a_re, b_re);
  575                 c_im = fpr_sub(a_im, b_im);
  576                 w2 = fpr_mul(fpr_add(c_re, c_im), fpr_invsqrt8);
  577                 w3 = fpr_mul(fpr_sub(c_im, c_re), fpr_invsqrt8);
  578 
  579                 x0 = w2;
  580                 x1 = w3;
  581                 sigma = tree1[3];
  582                 w2 = fpr_of(samp(samp_ctx, x0, sigma));
  583                 w3 = fpr_of(samp(samp_ctx, x1, sigma));
  584                 a_re = fpr_sub(x0, w2);
  585                 a_im = fpr_sub(x1, w3);
  586                 b_re = tree1[0];
  587                 b_im = tree1[1];
  588                 c_re = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
  589                 c_im = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
  590                 x0 = fpr_add(c_re, w0);
  591                 x1 = fpr_add(c_im, w1);
  592                 sigma = tree1[2];
  593                 w0 = fpr_of(samp(samp_ctx, x0, sigma));
  594                 w1 = fpr_of(samp(samp_ctx, x1, sigma));
  595 
  596                 a_re = w0;
  597                 a_im = w1;
  598                 b_re = w2;
  599                 b_im = w3;
  600                 c_re = fpr_mul(fpr_sub(b_re, b_im), fpr_invsqrt2);
  601                 c_im = fpr_mul(fpr_add(b_re, b_im), fpr_invsqrt2);
  602                 z1[0] = w0 = fpr_add(a_re, c_re);
  603                 z1[2] = w2 = fpr_add(a_im, c_im);
  604                 z1[1] = w1 = fpr_sub(a_re, c_re);
  605                 z1[3] = w3 = fpr_sub(a_im, c_im);
  606 
  607                 /*
  608                  * Compute tb0 = t0 + (t1 - z1) * L. Value tb0 ends up in w*.
  609                  */
  610                 w0 = fpr_sub(t1[0], w0);
  611                 w1 = fpr_sub(t1[1], w1);
  612                 w2 = fpr_sub(t1[2], w2);
  613                 w3 = fpr_sub(t1[3], w3);
  614 
  615                 a_re = w0;
  616                 a_im = w2;
  617                 b_re = tree[0];
  618                 b_im = tree[2];
  619                 w0 = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
  620                 w2 = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
  621                 a_re = w1;
  622                 a_im = w3;
  623                 b_re = tree[1];
  624                 b_im = tree[3];
  625                 w1 = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
  626                 w3 = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
  627 
  628                 w0 = fpr_add(w0, t0[0]);
  629                 w1 = fpr_add(w1, t0[1]);
  630                 w2 = fpr_add(w2, t0[2]);
  631                 w3 = fpr_add(w3, t0[3]);
  632 
  633                 /*
  634                  * Second recursive invocation.
  635                  */
  636                 a_re = w0;
  637                 a_im = w2;
  638                 b_re = w1;
  639                 b_im = w3;
  640                 c_re = fpr_add(a_re, b_re);
  641                 c_im = fpr_add(a_im, b_im);
  642                 w0 = fpr_half(c_re);
  643                 w1 = fpr_half(c_im);
  644                 c_re = fpr_sub(a_re, b_re);
  645                 c_im = fpr_sub(a_im, b_im);
  646                 w2 = fpr_mul(fpr_add(c_re, c_im), fpr_invsqrt8);
  647                 w3 = fpr_mul(fpr_sub(c_im, c_re), fpr_invsqrt8);
  648 
  649                 x0 = w2;
  650                 x1 = w3;
  651                 sigma = tree0[3];
  652                 w2 = y0 = fpr_of(samp(samp_ctx, x0, sigma));
  653                 w3 = y1 = fpr_of(samp(samp_ctx, x1, sigma));
  654                 a_re = fpr_sub(x0, y0);
  655                 a_im = fpr_sub(x1, y1);
  656                 b_re = tree0[0];
  657                 b_im = tree0[1];
  658                 c_re = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
  659                 c_im = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
  660                 x0 = fpr_add(c_re, w0);
  661                 x1 = fpr_add(c_im, w1);
  662                 sigma = tree0[2];
  663                 w0 = fpr_of(samp(samp_ctx, x0, sigma));
  664                 w1 = fpr_of(samp(samp_ctx, x1, sigma));
  665 
  666                 a_re = w0;
  667                 a_im = w1;
  668                 b_re = w2;
  669                 b_im = w3;
  670                 c_re = fpr_mul(fpr_sub(b_re, b_im), fpr_invsqrt2);
  671                 c_im = fpr_mul(fpr_add(b_re, b_im), fpr_invsqrt2);
  672                 z0[0] = fpr_add(a_re, c_re);
  673                 z0[2] = fpr_add(a_im, c_im);
  674                 z0[1] = fpr_sub(a_re, c_re);
  675                 z0[3] = fpr_sub(a_im, c_im);
  676 
  677                 return;
  678 #endif  // yyyAVX2-
  679         }
  680 
  681         /*
  682          * Case logn == 1 is reachable only when using Falcon-2 (the
  683          * smallest size for which Falcon is mathematically defined, but
  684          * of course way too insecure to be of any use).
  685          */
  686         if (logn == 1) {
  687                 fpr x0, x1, y0, y1, sigma;
  688                 fpr a_re, a_im, b_re, b_im, c_re, c_im;
  689 
  690                 x0 = t1[0];
  691                 x1 = t1[1];
  692                 sigma = tree[3];
  693                 z1[0] = y0 = fpr_of(samp(samp_ctx, x0, sigma));
  694                 z1[1] = y1 = fpr_of(samp(samp_ctx, x1, sigma));
  695                 a_re = fpr_sub(x0, y0);
  696                 a_im = fpr_sub(x1, y1);
  697                 b_re = tree[0];
  698                 b_im = tree[1];
  699                 c_re = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
  700                 c_im = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
  701                 x0 = fpr_add(c_re, t0[0]);
  702                 x1 = fpr_add(c_im, t0[1]);
  703                 sigma = tree[2];
  704                 z0[0] = fpr_of(samp(samp_ctx, x0, sigma));
  705                 z0[1] = fpr_of(samp(samp_ctx, x1, sigma));
  706 
  707                 return;
  708         }
  709 
  710         /*
  711          * Normal end of recursion is for logn == 0. Since the last
  712          * steps of the recursions were inlined in the blocks above
  713          * (when logn == 1 or 2), this case is not reachable, and is
  714          * retained here only for documentation purposes.
  715 
  716         if (logn == 0) {
  717                 fpr x0, x1, sigma;
  718 
  719                 x0 = t0[0];
  720                 x1 = t1[0];
  721                 sigma = tree[0];
  722                 z0[0] = fpr_of(samp(samp_ctx, x0, sigma));
  723                 z1[0] = fpr_of(samp(samp_ctx, x1, sigma));
  724                 return;
  725         }
  726 
  727          */
  728 
  729         /*
  730          * General recursive case (logn >= 3).
  731          */
  732 
  733         n = (size_t)1 << logn;
  734         hn = n >> 1;
  735         tree0 = tree + n;
  736         tree1 = tree + n + ffLDL_treesize(logn - 1);
  737 
  738         /*
  739          * We split t1 into z1 (reused as temporary storage), then do
  740          * the recursive invocation, with output in tmp. We finally
  741          * merge back into z1.
  742          */
  743         Zf(poly_split_fft)(z1, z1 + hn, t1, logn);
  744         ffSampling_fft(samp, samp_ctx, tmp, tmp + hn,
  745                 tree1, z1, z1 + hn, logn - 1, tmp + n);
  746         Zf(poly_merge_fft)(z1, tmp, tmp + hn, logn);
  747 
  748         /*
  749          * Compute tb0 = t0 + (t1 - z1) * L. Value tb0 ends up in tmp[].
  750          */
  751         memcpy(tmp, t1, n * sizeof *t1);
  752         Zf(poly_sub)(tmp, z1, logn);
  753         Zf(poly_mul_fft)(tmp, tree, logn);
  754         Zf(poly_add)(tmp, t0, logn);
  755 
  756         /*
  757          * Second recursive invocation.
  758          */
  759         Zf(poly_split_fft)(z0, z0 + hn, tmp, logn);
  760         ffSampling_fft(samp, samp_ctx, tmp, tmp + hn,
  761                 tree0, z0, z0 + hn, logn - 1, tmp + n);
  762         Zf(poly_merge_fft)(z0, tmp, tmp + hn, logn);
  763 }
  764 
  765 /*
  766  * Compute a signature: the signature contains two vectors, s1 and s2.
  767  * The s1 vector is not returned. The squared norm of (s1,s2) is
  768  * computed, and if it is short enough, then s2 is returned into the
  769  * s2[] buffer, and 1 is returned; otherwise, s2[] is untouched and 0 is
  770  * returned; the caller should then try again. This function uses an
  771  * expanded key.
  772  *
  773  * tmp[] must have room for at least six polynomials.
  774  */
  775 static int
  776 do_sign_tree(samplerZ samp, void *samp_ctx, int16_t *s2,
  777         const fpr *restrict expanded_key,
  778         const uint16_t *hm,
  779         unsigned logn, fpr *restrict tmp)
  780 {
  781         size_t n, u;
  782         fpr *t0, *t1, *tx, *ty;
  783         const fpr *b00, *b01, *b10, *b11, *tree;
  784         fpr ni;
  785         uint32_t sqn, ng;
  786         int16_t *s2tmp;
  787 
  788         n = MKN(logn);
  789         t0 = tmp;
  790         t1 = t0 + n;
  791         b00 = expanded_key + skoff_b00(logn);
  792         b01 = expanded_key + skoff_b01(logn);
  793         b10 = expanded_key + skoff_b10(logn);
  794         b11 = expanded_key + skoff_b11(logn);
  795         tree = expanded_key + skoff_tree(logn);
  796 
  797         /*
  798          * Set the target vector to [hm, 0] (hm is the hashed message).
  799          */
  800         for (u = 0; u < n; u ++) {
  801                 t0[u] = fpr_of(hm[u]);
  802                 /* This is implicit.
  803                 t1[u] = fpr_zero;
  804                 */
  805         }
  806 
  807         /*
  808          * Apply the lattice basis to obtain the real target
  809          * vector (after normalization with regards to modulus).
  810          */
  811         Zf(FFT)(t0, logn);
  812         ni = fpr_inverse_of_q;
  813         memcpy(t1, t0, n * sizeof *t0);
  814         Zf(poly_mul_fft)(t1, b01, logn);
  815         Zf(poly_mulconst)(t1, fpr_neg(ni), logn);
  816         Zf(poly_mul_fft)(t0, b11, logn);
  817         Zf(poly_mulconst)(t0, ni, logn);
  818 
  819         tx = t1 + n;
  820         ty = tx + n;
  821 
  822         /*
  823          * Apply sampling. Output is written back in [tx, ty].
  824          */
  825         ffSampling_fft(samp, samp_ctx, tx, ty, tree, t0, t1, logn, ty + n);
  826 
  827         /*
  828          * Get the lattice point corresponding to that tiny vector.
  829          */
  830         memcpy(t0, tx, n * sizeof *tx);
  831         memcpy(t1, ty, n * sizeof *ty);
  832         Zf(poly_mul_fft)(tx, b00, logn);
  833         Zf(poly_mul_fft)(ty, b10, logn);
  834         Zf(poly_add)(tx, ty, logn);
  835         memcpy(ty, t0, n * sizeof *t0);
  836         Zf(poly_mul_fft)(ty, b01, logn);
  837 
  838         memcpy(t0, tx, n * sizeof *tx);
  839         Zf(poly_mul_fft)(t1, b11, logn);
  840         Zf(poly_add)(t1, ty, logn);
  841 
  842         Zf(iFFT)(t0, logn);
  843         Zf(iFFT)(t1, logn);
  844 
  845         /*
  846          * Compute the signature.
  847          */
  848         sqn = 0;
  849         ng = 0;
  850         for (u = 0; u < n; u ++) {
  851                 int32_t z;
  852 
  853                 z = (int32_t)hm[u] - (int32_t)fpr_rint(t0[u]);
  854                 sqn += (uint32_t)(z * z);
  855                 ng |= sqn;
  856         }
  857         sqn |= -(ng >> 31);
  858 
  859         /*
  860          * With "normal" degrees (e.g. 512 or 1024), it is very
  861          * improbable that the computed vector is not short enough;
  862          * however, it may happen in practice for the very reduced
  863          * versions (e.g. degree 16 or below). In that case, the caller
  864          * will loop, and we must not write anything into s2[] because
  865          * s2[] may overlap with the hashed message hm[] and we need
  866          * hm[] for the next iteration.
  867          */
  868         s2tmp = (int16_t *)tmp;
  869         for (u = 0; u < n; u ++) {
  870                 s2tmp[u] = (int16_t)-fpr_rint(t1[u]);
  871         }
  872         if (Zf(is_short_half)(sqn, s2tmp, logn)) {
  873                 memcpy(s2, s2tmp, n * sizeof *s2);
  874                 return 1;
  875         }
  876         return 0;
  877 }
  878 
  879 /*
  880  * Compute a signature: the signature contains two vectors, s1 and s2.
  881  * The s1 vector is not returned. The squared norm of (s1,s2) is
  882  * computed, and if it is short enough, then s2 is returned into the
  883  * s2[] buffer, and 1 is returned; otherwise, s2[] is untouched and 0 is
  884  * returned; the caller should then try again.
  885  *
  886  * tmp[] must have room for at least nine polynomials.
  887  */
  888 static int
  889 do_sign_dyn(samplerZ samp, void *samp_ctx, int16_t *s2,
  890         const int8_t *restrict f, const int8_t *restrict g,
  891         const int8_t *restrict F, const int8_t *restrict G,
  892         const uint16_t *hm, unsigned logn, fpr *restrict tmp)
  893 {
  894         size_t n, u;
  895         fpr *t0, *t1, *tx, *ty;
  896         fpr *b00, *b01, *b10, *b11, *g00, *g01, *g11;
  897         fpr ni;
  898         uint32_t sqn, ng;
  899         int16_t *s2tmp;
  900 
  901         n = MKN(logn);
  902 
  903         /*
  904          * Lattice basis is B = [[g, -f], [G, -F]]. We convert it to FFT.
  905          */
  906         b00 = tmp;
  907         b01 = b00 + n;
  908         b10 = b01 + n;
  909         b11 = b10 + n;
  910         smallints_to_fpr(b01, f, logn);
  911         smallints_to_fpr(b00, g, logn);
  912         smallints_to_fpr(b11, F, logn);
  913         smallints_to_fpr(b10, G, logn);
  914         Zf(FFT)(b01, logn);
  915         Zf(FFT)(b00, logn);
  916         Zf(FFT)(b11, logn);
  917         Zf(FFT)(b10, logn);
  918         Zf(poly_neg)(b01, logn);
  919         Zf(poly_neg)(b11, logn);
  920 
  921         /*
  922          * Compute the Gram matrix G = B·B*. Formulas are:
  923          *   g00 = b00*adj(b00) + b01*adj(b01)
  924          *   g01 = b00*adj(b10) + b01*adj(b11)
  925          *   g10 = b10*adj(b00) + b11*adj(b01)
  926          *   g11 = b10*adj(b10) + b11*adj(b11)
  927          *
  928          * For historical reasons, this implementation uses
  929          * g00, g01 and g11 (upper triangle). g10 is not kept
  930          * since it is equal to adj(g01).
  931          *
  932          * We _replace_ the matrix B with the Gram matrix, but we
  933          * must keep b01 and b11 for computing the target vector.
  934          */
  935         t0 = b11 + n;
  936         t1 = t0 + n;
  937 
  938         memcpy(t0, b01, n * sizeof *b01);
  939         Zf(poly_mulselfadj_fft)(t0, logn);    // t0 <- b01*adj(b01)
  940 
  941         memcpy(t1, b00, n * sizeof *b00);
  942         Zf(poly_muladj_fft)(t1, b10, logn);   // t1 <- b00*adj(b10)
  943         Zf(poly_mulselfadj_fft)(b00, logn);   // b00 <- b00*adj(b00)
  944         Zf(poly_add)(b00, t0, logn);      // b00 <- g00
  945         memcpy(t0, b01, n * sizeof *b01);
  946         Zf(poly_muladj_fft)(b01, b11, logn);  // b01 <- b01*adj(b11)
  947         Zf(poly_add)(b01, t1, logn);      // b01 <- g01
  948 
  949         Zf(poly_mulselfadj_fft)(b10, logn);   // b10 <- b10*adj(b10)
  950         memcpy(t1, b11, n * sizeof *b11);
  951         Zf(poly_mulselfadj_fft)(t1, logn);    // t1 <- b11*adj(b11)
  952         Zf(poly_add)(b10, t1, logn);      // b10 <- g11
  953 
  954         /*
  955          * We rename variables to make things clearer. The three elements
  956          * of the Gram matrix uses the first 3*n slots of tmp[], followed
  957          * by b11 and b01 (in that order).
  958          */
  959         g00 = b00;
  960         g01 = b01;
  961         g11 = b10;
  962         b01 = t0;
  963         t0 = b01 + n;
  964         t1 = t0 + n;
  965 
  966         /*
  967          * Memory layout at that point:
  968          *   g00 g01 g11 b11 b01 t0 t1
  969          */
  970 
  971         /*
  972          * Set the target vector to [hm, 0] (hm is the hashed message).
  973          */
  974         for (u = 0; u < n; u ++) {
  975                 t0[u] = fpr_of(hm[u]);
  976                 /* This is implicit.
  977                 t1[u] = fpr_zero;
  978                 */
  979         }
  980 
  981         /*
  982          * Apply the lattice basis to obtain the real target
  983          * vector (after normalization with regards to modulus).
  984          */
  985         Zf(FFT)(t0, logn);
  986         ni = fpr_inverse_of_q;
  987         memcpy(t1, t0, n * sizeof *t0);
  988         Zf(poly_mul_fft)(t1, b01, logn);
  989         Zf(poly_mulconst)(t1, fpr_neg(ni), logn);
  990         Zf(poly_mul_fft)(t0, b11, logn);
  991         Zf(poly_mulconst)(t0, ni, logn);
  992 
  993         /*
  994          * b01 and b11 can be discarded, so we move back (t0,t1).
  995          * Memory layout is now:
  996          *      g00 g01 g11 t0 t1
  997          */
  998         memcpy(b11, t0, n * 2 * sizeof *t0);
  999         t0 = g11 + n;
 1000         t1 = t0 + n;
 1001 
 1002         /*
 1003          * Apply sampling; result is written over (t0,t1).
 1004          */
 1005         ffSampling_fft_dyntree(samp, samp_ctx,
 1006                 t0, t1, g00, g01, g11, logn, t1 + n);
 1007 
 1008         /*
 1009          * We arrange the layout back to:
 1010          *     b00 b01 b10 b11 t0 t1
 1011          *
 1012          * We did not conserve the matrix basis, so we must recompute
 1013          * it now.
 1014          */
 1015         b00 = tmp;
 1016         b01 = b00 + n;
 1017         b10 = b01 + n;
 1018         b11 = b10 + n;
 1019         memmove(b11 + n, t0, n * 2 * sizeof *t0);
 1020         t0 = b11 + n;
 1021         t1 = t0 + n;
 1022         smallints_to_fpr(b01, f, logn);
 1023         smallints_to_fpr(b00, g, logn);
 1024         smallints_to_fpr(b11, F, logn);
 1025         smallints_to_fpr(b10, G, logn);
 1026         Zf(FFT)(b01, logn);
 1027         Zf(FFT)(b00, logn);
 1028         Zf(FFT)(b11, logn);
 1029         Zf(FFT)(b10, logn);
 1030         Zf(poly_neg)(b01, logn);
 1031         Zf(poly_neg)(b11, logn);
 1032         tx = t1 + n;
 1033         ty = tx + n;
 1034 
 1035         /*
 1036          * Get the lattice point corresponding to that tiny vector.
 1037          */
 1038         memcpy(tx, t0, n * sizeof *t0);
 1039         memcpy(ty, t1, n * sizeof *t1);
 1040         Zf(poly_mul_fft)(tx, b00, logn);
 1041         Zf(poly_mul_fft)(ty, b10, logn);
 1042         Zf(poly_add)(tx, ty, logn);
 1043         memcpy(ty, t0, n * sizeof *t0);
 1044         Zf(poly_mul_fft)(ty, b01, logn);
 1045 
 1046         memcpy(t0, tx, n * sizeof *tx);
 1047         Zf(poly_mul_fft)(t1, b11, logn);
 1048         Zf(poly_add)(t1, ty, logn);
 1049         Zf(iFFT)(t0, logn);
 1050         Zf(iFFT)(t1, logn);
 1051 
 1052         sqn = 0;
 1053         ng = 0;
 1054         for (u = 0; u < n; u ++) {
 1055                 int32_t z;
 1056 
 1057                 z = (int32_t)hm[u] - (int32_t)fpr_rint(t0[u]);
 1058                 sqn += (uint32_t)(z * z);
 1059                 ng |= sqn;
 1060         }
 1061         sqn |= -(ng >> 31);
 1062 
 1063         /*
 1064          * With "normal" degrees (e.g. 512 or 1024), it is very
 1065          * improbable that the computed vector is not short enough;
 1066          * however, it may happen in practice for the very reduced
 1067          * versions (e.g. degree 16 or below). In that case, the caller
 1068          * will loop, and we must not write anything into s2[] because
 1069          * s2[] may overlap with the hashed message hm[] and we need
 1070          * hm[] for the next iteration.
 1071          */
 1072         s2tmp = (int16_t *)tmp;
 1073         for (u = 0; u < n; u ++) {
 1074                 s2tmp[u] = (int16_t)-fpr_rint(t1[u]);
 1075         }
 1076         if (Zf(is_short_half)(sqn, s2tmp, logn)) {
 1077                 memcpy(s2, s2tmp, n * sizeof *s2);
 1078                 return 1;
 1079         }
 1080         return 0;
 1081 }
 1082 
 1083 /*
 1084  * Sample an integer value along a half-gaussian distribution centered
 1085  * on zero and standard deviation 1.8205, with a precision of 72 bits.
 1086  */
 1087 TARGET_AVX2
 1088 static int
 1089 gaussian0_sampler(prng *p)
 1090 {
 1091 #if FALCON_AVX2 // yyyAVX2+1
 1092 
 1093         /*
 1094          * High words.
 1095          */
 1096         static const union {
 1097                 uint16_t u16[16];
 1098                 __m256i ymm[1];
 1099         } rhi15 = {
 1100                 {
 1101                         0x2E04, 0x2792, 0x192A, 0x0BD6,
 1102                         0x041D, 0x010F, 0x0033, 0x0007,
 1103                         0x0000, 0x0000, 0x0000, 0x0000,
 1104                         0x0000, 0x0000, 0x0000, 0x0000
 1105                 }
 1106         };
 1107 
 1108         static const union {
 1109                 uint64_t u64[20];
 1110                 __m256i ymm[5];
 1111         } rlo57 = {
 1112                 {
 1113                         0x00BD12C53C6E7FE, 0x0C916B46CBB3C80,
 1114                         0x15D478BF7541983, 0x0588D1B8090771B,
 1115                         0x1F0D6C8D45A2B75, 0x039A55A05CA6210,
 1116                         0x125EA00C1FD527B, 0x082DD77C5CA070A,
 1117                         0x1827F03DB31FFB2, 0x01DBCDAD87757BF,
 1118                         0x001B12E8D33DA6E, 0x000123A854528DC,
 1119                         0x0000091392D5F87, 0x0000003579E3ADD,
 1120                         0x00000000E90113D, 0x0000000002EECE6,
 1121                         0x000000000006FD2, 0x0000000000000C5,
 1122                         0x000000000000001, 0x000000000000000
 1123                 }
 1124         };
 1125 
 1126         uint64_t lo;
 1127         unsigned hi;
 1128         __m256i xhi, rhi, gthi, eqhi, eqm;
 1129         __m256i xlo, gtlo0, gtlo1, gtlo2, gtlo3, gtlo4;
 1130         __m128i t, zt;
 1131         int r;
 1132 
 1133         /*
 1134          * Get a 72-bit random value and split it into a low part
 1135          * (57 bits) and a high part (15 bits)
 1136          */
 1137         lo = prng_get_u64(p);
 1138         hi = prng_get_u8(p);
 1139         hi = (hi << 7) | (unsigned)(lo >> 57);
 1140         lo &= 0x1FFFFFFFFFFFFFF;
 1141 
 1142         /*
 1143          * Broadcast the high part and compare it with the relevant
 1144          * values. We need both a "greater than" and an "equal"
 1145          * comparisons.
 1146          */
 1147         xhi = _mm256_broadcastw_epi16(_mm_cvtsi32_si128(hi));
 1148         rhi = _mm256_loadu_si256(&rhi15.ymm[0]);
 1149         gthi = _mm256_cmpgt_epi16(rhi, xhi);
 1150         eqhi = _mm256_cmpeq_epi16(rhi, xhi);
 1151 
 1152         /*
 1153          * The result is the number of 72-bit values (among the list of 19)
 1154          * which are greater than the 72-bit random value. We first count
 1155          * all non-zero 16-bit elements in the first eight of gthi. Such
 1156          * elements have value -1 or 0, so we first negate them.
 1157          */
 1158         t = _mm_srli_epi16(_mm256_castsi256_si128(gthi), 15);
 1159         zt = _mm_setzero_si128();
 1160         t = _mm_hadd_epi16(t, zt);
 1161         t = _mm_hadd_epi16(t, zt);
 1162         t = _mm_hadd_epi16(t, zt);
 1163         r = _mm_cvtsi128_si32(t);
 1164 
 1165         /*
 1166          * We must look at the low bits for all values for which the
 1167          * high bits are an "equal" match; values 8-18 all have the
 1168          * same high bits (0).
 1169          * On 32-bit systems, 'lo' really is two registers, requiring
 1170          * some extra code.
 1171          */
 1172 #if defined(__x86_64__) || defined(_M_X64)
 1173         xlo = _mm256_broadcastq_epi64(_mm_cvtsi64_si128(*(int64_t *)&lo));
 1174 #else
 1175         {
 1176                 uint32_t e0, e1;
 1177                 int32_t f0, f1;
 1178 
 1179                 e0 = (uint32_t)lo;
 1180                 e1 = (uint32_t)(lo >> 32);
 1181                 f0 = *(int32_t *)&e0;
 1182                 f1 = *(int32_t *)&e1;
 1183                 xlo = _mm256_set_epi32(f1, f0, f1, f0, f1, f0, f1, f0);
 1184         }
 1185 #endif
 1186         gtlo0 = _mm256_cmpgt_epi64(_mm256_loadu_si256(&rlo57.ymm[0]), xlo); 
 1187         gtlo1 = _mm256_cmpgt_epi64(_mm256_loadu_si256(&rlo57.ymm[1]), xlo); 
 1188         gtlo2 = _mm256_cmpgt_epi64(_mm256_loadu_si256(&rlo57.ymm[2]), xlo); 
 1189         gtlo3 = _mm256_cmpgt_epi64(_mm256_loadu_si256(&rlo57.ymm[3]), xlo); 
 1190         gtlo4 = _mm256_cmpgt_epi64(_mm256_loadu_si256(&rlo57.ymm[4]), xlo); 
 1191 
 1192         /*
 1193          * Keep only comparison results that correspond to the non-zero
 1194          * elements in eqhi.
 1195          */
 1196         gtlo0 = _mm256_and_si256(gtlo0, _mm256_cvtepi16_epi64(
 1197                 _mm256_castsi256_si128(eqhi)));
 1198         gtlo1 = _mm256_and_si256(gtlo1, _mm256_cvtepi16_epi64(
 1199                 _mm256_castsi256_si128(_mm256_bsrli_epi128(eqhi, 8))));
 1200         eqm = _mm256_permute4x64_epi64(eqhi, 0xFF);
 1201         gtlo2 = _mm256_and_si256(gtlo2, eqm);
 1202         gtlo3 = _mm256_and_si256(gtlo3, eqm);
 1203         gtlo4 = _mm256_and_si256(gtlo4, eqm);
 1204 
 1205         /*
 1206          * Add all values to count the total number of "-1" elements.
 1207          * Since the first eight "high" words are all different, only
 1208          * one element (at most) in gtlo0:gtlo1 can be non-zero; however,
 1209          * if the high word of the random value is zero, then many
 1210          * elements of gtlo2:gtlo3:gtlo4 can be non-zero.
 1211          */
 1212         gtlo0 = _mm256_or_si256(gtlo0, gtlo1);
 1213         gtlo0 = _mm256_add_epi64(
 1214                 _mm256_add_epi64(gtlo0, gtlo2),
 1215                 _mm256_add_epi64(gtlo3, gtlo4));
 1216         t = _mm_add_epi64(
 1217                 _mm256_castsi256_si128(gtlo0),
 1218                 _mm256_extracti128_si256(gtlo0, 1));
 1219         t = _mm_add_epi64(t, _mm_srli_si128(t, 8));
 1220         r -= _mm_cvtsi128_si32(t);
 1221 
 1222         return r;
 1223 
 1224 #else // yyyAVX2+0
 1225 
 1226         static const uint32_t dist[] = {
 1227                  6031371U, 13708371U, 13035518U,
 1228                  5186761U,  1487980U, 12270720U,
 1229                  3298653U,  4688887U,  5511555U,
 1230                  1551448U,  9247616U,  9467675U,
 1231                   539632U, 14076116U,  5909365U,
 1232                   138809U, 10836485U, 13263376U,
 1233                    26405U, 15335617U, 16601723U,
 1234                     3714U, 14514117U, 13240074U,
 1235                      386U,  8324059U,  3276722U,
 1236                       29U, 12376792U,  7821247U,
 1237                        1U, 11611789U,  3398254U,
 1238                        0U,  1194629U,  4532444U,
 1239                        0U,    37177U,  2973575U,
 1240                        0U,      855U, 10369757U,
 1241                        0U,       14U,  9441597U,
 1242                        0U,        0U,  3075302U,
 1243                        0U,        0U,    28626U,
 1244                        0U,        0U,      197U,
 1245                        0U,        0U,        1U
 1246         };
 1247 
 1248         uint32_t v0, v1, v2, hi;
 1249         uint64_t lo;
 1250         size_t u;
 1251         int z;
 1252 
 1253         /*
 1254          * Get a random 72-bit value, into three 24-bit limbs v0..v2.
 1255          */
 1256         lo = prng_get_u64(p);
 1257         hi = prng_get_u8(p);
 1258         v0 = (uint32_t)lo & 0xFFFFFF;
 1259         v1 = (uint32_t)(lo >> 24) & 0xFFFFFF;
 1260         v2 = (uint32_t)(lo >> 48) | (hi << 16);
 1261 
 1262         /*
 1263          * Sampled value is z, such that v0..v2 is lower than the first
 1264          * z elements of the table.
 1265          */
 1266         z = 0;
 1267         for (u = 0; u < (sizeof dist) / sizeof(dist[0]); u += 3) {
 1268                 uint32_t w0, w1, w2, cc;
 1269 
 1270                 w0 = dist[u + 2];
 1271                 w1 = dist[u + 1];
 1272                 w2 = dist[u + 0];
 1273                 cc = (v0 - w0) >> 31;
 1274                 cc = (v1 - w1 - cc) >> 31;
 1275                 cc = (v2 - w2 - cc) >> 31;
 1276                 z += (int)cc;
 1277         }
 1278         return z;
 1279 
 1280 #endif // yyyAVX2-
 1281 }
 1282 
 1283 /*
 1284  * Sample a bit with probability exp(-x) for some x >= 0.
 1285  */
 1286 TARGET_AVX2
 1287 static int
 1288 BerExp(prng *p, fpr x)
 1289 {
 1290         int s, i;
 1291         fpr r;
 1292         uint32_t sw, w;
 1293         uint64_t z;
 1294 
 1295         /*
 1296          * Reduce x modulo log(2): x = s*log(2) + r, with s an integer,
 1297          * and 0 <= r < log(2). Since x >= 0, we can use fpr_trunc().
 1298          */
 1299         s = (int)fpr_trunc(fpr_mul(x, fpr_inv_log2));
 1300         r = fpr_sub(x, fpr_mul(fpr_of(s), fpr_log2));
 1301 
 1302         /*
 1303          * It may happen (quite rarely) that s >= 64; if sigma = 1.2
 1304          * (the minimum value for sigma), r = 0 and b = 1, then we get
 1305          * s >= 64 if the half-Gaussian produced a z >= 13, which happens
 1306          * with probability about 0.000000000230383991, which is
 1307          * approximatively equal to 2^(-32). In any case, if s >= 64,
 1308          * then BerExp will be non-zero with probability less than
 1309          * 2^(-64), so we can simply saturate s at 63.
 1310          */
 1311         sw = (uint32_t)s;
 1312         sw ^= (sw ^ 63) & -((63 - sw) >> 31);
 1313         s = (int)sw;
 1314 
 1315         /*
 1316          * Compute exp(-r); we know that 0 <= r < log(2) at this point, so
 1317          * we can use fpr_expm_p63(), which yields a result scaled to 2^63.
 1318          * We scale it up to 2^64, then right-shift it by s bits because
 1319          * we really want exp(-x) = 2^(-s)*exp(-r).
 1320          *
 1321          * The "-1" operation makes sure that the value fits on 64 bits
 1322          * (i.e. if r = 0, we may get 2^64, and we prefer 2^64-1 in that
 1323          * case). The bias is negligible since fpr_expm_p63() only computes
 1324          * with 51 bits of precision or so.
 1325          */
 1326         z = ((fpr_expm_p63(r) << 1) - 1) >> s;
 1327 
 1328         /*
 1329          * Sample a bit with probability exp(-x). Since x = s*log(2) + r,
 1330          * exp(-x) = 2^-s * exp(-r), we compare lazily exp(-x) with the
 1331          * PRNG output to limit its consumption, the sign of the difference
 1332          * yields the expected result.
 1333          */
 1334         i = 64;
 1335         do {
 1336                 i -= 8;
 1337                 w = prng_get_u8(p) - ((uint32_t)(z >> i) & 0xFF);
 1338         } while (!w && i > 0);
 1339         return (int)(w >> 31);
 1340 }
 1341 
 1342 typedef struct {
 1343         prng p;
 1344         fpr sigma_min;
 1345 } sampler_context;
 1346 
 1347 /*
 1348  * The sampler produces a random integer that follows a discrete Gaussian
 1349  * distribution, centered on mu, and with standard deviation sigma. The
 1350  * provided parameter isigma is equal to 1/sigma.
 1351  *
 1352  * The value of sigma MUST lie between 1 and 2 (i.e. isigma lies between
 1353  * 0.5 and 1); in Falcon, sigma should always be between 1.2 and 1.9.
 1354  */
 1355 TARGET_AVX2
 1356 static int
 1357 sampler(void *ctx, fpr mu, fpr isigma)
 1358 {
 1359         sampler_context *spc;
 1360         int s;
 1361         fpr r, dss, ccs;
 1362 
 1363         spc = ctx;
 1364 
 1365         /*
 1366          * Center is mu. We compute mu = s + r where s is an integer
 1367          * and 0 <= r < 1.
 1368          */
 1369         s = (int)fpr_floor(mu);
 1370         r = fpr_sub(mu, fpr_of(s));
 1371 
 1372         /*
 1373          * dss = 1/(2*sigma^2) = 0.5*(isigma^2).
 1374          */
 1375         dss = fpr_half(fpr_sqr(isigma));
 1376 
 1377         /*
 1378          * ccs = sigma_min / sigma = sigma_min * isigma.
 1379          */
 1380         ccs = fpr_mul(isigma, spc->sigma_min);
 1381 
 1382         /*
 1383          * We now need to sample on center r.
 1384          */
 1385         for (;;) {
 1386                 int z0, z, b;
 1387                 fpr x;
 1388 
 1389                 /*
 1390                  * Sample z for a Gaussian distribution. Then get a
 1391                  * random bit b to turn the sampling into a bimodal
 1392                  * distribution: if b = 1, we use z+1, otherwise we
 1393                  * use -z. We thus have two situations:
 1394                  *
 1395                  *  - b = 1: z >= 1 and sampled against a Gaussian
 1396                  *    centered on 1.
 1397                  *  - b = 0: z <= 0 and sampled against a Gaussian
 1398                  *    centered on 0.
 1399                  */
 1400                 z0 = gaussian0_sampler(&spc->p);
 1401                 b = prng_get_u8(&spc->p) & 1;
 1402                 z = b + ((b << 1) - 1) * z0;
 1403 
 1404                 /*
 1405                  * Rejection sampling. We want a Gaussian centered on r;
 1406                  * but we sampled against a Gaussian centered on b (0 or
 1407                  * 1). But we know that z is always in the range where
 1408                  * our sampling distribution is greater than the Gaussian
 1409                  * distribution, so rejection works.
 1410                  *
 1411                  * We got z with distribution:
 1412                  *    G(z) = exp(-((z-b)^2)/(2*sigma0^2))
 1413                  * We target distribution:
 1414                  *    S(z) = exp(-((z-r)^2)/(2*sigma^2))
 1415                  * Rejection sampling works by keeping the value z with
 1416                  * probability S(z)/G(z), and starting again otherwise.
 1417                  * This requires S(z) <= G(z), which is the case here.
 1418                  * Thus, we simply need to keep our z with probability:
 1419                  *    P = exp(-x)
 1420                  * where:
 1421                  *    x = ((z-r)^2)/(2*sigma^2) - ((z-b)^2)/(2*sigma0^2)
 1422                  *
 1423                  * Here, we scale up the Bernouilli distribution, which
 1424                  * makes rejection more probable, but makes rejection
 1425                  * rate sufficiently decorrelated from the Gaussian
 1426                  * center and standard deviation that the whole sampler
 1427                  * can be said to be constant-time.
 1428                  */
 1429                 x = fpr_mul(fpr_sqr(fpr_sub(fpr_of(z), r)), dss);
 1430                 x = fpr_sub(x, fpr_mul(fpr_of(z0 * z0), fpr_inv_2sqrsigma0));
 1431                 x = fpr_mul(x, ccs);
 1432                 if (BerExp(&spc->p, x)) {
 1433                         /*
 1434                          * Rejection sampling was centered on r, but the
 1435                          * actual center is mu = s + r.
 1436                          */
 1437                         return s + z;
 1438                 }
 1439         }
 1440 }
 1441 
 1442 /* see inner.h */
 1443 void
 1444 Zf(sign_tree)(int16_t *sig, inner_shake256_context *rng,
 1445         const fpr *restrict expanded_key,
 1446         const uint16_t *hm, unsigned logn, uint8_t *tmp)
 1447 {
 1448         fpr *ftmp;
 1449 
 1450         ftmp = (fpr *)tmp;
 1451         for (;;) {
 1452                 /*
 1453                  * Signature produces short vectors s1 and s2. The
 1454                  * signature is acceptable only if the aggregate vector
 1455                  * s1,s2 is short; we must use the same bound as the
 1456                  * verifier.
 1457                  *
 1458                  * If the signature is acceptable, then we return only s2
 1459                  * (the verifier recomputes s1 from s2, the hashed message,
 1460                  * and the public key).
 1461                  */
 1462                 sampler_context spc;
 1463                 samplerZ samp;
 1464                 void *samp_ctx;
 1465 
 1466                 /*
 1467                  * Normal sampling. We use a fast PRNG seeded from our
 1468                  * SHAKE context ('rng').
 1469                  */
 1470                 spc.sigma_min = (logn == 10)
 1471                         ? fpr_sigma_min_10
 1472                         : fpr_sigma_min_9;
 1473                 Zf(prng_init)(&spc.p, rng);
 1474                 samp = sampler;
 1475                 samp_ctx = &spc;
 1476 
 1477                 /*
 1478                  * Do the actual signature.
 1479                  */
 1480                 if (do_sign_tree(samp, samp_ctx, sig,
 1481                         expanded_key, hm, logn, ftmp))
 1482                 {
 1483                         break;
 1484                 }
 1485         }
 1486 }
 1487 
 1488 /* see inner.h */
 1489 void
 1490 Zf(sign_dyn)(int16_t *sig, inner_shake256_context *rng,
 1491         const int8_t *restrict f, const int8_t *restrict g,
 1492         const int8_t *restrict F, const int8_t *restrict G,
 1493         const uint16_t *hm, unsigned logn, uint8_t *tmp)
 1494 {
 1495         fpr *ftmp;
 1496 
 1497         ftmp = (fpr *)tmp;
 1498         for (;;) {
 1499                 /*
 1500                  * Signature produces short vectors s1 and s2. The
 1501                  * signature is acceptable only if the aggregate vector
 1502                  * s1,s2 is short; we must use the same bound as the
 1503                  * verifier.
 1504                  *
 1505                  * If the signature is acceptable, then we return only s2
 1506                  * (the verifier recomputes s1 from s2, the hashed message,
 1507                  * and the public key).
 1508                  */
 1509                 sampler_context spc;
 1510                 samplerZ samp;
 1511                 void *samp_ctx;
 1512 
 1513                 /*
 1514                  * Normal sampling. We use a fast PRNG seeded from our
 1515                  * SHAKE context ('rng').
 1516                  */
 1517                 spc.sigma_min = (logn == 10)
 1518                         ? fpr_sigma_min_10
 1519                         : fpr_sigma_min_9;
 1520                 Zf(prng_init)(&spc.p, rng);
 1521                 samp = sampler;
 1522                 samp_ctx = &spc;
 1523 
 1524                 /*
 1525                  * Do the actual signature.
 1526                  */
 1527                 if (do_sign_dyn(samp, samp_ctx, sig,
 1528                         f, g, F, G, hm, logn, ftmp))
 1529                 {
 1530                         break;
 1531                 }
 1532         }
 1533 }