Falcon source files (reference implementation)


sign.c

    1 /*
    2  * Falcon signature generation.
    3  *
    4  * ==========================(LICENSE BEGIN)============================
    5  *
    6  * Copyright (c) 2017-2019  Falcon Project
    7  *
    8  * Permission is hereby granted, free of charge, to any person obtaining
    9  * a copy of this software and associated documentation files (the
   10  * "Software"), to deal in the Software without restriction, including
   11  * without limitation the rights to use, copy, modify, merge, publish,
   12  * distribute, sublicense, and/or sell copies of the Software, and to
   13  * permit persons to whom the Software is furnished to do so, subject to
   14  * the following conditions:
   15  *
   16  * The above copyright notice and this permission notice shall be
   17  * included in all copies or substantial portions of the Software.
   18  *
   19  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
   20  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
   21  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
   22  * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
   23  * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
   24  * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
   25  * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
   26  *
   27  * ===========================(LICENSE END)=============================
   28  *
   29  * @author   Thomas Pornin <thomas.pornin@nccgroup.com>
   30  */
   31 
   32 #include "inner.h"
   33 
   34 /* =================================================================== */
   35 
   36 /*
   37  * Compute degree N from logarithm 'logn'.
   38  */
   39 #define MKN(logn)   ((size_t)1 << (logn))
   40 
   41 /* =================================================================== */
   42 /*
   43  * Binary case:
   44  *   N = 2^logn
   45  *   phi = X^N+1
   46  */
   47 
   48 /*
   49  * Get the size of the LDL tree for an input with polynomials of size
   50  * 2^logn. The size is expressed in the number of elements.
   51  */
   52 static inline unsigned
   53 ffLDL_treesize(unsigned logn)
   54 {
   55         /*
   56          * For logn = 0 (polynomials are constant), the "tree" is a
   57          * single element. Otherwise, the tree node has size 2^logn, and
   58          * has two child trees for size logn-1 each. Thus, treesize s()
   59          * must fulfill these two relations:
   60          *
   61          *   s(0) = 1
   62          *   s(logn) = (2^logn) + 2*s(logn-1)
   63          */
   64         return (logn + 1) << logn;
   65 }
   66 
   67 /*
   68  * Inner function for ffLDL_fft(). It expects the matrix to be both
   69  * auto-adjoint and quasicyclic; also, it uses the source operands
   70  * as modifiable temporaries.
   71  *
   72  * tmp[] must have room for at least one polynomial.
   73  */
   74 static void
   75 ffLDL_fft_inner(fpr *restrict tree,
   76         fpr *restrict g0, fpr *restrict g1, unsigned logn, fpr *restrict tmp)
   77 {
   78         size_t n, hn;
   79 
   80         n = MKN(logn);
   81         if (n == 1) {
   82                 tree[0] = g0[0];
   83                 return;
   84         }
   85         hn = n >> 1;
   86 
   87         /*
   88          * The LDL decomposition yields L (which is written in the tree)
   89          * and the diagonal of D. Since d00 = g0, we just write d11
   90          * into tmp.
   91          */
   92         Zf(poly_LDLmv_fft)(tmp, tree, g0, g1, g0, logn);
   93 
   94         /*
   95          * Split d00 (currently in g0) and d11 (currently in tmp). We
   96          * reuse g0 and g1 as temporary storage spaces:
   97          *   d00 splits into g1, g1+hn
   98          *   d11 splits into g0, g0+hn
   99          */
  100         Zf(poly_split_fft)(g1, g1 + hn, g0, logn);
  101         Zf(poly_split_fft)(g0, g0 + hn, tmp, logn);
  102 
  103         /*
  104          * Each split result is the first row of a new auto-adjoint
  105          * quasicyclic matrix for the next recursive step.
  106          */
  107         ffLDL_fft_inner(tree + n,
  108                 g1, g1 + hn, logn - 1, tmp);
  109         ffLDL_fft_inner(tree + n + ffLDL_treesize(logn - 1),
  110                 g0, g0 + hn, logn - 1, tmp);
  111 }
  112 
  113 /*
  114  * Compute the ffLDL tree of an auto-adjoint matrix G. The matrix
  115  * is provided as three polynomials (FFT representation).
  116  *
  117  * The "tree" array is filled with the computed tree, of size
  118  * (logn+1)*(2^logn) elements (see ffLDL_treesize()).
  119  *
  120  * Input arrays MUST NOT overlap, except possibly the three unmodified
  121  * arrays g00, g01 and g11. tmp[] should have room for at least three
  122  * polynomials of 2^logn elements each.
  123  */
  124 static void
  125 ffLDL_fft(fpr *restrict tree, const fpr *restrict g00,
  126         const fpr *restrict g01, const fpr *restrict g11,
  127         unsigned logn, fpr *restrict tmp)
  128 {
  129         size_t n, hn;
  130         fpr *d00, *d11;
  131 
  132         n = MKN(logn);
  133         if (n == 1) {
  134                 tree[0] = g00[0];
  135                 return;
  136         }
  137         hn = n >> 1;
  138         d00 = tmp;
  139         d11 = tmp + n;
  140         tmp += n << 1;
  141 
  142         memcpy(d00, g00, n * sizeof *g00);
  143         Zf(poly_LDLmv_fft)(d11, tree, g00, g01, g11, logn);
  144 
  145         Zf(poly_split_fft)(tmp, tmp + hn, d00, logn);
  146         Zf(poly_split_fft)(d00, d00 + hn, d11, logn);
  147         memcpy(d11, tmp, n * sizeof *tmp);
  148         ffLDL_fft_inner(tree + n,
  149                 d11, d11 + hn, logn - 1, tmp);
  150         ffLDL_fft_inner(tree + n + ffLDL_treesize(logn - 1),
  151                 d00, d00 + hn, logn - 1, tmp);
  152 }
  153 
  154 /*
  155  * Normalize an ffLDL tree: each leaf of value x is replaced with
  156  * sigma / sqrt(x).
  157  */
  158 static void
  159 ffLDL_binary_normalize(fpr *tree, unsigned orig_logn, unsigned logn)
  160 {
  161         /*
  162          * TODO: make an iterative version.
  163          */
  164         size_t n;
  165 
  166         n = MKN(logn);
  167         if (n == 1) {
  168                 /*
  169                  * We actually store in the tree leaf the inverse of
  170                  * the value mandated by the specification: this
  171                  * saves a division both here and in the sampler.
  172                  */
  173                 tree[0] = fpr_mul(fpr_sqrt(tree[0]), fpr_inv_sigma[orig_logn]);
  174         } else {
  175                 ffLDL_binary_normalize(tree + n, orig_logn, logn - 1);
  176                 ffLDL_binary_normalize(tree + n + ffLDL_treesize(logn - 1),
  177                         orig_logn, logn - 1);
  178         }
  179 }
  180 
  181 /* =================================================================== */
  182 
  183 /*
  184  * Convert an integer polynomial (with small values) into the
  185  * representation with complex numbers.
  186  */
  187 static void
  188 smallints_to_fpr(fpr *r, const int8_t *t, unsigned logn)
  189 {
  190         size_t n, u;
  191 
  192         n = MKN(logn);
  193         for (u = 0; u < n; u ++) {
  194                 r[u] = fpr_of(t[u]);
  195         }
  196 }
  197 
  198 /*
  199  * The expanded private key contains:
  200  *  - The B0 matrix (four elements)
  201  *  - The ffLDL tree
  202  */
  203 
  204 static inline size_t
  205 skoff_b00(unsigned logn)
  206 {
  207         (void)logn;
  208         return 0;
  209 }
  210 
  211 static inline size_t
  212 skoff_b01(unsigned logn)
  213 {
  214         return MKN(logn);
  215 }
  216 
  217 static inline size_t
  218 skoff_b10(unsigned logn)
  219 {
  220         return 2 * MKN(logn);
  221 }
  222 
  223 static inline size_t
  224 skoff_b11(unsigned logn)
  225 {
  226         return 3 * MKN(logn);
  227 }
  228 
  229 static inline size_t
  230 skoff_tree(unsigned logn)
  231 {
  232         return 4 * MKN(logn);
  233 }
  234 
  235 /* see inner.h */
  236 void
  237 Zf(expand_privkey)(fpr *restrict expanded_key,
  238         const int8_t *f, const int8_t *g,
  239         const int8_t *F, const int8_t *G,
  240         unsigned logn, uint8_t *restrict tmp)
  241 {
  242         size_t n;
  243         fpr *rf, *rg, *rF, *rG;
  244         fpr *b00, *b01, *b10, *b11;
  245         fpr *g00, *g01, *g11, *gxx;
  246         fpr *tree;
  247 
  248         n = MKN(logn);
  249         b00 = expanded_key + skoff_b00(logn);
  250         b01 = expanded_key + skoff_b01(logn);
  251         b10 = expanded_key + skoff_b10(logn);
  252         b11 = expanded_key + skoff_b11(logn);
  253         tree = expanded_key + skoff_tree(logn);
  254 
  255         /*
  256          * We load the private key elements directly into the B0 matrix,
  257          * since B0 = [[g, -f], [G, -F]].
  258          */
  259         rf = b01;
  260         rg = b00;
  261         rF = b11;
  262         rG = b10;
  263 
  264         smallints_to_fpr(rf, f, logn);
  265         smallints_to_fpr(rg, g, logn);
  266         smallints_to_fpr(rF, F, logn);
  267         smallints_to_fpr(rG, G, logn);
  268 
  269         /*
  270          * Compute the FFT for the key elements, and negate f and F.
  271          */
  272         Zf(FFT)(rf, logn);
  273         Zf(FFT)(rg, logn);
  274         Zf(FFT)(rF, logn);
  275         Zf(FFT)(rG, logn);
  276         Zf(poly_neg)(rf, logn);
  277         Zf(poly_neg)(rF, logn);
  278 
  279         /*
  280          * The Gram matrix is G = B·B*. Formulas are:
  281          *   g00 = b00*adj(b00) + b01*adj(b01)
  282          *   g01 = b00*adj(b10) + b01*adj(b11)
  283          *   g10 = b10*adj(b00) + b11*adj(b01)
  284          *   g11 = b10*adj(b10) + b11*adj(b11)
  285          *
  286          * For historical reasons, this implementation uses
  287          * g00, g01 and g11 (upper triangle).
  288          */
  289         g00 = (fpr *)tmp;
  290         g01 = g00 + n;
  291         g11 = g01 + n;
  292         gxx = g11 + n;
  293 
  294         memcpy(g00, b00, n * sizeof *b00);
  295         Zf(poly_mulselfadj_fft)(g00, logn);
  296         memcpy(gxx, b01, n * sizeof *b01);
  297         Zf(poly_mulselfadj_fft)(gxx, logn);
  298         Zf(poly_add)(g00, gxx, logn);
  299 
  300         memcpy(g01, b00, n * sizeof *b00);
  301         Zf(poly_muladj_fft)(g01, b10, logn);
  302         memcpy(gxx, b01, n * sizeof *b01);
  303         Zf(poly_muladj_fft)(gxx, b11, logn);
  304         Zf(poly_add)(g01, gxx, logn);
  305 
  306         memcpy(g11, b10, n * sizeof *b10);
  307         Zf(poly_mulselfadj_fft)(g11, logn);
  308         memcpy(gxx, b11, n * sizeof *b11);
  309         Zf(poly_mulselfadj_fft)(gxx, logn);
  310         Zf(poly_add)(g11, gxx, logn);
  311 
  312         /*
  313          * Compute the Falcon tree.
  314          */
  315         ffLDL_fft(tree, g00, g01, g11, logn, gxx);
  316 
  317         /*
  318          * Normalize tree.
  319          */
  320         ffLDL_binary_normalize(tree, logn, logn);
  321 }
  322 
  323 typedef int (*samplerZ)(void *ctx, fpr mu, fpr sigma);
  324 
  325 /*
  326  * Perform Fast Fourier Sampling for target vector t. The Gram matrix
  327  * is provided (G = [[g00, g01], [adj(g01), g11]]). The sampled vector
  328  * is written over (t0,t1). The Gram matrix is modified as well. The
  329  * tmp[] buffer must have room for four polynomials.
  330  */
  331 TARGET_AVX2
  332 static void
  333 ffSampling_fft_dyntree(samplerZ samp, void *samp_ctx,
  334         fpr *restrict t0, fpr *restrict t1,
  335         fpr *restrict g00, fpr *restrict g01, fpr *restrict g11,
  336         unsigned orig_logn, unsigned logn, fpr *restrict tmp)
  337 {
  338         size_t n, hn;
  339         fpr *z0, *z1;
  340 
  341         /*
  342          * Deepest level: the LDL tree leaf value is just g00 (the
  343          * array has length only 1 at this point); we normalize it
  344          * with regards to sigma, then use it for sampling.
  345          */
  346         if (logn == 0) {
  347                 fpr leaf;
  348 
  349                 leaf = g00[0];
  350                 leaf = fpr_mul(fpr_sqrt(leaf), fpr_inv_sigma[orig_logn]);
  351                 t0[0] = fpr_of(samp(samp_ctx, t0[0], leaf));
  352                 t1[0] = fpr_of(samp(samp_ctx, t1[0], leaf));
  353                 return;
  354         }
  355 
  356         n = (size_t)1 << logn;
  357         hn = n >> 1;
  358 
  359         /*
  360          * Decompose G into LDL. We only need d00 (identical to g00),
  361          * d11, and l10; we do that in place.
  362          */
  363         Zf(poly_LDL_fft)(g00, g01, g11, logn);
  364 
  365         /*
  366          * Split d00 and d11 and expand them into half-size quasi-cyclic
  367          * Gram matrices. We also save l10 in tmp[].
  368          */
  369         Zf(poly_split_fft)(tmp, tmp + hn, g00, logn);
  370         memcpy(g00, tmp, n * sizeof *tmp);
  371         Zf(poly_split_fft)(tmp, tmp + hn, g11, logn);
  372         memcpy(g11, tmp, n * sizeof *tmp);
  373         memcpy(tmp, g01, n * sizeof *g01);
  374         memcpy(g01, g00, hn * sizeof *g00);
  375         memcpy(g01 + hn, g11, hn * sizeof *g00);
  376 
  377         /*
  378          * The half-size Gram matrices for the recursive LDL tree
  379          * building are now:
  380          *   - left sub-tree: g00, g00+hn, g01
  381          *   - right sub-tree: g11, g11+hn, g01+hn
  382          * l10 is in tmp[].
  383          */
  384 
  385         /*
  386          * We split t1 and use the first recursive call on the two
  387          * halves, using the right sub-tree. The result is merged
  388          * back into tmp + 2*n.
  389          */
  390         z1 = tmp + n;
  391         Zf(poly_split_fft)(z1, z1 + hn, t1, logn);
  392         ffSampling_fft_dyntree(samp, samp_ctx, z1, z1 + hn,
  393                 g11, g11 + hn, g01 + hn, orig_logn, logn - 1, z1 + n);
  394         Zf(poly_merge_fft)(tmp + (n << 1), z1, z1 + hn, logn);
  395 
  396         /*
  397          * Compute tb0 = t0 + (t1 - z1) * l10.
  398          * At that point, l10 is in tmp, t1 is unmodified, and z1 is
  399          * in tmp + (n << 1). The buffer in z1 is free.
  400          *
  401          * In the end, z1 is written over t1, and tb0 is in t0.
  402          */
  403         memcpy(z1, t1, n * sizeof *t1);
  404         Zf(poly_sub)(z1, tmp + (n << 1), logn);
  405         memcpy(t1, tmp + (n << 1), n * sizeof *tmp);
  406         Zf(poly_mul_fft)(tmp, z1, logn);
  407         Zf(poly_add)(t0, tmp, logn);
  408 
  409         /*
  410          * Second recursive invocation, on the split tb0 (currently in t0)
  411          * and the left sub-tree.
  412          */
  413         z0 = tmp;
  414         Zf(poly_split_fft)(z0, z0 + hn, t0, logn);
  415         ffSampling_fft_dyntree(samp, samp_ctx, z0, z0 + hn,
  416                 g00, g00 + hn, g01, orig_logn, logn - 1, z0 + n);
  417         Zf(poly_merge_fft)(t0, z0, z0 + hn, logn);
  418 }
  419 
  420 /*
  421  * Perform Fast Fourier Sampling for target vector t and LDL tree T.
  422  * tmp[] must have size for at least two polynomials of size 2^logn.
  423  */
  424 TARGET_AVX2
  425 static void
  426 ffSampling_fft(samplerZ samp, void *samp_ctx,
  427         fpr *restrict z0, fpr *restrict z1,
  428         const fpr *restrict tree,
  429         const fpr *restrict t0, const fpr *restrict t1, unsigned logn,
  430         fpr *restrict tmp)
  431 {
  432         size_t n, hn;
  433         const fpr *tree0, *tree1;
  434 
  435         /*
  436          * When logn == 2, we inline the last two recursion levels.
  437          */
  438         if (logn == 2) {
  439 #if FALCON_AVX2  // yyyAVX2+1
  440                 fpr w0, w1, w2, w3, sigma;
  441                 __m128d ww0, ww1, wa, wb, wc, wd;
  442                 __m128d wy0, wy1, wz0, wz1;
  443                 __m128d half, invsqrt8, invsqrt2, neghi, neglo;
  444                 int si0, si1, si2, si3;
  445 
  446                 tree0 = tree + 4;
  447                 tree1 = tree + 8;
  448 
  449                 half = _mm_set1_pd(0.5);
  450                 invsqrt8 = _mm_set1_pd(0.353553390593273762200422181052);
  451                 invsqrt2 = _mm_set1_pd(0.707106781186547524400844362105);
  452                 neghi = _mm_set_pd(-0.0, 0.0);
  453                 neglo = _mm_set_pd(0.0, -0.0);
  454 
  455                 /*
  456                  * We split t1 into w*, then do the recursive invocation,
  457                  * with output in w*. We finally merge back into z1.
  458                  */
  459                 ww0 = _mm_loadu_pd(&t1[0].v);
  460                 ww1 = _mm_loadu_pd(&t1[2].v);
  461                 wa = _mm_unpacklo_pd(ww0, ww1);
  462                 wb = _mm_unpackhi_pd(ww0, ww1);
  463                 wc = _mm_add_pd(wa, wb);
  464                 ww0 = _mm_mul_pd(wc, half);
  465                 wc = _mm_sub_pd(wa, wb);
  466                 wd = _mm_xor_pd(_mm_permute_pd(wc, 1), neghi);
  467                 ww1 = _mm_mul_pd(_mm_add_pd(wc, wd), invsqrt8);
  468 
  469                 w2.v = _mm_cvtsd_f64(ww1);
  470                 w3.v = _mm_cvtsd_f64(_mm_permute_pd(ww1, 1));
  471                 wa = ww1;
  472                 sigma = tree1[3];
  473                 si2 = samp(samp_ctx, w2, sigma);
  474                 si3 = samp(samp_ctx, w3, sigma);
  475                 ww1 = _mm_set_pd((double)si3, (double)si2);
  476                 wa = _mm_sub_pd(wa, ww1);
  477                 wb = _mm_loadu_pd(&tree1[0].v);
  478                 wc = _mm_mul_pd(wa, wb);
  479                 wd = _mm_mul_pd(wa, _mm_permute_pd(wb, 1));
  480                 wa = _mm_unpacklo_pd(wc, wd);
  481                 wb = _mm_unpackhi_pd(wc, wd);
  482                 ww0 = _mm_add_pd(ww0, _mm_add_pd(wa, _mm_xor_pd(wb, neglo)));
  483                 w0.v = _mm_cvtsd_f64(ww0);
  484                 w1.v = _mm_cvtsd_f64(_mm_permute_pd(ww0, 1));
  485                 sigma = tree1[2];
  486                 si0 = samp(samp_ctx, w0, sigma);
  487                 si1 = samp(samp_ctx, w1, sigma);
  488                 ww0 = _mm_set_pd((double)si1, (double)si0);
  489 
  490                 wc = _mm_mul_pd(
  491                         _mm_set_pd((double)(si2 + si3), (double)(si2 - si3)),
  492                         invsqrt2);
  493                 wa = _mm_add_pd(ww0, wc);
  494                 wb = _mm_sub_pd(ww0, wc);
  495                 ww0 = _mm_unpacklo_pd(wa, wb);
  496                 ww1 = _mm_unpackhi_pd(wa, wb);
  497                 _mm_storeu_pd(&z1[0].v, ww0);
  498                 _mm_storeu_pd(&z1[2].v, ww1);
  499 
  500                 /*
  501                  * Compute tb0 = t0 + (t1 - z1) * L. Value tb0 ends up in w*.
  502                  */
  503                 wy0 = _mm_sub_pd(_mm_loadu_pd(&t1[0].v), ww0);
  504                 wy1 = _mm_sub_pd(_mm_loadu_pd(&t1[2].v), ww1);
  505                 wz0 = _mm_loadu_pd(&tree[0].v);
  506                 wz1 = _mm_loadu_pd(&tree[2].v);
  507                 ww0 = _mm_sub_pd(_mm_mul_pd(wy0, wz0), _mm_mul_pd(wy1, wz1));
  508                 ww1 = _mm_add_pd(_mm_mul_pd(wy0, wz1), _mm_mul_pd(wy1, wz0));
  509                 ww0 = _mm_add_pd(ww0, _mm_loadu_pd(&t0[0].v));
  510                 ww1 = _mm_add_pd(ww1, _mm_loadu_pd(&t0[2].v));
  511 
  512                 /*
  513                  * Second recursive invocation.
  514                  */
  515                 wa = _mm_unpacklo_pd(ww0, ww1);
  516                 wb = _mm_unpackhi_pd(ww0, ww1);
  517                 wc = _mm_add_pd(wa, wb);
  518                 ww0 = _mm_mul_pd(wc, half);
  519                 wc = _mm_sub_pd(wa, wb);
  520                 wd = _mm_xor_pd(_mm_permute_pd(wc, 1), neghi);
  521                 ww1 = _mm_mul_pd(_mm_add_pd(wc, wd), invsqrt8);
  522 
  523                 w2.v = _mm_cvtsd_f64(ww1);
  524                 w3.v = _mm_cvtsd_f64(_mm_permute_pd(ww1, 1));
  525                 wa = ww1;
  526                 sigma = tree0[3];
  527                 si2 = samp(samp_ctx, w2, sigma);
  528                 si3 = samp(samp_ctx, w3, sigma);
  529                 ww1 = _mm_set_pd((double)si3, (double)si2);
  530                 wa = _mm_sub_pd(wa, ww1);
  531                 wb = _mm_loadu_pd(&tree0[0].v);
  532                 wc = _mm_mul_pd(wa, wb);
  533                 wd = _mm_mul_pd(wa, _mm_permute_pd(wb, 1));
  534                 wa = _mm_unpacklo_pd(wc, wd);
  535                 wb = _mm_unpackhi_pd(wc, wd);
  536                 ww0 = _mm_add_pd(ww0, _mm_add_pd(wa, _mm_xor_pd(wb, neglo)));
  537                 w0.v = _mm_cvtsd_f64(ww0);
  538                 w1.v = _mm_cvtsd_f64(_mm_permute_pd(ww0, 1));
  539                 sigma = tree0[2];
  540                 si0 = samp(samp_ctx, w0, sigma);
  541                 si1 = samp(samp_ctx, w1, sigma);
  542                 ww0 = _mm_set_pd((double)si1, (double)si0);
  543 
  544                 wc = _mm_mul_pd(
  545                         _mm_set_pd((double)(si2 + si3), (double)(si2 - si3)),
  546                         invsqrt2);
  547                 wa = _mm_add_pd(ww0, wc);
  548                 wb = _mm_sub_pd(ww0, wc);
  549                 ww0 = _mm_unpacklo_pd(wa, wb);
  550                 ww1 = _mm_unpackhi_pd(wa, wb);
  551                 _mm_storeu_pd(&z0[0].v, ww0);
  552                 _mm_storeu_pd(&z0[2].v, ww1);
  553 
  554                 return;
  555 #else  // yyyAVX2+0
  556                 fpr x0, x1, y0, y1, w0, w1, w2, w3, sigma;
  557                 fpr a_re, a_im, b_re, b_im, c_re, c_im;
  558 
  559                 tree0 = tree + 4;
  560                 tree1 = tree + 8;
  561 
  562                 /*
  563                  * We split t1 into w*, then do the recursive invocation,
  564                  * with output in w*. We finally merge back into z1.
  565                  */
  566                 a_re = t1[0];
  567                 a_im = t1[2];
  568                 b_re = t1[1];
  569                 b_im = t1[3];
  570                 c_re = fpr_add(a_re, b_re);
  571                 c_im = fpr_add(a_im, b_im);
  572                 w0 = fpr_half(c_re);
  573                 w1 = fpr_half(c_im);
  574                 c_re = fpr_sub(a_re, b_re);
  575                 c_im = fpr_sub(a_im, b_im);
  576                 w2 = fpr_mul(fpr_add(c_re, c_im), fpr_invsqrt8);
  577                 w3 = fpr_mul(fpr_sub(c_im, c_re), fpr_invsqrt8);
  578 
  579                 x0 = w2;
  580                 x1 = w3;
  581                 sigma = tree1[3];
  582                 w2 = fpr_of(samp(samp_ctx, x0, sigma));
  583                 w3 = fpr_of(samp(samp_ctx, x1, sigma));
  584                 a_re = fpr_sub(x0, w2);
  585                 a_im = fpr_sub(x1, w3);
  586                 b_re = tree1[0];
  587                 b_im = tree1[1];
  588                 c_re = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
  589                 c_im = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
  590                 x0 = fpr_add(c_re, w0);
  591                 x1 = fpr_add(c_im, w1);
  592                 sigma = tree1[2];
  593                 w0 = fpr_of(samp(samp_ctx, x0, sigma));
  594                 w1 = fpr_of(samp(samp_ctx, x1, sigma));
  595 
  596                 a_re = w0;
  597                 a_im = w1;
  598                 b_re = w2;
  599                 b_im = w3;
  600                 c_re = fpr_mul(fpr_sub(b_re, b_im), fpr_invsqrt2);
  601                 c_im = fpr_mul(fpr_add(b_re, b_im), fpr_invsqrt2);
  602                 z1[0] = w0 = fpr_add(a_re, c_re);
  603                 z1[2] = w2 = fpr_add(a_im, c_im);
  604                 z1[1] = w1 = fpr_sub(a_re, c_re);
  605                 z1[3] = w3 = fpr_sub(a_im, c_im);
  606 
  607                 /*
  608                  * Compute tb0 = t0 + (t1 - z1) * L. Value tb0 ends up in w*.
  609                  */
  610                 w0 = fpr_sub(t1[0], w0);
  611                 w1 = fpr_sub(t1[1], w1);
  612                 w2 = fpr_sub(t1[2], w2);
  613                 w3 = fpr_sub(t1[3], w3);
  614 
  615                 a_re = w0;
  616                 a_im = w2;
  617                 b_re = tree[0];
  618                 b_im = tree[2];
  619                 w0 = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
  620                 w2 = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
  621                 a_re = w1;
  622                 a_im = w3;
  623                 b_re = tree[1];
  624                 b_im = tree[3];
  625                 w1 = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
  626                 w3 = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
  627 
  628                 w0 = fpr_add(w0, t0[0]);
  629                 w1 = fpr_add(w1, t0[1]);
  630                 w2 = fpr_add(w2, t0[2]);
  631                 w3 = fpr_add(w3, t0[3]);
  632 
  633                 /*
  634                  * Second recursive invocation.
  635                  */
  636                 a_re = w0;
  637                 a_im = w2;
  638                 b_re = w1;
  639                 b_im = w3;
  640                 c_re = fpr_add(a_re, b_re);
  641                 c_im = fpr_add(a_im, b_im);
  642                 w0 = fpr_half(c_re);
  643                 w1 = fpr_half(c_im);
  644                 c_re = fpr_sub(a_re, b_re);
  645                 c_im = fpr_sub(a_im, b_im);
  646                 w2 = fpr_mul(fpr_add(c_re, c_im), fpr_invsqrt8);
  647                 w3 = fpr_mul(fpr_sub(c_im, c_re), fpr_invsqrt8);
  648 
  649                 x0 = w2;
  650                 x1 = w3;
  651                 sigma = tree0[3];
  652                 w2 = y0 = fpr_of(samp(samp_ctx, x0, sigma));
  653                 w3 = y1 = fpr_of(samp(samp_ctx, x1, sigma));
  654                 a_re = fpr_sub(x0, y0);
  655                 a_im = fpr_sub(x1, y1);
  656                 b_re = tree0[0];
  657                 b_im = tree0[1];
  658                 c_re = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
  659                 c_im = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
  660                 x0 = fpr_add(c_re, w0);
  661                 x1 = fpr_add(c_im, w1);
  662                 sigma = tree0[2];
  663                 w0 = fpr_of(samp(samp_ctx, x0, sigma));
  664                 w1 = fpr_of(samp(samp_ctx, x1, sigma));
  665 
  666                 a_re = w0;
  667                 a_im = w1;
  668                 b_re = w2;
  669                 b_im = w3;
  670                 c_re = fpr_mul(fpr_sub(b_re, b_im), fpr_invsqrt2);
  671                 c_im = fpr_mul(fpr_add(b_re, b_im), fpr_invsqrt2);
  672                 z0[0] = fpr_add(a_re, c_re);
  673                 z0[2] = fpr_add(a_im, c_im);
  674                 z0[1] = fpr_sub(a_re, c_re);
  675                 z0[3] = fpr_sub(a_im, c_im);
  676 
  677                 return;
  678 #endif  // yyyAVX2-
  679         }
  680 
  681         /*
  682          * Case logn == 1 is reachable only when using Falcon-2 (the
  683          * smallest size for which Falcon is mathematically defined, but
  684          * of course way too insecure to be of any use).
  685          */
  686         if (logn == 1) {
  687                 fpr x0, x1, y0, y1, sigma;
  688                 fpr a_re, a_im, b_re, b_im, c_re, c_im;
  689 
  690                 x0 = t1[0];
  691                 x1 = t1[1];
  692                 sigma = tree[3];
  693                 z1[0] = y0 = fpr_of(samp(samp_ctx, x0, sigma));
  694                 z1[1] = y1 = fpr_of(samp(samp_ctx, x1, sigma));
  695                 a_re = fpr_sub(x0, y0);
  696                 a_im = fpr_sub(x1, y1);
  697                 b_re = tree[0];
  698                 b_im = tree[1];
  699                 c_re = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
  700                 c_im = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
  701                 x0 = fpr_add(c_re, t0[0]);
  702                 x1 = fpr_add(c_im, t0[1]);
  703                 sigma = tree[2];
  704                 z0[0] = fpr_of(samp(samp_ctx, x0, sigma));
  705                 z0[1] = fpr_of(samp(samp_ctx, x1, sigma));
  706 
  707                 return;
  708         }
  709 
  710         /*
  711          * Normal end of recursion is for logn == 0. Since the last
  712          * steps of the recursions were inlined in the blocks above
  713          * (when logn == 1 or 2), this case is not reachable, and is
  714          * retained here only for documentation purposes.
  715 
  716         if (logn == 0) {
  717                 fpr x0, x1, sigma;
  718 
  719                 x0 = t0[0];
  720                 x1 = t1[0];
  721                 sigma = tree[0];
  722                 z0[0] = fpr_of(samp(samp_ctx, x0, sigma));
  723                 z1[0] = fpr_of(samp(samp_ctx, x1, sigma));
  724                 return;
  725         }
  726 
  727          */
  728 
  729         /*
  730          * General recursive case (logn >= 3).
  731          */
  732 
  733         n = (size_t)1 << logn;
  734         hn = n >> 1;
  735         tree0 = tree + n;
  736         tree1 = tree + n + ffLDL_treesize(logn - 1);
  737 
  738         /*
  739          * We split t1 into z1 (reused as temporary storage), then do
  740          * the recursive invocation, with output in tmp. We finally
  741          * merge back into z1.
  742          */
  743         Zf(poly_split_fft)(z1, z1 + hn, t1, logn);
  744         ffSampling_fft(samp, samp_ctx, tmp, tmp + hn,
  745                 tree1, z1, z1 + hn, logn - 1, tmp + n);
  746         Zf(poly_merge_fft)(z1, tmp, tmp + hn, logn);
  747 
  748         /*
  749          * Compute tb0 = t0 + (t1 - z1) * L. Value tb0 ends up in tmp[].
  750          */
  751         memcpy(tmp, t1, n * sizeof *t1);
  752         Zf(poly_sub)(tmp, z1, logn);
  753         Zf(poly_mul_fft)(tmp, tree, logn);
  754         Zf(poly_add)(tmp, t0, logn);
  755 
  756         /*
  757          * Second recursive invocation.
  758          */
  759         Zf(poly_split_fft)(z0, z0 + hn, tmp, logn);
  760         ffSampling_fft(samp, samp_ctx, tmp, tmp + hn,
  761                 tree0, z0, z0 + hn, logn - 1, tmp + n);
  762         Zf(poly_merge_fft)(z0, tmp, tmp + hn, logn);
  763 }
  764 
  765 /*
  766  * Compute a signature: the signature contains two vectors, s1 and s2.
  767  * The s1 vector is not returned. The squared norm of (s1,s2) is
  768  * computed, and if it is short enough, then s2 is returned into the
  769  * s2[] buffer, and 1 is returned; otherwise, s2[] is untouched and 0 is
  770  * returned; the caller should then try again. This function uses an
  771  * expanded key.
  772  *
  773  * tmp[] must have room for at least six polynomials.
  774  */
  775 static int
  776 do_sign_tree(samplerZ samp, void *samp_ctx, int16_t *s2,
  777         const fpr *restrict expanded_key,
  778         const uint16_t *hm,
  779         unsigned logn, fpr *restrict tmp)
  780 {
  781         size_t n, u;
  782         fpr *t0, *t1, *tx, *ty;
  783         const fpr *b00, *b01, *b10, *b11, *tree;
  784         fpr ni;
  785         uint32_t sqn, ng;
  786         int16_t *s1tmp, *s2tmp;
  787 
  788         n = MKN(logn);
  789         t0 = tmp;
  790         t1 = t0 + n;
  791         b00 = expanded_key + skoff_b00(logn);
  792         b01 = expanded_key + skoff_b01(logn);
  793         b10 = expanded_key + skoff_b10(logn);
  794         b11 = expanded_key + skoff_b11(logn);
  795         tree = expanded_key + skoff_tree(logn);
  796 
  797         /*
  798          * Set the target vector to [hm, 0] (hm is the hashed message).
  799          */
  800         for (u = 0; u < n; u ++) {
  801                 t0[u] = fpr_of(hm[u]);
  802                 /* This is implicit.
  803                 t1[u] = fpr_zero;
  804                 */
  805         }
  806 
  807         /*
  808          * Apply the lattice basis to obtain the real target
  809          * vector (after normalization with regards to modulus).
  810          */
  811         Zf(FFT)(t0, logn);
  812         ni = fpr_inverse_of_q;
  813         memcpy(t1, t0, n * sizeof *t0);
  814         Zf(poly_mul_fft)(t1, b01, logn);
  815         Zf(poly_mulconst)(t1, fpr_neg(ni), logn);
  816         Zf(poly_mul_fft)(t0, b11, logn);
  817         Zf(poly_mulconst)(t0, ni, logn);
  818 
  819         tx = t1 + n;
  820         ty = tx + n;
  821 
  822         /*
  823          * Apply sampling. Output is written back in [tx, ty].
  824          */
  825         ffSampling_fft(samp, samp_ctx, tx, ty, tree, t0, t1, logn, ty + n);
  826 
  827         /*
  828          * Get the lattice point corresponding to that tiny vector.
  829          */
  830         memcpy(t0, tx, n * sizeof *tx);
  831         memcpy(t1, ty, n * sizeof *ty);
  832         Zf(poly_mul_fft)(tx, b00, logn);
  833         Zf(poly_mul_fft)(ty, b10, logn);
  834         Zf(poly_add)(tx, ty, logn);
  835         memcpy(ty, t0, n * sizeof *t0);
  836         Zf(poly_mul_fft)(ty, b01, logn);
  837 
  838         memcpy(t0, tx, n * sizeof *tx);
  839         Zf(poly_mul_fft)(t1, b11, logn);
  840         Zf(poly_add)(t1, ty, logn);
  841 
  842         Zf(iFFT)(t0, logn);
  843         Zf(iFFT)(t1, logn);
  844 
  845         /*
  846          * Compute the signature.
  847          */
  848         s1tmp = (int16_t *)tx;
  849         sqn = 0;
  850         ng = 0;
  851         for (u = 0; u < n; u ++) {
  852                 int32_t z;
  853 
  854                 z = (int32_t)hm[u] - (int32_t)fpr_rint(t0[u]);
  855                 sqn += (uint32_t)(z * z);
  856                 ng |= sqn;
  857                 s1tmp[u] = (int16_t)z;
  858         }
  859         sqn |= -(ng >> 31);
  860 
  861         /*
  862          * With "normal" degrees (e.g. 512 or 1024), it is very
  863          * improbable that the computed vector is not short enough;
  864          * however, it may happen in practice for the very reduced
  865          * versions (e.g. degree 16 or below). In that case, the caller
  866          * will loop, and we must not write anything into s2[] because
  867          * s2[] may overlap with the hashed message hm[] and we need
  868          * hm[] for the next iteration.
  869          */
  870         s2tmp = (int16_t *)tmp;
  871         for (u = 0; u < n; u ++) {
  872                 s2tmp[u] = (int16_t)-fpr_rint(t1[u]);
  873         }
  874         if (Zf(is_short_half)(sqn, s2tmp, logn)) {
  875                 memcpy(s2, s2tmp, n * sizeof *s2);
  876                 memcpy(tmp, s1tmp, n * sizeof *s1tmp);
  877                 return 1;
  878         }
  879         return 0;
  880 }
  881 
  882 /*
  883  * Compute a signature: the signature contains two vectors, s1 and s2.
  884  * The s1 vector is not returned. The squared norm of (s1,s2) is
  885  * computed, and if it is short enough, then s2 is returned into the
  886  * s2[] buffer, and 1 is returned; otherwise, s2[] is untouched and 0 is
  887  * returned; the caller should then try again.
  888  *
  889  * tmp[] must have room for at least nine polynomials.
  890  */
  891 static int
  892 do_sign_dyn(samplerZ samp, void *samp_ctx, int16_t *s2,
  893         const int8_t *restrict f, const int8_t *restrict g,
  894         const int8_t *restrict F, const int8_t *restrict G,
  895         const uint16_t *hm, unsigned logn, fpr *restrict tmp)
  896 {
  897         size_t n, u;
  898         fpr *t0, *t1, *tx, *ty;
  899         fpr *b00, *b01, *b10, *b11, *g00, *g01, *g11;
  900         fpr ni;
  901         uint32_t sqn, ng;
  902         int16_t *s1tmp, *s2tmp;
  903 
  904         n = MKN(logn);
  905 
  906         /*
  907          * Lattice basis is B = [[g, -f], [G, -F]]. We convert it to FFT.
  908          */
  909         b00 = tmp;
  910         b01 = b00 + n;
  911         b10 = b01 + n;
  912         b11 = b10 + n;
  913         smallints_to_fpr(b01, f, logn);
  914         smallints_to_fpr(b00, g, logn);
  915         smallints_to_fpr(b11, F, logn);
  916         smallints_to_fpr(b10, G, logn);
  917         Zf(FFT)(b01, logn);
  918         Zf(FFT)(b00, logn);
  919         Zf(FFT)(b11, logn);
  920         Zf(FFT)(b10, logn);
  921         Zf(poly_neg)(b01, logn);
  922         Zf(poly_neg)(b11, logn);
  923 
  924         /*
  925          * Compute the Gram matrix G = B·B*. Formulas are:
  926          *   g00 = b00*adj(b00) + b01*adj(b01)
  927          *   g01 = b00*adj(b10) + b01*adj(b11)
  928          *   g10 = b10*adj(b00) + b11*adj(b01)
  929          *   g11 = b10*adj(b10) + b11*adj(b11)
  930          *
  931          * For historical reasons, this implementation uses
  932          * g00, g01 and g11 (upper triangle). g10 is not kept
  933          * since it is equal to adj(g01).
  934          *
  935          * We _replace_ the matrix B with the Gram matrix, but we
  936          * must keep b01 and b11 for computing the target vector.
  937          */
  938         t0 = b11 + n;
  939         t1 = t0 + n;
  940 
  941         memcpy(t0, b01, n * sizeof *b01);
  942         Zf(poly_mulselfadj_fft)(t0, logn);    // t0 <- b01*adj(b01)
  943 
  944         memcpy(t1, b00, n * sizeof *b00);
  945         Zf(poly_muladj_fft)(t1, b10, logn);   // t1 <- b00*adj(b10)
  946         Zf(poly_mulselfadj_fft)(b00, logn);   // b00 <- b00*adj(b00)
  947         Zf(poly_add)(b00, t0, logn);      // b00 <- g00
  948         memcpy(t0, b01, n * sizeof *b01);
  949         Zf(poly_muladj_fft)(b01, b11, logn);  // b01 <- b01*adj(b11)
  950         Zf(poly_add)(b01, t1, logn);      // b01 <- g01
  951 
  952         Zf(poly_mulselfadj_fft)(b10, logn);   // b10 <- b10*adj(b10)
  953         memcpy(t1, b11, n * sizeof *b11);
  954         Zf(poly_mulselfadj_fft)(t1, logn);    // t1 <- b11*adj(b11)
  955         Zf(poly_add)(b10, t1, logn);      // b10 <- g11
  956 
  957         /*
  958          * We rename variables to make things clearer. The three elements
  959          * of the Gram matrix uses the first 3*n slots of tmp[], followed
  960          * by b11 and b01 (in that order).
  961          */
  962         g00 = b00;
  963         g01 = b01;
  964         g11 = b10;
  965         b01 = t0;
  966         t0 = b01 + n;
  967         t1 = t0 + n;
  968 
  969         /*
  970          * Memory layout at that point:
  971          *   g00 g01 g11 b11 b01 t0 t1
  972          */
  973 
  974         /*
  975          * Set the target vector to [hm, 0] (hm is the hashed message).
  976          */
  977         for (u = 0; u < n; u ++) {
  978                 t0[u] = fpr_of(hm[u]);
  979                 /* This is implicit.
  980                 t1[u] = fpr_zero;
  981                 */
  982         }
  983 
  984         /*
  985          * Apply the lattice basis to obtain the real target
  986          * vector (after normalization with regards to modulus).
  987          */
  988         Zf(FFT)(t0, logn);
  989         ni = fpr_inverse_of_q;
  990         memcpy(t1, t0, n * sizeof *t0);
  991         Zf(poly_mul_fft)(t1, b01, logn);
  992         Zf(poly_mulconst)(t1, fpr_neg(ni), logn);
  993         Zf(poly_mul_fft)(t0, b11, logn);
  994         Zf(poly_mulconst)(t0, ni, logn);
  995 
  996         /*
  997          * b01 and b11 can be discarded, so we move back (t0,t1).
  998          * Memory layout is now:
  999          *      g00 g01 g11 t0 t1
 1000          */
 1001         memcpy(b11, t0, n * 2 * sizeof *t0);
 1002         t0 = g11 + n;
 1003         t1 = t0 + n;
 1004 
 1005         /*
 1006          * Apply sampling; result is written over (t0,t1).
 1007          */
 1008         ffSampling_fft_dyntree(samp, samp_ctx,
 1009                 t0, t1, g00, g01, g11, logn, logn, t1 + n);
 1010 
 1011         /*
 1012          * We arrange the layout back to:
 1013          *     b00 b01 b10 b11 t0 t1
 1014          *
 1015          * We did not conserve the matrix basis, so we must recompute
 1016          * it now.
 1017          */
 1018         b00 = tmp;
 1019         b01 = b00 + n;
 1020         b10 = b01 + n;
 1021         b11 = b10 + n;
 1022         memmove(b11 + n, t0, n * 2 * sizeof *t0);
 1023         t0 = b11 + n;
 1024         t1 = t0 + n;
 1025         smallints_to_fpr(b01, f, logn);
 1026         smallints_to_fpr(b00, g, logn);
 1027         smallints_to_fpr(b11, F, logn);
 1028         smallints_to_fpr(b10, G, logn);
 1029         Zf(FFT)(b01, logn);
 1030         Zf(FFT)(b00, logn);
 1031         Zf(FFT)(b11, logn);
 1032         Zf(FFT)(b10, logn);
 1033         Zf(poly_neg)(b01, logn);
 1034         Zf(poly_neg)(b11, logn);
 1035         tx = t1 + n;
 1036         ty = tx + n;
 1037 
 1038         /*
 1039          * Get the lattice point corresponding to that tiny vector.
 1040          */
 1041         memcpy(tx, t0, n * sizeof *t0);
 1042         memcpy(ty, t1, n * sizeof *t1);
 1043         Zf(poly_mul_fft)(tx, b00, logn);
 1044         Zf(poly_mul_fft)(ty, b10, logn);
 1045         Zf(poly_add)(tx, ty, logn);
 1046         memcpy(ty, t0, n * sizeof *t0);
 1047         Zf(poly_mul_fft)(ty, b01, logn);
 1048 
 1049         memcpy(t0, tx, n * sizeof *tx);
 1050         Zf(poly_mul_fft)(t1, b11, logn);
 1051         Zf(poly_add)(t1, ty, logn);
 1052         Zf(iFFT)(t0, logn);
 1053         Zf(iFFT)(t1, logn);
 1054 
 1055         s1tmp = (int16_t *)tx;
 1056         sqn = 0;
 1057         ng = 0;
 1058         for (u = 0; u < n; u ++) {
 1059                 int32_t z;
 1060 
 1061                 z = (int32_t)hm[u] - (int32_t)fpr_rint(t0[u]);
 1062                 sqn += (uint32_t)(z * z);
 1063                 ng |= sqn;
 1064                 s1tmp[u] = (int16_t)z;
 1065         }
 1066         sqn |= -(ng >> 31);
 1067 
 1068         /*
 1069          * With "normal" degrees (e.g. 512 or 1024), it is very
 1070          * improbable that the computed vector is not short enough;
 1071          * however, it may happen in practice for the very reduced
 1072          * versions (e.g. degree 16 or below). In that case, the caller
 1073          * will loop, and we must not write anything into s2[] because
 1074          * s2[] may overlap with the hashed message hm[] and we need
 1075          * hm[] for the next iteration.
 1076          */
 1077         s2tmp = (int16_t *)tmp;
 1078         for (u = 0; u < n; u ++) {
 1079                 s2tmp[u] = (int16_t)-fpr_rint(t1[u]);
 1080         }
 1081         if (Zf(is_short_half)(sqn, s2tmp, logn)) {
 1082                 memcpy(s2, s2tmp, n * sizeof *s2);
 1083                 memcpy(tmp, s1tmp, n * sizeof *s1tmp);
 1084                 return 1;
 1085         }
 1086         return 0;
 1087 }
 1088 
 1089 /*
 1090  * Sample an integer value along a half-gaussian distribution centered
 1091  * on zero and standard deviation 1.8205, with a precision of 72 bits.
 1092  */
 1093 TARGET_AVX2
 1094 int
 1095 Zf(gaussian0_sampler)(prng *p)
 1096 {
 1097 #if FALCON_AVX2 // yyyAVX2+1
 1098 
 1099         /*
 1100          * High words.
 1101          */
 1102         static const union {
 1103                 uint16_t u16[16];
 1104                 __m256i ymm[1];
 1105         } rhi15 = {
 1106                 {
 1107                         0x51FB, 0x2A69, 0x113E, 0x0568,
 1108                         0x014A, 0x003B, 0x0008, 0x0000,
 1109                         0x0000, 0x0000, 0x0000, 0x0000,
 1110                         0x0000, 0x0000, 0x0000, 0x0000
 1111                 }
 1112         };
 1113 
 1114         static const union {
 1115                 uint64_t u64[20];
 1116                 __m256i ymm[5];
 1117         } rlo57 = {
 1118                 {
 1119                         0x1F42ED3AC391802, 0x12B181F3F7DDB82,
 1120                         0x1CDD0934829C1FF, 0x1754377C7994AE4,
 1121                         0x1846CAEF33F1F6F, 0x14AC754ED74BD5F,
 1122                         0x024DD542B776AE4, 0x1A1FFDC65AD63DA,
 1123                         0x01F80D88A7B6428, 0x001C3FDB2040C69,
 1124                         0x00012CF24D031FB, 0x00000949F8B091F,
 1125                         0x0000003665DA998, 0x00000000EBF6EBB,
 1126                         0x0000000002F5D7E, 0x000000000007098,
 1127                         0x0000000000000C6, 0x000000000000001,
 1128                         0x000000000000000, 0x000000000000000
 1129                 }
 1130         };
 1131 
 1132         uint64_t lo;
 1133         unsigned hi;
 1134         __m256i xhi, rhi, gthi, eqhi, eqm;
 1135         __m256i xlo, gtlo0, gtlo1, gtlo2, gtlo3, gtlo4;
 1136         __m128i t, zt;
 1137         int r;
 1138 
 1139         /*
 1140          * Get a 72-bit random value and split it into a low part
 1141          * (57 bits) and a high part (15 bits)
 1142          */
 1143         lo = prng_get_u64(p);
 1144         hi = prng_get_u8(p);
 1145         hi = (hi << 7) | (unsigned)(lo >> 57);
 1146         lo &= 0x1FFFFFFFFFFFFFF;
 1147 
 1148         /*
 1149          * Broadcast the high part and compare it with the relevant
 1150          * values. We need both a "greater than" and an "equal"
 1151          * comparisons.
 1152          */
 1153         xhi = _mm256_broadcastw_epi16(_mm_cvtsi32_si128(hi));
 1154         rhi = _mm256_loadu_si256(&rhi15.ymm[0]);
 1155         gthi = _mm256_cmpgt_epi16(rhi, xhi);
 1156         eqhi = _mm256_cmpeq_epi16(rhi, xhi);
 1157 
 1158         /*
 1159          * The result is the number of 72-bit values (among the list of 19)
 1160          * which are greater than the 72-bit random value. We first count
 1161          * all non-zero 16-bit elements in the first eight of gthi. Such
 1162          * elements have value -1 or 0, so we first negate them.
 1163          */
 1164         t = _mm_srli_epi16(_mm256_castsi256_si128(gthi), 15);
 1165         zt = _mm_setzero_si128();
 1166         t = _mm_hadd_epi16(t, zt);
 1167         t = _mm_hadd_epi16(t, zt);
 1168         t = _mm_hadd_epi16(t, zt);
 1169         r = _mm_cvtsi128_si32(t);
 1170 
 1171         /*
 1172          * We must look at the low bits for all values for which the
 1173          * high bits are an "equal" match; values 8-18 all have the
 1174          * same high bits (0).
 1175          * On 32-bit systems, 'lo' really is two registers, requiring
 1176          * some extra code.
 1177          */
 1178 #if defined(__x86_64__) || defined(_M_X64)
 1179         xlo = _mm256_broadcastq_epi64(_mm_cvtsi64_si128(*(int64_t *)&lo));
 1180 #else
 1181         {
 1182                 uint32_t e0, e1;
 1183                 int32_t f0, f1;
 1184 
 1185                 e0 = (uint32_t)lo;
 1186                 e1 = (uint32_t)(lo >> 32);
 1187                 f0 = *(int32_t *)&e0;
 1188                 f1 = *(int32_t *)&e1;
 1189                 xlo = _mm256_set_epi32(f1, f0, f1, f0, f1, f0, f1, f0);
 1190         }
 1191 #endif
 1192         gtlo0 = _mm256_cmpgt_epi64(_mm256_loadu_si256(&rlo57.ymm[0]), xlo); 
 1193         gtlo1 = _mm256_cmpgt_epi64(_mm256_loadu_si256(&rlo57.ymm[1]), xlo); 
 1194         gtlo2 = _mm256_cmpgt_epi64(_mm256_loadu_si256(&rlo57.ymm[2]), xlo); 
 1195         gtlo3 = _mm256_cmpgt_epi64(_mm256_loadu_si256(&rlo57.ymm[3]), xlo); 
 1196         gtlo4 = _mm256_cmpgt_epi64(_mm256_loadu_si256(&rlo57.ymm[4]), xlo); 
 1197 
 1198         /*
 1199          * Keep only comparison results that correspond to the non-zero
 1200          * elements in eqhi.
 1201          */
 1202         gtlo0 = _mm256_and_si256(gtlo0, _mm256_cvtepi16_epi64(
 1203                 _mm256_castsi256_si128(eqhi)));
 1204         gtlo1 = _mm256_and_si256(gtlo1, _mm256_cvtepi16_epi64(
 1205                 _mm256_castsi256_si128(_mm256_bsrli_epi128(eqhi, 8))));
 1206         eqm = _mm256_permute4x64_epi64(eqhi, 0xFF);
 1207         gtlo2 = _mm256_and_si256(gtlo2, eqm);
 1208         gtlo3 = _mm256_and_si256(gtlo3, eqm);
 1209         gtlo4 = _mm256_and_si256(gtlo4, eqm);
 1210 
 1211         /*
 1212          * Add all values to count the total number of "-1" elements.
 1213          * Since the first eight "high" words are all different, only
 1214          * one element (at most) in gtlo0:gtlo1 can be non-zero; however,
 1215          * if the high word of the random value is zero, then many
 1216          * elements of gtlo2:gtlo3:gtlo4 can be non-zero.
 1217          */
 1218         gtlo0 = _mm256_or_si256(gtlo0, gtlo1);
 1219         gtlo0 = _mm256_add_epi64(
 1220                 _mm256_add_epi64(gtlo0, gtlo2),
 1221                 _mm256_add_epi64(gtlo3, gtlo4));
 1222         t = _mm_add_epi64(
 1223                 _mm256_castsi256_si128(gtlo0),
 1224                 _mm256_extracti128_si256(gtlo0, 1));
 1225         t = _mm_add_epi64(t, _mm_srli_si128(t, 8));
 1226         r -= _mm_cvtsi128_si32(t);
 1227 
 1228         return r;
 1229 
 1230 #else // yyyAVX2+0
 1231 
 1232         static const uint32_t dist[] = {
 1233                 10745844u,  3068844u,  3741698u,
 1234                  5559083u,  1580863u,  8248194u,
 1235                  2260429u, 13669192u,  2736639u,
 1236                   708981u,  4421575u, 10046180u,
 1237                   169348u,  7122675u,  4136815u,
 1238                    30538u, 13063405u,  7650655u,
 1239                     4132u, 14505003u,  7826148u,
 1240                      417u, 16768101u, 11363290u,
 1241                       31u,  8444042u,  8086568u,
 1242                        1u, 12844466u,   265321u,
 1243                        0u,  1232676u, 13644283u,
 1244                        0u,    38047u,  9111839u,
 1245                        0u,      870u,  6138264u,
 1246                        0u,       14u, 12545723u,
 1247                        0u,        0u,  3104126u,
 1248                        0u,        0u,    28824u,
 1249                        0u,        0u,      198u,
 1250                        0u,        0u,        1u
 1251         };
 1252 
 1253         uint32_t v0, v1, v2, hi;
 1254         uint64_t lo;
 1255         size_t u;
 1256         int z;
 1257 
 1258         /*
 1259          * Get a random 72-bit value, into three 24-bit limbs v0..v2.
 1260          */
 1261         lo = prng_get_u64(p);
 1262         hi = prng_get_u8(p);
 1263         v0 = (uint32_t)lo & 0xFFFFFF;
 1264         v1 = (uint32_t)(lo >> 24) & 0xFFFFFF;
 1265         v2 = (uint32_t)(lo >> 48) | (hi << 16);
 1266 
 1267         /*
 1268          * Sampled value is z, such that v0..v2 is lower than the first
 1269          * z elements of the table.
 1270          */
 1271         z = 0;
 1272         for (u = 0; u < (sizeof dist) / sizeof(dist[0]); u += 3) {
 1273                 uint32_t w0, w1, w2, cc;
 1274 
 1275                 w0 = dist[u + 2];
 1276                 w1 = dist[u + 1];
 1277                 w2 = dist[u + 0];
 1278                 cc = (v0 - w0) >> 31;
 1279                 cc = (v1 - w1 - cc) >> 31;
 1280                 cc = (v2 - w2 - cc) >> 31;
 1281                 z += (int)cc;
 1282         }
 1283         return z;
 1284 
 1285 #endif // yyyAVX2-
 1286 }
 1287 
 1288 /*
 1289  * Sample a bit with probability exp(-x) for some x >= 0.
 1290  */
 1291 TARGET_AVX2
 1292 static int
 1293 BerExp(prng *p, fpr x, fpr ccs)
 1294 {
 1295         int s, i;
 1296         fpr r;
 1297         uint32_t sw, w;
 1298         uint64_t z;
 1299 
 1300         /*
 1301          * Reduce x modulo log(2): x = s*log(2) + r, with s an integer,
 1302          * and 0 <= r < log(2). Since x >= 0, we can use fpr_trunc().
 1303          */
 1304         s = (int)fpr_trunc(fpr_mul(x, fpr_inv_log2));
 1305         r = fpr_sub(x, fpr_mul(fpr_of(s), fpr_log2));
 1306 
 1307         /*
 1308          * It may happen (quite rarely) that s >= 64; if sigma = 1.2
 1309          * (the minimum value for sigma), r = 0 and b = 1, then we get
 1310          * s >= 64 if the half-Gaussian produced a z >= 13, which happens
 1311          * with probability about 0.000000000230383991, which is
 1312          * approximatively equal to 2^(-32). In any case, if s >= 64,
 1313          * then BerExp will be non-zero with probability less than
 1314          * 2^(-64), so we can simply saturate s at 63.
 1315          */
 1316         sw = (uint32_t)s;
 1317         sw ^= (sw ^ 63) & -((63 - sw) >> 31);
 1318         s = (int)sw;
 1319 
 1320         /*
 1321          * Compute exp(-r); we know that 0 <= r < log(2) at this point, so
 1322          * we can use fpr_expm_p63(), which yields a result scaled to 2^63.
 1323          * We scale it up to 2^64, then right-shift it by s bits because
 1324          * we really want exp(-x) = 2^(-s)*exp(-r).
 1325          *
 1326          * The "-1" operation makes sure that the value fits on 64 bits
 1327          * (i.e. if r = 0, we may get 2^64, and we prefer 2^64-1 in that
 1328          * case). The bias is negligible since fpr_expm_p63() only computes
 1329          * with 51 bits of precision or so.
 1330          */
 1331         z = ((fpr_expm_p63(r, ccs) << 1) - 1) >> s;
 1332 
 1333         /*
 1334          * Sample a bit with probability exp(-x). Since x = s*log(2) + r,
 1335          * exp(-x) = 2^-s * exp(-r), we compare lazily exp(-x) with the
 1336          * PRNG output to limit its consumption, the sign of the difference
 1337          * yields the expected result.
 1338          */
 1339         i = 64;
 1340         do {
 1341                 i -= 8;
 1342                 w = prng_get_u8(p) - ((uint32_t)(z >> i) & 0xFF);
 1343         } while (!w && i > 0);
 1344         return (int)(w >> 31);
 1345 }
 1346 
 1347 /*
 1348  * The sampler produces a random integer that follows a discrete Gaussian
 1349  * distribution, centered on mu, and with standard deviation sigma. The
 1350  * provided parameter isigma is equal to 1/sigma.
 1351  *
 1352  * The value of sigma MUST lie between 1 and 2 (i.e. isigma lies between
 1353  * 0.5 and 1); in Falcon, sigma should always be between 1.2 and 1.9.
 1354  */
 1355 TARGET_AVX2
 1356 int
 1357 Zf(sampler)(void *ctx, fpr mu, fpr isigma)
 1358 {
 1359         sampler_context *spc;
 1360         int s;
 1361         fpr r, dss, ccs;
 1362 
 1363         spc = ctx;
 1364 
 1365         /*
 1366          * Center is mu. We compute mu = s + r where s is an integer
 1367          * and 0 <= r < 1.
 1368          */
 1369         s = (int)fpr_floor(mu);
 1370         r = fpr_sub(mu, fpr_of(s));
 1371 
 1372         /*
 1373          * dss = 1/(2*sigma^2) = 0.5*(isigma^2).
 1374          */
 1375         dss = fpr_half(fpr_sqr(isigma));
 1376 
 1377         /*
 1378          * ccs = sigma_min / sigma = sigma_min * isigma.
 1379          */
 1380         ccs = fpr_mul(isigma, spc->sigma_min);
 1381 
 1382         /*
 1383          * We now need to sample on center r.
 1384          */
 1385         for (;;) {
 1386                 int z0, z, b;
 1387                 fpr x;
 1388 
 1389                 /*
 1390                  * Sample z for a Gaussian distribution. Then get a
 1391                  * random bit b to turn the sampling into a bimodal
 1392                  * distribution: if b = 1, we use z+1, otherwise we
 1393                  * use -z. We thus have two situations:
 1394                  *
 1395                  *  - b = 1: z >= 1 and sampled against a Gaussian
 1396                  *    centered on 1.
 1397                  *  - b = 0: z <= 0 and sampled against a Gaussian
 1398                  *    centered on 0.
 1399                  */
 1400                 z0 = Zf(gaussian0_sampler)(&spc->p);
 1401                 b = (int)prng_get_u8(&spc->p) & 1;
 1402                 z = b + ((b << 1) - 1) * z0;
 1403 
 1404                 /*
 1405                  * Rejection sampling. We want a Gaussian centered on r;
 1406                  * but we sampled against a Gaussian centered on b (0 or
 1407                  * 1). But we know that z is always in the range where
 1408                  * our sampling distribution is greater than the Gaussian
 1409                  * distribution, so rejection works.
 1410                  *
 1411                  * We got z with distribution:
 1412                  *    G(z) = exp(-((z-b)^2)/(2*sigma0^2))
 1413                  * We target distribution:
 1414                  *    S(z) = exp(-((z-r)^2)/(2*sigma^2))
 1415                  * Rejection sampling works by keeping the value z with
 1416                  * probability S(z)/G(z), and starting again otherwise.
 1417                  * This requires S(z) <= G(z), which is the case here.
 1418                  * Thus, we simply need to keep our z with probability:
 1419                  *    P = exp(-x)
 1420                  * where:
 1421                  *    x = ((z-r)^2)/(2*sigma^2) - ((z-b)^2)/(2*sigma0^2)
 1422                  *
 1423                  * Here, we scale up the Bernouilli distribution, which
 1424                  * makes rejection more probable, but makes rejection
 1425                  * rate sufficiently decorrelated from the Gaussian
 1426                  * center and standard deviation that the whole sampler
 1427                  * can be said to be constant-time.
 1428                  */
 1429                 x = fpr_mul(fpr_sqr(fpr_sub(fpr_of(z), r)), dss);
 1430                 x = fpr_sub(x, fpr_mul(fpr_of(z0 * z0), fpr_inv_2sqrsigma0));
 1431                 if (BerExp(&spc->p, x, ccs)) {
 1432                         /*
 1433                          * Rejection sampling was centered on r, but the
 1434                          * actual center is mu = s + r.
 1435                          */
 1436                         return s + z;
 1437                 }
 1438         }
 1439 }
 1440 
 1441 /* see inner.h */
 1442 void
 1443 Zf(sign_tree)(int16_t *sig, inner_shake256_context *rng,
 1444         const fpr *restrict expanded_key,
 1445         const uint16_t *hm, unsigned logn, uint8_t *tmp)
 1446 {
 1447         fpr *ftmp;
 1448 
 1449         ftmp = (fpr *)tmp;
 1450         for (;;) {
 1451                 /*
 1452                  * Signature produces short vectors s1 and s2. The
 1453                  * signature is acceptable only if the aggregate vector
 1454                  * s1,s2 is short; we must use the same bound as the
 1455                  * verifier.
 1456                  *
 1457                  * If the signature is acceptable, then we return only s2
 1458                  * (the verifier recomputes s1 from s2, the hashed message,
 1459                  * and the public key).
 1460                  */
 1461                 sampler_context spc;
 1462                 samplerZ samp;
 1463                 void *samp_ctx;
 1464 
 1465                 /*
 1466                  * Normal sampling. We use a fast PRNG seeded from our
 1467                  * SHAKE context ('rng').
 1468                  */
 1469                 spc.sigma_min = fpr_sigma_min[logn];
 1470                 Zf(prng_init)(&spc.p, rng);
 1471                 samp = Zf(sampler);
 1472                 samp_ctx = &spc;
 1473 
 1474                 /*
 1475                  * Do the actual signature.
 1476                  */
 1477                 if (do_sign_tree(samp, samp_ctx, sig,
 1478                         expanded_key, hm, logn, ftmp))
 1479                 {
 1480                         break;
 1481                 }
 1482         }
 1483 }
 1484 
 1485 /* see inner.h */
 1486 void
 1487 Zf(sign_dyn)(int16_t *sig, inner_shake256_context *rng,
 1488         const int8_t *restrict f, const int8_t *restrict g,
 1489         const int8_t *restrict F, const int8_t *restrict G,
 1490         const uint16_t *hm, unsigned logn, uint8_t *tmp)
 1491 {
 1492         fpr *ftmp;
 1493 
 1494         ftmp = (fpr *)tmp;
 1495         for (;;) {
 1496                 /*
 1497                  * Signature produces short vectors s1 and s2. The
 1498                  * signature is acceptable only if the aggregate vector
 1499                  * s1,s2 is short; we must use the same bound as the
 1500                  * verifier.
 1501                  *
 1502                  * If the signature is acceptable, then we return only s2
 1503                  * (the verifier recomputes s1 from s2, the hashed message,
 1504                  * and the public key).
 1505                  */
 1506                 sampler_context spc;
 1507                 samplerZ samp;
 1508                 void *samp_ctx;
 1509 
 1510                 /*
 1511                  * Normal sampling. We use a fast PRNG seeded from our
 1512                  * SHAKE context ('rng').
 1513                  */
 1514                 spc.sigma_min = fpr_sigma_min[logn];
 1515                 Zf(prng_init)(&spc.p, rng);
 1516                 samp = Zf(sampler);
 1517                 samp_ctx = &spc;
 1518 
 1519                 /*
 1520                  * Do the actual signature.
 1521                  */
 1522                 if (do_sign_dyn(samp, samp_ctx, sig,
 1523                         f, g, F, G, hm, logn, ftmp))
 1524                 {
 1525                         break;
 1526                 }
 1527         }
 1528 }