Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions config/zen/bli_cntx_init_zen.c
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ void bli_cntx_init_zen( cntx_t* cntx )
// setv
BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int,
BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int,
BLIS_SETV_KER, BLIS_SCOMPLEX, bli_csetv_zen_int,
BLIS_SETV_KER, BLIS_DCOMPLEX, bli_zsetv_zen_int,

// swapv
BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8,
Expand Down
2 changes: 2 additions & 0 deletions config/zen2/bli_cntx_init_zen2.c
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ void bli_cntx_init_zen2( cntx_t* cntx )
//set
BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int,
BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int,
BLIS_SETV_KER, BLIS_SCOMPLEX, bli_csetv_zen_int,
BLIS_SETV_KER, BLIS_DCOMPLEX, bli_zsetv_zen_int,

BLIS_VA_END
);
Expand Down
2 changes: 2 additions & 0 deletions config/zen3/bli_cntx_init_zen3.c
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ void bli_cntx_init_zen3( cntx_t* cntx )
// setv
BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int,
BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int,
BLIS_SETV_KER, BLIS_SCOMPLEX, bli_csetv_zen_int,
BLIS_SETV_KER, BLIS_DCOMPLEX, bli_zsetv_zen_int,

BLIS_VA_END
);
Expand Down
220 changes: 219 additions & 1 deletion kernels/zen/1/bli_setv_zen_int.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.

Copyright (C) 2020, Advanced Micro Devices, Inc.
Copyright (C) 2020-2026, Advanced Micro Devices, Inc.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
Expand Down Expand Up @@ -232,3 +232,221 @@ void bli_dsetv_zen_int
}
}

void bli_csetv_zen_int
(
conj_t conjalpha,
dim_t n,
const void* alpha0,
void* x00, inc_t incx,
const cntx_t* cntx
)
{
scomplex* alpha = (scomplex *)alpha0;
scomplex * x = x00;

// Declaring and initializing local variables and pointers
const dim_t num_elem_per_reg = 8;
dim_t i = 0;
float *x0 = (float *)x;

// If the vector dimension is zero return early.
if ( bli_zero_dim1( n ) ) return;
scomplex alpha_conj = *alpha;

// Handle conjugation of alpha
if( bli_is_conj( conjalpha ) ) alpha_conj.imag = -alpha_conj.imag;

if ( incx == 1 )
{
__m256 alphaRv, alphaIv, alphav;

// Broadcast the scomplex alpha value
alphaRv = _mm256_broadcast_ss( &(alpha_conj.real) );
alphaIv = _mm256_broadcast_ss( &(alpha_conj.imag) );
alphav = _mm256_unpacklo_ps( alphaRv, alphaIv );

// The condition n & ~0x3F => n & 0xFFFFFFC0
// This sets the lower 6 bits to 0 and results in multiples of 64
// Thus, we iterate in blocks of 64 scomplex elements
// Fringe loops have similar conditions to set their masks(32, 16, ...)
for ( i = 0; i < (n & (~0x3F)); i += 64 )
{
_mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 2, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 3, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 4, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 5, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 6, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 7, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 8, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 9, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 10, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 11, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 12, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 13, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 14, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 15, alphav);

x0 += num_elem_per_reg * 16;
}
for ( ; i < (n & (~0x1F)); i += 32 )
{
_mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 2, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 3, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 4, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 5, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 6, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 7, alphav);

x0 += num_elem_per_reg * 8;
}
for ( ; i < (n & (~0x0F)); i += 16 )
{
_mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 2, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 3, alphav);

x0 += num_elem_per_reg * 4;
}
for ( ; i < (n & (~0x07)); i += 8 )
{
_mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav);
_mm256_storeu_ps(x0 + num_elem_per_reg * 1, alphav);

x0 += num_elem_per_reg * 2;
}
for ( ; i < (n & (~0x03)); i += 4 )
{
_mm256_storeu_ps(x0 + num_elem_per_reg * 0, alphav);
x0 += num_elem_per_reg;
}

_mm256_zeroupper();

}

// Code-section for non-unit stride
for( ; i < n; i += 1 )
{
*x0 = alpha_conj.real;
*(x0 + 1) = alpha_conj.imag;

x0 += 2 * incx;
}

}

void bli_zsetv_zen_int
(
conj_t conjalpha,
dim_t n,
const void* alpha0,
void* x00, inc_t incx,
const cntx_t* cntx
)
{
dcomplex* alpha = (dcomplex *)alpha0;
dcomplex* x = x00;
// Declaring and initializing local variables and pointers
const dim_t num_elem_per_reg = 4;
dim_t i = 0;
double *x0 = (double *)x;

// If the vector dimension is zero return early.
if ( bli_zero_dim1( n ) ) return;

// Handle conjugation of alpha
if( bli_is_conj( conjalpha ) ) alpha->imag = -alpha->imag;

if ( incx == 1 )
{
__m256d alphav;

// Broadcast the dcomplex alpha value
alphav = _mm256_broadcast_pd( (const __m128d *)alpha );

// The condition n & ~0x1F => n & 0xFFFFFFE0
// This sets the lower 5 bits to 0 and results in multiples of 32
// Thus, we iterate in blocks of 32 elements
// Fringe loops have similar conditions to set their masks(16, 8, ...)
for ( i = 0; i < (n & (~0x1F)); i += 32 )
{
_mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 2, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 3, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 4, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 5, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 6, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 7, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 8, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 9, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 10, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 11, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 12, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 13, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 14, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 15, alphav);

x0 += num_elem_per_reg * 16;
}
for ( ; i < (n & (~0x0F)); i += 16 )
{
_mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 2, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 3, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 4, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 5, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 6, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 7, alphav);

x0 += num_elem_per_reg * 8;
}
for ( ; i < (n & (~0x07)); i += 8 )
{
_mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 2, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 3, alphav);

x0 += num_elem_per_reg * 4;
}
for ( ; i < (n & (~0x03)); i += 4 )
{
_mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav);
_mm256_storeu_pd(x0 + num_elem_per_reg * 1, alphav);

x0 += num_elem_per_reg * 2;
}
for ( ; i < (n & (~0x01)); i += 2 )
{
_mm256_storeu_pd(x0 + num_elem_per_reg * 0, alphav);
x0 += num_elem_per_reg;
}

// Issue vzeroupper instruction to clear upper lanes of ymm registers.
// This avoids a performance penalty caused by false dependencies when
// transitioning from AVX to SSE instructions (which may occur later,
// especially if BLIS is compiled with -mfpmath=sse).
_mm256_zeroupper();
}

if ( i < n )
{
__m128d alphav;
alphav = _mm_loadu_pd((const double*)alpha);

for( ; i < n; i += 1 )
{
_mm_storeu_pd(x0, alphav);
x0 += 2 * incx;
}
}

}

2 changes: 2 additions & 0 deletions kernels/zen/bli_kernels_zen.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ COPYV_KER_PROT( double, d, copyv_zen_int )
//
SETV_KER_PROT(float, s, setv_zen_int)
SETV_KER_PROT(double, d, setv_zen_int)
SETV_KER_PROT( scomplex, c, setv_zen_int)
SETV_KER_PROT( dcomplex, z, setv_zen_int)

// swapv (intrinsics)
SWAPV_KER_PROT(float, s, swapv_zen_int8 )
Expand Down