Last active
December 15, 2020 14:28
-
-
Save luhenry/6b24ac146a110143ad31736caf7250e6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Licensed to the Apache Software Foundation (ASF) under one or more | |
* contributor license agreements. See the NOTICE file distributed with | |
* this work for additional information regarding copyright ownership. | |
* The ASF licenses this file to You under the Apache License, Version 2.0 | |
* (the "License"); you may not use this file except in compliance with | |
* the License. You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.apache.spark.ml.linalg | |
import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} | |
/** | |
* BLAS routines for MLlib's vectors and matrices. | |
*/ | |
private[spark] object BLAS extends Serializable { | |
@transient private var _javaBLAS: NetlibBLAS = _ | |
@transient private var _nativeBLAS: NetlibBLAS = _ | |
private val nativeL1Threshold: Int = 256 | |
// For level-1 function dspmv, use javaBLAS for better performance. | |
def javaBLAS: NetlibBLAS = { | |
if (_javaBLAS == null) { | |
try { | |
_javaBLAS = new VectorBLAS | |
} catch { | |
case e: NoClassDefFoundError => | |
_javaBLAS = new F2jBLAS | |
} | |
println("-------------------- javaBLAS = " + javaBLAS.getClass.getName) // scalastyle:ignore | |
} | |
_javaBLAS | |
} | |
// For level-3 routines, we use the native BLAS. | |
def nativeBLAS: NetlibBLAS = { | |
if (_nativeBLAS == null) { | |
_nativeBLAS = | |
if (NetlibBLAS.getInstance.isInstanceOf[F2jBLAS]) { | |
javaBLAS | |
} else { | |
NetlibBLAS.getInstance | |
} | |
} | |
_nativeBLAS | |
} | |
def getBLAS(vectorSize: Int): NetlibBLAS = { | |
if (vectorSize < nativeL1Threshold) { | |
javaBLAS | |
} else { | |
nativeBLAS | |
} | |
} | |
/** | |
* y += a * x | |
*/ | |
def axpy(a: Double, x: Vector, y: Vector): Unit = { | |
require(x.size == y.size) | |
y match { | |
case dy: DenseVector => | |
x match { | |
case sx: SparseVector => | |
axpy(a, sx, dy) | |
case dx: DenseVector => | |
axpy(a, dx, dy) | |
case _ => | |
throw new UnsupportedOperationException( | |
s"axpy doesn't support x type ${x.getClass}.") | |
} | |
case _ => | |
throw new IllegalArgumentException( | |
s"axpy only supports adding to a dense vector but got type ${y.getClass}.") | |
} | |
} | |
/** | |
* y += a * x | |
*/ | |
private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = { | |
val n = x.size | |
getBLAS(n).daxpy(n, a, x.values, 1, y.values, 1) | |
} | |
/** | |
* y += a * x | |
*/ | |
private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = { | |
val xValues = x.values | |
val xIndices = x.indices | |
val yValues = y.values | |
val nnz = xIndices.length | |
if (a == 1.0) { | |
var k = 0 | |
while (k < nnz) { | |
yValues(xIndices(k)) += xValues(k) | |
k += 1 | |
} | |
} else { | |
var k = 0 | |
while (k < nnz) { | |
yValues(xIndices(k)) += a * xValues(k) | |
k += 1 | |
} | |
} | |
} | |
/** Y += a * x */ | |
private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = { | |
require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " + | |
s"size(X) = ${(X.numRows, X.numCols)} but size(Y) = ${(Y.numRows, Y.numCols)}.") | |
getBLAS(X.values.length).daxpy(X.numRows * X.numCols, a, X.values, 1, Y.values, 1) | |
} | |
/** | |
* dot(x, y) | |
*/ | |
def dot(x: Vector, y: Vector): Double = { | |
require(x.size == y.size, | |
"BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" + | |
" x.size = " + x.size + ", y.size = " + y.size) | |
(x, y) match { | |
case (dx: DenseVector, dy: DenseVector) => | |
dot(dx, dy) | |
case (sx: SparseVector, dy: DenseVector) => | |
dot(sx, dy) | |
case (dx: DenseVector, sy: SparseVector) => | |
dot(sy, dx) | |
case (sx: SparseVector, sy: SparseVector) => | |
dot(sx, sy) | |
case _ => | |
throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).") | |
} | |
} | |
/** | |
* dot(x, y) | |
*/ | |
private def dot(x: DenseVector, y: DenseVector): Double = { | |
val n = x.size | |
getBLAS(n).ddot(n, x.values, 1, y.values, 1) | |
} | |
/** | |
* dot(x, y) | |
*/ | |
private def dot(x: SparseVector, y: DenseVector): Double = { | |
val xValues = x.values | |
val xIndices = x.indices | |
val yValues = y.values | |
val nnz = xIndices.length | |
var sum = 0.0 | |
var k = 0 | |
while (k < nnz) { | |
sum += xValues(k) * yValues(xIndices(k)) | |
k += 1 | |
} | |
sum | |
} | |
/** | |
* dot(x, y) | |
*/ | |
private def dot(x: SparseVector, y: SparseVector): Double = { | |
val xValues = x.values | |
val xIndices = x.indices | |
val yValues = y.values | |
val yIndices = y.indices | |
val nnzx = xIndices.length | |
val nnzy = yIndices.length | |
var kx = 0 | |
var ky = 0 | |
var sum = 0.0 | |
// y catching x | |
while (kx < nnzx && ky < nnzy) { | |
val ix = xIndices(kx) | |
while (ky < nnzy && yIndices(ky) < ix) { | |
ky += 1 | |
} | |
if (ky < nnzy && yIndices(ky) == ix) { | |
sum += xValues(kx) * yValues(ky) | |
ky += 1 | |
} | |
kx += 1 | |
} | |
sum | |
} | |
/** | |
* y = x | |
*/ | |
def copy(x: Vector, y: Vector): Unit = { | |
val n = y.size | |
require(x.size == n) | |
y match { | |
case dy: DenseVector => | |
x match { | |
case sx: SparseVector => | |
val sxIndices = sx.indices | |
val sxValues = sx.values | |
val dyValues = dy.values | |
val nnz = sxIndices.length | |
var i = 0 | |
var k = 0 | |
while (k < nnz) { | |
val j = sxIndices(k) | |
while (i < j) { | |
dyValues(i) = 0.0 | |
i += 1 | |
} | |
dyValues(i) = sxValues(k) | |
i += 1 | |
k += 1 | |
} | |
while (i < n) { | |
dyValues(i) = 0.0 | |
i += 1 | |
} | |
case dx: DenseVector => | |
Array.copy(dx.values, 0, dy.values, 0, n) | |
} | |
case _ => | |
throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}") | |
} | |
} | |
/** | |
* x = a * x | |
*/ | |
def scal(a: Double, x: Vector): Unit = { | |
x match { | |
case sx: SparseVector => | |
getBLAS(sx.values.length).dscal(sx.values.length, a, sx.values, 1) | |
case dx: DenseVector => | |
getBLAS(dx.size).dscal(dx.values.length, a, dx.values, 1) | |
case _ => | |
throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") | |
} | |
} | |
/** | |
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR. | |
* | |
* @param U the upper triangular part of the matrix in a [[DenseVector]](column major) | |
*/ | |
def spr(alpha: Double, v: Vector, U: DenseVector): Unit = { | |
spr(alpha, v, U.values) | |
} | |
/** | |
* y := alpha*A*x + beta*y | |
* | |
* @param n The order of the n by n matrix A. | |
* @param A The upper triangular part of A in a [[DenseVector]] (column major). | |
* @param x The [[DenseVector]] transformed by A. | |
* @param y The [[DenseVector]] to be modified in place. | |
*/ | |
def dspmv( | |
n: Int, | |
alpha: Double, | |
A: DenseVector, | |
x: DenseVector, | |
beta: Double, | |
y: DenseVector): Unit = { | |
javaBLAS.dspmv("U", n, alpha, A.values, x.values, 1, beta, y.values, 1) | |
} | |
/** | |
* Adds alpha * v * v.t to a matrix in-place. This is the same as BLAS's ?SPR. | |
* | |
* @param U the upper triangular part of the matrix packed in an array (column major) | |
*/ | |
def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = { | |
val n = v.size | |
v match { | |
case DenseVector(values) => | |
nativeBLAS.dspr("U", n, alpha, values, 1, U) | |
case SparseVector(size, indices, values) => | |
val nnz = indices.length | |
var colStartIdx = 0 | |
var prevCol = 0 | |
var col = 0 | |
var j = 0 | |
var i = 0 | |
var av = 0.0 | |
while (j < nnz) { | |
col = indices(j) | |
// Skip empty columns. | |
colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 | |
col = indices(j) | |
av = alpha * values(j) | |
i = 0 | |
while (i <= j) { | |
U(colStartIdx + indices(i)) += av * values(i) | |
i += 1 | |
} | |
j += 1 | |
prevCol = col | |
} | |
case _ => | |
throw new IllegalArgumentException(s"spr doesn't support vector type ${v.getClass}.") | |
} | |
} | |
/** | |
* A := alpha * x * x^T^ + A | |
* @param alpha a real scalar that will be multiplied to x * x^T^. | |
* @param x the vector x that contains the n elements. | |
* @param A the symmetric matrix A. Size of n x n. | |
*/ | |
def syr(alpha: Double, x: Vector, A: DenseMatrix): Unit = { | |
val mA = A.numRows | |
val nA = A.numCols | |
require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA") | |
require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}") | |
x match { | |
case dv: DenseVector => syr(alpha, dv, A) | |
case sv: SparseVector => syr(alpha, sv, A) | |
case _ => | |
throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.") | |
} | |
} | |
private def syr(alpha: Double, x: DenseVector, A: DenseMatrix): Unit = { | |
val nA = A.numRows | |
val mA = A.numCols | |
nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA) | |
// Fill lower triangular part of A | |
var i = 0 | |
while (i < mA) { | |
var j = i + 1 | |
while (j < nA) { | |
A(j, i) = A(i, j) | |
j += 1 | |
} | |
i += 1 | |
} | |
} | |
private def syr(alpha: Double, x: SparseVector, A: DenseMatrix): Unit = { | |
val mA = A.numCols | |
val xIndices = x.indices | |
val xValues = x.values | |
val nnz = xValues.length | |
val Avalues = A.values | |
var i = 0 | |
while (i < nnz) { | |
val multiplier = alpha * xValues(i) | |
val offset = xIndices(i) * mA | |
var j = 0 | |
while (j < nnz) { | |
Avalues(xIndices(j) + offset) += multiplier * xValues(j) | |
j += 1 | |
} | |
i += 1 | |
} | |
} | |
/** | |
* C := alpha * A * B + beta * C | |
* @param alpha a scalar to scale the multiplication A * B. | |
* @param A the matrix A that will be left multiplied to B. Size of m x k. | |
* @param B the matrix B that will be left multiplied by A. Size of k x n. | |
* @param beta a scalar that can be used to scale matrix C. | |
* @param C the resulting matrix C. Size of m x n. C.isTransposed must be false. | |
*/ | |
def gemm( | |
alpha: Double, | |
A: Matrix, | |
B: DenseMatrix, | |
beta: Double, | |
C: DenseMatrix): Unit = { | |
require(!C.isTransposed, | |
"The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") | |
if (alpha == 0.0 && beta == 1.0) { | |
// gemm: alpha is equal to 0 and beta is equal to 1. Returning C. | |
return | |
} else if (alpha == 0.0) { | |
getBLAS(C.values.length).dscal(C.values.length, beta, C.values, 1) | |
} else { | |
A match { | |
case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C) | |
case dense: DenseMatrix => gemm(alpha, dense, B, beta, C) | |
case _ => | |
throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.") | |
} | |
} | |
} | |
/** | |
* C := alpha * A * B + beta * C | |
* For `DenseMatrix` A. | |
*/ | |
private def gemm( | |
alpha: Double, | |
A: DenseMatrix, | |
B: DenseMatrix, | |
beta: Double, | |
C: DenseMatrix): Unit = { | |
val tAstr = if (A.isTransposed) "T" else "N" | |
val tBstr = if (B.isTransposed) "T" else "N" | |
val lda = if (!A.isTransposed) A.numRows else A.numCols | |
val ldb = if (!B.isTransposed) B.numRows else B.numCols | |
require(A.numCols == B.numRows, | |
s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}") | |
require(A.numRows == C.numRows, | |
s"The rows of C don't match the rows of A. C: ${C.numRows}, A: ${A.numRows}") | |
require(B.numCols == C.numCols, | |
s"The columns of C don't match the columns of B. C: ${C.numCols}, A: ${B.numCols}") | |
nativeBLAS.dgemm(tAstr, tBstr, A.numRows, B.numCols, A.numCols, alpha, A.values, lda, | |
B.values, ldb, beta, C.values, C.numRows) | |
} | |
/** | |
* C := alpha * A * B + beta * C | |
* For `SparseMatrix` A. | |
*/ | |
private def gemm( | |
alpha: Double, | |
A: SparseMatrix, | |
B: DenseMatrix, | |
beta: Double, | |
C: DenseMatrix): Unit = { | |
val mA: Int = A.numRows | |
val nB: Int = B.numCols | |
val kA: Int = A.numCols | |
val kB: Int = B.numRows | |
require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") | |
require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") | |
require(nB == C.numCols, | |
s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") | |
val Avals = A.values | |
val Bvals = B.values | |
val Cvals = C.values | |
val ArowIndices = A.rowIndices | |
val AcolPtrs = A.colPtrs | |
// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices | |
if (A.isTransposed) { | |
var colCounterForB = 0 | |
if (!B.isTransposed) { // Expensive to put the check inside the loop | |
while (colCounterForB < nB) { | |
var rowCounterForA = 0 | |
val Cstart = colCounterForB * mA | |
val Bstart = colCounterForB * kA | |
while (rowCounterForA < mA) { | |
var i = AcolPtrs(rowCounterForA) | |
val indEnd = AcolPtrs(rowCounterForA + 1) | |
var sum = 0.0 | |
while (i < indEnd) { | |
sum += Avals(i) * Bvals(Bstart + ArowIndices(i)) | |
i += 1 | |
} | |
val Cindex = Cstart + rowCounterForA | |
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha | |
rowCounterForA += 1 | |
} | |
colCounterForB += 1 | |
} | |
} else { | |
while (colCounterForB < nB) { | |
var rowCounterForA = 0 | |
val Cstart = colCounterForB * mA | |
while (rowCounterForA < mA) { | |
var i = AcolPtrs(rowCounterForA) | |
val indEnd = AcolPtrs(rowCounterForA + 1) | |
var sum = 0.0 | |
while (i < indEnd) { | |
sum += Avals(i) * B(ArowIndices(i), colCounterForB) | |
i += 1 | |
} | |
val Cindex = Cstart + rowCounterForA | |
Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha | |
rowCounterForA += 1 | |
} | |
colCounterForB += 1 | |
} | |
} | |
} else { | |
// Scale matrix first if `beta` is not equal to 1.0 | |
if (beta != 1.0) { | |
getBLAS(C.values.length).dscal(C.values.length, beta, C.values, 1) | |
} | |
// Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of | |
// B, and added to C. | |
var colCounterForB = 0 // the column to be updated in C | |
if (!B.isTransposed) { // Expensive to put the check inside the loop | |
while (colCounterForB < nB) { | |
var colCounterForA = 0 // The column of A to multiply with the row of B | |
val Bstart = colCounterForB * kB | |
val Cstart = colCounterForB * mA | |
while (colCounterForA < kA) { | |
var i = AcolPtrs(colCounterForA) | |
val indEnd = AcolPtrs(colCounterForA + 1) | |
val Bval = Bvals(Bstart + colCounterForA) * alpha | |
while (i < indEnd) { | |
Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval | |
i += 1 | |
} | |
colCounterForA += 1 | |
} | |
colCounterForB += 1 | |
} | |
} else { | |
while (colCounterForB < nB) { | |
var colCounterForA = 0 // The column of A to multiply with the row of B | |
val Cstart = colCounterForB * mA | |
while (colCounterForA < kA) { | |
var i = AcolPtrs(colCounterForA) | |
val indEnd = AcolPtrs(colCounterForA + 1) | |
val Bval = B(colCounterForA, colCounterForB) * alpha | |
while (i < indEnd) { | |
Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval | |
i += 1 | |
} | |
colCounterForA += 1 | |
} | |
colCounterForB += 1 | |
} | |
} | |
} | |
} | |
/** | |
* y := alpha * A * x + beta * y | |
* @param alpha a scalar to scale the multiplication A * x. | |
* @param A the matrix A that will be left multiplied to x. Size of m x n. | |
* @param x the vector x that will be left multiplied by A. Size of n x 1. | |
* @param beta a scalar that can be used to scale vector y. | |
* @param y the resulting vector y. Size of m x 1. | |
*/ | |
def gemv( | |
alpha: Double, | |
A: Matrix, | |
x: Vector, | |
beta: Double, | |
y: DenseVector): Unit = { | |
require(A.numCols == x.size, | |
s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") | |
require(A.numRows == y.size, | |
s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}") | |
if (alpha == 0.0 && beta == 1.0) { | |
// gemv: alpha is equal to 0 and beta is equal to 1. Returning y. | |
return | |
} else if (alpha == 0.0) { | |
scal(beta, y) | |
} else { | |
(A, x) match { | |
case (smA: SparseMatrix, dvx: DenseVector) => | |
gemv(alpha, smA, dvx, beta, y) | |
case (smA: SparseMatrix, svx: SparseVector) => | |
gemv(alpha, smA, svx, beta, y) | |
case (dmA: DenseMatrix, dvx: DenseVector) => | |
gemv(alpha, dmA, dvx, beta, y) | |
case (dmA: DenseMatrix, svx: SparseVector) => | |
gemv(alpha, dmA, svx, beta, y) | |
case _ => | |
throw new IllegalArgumentException(s"gemv doesn't support running on matrix type " + | |
s"${A.getClass} and vector type ${x.getClass}.") | |
} | |
} | |
} | |
/** | |
* y := alpha * A * x + beta * y | |
* For `DenseMatrix` A and `DenseVector` x. | |
*/ | |
private def gemv( | |
alpha: Double, | |
A: DenseMatrix, | |
x: DenseVector, | |
beta: Double, | |
y: DenseVector): Unit = { | |
val tStrA = if (A.isTransposed) "T" else "N" | |
val mA = if (!A.isTransposed) A.numRows else A.numCols | |
val nA = if (!A.isTransposed) A.numCols else A.numRows | |
nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, | |
y.values, 1) | |
} | |
/** | |
* y := alpha * A * x + beta * y | |
* For `DenseMatrix` A and `SparseVector` x. | |
*/ | |
private def gemv( | |
alpha: Double, | |
A: DenseMatrix, | |
x: SparseVector, | |
beta: Double, | |
y: DenseVector): Unit = { | |
val mA: Int = A.numRows | |
val nA: Int = A.numCols | |
val Avals = A.values | |
val xIndices = x.indices | |
val xNnz = xIndices.length | |
val xValues = x.values | |
val yValues = y.values | |
if (A.isTransposed) { | |
var rowCounterForA = 0 | |
while (rowCounterForA < mA) { | |
var sum = 0.0 | |
var k = 0 | |
while (k < xNnz) { | |
sum += xValues(k) * Avals(xIndices(k) + rowCounterForA * nA) | |
k += 1 | |
} | |
yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA) | |
rowCounterForA += 1 | |
} | |
} else { | |
var rowCounterForA = 0 | |
while (rowCounterForA < mA) { | |
var sum = 0.0 | |
var k = 0 | |
while (k < xNnz) { | |
sum += xValues(k) * Avals(xIndices(k) * mA + rowCounterForA) | |
k += 1 | |
} | |
yValues(rowCounterForA) = sum * alpha + beta * yValues(rowCounterForA) | |
rowCounterForA += 1 | |
} | |
} | |
} | |
/** | |
* y := alpha * A * x + beta * y | |
* For `SparseMatrix` A and `SparseVector` x. | |
*/ | |
private def gemv( | |
alpha: Double, | |
A: SparseMatrix, | |
x: SparseVector, | |
beta: Double, | |
y: DenseVector): Unit = { | |
val xValues = x.values | |
val xIndices = x.indices | |
val xNnz = xIndices.length | |
val yValues = y.values | |
val mA: Int = A.numRows | |
val nA: Int = A.numCols | |
val Avals = A.values | |
val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs | |
val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices | |
if (A.isTransposed) { | |
var rowCounter = 0 | |
while (rowCounter < mA) { | |
var i = Arows(rowCounter) | |
val indEnd = Arows(rowCounter + 1) | |
var sum = 0.0 | |
var k = 0 | |
while (i < indEnd && k < xNnz) { | |
if (xIndices(k) == Acols(i)) { | |
sum += Avals(i) * xValues(k) | |
k += 1 | |
i += 1 | |
} else if (xIndices(k) < Acols(i)) { | |
k += 1 | |
} else { | |
i += 1 | |
} | |
} | |
yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) | |
rowCounter += 1 | |
} | |
} else { | |
if (beta != 1.0) scal(beta, y) | |
var colCounterForA = 0 | |
var k = 0 | |
while (colCounterForA < nA && k < xNnz) { | |
if (xIndices(k) == colCounterForA) { | |
var i = Acols(colCounterForA) | |
val indEnd = Acols(colCounterForA + 1) | |
val xTemp = xValues(k) * alpha | |
while (i < indEnd) { | |
yValues(Arows(i)) += Avals(i) * xTemp | |
i += 1 | |
} | |
k += 1 | |
} | |
colCounterForA += 1 | |
} | |
} | |
} | |
/** | |
* y := alpha * A * x + beta * y | |
* For `SparseMatrix` A and `DenseVector` x. | |
*/ | |
private def gemv( | |
alpha: Double, | |
A: SparseMatrix, | |
x: DenseVector, | |
beta: Double, | |
y: DenseVector): Unit = { | |
val xValues = x.values | |
val yValues = y.values | |
val mA: Int = A.numRows | |
val nA: Int = A.numCols | |
val Avals = A.values | |
val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs | |
val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices | |
// Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices | |
if (A.isTransposed) { | |
var rowCounter = 0 | |
while (rowCounter < mA) { | |
var i = Arows(rowCounter) | |
val indEnd = Arows(rowCounter + 1) | |
var sum = 0.0 | |
while (i < indEnd) { | |
sum += Avals(i) * xValues(Acols(i)) | |
i += 1 | |
} | |
yValues(rowCounter) = beta * yValues(rowCounter) + sum * alpha | |
rowCounter += 1 | |
} | |
} else { | |
if (beta != 1.0) scal(beta, y) | |
// Perform matrix-vector multiplication and add to y | |
var colCounterForA = 0 | |
while (colCounterForA < nA) { | |
var i = Acols(colCounterForA) | |
val indEnd = Acols(colCounterForA + 1) | |
val xVal = xValues(colCounterForA) * alpha | |
while (i < indEnd) { | |
yValues(Arows(i)) += Avals(i) * xVal | |
i += 1 | |
} | |
colCounterForA += 1 | |
} | |
} | |
} | |
} | |
final private class VectorBLAS extends F2jBLAS { | |
import jdk.incubator.vector.DoubleVector | |
import jdk.incubator.vector.FloatVector | |
import jdk.incubator.vector.VectorOperators | |
val FMAX = FloatVector.SPECIES_MAX | |
val DMAX = DoubleVector.SPECIES_MAX | |
val D128 = DoubleVector.SPECIES_128 | |
// y += alpha * x | |
override def daxpy( | |
n: Int, | |
alpha: Double, | |
x: Array[Double], | |
incx: Int, | |
y: Array[Double], | |
incy: Int): Unit = { | |
// printf("daxpy(n=%s, alpha=%s, x=%s, incx=%s, y=%s, incy=%s)\n", // scalastyle = ignore | |
// n, alpha, x, incx, y, incy) | |
if (alpha != 0.0 && incx == 1 && incy == 1) { | |
var i = 0 | |
while (i < DMAX.loopBound(n)) { | |
val vx = DoubleVector.fromArray(DMAX, x, i) | |
val vy = DoubleVector.fromArray(DMAX, y, i) | |
vx.lanewise(VectorOperators.MUL, alpha).add(vy) | |
.intoArray(y, i) | |
i += DMAX.length() | |
} | |
while (i < n) { | |
y(i) += alpha * x(i) | |
i += 1 | |
} | |
} else { | |
super.daxpy(n, alpha, x, incx, y, incy) | |
} | |
} | |
// sum(x * y) | |
override def sdot( | |
n: Int, | |
x: Array[Float], | |
incx: Int, | |
y: Array[Float], | |
incy: Int): Float = { | |
// printf("sdot(n=%s, x=%s, incx=%s, y=%s, incy=%s)\n", // scalastyle = ignore | |
// n, x, incx, y, incy) | |
if (incx == 1 && incy == 1) { | |
var sum: Float = 0.0f | |
var i = 0 | |
while (i < FMAX.loopBound(n)) { | |
val vx = FloatVector.fromArray(FMAX, x, i) | |
val vy = FloatVector.fromArray(FMAX, y, i) | |
sum += vx.mul(vy).reduceLanes(VectorOperators.ADD) | |
i += FMAX.length() | |
} | |
while (i < n) { | |
sum += x(i) * y(i) | |
i += 1 | |
} | |
sum | |
} else { | |
super.sdot(n, x, incx, y, incy) | |
} | |
} | |
// sum(x * y) | |
override def ddot( | |
n: Int, | |
x: Array[Double], | |
incx: Int, | |
y: Array[Double], | |
incy: Int): Double = { | |
// printf("ddot(n=%s, x=%s, incx=%s, y=%s, incy=%s)\n", // scalastyle = ignore | |
// n, x, incx, y, incy) | |
if (incx == 1 && incy == 1) { | |
var sum: Double = 0.0 | |
var i = 0 | |
while (i < DMAX.loopBound(n)) { | |
val vx = DoubleVector.fromArray(DMAX, x, i) | |
val vy = DoubleVector.fromArray(DMAX, y, i) | |
sum += vx.mul(vy).reduceLanes(VectorOperators.ADD) | |
i += DMAX.length() | |
} | |
while (i < n) { | |
sum += x(i) * y(i) | |
i += 1 | |
} | |
sum | |
} else { | |
super.ddot(n, x, incx, y, incy) | |
} | |
} | |
// x = alpha * x | |
override def dscal( | |
n: Int, | |
alpha: Double, | |
x: Array[Double], | |
incx: Int): Unit = { | |
// printf("dscal(n=%s, alpha=%s, x=%s, incx=%s)\n", // scalastyle = ignore | |
// n, alpha, x, incx) | |
if (incx == 1) { | |
var i = 0 | |
while (i < DMAX.loopBound(n)) { | |
val vx = DoubleVector.fromArray(DMAX, x, i) | |
vx.lanewise(VectorOperators.MUL, alpha) | |
.intoArray(x, i) | |
i += DMAX.length() | |
} | |
while (i < n) { | |
x(i) *= alpha | |
i += 1 | |
} | |
} else { | |
super.dscal(n, alpha, x, incx) | |
} | |
} | |
// a += alpha * x * x.t | |
override def dspr( | |
uplo: String, | |
n: Int, | |
alpha: Double, | |
x: Array[Double], | |
incx: Int, | |
a: Array[Double]): Unit = { | |
// printf("dspr(uplo=%s, n=%s, alpha=%s, x=%s, incx=%s, a=%s)\n", // scalastyle = ignore | |
// uplo, n, alpha, x, incx, a) | |
if (uplo == "U" && alpha != 0.0 && incx == 1) { | |
var col = 0; | |
while (col < n) { | |
var row = 0; | |
while (row < DMAX.loopBound(col + 1)) { | |
val vx = DoubleVector.fromArray(DMAX, x, row) | |
val va = DoubleVector.fromArray(DMAX, a, row + col * (col + 1) / 2) | |
vx.lanewise(VectorOperators.MUL, alpha * x(col)).add(va) | |
.intoArray(a, row + col * (col + 1) / 2) | |
row += DMAX.length() | |
} | |
while (row < col + 1) { | |
a(row + col * (col + 1) / 2) += alpha * x(col) * x(row) | |
row += 1 | |
} | |
col += 1 | |
} | |
} else { | |
super.dspr(uplo, n, alpha, x, incx, a) | |
} | |
} | |
// a += alpha * x * x.t | |
override def dsyr( | |
uplo: String, | |
n: Int, | |
alpha: Double, | |
x: Array[Double], | |
incx: Int, | |
a: Array[Double], | |
lda: Int): Unit = { | |
// printf("dsyr(uplo=%s, n=%s, alpha=%s, x=%s, incx=%s, a=%s, lda=%s)\n", // scalastyle = ignore | |
// uplo, n, alpha, x, incx, a, lda) | |
if (uplo == "U" && alpha != 0.0 && incx == 1) { | |
var col = 0 | |
while (col < n) { | |
var row = 0 | |
while (row < DMAX.loopBound(col + 1)) { | |
val vx = DoubleVector.fromArray(DMAX, x, row) | |
val va = DoubleVector.fromArray(DMAX, a, row + col * n) | |
vx.lanewise(VectorOperators.MUL, alpha * x(col)).add(va) | |
.intoArray(a, row + col * n) | |
row += DMAX.length() | |
} | |
while (row < col + 1) { | |
a(row + col * n) += alpha * x(col) * x(row) | |
row += 1 | |
} | |
col += 1 | |
} | |
} else { | |
super.dsyr(uplo, n, alpha, x, incx, a, lda) | |
} | |
} | |
// [trans=N] y = alpha * a * x + beta * y | |
// [trans=T] y = alpha * a.t * x + beta * y | |
override def dgemv( // scalastyle:ignore | |
trans: String, | |
m: Int, | |
n: Int, | |
alpha: Double, | |
a: Array[Double], | |
lda: Int, | |
x: Array[Double], | |
incx: Int, | |
beta: Double, | |
y: Array[Double], | |
incy: Int): Unit = { | |
// printf("dgemv(trans=%s, m=%s, n=%s, alpha=%s, a=%s, lda=%s, x=%s, " + // scalastyle = ignore | |
// "incx=%s, beta=%s, y=%s, incy=%s)\n", | |
// trans, m, n, alpha, a, lda, x, incx, beta, y, incy) | |
if (trans == "T" && alpha != 0.0 && incx == 1 && incy == 1 && lda == m) { | |
var col = 0 | |
while (col < n) { | |
var ax = 0.0 // A * x | |
var row = 0 | |
while (row < DMAX.loopBound(m)) { | |
val vx = DoubleVector.fromArray(DMAX, x, row) | |
val va = DoubleVector.fromArray(DMAX, a, row + col * m) | |
ax += vx.mul(va).reduceLanes(VectorOperators.ADD) | |
row += DMAX.length() | |
} | |
while (row < m) { | |
ax += x(row) * a(row + col * m) | |
row += 1 | |
} | |
y(col) = alpha * ax + beta * y(col) | |
col += 1 | |
} | |
} else { | |
super.dgemv(trans, m, n, alpha, a, lda, x, incx, beta, y, incy) | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Licensed to the Apache Software Foundation (ASF) under one or more | |
* contributor license agreements. See the NOTICE file distributed with | |
* this work for additional information regarding copyright ownership. | |
* The ASF licenses this file to You under the Apache License, Version 2.0 | |
* (the "License"); you may not use this file except in compliance with | |
* the License. You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.apache.spark.ml.linalg | |
import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} | |
import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} | |
/** | |
* Serialization benchmark for BLAS. | |
* To run this benchmark: | |
* {{{ | |
* 1. without sbt: bin/spark-submit --class <this class> <spark mllib test jar> | |
* 2. build/sbt "mllib/test:runMain <this class>" | |
* 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "mllib/test:runMain <this class>" | |
* Results will be written to "benchmarks/UDTSerializationBenchmark-results.txt". | |
* }}} | |
*/ | |
object BLASBenchmark extends BenchmarkBase { | |
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { | |
val iters = 1e2.toInt | |
val rnd = new scala.util.Random(0) | |
val f2jBLAS = new F2jBLAS | |
val vectorBLAS = new VectorBLAS | |
println("f2jBLAS = " + f2jBLAS.getClass.getName) // scalastyle:off println | |
println("vectorBLAS = " + vectorBLAS.getClass.getName) // scalastyle:off println | |
runBenchmark("daxpy") { | |
val n = 1e7.toInt | |
val a = rnd.nextDouble | |
val x = Array.fill(n) { rnd.nextDouble } | |
val y = Array.fill(n) { rnd.nextDouble } | |
val benchmark = new Benchmark("daxpy", n, iters, output = output) | |
benchmark.addCase("f2j") { _ => | |
f2jBLAS.daxpy(n, a, x, 1, y, 1) | |
} | |
benchmark.addCase("vector") { _ => | |
vectorBLAS.daxpy(n, a, x, 1, y, 1) | |
} | |
benchmark.run() | |
} | |
runBenchmark("sdot") { | |
val n = 1e7.toInt | |
val x = Array.fill(n) { rnd.nextFloat } | |
val y = Array.fill(n) { rnd.nextFloat } | |
val benchmark = new Benchmark("sdot", n, iters, output = output) | |
benchmark.addCase("f2j") { _ => | |
f2jBLAS.sdot(n, x, 1, y, 1) | |
} | |
benchmark.addCase("vector") { _ => | |
vectorBLAS.sdot(n, x, 1, y, 1) | |
} | |
benchmark.run() | |
} | |
runBenchmark("ddot") { | |
val n = 1e7.toInt | |
val x = Array.fill(n) { rnd.nextDouble } | |
val y = Array.fill(n) { rnd.nextDouble } | |
val benchmark = new Benchmark("ddot", n, iters, output = output) | |
benchmark.addCase("f2j") { _ => | |
f2jBLAS.ddot(n, x, 1, y, 1) | |
} | |
benchmark.addCase("vector") { _ => | |
vectorBLAS.ddot(n, x, 1, y, 1) | |
} | |
benchmark.run() | |
} | |
runBenchmark("dscal") { | |
val n = 1e7.toInt | |
val a = rnd.nextDouble | |
val x = Array.fill(n) { rnd.nextDouble } | |
val benchmark = new Benchmark("dscal", n, iters, output = output) | |
benchmark.addCase("f2j") { _ => | |
f2jBLAS.dscal(n, a, x, 1) | |
} | |
benchmark.addCase("vector") { _ => | |
vectorBLAS.dscal(n, a, x, 1) | |
} | |
benchmark.run() | |
} | |
runBenchmark("dspr") { | |
val n = 1e4.toInt | |
val alpha = rnd.nextDouble | |
val x = Array.fill(n) { rnd.nextDouble } | |
val a = Array.fill(n * (n + 1) / 2) { rnd.nextDouble } | |
val benchmark = new Benchmark("dspr", n, iters, output = output) | |
benchmark.addCase("f2j") { _ => | |
f2jBLAS.dspr("U", n, alpha, x, 1, a) | |
} | |
benchmark.addCase("vector") { _ => | |
vectorBLAS.dspr("U", n, alpha, x, 1, a) | |
} | |
benchmark.run() | |
} | |
runBenchmark("dsyr") { | |
val n = 1e4.toInt | |
val alpha = rnd.nextDouble | |
val x = Array.fill(n) { rnd.nextDouble } | |
val a = Array.fill(n * n) { rnd.nextDouble } | |
val lda = n | |
val benchmark = new Benchmark("dsyr", n, iters, output = output) | |
benchmark.addCase("f2j") { _ => | |
f2jBLAS.dsyr("U", n, alpha, x, 1, a, lda) | |
} | |
benchmark.addCase("vector") { _ => | |
vectorBLAS.dsyr("U", n, alpha, x, 1, a, lda) | |
} | |
benchmark.run() | |
} | |
runBenchmark("dgemv[T]") { | |
val m = 1e4.toInt | |
val n = 1e3.toInt | |
val alpha = rnd.nextDouble | |
val a = Array.fill(n * m) { rnd.nextDouble } | |
val lda = m | |
val x = Array.fill(m) { rnd.nextDouble } | |
val beta = rnd.nextDouble | |
val y = Array.fill(n) { rnd.nextDouble } | |
val benchmark = new Benchmark("dgemv[T]", n, iters, output = output) | |
benchmark.addCase("f2j") { _ => | |
f2jBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) | |
} | |
benchmark.addCase("vector") { _ => | |
vectorBLAS.dgemv("T", m, n, alpha, a, lda, x, 1, beta, y, 1) | |
} | |
benchmark.run() | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[info] daxpy: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative | |
[info] ------------------------------------------------------------------------------------------------------------------------ | |
[info] f2j 45 46 1 219.8 4.5 1.0X | |
[info] vector 24 25 3 411.5 2.4 1.9X | |
[info] | |
[info] sdot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative | |
[info] ------------------------------------------------------------------------------------------------------------------------ | |
[info] f2j 54 56 3 185.8 5.4 1.0X | |
[info] vector 18 18 1 563.6 1.8 3.0X | |
[info] | |
[info] ddot: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative | |
[info] ------------------------------------------------------------------------------------------------------------------------ | |
[info] f2j 73 75 2 137.0 7.3 1.0X | |
[info] vector 26 26 1 386.0 2.6 2.8X | |
[info] | |
[info] dscal: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative | |
[info] ------------------------------------------------------------------------------------------------------------------------ | |
[info] f2j 36 36 0 276.4 3.6 1.0X | |
[info] vector 16 16 1 629.7 1.6 2.3X | |
[info] | |
[info] dspr: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative | |
[info] ------------------------------------------------------------------------------------------------------------------------ | |
[info] f2j 153 155 5 0.1 15346.4 1.0X | |
[info] vector 116 117 6 0.1 11564.4 1.3X | |
[info] | |
[info] dsyr: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative | |
[info] ------------------------------------------------------------------------------------------------------------------------ | |
[info] f2j 176 180 7 0.1 17614.8 1.0X | |
[info] vector 121 121 1 0.1 12080.8 1.5X | |
[info] | |
[info] dgemv[T]: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative | |
[info] ------------------------------------------------------------------------------------------------------------------------ | |
[info] f2j 36 36 1 0.0 35728.2 1.0X | |
[info] vector 26 26 0 0.0 25740.3 1.4X | |
[info] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment